Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions google/cloud/sql/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ async def _get_metadata(
)

ip_addresses = (
{ip["type"]: ip["ipAddress"] for ip in ret_dict["ipAddresses"]}
{ip["type"]: [ip["ipAddress"]] for ip in ret_dict["ipAddresses"]}
if "ipAddresses" in ret_dict
else {}
)
Expand All @@ -156,27 +156,66 @@ async def _get_metadata(
if ret_dict.get("pscEnabled"):
# Find PSC instance DNS name in the dns_names field
psc_dns_names = [
d["name"]
d["name"].rstrip(".")
for d in ret_dict.get("dnsNames", [])
if d["connectionType"] == "PRIVATE_SERVICE_CONNECT"
and d["dnsScope"] == "INSTANCE"
]
dns_name = psc_dns_names[0] if psc_dns_names else None
# Sort: .sql-psc.goog first
psc_dns_names.sort(key=lambda x: x.endswith(".sql-psc.goog"), reverse=True)

# Fall back do dns_name field if dns_names is not set
if dns_name is None:
if not psc_dns_names:
dns_name = ret_dict.get("dnsName", None)
if dns_name:
psc_dns_names = [dns_name.rstrip(".")]

# Remove trailing period from DNS name. Required for SSL in Python
if dns_name:
ip_addresses["PSC"] = dns_name.rstrip(".")
if psc_dns_names:
ip_addresses["PSC"] = psc_dns_names

return {
"ip_addresses": ip_addresses,
"server_ca_cert": ret_dict["serverCaCert"]["cert"],
"database_version": ret_dict["databaseVersion"],
}

async def resolve_connect_settings(
self,
dns_name: str,
location: str,
) -> dict[str, Any]:
"""Asynchronously calls the resolveConnectSettings endpoint to resolve a
PSC DNS name to a connection name.

Args:
dns_name (str): The DNS name of the Cloud SQL instance.
location (str): The region/location of the instance.

Returns:
A dictionary containing the resolve response (e.g. connectionName).
"""
headers = {
"Authorization": f"Bearer {self._credentials.token}",
}

url = f"{self._sqladmin_api_endpoint}/sql/{API_VERSION}/dns/{dns_name}/locations/{location}:resolveConnectSettings"

resp = await self._client.get(url, headers=headers)
if resp.status >= 500:
resp = await retry_50x(self._client.get, url, headers=headers)
try:
ret_dict = await resp.json()
if resp.status >= 400:
message = ret_dict.get("error", {}).get("message")
if message:
resp.reason = message
except Exception:
pass
finally:
resp.raise_for_status()

return ret_dict

async def _get_ephemeral(
self,
project: str,
Expand Down
11 changes: 11 additions & 0 deletions google/cloud/sql/connector/connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ def get_preferred_ip(self, ip_type: IPTypes) -> str:
"""Returns the first IP address for the instance, according to the preference
supplied by ip_type. If no IP addressess with the given preference are found,
an error is raised."""
if ip_type.value in self.ip_addrs:
return self.ip_addrs[ip_type.value][0]
raise CloudSQLIPTypeError(
"Cloud SQL instance does not have any IP addresses matching "
f"preference: {ip_type.value}"
)

def get_preferred_ips(self, ip_type: IPTypes) -> list[str]:
"""Returns all IP addresses for the instance, according to the preference
supplied by ip_type. If no IP addressess with the given preference are found,
an error is raised."""
if ip_type.value in self.ip_addrs:
return self.ip_addrs[ip_type.value]
raise CloudSQLIPTypeError(
Expand Down
120 changes: 89 additions & 31 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import asyncio
from functools import partial
import ipaddress
import logging
import os
import socket
Expand Down Expand Up @@ -49,6 +50,27 @@

logger = logging.getLogger(name=__name__)


def _is_ip_address(ip: str) -> bool:
try:
ipaddress.ip_address(ip)
return True
except ValueError:
return False


def _get_fallback_ips(
current_ips: list[str], ip_addresses: dict[str, list[str]]
) -> list[str]:
if not current_ips:
return current_ips
if _is_ip_address(current_ips[0]):
return current_ips
fallback = ip_addresses.get("PRIVATE")
if not fallback:
fallback = ip_addresses.get("PRIMARY")
return fallback if fallback else current_ips

ASYNC_DRIVERS = ["asyncpg"]
SERVER_PROXY_PORT = 3307
_DEFAULT_SCHEME = "https://"
Expand Down Expand Up @@ -316,6 +338,8 @@ async def connect_async(
user_agent=self._user_agent,
driver=driver,
)
if hasattr(self._resolver, "set_client"):
self._resolver.set_client(self._client)
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)

conn_name = await self._resolver.resolve(instance_connection_string)
Expand Down Expand Up @@ -384,40 +408,50 @@ async def connect_async(
conn_info = await monitored_cache.connect_info()
# validate driver matches intended database engine
DriverMapping.validate_engine(driver, conn_info.database_version)
ip_address = conn_info.get_preferred_ip(ip_type)
preferred_ips = conn_info.get_preferred_ips(ip_type)
except Exception:
# with an error from Cloud SQL Admin API call or IP type, invalidate
# the cache and re-raise the error
await self._remove_cached(str(conn_name), enable_iam_auth)
raise

targets = []
# If the connector is configured with a custom DNS name, attempt to use
# that DNS name to connect to the instance. Fall back to the metadata IP
# address if the DNS name does not resolve to an IP address.
if conn_info.conn_name.domain_name and isinstance(self._resolver, DnsResolver):
try:
ips = await self._resolver.resolve_a_record(conn_info.conn_name.domain_name)
if ips:
ip_address = ips[0]
targets.extend(ips)
logger.debug(
f"['{instance_connection_string}']: Custom DNS name "
f"'{conn_info.conn_name.domain_name}' resolved to '{ip_address}', "
f"'{conn_info.conn_name.domain_name}' resolved to '{ips}', "
"using it to connect"
)
else:
fallback_ips = _get_fallback_ips(
preferred_ips, conn_info.ip_addrs
)
logger.debug(
f"['{instance_connection_string}']: Custom DNS name "
f"'{conn_info.conn_name.domain_name}' resolved but returned no "
f"entries, using '{ip_address}' from instance metadata"
f"entries, using '{fallback_ips[0]}' from instance metadata"
)
targets.extend(fallback_ips)
except Exception as e:
fallback_ips = _get_fallback_ips(
preferred_ips, conn_info.ip_addrs
)
logger.debug(
f"['{instance_connection_string}']: Custom DNS name "
f"'{conn_info.conn_name.domain_name}' did not resolve to an IP "
f"address: {e}, using '{ip_address}' from instance metadata"
f"address: {e}, using '{fallback_ips[0]}' from instance metadata"
)
targets.extend(fallback_ips)
else:
targets.extend(preferred_ips)

logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307")
# format `user` param for automatic IAM database authn
if enable_iam_auth:
formatted_user = format_database_user(
Expand All @@ -428,32 +462,56 @@ async def connect_async(
f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}"
)
kwargs["user"] = formatted_user

try:
# async drivers are unblocking and can be awaited directly
if driver in ASYNC_DRIVERS:
return await connector(
ip_address,
await conn_info.create_ssl_context(enable_iam_auth),
**kwargs,
)
# Create socket with SSLContext for sync drivers
ctx = await conn_info.create_ssl_context(enable_iam_auth)
sock = ctx.wrap_socket(
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
server_hostname=ip_address,
)
# If this connection was opened using a domain name, then store it
# for later in case we need to forcibly close it on failover.
if conn_info.conn_name.domain_name:
monitored_cache.sockets.append(sock)
# Synchronous drivers are blocking and run using executor
connect_partial = partial(
connector,
ip_address,
sock,
**kwargs,
)
return await self._loop.run_in_executor(None, connect_partial)
last_ex = None
for target_ip in targets:
logger.debug(f"['{conn_info.conn_name}']: Connecting to {target_ip}:3307")
try:
# async drivers are unblocking and can be awaited directly
if driver in ASYNC_DRIVERS:
conn = await connector(
target_ip,
await conn_info.create_ssl_context(enable_iam_auth),
**kwargs,
)
last_ex = None
return conn

# Create socket with SSLContext for sync drivers
ctx = await conn_info.create_ssl_context(enable_iam_auth)
raw_sock = socket.create_connection((target_ip, SERVER_PROXY_PORT))
try:
sock = ctx.wrap_socket(
raw_sock,
server_hostname=target_ip,
)
except Exception:
raw_sock.close()
raise

# If this connection was opened using a domain name, then store it
# for later in case we need to forcibly close it on failover.
if conn_info.conn_name.domain_name:
monitored_cache.sockets.append(sock)
# Synchronous drivers are blocking and run using executor
connect_partial = partial(
connector,
target_ip,
sock,
**kwargs,
)
conn = await self._loop.run_in_executor(None, connect_partial)
last_ex = None
return conn
except Exception as e:
logger.debug(
f"['{conn_info.conn_name}']: Connection to {target_ip} failed: {e}"
)
last_ex = e

if last_ex:
raise last_ex

except Exception:
# with any exception, we attempt a force refresh, then throw the error
Expand Down
Loading
Loading