diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index ba6439848..50e6d044e 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -46,16 +46,16 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@ea9e4e37992a54ee68a9622e985e60c8e8f12d9f # v3.27.4 + uses: github/codeql-action/init@babb554ede22fd5605947329c4d04d8e7a0b8155 # v3.27.7 with: languages: ${{ matrix.language }} # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). # If this step fails, then you should remove it and run the build manually - name: Autobuild - uses: github/codeql-action/autobuild@ea9e4e37992a54ee68a9622e985e60c8e8f12d9f # v3.27.4 + uses: github/codeql-action/autobuild@babb554ede22fd5605947329c4d04d8e7a0b8155 # v3.27.7 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@ea9e4e37992a54ee68a9622e985e60c8e8f12d9f # v3.27.4 + uses: github/codeql-action/analyze@babb554ede22fd5605947329c4d04d8e7a0b8155 # v3.27.7 with: category: "/language:${{matrix.language}}" diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 37c9428dd..653a2d6fc 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -65,6 +65,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@ea9e4e37992a54ee68a9622e985e60c8e8f12d9f # v3.27.4 + uses: github/codeql-action/upload-sarif@babb554ede22fd5605947329c4d04d8e7a0b8155 # v3.27.7 with: sarif_file: resultsFiltered.sarif diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e97c4830..1467f5436 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## [1.15.0](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/compare/v1.14.0...v1.15.0) (2024-12-10) + + +### Features + +* add support for DNS names with Connector ([#1204](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/issues/1204)) ([1a8f274](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/commit/1a8f2743c661806528c54110fd2a08384398816c)) +* improve aiohttp client error messages ([#1201](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/issues/1201)) ([11f9fe9](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/commit/11f9fe93617cd9a161c69d527d3e7592fcbe56af)) + ## [1.14.0](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/compare/v1.13.0...v1.14.0) (2024-11-20) diff --git a/README.md b/README.md index 28553f972..1f0e633b9 100644 --- a/README.md +++ b/README.md @@ -365,6 +365,69 @@ conn = connector.connect( ) ``` +### Using DNS domain names to identify instances + +The connector can be configured to use DNS to look up an instance. This would +allow you to configure your application to connect to a database instance, and +centrally configure which instance in your DNS zone. + +#### Configure your DNS Records + +Add a DNS TXT record for the Cloud SQL instance to a **private** DNS server +or a private Google Cloud DNS Zone used by your application. + +> [!NOTE] +> +> You are strongly discouraged from adding DNS records for your +> Cloud SQL instances to a public DNS server. This would allow anyone on the +> internet to discover the Cloud SQL instance name. + +For example: suppose you wanted to use the domain name +`prod-db.mycompany.example.com` to connect to your database instance +`my-project:region:my-instance`. You would create the following DNS record: + +* Record type: `TXT` +* Name: `prod-db.mycompany.example.com` – This is the domain name used by the application +* Value: `my-project:my-region:my-instance` – This is the Cloud SQL instance connection name + +#### Configure the connector + +Configure the connector to resolve DNS names by initializing it with +`resolver=DnsResolver` and replacing the instance connection name with the DNS +name in `connector.connect`: + +```python +from google.cloud.sql.connector import Connector, DnsResolver +import pymysql +import sqlalchemy + +# helper function to return SQLAlchemy connection pool +def init_connection_pool(connector: Connector) -> sqlalchemy.engine.Engine: + # function used to generate database connection + def getconn() -> pymysql.connections.Connection: + conn = connector.connect( + "prod-db.mycompany.example.com", # using DNS name + "pymysql", + user="my-user", + password="my-password", + db="my-db-name" + ) + return conn + + # create connection pool + pool = sqlalchemy.create_engine( + "mysql+pymysql://", + creator=getconn, + ) + return pool + +# initialize Cloud SQL Python Connector with `resolver=DnsResolver` +with Connector(resolver=DnsResolver) as connector: + # initialize connection pool + pool = init_connection_pool(connector) + # ... use SQLAlchemy engine normally +``` + ### Using the Python Connector with Python Web Frameworks The Python Connector can be used alongside popular Python web frameworks such diff --git a/google/cloud/sql/connector/__init__.py b/google/cloud/sql/connector/__init__.py index 5b06fcd7f..99a5097a2 100644 --- a/google/cloud/sql/connector/__init__.py +++ b/google/cloud/sql/connector/__init__.py @@ -18,12 +18,16 @@ from google.cloud.sql.connector.connector import create_async_connector from google.cloud.sql.connector.enums import IPTypes from google.cloud.sql.connector.enums import RefreshStrategy +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver from google.cloud.sql.connector.version import __version__ __all__ = [ "__version__", "create_async_connector", "Connector", + "DefaultResolver", + "DnsResolver", "IPTypes", "RefreshStrategy", ] diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index ed305ec53..8a31eb9a0 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -26,6 +26,7 @@ from google.auth.transport import requests from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported from google.cloud.sql.connector.refresh_utils import _downscope_credentials from google.cloud.sql.connector.refresh_utils import retry_50x @@ -128,8 +129,19 @@ async def _get_metadata( resp = await self._client.get(url, headers=headers) if resp.status >= 500: resp = await retry_50x(self._client.get, url, headers=headers) - resp.raise_for_status() - ret_dict = await resp.json() + # try to get response json for better error message + try: + ret_dict = await resp.json() + if resp.status >= 400: + # if detailed error message is in json response, use as error message + message = ret_dict.get("error", {}).get("message") + if message: + resp.reason = message + # skip, raise_for_status will catch all errors in finally block + except Exception: + pass + finally: + resp.raise_for_status() if ret_dict["region"] != region: raise ValueError( @@ -198,8 +210,19 @@ async def _get_ephemeral( resp = await self._client.post(url, headers=headers, json=data) if resp.status >= 500: resp = await retry_50x(self._client.post, url, headers=headers, json=data) - resp.raise_for_status() - ret_dict = await resp.json() + # try to get response json for better error message + try: + ret_dict = await resp.json() + if resp.status >= 400: + # if detailed error message is in json response, use as error message + message = ret_dict.get("error", {}).get("message") + if message: + resp.reason = message + # skip, raise_for_status will catch all errors in finally block + except Exception: + pass + finally: + resp.raise_for_status() ephemeral_cert: str = ret_dict["ephemeralCert"]["cert"] @@ -223,9 +246,7 @@ async def _get_ephemeral( async def get_connection_info( self, - project: str, - region: str, - instance: str, + conn_name: ConnectionName, keys: asyncio.Future, enable_iam_auth: bool, ) -> ConnectionInfo: @@ -233,10 +254,8 @@ async def get_connection_info( Admin API. Args: - project (str): The name of the project the Cloud SQL instance is - located in. - region (str): The region the Cloud SQL instance is located in. - instance (str): Name of the Cloud SQL instance. + conn_name (ConnectionName): The Cloud SQL instance's + connection name. keys (asyncio.Future): A future to the client's public-private key pair. enable_iam_auth (bool): Whether an automatic IAM database @@ -256,16 +275,16 @@ async def get_connection_info( metadata_task = asyncio.create_task( self._get_metadata( - project, - region, - instance, + conn_name.project, + conn_name.region, + conn_name.instance_name, ) ) ephemeral_task = asyncio.create_task( self._get_ephemeral( - project, - instance, + conn_name.project, + conn_name.instance_name, pub_key, enable_iam_auth, ) @@ -289,6 +308,7 @@ async def get_connection_info( ephemeral_cert, expiration = await ephemeral_task return ConnectionInfo( + conn_name, ephemeral_cert, metadata["server_ca_cert"], priv_key, diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index b738063c2..82e3a9018 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -21,6 +21,7 @@ from aiofiles.tempfile import TemporaryDirectory +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.exceptions import TLSVersionError from google.cloud.sql.connector.utils import write_to_file @@ -38,6 +39,7 @@ class ConnectionInfo: """Contains all necessary information to connect securely to the server-side Proxy running on a Cloud SQL instance.""" + conn_name: ConnectionName client_cert: str server_ca_cert: str private_key: bytes diff --git a/google/cloud/sql/connector/connection_name.py b/google/cloud/sql/connector/connection_name.py index d240fb565..1bf711ab7 100644 --- a/google/cloud/sql/connector/connection_name.py +++ b/google/cloud/sql/connector/connection_name.py @@ -31,12 +31,21 @@ class ConnectionName: project: str region: str instance_name: str + domain_name: str = "" def __str__(self) -> str: + if self.domain_name: + return f"{self.domain_name} -> {self.project}:{self.region}:{self.instance_name}" return f"{self.project}:{self.region}:{self.instance_name}" -def _parse_instance_connection_name(connection_name: str) -> ConnectionName: +def _parse_connection_name(connection_name: str) -> ConnectionName: + return _parse_connection_name_with_domain_name(connection_name, "") + + +def _parse_connection_name_with_domain_name( + connection_name: str, domain_name: str +) -> ConnectionName: if CONN_NAME_REGEX.fullmatch(connection_name) is None: raise ValueError( "Arg `instance_connection_string` must have " @@ -48,4 +57,5 @@ def _parse_instance_connection_name(connection_name: str) -> ConnectionName: connection_name_split[1], connection_name_split[3], connection_name_split[4], + domain_name, ) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 7a89d7194..7c5fb20de 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -37,6 +37,8 @@ import google.cloud.sql.connector.pg8000 as pg8000 import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver from google.cloud.sql.connector.utils import format_database_user from google.cloud.sql.connector.utils import generate_keys @@ -63,6 +65,7 @@ def __init__( user_agent: Optional[str] = None, universe_domain: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, + resolver: Type[DefaultResolver] | Type[DnsResolver] = DefaultResolver, ) -> None: """Initializes a Connector instance. @@ -104,6 +107,12 @@ def __init__( of the following: RefreshStrategy.LAZY ("LAZY") or RefreshStrategy.BACKGROUND ("BACKGROUND"). Default: RefreshStrategy.BACKGROUND + + resolver (DefaultResolver | DnsResolver): The class name of the + resolver to use for resolving the Cloud SQL instance connection + name. To resolve a DNS record to an instance connection name, use + DnsResolver. + Default: DefaultResolver """ # if refresh_strategy is str, convert to RefreshStrategy enum if isinstance(refresh_strategy, str): @@ -157,6 +166,7 @@ def __init__( self._enable_iam_auth = enable_iam_auth self._quota_project = quota_project self._user_agent = user_agent + self._resolver = resolver() # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes._from_str(ip_type) @@ -269,31 +279,28 @@ async def connect_async( if (instance_connection_string, enable_iam_auth) in self._cache: cache = self._cache[(instance_connection_string, enable_iam_auth)] else: + conn_name = await self._resolver.resolve(instance_connection_string) if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( - f"['{instance_connection_string}']: Refresh strategy is set" - " to lazy refresh" + f"['{conn_name}']: Refresh strategy is set to lazy refresh" ) cache = LazyRefreshCache( - instance_connection_string, + conn_name, self._client, self._keys, enable_iam_auth, ) else: logger.debug( - f"['{instance_connection_string}']: Refresh strategy is set" - " to backgound refresh" + f"['{conn_name}']: Refresh strategy is set to backgound refresh" ) cache = RefreshAheadCache( - instance_connection_string, + conn_name, self._client, self._keys, enable_iam_auth, ) - logger.debug( - f"['{instance_connection_string}']: Connection info added to cache" - ) + logger.debug(f"['{conn_name}']: Connection info added to cache") self._cache[(instance_connection_string, enable_iam_auth)] = cache connect_func = { @@ -332,9 +339,7 @@ async def connect_async( # the cache and re-raise the error await self._remove_cached(instance_connection_string, enable_iam_auth) raise - logger.debug( - f"['{instance_connection_string}']: Connecting to {ip_address}:3307" - ) + logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") # format `user` param for automatic IAM database authn if enable_iam_auth: formatted_user = format_database_user( diff --git a/google/cloud/sql/connector/exceptions.py b/google/cloud/sql/connector/exceptions.py index 7bff2300d..92e3e5662 100644 --- a/google/cloud/sql/connector/exceptions.py +++ b/google/cloud/sql/connector/exceptions.py @@ -70,3 +70,10 @@ class IncompatibleDriverError(Exception): Exception to be raised when the database driver given is for the wrong database engine. (i.e. asyncpg for a MySQL database) """ + + +class DnsResolutionError(Exception): + """ + Exception to be raised when an instance connection name can not be resolved + from a DNS record. + """ diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index f244b8cf3..5df272fe2 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -22,11 +22,9 @@ from datetime import timezone import logging -import aiohttp - from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo -from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import RefreshNotValidError from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter from google.cloud.sql.connector.refresh_utils import _is_valid @@ -47,7 +45,7 @@ class RefreshAheadCache: def __init__( self, - instance_connection_string: str, + conn_name: ConnectionName, client: CloudSQLClient, keys: asyncio.Future, enable_iam_auth: bool = False, @@ -55,8 +53,8 @@ def __init__( """Initializes a RefreshAheadCache instance. Args: - instance_connection_string (str): The Cloud SQL Instance's - connection string (also known as an instance connection name). + conn_name (ConnectionName): The Cloud SQL instance's + connection name. client (CloudSQLClient): The Cloud SQL Client instance. keys (asyncio.Future): A future to the client's public-private key pair. @@ -64,13 +62,6 @@ def __init__( (Postgres and MySQL) as the default authentication method for all connections. """ - # validate and parse instance connection name - conn_name = _parse_instance_connection_name(instance_connection_string) - self._project, self._region, self._instance = ( - conn_name.project, - conn_name.region, - conn_name.instance_name, - ) self._conn_name = conn_name self._enable_iam_auth = enable_iam_auth @@ -108,35 +99,24 @@ async def _perform_refresh(self) -> ConnectionInfo: """ self._refresh_in_progress.set() logger.debug( - f"['{self._conn_name}']: Connection info refresh " "operation started" + f"['{self._conn_name}']: Connection info refresh operation started" ) try: await self._refresh_rate_limiter.acquire() connection_info = await self._client.get_connection_info( - self._project, - self._region, - self._instance, + self._conn_name, self._keys, self._enable_iam_auth, ) logger.debug( - f"['{self._conn_name}']: Connection info " "refresh operation complete" + f"['{self._conn_name}']: Connection info refresh operation complete" ) logger.debug( f"['{self._conn_name}']: Current certificate " f"expiration = {connection_info.expiration.isoformat()}" ) - except aiohttp.ClientResponseError as e: - logger.debug( - f"['{self._conn_name}']: Connection info " - f"refresh operation failed: {str(e)}" - ) - if e.status == 403: - e.message = "Forbidden: Authenticated IAM principal does not seem authorized to make API request. Verify 'Cloud SQL Admin API' is enabled within your GCP project and 'Cloud SQL Client' role has been granted to IAM principal." - raise - except Exception as e: logger.debug( f"['{self._conn_name}']: Connection info " diff --git a/google/cloud/sql/connector/lazy.py b/google/cloud/sql/connector/lazy.py index 672f989e8..1bc4f90f8 100644 --- a/google/cloud/sql/connector/lazy.py +++ b/google/cloud/sql/connector/lazy.py @@ -21,7 +21,7 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo -from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.refresh_utils import _refresh_buffer logger = logging.getLogger(name=__name__) @@ -38,7 +38,7 @@ class LazyRefreshCache: def __init__( self, - instance_connection_string: str, + conn_name: ConnectionName, client: CloudSQLClient, keys: asyncio.Future, enable_iam_auth: bool = False, @@ -46,8 +46,8 @@ def __init__( """Initializes a LazyRefreshCache instance. Args: - instance_connection_string (str): The Cloud SQL Instance's - connection string (also known as an instance connection name). + conn_name (ConnectionName): The Cloud SQL instance's + connection name. client (CloudSQLClient): The Cloud SQL Client instance. keys (asyncio.Future): A future to the client's public-private key pair. @@ -55,15 +55,7 @@ def __init__( (Postgres and MySQL) as the default authentication method for all connections. """ - # validate and parse instance connection name - conn_name = _parse_instance_connection_name(instance_connection_string) - self._project, self._region, self._instance = ( - conn_name.project, - conn_name.region, - conn_name.instance_name, - ) self._conn_name = conn_name - self._enable_iam_auth = enable_iam_auth self._keys = keys self._client = client @@ -103,9 +95,7 @@ async def connect_info(self) -> ConnectionInfo: ) try: conn_info = await self._client.get_connection_info( - self._project, - self._region, - self._instance, + self._conn_name, self._keys, self._enable_iam_auth, ) diff --git a/google/cloud/sql/connector/resolver.py b/google/cloud/sql/connector/resolver.py new file mode 100644 index 000000000..39efd0492 --- /dev/null +++ b/google/cloud/sql/connector/resolver.py @@ -0,0 +1,70 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dns.asyncresolver + +from google.cloud.sql.connector.connection_name import ( + _parse_connection_name_with_domain_name, +) +from google.cloud.sql.connector.connection_name import _parse_connection_name +from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.exceptions import DnsResolutionError + + +class DefaultResolver: + """DefaultResolver simply validates and parses instance connection name.""" + + async def resolve(self, connection_name: str) -> ConnectionName: + return _parse_connection_name(connection_name) + + +class DnsResolver(dns.asyncresolver.Resolver): + """ + DnsResolver resolves domain names into instance connection names using + TXT records in DNS. + """ + + async def resolve(self, dns: str) -> ConnectionName: # type: ignore + try: + conn_name = _parse_connection_name(dns) + except ValueError: + # The connection name was not project:region:instance format. + # Attempt to query a TXT record to get connection name. + conn_name = await self.query_dns(dns) + return conn_name + + async def query_dns(self, dns: str) -> ConnectionName: + try: + # Attempt to query the TXT records. + records = await super().resolve(dns, "TXT", raise_on_no_answer=True) + # Sort the TXT record values alphabetically, strip quotes as record + # values can be returned as raw strings + rdata = [record.to_text().strip('"') for record in records] + rdata.sort() + # Attempt to parse records, returning the first valid record. + for record in rdata: + try: + conn_name = _parse_connection_name_with_domain_name(record, dns) + return conn_name + except Exception: + continue + # If all records failed to parse, throw error + raise DnsResolutionError( + f"Unable to parse TXT record for `{dns}` -> `{rdata[0]}`" + ) + # Don't override above DnsResolutionError + except DnsResolutionError: + raise + except Exception as e: + raise DnsResolutionError(f"Unable to resolve TXT record for `{dns}`") from e diff --git a/google/cloud/sql/connector/version.py b/google/cloud/sql/connector/version.py index 7f2e3d950..587a7adbf 100644 --- a/google/cloud/sql/connector/version.py +++ b/google/cloud/sql/connector/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.14.0" +__version__ = "1.15.0" diff --git a/noxfile.py b/noxfile.py index 8329b2de8..de1b9b81d 100644 --- a/noxfile.py +++ b/noxfile.py @@ -20,7 +20,7 @@ import nox -BLACK_VERSION = "black==23.12.1" +BLACK_VERSION = "black==24.10.0" ISORT_VERSION = "isort==5.13.2" LINT_PATHS = ["google", "tests", "noxfile.py", "setup.py"] @@ -50,7 +50,10 @@ def lint(session): "--fss", "--check-only", "--diff", - "--profile=google", + "--profile=black", + "--force-single-line-imports", + "--dont-order-by-type", + "--single-line-exclusions=typing", "-w=88", *LINT_PATHS, ) @@ -85,7 +88,10 @@ def format(session): session.run( "isort", "--fss", - "--profile=google", + "--profile=black", + "--force-single-line-imports", + "--dont-order-by-type", + "--single-line-exclusions=typing", "-w=88", *LINT_PATHS, ) diff --git a/requirements-test.txt b/requirements-test.txt index 4aeecede7..40102fc21 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,4 +1,4 @@ -pytest==8.3.3 +pytest==8.3.4 mock==5.1.0 pytest-cov==6.0.0 pytest-asyncio==0.24.0 diff --git a/requirements.txt b/requirements.txt index 3085b3e94..75b53b946 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ aiofiles==24.1.0 -aiohttp==3.11.6 -cryptography==43.0.3 +aiohttp==3.11.10 +cryptography==44.0.0 +dnspython==2.7.0 Requests==2.32.3 google-auth==2.36.0 diff --git a/setup.py b/setup.py index bb70449a5..79c6acf74 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ "aiofiles", "aiohttp", "cryptography>=42.0.0", + "dnspython>=2.0.0", "Requests", "google-auth>=2.28.0", ] diff --git a/tests/conftest.py b/tests/conftest.py index 470fe19f4..3a1a38a27 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,6 +26,7 @@ from unit.mocks import FakeCSQLInstance # type: ignore from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.utils import generate_keys @@ -144,7 +145,7 @@ async def fake_client( async def cache(fake_client: CloudSQLClient) -> AsyncGenerator[RefreshAheadCache, None]: keys = asyncio.create_task(generate_keys()) cache = RefreshAheadCache( - "test-project:test-region:test-instance", + ConnectionName("test-project", "test-region", "test-instance"), client=fake_client, keys=keys, ) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 046e8e51b..af42af0ae 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -15,6 +15,9 @@ import datetime from typing import Optional +from aiohttp import ClientResponseError +from aioresponses import aioresponses +from google.auth.credentials import Credentials from mocks import FakeCredentials import pytest @@ -138,3 +141,130 @@ async def test_CloudSQLClient_user_agent( assert client._user_agent == f"cloud-sql-python-connector/{version}+{driver}" # close client await client.close() + + +async def test_cloud_sql_error_messages_get_metadata( + fake_credentials: Credentials, +) -> None: + """ + Test that Cloud SQL Admin API error messages are raised for _get_metadata. + """ + # mock Cloud SQL Admin API calls with exceptions + client = CloudSQLClient( + sqladmin_api_endpoint="https://sqladmin.googleapis.com", + quota_project=None, + credentials=fake_credentials, + ) + get_url = "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance/connectSettings" + resp_body = { + "error": { + "code": 403, + "message": "Cloud SQL Admin API has not been used in project 123456789 before or it is disabled", + } + } + with aioresponses() as mocked: + mocked.get( + get_url, + status=403, + payload=resp_body, + repeat=True, + ) + with pytest.raises(ClientResponseError) as exc_info: + await client._get_metadata("my-project", "my-region", "my-instance") + assert exc_info.value.status == 403 + assert ( + exc_info.value.message + == "Cloud SQL Admin API has not been used in project 123456789 before or it is disabled" + ) + await client.close() + + +async def test_get_metadata_error_parsing_json( + fake_credentials: Credentials, +) -> None: + """ + Test that aiohttp default error messages are raised when _get_metadata gets + a bad JSON response. + """ + # mock Cloud SQL Admin API calls with exceptions + client = CloudSQLClient( + sqladmin_api_endpoint="https://sqladmin.googleapis.com", + quota_project=None, + credentials=fake_credentials, + ) + get_url = "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance/connectSettings" + resp_body = ["error"] # invalid JSON + with aioresponses() as mocked: + mocked.get( + get_url, + status=403, + payload=resp_body, + repeat=True, + ) + with pytest.raises(ClientResponseError) as exc_info: + await client._get_metadata("my-project", "my-region", "my-instance") + assert exc_info.value.status == 403 + assert exc_info.value.message == "Forbidden" + await client.close() + + +async def test_cloud_sql_error_messages_get_ephemeral( + fake_credentials: Credentials, +) -> None: + """ + Test that Cloud SQL Admin API error messages are raised for _get_ephemeral. + """ + # mock Cloud SQL Admin API calls with exceptions + client = CloudSQLClient( + sqladmin_api_endpoint="https://sqladmin.googleapis.com", + quota_project=None, + credentials=fake_credentials, + ) + post_url = "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance:generateEphemeralCert" + resp_body = { + "error": { + "code": 404, + "message": "The Cloud SQL instance does not exist.", + } + } + with aioresponses() as mocked: + mocked.post( + post_url, + status=404, + payload=resp_body, + repeat=True, + ) + with pytest.raises(ClientResponseError) as exc_info: + await client._get_ephemeral("my-project", "my-instance", "my-key") + assert exc_info.value.status == 404 + assert exc_info.value.message == "The Cloud SQL instance does not exist." + await client.close() + + +async def test_get_ephemeral_error_parsing_json( + fake_credentials: Credentials, +) -> None: + """ + Test that aiohttp default error messages are raised when _get_ephemeral gets + a bad JSON response. + """ + # mock Cloud SQL Admin API calls with exceptions + client = CloudSQLClient( + sqladmin_api_endpoint="https://sqladmin.googleapis.com", + quota_project=None, + credentials=fake_credentials, + ) + post_url = "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance:generateEphemeralCert" + resp_body = ["error"] # invalid JSON + with aioresponses() as mocked: + mocked.post( + post_url, + status=404, + payload=resp_body, + repeat=True, + ) + with pytest.raises(ClientResponseError) as exc_info: + await client._get_ephemeral("my-project", "my-instance", "my-key") + assert exc_info.value.status == 404 + assert exc_info.value.message == "Not Found" + await client.close() diff --git a/tests/unit/test_connection_name.py b/tests/unit/test_connection_name.py index 1e3730424..783e14fe3 100644 --- a/tests/unit/test_connection_name.py +++ b/tests/unit/test_connection_name.py @@ -14,7 +14,10 @@ import pytest # noqa F401 Needed to run the tests -from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +from google.cloud.sql.connector.connection_name import ( + _parse_connection_name_with_domain_name, +) +from google.cloud.sql.connector.connection_name import _parse_connection_name from google.cloud.sql.connector.connection_name import ConnectionName @@ -24,10 +27,22 @@ def test_ConnectionName() -> None: assert conn_name.project == "project" assert conn_name.region == "region" assert conn_name.instance_name == "instance" + assert conn_name.domain_name == "" # test ConnectionName str() method prints instance connection name assert str(conn_name) == "project:region:instance" +def test_ConnectionName_with_domain_name() -> None: + conn_name = ConnectionName("project", "region", "instance", "db.example.com") + # test class attributes are set properly + assert conn_name.project == "project" + assert conn_name.region == "region" + assert conn_name.instance_name == "instance" + assert conn_name.domain_name == "db.example.com" + # test ConnectionName str() method prints with domain name + assert str(conn_name) == "db.example.com -> project:region:instance" + + @pytest.mark.parametrize( "connection_name, expected", [ @@ -38,19 +53,46 @@ def test_ConnectionName() -> None: ), ], ) -def test_parse_instance_connection_name( - connection_name: str, expected: ConnectionName -) -> None: +def test_parse_connection_name(connection_name: str, expected: ConnectionName) -> None: """ - Test that _parse_instance_connection_name works correctly on + Test that _parse_connection_name works correctly on normal instance connection names and domain-scoped projects. """ - assert expected == _parse_instance_connection_name(connection_name) + assert expected == _parse_connection_name(connection_name) -def test_parse_instance_connection_name_bad_conn_name() -> None: +def test_parse_connection_name_bad_conn_name() -> None: """ Tests that ValueError is thrown for bad instance connection names. """ with pytest.raises(ValueError): - _parse_instance_connection_name("project:instance") # missing region + _parse_connection_name("project:instance") # missing region + + +@pytest.mark.parametrize( + "connection_name, domain_name, expected", + [ + ( + "project:region:instance", + "db.example.com", + ConnectionName("project", "region", "instance", "db.example.com"), + ), + ( + "domain-prefix:project:region:instance", + "db.example.com", + ConnectionName( + "domain-prefix:project", "region", "instance", "db.example.com" + ), + ), + ], +) +def test_parse_connection_name_with_domain_name( + connection_name: str, domain_name: str, expected: ConnectionName +) -> None: + """ + Test that _parse_connection_name_with_domain_name works correctly on + normal instance connection names and domain-scoped projects. + """ + assert expected == _parse_connection_name_with_domain_name( + connection_name, domain_name + ) diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index fd18f2d5e..d4f53ed51 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -26,6 +26,7 @@ from google.cloud.sql.connector import create_async_connector from google.cloud.sql.connector import IPTypes from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.exceptions import IncompatibleDriverError from google.cloud.sql.connector.instance import RefreshAheadCache @@ -322,18 +323,18 @@ async def test_Connector_remove_cached_bad_instance( async with Connector( credentials=fake_credentials, loop=asyncio.get_running_loop() ) as connector: - conn_name = "bad-project:bad-region:bad-inst" + conn_name = ConnectionName("bad-project", "bad-region", "bad-inst") # populate cache cache = RefreshAheadCache(conn_name, fake_client, connector._keys) - connector._cache[(conn_name, False)] = cache + connector._cache[(str(conn_name), False)] = cache # aiohttp client should throw a 404 ClientResponseError with pytest.raises(ClientResponseError): await connector.connect_async( - conn_name, + str(conn_name), "pg8000", ) # check that cache has been removed from dict - assert (conn_name, False) not in connector._cache + assert (str(conn_name), False) not in connector._cache async def test_Connector_remove_cached_no_ip_type( @@ -348,21 +349,21 @@ async def test_Connector_remove_cached_no_ip_type( async with Connector( credentials=fake_credentials, loop=asyncio.get_running_loop() ) as connector: - conn_name = "test-project:test-region:test-instance" + conn_name = ConnectionName("test-project", "test-region", "test-instance") # populate cache cache = RefreshAheadCache(conn_name, fake_client, connector._keys) - connector._cache[(conn_name, False)] = cache + connector._cache[(str(conn_name), False)] = cache # test instance does not have Private IP, thus should invalidate cache with pytest.raises(CloudSQLIPTypeError): await connector.connect_async( - conn_name, + str(conn_name), "pg8000", user="my-user", password="my-pass", ip_type="private", ) # check that cache has been removed from dict - assert (conn_name, False) not in connector._cache + assert (str(conn_name), False) not in connector._cache def test_default_universe_domain(fake_credentials: Credentials) -> None: diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 5b0887aa2..aeedf3399 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -17,10 +17,6 @@ import asyncio import datetime -from aiohttp import ClientResponseError -from aiohttp import RequestInfo -from aioresponses import aioresponses -from google.auth.credentials import Credentials from mock import patch import mocks import pytest # noqa F401 Needed to run the tests @@ -28,6 +24,7 @@ from google.cloud.sql.connector import IPTypes from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.instance import RefreshAheadCache @@ -50,9 +47,9 @@ async def test_Instance_init( can tell if the connection string that's passed in is formatted correctly. """ assert ( - cache._project == "test-project" - and cache._region == "test-region" - and cache._instance == "test-instance" + cache._conn_name.project == "test-project" + and cache._conn_name.region == "test-region" + and cache._conn_name.instance_name == "test-instance" ) assert cache._enable_iam_auth is False @@ -266,59 +263,6 @@ async def test_get_preferred_ip_CloudSQLIPTypeError(cache: RefreshAheadCache) -> instance_metadata.get_preferred_ip(IPTypes.PSC) -@pytest.mark.asyncio -async def test_ClientResponseError( - fake_credentials: Credentials, -) -> None: - """ - Test that detailed error message is applied to ClientResponseError. - """ - # mock Cloud SQL Admin API calls with exceptions - keys = asyncio.create_task(generate_keys()) - client = CloudSQLClient( - sqladmin_api_endpoint="https://sqladmin.googleapis.com", - quota_project=None, - credentials=fake_credentials, - ) - get_url = "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance/connectSettings" - post_url = "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance:generateEphemeralCert" - with aioresponses() as mocked: - mocked.get( - get_url, - status=403, - exception=ClientResponseError( - RequestInfo(get_url, "GET", headers=[]), history=[], status=403 # type: ignore - ), - repeat=True, - ) - mocked.post( - post_url, - status=403, - exception=ClientResponseError( - RequestInfo(post_url, "POST", headers=[]), history=[], status=403 # type: ignore - ), - repeat=True, - ) - cache = RefreshAheadCache( - "my-project:my-region:my-instance", - client, - keys, - ) - try: - await cache._current - except ClientResponseError as e: - assert e.status == 403 - assert ( - e.message == "Forbidden: Authenticated IAM principal does not " - "seem authorized to make API request. Verify " - "'Cloud SQL Admin API' is enabled within your GCP project and " - "'Cloud SQL Client' role has been granted to IAM principal." - ) - finally: - await cache.close() - await client.close() - - @pytest.mark.asyncio async def test_AutoIAMAuthNotSupportedError(fake_client: CloudSQLClient) -> None: """ @@ -328,7 +272,7 @@ async def test_AutoIAMAuthNotSupportedError(fake_client: CloudSQLClient) -> None # generate client key pair keys = asyncio.create_task(generate_keys()) cache = RefreshAheadCache( - "test-project:test-region:sqlserver-instance", + ConnectionName("test-project", "test-region", "sqlserver-instance"), client=fake_client, keys=keys, enable_iam_auth=True, @@ -339,7 +283,7 @@ async def test_AutoIAMAuthNotSupportedError(fake_client: CloudSQLClient) -> None async def test_ConnectionInfo_caches_sslcontext() -> None: info = ConnectionInfo( - "cert", "cert", "key".encode(), {}, "POSTGRES", datetime.datetime.now() + "", "cert", "cert", "key".encode(), {}, "POSTGRES", datetime.datetime.now() ) # context should default to None assert info.context is None diff --git a/tests/unit/test_lazy.py b/tests/unit/test_lazy.py index 27cd80b4f..344b073e8 100644 --- a/tests/unit/test_lazy.py +++ b/tests/unit/test_lazy.py @@ -16,6 +16,7 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.lazy import LazyRefreshCache from google.cloud.sql.connector.utils import generate_keys @@ -26,7 +27,7 @@ async def test_LazyRefreshCache_connect_info(fake_client: CloudSQLClient) -> Non """ keys = asyncio.create_task(generate_keys()) cache = LazyRefreshCache( - "test-project:test-region:test-instance", + ConnectionName("test-project", "test-region", "test-instance"), client=fake_client, keys=keys, enable_iam_auth=False, @@ -47,7 +48,7 @@ async def test_LazyRefreshCache_force_refresh(fake_client: CloudSQLClient) -> No """ keys = asyncio.create_task(generate_keys()) cache = LazyRefreshCache( - "test-project:test-region:test-instance", + ConnectionName("test-project", "test-region", "test-instance"), client=fake_client, keys=keys, enable_iam_auth=False, diff --git a/tests/unit/test_resolver.py b/tests/unit/test_resolver.py new file mode 100644 index 000000000..a9a7f2632 --- /dev/null +++ b/tests/unit/test_resolver.py @@ -0,0 +1,131 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dns.message +import dns.rdataclass +import dns.rdatatype +import dns.resolver +from mock import patch +import pytest + +from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.exceptions import DnsResolutionError +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver + +conn_str = "my-project:my-region:my-instance" +conn_name = ConnectionName("my-project", "my-region", "my-instance") +conn_name_with_domain = ConnectionName( + "my-project", "my-region", "my-instance", "db.example.com" +) + + +async def test_DefaultResolver() -> None: + """Test DefaultResolver just parses instance connection string.""" + resolver = DefaultResolver() + result = await resolver.resolve(conn_str) + assert result == conn_name + + +async def test_DnsResolver_with_conn_str() -> None: + """Test DnsResolver with instance connection name just parses connection string.""" + resolver = DnsResolver() + result = await resolver.resolve(conn_str) + assert result == conn_name + + +query_text = """id 1234 +opcode QUERY +rcode NOERROR +flags QR AA RD RA +;QUESTION +db.example.com. IN TXT +;ANSWER +db.example.com. 0 IN TXT "test-project:test-region:test-instance" +db.example.com. 0 IN TXT "my-project:my-region:my-instance" +;AUTHORITY +;ADDITIONAL +""" + + +async def test_DnsResolver_with_dns_name() -> None: + """Test DnsResolver resolves TXT record into proper instance connection name. + + Should sort valid TXT records alphabetically and take first one. + """ + # Patch DNS resolution with valid TXT records + with patch("dns.asyncresolver.Resolver.resolve") as mock_connect: + answer = dns.resolver.Answer( + "db.example.com", + dns.rdatatype.TXT, + dns.rdataclass.IN, + dns.message.from_text(query_text), + ) + mock_connect.return_value = answer + resolver = DnsResolver() + resolver.port = 5053 + # Resolution should return first value sorted alphabetically + result = await resolver.resolve("db.example.com") + assert result == conn_name_with_domain + + +query_text_malformed = """id 1234 +opcode QUERY +rcode NOERROR +flags QR AA RD RA +;QUESTION +bad.example.com. IN TXT +;ANSWER +bad.example.com. 0 IN TXT "malformed-instance-name" +;AUTHORITY +;ADDITIONAL +""" + + +async def test_DnsResolver_with_malformed_txt() -> None: + """Test DnsResolver with TXT record that holds malformed instance connection name. + + Should throw DnsResolutionError + """ + # patch DNS resolution with malformed TXT record + with patch("dns.asyncresolver.Resolver.resolve") as mock_connect: + answer = dns.resolver.Answer( + "bad.example.com", + dns.rdatatype.TXT, + dns.rdataclass.IN, + dns.message.from_text(query_text_malformed), + ) + mock_connect.return_value = answer + resolver = DnsResolver() + resolver.port = 5053 + with pytest.raises(DnsResolutionError) as exc_info: + await resolver.resolve("bad.example.com") + assert ( + exc_info.value.args[0] + == "Unable to parse TXT record for `bad.example.com` -> `malformed-instance-name`" + ) + + +async def test_DnsResolver_with_bad_dns_name() -> None: + """Test DnsResolver with bad dns name. + + Should throw DnsResolutionError + """ + resolver = DnsResolver() + resolver.port = 5053 + # set lifetime to 1 second for shorter timeout + resolver.lifetime = 1 + with pytest.raises(DnsResolutionError) as exc_info: + await resolver.resolve("bad.dns.com") + assert exc_info.value.args[0] == "Unable to resolve TXT record for `bad.dns.com`"