From 020052489b144c239cefa39318ea4b65b30768cd Mon Sep 17 00:00:00 2001 From: hessjc Date: Thu, 4 Jun 2026 16:15:30 +0000 Subject: [PATCH] feat: Support Developer Edition connections This PR adds support for Developer Edition connections via the SqlDataService. It includes a fallback mechanism to standard IP connections if the SqlDataService is not supported by the instance edition. See https://github.com/GoogleCloudPlatform/cloud-sql-go-connector/pull/1108 --- build.sh | 5 + .../google/cloud/sql/connector/__init__.py | 33 + .../lib/google/cloud/sql/connector/asyncpg.py | 69 ++ .../lib/google/cloud/sql/connector/client.py | 336 +++++++++ .../cloud/sql/connector/connection_info.py | 134 ++++ .../cloud/sql/connector/connection_name.py | 75 ++ .../google/cloud/sql/connector/connector.py | 705 ++++++++++++++++++ build/lib/google/cloud/sql/connector/enums.py | 87 +++ .../google/cloud/sql/connector/exceptions.py | 93 +++ .../google/cloud/sql/connector/instance.py | 225 ++++++ build/lib/google/cloud/sql/connector/lazy.py | 135 ++++ .../cloud/sql/connector/monitored_cache.py | 146 ++++ .../lib/google/cloud/sql/connector/pg8000.py | 59 ++ .../connector/proto/sql_data_service_pb2.py | 88 +++ .../proto/sql_data_service_pb2_grpc.py | 137 ++++ build/lib/google/cloud/sql/connector/py.typed | 0 .../lib/google/cloud/sql/connector/pymysql.py | 58 ++ build/lib/google/cloud/sql/connector/pytds.py | 71 ++ .../cloud/sql/connector/rate_limiter.py | 79 ++ .../cloud/sql/connector/refresh_utils.py | 155 ++++ .../google/cloud/sql/connector/resolver.py | 91 +++ .../cloud/sql/connector/sqldata_client.py | 355 +++++++++ build/lib/google/cloud/sql/connector/utils.py | 101 +++ .../lib/google/cloud/sql/connector/version.py | 15 + google/cloud/sql/connector/asyncpg.py | 29 +- google/cloud/sql/connector/client.py | 12 +- google/cloud/sql/connector/connection_info.py | 6 +- google/cloud/sql/connector/connector.py | 308 +++++--- google/cloud/sql/connector/enums.py | 1 + google/cloud/sql/connector/monitored_cache.py | 16 +- .../sql/connector/proto/google/rpc/code.proto | 186 +++++ .../proto/google/rpc/error_details.proto | 200 +++++ .../connector/proto/google/rpc/status.proto | 92 +++ .../connector/proto/sql_data_service.proto | 264 +++++++ .../connector/proto/sql_data_service_pb2.py | 88 +++ .../proto/sql_data_service_pb2_grpc.py | 137 ++++ google/cloud/sql/connector/sqldata_client.py | 355 +++++++++ pyproject.toml | 3 + requirements-test.txt | 3 +- tests/system/test_sqldata_connection.py | 104 +++ tests/unit/test_connector.py | 4 +- 41 files changed, 4924 insertions(+), 136 deletions(-) create mode 100644 build/lib/google/cloud/sql/connector/__init__.py create mode 100644 build/lib/google/cloud/sql/connector/asyncpg.py create mode 100644 build/lib/google/cloud/sql/connector/client.py create mode 100644 build/lib/google/cloud/sql/connector/connection_info.py create mode 100644 build/lib/google/cloud/sql/connector/connection_name.py create mode 100644 build/lib/google/cloud/sql/connector/connector.py create mode 100644 build/lib/google/cloud/sql/connector/enums.py create mode 100644 build/lib/google/cloud/sql/connector/exceptions.py create mode 100644 build/lib/google/cloud/sql/connector/instance.py create mode 100644 build/lib/google/cloud/sql/connector/lazy.py create mode 100644 build/lib/google/cloud/sql/connector/monitored_cache.py create mode 100644 build/lib/google/cloud/sql/connector/pg8000.py create mode 100644 build/lib/google/cloud/sql/connector/proto/sql_data_service_pb2.py create mode 100644 build/lib/google/cloud/sql/connector/proto/sql_data_service_pb2_grpc.py create mode 100644 build/lib/google/cloud/sql/connector/py.typed create mode 100644 build/lib/google/cloud/sql/connector/pymysql.py create mode 100644 build/lib/google/cloud/sql/connector/pytds.py create mode 100644 build/lib/google/cloud/sql/connector/rate_limiter.py create mode 100644 build/lib/google/cloud/sql/connector/refresh_utils.py create mode 100644 build/lib/google/cloud/sql/connector/resolver.py create mode 100644 build/lib/google/cloud/sql/connector/sqldata_client.py create mode 100644 build/lib/google/cloud/sql/connector/utils.py create mode 100644 build/lib/google/cloud/sql/connector/version.py create mode 100644 google/cloud/sql/connector/proto/google/rpc/code.proto create mode 100644 google/cloud/sql/connector/proto/google/rpc/error_details.proto create mode 100644 google/cloud/sql/connector/proto/google/rpc/status.proto create mode 100644 google/cloud/sql/connector/proto/sql_data_service.proto create mode 100644 google/cloud/sql/connector/proto/sql_data_service_pb2.py create mode 100644 google/cloud/sql/connector/proto/sql_data_service_pb2_grpc.py create mode 100644 google/cloud/sql/connector/sqldata_client.py create mode 100644 tests/system/test_sqldata_connection.py diff --git a/build.sh b/build.sh index 7e7a1c376..f7c66e19a 100755 --- a/build.sh +++ b/build.sh @@ -150,6 +150,11 @@ function write_e2e_env(){ } +## with_venv - runs a command with the venv activated +function with_venv() { + "$@" +} + ## help - prints the help details ## function help() { diff --git a/build/lib/google/cloud/sql/connector/__init__.py b/build/lib/google/cloud/sql/connector/__init__.py new file mode 100644 index 000000000..6913337d3 --- /dev/null +++ b/build/lib/google/cloud/sql/connector/__init__.py @@ -0,0 +1,33 @@ +""" +Copyright 2019 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from google.cloud.sql.connector.connector import Connector +from google.cloud.sql.connector.connector import create_async_connector +from google.cloud.sql.connector.enums import IPTypes +from google.cloud.sql.connector.enums import RefreshStrategy +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver +from google.cloud.sql.connector.version import __version__ + +__all__ = [ + "__version__", + "create_async_connector", + "Connector", + "DefaultResolver", + "DnsResolver", + "IPTypes", + "RefreshStrategy", +] diff --git a/build/lib/google/cloud/sql/connector/asyncpg.py b/build/lib/google/cloud/sql/connector/asyncpg.py new file mode 100644 index 000000000..2e28dbbaf --- /dev/null +++ b/build/lib/google/cloud/sql/connector/asyncpg.py @@ -0,0 +1,69 @@ +""" +Copyright 2022 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import ssl +from typing import Any, Optional, TYPE_CHECKING + +SERVER_PROXY_PORT = 3307 + +if TYPE_CHECKING: + import asyncpg + + +async def connect( + ip_address: str, ctx: Optional[ssl.SSLContext], **kwargs: Any +) -> "asyncpg.Connection": + """Helper function to create an asyncpg DB-API connection object. + + Args: + ip_address (str): A string containing an IP address for the Cloud SQL + instance. + ctx (ssl.SSLContext): An SSLContext object created from the Cloud SQL + server CA cert and ephemeral cert. Pass None to disable SSL. + kwargs: Keyword arguments for establishing asyncpg connection + object to Cloud SQL instance. + + Returns: + asyncpg.Connection: An asyncpg connection to the Cloud SQL + instance. + Raises: + ImportError: The asyncpg module cannot be imported. + """ + + try: + import asyncpg + except ImportError: + raise ImportError( + 'Unable to import module "asyncpg." Please install and try again.' + ) + user = kwargs.pop("user") + db = kwargs.pop("db") + passwd = kwargs.pop("password", None) + port = kwargs.pop("port", SERVER_PROXY_PORT) + + connect_args = { + "user": user, + "database": db, + "password": passwd, + "host": ip_address, + "port": port, + **kwargs, + } + if ctx is not None: + connect_args["ssl"] = ctx + connect_args["direct_tls"] = True + + return await asyncpg.connect(**connect_args) diff --git a/build/lib/google/cloud/sql/connector/client.py b/build/lib/google/cloud/sql/connector/client.py new file mode 100644 index 000000000..1befdb793 --- /dev/null +++ b/build/lib/google/cloud/sql/connector/client.py @@ -0,0 +1,336 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import datetime +import logging +from typing import Any, Optional, TYPE_CHECKING + +import aiohttp +from cryptography.hazmat.backends import default_backend +from cryptography.x509 import load_pem_x509_certificate +from google.auth.credentials import TokenState +from google.auth.transport import requests + +from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported +from google.cloud.sql.connector.refresh_utils import _downscope_credentials +from google.cloud.sql.connector.refresh_utils import retry_50x +from google.cloud.sql.connector.version import __version__ as version + +if TYPE_CHECKING: + from google.auth.credentials import Credentials + +USER_AGENT: str = f"cloud-sql-python-connector/{version}" +API_VERSION: str = "v1beta4" +DEFAULT_SERVICE_ENDPOINT: str = "https://sqladmin.googleapis.com" + +logger = logging.getLogger(name=__name__) + + +def _format_user_agent(driver: Optional[str], custom: Optional[str]) -> str: + agent = f"{USER_AGENT}+{driver}" if driver else USER_AGENT + if custom and isinstance(custom, str): + agent = f"{agent} {custom}" + return agent + + +class CloudSQLClient: + def __init__( + self, + sqladmin_api_endpoint: Optional[str], + quota_project: Optional[str], + credentials: Credentials, + client: Optional[aiohttp.ClientSession] = None, + driver: Optional[str] = None, + user_agent: Optional[str] = None, + ) -> None: + """Establishes the client to be used for Cloud SQL Admin API requests. + + Args: + sqladmin_api_endpoint (str): Base URL to use when calling + the Cloud SQL Admin API endpoints. + quota_project (str): The Project ID for an existing Google Cloud + project. The project specified is used for quota and + billing purposes. + credentials (google.auth.credentials.Credentials): + A credentials object created from the google-auth Python library. + Must have the Cloud SQL Admin scopes. For more info check out + https://google-auth.readthedocs.io/en/latest/. + client (aiohttp.ClientSession): Async client used to make requests to + Cloud SQL Admin APIs. + Optional, defaults to None and creates new client. + driver (str): Database driver to be used by the client. + """ + user_agent = _format_user_agent(driver, user_agent) + headers = { + "x-goog-api-client": user_agent, + "User-Agent": user_agent, + "Content-Type": "application/json", + } + if quota_project: + headers["x-goog-user-project"] = quota_project + + self._client = client if client else aiohttp.ClientSession(headers=headers) + self._credentials = credentials + if sqladmin_api_endpoint is None: + self._sqladmin_api_endpoint = DEFAULT_SERVICE_ENDPOINT + else: + self._sqladmin_api_endpoint = sqladmin_api_endpoint + self._user_agent = user_agent + + async def _get_metadata( + self, + project: str, + region: str, + instance: str, + ) -> dict[str, Any]: + """Requests metadata from the Cloud SQL Instance and returns a dictionary + containing the IP addresses and certificate authority of the Cloud SQL + Instance. + + Args: + project (str): A string representing the name of the project. + region (str): A string representing the name of the region. + instance (str): A string representing the name of the instance. + + Returns: + A dictionary containing a dictionary of all IP addresses + and their type and a string representing the certificate authority. + + Raises: + ValueError: Provided region does not match the region of the + Cloud SQL instance. + """ + + headers = { + "Authorization": f"Bearer {self._credentials.token}", + } + + url = f"{self._sqladmin_api_endpoint}/sql/{API_VERSION}/projects/{project}/instances/{instance}/connectSettings" + + resp = await self._client.get(url, headers=headers) + if resp.status >= 500: + resp = await retry_50x(self._client.get, url, headers=headers) + # try to get response json for better error message + try: + ret_dict = await resp.json() + if resp.status >= 400: + # if detailed error message is in json response, use as error message + message = ret_dict.get("error", {}).get("message") + if message: + resp.reason = message + # skip, raise_for_status will catch all errors in finally block + except Exception: + pass + finally: + resp.raise_for_status() + + if ret_dict["region"] != region: + raise ValueError( + f'[{project}:{region}:{instance}]: Provided region was mismatched - got region {region}, expected {ret_dict["region"]}.' + ) + + ip_addresses = ( + {ip["type"]: ip["ipAddress"] for ip in ret_dict["ipAddresses"]} + if "ipAddresses" in ret_dict + else {} + ) + # resolve dnsName into IP address for PSC + # Note that we have to check for PSC enablement also because CAS + # instances also set the dnsName field. + if ret_dict.get("pscEnabled"): + # Find PSC instance DNS name in the dns_names field + psc_dns_names = [ + d["name"] + for d in ret_dict.get("dnsNames", []) + if d["connectionType"] == "PRIVATE_SERVICE_CONNECT" + and d["dnsScope"] == "INSTANCE" + ] + dns_name = psc_dns_names[0] if psc_dns_names else None + + # Fall back do dns_name field if dns_names is not set + if dns_name is None: + dns_name = ret_dict.get("dnsName", None) + + # Remove trailing period from DNS name. Required for SSL in Python + if dns_name: + ip_addresses["PSC"] = dns_name.rstrip(".") + + server_ca_cert = None + if "serverCaCert" in ret_dict and "cert" in ret_dict["serverCaCert"]: + server_ca_cert = ret_dict["serverCaCert"]["cert"] + + return { + "ip_addresses": ip_addresses, + "server_ca_cert": server_ca_cert, + "database_version": ret_dict["databaseVersion"], + } + + async def _get_ephemeral( + self, + project: str, + instance: str, + pub_key: str, + enable_iam_auth: bool = False, + ) -> tuple[str, datetime.datetime]: + """Asynchronously requests an ephemeral certificate from the Cloud SQL Instance. + + Args: + project (str): A string representing the name of the project. + instance (str): string representing the name of the instance. + pub_key (str): A string representing PEM-encoded RSA public key. + enable_iam_auth (bool): Enables automatic IAM database + authentication for Postgres or MySQL instances. + + Returns: + A tuple containing an ephemeral certificate from + the Cloud SQL instance as well as a datetime object + representing the expiration time of the certificate. + """ + headers = { + "Authorization": f"Bearer {self._credentials.token}", + } + + url = f"{self._sqladmin_api_endpoint}/sql/{API_VERSION}/projects/{project}/instances/{instance}:generateEphemeralCert" + + data = {"public_key": pub_key} + + if enable_iam_auth: + # down-scope credentials with only IAM login scope (refreshes them too) + login_creds = _downscope_credentials(self._credentials) + data["access_token"] = login_creds.token + + resp = await self._client.post(url, headers=headers, json=data) + if resp.status >= 500: + resp = await retry_50x(self._client.post, url, headers=headers, json=data) + # try to get response json for better error message + try: + ret_dict = await resp.json() + if resp.status >= 400: + # if detailed error message is in json response, use as error message + message = ret_dict.get("error", {}).get("message") + if message: + resp.reason = message + # skip, raise_for_status will catch all errors in finally block + except Exception: + pass + finally: + resp.raise_for_status() + + try: + ephemeral_cert: str = ret_dict["ephemeralCert"]["cert"] + except KeyError as e: + logger.error(f"KeyError in _get_ephemeral parsing generateEphemeralCert: {e}. Response dict: {ret_dict}") + raise + + # decode cert to read expiration + x509 = load_pem_x509_certificate( + ephemeral_cert.encode("UTF-8"), default_backend() + ) + expiration = x509.not_valid_after_utc + # for IAM authentication OAuth2 token is embedded in cert so it + # must still be valid for successful connection + if enable_iam_auth: + token_expiration: datetime.datetime = login_creds.expiry + # google.auth library strips timezone info for backwards compatibality + # reasons with Python 2. Add it back to allow timezone aware datetimes. + # Ref: https://github.com/googleapis/google-auth-library-python/blob/49a5ff7411a2ae4d32a7d11700f9f961c55406a9/google/auth/_helpers.py#L93-L99 + token_expiration = token_expiration.replace(tzinfo=datetime.timezone.utc) + + if expiration > token_expiration: + expiration = token_expiration + return ephemeral_cert, expiration + + async def get_connection_info( + self, + conn_name: ConnectionName, + keys: asyncio.Future, + enable_iam_auth: bool, + ) -> ConnectionInfo: + """Immediately performs a full refresh operation using the Cloud SQL + Admin API. + + Args: + conn_name (ConnectionName): The Cloud SQL instance's + connection name. + keys (asyncio.Future): A future to the client's public-private key + pair. + enable_iam_auth (bool): Whether an automatic IAM database + authentication connection is being requested (Postgres and MySQL). + + Returns: + ConnectionInfo: All the information required to connect securely to + the Cloud SQL instance. + Raises: + AutoIAMAuthNotSupported: Database engine does not support automatic + IAM authentication. + """ + priv_key, pub_key = await keys + # before making Cloud SQL Admin API calls, refresh creds if required + if not self._credentials.token_state == TokenState.FRESH: + self._credentials.refresh(requests.Request()) + + metadata_task = asyncio.create_task( + self._get_metadata( + conn_name.project, + conn_name.region, + conn_name.instance_name, + ) + ) + + ephemeral_task = asyncio.create_task( + self._get_ephemeral( + conn_name.project, + conn_name.instance_name, + pub_key, + enable_iam_auth, + ) + ) + try: + metadata = await metadata_task + # check if automatic IAM database authn is supported for database engine + if enable_iam_auth and not metadata["database_version"].startswith( + ("POSTGRES", "MYSQL") + ): + raise AutoIAMAuthNotSupported( + f"'{metadata['database_version']}' does not support " + "automatic IAM authentication. It is only supported with " + "Cloud SQL Postgres or MySQL instances." + ) + except Exception: + # cancel ephemeral cert task if exception occurs before it is awaited + ephemeral_task.cancel() + raise + + ephemeral_cert, expiration = await ephemeral_task + + return ConnectionInfo( + conn_name, + ephemeral_cert, + metadata["server_ca_cert"], + priv_key, + metadata["ip_addresses"], + metadata["database_version"], + expiration, + ) + + async def close(self) -> None: + """Close CloudSQLClient gracefully.""" + logger.debug("Waiting for Connector's http client to close") + await self._client.close() + logger.debug("Closed Connector's http client") diff --git a/build/lib/google/cloud/sql/connector/connection_info.py b/build/lib/google/cloud/sql/connector/connection_info.py new file mode 100644 index 000000000..bf9330e1b --- /dev/null +++ b/build/lib/google/cloud/sql/connector/connection_info.py @@ -0,0 +1,134 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import abc +from dataclasses import dataclass +import logging +import ssl +from typing import Any, Optional, TYPE_CHECKING + +from aiofiles.tempfile import TemporaryDirectory + +from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError +from google.cloud.sql.connector.exceptions import TLSVersionError +from google.cloud.sql.connector.utils import write_to_file + +if TYPE_CHECKING: + import datetime + + from google.cloud.sql.connector.enums import IPTypes + +logger = logging.getLogger(name=__name__) + + +class ConnectionInfoCache(abc.ABC): + """Abstract class for Connector connection info caches.""" + + @abc.abstractmethod + async def connect_info(self) -> ConnectionInfo: + pass + + @abc.abstractmethod + async def force_refresh(self) -> None: + pass + + @abc.abstractmethod + async def close(self) -> None: + pass + + @property + @abc.abstractmethod + def closed(self) -> bool: + pass + + +@dataclass +class ConnectionInfo: + """Contains all necessary information to connect securely to the + server-side Proxy running on a Cloud SQL instance.""" + + conn_name: ConnectionName + client_cert: str + server_ca_cert: Optional[str] + private_key: bytes + ip_addrs: dict[str, Any] + database_version: str + expiration: datetime.datetime + context: Optional[ssl.SSLContext] = None + + async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext: + """Constructs a SSL/TLS context for the given connection info. + + Cache the SSL context to ensure we don't read from disk repeatedly when + configuring a secure connection. + """ + # if SSL context is cached, use it + if self.context is not None: + return self.context + + if self.server_ca_cert is None: + raise ValueError("Cannot create SSL context: server CA certificate is missing.") + + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + + # update ssl.PROTOCOL_TLS_CLIENT default + context.check_hostname = False + + # TODO: remove if/else when Python 3.10 is min version. PEP 644 has been + # implemented. The ssl module requires OpenSSL 1.1.1 or newer. + # verify OpenSSL version supports TLSv1.3 + if ssl.HAS_TLSv1_3: + # force TLSv1.3 if supported by client + context.minimum_version = ssl.TLSVersion.TLSv1_3 + # fallback to TLSv1.2 for older versions of OpenSSL + else: + if enable_iam_auth: + raise TLSVersionError( + f"Your current version of OpenSSL ({ssl.OPENSSL_VERSION}) does not " + "support TLSv1.3, which is required to use IAM Authentication.\n" + "Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support." + ) + logger.warning( + "TLSv1.3 is not supported with your version of OpenSSL " + f"({ssl.OPENSSL_VERSION}), falling back to TLSv1.2\n" + "Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support." + ) + context.minimum_version = ssl.TLSVersion.TLSv1_2 + + # tmpdir and its contents are automatically deleted after the CA cert + # and ephemeral cert are loaded into the SSLcontext. The values + # need to be written to files in order to be loaded by the SSLContext + async with TemporaryDirectory() as tmpdir: + ca_filename, cert_filename, key_filename = await write_to_file( + tmpdir, self.server_ca_cert, self.client_cert, self.private_key + ) + context.load_cert_chain(cert_filename, keyfile=key_filename) + context.load_verify_locations(cafile=ca_filename) + # set class attribute to cache context for subsequent calls + self.context = context + return context + + def get_preferred_ip(self, ip_type: IPTypes) -> str: + """Returns the first IP address for the instance, according to the preference + supplied by ip_type. If no IP addressess with the given preference are found, + an error is raised.""" + if ip_type.value in self.ip_addrs: + return self.ip_addrs[ip_type.value] + raise CloudSQLIPTypeError( + "Cloud SQL instance does not have any IP addresses matching " + f"preference: {ip_type.value}" + ) diff --git a/build/lib/google/cloud/sql/connector/connection_name.py b/build/lib/google/cloud/sql/connector/connection_name.py new file mode 100644 index 000000000..ad5dc40fb --- /dev/null +++ b/build/lib/google/cloud/sql/connector/connection_name.py @@ -0,0 +1,75 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +import re + +# Instance connection name is the format :: +# Additionally, we have to support legacy "domain-scoped" projects +# (e.g. "google.com:PROJECT") +CONN_NAME_REGEX = re.compile(("([^:]+(:[^:]+)?):([^:]+):([^:]+)")) +# The domain name pattern in accordance with RFC 1035, RFC 1123 and RFC 2181. +DOMAIN_NAME_REGEX = re.compile( + r"^(?:[_a-z0-9](?:[_a-z0-9-]{0,61}[a-z0-9])?\.)+(?:[a-z](?:[a-z0-9-]{0,61}[a-z0-9])?)?$" +) + + +@dataclass +class ConnectionName: + """ConnectionName represents a Cloud SQL instance's "instance connection name". + + Takes the format "::". + """ + + project: str + region: str + instance_name: str + domain_name: str = "" + + def __str__(self) -> str: + if self.domain_name: + return f"{self.domain_name} -> {self.project}:{self.region}:{self.instance_name}" + return f"{self.project}:{self.region}:{self.instance_name}" + + def get_connection_string(self) -> str: + """Get the instance connection string for the Cloud SQL instance.""" + return f"{self.project}:{self.region}:{self.instance_name}" + + +def _is_valid_domain(domain_name: str) -> bool: + if DOMAIN_NAME_REGEX.fullmatch(domain_name) is None: + return False + return True + + +def _parse_connection_name(connection_name: str) -> ConnectionName: + return _parse_connection_name_with_domain_name(connection_name, "") + + +def _parse_connection_name_with_domain_name( + connection_name: str, domain_name: str +) -> ConnectionName: + if CONN_NAME_REGEX.fullmatch(connection_name) is None: + raise ValueError( + "Arg `instance_connection_string` must have " + "format: PROJECT:REGION:INSTANCE, " + f"got {connection_name}." + ) + connection_name_split = CONN_NAME_REGEX.split(connection_name) + return ConnectionName( + connection_name_split[1], + connection_name_split[3], + connection_name_split[4], + domain_name, + ) diff --git a/build/lib/google/cloud/sql/connector/connector.py b/build/lib/google/cloud/sql/connector/connector.py new file mode 100644 index 000000000..6d902b1e2 --- /dev/null +++ b/build/lib/google/cloud/sql/connector/connector.py @@ -0,0 +1,705 @@ +""" +Copyright 2019 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import asyncio +from functools import partial +import logging +import os +import socket +from threading import Thread +from types import TracebackType +from typing import Any, Callable, Optional, Union + +import google.auth +from google.auth.credentials import Credentials +from google.auth.credentials import with_scopes_if_required + +import google.cloud.sql.connector.asyncpg as asyncpg +from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.enums import DriverMapping +from google.cloud.sql.connector.enums import IPTypes +from google.cloud.sql.connector.enums import RefreshStrategy +from google.cloud.sql.connector.exceptions import ClosedConnectorError +from google.cloud.sql.connector.exceptions import ConnectorLoopError +from google.cloud.sql.connector.instance import RefreshAheadCache +from google.cloud.sql.connector.lazy import LazyRefreshCache +from google.cloud.sql.connector.monitored_cache import MonitoredCache +import google.cloud.sql.connector.pg8000 as pg8000 +import google.cloud.sql.connector.pymysql as pymysql +import google.cloud.sql.connector.pytds as pytds +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver +from google.cloud.sql.connector.sqldata_client import FallbackSocket +from google.cloud.sql.connector.sqldata_client import SqlDataClient +from google.cloud.sql.connector.utils import format_database_user +from google.cloud.sql.connector.utils import generate_keys + +logger = logging.getLogger(name=__name__) + +ASYNC_DRIVERS = ["asyncpg"] +SERVER_PROXY_PORT = 3307 +_DEFAULT_SCHEME = "https://" +_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" +_SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}" + + +class Connector: + """Configure and create secure connections to Cloud SQL.""" + + def __init__( + self, + ip_type: str | IPTypes = IPTypes.PUBLIC, + enable_iam_auth: bool = False, + timeout: int = 30, + credentials: Optional[Credentials] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + quota_project: Optional[str] = None, + sqladmin_api_endpoint: Optional[str] = None, + user_agent: Optional[str] = None, + universe_domain: Optional[str] = None, + refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, + resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver, + failover_period: int = 30, + sql_data_endpoint: str = "sqladmin.googleapis.com", + sql_data_stream_timeout: int = 7200, + ) -> None: + """Initializes a Connector instance. + + Args: + ip_type (str | IPTypes): The default IP address type used to connect to + Cloud SQL instances. Can be one of the following: + IPTypes.PUBLIC ("PUBLIC"), IPTypes.PRIVATE ("PRIVATE"), or + IPTypes.PSC ("PSC"). Default: IPTypes.PUBLIC + + enable_iam_auth (bool): Enables automatic IAM database authentication + (Postgres and MySQL) as the default authentication method for all + connections. + + timeout (int): The default time limit in seconds for a connection before + raising a TimeoutError. + + credentials (google.auth.credentials.Credentials): A credentials object + created from the google-auth Python library to be used. + If not specified, Application Default Credentials (ADC) are used. + + quota_project (str): The Project ID for an existing Google Cloud + project. The project specified is used for quota and billing + purposes. If not specified, defaults to project sourced from + environment. + + loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks, if + not specified, defaults to creating new event loop on background + thread. + + sqladmin_api_endpoint (str): Base URL to use when calling the Cloud SQL + Admin API endpoint. Defaults to "https://sqladmin.googleapis.com", + this argument should only be used in development. + + universe_domain (str): The universe domain for Cloud SQL API calls. + Default: "googleapis.com". + + refresh_strategy (str | RefreshStrategy): The default refresh strategy + used to refresh SSL/TLS cert and instance metadata. Can be one + of the following: RefreshStrategy.LAZY ("LAZY") or + RefreshStrategy.BACKGROUND ("BACKGROUND"). + Default: RefreshStrategy.BACKGROUND + + resolver (DefaultResolver | DnsResolver): The class name of the + resolver to use for resolving the Cloud SQL instance connection + name. To resolve a DNS record to an instance connection name, use + DnsResolver. + Default: DefaultResolver + + failover_period (int): The time interval in seconds between each + attempt to check if a failover has occured for a given instance. + Must be used with `resolver=DnsResolver` to have any effect. + Default: 30 + """ + # if refresh_strategy is str, convert to RefreshStrategy enum + if isinstance(refresh_strategy, str): + refresh_strategy = RefreshStrategy._from_str(refresh_strategy) + self._refresh_strategy = refresh_strategy + # if event loop is given, use for background tasks + if loop: + self._loop: asyncio.AbstractEventLoop = loop + self._thread: Optional[Thread] = None + # if lazy refresh is specified we should lazy init keys + if self._refresh_strategy == RefreshStrategy.LAZY: + self._keys: Optional[asyncio.Future] = None + else: + self._keys = loop.create_task(generate_keys()) + # if no event loop is given, spin up new loop in background thread + else: + self._loop = asyncio.new_event_loop() + self._thread = Thread(target=self._loop.run_forever, daemon=True) + self._thread.start() + # if lazy refresh is specified we should lazy init keys + if self._refresh_strategy == RefreshStrategy.LAZY: + self._keys = None + else: + self._keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), self._loop), + loop=self._loop, + ) + # initialize dict to store caches, key is a tuple consisting of instance + # connection name string and enable_iam_auth boolean flag + self._cache: dict[tuple[str, bool], MonitoredCache] = {} + self._client: Optional[CloudSQLClient] = None + self._closed: bool = False + + # initialize credentials + scopes = ["https://www.googleapis.com/auth/sqlservice.admin"] + if credentials: + # verify custom credentials are proper type + # and atleast base class of google.auth.credentials + if not isinstance(credentials, Credentials): + raise TypeError( + "credentials must be of type google.auth.credentials.Credentials," + f" got {type(credentials)}" + ) + self._credentials = with_scopes_if_required(credentials, scopes=scopes) + # otherwise use application default credentials + else: + self._credentials, _ = google.auth.default(scopes=scopes) + # set default params for connections + self._timeout = timeout + self._enable_iam_auth = enable_iam_auth + self._user_agent = user_agent + self._resolver = resolver() + self._failover_period = failover_period + # if ip_type is str, convert to IPTypes enum + if isinstance(ip_type, str): + ip_type = IPTypes._from_str(ip_type) + self._ip_type = ip_type + # check for quota project arg and then env var + if quota_project: + self._quota_project = quota_project + else: + self._quota_project = os.environ.get("GOOGLE_CLOUD_QUOTA_PROJECT") # type: ignore + # check for universe domain arg and then env var + if universe_domain: + self._universe_domain = universe_domain + else: + self._universe_domain = os.environ.get("GOOGLE_CLOUD_UNIVERSE_DOMAIN") # type: ignore + # construct service endpoint for Cloud SQL Admin API calls + if not sqladmin_api_endpoint: + self._sqladmin_api_endpoint = ( + _DEFAULT_SCHEME + + _SQLADMIN_HOST_TEMPLATE.format(universe_domain=self.universe_domain) + ) + # otherwise if endpoint override is passed in use it + else: + self._sqladmin_api_endpoint = sqladmin_api_endpoint + + # validate that the universe domain of the credentials matches the + # universe domain of the service endpoint + if self._credentials.universe_domain != self.universe_domain: + raise ValueError( + f"The configured universe domain ({self.universe_domain}) does " + "not match the universe domain found in the credentials " + f"({self._credentials.universe_domain}). If you haven't " + "configured the universe domain explicitly, `googleapis.com` " + "is the default." + ) + self._sql_data_endpoint = sql_data_endpoint + self._sql_data_stream_timeout = sql_data_stream_timeout + self._sql_data_fallback_cache: set[str] = set() + self._sqldata_clients: list[SqlDataClient] = [] + + + + @property + def universe_domain(self) -> str: + return self._universe_domain or _DEFAULT_UNIVERSE_DOMAIN + + def connect( + self, instance_connection_string: str, driver: str, **kwargs: Any + ) -> Any: + """Connect to a Cloud SQL instance. + + Prepares and returns a database connection object connected to a Cloud + SQL instance using SSL/TLS. Starts a background refresh to periodically + retrieve up-to-date ephemeral certificate and instance metadata. + + Args: + instance_connection_string (str): The instance connection name of the + Cloud SQL instance to connect to. Takes the form of + "project-id:region:instance-name" + + Example: "my-project:us-central1:my-instance" + + driver (str): A string representing the database driver to connect + with. Supported drivers are pymysql, pg8000, and pytds. + + **kwargs: Any driver-specific arguments to pass to the underlying + driver .connect call. + + Returns: + A DB-API connection to the specified Cloud SQL instance. + """ + + # connect runs sync database connections on background thread. + # Async database connections should call 'connect_async' directly to + # avoid hanging indefinitely. + + # Check if the connector is closed before attempting to connect. + if self._closed: + raise ClosedConnectorError( + "Connection attempt failed because the connector has already been closed." + ) + connect_future = asyncio.run_coroutine_threadsafe( + self.connect_async(instance_connection_string, driver, **kwargs), + self._loop, + ) + return connect_future.result() + + def _get_or_create_cache( + self, + conn_name: ConnectionName, + enable_iam_auth: bool, + ) -> MonitoredCache: + assert self._client is not None, "client must be initialized before creating cache" + assert self._keys is not None, "keys must be initialized before creating cache" + if (str(conn_name), enable_iam_auth) in self._cache and not self._cache[ + (str(conn_name), enable_iam_auth) + ].closed: + return self._cache[(str(conn_name), enable_iam_auth)] + + if self._refresh_strategy == RefreshStrategy.LAZY: + logger.debug( + f"['{conn_name}']: Refresh strategy is set to lazy refresh" + ) + cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache( + conn_name, + self._client, + self._keys, + enable_iam_auth, + ) + else: + logger.debug( + f"['{conn_name}']: Refresh strategy is set to backgound refresh" + ) + cache = RefreshAheadCache( + conn_name, + self._client, + self._keys, + enable_iam_auth, + ) + # wrap cache as a MonitoredCache + monitored_cache = MonitoredCache( + cache, + self._failover_period, + self._resolver, + ) + logger.debug(f"['{conn_name}']: Connection info added to cache") + self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache + return monitored_cache + + async def connect_async( + self, instance_connection_string: str, driver: str, **kwargs: Any + ) -> Any: + """Connect asynchronously to a Cloud SQL instance. + + Prepares and returns a database connection object connected to a Cloud + SQL instance using SSL/TLS. Schedules a refresh to periodically + retrieve up-to-date ephemeral certificate and instance metadata. Async + version of Connector.connect. + + Args: + instance_connection_string (str): The instance connection name of the + Cloud SQL instance to connect to. Takes the form of + "project-id:region:instance-name" + + Example: "my-project:us-central1:my-instance" + + driver (str): A string representing the database driver to connect + with. Supported drivers are pymysql, asyncpg, pg8000, and pytds. + + **kwargs: Any driver-specific arguments to pass to the underlying + driver .connect call. + + Returns: + A DB-API connection to the specified Cloud SQL instance. + + Raises: + ValueError: Connection attempt with built-in database authentication + and then subsequent attempt with IAM database authentication. + KeyError: Unsupported database driver Must be one of pymysql, asyncpg, + pg8000, and pytds. + RuntimeError: Connector has been closed. Cannot connect using a closed + Connector. + """ + if self._closed: + raise ClosedConnectorError( + "Connection attempt failed because the connector has already been closed." + ) + # check if event loop is running in current thread + if self._loop != asyncio.get_running_loop(): + raise ConnectorLoopError( + "Running event loop does not match 'connector._loop'. " + "Connector.connect_async() must be called from the event loop " + "the Connector was initialized with. If you need to connect " + "across event loops, please use a new Connector object." + ) + + if self._keys is None: + self._keys = asyncio.create_task(generate_keys()) + if self._client is None: + # lazy init client as it has to be initialized in async context + self._client = CloudSQLClient( + self._sqladmin_api_endpoint, + self._quota_project, + self._credentials, + user_agent=self._user_agent, + driver=driver, + ) + enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) + ip_type = kwargs.pop("ip_type", self._ip_type) + if isinstance(ip_type, str): + ip_type = IPTypes._from_str(ip_type) + + conn_name = await self._resolver.resolve(instance_connection_string) + + if ip_type != IPTypes.SQL_DATA: + monitored_cache = self._get_or_create_cache(conn_name, enable_iam_auth) + + connect_func = { + "pymysql": pymysql.connect, + "pg8000": pg8000.connect, + "asyncpg": asyncpg.connect, + "pytds": pytds.connect, + } + + # only accept supported database drivers + try: + connector: Callable = connect_func[driver] # type: ignore + except KeyError: + raise KeyError(f"Driver '{driver}' is not supported.") + kwargs["timeout"] = kwargs.get("timeout", self._timeout) + + # Host and ssl options come from the certificates and metadata, so we don't + # want the user to specify them. + kwargs.pop("host", None) + kwargs.pop("ssl", None) + kwargs.pop("port", None) + + # attempt to establish connection + try: + if ip_type == IPTypes.SQL_DATA: + logger.debug(f"['{conn_name}']: Connecting via SQL Data Service tunnel") + if enable_iam_auth: + engine = DriverMapping[driver].value + formatted_user = format_database_user( + engine, kwargs["user"] + ) + if formatted_user != kwargs["user"]: + logger.debug( + f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}" + ) + kwargs["user"] = formatted_user + + sqldata_client = SqlDataClient( + endpoint=self._sql_data_endpoint, + credentials=self._credentials, + quota_project=self._quota_project, + timeout=self._sql_data_stream_timeout, + ) + self._sqldata_clients.append(sqldata_client) + + def on_fallback(name): + self._sql_data_fallback_cache.add(name) + + def is_fallback_cached(name): + return name in self._sql_data_fallback_cache + + # Defer cache creation and connect_info call + async def get_conn_info(): + cache = self._get_or_create_cache(conn_name, enable_iam_auth) + return await cache.connect_info() + + tunnel_port = await sqldata_client.connect_tunnel( + instance_connection_name=str(conn_name), + region=conn_name.region, + project=conn_name.project, + get_conn_info=get_conn_info, + enable_iam_auth=enable_iam_auth, + on_fallback=on_fallback, + is_fallback_cached=is_fallback_cached, + ) + + if driver in ASYNC_DRIVERS: + return await connector( + "127.0.0.1", + None, + port=tunnel_port, + **kwargs, + ) + else: + raw_sock = socket.create_connection(("127.0.0.1", tunnel_port)) + fd = raw_sock.detach() + fallback_sock = FallbackSocket(fileno=fd) + + if conn_name.domain_name: + monitored_cache.sockets.append(fallback_sock) + + connect_partial = partial( + connector, + "127.0.0.1", + fallback_sock, + **kwargs, + ) + return await self._loop.run_in_executor(None, connect_partial) + else: + # Standard path (requires metadata and certs) + try: + conn_info = await monitored_cache.connect_info() + # validate driver matches intended database engine + DriverMapping.validate_engine(driver, conn_info.database_version) + ip_address = conn_info.get_preferred_ip(ip_type) + except Exception: + # with an error from Cloud SQL Admin API call or IP type, invalidate + # the cache and re-raise the error + await self._remove_cached(str(conn_name), enable_iam_auth) + raise + + # If the connector is configured with a custom DNS name, attempt to use + # that DNS name to connect to the instance. Fall back to the metadata IP + # address if the DNS name does not resolve to an IP address. + if conn_info.conn_name.domain_name and isinstance(self._resolver, DnsResolver): + try: + ips = await self._resolver.resolve_a_record(conn_info.conn_name.domain_name) + if ips: + ip_address = ips[0] + logger.debug( + f"['{instance_connection_string}']: Custom DNS name " + f"'{conn_info.conn_name.domain_name}' resolved to '{ip_address}', " + "using it to connect" + ) + else: + logger.debug( + f"['{instance_connection_string}']: Custom DNS name " + f"'{conn_info.conn_name.domain_name}' resolved but returned no " + f"entries, using '{ip_address}' from instance metadata" + ) + except Exception as e: + logger.debug( + f"['{instance_connection_string}']: Custom DNS name " + f"'{conn_info.conn_name.domain_name}' did not resolve to an IP " + f"address: {e}, using '{ip_address}' from instance metadata" + ) + + logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") + # format `user` param for automatic IAM database authn + if enable_iam_auth: + formatted_user = format_database_user( + conn_info.database_version, kwargs["user"] + ) + if formatted_user != kwargs["user"]: + logger.debug( + f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}" + ) + kwargs["user"] = formatted_user + + if driver in ASYNC_DRIVERS: + return await connector( + ip_address, + await conn_info.create_ssl_context(enable_iam_auth), + **kwargs, + ) + ctx = await conn_info.create_ssl_context(enable_iam_auth) + ssl_sock = ctx.wrap_socket( + socket.create_connection((ip_address, SERVER_PROXY_PORT)), + server_hostname=ip_address, + ) + if conn_info.conn_name.domain_name: + monitored_cache.sockets.append(ssl_sock) + connect_partial = partial( + connector, + ip_address, + ssl_sock, + **kwargs, + ) + return await self._loop.run_in_executor(None, connect_partial) + + except Exception: + # with any exception, we attempt a force refresh, then throw the error + cache = self._cache.get((str(conn_name), enable_iam_auth)) + if cache: + await cache.force_refresh() + raise + + async def _remove_cached( + self, instance_connection_string: str, enable_iam_auth: bool + ) -> None: + """Stops all background refreshes and deletes the connection + info cache from the map of caches. + """ + logger.debug( + f"['{instance_connection_string}']: Removing connection info from cache" + ) + # remove cache from stored caches and close it + cache = self._cache.pop((instance_connection_string, enable_iam_auth)) + await cache.close() + + def __enter__(self) -> Any: + """Enter context manager by returning Connector object""" + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """Exit context manager by closing Connector""" + self.close() + + async def __aenter__(self) -> Any: + """Enter async context manager by returning Connector object""" + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """Exit async context manager by closing Connector""" + await self.close_async() + + def close(self) -> None: + """Close Connector by stopping tasks and releasing resources.""" + if self._loop.is_running(): + close_future = asyncio.run_coroutine_threadsafe( + self.close_async(), loop=self._loop + ) + try: + # Will attempt to safely shut down tasks for 3s + close_future.result(timeout=3) + except Exception as e: + logger.error(f"Error during close_async: {e}") + # if background thread exists for Connector, clean it up + if self._thread: + if self._loop.is_running(): + # stop event loop running in background thread + self._loop.call_soon_threadsafe(self._loop.stop) + # wait for thread to finish closing (i.e. loop to stop) + self._thread.join() + + async def close_async(self) -> None: + """Helper function to cancel the cache's tasks + and close aiohttp.ClientSession.""" + self._closed = True + if self._client: + await self._client.close() + await asyncio.gather( + *[cache.close() for cache in self._cache.values()], + *[client.close() for client in self._sqldata_clients], + ) + + +async def create_async_connector( + ip_type: str | IPTypes = IPTypes.PUBLIC, + enable_iam_auth: bool = False, + timeout: int = 30, + credentials: Optional[Credentials] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + quota_project: Optional[str] = None, + sqladmin_api_endpoint: Optional[str] = None, + user_agent: Optional[str] = None, + universe_domain: Optional[str] = None, + refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, + resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver, + failover_period: int = 30, +) -> Connector: + """Helper function to create Connector object for asyncio connections. + + Force use of Connector in an asyncio context. Auto-detect and use current + thread's running event loop. + + Args: + ip_type (str | IPTypes): The default IP address type used to connect to + Cloud SQL instances. Can be one of the following: + IPTypes.PUBLIC ("PUBLIC"), IPTypes.PRIVATE ("PRIVATE"), or + IPTypes.PSC ("PSC"). Default: IPTypes.PUBLIC + + enable_iam_auth (bool): Enables automatic IAM database authentication + (Postgres and MySQL) as the default authentication method for all + connections. + + timeout (int): The default time limit in seconds for a connection before + raising a TimeoutError. + + credentials (google.auth.credentials.Credentials): A credentials object + created from the google-auth Python library to be used. + If not specified, Application Default Credentials (ADC) are used. + + quota_project (str): The Project ID for an existing Google Cloud + project. The project specified is used for quota and billing + purposes. If not specified, defaults to project sourced from + environment. + + loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks, if + not specified, defaults to creating new event loop on background + thread. + + sqladmin_api_endpoint (str): Base URL to use when calling the Cloud SQL + Admin API endpoint. Defaults to "https://sqladmin.googleapis.com", + this argument should only be used in development. + + universe_domain (str): The universe domain for Cloud SQL API calls. + Default: "googleapis.com". + + refresh_strategy (str | RefreshStrategy): The default refresh strategy + used to refresh SSL/TLS cert and instance metadata. Can be one + of the following: RefreshStrategy.LAZY ("LAZY") or + RefreshStrategy.BACKGROUND ("BACKGROUND"). + Default: RefreshStrategy.BACKGROUND + + resolver (DefaultResolver | DnsResolver): The class name of the + resolver to use for resolving the Cloud SQL instance connection + name. To resolve a DNS record to an instance connection name, use + DnsResolver. + Default: DefaultResolver + + failover_period (int): The time interval in seconds between each + attempt to check if a failover has occured for a given instance. + Must be used with `resolver=DnsResolver` to have any effect. + Default: 30 + + Returns: + A Connector instance configured with running event loop. + """ + # if no loop given, automatically detect running event loop + if loop is None: + loop = asyncio.get_running_loop() + return Connector( + ip_type=ip_type, + enable_iam_auth=enable_iam_auth, + timeout=timeout, + credentials=credentials, + loop=loop, + quota_project=quota_project, + sqladmin_api_endpoint=sqladmin_api_endpoint, + user_agent=user_agent, + universe_domain=universe_domain, + refresh_strategy=refresh_strategy, + resolver=resolver, + failover_period=failover_period, + ) diff --git a/build/lib/google/cloud/sql/connector/enums.py b/build/lib/google/cloud/sql/connector/enums.py new file mode 100644 index 000000000..f936dba84 --- /dev/null +++ b/build/lib/google/cloud/sql/connector/enums.py @@ -0,0 +1,87 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from enum import Enum + +from google.cloud.sql.connector.exceptions import IncompatibleDriverError + + +# TODO: Replace Enum with StrEnum when Python 3.11 is minimum supported version +class RefreshStrategy(Enum): + LAZY = "LAZY" + BACKGROUND = "BACKGROUND" + + @classmethod + def _missing_(cls, value: object) -> None: + raise ValueError( + f"Incorrect value for refresh_strategy, got '{value}'. Want one of: " + f"{', '.join([repr(m.value) for m in cls])}." + ) + + @classmethod + def _from_str(cls, refresh_strategy: str) -> RefreshStrategy: + """Convert refresh strategy from a str into RefreshStrategy.""" + return cls(refresh_strategy.upper()) + + +class IPTypes(Enum): + PUBLIC = "PRIMARY" + PRIVATE = "PRIVATE" + PSC = "PSC" + SQL_DATA = "SQL_DATA" + + @classmethod + def _missing_(cls, value: object) -> None: + raise ValueError( + f"Incorrect value for ip_type, got '{value}'. Want one of: " + f"{', '.join([repr(m.value) for m in cls])}, 'PUBLIC'." + ) + + @classmethod + def _from_str(cls, ip_type_str: str) -> IPTypes: + """Convert IP type from a str into IPTypes.""" + if ip_type_str.upper() == "PUBLIC": + ip_type_str = "PRIMARY" + return cls(ip_type_str.upper()) + + +class DriverMapping(Enum): + """Maps a given database driver to it's corresponding database engine.""" + + ASYNCPG = "POSTGRES" + PG8000 = "POSTGRES" + PYMYSQL = "MYSQL" + PYTDS = "SQLSERVER" + + @staticmethod + def validate_engine(driver: str, engine_version: str) -> None: + """Validate that the given driver is compatible with the given engine. + + Args: + driver (str): Database driver being used. (i.e. "pg8000") + engine_version (str): Database engine version. (i.e. "POSTGRES_16") + + Raises: + IncompatibleDriverError: If the given driver is not compatible with + the given engine. + """ + mapping = DriverMapping[driver.upper()] + if not engine_version.startswith(mapping.value): + raise IncompatibleDriverError( + f"Database driver '{driver}' is incompatible with database " + f"version '{engine_version}'. Given driver can " + f"only be used with Cloud SQL {mapping.value} databases." + ) diff --git a/build/lib/google/cloud/sql/connector/exceptions.py b/build/lib/google/cloud/sql/connector/exceptions.py new file mode 100644 index 000000000..1f15ced47 --- /dev/null +++ b/build/lib/google/cloud/sql/connector/exceptions.py @@ -0,0 +1,93 @@ +""" +Copyright 2022 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + +class ConnectorLoopError(Exception): + """ + Raised when Connector.connect is called with Connector._loop + in an invalid state (event loop in current thread). + """ + + pass + + +class TLSVersionError(Exception): + """ + Raised when the required TLS protocol version is not supported. + """ + + pass + + +class CloudSQLIPTypeError(Exception): + """ + Raised when IP address for the preferred IP type is not found. + """ + + pass + + +class PlatformNotSupportedError(Exception): + """ + Raised when a feature is not supported on the current platform. + """ + + pass + + +class AutoIAMAuthNotSupported(Exception): + """ + Exception to be raised when Automatic IAM Authentication is not + supported with database engine version. + """ + + pass + + +class RefreshNotValidError(Exception): + """ + Exception to be raised when the task returned from refresh is not valid. + """ + + pass + + +class IncompatibleDriverError(Exception): + """ + Exception to be raised when the database driver given is for the wrong + database engine. (i.e. asyncpg for a MySQL database) + """ + + +class DnsResolutionError(Exception): + """ + Exception to be raised when an instance connection name can not be resolved + from a DNS record. + """ + + +class CacheClosedError(Exception): + """ + Exception to be raised when a ConnectionInfoCache can not be accessed after + it is closed. + """ + + +class ClosedConnectorError(Exception): + """ + Exception to be raised when a Connector is closed and connect method is + called on it. + """ diff --git a/build/lib/google/cloud/sql/connector/instance.py b/build/lib/google/cloud/sql/connector/instance.py new file mode 100644 index 000000000..fb8711309 --- /dev/null +++ b/build/lib/google/cloud/sql/connector/instance.py @@ -0,0 +1,225 @@ +""" +Copyright 2019 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import asyncio +from datetime import datetime +from datetime import timedelta +from datetime import timezone +import logging + +from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_info import ConnectionInfoCache +from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.exceptions import RefreshNotValidError +from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter +from google.cloud.sql.connector.refresh_utils import _is_valid +from google.cloud.sql.connector.refresh_utils import _seconds_until_refresh + +logger = logging.getLogger(name=__name__) + +APPLICATION_NAME = "cloud-sql-python-connector" + + +class RefreshAheadCache(ConnectionInfoCache): + """Cache that refreshes connection info in the background prior to expiration. + + Background tasks are used to schedule refresh attempts to get a new + ephemeral certificate and Cloud SQL metadata (IP addresses, etc.) ahead of + expiration. + """ + + def __init__( + self, + conn_name: ConnectionName, + client: CloudSQLClient, + keys: asyncio.Future, + enable_iam_auth: bool = False, + ) -> None: + """Initializes a RefreshAheadCache instance. + + Args: + conn_name (ConnectionName): The Cloud SQL instance's + connection name. + client (CloudSQLClient): The Cloud SQL Client instance. + keys (asyncio.Future): A future to the client's public-private key + pair. + enable_iam_auth (bool): Enables automatic IAM database authentication + (Postgres and MySQL) as the default authentication method for all + connections. + """ + self._conn_name = conn_name + + self._enable_iam_auth = enable_iam_auth + self._keys = keys + self._client = client + self._refresh_rate_limiter = AsyncRateLimiter( + max_capacity=2, + rate=1 / 30, + ) + self._refresh_in_progress = asyncio.locks.Event() + self._current: asyncio.Task = self._schedule_refresh(0) + self._next: asyncio.Task = self._current + self._closed = False + + @property + def conn_name(self) -> ConnectionName: + return self._conn_name + + @property + def closed(self) -> bool: + return self._closed + + async def force_refresh(self) -> None: + """ + Forces a new refresh attempt immediately to be used for future connection attempts. + """ + # if next refresh is not already in progress, cancel it and schedule new one immediately + if not self._refresh_in_progress.is_set(): + self._next.cancel() + self._next = self._schedule_refresh(0) + # block all sequential connection attempts on the next refresh result if current is invalid + if not await _is_valid(self._current): + self._current = self._next + + async def _perform_refresh(self) -> ConnectionInfo: + """Retrieves instance metadata and ephemeral certificate from the + Cloud SQL Instance. + + Returns: + A ConnectionInfo instance containing a string representing the + ephemeral certificate, a dict containing the instances IP adresses, + a string representing a PEM-encoded private key and a string + representing a PEM-encoded certificate authority. + """ + self._refresh_in_progress.set() + logger.debug( + f"['{self._conn_name}']: Connection info refresh operation started" + ) + + try: + await self._refresh_rate_limiter.acquire() + connection_info = await self._client.get_connection_info( + self._conn_name, + self._keys, + self._enable_iam_auth, + ) + logger.debug( + f"['{self._conn_name}']: Connection info refresh operation complete" + ) + logger.debug( + f"['{self._conn_name}']: Current certificate " + f"expiration = {connection_info.expiration.isoformat()}" + ) + + except Exception as e: + logger.debug( + f"['{self._conn_name}']: Connection info " + f"refresh operation failed: {str(e)}" + ) + raise + + finally: + self._refresh_in_progress.clear() + return connection_info + + def _schedule_refresh(self, delay: int) -> asyncio.Task: + """ + Schedule task to sleep and then perform refresh to get ConnectionInfo. + + Args: + delay (int): Time in seconds to sleep before performing a refresh. + + Returns: + An asyncio.Task representing the scheduled refresh. + """ + + async def _refresh_task(self: RefreshAheadCache, delay: int) -> ConnectionInfo: + """ + A coroutine that sleeps for the specified amount of time before + running _perform_refresh. + """ + refresh_task: asyncio.Task + try: + if delay > 0: + await asyncio.sleep(delay) + refresh_task = asyncio.create_task(self._perform_refresh()) + refresh_data = await refresh_task + # check that refresh is valid + if not await _is_valid(refresh_task): + raise RefreshNotValidError( + f"['{self._conn_name}']: Invalid refresh operation. Certficate appears to be expired." + ) + except asyncio.CancelledError: + logger.debug( + f"['{self._conn_name}']: Scheduled refresh" " operation cancelled" + ) + raise + # bad refresh attempt + except Exception as e: + logger.exception( + f"['{self._conn_name}']: " + "An error occurred while performing refresh. " + "Scheduling another refresh attempt immediately", + exc_info=e, + ) + # check if current metadata is invalid (expired), + # don't want to replace valid metadata with invalid refresh + if not await _is_valid(self._current): + self._current = refresh_task + # schedule new refresh attempt immediately + self._next = self._schedule_refresh(0) + raise + # if valid refresh, replace current with valid metadata and schedule next refresh + self._current = refresh_task + # calculate refresh delay based on certificate expiration + delay = _seconds_until_refresh(refresh_data.expiration) + logger.debug( + f"['{self._conn_name}']: Connection info refresh" + " operation scheduled for " + f"{(datetime.now(timezone.utc) + timedelta(seconds=delay)).isoformat(timespec='seconds')} " + f"(now + {timedelta(seconds=delay)})" + ) + self._next = self._schedule_refresh(delay) + + return refresh_data + + # schedule refresh task and return it + scheduled_task = asyncio.create_task(_refresh_task(self, delay)) + return scheduled_task + + async def connect_info(self) -> ConnectionInfo: + """Retrieves ConnectionInfo instance for establishing a secure + connection to the Cloud SQL instance. + """ + return await self._current + + async def close(self) -> None: + """Cleanup function to make sure tasks have finished to have a + graceful exit. + """ + logger.debug( + f"['{self._conn_name}']: Canceling connection info " + "refresh operation tasks" + ) + self._current.cancel() + self._next.cancel() + # gracefully wait for tasks to cancel + tasks = asyncio.gather(self._current, self._next, return_exceptions=True) + await asyncio.wait_for(tasks, timeout=2.0) + self._closed = True diff --git a/build/lib/google/cloud/sql/connector/lazy.py b/build/lib/google/cloud/sql/connector/lazy.py new file mode 100644 index 000000000..c75d07e52 --- /dev/null +++ b/build/lib/google/cloud/sql/connector/lazy.py @@ -0,0 +1,135 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from datetime import datetime +from datetime import timedelta +from datetime import timezone +import logging +from typing import Optional + +from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_info import ConnectionInfoCache +from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.refresh_utils import _refresh_buffer + +logger = logging.getLogger(name=__name__) + + +class LazyRefreshCache(ConnectionInfoCache): + """Cache that refreshes connection info when a caller requests a connection. + + Only refreshes the cache when a new connection is requested and the current + certificate is close to or already expired. + + This is the recommended option for serverless environments. + """ + + def __init__( + self, + conn_name: ConnectionName, + client: CloudSQLClient, + keys: asyncio.Future, + enable_iam_auth: bool = False, + ) -> None: + """Initializes a LazyRefreshCache instance. + + Args: + conn_name (ConnectionName): The Cloud SQL instance's + connection name. + client (CloudSQLClient): The Cloud SQL Client instance. + keys (asyncio.Future): A future to the client's public-private key + pair. + enable_iam_auth (bool): Enables automatic IAM database authentication + (Postgres and MySQL) as the default authentication method for all + connections. + """ + self._conn_name = conn_name + self._enable_iam_auth = enable_iam_auth + self._keys = keys + self._client = client + self._lock = asyncio.Lock() + self._cached: Optional[ConnectionInfo] = None + self._needs_refresh = False + self._closed = False + + @property + def conn_name(self) -> ConnectionName: + return self._conn_name + + @property + def closed(self) -> bool: + return self._closed + + async def force_refresh(self) -> None: + """ + Invalidates the cache and configures the next call to + connect_info() to retrieve a fresh ConnectionInfo instance. + """ + async with self._lock: + self._needs_refresh = True + + async def connect_info(self) -> ConnectionInfo: + """Retrieves ConnectionInfo instance for establishing a secure + connection to the Cloud SQL instance. + """ + async with self._lock: + # If connection info is cached, check expiration. + # Pad expiration with a buffer to give the client plenty of time to + # establish a connection to the server with the certificate. + if ( + self._cached + and not self._needs_refresh + and datetime.now(timezone.utc) + < (self._cached.expiration - timedelta(seconds=_refresh_buffer)) + ): + logger.debug( + f"['{self._conn_name}']: Connection info " + "is still valid, using cached info" + ) + return self._cached + logger.debug( + f"['{self._conn_name}']: Connection info " "refresh operation started" + ) + try: + conn_info = await self._client.get_connection_info( + self._conn_name, + self._keys, + self._enable_iam_auth, + ) + except Exception as e: + logger.debug( + f"['{self._conn_name}']: Connection info " + f"refresh operation failed: {str(e)}" + ) + raise + logger.debug( + f"['{self._conn_name}']: Connection info " + "refresh operation completed successfully" + ) + logger.debug( + f"['{self._conn_name}']: Current certificate " + f"expiration = {str(conn_info.expiration)}" + ) + self._cached = conn_info + self._needs_refresh = False + return conn_info + + async def close(self) -> None: + """Close is a no-op and provided purely for a consistent interface with + other cache types. + """ + self._closed = True + return diff --git a/build/lib/google/cloud/sql/connector/monitored_cache.py b/build/lib/google/cloud/sql/connector/monitored_cache.py new file mode 100644 index 000000000..79a77aeda --- /dev/null +++ b/build/lib/google/cloud/sql/connector/monitored_cache.py @@ -0,0 +1,146 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import socket +from typing import Any, Callable, Optional, Union + +from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_info import ConnectionInfoCache +from google.cloud.sql.connector.exceptions import CacheClosedError +from google.cloud.sql.connector.instance import RefreshAheadCache +from google.cloud.sql.connector.lazy import LazyRefreshCache +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver + +logger = logging.getLogger(name=__name__) + + +class MonitoredCache(ConnectionInfoCache): + def __init__( + self, + cache: Union[RefreshAheadCache, LazyRefreshCache], + failover_period: int, + resolver: Union[DefaultResolver, DnsResolver], + ) -> None: + self.resolver = resolver + self.cache = cache + self.domain_name_ticker: Optional[asyncio.Task] = None + self.sockets: list[socket.socket] = [] + + # If domain name is configured for instance and failover period is set, + # poll for DNS record changes. + if self.cache.conn_name.domain_name and failover_period > 0: + self.domain_name_ticker = asyncio.create_task( + ticker(failover_period, self._check_domain_name) + ) + logger.debug( + f"['{self.cache.conn_name}']: Configured polling of domain " + f"name with failover period of {failover_period} seconds." + ) + + @property + def closed(self) -> bool: + return self.cache.closed + + def _purge_closed_sockets(self) -> None: + """Remove closed sockets from monitored cache. + + If a socket is closed by the database driver we should remove it from + list of sockets. + """ + open_sockets = [] + for sock in self.sockets: + # Check fileno for if socket is closed. Will return + # -1 on failure, which will be used to signal socket closed. + if sock.fileno() != -1: + open_sockets.append(sock) + self.sockets = open_sockets + + async def _check_domain_name(self) -> None: + # remove any closed connections from cache + self._purge_closed_sockets() + try: + # Resolve domain name and see if Cloud SQL instance connection name + # has changed. If it has, close all connections. + new_conn_name = await self.resolver.resolve( + self.cache.conn_name.domain_name + ) + if new_conn_name != self.cache.conn_name: + logger.debug( + f"['{self.cache.conn_name}']: Cloud SQL instance changed " + f"from {self.cache.conn_name.get_connection_string()} to " + f"{new_conn_name.get_connection_string()}, closing all " + "connections!" + ) + await self.close() + + except Exception as e: + # Domain name checks should not be fatal, log error and continue. + logger.debug( + f"['{self.cache.conn_name}']: Unable to check domain name, " + f"domain name {self.cache.conn_name.domain_name} did not " + f"resolve: {e}" + ) + + async def connect_info(self) -> ConnectionInfo: + if self.closed: + raise CacheClosedError( + "Can not get connection info, cache has already been closed." + ) + return await self.cache.connect_info() + + async def force_refresh(self) -> None: + # if cache is closed do not refresh + if self.closed: + return + return await self.cache.force_refresh() + + async def close(self) -> None: + # Cancel domain name ticker task. + if self.domain_name_ticker: + self.domain_name_ticker.cancel() + try: + await self.domain_name_ticker + except asyncio.CancelledError: + logger.debug( + f"['{self.cache.conn_name}']: Cancelled domain name polling task." + ) + finally: + self.domain_name_ticker = None + # If cache is already closed, no further work. + if self.closed: + return + + # Close underyling ConnectionInfoCache + await self.cache.close() + + # Close any still open sockets + for sock in self.sockets: + # Check fileno for if socket is closed. Will return + # -1 on failure, which will be used to signal socket closed. + if sock.fileno() != -1: + sock.close() + + +async def ticker(interval: int, function: Callable, *args: Any, **kwargs: Any) -> None: + """ + Ticker function to sleep for specified interval and then schedule call + to given function. + """ + while True: + # Sleep for interval and then schedule task + await asyncio.sleep(interval) + asyncio.create_task(function(*args, **kwargs)) diff --git a/build/lib/google/cloud/sql/connector/pg8000.py b/build/lib/google/cloud/sql/connector/pg8000.py new file mode 100644 index 000000000..5a43ad319 --- /dev/null +++ b/build/lib/google/cloud/sql/connector/pg8000.py @@ -0,0 +1,59 @@ +""" +Copyright 2022 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import ssl +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + import pg8000 + + +def connect( + ip_address: str, sock: ssl.SSLSocket, **kwargs: Any +) -> "pg8000.dbapi.Connection": + """Helper function to create a pg8000 DB-API connection object. + + Args: + ip_address (str): A string containing an IP address for the Cloud SQL + instance. + sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL + server CA cert and ephemeral cert. + kwargs: Additional arguments to pass to the pg8000 connect method. + + Returns: + pg8000.dbapi.Connection: A pg8000 connection to the Cloud SQL + instance. + + Raises: + ImportError: The pg8000 module cannot be imported. + """ + try: + import pg8000 + except ImportError: + raise ImportError( + 'Unable to import module "pg8000." Please install and try again.' + ) + + user = kwargs.pop("user") + db = kwargs.pop("db") + passwd = kwargs.pop("password", None) + return pg8000.dbapi.connect( + user, + database=db, + password=passwd, + sock=sock, + **kwargs, + ) diff --git a/build/lib/google/cloud/sql/connector/proto/sql_data_service_pb2.py b/build/lib/google/cloud/sql/connector/proto/sql_data_service_pb2.py new file mode 100644 index 000000000..fb25be21d --- /dev/null +++ b/build/lib/google/cloud/sql/connector/proto/sql_data_service_pb2.py @@ -0,0 +1,88 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: google/cloud/sql/connector/proto/sql_data_service.proto +# Protobuf Python Version: 6.33.5 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 6, + 33, + 5, + '', + 'google/cloud/sql/connector/proto/sql_data_service.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n7google/cloud/sql/connector/proto/sql_data_service.proto\x12\x18google.cloud.sql.v1beta4\x1a\x17google/rpc/status.proto\"\x82\x04\n\x14StreamSqlDataRequest\x12\x13\n\x0blocation_id\x18\x01 \x01(\t\x12K\n\x13\x63onnection_settings\x18\x02 \x01(\x0b\x32,.google.cloud.sql.v1beta4.ConnectionSettingsH\x00\x12:\n\x07payload\x18\x03 \x01(\x0b\x32\'.google.cloud.sql.v1beta4.ClientPayloadH\x00\x12*\n\x03\x61\x63k\x18\x04 \x01(\x0b\x32\x1d.google.cloud.sql.v1beta4.Ack\x12?\n\rstart_session\x18\x05 \x01(\x0b\x32&.google.cloud.sql.v1beta4.StartSessionH\x01\x12\x45\n\x10\x63ontinue_session\x18\x06 \x01(\x0b\x32).google.cloud.sql.v1beta4.ContinueSessionH\x01\x12\x34\n\x04\x64\x61ta\x18\x07 \x01(\x0b\x32$.google.cloud.sql.v1beta4.DataPacketH\x01\x12G\n\x11terminate_session\x18\x08 \x01(\x0b\x32*.google.cloud.sql.v1beta4.TerminateSessionH\x01\x42\x0e\n\x0cmessage_typeB\t\n\x07message\"_\n\x12\x43onnectionSettings\x12;\n\x0binstance_id\x18\x01 \x01(\x0b\x32$.google.cloud.sql.v1beta4.InstanceIdH\x00:\x02\x18\x01\x42\x08\n\x06target\"L\n\x0cStartSession\x12\x13\n\x0blocation_id\x18\x01 \x01(\t\x12\x13\n\x0binstance_id\x18\x02 \x01(\t\x12\x12\n\nsession_id\x18\x03 \x01(\t\"O\n\x0f\x43ontinueSession\x12\x13\n\x0blocation_id\x18\x01 \x01(\t\x12\x13\n\x0binstance_id\x18\x02 \x01(\t\x12\x12\n\nsession_id\x18\x03 \x01(\t\"\"\n\nInstanceId\x12\x10\n\x08instance\x18\x01 \x01(\t:\x02\x18\x01\"!\n\rClientPayload\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c:\x02\x18\x01\"\xd8\x02\n\x15StreamSqlDataResponse\x12:\n\x07payload\x18\x01 \x01(\x0b\x32\'.google.cloud.sql.v1beta4.ServerPayloadH\x00\x12*\n\x03\x61\x63k\x18\x02 \x01(\x0b\x32\x1d.google.cloud.sql.v1beta4.Ack\x12\x45\n\x10session_metadata\x18\x03 \x01(\x0b\x32).google.cloud.sql.v1beta4.SessionMetadataH\x01\x12\x34\n\x04\x64\x61ta\x18\x04 \x01(\x0b\x32$.google.cloud.sql.v1beta4.DataPacketH\x01\x12G\n\x11terminate_session\x18\x05 \x01(\x0b\x32*.google.cloud.sql.v1beta4.TerminateSessionH\x01\x42\x06\n\x04typeB\t\n\x07message\"!\n\rServerPayload\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c:\x02\x18\x01\"W\n\x0fSessionMetadata\x12\x44\n\x12supported_features\x18\x01 \x03(\x0e\x32(.google.cloud.sql.v1beta4.SqlDataFeature\"5\n\nDataPacket\x12\x19\n\x11\x66irst_byte_offset\x18\x01 \x01(\x03\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x1e\n\x03\x41\x63k\x12\x17\n\x0freceived_offset\x18\x01 \x01(\x03\"6\n\x10TerminateSession\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.google.rpc.Status*\xf6\x02\n\x18StreamSqlDataErrorReason\x12(\n$STREAM_SQL_DATA_ERROR_REASON_UNKNOWN\x10\x00\x12:\n6STREAM_SQL_DATA_ERROR_REASON_UNSUPPORTED_INSTANCE_TYPE\x10\x01\x12\x31\n-STREAM_SQL_DATA_ERROR_REASON_RECONNECT_FAILED\x10\x03\x12\x31\n-STREAM_SQL_DATA_ERROR_REASON_DB_CLIENT_CLOSED\x10\x04\x12\x31\n-STREAM_SQL_DATA_ERROR_REASON_DB_SERVER_CLOSED\x10\x05\x12,\n(STREAM_SQL_DATA_ERROR_REASON_INVALID_ACK\x10\x06\x12-\n)STREAM_SQL_DATA_ERROR_REASON_DISCONNECTED\x10\x07*R\n\x0eSqlDataFeature\x12 \n\x1cSQL_DATA_FEATURE_UNSPECIFIED\x10\x00\x12\x1e\n\x1aSQL_DATA_FEATURE_RECONNECT\x10\x01\x32\x88\x01\n\x0eSqlDataService\x12v\n\rStreamSqlData\x12..google.cloud.sql.v1beta4.StreamSqlDataRequest\x1a/.google.cloud.sql.v1beta4.StreamSqlDataResponse\"\x00(\x01\x30\x01\x42\x45\n\x1c\x63om.google.cloud.sql.v1beta4B\x11\x43loudSqlDataProtoP\x01Z\x10internal/sqldatab\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.cloud.sql.connector.proto.sql_data_service_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\034com.google.cloud.sql.v1beta4B\021CloudSqlDataProtoP\001Z\020internal/sqldata' + _globals['_CONNECTIONSETTINGS']._loaded_options = None + _globals['_CONNECTIONSETTINGS']._serialized_options = b'\030\001' + _globals['_INSTANCEID']._loaded_options = None + _globals['_INSTANCEID']._serialized_options = b'\030\001' + _globals['_CLIENTPAYLOAD']._loaded_options = None + _globals['_CLIENTPAYLOAD']._serialized_options = b'\030\001' + _globals['_SERVERPAYLOAD']._loaded_options = None + _globals['_SERVERPAYLOAD']._serialized_options = b'\030\001' + _globals['_STREAMSQLDATAERRORREASON']._serialized_start=1569 + _globals['_STREAMSQLDATAERRORREASON']._serialized_end=1943 + _globals['_SQLDATAFEATURE']._serialized_start=1945 + _globals['_SQLDATAFEATURE']._serialized_end=2027 + _globals['_STREAMSQLDATAREQUEST']._serialized_start=111 + _globals['_STREAMSQLDATAREQUEST']._serialized_end=625 + _globals['_CONNECTIONSETTINGS']._serialized_start=627 + _globals['_CONNECTIONSETTINGS']._serialized_end=722 + _globals['_STARTSESSION']._serialized_start=724 + _globals['_STARTSESSION']._serialized_end=800 + _globals['_CONTINUESESSION']._serialized_start=802 + _globals['_CONTINUESESSION']._serialized_end=881 + _globals['_INSTANCEID']._serialized_start=883 + _globals['_INSTANCEID']._serialized_end=917 + _globals['_CLIENTPAYLOAD']._serialized_start=919 + _globals['_CLIENTPAYLOAD']._serialized_end=952 + _globals['_STREAMSQLDATARESPONSE']._serialized_start=955 + _globals['_STREAMSQLDATARESPONSE']._serialized_end=1299 + _globals['_SERVERPAYLOAD']._serialized_start=1301 + _globals['_SERVERPAYLOAD']._serialized_end=1334 + _globals['_SESSIONMETADATA']._serialized_start=1336 + _globals['_SESSIONMETADATA']._serialized_end=1423 + _globals['_DATAPACKET']._serialized_start=1425 + _globals['_DATAPACKET']._serialized_end=1478 + _globals['_ACK']._serialized_start=1480 + _globals['_ACK']._serialized_end=1510 + _globals['_TERMINATESESSION']._serialized_start=1512 + _globals['_TERMINATESESSION']._serialized_end=1566 + _globals['_SQLDATASERVICE']._serialized_start=2030 + _globals['_SQLDATASERVICE']._serialized_end=2166 +# @@protoc_insertion_point(module_scope) diff --git a/build/lib/google/cloud/sql/connector/proto/sql_data_service_pb2_grpc.py b/build/lib/google/cloud/sql/connector/proto/sql_data_service_pb2_grpc.py new file mode 100644 index 000000000..42240ecbd --- /dev/null +++ b/build/lib/google/cloud/sql/connector/proto/sql_data_service_pb2_grpc.py @@ -0,0 +1,137 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" + +import grpc + +from google.cloud.sql.connector.proto import ( + sql_data_service_pb2 as google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2, +) + +GRPC_GENERATED_VERSION = '1.81.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + ' but the generated code in google/cloud/sql/connector/proto/sql_data_service_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class SqlDataServiceStub: + """Service for streaming data to and from Cloud SQL instances. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.StreamSqlData = channel.stream_stream( + '/google.cloud.sql.v1beta4.SqlDataService/StreamSqlData', + request_serializer=google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataRequest.SerializeToString, + response_deserializer=google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataResponse.FromString, + _registered_method=True) + + +class SqlDataServiceServicer: + """Service for streaming data to and from Cloud SQL instances. + """ + + def StreamSqlData(self, request_iterator, context): + """`StreamSqlData` establishes a bidirectional stream to a Cloud SQL instance, + and then streams data to and from the instance. + + The first message from the client MUST be a `StreamSqlDataRequest` request + with configuration settings, including required values for the + `connection_settings` field. Subsequent messages from the client may + contain the `payload` field. + + Messages from the server may contain the `payload` field. + + The `payload` fields of the request and response streams contain the raw + data of the database's native wire protocol (e.g., PostgreSQL wire + protocol). The database client is responsible for generating and parsing + this data. + + Any errors on initial connection (e.g., connection failure, authorization + issues, network problems) will result in the stream being terminated with + an appropriate RPC status exception. + + After a successful connection is made, if an error occurs, then the server + terminates connection and returns the appropriate RPC status exception. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_SqlDataServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'StreamSqlData': grpc.stream_stream_rpc_method_handler( + servicer.StreamSqlData, + request_deserializer=google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataRequest.FromString, + response_serializer=google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'google.cloud.sql.v1beta4.SqlDataService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('google.cloud.sql.v1beta4.SqlDataService', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class SqlDataService: + """Service for streaming data to and from Cloud SQL instances. + """ + + @staticmethod + def StreamSqlData(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_stream( + request_iterator, + target, + '/google.cloud.sql.v1beta4.SqlDataService/StreamSqlData', + google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataRequest.SerializeToString, + google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/build/lib/google/cloud/sql/connector/py.typed b/build/lib/google/cloud/sql/connector/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/build/lib/google/cloud/sql/connector/pymysql.py b/build/lib/google/cloud/sql/connector/pymysql.py new file mode 100644 index 000000000..e01cfed08 --- /dev/null +++ b/build/lib/google/cloud/sql/connector/pymysql.py @@ -0,0 +1,58 @@ +""" +Copyright 2022 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import ssl +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + import pymysql + + +def connect( + ip_address: str, sock: ssl.SSLSocket, **kwargs: Any +) -> "pymysql.connections.Connection": + """Helper function to create a pymysql DB-API connection object. + + Args: + ip_address (str): A string containing an IP address for the Cloud SQL + instance. + sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL + server CA cert and ephemeral cert. + + Returns: + pymysql.connections.Connection: A pymysql connection to the Cloud SQL + instance. + + Raises: + ImportError: The pymysql module cannot be imported. + """ + try: + import pymysql + except ImportError: + raise ImportError( + 'Unable to import module "pymysql." Please install and try again.' + ) + + # allow automatic IAM database authentication to not require password + kwargs["password"] = kwargs["password"] if "password" in kwargs else None + + # pop timeout as timeout arg is called 'connect_timeout' for pymysql + timeout = kwargs.pop("timeout") + kwargs["connect_timeout"] = kwargs.get("connect_timeout", timeout) + # Create pymysql connection object and hand in pre-made connection + conn = pymysql.Connection(host=ip_address, defer_connect=True, **kwargs) + conn.connect(sock) + return conn diff --git a/build/lib/google/cloud/sql/connector/pytds.py b/build/lib/google/cloud/sql/connector/pytds.py new file mode 100644 index 000000000..6cc3c0934 --- /dev/null +++ b/build/lib/google/cloud/sql/connector/pytds.py @@ -0,0 +1,71 @@ +""" +Copyright 2022 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import platform +import ssl +from typing import Any, TYPE_CHECKING + +from google.cloud.sql.connector.exceptions import PlatformNotSupportedError + +if TYPE_CHECKING: + import pytds + + +def connect(ip_address: str, sock: ssl.SSLSocket, **kwargs: Any) -> "pytds.Connection": + """Helper function to create a pytds DB-API connection object. + + Args: + ip_address (str): A string containing an IP address for the Cloud SQL + instance. + sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL + server CA cert and ephemeral cert. + + Returns: + pytds.Connection: A pytds connection to the Cloud SQL + instance. + + Raises: + ImportError: The pytds module cannot be imported. + """ + try: + import pytds + except ImportError: + raise ImportError( + 'Unable to import module "pytds." Please install and try again.' + ) + + db = kwargs.pop("db", None) + + if kwargs.pop("active_directory_auth", False): + if platform.system() == "Windows": + # Ignore username and password if using active directory auth + server_name = kwargs.pop("server_name") + return pytds.connect( + database=db, + auth=pytds.login.SspiAuth(port=1433, server_name=server_name), + sock=sock, + **kwargs, + ) + else: + raise PlatformNotSupportedError( + "Active Directory authentication is currently only supported on Windows." + ) + + user = kwargs.pop("user") + passwd = kwargs.pop("password") + return pytds.connect( + ip_address, database=db, user=user, password=passwd, sock=sock, **kwargs + ) diff --git a/build/lib/google/cloud/sql/connector/rate_limiter.py b/build/lib/google/cloud/sql/connector/rate_limiter.py new file mode 100644 index 000000000..be9c68c64 --- /dev/null +++ b/build/lib/google/cloud/sql/connector/rate_limiter.py @@ -0,0 +1,79 @@ +""" +Copyright 2021 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +from typing import Optional + + +class AsyncRateLimiter(object): + """An asyncio-compatible rate limiter which uses the Token Bucket algorithm + (https://en.wikipedia.org/wiki/Token_bucket) to limit the number + of function calls over a time interval using an event queue. + + Args: + max_capacity (int): The maximum capacity of tokens the bucket + will store at any one time. Default: 1 + rate (float): The number of tokens that should be added per second. + loop (asyncio.AbstractEventLoop): The event loop to use. + If not provided, the default event loop will be used. + """ + + def __init__( + self, + max_capacity: int = 1, + rate: float = 1 / 60, + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> None: + self.rate = rate + self.max_capacity = max_capacity + self._loop = loop or asyncio.get_event_loop() + self._tokens: float = max_capacity + self._last_token_update = self._loop.time() + self._lock = asyncio.Lock() + + def _update_token_count(self) -> None: + """ + Calculates how much time has passed since the last leak and removes the + appropriate amount of events from the queue. + Leaking is done lazily, meaning that if there is a large time gap between + leaks, the next set of calls might be a burst if burst_size > 1 + """ + now = self._loop.time() + time_elapsed = now - self._last_token_update + new_tokens = time_elapsed * self.rate + self._tokens = min(new_tokens + self._tokens, self.max_capacity) + self._last_token_update = now + + async def _wait_for_next_token(self) -> None: + """ + Wait until enough time has elapsed to add another token. + """ + token_deficit = 1 - self._tokens + if token_deficit > 0: + wait_time = token_deficit / self.rate + await asyncio.sleep(wait_time) + + async def acquire(self) -> None: + """ + Waits for a token to become available, if necessary, then subtracts token and allows + request to go through. + """ + async with self._lock: + self._update_token_count() + if self._tokens < 1: + await self._wait_for_next_token() + self._update_token_count() + self._tokens -= 1 diff --git a/build/lib/google/cloud/sql/connector/refresh_utils.py b/build/lib/google/cloud/sql/connector/refresh_utils.py new file mode 100644 index 000000000..898f0f7a9 --- /dev/null +++ b/build/lib/google/cloud/sql/connector/refresh_utils.py @@ -0,0 +1,155 @@ +""" +Copyright 2021 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import asyncio +import copy +import datetime +import logging +import random +from typing import Any, Callable + +import aiohttp +from google.auth.credentials import Credentials +from google.auth.credentials import Scoped +import google.auth.transport.requests + +logger = logging.getLogger(name=__name__) + +# _refresh_buffer is the amount of time before a refresh's result expires +# that a new refresh operation begins. +_refresh_buffer: int = 4 * 60 # 4 minutes + + +def _seconds_until_refresh( + expiration: datetime.datetime, +) -> int: + """ + Calculates the duration to wait before starting the next refresh. + + Usually the duration will be half of the time until certificate + expiration. + + Args: + expiration (datetime.datetime): The expiration time of the certificate. + + Returns: + int: Time in seconds to wait before performing next refresh. + """ + + duration = int( + (expiration - datetime.datetime.now(datetime.timezone.utc)).total_seconds() + ) + + # if certificate duration is less than 1 hour + if duration < 3600: + # something is wrong with certificate, refresh now + if duration < _refresh_buffer: + return 0 + # otherwise wait until 4 minutes before expiration for next refresh + return duration - _refresh_buffer + + return duration // 2 + + +async def _is_valid(task: asyncio.Task) -> bool: + try: + metadata = await task + # only valid if now is before the cert expires + if datetime.datetime.now(datetime.timezone.utc) < metadata.expiration: + return True + except Exception: + # supress any errors from task + logger.debug("Current instance metadata is invalid.") + return False + + +def _downscope_credentials( + credentials: Credentials, + scopes: list[str] = ["https://www.googleapis.com/auth/sqlservice.login"], +) -> Credentials: + """Generate a down-scoped credential. + + Args: + credentials (google.auth.credentials.Credentials): + Credentials object used to generate down-scoped credentials. + scopes (list[str]): List of Google scopes to + include in down-scoped credentials object. + + Returns: + google.auth.credentials.Credentials: Down-scoped credentials object. + """ + # credentials sourced from a service account or metadata are children of + # Scoped class and are capable of being re-scoped + if isinstance(credentials, Scoped): + scoped_creds = credentials.with_scopes(scopes=scopes) + # authenticated user credentials can not be re-scoped + else: + # create shallow copy to not overwrite scopes on default credentials + scoped_creds = copy.copy(credentials) + # overwrite '_scopes' to down-scope user credentials + # Cloud SDK reference: https://github.com/google-cloud-sdk-unofficial/google-cloud-sdk/blob/93920ccb6d2cce0fe6d1ce841e9e33410551d66b/lib/googlecloudsdk/command_lib/sql/generate_login_token_util.py#L116 + scoped_creds._scopes = scopes # type: ignore[attr-defined] + # down-scoped credentials require refresh, are invalid after being re-scoped + request = google.auth.transport.requests.Request() + scoped_creds.refresh(request) + return scoped_creds + + +def _exponential_backoff(attempt: int) -> float: + """Calculates a duration to backoff in milliseconds based on the attempt i. + + The formula is: + + base * multi^(attempt + 1 + random) + + With base = 200ms and multi = 1.618, and random = [0.0, 1.0), + the backoff values would fall between the following low and high ends: + + Attempt Low (ms) High (ms) + + 0 324 524 + 1 524 847 + 2 847 1371 + 3 1371 2218 + 4 2218 3588 + + The theoretical worst case scenario would have a client wait 8.5s in total + for an API request to complete (with the first four attempts failing, and + the fifth succeeding). + """ + base = 200 + multi = 1.618 + exp = attempt + 1 + random.random() + return base * pow(multi, exp) + + +async def retry_50x( + request_coro: Callable, *args: Any, **kwargs: Any +) -> aiohttp.ClientResponse: + """Retry any 50x HTTP response up to X number of times.""" + max_retries = 5 + for i in range(max_retries): + resp = await request_coro(*args, **kwargs) + # backoff for any 50X errors + if resp.status >= 500 and i < max_retries: + # calculate backoff time + backoff = _exponential_backoff(i) + await asyncio.sleep(backoff / 1000) + else: + break + return resp diff --git a/build/lib/google/cloud/sql/connector/resolver.py b/build/lib/google/cloud/sql/connector/resolver.py new file mode 100644 index 000000000..e255f328a --- /dev/null +++ b/build/lib/google/cloud/sql/connector/resolver.py @@ -0,0 +1,91 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import dns.asyncresolver + +from google.cloud.sql.connector.connection_name import _is_valid_domain +from google.cloud.sql.connector.connection_name import _parse_connection_name +from google.cloud.sql.connector.connection_name import ( + _parse_connection_name_with_domain_name, +) +from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.exceptions import DnsResolutionError + + +class DefaultResolver: + """DefaultResolver simply validates and parses instance connection name.""" + + async def resolve(self, connection_name: str) -> ConnectionName: + return _parse_connection_name(connection_name) + + +class DnsResolver(dns.asyncresolver.Resolver): + """ + DnsResolver resolves domain names into instance connection names using + TXT records in DNS. + """ + + async def resolve(self, dns: str) -> ConnectionName: # type: ignore + try: + conn_name = _parse_connection_name(dns) + except ValueError: + # The connection name was not project:region:instance format. + # Check if connection name is a valid DNS domain name + if _is_valid_domain(dns): + # Attempt to query a TXT record to get connection name. + conn_name = await self.query_dns(dns) + else: + raise ValueError( + "Arg `instance_connection_string` must have " + "format: PROJECT:REGION:INSTANCE or be a valid DNS domain " + f"name, got {dns}." + ) + return conn_name + + async def resolve_a_record(self, dns: str) -> List[str]: + try: + # Attempt to query the A records. + records = await super().resolve(dns, "A", raise_on_no_answer=True) + # return IP addresses as strings + return [record.to_text() for record in records] + except Exception: + # On any error, return empty list + return [] + + async def query_dns(self, dns: str) -> ConnectionName: + try: + # Attempt to query the TXT records. + records = await super().resolve(dns, "TXT", raise_on_no_answer=True) + # Sort the TXT record values alphabetically, strip quotes as record + # values can be returned as raw strings + rdata = [record.to_text().strip('"') for record in records] + rdata.sort() + # Attempt to parse records, returning the first valid record. + for record in rdata: + try: + conn_name = _parse_connection_name_with_domain_name(record, dns) + return conn_name + except Exception: + continue + # If all records failed to parse, throw error + raise DnsResolutionError( + f"Unable to parse TXT record for `{dns}` -> `{rdata[0]}`" + ) + # Don't override above DnsResolutionError + except DnsResolutionError: + raise + except Exception as e: + raise DnsResolutionError(f"Unable to resolve TXT record for `{dns}`") from e diff --git a/build/lib/google/cloud/sql/connector/sqldata_client.py b/build/lib/google/cloud/sql/connector/sqldata_client.py new file mode 100644 index 000000000..950373f4b --- /dev/null +++ b/build/lib/google/cloud/sql/connector/sqldata_client.py @@ -0,0 +1,355 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +import socket +from typing import Any, Callable, Optional + +from google.auth.credentials import Credentials +from google.auth.transport.grpc import AuthMetadataPlugin +from google.auth.transport.requests import Request +import grpc + +import google.rpc.status_pb2 # noqa: F401 # isort: skip +from google.cloud.sql.connector.proto import sql_data_service_pb2 # type: ignore +from google.cloud.sql.connector.proto import sql_data_service_pb2_grpc # type: ignore + +logger = logging.getLogger(__name__) + + +class SqlDataClient: + def __init__( + self, + endpoint: str, + credentials: Credentials, + quota_project: Optional[str] = None, + timeout: Optional[float] = None, + ): + self._endpoint = endpoint + self._credentials = credentials + self._quota_project = quota_project + self._timeout = timeout + + async def connect_tunnel( + self, + instance_connection_name: str, + region: str, + project: str, + get_conn_info: Callable[[], Any], + enable_iam_auth: bool, + on_fallback: Callable[[str], None], + is_fallback_cached: Callable[[str], bool], + ) -> int: + """Starts a local TCP tunnel and returns the local port. + + If the instance does not support SQL Data Service, it falls back + to a direct TLS connection. + """ + # Start local TCP server + server = await asyncio.start_server( + lambda r, w: self._handle_tunnel( + r, + w, + instance_connection_name, + region, + project, + get_conn_info, + enable_iam_auth, + on_fallback, + is_fallback_cached, + ), + "127.0.0.1", + 0, + ) + + port = server.sockets[0].getsockname()[1] + logger.debug(f"SQL Data tunnel listening on 127.0.0.1:{port}") + + # Keep reference to server to close it + self._server = server + return port + + async def close(self) -> None: + """Closes the local tunnel server if it is running.""" + if hasattr(self, "_server") and self._server: + self._server.close() + try: + await asyncio.wait_for(self._server.wait_closed(), timeout=2.0) + logger.debug("SQL Data tunnel server closed by client close()") + except asyncio.TimeoutError: + logger.warning("Timeout waiting for SQL Data tunnel server to close") + + async def _handle_tunnel( + self, + client_reader: asyncio.StreamReader, + client_writer: asyncio.StreamWriter, + instance_connection_name: str, + region: str, + project: str, + get_conn_info: Callable[[], Any], + enable_iam_auth: bool, + on_fallback: Callable[[str], None], + is_fallback_cached: Callable[[str], bool], + ): + logger.debug("Accepted local connection for SQL Data tunnel") + # Close the server so no more connections are accepted on this port + self._server.close() + + # Buffer to cache client writes for fallback replay + client_write_buffer = bytearray() + first_read_done = False + fallback_triggered = False + + # We need to share these streams between tasks + backend_reader: Optional[asyncio.StreamReader] = None + backend_writer: Optional[asyncio.StreamWriter] = None + grpc_stream: Optional[Any] = None + grpc_channel: Optional[grpc.aio.Channel] = None + + # Check if fallback is already cached + use_fallback = is_fallback_cached(instance_connection_name) + + async def connect_grpc() -> tuple[grpc.aio.Channel, Any]: + auth_request = Request() + plugin = AuthMetadataPlugin(self._credentials, auth_request) + call_creds = grpc.metadata_call_credentials(plugin) + channel_creds = grpc.composite_channel_credentials( + grpc.ssl_channel_credentials(), call_creds + ) + + endpoint = self._endpoint + if endpoint.startswith("https://"): + endpoint = endpoint[len("https://") :] + if endpoint.startswith("http://"): + endpoint = endpoint[len("http://") :] + + logger.debug(f"Creating secure channel to {endpoint}") + channel = grpc.aio.secure_channel(endpoint, channel_creds) + stub = sql_data_service_pb2_grpc.SqlDataServiceStub(channel) + + instance_id = f"projects/{project}/instances/{instance_connection_name.split(':')[-1]}" + location_id = f"locations/{region}" + + metadata = [] + quota_project_in_creds = getattr(self._credentials, "quota_project_id", None) + if self._quota_project and self._quota_project != quota_project_in_creds: + metadata.append(("x-goog-user-project", self._quota_project)) + metadata.append( + ( + "x-goog-request-params", + f"instance_id={instance_id}&location_id={location_id}", + ) + ) + + # Start stream + logger.debug(f"Starting StreamSqlData with metadata {metadata}") + stream = stub.StreamSqlData(metadata=metadata) + + # Send StartSession + start_session = sql_data_service_pb2.StartSession( # type: ignore[attr-defined] + instance_id=instance_id, location_id=location_id + ) + req = sql_data_service_pb2.StreamSqlDataRequest( # type: ignore[attr-defined] + start_session=start_session + ) + logger.debug("Writing StartSession to stream...") + await stream.write(req) + logger.debug("StartSession written successfully") + return channel, stream + + async def connect_direct() -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: + logger.debug("Fallback triggered, fetching connection info...") + conn_info = await get_conn_info() + # Find a fallback IP address + fallback_ip = None + from google.cloud.sql.connector.enums import IPTypes + for t in [IPTypes.PUBLIC, IPTypes.PRIVATE, IPTypes.PSC]: + try: + fallback_ip = conn_info.get_preferred_ip(t) + break + except Exception: + continue + if not fallback_ip: + raise ValueError("Cannot fallback to direct connection: no IP address available.") + logger.debug(f"Connecting directly to {fallback_ip}:3307") + ssl_context = await conn_info.create_ssl_context(enable_iam_auth) + return await asyncio.open_connection( + fallback_ip, 3307, ssl=ssl_context, server_hostname=fallback_ip + ) + + # Initialize connection + if use_fallback: + logger.debug("Using cached fallback connection") + backend_reader, backend_writer = await connect_direct() + fallback_triggered = True + else: + try: + grpc_channel, grpc_stream = await connect_grpc() + except Exception as e: + logger.debug(f"Failed to initialize gRPC stream: {e}") + # Try fallback immediately + backend_reader, backend_writer = await connect_direct() + fallback_triggered = True + on_fallback(instance_connection_name) + + # Task to read from client and write to backend + async def client_to_backend(): + nonlocal first_read_done, fallback_triggered, backend_writer, grpc_stream + try: + while True: + data = await client_reader.read(4096) + if not data: + logger.debug("Client socket EOF") + break + + if not first_read_done and not fallback_triggered: + client_write_buffer.extend(data) + + if fallback_triggered: + if backend_writer: + backend_writer.write(data) + await backend_writer.drain() + else: + packet = sql_data_service_pb2.DataPacket(data=data) # type: ignore[attr-defined] + req = sql_data_service_pb2.StreamSqlDataRequest( # type: ignore[attr-defined] + data=packet + ) + if grpc_stream: + await grpc_stream.write(req) + except Exception as e: + logger.error(f"Error in client_to_backend: {e}") + raise + finally: + if fallback_triggered: + if backend_writer: + backend_writer.write_eof() + else: + if grpc_stream: + try: + await grpc_stream.done_writing() + except Exception: + pass + logger.debug("Client to backend task finished") + + # Task to read from backend and write to client + async def backend_to_client(): + nonlocal first_read_done, fallback_triggered, backend_reader, backend_writer, grpc_stream, grpc_channel + try: + if fallback_triggered: + # If we started with fallback, just copy + while True: + if not backend_reader: + break + data = await backend_reader.read(4096) + if not data: + break + client_writer.write(data) + await client_writer.drain() + else: + # gRPC read loop + try: + if not grpc_stream: + return + async for resp in grpc_stream: + first_read_done = True + msg_type = resp.WhichOneof("message") + if msg_type == "session_metadata": + logger.debug("Received SessionMetadata") + elif msg_type == "data": + data = resp.data.data + client_writer.write(data) + await client_writer.drain() + elif msg_type == "terminate_session": + logger.debug("Received TerminateSession") + break + except grpc.aio.AioRpcError as e: + logger.debug(f"gRPC stream error: {e}") + # Check for fallback condition + if ( + not first_read_done + and e.code() == grpc.StatusCode.FAILED_PRECONDITION + ): + logger.info( + f"SQL Data Service not supported for {instance_connection_name}. " + "Falling back to direct connection." + ) + fallback_triggered = True + on_fallback(instance_connection_name) + + # Clean up gRPC + if grpc_channel: + await grpc_channel.close() + + # Connect direct + backend_reader, backend_writer = await connect_direct() + + # Replay buffered client data + if client_write_buffer: + logger.debug(f"Replaying {len(client_write_buffer)} bytes to fallback connection") + backend_writer.write(bytes(client_write_buffer)) + await backend_writer.drain() + + # Start copying from direct connection + while True: + data = await backend_reader.read(4096) + if not data: + break + client_writer.write(data) + await client_writer.drain() + else: + # Other gRPC error, re-raise to close connection + raise + except Exception as e: + logger.error(f"Error in backend_to_client: {e}") + raise + finally: + client_writer.close() + try: + await client_writer.wait_closed() + except Exception: + pass + if fallback_triggered and backend_writer: + backend_writer.close() + try: + await backend_writer.wait_closed() + except Exception: + pass + elif grpc_channel: + await grpc_channel.close() + logger.debug("Backend to client task finished") + + # Run both tasks + try: + await asyncio.gather(client_to_backend(), backend_to_client()) + finally: + logger.debug("Closing client socket in _handle_tunnel finally") + try: + client_writer.close() + sock = client_writer.get_extra_info('socket') + if sock: + sock.close() + except Exception as e: + logger.debug(f"Error closing client writer: {e}") + logger.debug("SQL Data tunnel handler finished") + + +class FallbackSocket(socket.socket): + def connect(self, *args: Any, **kwargs: Any) -> None: + # Already connected, do nothing. + # This is needed because some drivers (like pymysql) try to call connect() + # internally even if passed an already connected socket. + pass diff --git a/build/lib/google/cloud/sql/connector/utils.py b/build/lib/google/cloud/sql/connector/utils.py new file mode 100644 index 000000000..dd0aec344 --- /dev/null +++ b/build/lib/google/cloud/sql/connector/utils.py @@ -0,0 +1,101 @@ +""" +Copyright 2019 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import aiofiles +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa + + +async def generate_keys() -> tuple[bytes, str]: + """A helper function to generate the private and public keys. + + backend - The value specified is default_backend(). This is because the + cryptography library used to support different backends, but now only uses + the default_backend(). + + public_exponent - The public exponent is one of the variables used in the + generation of the keys. 65537 is recommended due to being a good balance + between speed and security. + + key_size - The cryptography documentation recommended a key_size + of at least 2048. + + """ + private_key_obj = rsa.generate_private_key( + backend=default_backend(), public_exponent=65537, key_size=2048 + ) + + pub_key = ( + private_key_obj.public_key() + .public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + .decode("UTF-8") + ) + + priv_key = private_key_obj.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + return priv_key, pub_key + + +async def write_to_file( + dir_path: str, serverCaCert: str, ephemeralCert: str, priv_key: bytes +) -> tuple[str, str, str]: + """ + Helper function to write the serverCaCert, ephemeral certificate and + private key to .pem files in a given directory + """ + ca_filename = f"{dir_path}/ca.pem" + cert_filename = f"{dir_path}/cert.pem" + key_filename = f"{dir_path}/priv.pem" + + async with aiofiles.open(ca_filename, "w+") as ca_out: + await ca_out.write(serverCaCert) + async with aiofiles.open(cert_filename, "w+") as ephemeral_out: + await ephemeral_out.write(ephemeralCert) + async with aiofiles.open(key_filename, "wb") as priv_out: + await priv_out.write(priv_key) + + return (ca_filename, cert_filename, key_filename) + + +def format_database_user(database_version: str, user: str) -> str: + """Format database `user` param for Cloud SQL automatic IAM authentication. + + Args: + database_version (str): Cloud SQL database version. + user (str): Database username to connect to Cloud SQL database with. + + Returns: + str: Formatted database username. + """ + # remove suffix for Postgres service accounts + if database_version.startswith("POSTGRES"): + suffix = ".gserviceaccount.com" + user = user[: -len(suffix)] if user.endswith(suffix) else user + return user + + # remove everything after and including the @ for MySQL + if database_version.startswith("MYSQL") and "@" in user: + return user.split("@")[0] + + return user diff --git a/build/lib/google/cloud/sql/connector/version.py b/build/lib/google/cloud/sql/connector/version.py new file mode 100644 index 000000000..64c323da2 --- /dev/null +++ b/build/lib/google/cloud/sql/connector/version.py @@ -0,0 +1,15 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "1.20.3" diff --git a/google/cloud/sql/connector/asyncpg.py b/google/cloud/sql/connector/asyncpg.py index 2fbc30273..2e28dbbaf 100644 --- a/google/cloud/sql/connector/asyncpg.py +++ b/google/cloud/sql/connector/asyncpg.py @@ -15,7 +15,7 @@ """ import ssl -from typing import Any, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING SERVER_PROXY_PORT = 3307 @@ -24,7 +24,7 @@ async def connect( - ip_address: str, ctx: ssl.SSLContext, **kwargs: Any + ip_address: str, ctx: Optional[ssl.SSLContext], **kwargs: Any ) -> "asyncpg.Connection": """Helper function to create an asyncpg DB-API connection object. @@ -32,8 +32,7 @@ async def connect( ip_address (str): A string containing an IP address for the Cloud SQL instance. ctx (ssl.SSLContext): An SSLContext object created from the Cloud SQL - server CA cert and ephemeral cert. - server CA cert and ephemeral cert. + server CA cert and ephemeral cert. Pass None to disable SSL. kwargs: Keyword arguments for establishing asyncpg connection object to Cloud SQL instance. @@ -53,14 +52,18 @@ async def connect( user = kwargs.pop("user") db = kwargs.pop("db") passwd = kwargs.pop("password", None) + port = kwargs.pop("port", SERVER_PROXY_PORT) - return await asyncpg.connect( - user=user, - database=db, - password=passwd, - host=ip_address, - port=SERVER_PROXY_PORT, - ssl=ctx, - direct_tls=True, + connect_args = { + "user": user, + "database": db, + "password": passwd, + "host": ip_address, + "port": port, **kwargs, - ) + } + if ctx is not None: + connect_args["ssl"] = ctx + connect_args["direct_tls"] = True + + return await asyncpg.connect(**connect_args) diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index 11508ce17..1befdb793 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -171,9 +171,13 @@ async def _get_metadata( if dns_name: ip_addresses["PSC"] = dns_name.rstrip(".") + server_ca_cert = None + if "serverCaCert" in ret_dict and "cert" in ret_dict["serverCaCert"]: + server_ca_cert = ret_dict["serverCaCert"]["cert"] + return { "ip_addresses": ip_addresses, - "server_ca_cert": ret_dict["serverCaCert"]["cert"], + "server_ca_cert": server_ca_cert, "database_version": ret_dict["databaseVersion"], } @@ -228,7 +232,11 @@ async def _get_ephemeral( finally: resp.raise_for_status() - ephemeral_cert: str = ret_dict["ephemeralCert"]["cert"] + try: + ephemeral_cert: str = ret_dict["ephemeralCert"]["cert"] + except KeyError as e: + logger.error(f"KeyError in _get_ephemeral parsing generateEphemeralCert: {e}. Response dict: {ret_dict}") + raise # decode cert to read expiration x509 = load_pem_x509_certificate( diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index c9e48935f..bf9330e1b 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -63,7 +63,7 @@ class ConnectionInfo: conn_name: ConnectionName client_cert: str - server_ca_cert: str + server_ca_cert: Optional[str] private_key: bytes ip_addrs: dict[str, Any] database_version: str @@ -79,6 +79,10 @@ async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLCont # if SSL context is cached, use it if self.context is not None: return self.context + + if self.server_ca_cert is None: + raise ValueError("Cannot create SSL context: server CA certificate is missing.") + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) # update ssl.PROTOCOL_TLS_CLIENT default diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 798969c2c..6d902b1e2 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -31,6 +31,7 @@ import google.cloud.sql.connector.asyncpg as asyncpg from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.enums import DriverMapping from google.cloud.sql.connector.enums import IPTypes from google.cloud.sql.connector.enums import RefreshStrategy @@ -44,6 +45,8 @@ import google.cloud.sql.connector.pytds as pytds from google.cloud.sql.connector.resolver import DefaultResolver from google.cloud.sql.connector.resolver import DnsResolver +from google.cloud.sql.connector.sqldata_client import FallbackSocket +from google.cloud.sql.connector.sqldata_client import SqlDataClient from google.cloud.sql.connector.utils import format_database_user from google.cloud.sql.connector.utils import generate_keys @@ -73,6 +76,8 @@ def __init__( refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver, failover_period: int = 30, + sql_data_endpoint: str = "sqladmin.googleapis.com", + sql_data_stream_timeout: int = 7200, ) -> None: """Initializes a Connector instance. @@ -212,6 +217,12 @@ def __init__( "configured the universe domain explicitly, `googleapis.com` " "is the default." ) + self._sql_data_endpoint = sql_data_endpoint + self._sql_data_stream_timeout = sql_data_stream_timeout + self._sql_data_fallback_cache: set[str] = set() + self._sqldata_clients: list[SqlDataClient] = [] + + @property def universe_domain(self) -> str: @@ -258,6 +269,48 @@ def connect( ) return connect_future.result() + def _get_or_create_cache( + self, + conn_name: ConnectionName, + enable_iam_auth: bool, + ) -> MonitoredCache: + assert self._client is not None, "client must be initialized before creating cache" + assert self._keys is not None, "keys must be initialized before creating cache" + if (str(conn_name), enable_iam_auth) in self._cache and not self._cache[ + (str(conn_name), enable_iam_auth) + ].closed: + return self._cache[(str(conn_name), enable_iam_auth)] + + if self._refresh_strategy == RefreshStrategy.LAZY: + logger.debug( + f"['{conn_name}']: Refresh strategy is set to lazy refresh" + ) + cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache( + conn_name, + self._client, + self._keys, + enable_iam_auth, + ) + else: + logger.debug( + f"['{conn_name}']: Refresh strategy is set to backgound refresh" + ) + cache = RefreshAheadCache( + conn_name, + self._client, + self._keys, + enable_iam_auth, + ) + # wrap cache as a MonitoredCache + monitored_cache = MonitoredCache( + cache, + self._failover_period, + self._resolver, + ) + logger.debug(f"['{conn_name}']: Connection info added to cache") + self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache + return monitored_cache + async def connect_async( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: @@ -317,42 +370,14 @@ async def connect_async( driver=driver, ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) + ip_type = kwargs.pop("ip_type", self._ip_type) + if isinstance(ip_type, str): + ip_type = IPTypes._from_str(ip_type) conn_name = await self._resolver.resolve(instance_connection_string) - # Cache entry must exist and not be closed - if (str(conn_name), enable_iam_auth) in self._cache and not self._cache[ - (str(conn_name), enable_iam_auth) - ].closed: - monitored_cache = self._cache[(str(conn_name), enable_iam_auth)] - else: - if self._refresh_strategy == RefreshStrategy.LAZY: - logger.debug( - f"['{conn_name}']: Refresh strategy is set to lazy refresh" - ) - cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache( - conn_name, - self._client, - self._keys, - enable_iam_auth, - ) - else: - logger.debug( - f"['{conn_name}']: Refresh strategy is set to backgound refresh" - ) - cache = RefreshAheadCache( - conn_name, - self._client, - self._keys, - enable_iam_auth, - ) - # wrap cache as a MonitoredCache - monitored_cache = MonitoredCache( - cache, - self._failover_period, - self._resolver, - ) - logger.debug(f"['{conn_name}']: Connection info added to cache") - self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache + + if ip_type != IPTypes.SQL_DATA: + monitored_cache = self._get_or_create_cache(conn_name, enable_iam_auth) connect_func = { "pymysql": pymysql.connect, @@ -366,11 +391,6 @@ async def connect_async( connector: Callable = connect_func[driver] # type: ignore except KeyError: raise KeyError(f"Driver '{driver}' is not supported.") - - ip_type = kwargs.pop("ip_type", self._ip_type) - # if ip_type is str, convert to IPTypes enum - if isinstance(ip_type, str): - ip_type = IPTypes._from_str(ip_type) kwargs["timeout"] = kwargs.get("timeout", self._timeout) # Host and ssl options come from the certificates and metadata, so we don't @@ -379,85 +399,149 @@ async def connect_async( kwargs.pop("ssl", None) kwargs.pop("port", None) - # attempt to get connection info for Cloud SQL instance + # attempt to establish connection try: - conn_info = await monitored_cache.connect_info() - # validate driver matches intended database engine - DriverMapping.validate_engine(driver, conn_info.database_version) - ip_address = conn_info.get_preferred_ip(ip_type) - except Exception: - # with an error from Cloud SQL Admin API call or IP type, invalidate - # the cache and re-raise the error - await self._remove_cached(str(conn_name), enable_iam_auth) - raise + if ip_type == IPTypes.SQL_DATA: + logger.debug(f"['{conn_name}']: Connecting via SQL Data Service tunnel") + if enable_iam_auth: + engine = DriverMapping[driver].value + formatted_user = format_database_user( + engine, kwargs["user"] + ) + if formatted_user != kwargs["user"]: + logger.debug( + f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}" + ) + kwargs["user"] = formatted_user + + sqldata_client = SqlDataClient( + endpoint=self._sql_data_endpoint, + credentials=self._credentials, + quota_project=self._quota_project, + timeout=self._sql_data_stream_timeout, + ) + self._sqldata_clients.append(sqldata_client) + + def on_fallback(name): + self._sql_data_fallback_cache.add(name) + + def is_fallback_cached(name): + return name in self._sql_data_fallback_cache + + # Defer cache creation and connect_info call + async def get_conn_info(): + cache = self._get_or_create_cache(conn_name, enable_iam_auth) + return await cache.connect_info() + + tunnel_port = await sqldata_client.connect_tunnel( + instance_connection_name=str(conn_name), + region=conn_name.region, + project=conn_name.project, + get_conn_info=get_conn_info, + enable_iam_auth=enable_iam_auth, + on_fallback=on_fallback, + is_fallback_cached=is_fallback_cached, + ) - # If the connector is configured with a custom DNS name, attempt to use - # that DNS name to connect to the instance. Fall back to the metadata IP - # address if the DNS name does not resolve to an IP address. - if conn_info.conn_name.domain_name and isinstance(self._resolver, DnsResolver): - try: - ips = await self._resolver.resolve_a_record(conn_info.conn_name.domain_name) - if ips: - ip_address = ips[0] - logger.debug( - f"['{instance_connection_string}']: Custom DNS name " - f"'{conn_info.conn_name.domain_name}' resolved to '{ip_address}', " - "using it to connect" + if driver in ASYNC_DRIVERS: + return await connector( + "127.0.0.1", + None, + port=tunnel_port, + **kwargs, ) else: - logger.debug( - f"['{instance_connection_string}']: Custom DNS name " - f"'{conn_info.conn_name.domain_name}' resolved but returned no " - f"entries, using '{ip_address}' from instance metadata" + raw_sock = socket.create_connection(("127.0.0.1", tunnel_port)) + fd = raw_sock.detach() + fallback_sock = FallbackSocket(fileno=fd) + + if conn_name.domain_name: + monitored_cache.sockets.append(fallback_sock) + + connect_partial = partial( + connector, + "127.0.0.1", + fallback_sock, + **kwargs, ) - except Exception as e: - logger.debug( - f"['{instance_connection_string}']: Custom DNS name " - f"'{conn_info.conn_name.domain_name}' did not resolve to an IP " - f"address: {e}, using '{ip_address}' from instance metadata" - ) - - logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") - # format `user` param for automatic IAM database authn - if enable_iam_auth: - formatted_user = format_database_user( - conn_info.database_version, kwargs["user"] - ) - if formatted_user != kwargs["user"]: - logger.debug( - f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}" + return await self._loop.run_in_executor(None, connect_partial) + else: + # Standard path (requires metadata and certs) + try: + conn_info = await monitored_cache.connect_info() + # validate driver matches intended database engine + DriverMapping.validate_engine(driver, conn_info.database_version) + ip_address = conn_info.get_preferred_ip(ip_type) + except Exception: + # with an error from Cloud SQL Admin API call or IP type, invalidate + # the cache and re-raise the error + await self._remove_cached(str(conn_name), enable_iam_auth) + raise + + # If the connector is configured with a custom DNS name, attempt to use + # that DNS name to connect to the instance. Fall back to the metadata IP + # address if the DNS name does not resolve to an IP address. + if conn_info.conn_name.domain_name and isinstance(self._resolver, DnsResolver): + try: + ips = await self._resolver.resolve_a_record(conn_info.conn_name.domain_name) + if ips: + ip_address = ips[0] + logger.debug( + f"['{instance_connection_string}']: Custom DNS name " + f"'{conn_info.conn_name.domain_name}' resolved to '{ip_address}', " + "using it to connect" + ) + else: + logger.debug( + f"['{instance_connection_string}']: Custom DNS name " + f"'{conn_info.conn_name.domain_name}' resolved but returned no " + f"entries, using '{ip_address}' from instance metadata" + ) + except Exception as e: + logger.debug( + f"['{instance_connection_string}']: Custom DNS name " + f"'{conn_info.conn_name.domain_name}' did not resolve to an IP " + f"address: {e}, using '{ip_address}' from instance metadata" + ) + + logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") + # format `user` param for automatic IAM database authn + if enable_iam_auth: + formatted_user = format_database_user( + conn_info.database_version, kwargs["user"] + ) + if formatted_user != kwargs["user"]: + logger.debug( + f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}" + ) + kwargs["user"] = formatted_user + + if driver in ASYNC_DRIVERS: + return await connector( + ip_address, + await conn_info.create_ssl_context(enable_iam_auth), + **kwargs, + ) + ctx = await conn_info.create_ssl_context(enable_iam_auth) + ssl_sock = ctx.wrap_socket( + socket.create_connection((ip_address, SERVER_PROXY_PORT)), + server_hostname=ip_address, ) - kwargs["user"] = formatted_user - try: - # async drivers are unblocking and can be awaited directly - if driver in ASYNC_DRIVERS: - return await connector( + if conn_info.conn_name.domain_name: + monitored_cache.sockets.append(ssl_sock) + connect_partial = partial( + connector, ip_address, - await conn_info.create_ssl_context(enable_iam_auth), + ssl_sock, **kwargs, ) - # Create socket with SSLContext for sync drivers - ctx = await conn_info.create_ssl_context(enable_iam_auth) - sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, - ) - # If this connection was opened using a domain name, then store it - # for later in case we need to forcibly close it on failover. - if conn_info.conn_name.domain_name: - monitored_cache.sockets.append(sock) - # Synchronous drivers are blocking and run using executor - connect_partial = partial( - connector, - ip_address, - sock, - **kwargs, - ) - return await self._loop.run_in_executor(None, connect_partial) + return await self._loop.run_in_executor(None, connect_partial) except Exception: # with any exception, we attempt a force refresh, then throw the error - await monitored_cache.force_refresh() + cache = self._cache.get((str(conn_name), enable_iam_auth)) + if cache: + await cache.force_refresh() raise async def _remove_cached( @@ -505,8 +589,11 @@ def close(self) -> None: close_future = asyncio.run_coroutine_threadsafe( self.close_async(), loop=self._loop ) - # Will attempt to safely shut down tasks for 3s - close_future.result(timeout=3) + try: + # Will attempt to safely shut down tasks for 3s + close_future.result(timeout=3) + except Exception as e: + logger.error(f"Error during close_async: {e}") # if background thread exists for Connector, clean it up if self._thread: if self._loop.is_running(): @@ -521,7 +608,10 @@ async def close_async(self) -> None: self._closed = True if self._client: await self._client.close() - await asyncio.gather(*[cache.close() for cache in self._cache.values()]) + await asyncio.gather( + *[cache.close() for cache in self._cache.values()], + *[client.close() for client in self._sqldata_clients], + ) async def create_async_connector( diff --git a/google/cloud/sql/connector/enums.py b/google/cloud/sql/connector/enums.py index e6b56af0e..f936dba84 100644 --- a/google/cloud/sql/connector/enums.py +++ b/google/cloud/sql/connector/enums.py @@ -41,6 +41,7 @@ class IPTypes(Enum): PUBLIC = "PRIMARY" PRIVATE = "PRIVATE" PSC = "PSC" + SQL_DATA = "SQL_DATA" @classmethod def _missing_(cls, value: object) -> None: diff --git a/google/cloud/sql/connector/monitored_cache.py b/google/cloud/sql/connector/monitored_cache.py index 0c3fc4d03..79a77aeda 100644 --- a/google/cloud/sql/connector/monitored_cache.py +++ b/google/cloud/sql/connector/monitored_cache.py @@ -14,7 +14,7 @@ import asyncio import logging -import ssl +import socket from typing import Any, Callable, Optional, Union from google.cloud.sql.connector.connection_info import ConnectionInfo @@ -38,7 +38,7 @@ def __init__( self.resolver = resolver self.cache = cache self.domain_name_ticker: Optional[asyncio.Task] = None - self.sockets: list[ssl.SSLSocket] = [] + self.sockets: list[socket.socket] = [] # If domain name is configured for instance and failover period is set, # poll for DNS record changes. @@ -62,11 +62,11 @@ def _purge_closed_sockets(self) -> None: list of sockets. """ open_sockets = [] - for socket in self.sockets: + for sock in self.sockets: # Check fileno for if socket is closed. Will return # -1 on failure, which will be used to signal socket closed. - if socket.fileno() != -1: - open_sockets.append(socket) + if sock.fileno() != -1: + open_sockets.append(sock) self.sockets = open_sockets async def _check_domain_name(self) -> None: @@ -128,11 +128,11 @@ async def close(self) -> None: await self.cache.close() # Close any still open sockets - for socket in self.sockets: + for sock in self.sockets: # Check fileno for if socket is closed. Will return # -1 on failure, which will be used to signal socket closed. - if socket.fileno() != -1: - socket.close() + if sock.fileno() != -1: + sock.close() async def ticker(interval: int, function: Callable, *args: Any, **kwargs: Any) -> None: diff --git a/google/cloud/sql/connector/proto/google/rpc/code.proto b/google/cloud/sql/connector/proto/google/rpc/code.proto new file mode 100644 index 000000000..8fef41170 --- /dev/null +++ b/google/cloud/sql/connector/proto/google/rpc/code.proto @@ -0,0 +1,186 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package google.rpc; + +option go_package = "google.golang.org/genproto/googleapis/rpc/code;code"; +option java_multiple_files = true; +option java_outer_classname = "CodeProto"; +option java_package = "com.google.rpc"; +option objc_class_prefix = "RPC"; + + +// The canonical error codes for Google APIs. +// +// +// Sometimes multiple error codes may apply. Services should return +// the most specific error code that applies. For example, prefer +// `OUT_OF_RANGE` over `FAILED_PRECONDITION` if both codes apply. +// Similarly prefer `NOT_FOUND` or `ALREADY_EXISTS` over `FAILED_PRECONDITION`. +enum Code { + // Not an error; returned on success + // + // HTTP Mapping: 200 OK + OK = 0; + + // The operation was cancelled, typically by the caller. + // + // HTTP Mapping: 499 Client Closed Request + CANCELLED = 1; + + // Unknown error. For example, this error may be returned when + // a `Status` value received from another address space belongs to + // an error space that is not known in this address space. Also + // errors raised by APIs that do not return enough error information + // may be converted to this error. + // + // HTTP Mapping: 500 Internal Server Error + UNKNOWN = 2; + + // The client specified an invalid argument. Note that this differs + // from `FAILED_PRECONDITION`. `INVALID_ARGUMENT` indicates arguments + // that are problematic regardless of the state of the system + // (e.g., a malformed file name). + // + // HTTP Mapping: 400 Bad Request + INVALID_ARGUMENT = 3; + + // The deadline expired before the operation could complete. For operations + // that change the state of the system, this error may be returned + // even if the operation has completed successfully. For example, a + // successful response from a server could have been delayed long + // enough for the deadline to expire. + // + // HTTP Mapping: 504 Gateway Timeout + DEADLINE_EXCEEDED = 4; + + // Some requested entity (e.g., file or directory) was not found. + // + // Note to server developers: if a request is denied for an entire class + // of users, such as gradual feature rollout or undocumented whitelist, + // `NOT_FOUND` may be used. If a request is denied for some users within + // a class of users, such as user-based access control, `PERMISSION_DENIED` + // must be used. + // + // HTTP Mapping: 404 Not Found + NOT_FOUND = 5; + + // The entity that a client attempted to create (e.g., file or directory) + // already exists. + // + // HTTP Mapping: 409 Conflict + ALREADY_EXISTS = 6; + + // The caller does not have permission to execute the specified + // operation. `PERMISSION_DENIED` must not be used for rejections + // caused by exhausting some resource (use `RESOURCE_EXHAUSTED` + // instead for those errors). `PERMISSION_DENIED` must not be + // used if the caller can not be identified (use `UNAUTHENTICATED` + // instead for those errors). This error code does not imply the + // request is valid or the requested entity exists or satisfies + // other pre-conditions. + // + // HTTP Mapping: 403 Forbidden + PERMISSION_DENIED = 7; + + // The request does not have valid authentication credentials for the + // operation. + // + // HTTP Mapping: 401 Unauthorized + UNAUTHENTICATED = 16; + + // Some resource has been exhausted, perhaps a per-user quota, or + // perhaps the entire file system is out of space. + // + // HTTP Mapping: 429 Too Many Requests + RESOURCE_EXHAUSTED = 8; + + // The operation was rejected because the system is not in a state + // required for the operation's execution. For example, the directory + // to be deleted is non-empty, an rmdir operation is applied to + // a non-directory, etc. + // + // Service implementors can use the following guidelines to decide + // between `FAILED_PRECONDITION`, `ABORTED`, and `UNAVAILABLE`: + // (a) Use `UNAVAILABLE` if the client can retry just the failing call. + // (b) Use `ABORTED` if the client should retry at a higher level + // (e.g., when a client-specified test-and-set fails, indicating the + // client should restart a read-modify-write sequence). + // (c) Use `FAILED_PRECONDITION` if the client should not retry until + // the system state has been explicitly fixed. E.g., if an "rmdir" + // fails because the directory is non-empty, `FAILED_PRECONDITION` + // should be returned since the client should not retry unless + // the files are deleted from the directory. + // + // HTTP Mapping: 400 Bad Request + FAILED_PRECONDITION = 9; + + // The operation was aborted, typically due to a concurrency issue such as + // a sequencer check failure or transaction abort. + // + // See the guidelines above for deciding between `FAILED_PRECONDITION`, + // `ABORTED`, and `UNAVAILABLE`. + // + // HTTP Mapping: 409 Conflict + ABORTED = 10; + + // The operation was attempted past the valid range. E.g., seeking or + // reading past end-of-file. + // + // Unlike `INVALID_ARGUMENT`, this error indicates a problem that may + // be fixed if the system state changes. For example, a 32-bit file + // system will generate `INVALID_ARGUMENT` if asked to read at an + // offset that is not in the range [0,2^32-1], but it will generate + // `OUT_OF_RANGE` if asked to read from an offset past the current + // file size. + // + // There is a fair bit of overlap between `FAILED_PRECONDITION` and + // `OUT_OF_RANGE`. We recommend using `OUT_OF_RANGE` (the more specific + // error) when it applies so that callers who are iterating through + // a space can easily look for an `OUT_OF_RANGE` error to detect when + // they are done. + // + // HTTP Mapping: 400 Bad Request + OUT_OF_RANGE = 11; + + // The operation is not implemented or is not supported/enabled in this + // service. + // + // HTTP Mapping: 501 Not Implemented + UNIMPLEMENTED = 12; + + // Internal errors. This means that some invariants expected by the + // underlying system have been broken. This error code is reserved + // for serious errors. + // + // HTTP Mapping: 500 Internal Server Error + INTERNAL = 13; + + // The service is currently unavailable. This is most likely a + // transient condition, which can be corrected by retrying with + // a backoff. + // + // See the guidelines above for deciding between `FAILED_PRECONDITION`, + // `ABORTED`, and `UNAVAILABLE`. + // + // HTTP Mapping: 503 Service Unavailable + UNAVAILABLE = 14; + + // Unrecoverable data loss or corruption. + // + // HTTP Mapping: 500 Internal Server Error + DATA_LOSS = 15; +} diff --git a/google/cloud/sql/connector/proto/google/rpc/error_details.proto b/google/cloud/sql/connector/proto/google/rpc/error_details.proto new file mode 100644 index 000000000..f24ae0099 --- /dev/null +++ b/google/cloud/sql/connector/proto/google/rpc/error_details.proto @@ -0,0 +1,200 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package google.rpc; + +import "google/protobuf/duration.proto"; + +option go_package = "google.golang.org/genproto/googleapis/rpc/errdetails;errdetails"; +option java_multiple_files = true; +option java_outer_classname = "ErrorDetailsProto"; +option java_package = "com.google.rpc"; +option objc_class_prefix = "RPC"; + + +// Describes when the clients can retry a failed request. Clients could ignore +// the recommendation here or retry when this information is missing from error +// responses. +// +// It's always recommended that clients should use exponential backoff when +// retrying. +// +// Clients should wait until `retry_delay` amount of time has passed since +// receiving the error response before retrying. If retrying requests also +// fail, clients should use an exponential backoff scheme to gradually increase +// the delay between retries based on `retry_delay`, until either a maximum +// number of retires have been reached or a maximum retry delay cap has been +// reached. +message RetryInfo { + // Clients should wait at least this long between retrying the same request. + google.protobuf.Duration retry_delay = 1; +} + +// Describes additional debugging info. +message DebugInfo { + // The stack trace entries indicating where the error occurred. + repeated string stack_entries = 1; + + // Additional debugging information provided by the server. + string detail = 2; +} + +// Describes how a quota check failed. +// +// For example if a daily limit was exceeded for the calling project, +// a service could respond with a QuotaFailure detail containing the project +// id and the description of the quota limit that was exceeded. If the +// calling project hasn't enabled the service in the developer console, then +// a service could respond with the project id and set `service_disabled` +// to true. +// +// Also see RetryDetail and Help types for other details about handling a +// quota failure. +message QuotaFailure { + // A message type used to describe a single quota violation. For example, a + // daily quota or a custom quota that was exceeded. + message Violation { + // The subject on which the quota check failed. + // For example, "clientip:" or "project:". + string subject = 1; + + // A description of how the quota check failed. Clients can use this + // description to find more about the quota configuration in the service's + // public documentation, or find the relevant quota limit to adjust through + // developer console. + // + // For example: "Service disabled" or "Daily Limit for read operations + // exceeded". + string description = 2; + } + + // Describes all quota violations. + repeated Violation violations = 1; +} + +// Describes what preconditions have failed. +// +// For example, if an RPC failed because it required the Terms of Service to be +// acknowledged, it could list the terms of service violation in the +// PreconditionFailure message. +message PreconditionFailure { + // A message type used to describe a single precondition failure. + message Violation { + // The type of PreconditionFailure. We recommend using a service-specific + // enum type to define the supported precondition violation types. For + // example, "TOS" for "Terms of Service violation". + string type = 1; + + // The subject, relative to the type, that failed. + // For example, "google.com/cloud" relative to the "TOS" type would + // indicate which terms of service is being referenced. + string subject = 2; + + // A description of how the precondition failed. Developers can use this + // description to understand how to fix the failure. + // + // For example: "Terms of service not accepted". + string description = 3; + } + + // Describes all precondition violations. + repeated Violation violations = 1; +} + +// Describes violations in a client request. This error type focuses on the +// syntactic aspects of the request. +message BadRequest { + // A message type used to describe a single bad request field. + message FieldViolation { + // A path leading to a field in the request body. The value will be a + // sequence of dot-separated identifiers that identify a protocol buffer + // field. E.g., "field_violations.field" would identify this field. + string field = 1; + + // A description of why the request element is bad. + string description = 2; + } + + // Describes all violations in a client request. + repeated FieldViolation field_violations = 1; +} + +// Contains metadata about the request that clients can attach when filing a bug +// or providing other forms of feedback. +message RequestInfo { + // An opaque string that should only be interpreted by the service generating + // it. For example, it can be used to identify requests in the service's logs. + string request_id = 1; + + // Any data that was used to serve this request. For example, an encrypted + // stack trace that can be sent back to the service provider for debugging. + string serving_data = 2; +} + +// Describes the resource that is being accessed. +message ResourceInfo { + // A name for the type of resource being accessed, e.g. "sql table", + // "cloud storage bucket", "file", "Google calendar"; or the type URL + // of the resource: e.g. "type.googleapis.com/google.pubsub.v1.Topic". + string resource_type = 1; + + // The name of the resource being accessed. For example, a shared calendar + // name: "example.com_4fghdhgsrgh@group.calendar.google.com", if the current + // error is [google.rpc.Code.PERMISSION_DENIED][google.rpc.Code.PERMISSION_DENIED]. + string resource_name = 2; + + // The owner of the resource (optional). + // For example, "user:" or "project:". + string owner = 3; + + // Describes what error is encountered when accessing this resource. + // For example, updating a cloud project may require the `writer` permission + // on the developer console project. + string description = 4; +} + +// Provides links to documentation or for performing an out of band action. +// +// For example, if a quota check failed with an error indicating the calling +// project hasn't enabled the accessed service, this can contain a URL pointing +// directly to the right place in the developer console to flip the bit. +message Help { + // Describes a URL link. + message Link { + // Describes what the link offers. + string description = 1; + + // The URL of the link. + string url = 2; + } + + // URL(s) pointing to additional information on handling the current error. + repeated Link links = 1; +} + +// Provides a localized error message that is safe to return to the user +// which can be attached to an RPC error. +message LocalizedMessage { + // The locale used following the specification defined at + // http://www.rfc-editor.org/rfc/bcp/bcp47.txt. + // Examples are: "en-US", "fr-CH", "es-MX" + string locale = 1; + + // The localized error message in the above locale. + string message = 2; +} diff --git a/google/cloud/sql/connector/proto/google/rpc/status.proto b/google/cloud/sql/connector/proto/google/rpc/status.proto new file mode 100644 index 000000000..0839ee966 --- /dev/null +++ b/google/cloud/sql/connector/proto/google/rpc/status.proto @@ -0,0 +1,92 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package google.rpc; + +import "google/protobuf/any.proto"; + +option go_package = "google.golang.org/genproto/googleapis/rpc/status;status"; +option java_multiple_files = true; +option java_outer_classname = "StatusProto"; +option java_package = "com.google.rpc"; +option objc_class_prefix = "RPC"; + + +// The `Status` type defines a logical error model that is suitable for different +// programming environments, including REST APIs and RPC APIs. It is used by +// [gRPC](https://github.com/grpc). The error model is designed to be: +// +// - Simple to use and understand for most users +// - Flexible enough to meet unexpected needs +// +// # Overview +// +// The `Status` message contains three pieces of data: error code, error message, +// and error details. The error code should be an enum value of +// [google.rpc.Code][google.rpc.Code], but it may accept additional error codes if needed. The +// error message should be a developer-facing English message that helps +// developers *understand* and *resolve* the error. If a localized user-facing +// error message is needed, put the localized message in the error details or +// localize it in the client. The optional error details may contain arbitrary +// information about the error. There is a predefined set of error detail types +// in the package `google.rpc` that can be used for common error conditions. +// +// # Language mapping +// +// The `Status` message is the logical representation of the error model, but it +// is not necessarily the actual wire format. When the `Status` message is +// exposed in different client libraries and different wire protocols, it can be +// mapped differently. For example, it will likely be mapped to some exceptions +// in Java, but more likely mapped to some error codes in C. +// +// # Other uses +// +// The error model and the `Status` message can be used in a variety of +// environments, either with or without APIs, to provide a +// consistent developer experience across different environments. +// +// Example uses of this error model include: +// +// - Partial errors. If a service needs to return partial errors to the client, +// it may embed the `Status` in the normal response to indicate the partial +// errors. +// +// - Workflow errors. A typical workflow has multiple steps. Each step may +// have a `Status` message for error reporting. +// +// - Batch operations. If a client uses batch request and batch response, the +// `Status` message should be used directly inside batch response, one for +// each error sub-response. +// +// - Asynchronous operations. If an API call embeds asynchronous operation +// results in its response, the status of those operations should be +// represented directly using the `Status` message. +// +// - Logging. If some API errors are stored in logs, the message `Status` could +// be used directly after any stripping needed for security/privacy reasons. +message Status { + // The status code, which should be an enum value of [google.rpc.Code][google.rpc.Code]. + int32 code = 1; + + // A developer-facing error message, which should be in English. Any + // user-facing error message should be localized and sent in the + // [google.rpc.Status.details][google.rpc.Status.details] field, or localized by the client. + string message = 2; + + // A list of messages that carry the error details. There is a common set of + // message types for APIs to use. + repeated google.protobuf.Any details = 3; +} diff --git a/google/cloud/sql/connector/proto/sql_data_service.proto b/google/cloud/sql/connector/proto/sql_data_service.proto new file mode 100644 index 000000000..98d688cd3 --- /dev/null +++ b/google/cloud/sql/connector/proto/sql_data_service.proto @@ -0,0 +1,264 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package google.cloud.sql.v1beta4; + +option go_package = "internal/sqldata"; +option java_package = "com.google.cloud.sql.v1beta4"; +option java_outer_classname = "CloudSqlDataProto"; +option java_multiple_files = true; + +import "google/rpc/status.proto"; + +// Service for streaming data to and from Cloud SQL instances. +service SqlDataService { + // `StreamSqlData` establishes a bidirectional stream to a Cloud SQL instance, + // and then streams data to and from the instance. + // + // The first message from the client MUST be a `StreamSqlDataRequest` request + // with configuration settings, including required values for the + // `connection_settings` field. Subsequent messages from the client may + // contain the `payload` field. + // + // Messages from the server may contain the `payload` field. + // + // The `payload` fields of the request and response streams contain the raw + // data of the database's native wire protocol (e.g., PostgreSQL wire + // protocol). The database client is responsible for generating and parsing + // this data. + // + // Any errors on initial connection (e.g., connection failure, authorization + // issues, network problems) will result in the stream being terminated with + // an appropriate RPC status exception. + // + // After a successful connection is made, if an error occurs, then the server + // terminates connection and returns the appropriate RPC status exception. + rpc StreamSqlData(stream StreamSqlDataRequest) + returns (stream StreamSqlDataResponse) {} +} + +// Message sent from the client to `SqlDataService`. +message StreamSqlDataRequest { + // Deprecated: Use `StartSession.location_id` or `ContinueSession.location_id` + // instead. `location_id` is used to route the request to a specific region. + // Use the same region which was used to create the instance. Use the format + // `locations/{location}`, for example: `locations/us-central1`. + string location_id = 1; + + // Deprecated: Use the `message` oneof instead. The type of message sent + // within the stream. + oneof message_type { + // Deprecated: Use `start_session` or `continue_session` instead. + // Parameters for establishing the connection. MUST be sent as the first + // message on the stream. + ConnectionSettings connection_settings = 2; + + // Deprecated: Use `DataPacket` instead. + // Data to be forwarded to the database. + ClientPayload payload = 3; + } + + // Acknowledges data received by the client. + Ack ack = 4; + + // The message to the server. + oneof message { + // Starts a new session. When starting a new session, this is the first + // message the client sends. + StartSession start_session = 5; + // Continues an existing session. When starting a new session, this is the + // first message the client sends. + ContinueSession continue_session = 6; + // Database data. + DataPacket data = 7; + // Terminates the session. This closes the connection to the database. + TerminateSession terminate_session = 8; + } +} + +// Deprecated: New schema structure. Initial connection parameters. +message ConnectionSettings { + option deprecated = true; + + // The target of the connection. + oneof target { + // The identifier of the Cloud SQL instance. + InstanceId instance_id = 1; + } +} + +// Start a new session. The client must send this as the first message to the +// server to start a new session. The client may immediately send Data messages +// without waiting for a reply from the server. +message StartSession { + //`location_id` is used to route the + // request to a specific region. Use the same region which was used to create + // the instance. Use the format `locations/{location}`, for example: + // `locations/us-central1`. + string location_id = 1; + // The Cloud SQL instance resource name, for example: + // projects/example-project/instances/example-instance + string instance_id = 2; + // The session id, chosen by the client. This should be an unguessable string. + // If the client does not intend to reconnect to this session, the client may + // leave session_id unset. + string session_id = 3; +} + +// Reconnects to an existing session. The client must send this as the first +// message to the server to reconnect to an existing session. The client may +// immediately send Data messages without waiting for a reply from the server. +message ContinueSession { + //`location_id` is used to route the + // request to a specific region. Use the same region which was used to create + // the instance. Use the format `locations/{location}`, for example: + // `locations/us-central1`. + string location_id = 1; + + // The Cloud SQL instance resource name, for example: + // projects/example-project/instances/example-instance + string instance_id = 2; + + // The id of the session to reconnect. + string session_id = 3; +} + +// Deprecated: New schema structure. The identifier of the Cloud SQL instance. +message InstanceId { + option deprecated = true; + + // Full resource name of the Cloud SQL instance, in the form: + // `projects/{project}/instances/{instance}`, for example: + // `projects/foo-project/instances/bar-instance`. + string instance = 1; +} + +// Deprecated: New schema structure. Wrapper for data being sent to the +// database. +message ClientPayload { + option deprecated = true; + + // Raw data to be sent to the database. See the documentation for + // `StreamSqlData` for details on the expected wire format. + bytes data = 1; +} + +// Message sent from SqlDataService back to the client. +message StreamSqlDataResponse { + // Deprecated: New schema structure. The type of the message received from + // `SqlDataService`. + oneof type { + // Raw data received from the database. + ServerPayload payload = 1; + } + + // Acknowledges data received by the server. + Ack ack = 2; + // A message from the server to the client. + oneof message { + // The first message from the server to the client, containing metadata + // about this session. + SessionMetadata session_metadata = 3; + // Data from the database. + DataPacket data = 4; + // Terminates the session. This indicates that the database connection + // is closed. When the client receives this message, it should not + // attempt to reconnect. + TerminateSession terminate_session = 5; + } +} + +// Deprecated: New schema structure. Wrapper for data being received from the +// database. +message ServerPayload { + option deprecated = true; + + // Raw data received from the database. See the documentation for + // `StreamSqlData` for details on the expected wire format. + bytes data = 1; +} +// Metadata from the server to the client about the session. The server will +// always send this as the first message +message SessionMetadata { + // The features supported by the server for this session. This field is used + // by the client to determine which features are available on the server. + // The features supported by the server for this session. + repeated SqlDataFeature supported_features = 1; +} + +// Contains data being sent or received by the database. +message DataPacket { + // The absolute byte offset of the first byte in this payload. + // 0 for new connections or resumed connections that hasn't acked any bytes + // from server. Non-zero for resumed connections + int64 first_byte_offset = 1; + // Raw data being sent or received by the database. + bytes data = 2; +} +// Acknowledges data received by the client or server. +message Ack { + // The absolute number of bytes processed in the session. + int64 received_offset = 1; +} +// Indicates that the session is permanently ended. +message TerminateSession { + // The session termination status. + google.rpc.Status status = 1; +} + +// Error reasons for `StreamSqlData`. +// Typically used with standard error codes, with the error info/reason field +// set to the string representation of the enum value. +enum StreamSqlDataErrorReason { + // Indicates that the error reason is unknown. + STREAM_SQL_DATA_ERROR_REASON_UNKNOWN = 0; + + // Indicates that the operation is not supported for given instance type. + // Used with status code `google.rpc.Code.FAILED_PRECONDITION`. + STREAM_SQL_DATA_ERROR_REASON_UNSUPPORTED_INSTANCE_TYPE = 1; + + // Indicates that reconnect failed and should not be retried. + // Used with status code `google.rpc.Code.INTERNAL`. + STREAM_SQL_DATA_ERROR_REASON_RECONNECT_FAILED = 3; + + // Indicates that the database client closed its connection normally. + // Used with status code `google.rpc.Code.CANCELED`. + STREAM_SQL_DATA_ERROR_REASON_DB_CLIENT_CLOSED = 4; + + // Indicates that the database server closed its connection normally. + // Used with status code `google.rpc.Code.CANCELED`. + STREAM_SQL_DATA_ERROR_REASON_DB_SERVER_CLOSED = 5; + + // Indicates that the peer sent an ACK message that was not within an + // acceptable range. Used with the status code + // `google.rpc.Code.FAILED_PRECONDITION`. + STREAM_SQL_DATA_ERROR_REASON_INVALID_ACK = 6; + + // Indicates that the SqlDataService lost its connection to the + // database instance. This is a retryable error. + // Used with status code `google.rpc.Code.ABORTED`. + STREAM_SQL_DATA_ERROR_REASON_DISCONNECTED = 7; +} + +// The session features. The server must send the supported features in its +// first message to the client. +enum SqlDataFeature { + // The feature is not specified. This value should not be used. + SQL_DATA_FEATURE_UNSPECIFIED = 0; + // The server supports reconnecting to the session. If this feature is not + // present, the client should not attempt to reconnect to the session. + SQL_DATA_FEATURE_RECONNECT = 1; +} diff --git a/google/cloud/sql/connector/proto/sql_data_service_pb2.py b/google/cloud/sql/connector/proto/sql_data_service_pb2.py new file mode 100644 index 000000000..fb25be21d --- /dev/null +++ b/google/cloud/sql/connector/proto/sql_data_service_pb2.py @@ -0,0 +1,88 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: google/cloud/sql/connector/proto/sql_data_service.proto +# Protobuf Python Version: 6.33.5 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 6, + 33, + 5, + '', + 'google/cloud/sql/connector/proto/sql_data_service.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n7google/cloud/sql/connector/proto/sql_data_service.proto\x12\x18google.cloud.sql.v1beta4\x1a\x17google/rpc/status.proto\"\x82\x04\n\x14StreamSqlDataRequest\x12\x13\n\x0blocation_id\x18\x01 \x01(\t\x12K\n\x13\x63onnection_settings\x18\x02 \x01(\x0b\x32,.google.cloud.sql.v1beta4.ConnectionSettingsH\x00\x12:\n\x07payload\x18\x03 \x01(\x0b\x32\'.google.cloud.sql.v1beta4.ClientPayloadH\x00\x12*\n\x03\x61\x63k\x18\x04 \x01(\x0b\x32\x1d.google.cloud.sql.v1beta4.Ack\x12?\n\rstart_session\x18\x05 \x01(\x0b\x32&.google.cloud.sql.v1beta4.StartSessionH\x01\x12\x45\n\x10\x63ontinue_session\x18\x06 \x01(\x0b\x32).google.cloud.sql.v1beta4.ContinueSessionH\x01\x12\x34\n\x04\x64\x61ta\x18\x07 \x01(\x0b\x32$.google.cloud.sql.v1beta4.DataPacketH\x01\x12G\n\x11terminate_session\x18\x08 \x01(\x0b\x32*.google.cloud.sql.v1beta4.TerminateSessionH\x01\x42\x0e\n\x0cmessage_typeB\t\n\x07message\"_\n\x12\x43onnectionSettings\x12;\n\x0binstance_id\x18\x01 \x01(\x0b\x32$.google.cloud.sql.v1beta4.InstanceIdH\x00:\x02\x18\x01\x42\x08\n\x06target\"L\n\x0cStartSession\x12\x13\n\x0blocation_id\x18\x01 \x01(\t\x12\x13\n\x0binstance_id\x18\x02 \x01(\t\x12\x12\n\nsession_id\x18\x03 \x01(\t\"O\n\x0f\x43ontinueSession\x12\x13\n\x0blocation_id\x18\x01 \x01(\t\x12\x13\n\x0binstance_id\x18\x02 \x01(\t\x12\x12\n\nsession_id\x18\x03 \x01(\t\"\"\n\nInstanceId\x12\x10\n\x08instance\x18\x01 \x01(\t:\x02\x18\x01\"!\n\rClientPayload\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c:\x02\x18\x01\"\xd8\x02\n\x15StreamSqlDataResponse\x12:\n\x07payload\x18\x01 \x01(\x0b\x32\'.google.cloud.sql.v1beta4.ServerPayloadH\x00\x12*\n\x03\x61\x63k\x18\x02 \x01(\x0b\x32\x1d.google.cloud.sql.v1beta4.Ack\x12\x45\n\x10session_metadata\x18\x03 \x01(\x0b\x32).google.cloud.sql.v1beta4.SessionMetadataH\x01\x12\x34\n\x04\x64\x61ta\x18\x04 \x01(\x0b\x32$.google.cloud.sql.v1beta4.DataPacketH\x01\x12G\n\x11terminate_session\x18\x05 \x01(\x0b\x32*.google.cloud.sql.v1beta4.TerminateSessionH\x01\x42\x06\n\x04typeB\t\n\x07message\"!\n\rServerPayload\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c:\x02\x18\x01\"W\n\x0fSessionMetadata\x12\x44\n\x12supported_features\x18\x01 \x03(\x0e\x32(.google.cloud.sql.v1beta4.SqlDataFeature\"5\n\nDataPacket\x12\x19\n\x11\x66irst_byte_offset\x18\x01 \x01(\x03\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x1e\n\x03\x41\x63k\x12\x17\n\x0freceived_offset\x18\x01 \x01(\x03\"6\n\x10TerminateSession\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.google.rpc.Status*\xf6\x02\n\x18StreamSqlDataErrorReason\x12(\n$STREAM_SQL_DATA_ERROR_REASON_UNKNOWN\x10\x00\x12:\n6STREAM_SQL_DATA_ERROR_REASON_UNSUPPORTED_INSTANCE_TYPE\x10\x01\x12\x31\n-STREAM_SQL_DATA_ERROR_REASON_RECONNECT_FAILED\x10\x03\x12\x31\n-STREAM_SQL_DATA_ERROR_REASON_DB_CLIENT_CLOSED\x10\x04\x12\x31\n-STREAM_SQL_DATA_ERROR_REASON_DB_SERVER_CLOSED\x10\x05\x12,\n(STREAM_SQL_DATA_ERROR_REASON_INVALID_ACK\x10\x06\x12-\n)STREAM_SQL_DATA_ERROR_REASON_DISCONNECTED\x10\x07*R\n\x0eSqlDataFeature\x12 \n\x1cSQL_DATA_FEATURE_UNSPECIFIED\x10\x00\x12\x1e\n\x1aSQL_DATA_FEATURE_RECONNECT\x10\x01\x32\x88\x01\n\x0eSqlDataService\x12v\n\rStreamSqlData\x12..google.cloud.sql.v1beta4.StreamSqlDataRequest\x1a/.google.cloud.sql.v1beta4.StreamSqlDataResponse\"\x00(\x01\x30\x01\x42\x45\n\x1c\x63om.google.cloud.sql.v1beta4B\x11\x43loudSqlDataProtoP\x01Z\x10internal/sqldatab\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.cloud.sql.connector.proto.sql_data_service_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\034com.google.cloud.sql.v1beta4B\021CloudSqlDataProtoP\001Z\020internal/sqldata' + _globals['_CONNECTIONSETTINGS']._loaded_options = None + _globals['_CONNECTIONSETTINGS']._serialized_options = b'\030\001' + _globals['_INSTANCEID']._loaded_options = None + _globals['_INSTANCEID']._serialized_options = b'\030\001' + _globals['_CLIENTPAYLOAD']._loaded_options = None + _globals['_CLIENTPAYLOAD']._serialized_options = b'\030\001' + _globals['_SERVERPAYLOAD']._loaded_options = None + _globals['_SERVERPAYLOAD']._serialized_options = b'\030\001' + _globals['_STREAMSQLDATAERRORREASON']._serialized_start=1569 + _globals['_STREAMSQLDATAERRORREASON']._serialized_end=1943 + _globals['_SQLDATAFEATURE']._serialized_start=1945 + _globals['_SQLDATAFEATURE']._serialized_end=2027 + _globals['_STREAMSQLDATAREQUEST']._serialized_start=111 + _globals['_STREAMSQLDATAREQUEST']._serialized_end=625 + _globals['_CONNECTIONSETTINGS']._serialized_start=627 + _globals['_CONNECTIONSETTINGS']._serialized_end=722 + _globals['_STARTSESSION']._serialized_start=724 + _globals['_STARTSESSION']._serialized_end=800 + _globals['_CONTINUESESSION']._serialized_start=802 + _globals['_CONTINUESESSION']._serialized_end=881 + _globals['_INSTANCEID']._serialized_start=883 + _globals['_INSTANCEID']._serialized_end=917 + _globals['_CLIENTPAYLOAD']._serialized_start=919 + _globals['_CLIENTPAYLOAD']._serialized_end=952 + _globals['_STREAMSQLDATARESPONSE']._serialized_start=955 + _globals['_STREAMSQLDATARESPONSE']._serialized_end=1299 + _globals['_SERVERPAYLOAD']._serialized_start=1301 + _globals['_SERVERPAYLOAD']._serialized_end=1334 + _globals['_SESSIONMETADATA']._serialized_start=1336 + _globals['_SESSIONMETADATA']._serialized_end=1423 + _globals['_DATAPACKET']._serialized_start=1425 + _globals['_DATAPACKET']._serialized_end=1478 + _globals['_ACK']._serialized_start=1480 + _globals['_ACK']._serialized_end=1510 + _globals['_TERMINATESESSION']._serialized_start=1512 + _globals['_TERMINATESESSION']._serialized_end=1566 + _globals['_SQLDATASERVICE']._serialized_start=2030 + _globals['_SQLDATASERVICE']._serialized_end=2166 +# @@protoc_insertion_point(module_scope) diff --git a/google/cloud/sql/connector/proto/sql_data_service_pb2_grpc.py b/google/cloud/sql/connector/proto/sql_data_service_pb2_grpc.py new file mode 100644 index 000000000..42240ecbd --- /dev/null +++ b/google/cloud/sql/connector/proto/sql_data_service_pb2_grpc.py @@ -0,0 +1,137 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" + +import grpc + +from google.cloud.sql.connector.proto import ( + sql_data_service_pb2 as google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2, +) + +GRPC_GENERATED_VERSION = '1.81.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + ' but the generated code in google/cloud/sql/connector/proto/sql_data_service_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class SqlDataServiceStub: + """Service for streaming data to and from Cloud SQL instances. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.StreamSqlData = channel.stream_stream( + '/google.cloud.sql.v1beta4.SqlDataService/StreamSqlData', + request_serializer=google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataRequest.SerializeToString, + response_deserializer=google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataResponse.FromString, + _registered_method=True) + + +class SqlDataServiceServicer: + """Service for streaming data to and from Cloud SQL instances. + """ + + def StreamSqlData(self, request_iterator, context): + """`StreamSqlData` establishes a bidirectional stream to a Cloud SQL instance, + and then streams data to and from the instance. + + The first message from the client MUST be a `StreamSqlDataRequest` request + with configuration settings, including required values for the + `connection_settings` field. Subsequent messages from the client may + contain the `payload` field. + + Messages from the server may contain the `payload` field. + + The `payload` fields of the request and response streams contain the raw + data of the database's native wire protocol (e.g., PostgreSQL wire + protocol). The database client is responsible for generating and parsing + this data. + + Any errors on initial connection (e.g., connection failure, authorization + issues, network problems) will result in the stream being terminated with + an appropriate RPC status exception. + + After a successful connection is made, if an error occurs, then the server + terminates connection and returns the appropriate RPC status exception. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_SqlDataServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'StreamSqlData': grpc.stream_stream_rpc_method_handler( + servicer.StreamSqlData, + request_deserializer=google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataRequest.FromString, + response_serializer=google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'google.cloud.sql.v1beta4.SqlDataService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('google.cloud.sql.v1beta4.SqlDataService', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class SqlDataService: + """Service for streaming data to and from Cloud SQL instances. + """ + + @staticmethod + def StreamSqlData(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_stream( + request_iterator, + target, + '/google.cloud.sql.v1beta4.SqlDataService/StreamSqlData', + google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataRequest.SerializeToString, + google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/google/cloud/sql/connector/sqldata_client.py b/google/cloud/sql/connector/sqldata_client.py new file mode 100644 index 000000000..950373f4b --- /dev/null +++ b/google/cloud/sql/connector/sqldata_client.py @@ -0,0 +1,355 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +import socket +from typing import Any, Callable, Optional + +from google.auth.credentials import Credentials +from google.auth.transport.grpc import AuthMetadataPlugin +from google.auth.transport.requests import Request +import grpc + +import google.rpc.status_pb2 # noqa: F401 # isort: skip +from google.cloud.sql.connector.proto import sql_data_service_pb2 # type: ignore +from google.cloud.sql.connector.proto import sql_data_service_pb2_grpc # type: ignore + +logger = logging.getLogger(__name__) + + +class SqlDataClient: + def __init__( + self, + endpoint: str, + credentials: Credentials, + quota_project: Optional[str] = None, + timeout: Optional[float] = None, + ): + self._endpoint = endpoint + self._credentials = credentials + self._quota_project = quota_project + self._timeout = timeout + + async def connect_tunnel( + self, + instance_connection_name: str, + region: str, + project: str, + get_conn_info: Callable[[], Any], + enable_iam_auth: bool, + on_fallback: Callable[[str], None], + is_fallback_cached: Callable[[str], bool], + ) -> int: + """Starts a local TCP tunnel and returns the local port. + + If the instance does not support SQL Data Service, it falls back + to a direct TLS connection. + """ + # Start local TCP server + server = await asyncio.start_server( + lambda r, w: self._handle_tunnel( + r, + w, + instance_connection_name, + region, + project, + get_conn_info, + enable_iam_auth, + on_fallback, + is_fallback_cached, + ), + "127.0.0.1", + 0, + ) + + port = server.sockets[0].getsockname()[1] + logger.debug(f"SQL Data tunnel listening on 127.0.0.1:{port}") + + # Keep reference to server to close it + self._server = server + return port + + async def close(self) -> None: + """Closes the local tunnel server if it is running.""" + if hasattr(self, "_server") and self._server: + self._server.close() + try: + await asyncio.wait_for(self._server.wait_closed(), timeout=2.0) + logger.debug("SQL Data tunnel server closed by client close()") + except asyncio.TimeoutError: + logger.warning("Timeout waiting for SQL Data tunnel server to close") + + async def _handle_tunnel( + self, + client_reader: asyncio.StreamReader, + client_writer: asyncio.StreamWriter, + instance_connection_name: str, + region: str, + project: str, + get_conn_info: Callable[[], Any], + enable_iam_auth: bool, + on_fallback: Callable[[str], None], + is_fallback_cached: Callable[[str], bool], + ): + logger.debug("Accepted local connection for SQL Data tunnel") + # Close the server so no more connections are accepted on this port + self._server.close() + + # Buffer to cache client writes for fallback replay + client_write_buffer = bytearray() + first_read_done = False + fallback_triggered = False + + # We need to share these streams between tasks + backend_reader: Optional[asyncio.StreamReader] = None + backend_writer: Optional[asyncio.StreamWriter] = None + grpc_stream: Optional[Any] = None + grpc_channel: Optional[grpc.aio.Channel] = None + + # Check if fallback is already cached + use_fallback = is_fallback_cached(instance_connection_name) + + async def connect_grpc() -> tuple[grpc.aio.Channel, Any]: + auth_request = Request() + plugin = AuthMetadataPlugin(self._credentials, auth_request) + call_creds = grpc.metadata_call_credentials(plugin) + channel_creds = grpc.composite_channel_credentials( + grpc.ssl_channel_credentials(), call_creds + ) + + endpoint = self._endpoint + if endpoint.startswith("https://"): + endpoint = endpoint[len("https://") :] + if endpoint.startswith("http://"): + endpoint = endpoint[len("http://") :] + + logger.debug(f"Creating secure channel to {endpoint}") + channel = grpc.aio.secure_channel(endpoint, channel_creds) + stub = sql_data_service_pb2_grpc.SqlDataServiceStub(channel) + + instance_id = f"projects/{project}/instances/{instance_connection_name.split(':')[-1]}" + location_id = f"locations/{region}" + + metadata = [] + quota_project_in_creds = getattr(self._credentials, "quota_project_id", None) + if self._quota_project and self._quota_project != quota_project_in_creds: + metadata.append(("x-goog-user-project", self._quota_project)) + metadata.append( + ( + "x-goog-request-params", + f"instance_id={instance_id}&location_id={location_id}", + ) + ) + + # Start stream + logger.debug(f"Starting StreamSqlData with metadata {metadata}") + stream = stub.StreamSqlData(metadata=metadata) + + # Send StartSession + start_session = sql_data_service_pb2.StartSession( # type: ignore[attr-defined] + instance_id=instance_id, location_id=location_id + ) + req = sql_data_service_pb2.StreamSqlDataRequest( # type: ignore[attr-defined] + start_session=start_session + ) + logger.debug("Writing StartSession to stream...") + await stream.write(req) + logger.debug("StartSession written successfully") + return channel, stream + + async def connect_direct() -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: + logger.debug("Fallback triggered, fetching connection info...") + conn_info = await get_conn_info() + # Find a fallback IP address + fallback_ip = None + from google.cloud.sql.connector.enums import IPTypes + for t in [IPTypes.PUBLIC, IPTypes.PRIVATE, IPTypes.PSC]: + try: + fallback_ip = conn_info.get_preferred_ip(t) + break + except Exception: + continue + if not fallback_ip: + raise ValueError("Cannot fallback to direct connection: no IP address available.") + logger.debug(f"Connecting directly to {fallback_ip}:3307") + ssl_context = await conn_info.create_ssl_context(enable_iam_auth) + return await asyncio.open_connection( + fallback_ip, 3307, ssl=ssl_context, server_hostname=fallback_ip + ) + + # Initialize connection + if use_fallback: + logger.debug("Using cached fallback connection") + backend_reader, backend_writer = await connect_direct() + fallback_triggered = True + else: + try: + grpc_channel, grpc_stream = await connect_grpc() + except Exception as e: + logger.debug(f"Failed to initialize gRPC stream: {e}") + # Try fallback immediately + backend_reader, backend_writer = await connect_direct() + fallback_triggered = True + on_fallback(instance_connection_name) + + # Task to read from client and write to backend + async def client_to_backend(): + nonlocal first_read_done, fallback_triggered, backend_writer, grpc_stream + try: + while True: + data = await client_reader.read(4096) + if not data: + logger.debug("Client socket EOF") + break + + if not first_read_done and not fallback_triggered: + client_write_buffer.extend(data) + + if fallback_triggered: + if backend_writer: + backend_writer.write(data) + await backend_writer.drain() + else: + packet = sql_data_service_pb2.DataPacket(data=data) # type: ignore[attr-defined] + req = sql_data_service_pb2.StreamSqlDataRequest( # type: ignore[attr-defined] + data=packet + ) + if grpc_stream: + await grpc_stream.write(req) + except Exception as e: + logger.error(f"Error in client_to_backend: {e}") + raise + finally: + if fallback_triggered: + if backend_writer: + backend_writer.write_eof() + else: + if grpc_stream: + try: + await grpc_stream.done_writing() + except Exception: + pass + logger.debug("Client to backend task finished") + + # Task to read from backend and write to client + async def backend_to_client(): + nonlocal first_read_done, fallback_triggered, backend_reader, backend_writer, grpc_stream, grpc_channel + try: + if fallback_triggered: + # If we started with fallback, just copy + while True: + if not backend_reader: + break + data = await backend_reader.read(4096) + if not data: + break + client_writer.write(data) + await client_writer.drain() + else: + # gRPC read loop + try: + if not grpc_stream: + return + async for resp in grpc_stream: + first_read_done = True + msg_type = resp.WhichOneof("message") + if msg_type == "session_metadata": + logger.debug("Received SessionMetadata") + elif msg_type == "data": + data = resp.data.data + client_writer.write(data) + await client_writer.drain() + elif msg_type == "terminate_session": + logger.debug("Received TerminateSession") + break + except grpc.aio.AioRpcError as e: + logger.debug(f"gRPC stream error: {e}") + # Check for fallback condition + if ( + not first_read_done + and e.code() == grpc.StatusCode.FAILED_PRECONDITION + ): + logger.info( + f"SQL Data Service not supported for {instance_connection_name}. " + "Falling back to direct connection." + ) + fallback_triggered = True + on_fallback(instance_connection_name) + + # Clean up gRPC + if grpc_channel: + await grpc_channel.close() + + # Connect direct + backend_reader, backend_writer = await connect_direct() + + # Replay buffered client data + if client_write_buffer: + logger.debug(f"Replaying {len(client_write_buffer)} bytes to fallback connection") + backend_writer.write(bytes(client_write_buffer)) + await backend_writer.drain() + + # Start copying from direct connection + while True: + data = await backend_reader.read(4096) + if not data: + break + client_writer.write(data) + await client_writer.drain() + else: + # Other gRPC error, re-raise to close connection + raise + except Exception as e: + logger.error(f"Error in backend_to_client: {e}") + raise + finally: + client_writer.close() + try: + await client_writer.wait_closed() + except Exception: + pass + if fallback_triggered and backend_writer: + backend_writer.close() + try: + await backend_writer.wait_closed() + except Exception: + pass + elif grpc_channel: + await grpc_channel.close() + logger.debug("Backend to client task finished") + + # Run both tasks + try: + await asyncio.gather(client_to_backend(), backend_to_client()) + finally: + logger.debug("Closing client socket in _handle_tunnel finally") + try: + client_writer.close() + sock = client_writer.get_extra_info('socket') + if sock: + sock.close() + except Exception as e: + logger.debug(f"Error closing client writer: {e}") + logger.debug("SQL Data tunnel handler finished") + + +class FallbackSocket(socket.socket): + def connect(self, *args: Any, **kwargs: Any) -> None: + # Already connected, do nothing. + # This is needed because some drivers (like pymysql) try to call connect() + # internally even if passed an already connected socket. + pass diff --git a/pyproject.toml b/pyproject.toml index cbf0dd10f..4a2894a51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,9 @@ dependencies = [ "dnspython>=2.0.0", "Requests", "google-auth>=2.28.0", + "grpcio", + "protobuf", + "googleapis-common-protos", ] dynamic = ["version"] diff --git a/requirements-test.txt b/requirements-test.txt index 296878dd8..e1a150b21 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -10,4 +10,5 @@ pg8000==1.31.5 asyncpg==0.31.0 python-tds==1.17.1 aioresponses==0.7.8 -pytest-aiohttp==1.1.0 +pytest-aiohttp<1.1.0 +aiohttp==3.10.11 diff --git a/tests/system/test_sqldata_connection.py b/tests/system/test_sqldata_connection.py new file mode 100644 index 000000000..15a9fe425 --- /dev/null +++ b/tests/system/test_sqldata_connection.py @@ -0,0 +1,104 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os + +import pytest + +from google.cloud.sql.connector import Connector +from google.cloud.sql.connector import IPTypes + +# AI Developer Edition instance connection details +DB_USER = "postgres" +DB_NAME = "postgres" + +# Sandbox endpoints +ADMIN_ENDPOINT = "https://coreltest-sqladmin.mtls.sandbox.googleapis.com" +SQL_DATA_ENDPOINT = "coreltest-sqladmin.mtls.sandbox.googleapis.com:443" + + +@pytest.fixture(name="config") +def config_fixture(): + conn_name = os.environ.get("SQL_DATA_CONNECTION_NAME") + quota_project = os.environ.get("SQL_DATA_PROJECT") + + if not conn_name: + pytest.skip("SQL_DATA_CONNECTION_NAME env var not set") + + password = os.environ.get("POSTGRES_CUSTOMER_CAS_PASS") + if not password: + pytest.skip("POSTGRES_CUSTOMER_CAS_PASS env var not set") + + return { + "conn_name": conn_name, + "quota_project": quota_project, + "password": password, + } + + +@pytest.mark.asyncio +async def test_asyncpg_sqldata_connect(config): + loop = asyncio.get_running_loop() + connector = Connector( + loop=loop, + sqladmin_api_endpoint=ADMIN_ENDPOINT, + sql_data_endpoint=SQL_DATA_ENDPOINT, + quota_project=config["quota_project"], + ) + + conn = None + try: + conn = await connector.connect_async( + config["conn_name"], + "asyncpg", + user=DB_USER, + password=config["password"], + db=DB_NAME, + ip_type=IPTypes.SQL_DATA, + ) + val = await conn.fetchval("SELECT NOW()") + assert val is not None + finally: + if conn: + await conn.close() + await connector.close_async() + + +def test_pg8000_sqldata_connect(config): + connector = Connector( + sqladmin_api_endpoint=ADMIN_ENDPOINT, + sql_data_endpoint=SQL_DATA_ENDPOINT, + quota_project=config["quota_project"], + ) + + conn = None + try: + conn = connector.connect( + config["conn_name"], + "pg8000", + user=DB_USER, + password=config["password"], + db=DB_NAME, + ip_type=IPTypes.SQL_DATA, + ) + cursor = conn.cursor() + cursor.execute("SELECT NOW()") + val = cursor.fetchone() + assert val is not None + cursor.close() + finally: + if conn: + conn.close() + connector.close() diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index a09b5b72f..ac1288da2 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -234,7 +234,7 @@ def test_Connector_Init_bad_ip_type(fake_credentials: Credentials) -> None: assert ( exc_info.value.args[0] == f"Incorrect value for ip_type, got '{bad_ip_type.upper()}'. " - "Want one of: 'PRIMARY', 'PRIVATE', 'PSC', 'PUBLIC'." + "Want one of: 'PRIMARY', 'PRIVATE', 'PSC', 'SQL_DATA', 'PUBLIC'." ) @@ -257,7 +257,7 @@ def test_Connector_connect_bad_ip_type( assert ( exc_info.value.args[0] == f"Incorrect value for ip_type, got '{bad_ip_type.upper()}'. " - "Want one of: 'PRIMARY', 'PRIVATE', 'PSC', 'PUBLIC'." + "Want one of: 'PRIMARY', 'PRIVATE', 'PSC', 'SQL_DATA', 'PUBLIC'." )