From 9ca72dbb56b2aa40c8c0a3a1773ca08d23113add Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Tue, 26 Nov 2024 22:27:54 +0100 Subject: [PATCH 01/13] chore(deps): update dependency aiohttp to v3.11.7 (#1198) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3085b3e94..e285d4a00 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ aiofiles==24.1.0 -aiohttp==3.11.6 +aiohttp==3.11.7 cryptography==43.0.3 Requests==2.32.3 google-auth==2.36.0 From f99661e6cc4fa8a1a5322c19b17701e0ab0ca9a1 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Thu, 28 Nov 2024 20:15:53 +0100 Subject: [PATCH 02/13] chore(deps): Update github/codeql-action action to v3.27.5 (#1197) --- .github/workflows/codeql.yml | 6 +++--- .github/workflows/scorecard.yml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index ba6439848..1116c9986 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -46,16 +46,16 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@ea9e4e37992a54ee68a9622e985e60c8e8f12d9f # v3.27.4 + uses: github/codeql-action/init@f09c1c0a94de965c15400f5634aa42fac8fb8f88 # v3.27.5 with: languages: ${{ matrix.language }} # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). # If this step fails, then you should remove it and run the build manually - name: Autobuild - uses: github/codeql-action/autobuild@ea9e4e37992a54ee68a9622e985e60c8e8f12d9f # v3.27.4 + uses: github/codeql-action/autobuild@f09c1c0a94de965c15400f5634aa42fac8fb8f88 # v3.27.5 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@ea9e4e37992a54ee68a9622e985e60c8e8f12d9f # v3.27.4 + uses: github/codeql-action/analyze@f09c1c0a94de965c15400f5634aa42fac8fb8f88 # v3.27.5 with: category: "/language:${{matrix.language}}" diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 37c9428dd..ee3efa120 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -65,6 +65,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@ea9e4e37992a54ee68a9622e985e60c8e8f12d9f # v3.27.4 + uses: github/codeql-action/upload-sarif@f09c1c0a94de965c15400f5634aa42fac8fb8f88 # v3.27.5 with: sarif_file: resultsFiltered.sarif From 5c9a73e8eab5d14efe7a249244162a1fcab7e0ec Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Sat, 30 Nov 2024 03:51:16 +0100 Subject: [PATCH 03/13] chore(deps): update dependency aiohttp to v3.11.8 (#1203) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e285d4a00..55695e711 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ aiofiles==24.1.0 -aiohttp==3.11.7 +aiohttp==3.11.8 cryptography==43.0.3 Requests==2.32.3 google-auth==2.36.0 From 1dd2804638c774774ba16b1e04517dfd28c8492d Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Tue, 3 Dec 2024 02:29:04 +0100 Subject: [PATCH 04/13] chore(deps): update python-nonmajor (#1206) --- requirements-test.txt | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-test.txt b/requirements-test.txt index 4aeecede7..40102fc21 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,4 +1,4 @@ -pytest==8.3.3 +pytest==8.3.4 mock==5.1.0 pytest-cov==6.0.0 pytest-asyncio==0.24.0 diff --git a/requirements.txt b/requirements.txt index 55695e711..f3c2e8d1f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ aiofiles==24.1.0 -aiohttp==3.11.8 +aiohttp==3.11.9 cryptography==43.0.3 Requests==2.32.3 google-auth==2.36.0 From 968b6b2b5691d00edd164fdc09a3054faccbc2c6 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Tue, 3 Dec 2024 02:34:33 +0100 Subject: [PATCH 05/13] chore(deps): update dependency cryptography to v44 (#1202) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f3c2e8d1f..466849d56 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ aiofiles==24.1.0 aiohttp==3.11.9 -cryptography==43.0.3 +cryptography==44.0.0 Requests==2.32.3 google-auth==2.36.0 From 11f9fe93617cd9a161c69d527d3e7592fcbe56af Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Tue, 3 Dec 2024 16:05:13 -0500 Subject: [PATCH 06/13] feat: improve aiohttp client error messages (#1201) By default, aiohttp on bad requests will discard response body and output generic error message in response, Bad Request, Forbidden, etc. The Cloud SQL Admin APIs response body contains more detailed error messages. We need to raise these to the end user for them to be able to resolve common config issues on their own. This PR implements a more robust solution, copying the actual Cloud SQL Admin API response body's error message to the end user. --- google/cloud/sql/connector/client.py | 30 +++++- google/cloud/sql/connector/instance.py | 11 --- tests/unit/test_client.py | 130 +++++++++++++++++++++++++ tests/unit/test_instance.py | 57 ----------- 4 files changed, 156 insertions(+), 72 deletions(-) diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index ed305ec53..61b77d561 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -128,8 +128,19 @@ async def _get_metadata( resp = await self._client.get(url, headers=headers) if resp.status >= 500: resp = await retry_50x(self._client.get, url, headers=headers) - resp.raise_for_status() - ret_dict = await resp.json() + # try to get response json for better error message + try: + ret_dict = await resp.json() + if resp.status >= 400: + # if detailed error message is in json response, use as error message + message = ret_dict.get("error", {}).get("message") + if message: + resp.reason = message + # skip, raise_for_status will catch all errors in finally block + except Exception: + pass + finally: + resp.raise_for_status() if ret_dict["region"] != region: raise ValueError( @@ -198,8 +209,19 @@ async def _get_ephemeral( resp = await self._client.post(url, headers=headers, json=data) if resp.status >= 500: resp = await retry_50x(self._client.post, url, headers=headers, json=data) - resp.raise_for_status() - ret_dict = await resp.json() + # try to get response json for better error message + try: + ret_dict = await resp.json() + if resp.status >= 400: + # if detailed error message is in json response, use as error message + message = ret_dict.get("error", {}).get("message") + if message: + resp.reason = message + # skip, raise_for_status will catch all errors in finally block + except Exception: + pass + finally: + resp.raise_for_status() ephemeral_cert: str = ret_dict["ephemeralCert"]["cert"] diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index f244b8cf3..9cf9bc787 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -22,8 +22,6 @@ from datetime import timezone import logging -import aiohttp - from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo from google.cloud.sql.connector.connection_name import _parse_instance_connection_name @@ -128,15 +126,6 @@ async def _perform_refresh(self) -> ConnectionInfo: f"expiration = {connection_info.expiration.isoformat()}" ) - except aiohttp.ClientResponseError as e: - logger.debug( - f"['{self._conn_name}']: Connection info " - f"refresh operation failed: {str(e)}" - ) - if e.status == 403: - e.message = "Forbidden: Authenticated IAM principal does not seem authorized to make API request. Verify 'Cloud SQL Admin API' is enabled within your GCP project and 'Cloud SQL Client' role has been granted to IAM principal." - raise - except Exception as e: logger.debug( f"['{self._conn_name}']: Connection info " diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 046e8e51b..af42af0ae 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -15,6 +15,9 @@ import datetime from typing import Optional +from aiohttp import ClientResponseError +from aioresponses import aioresponses +from google.auth.credentials import Credentials from mocks import FakeCredentials import pytest @@ -138,3 +141,130 @@ async def test_CloudSQLClient_user_agent( assert client._user_agent == f"cloud-sql-python-connector/{version}+{driver}" # close client await client.close() + + +async def test_cloud_sql_error_messages_get_metadata( + fake_credentials: Credentials, +) -> None: + """ + Test that Cloud SQL Admin API error messages are raised for _get_metadata. + """ + # mock Cloud SQL Admin API calls with exceptions + client = CloudSQLClient( + sqladmin_api_endpoint="https://sqladmin.googleapis.com", + quota_project=None, + credentials=fake_credentials, + ) + get_url = "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance/connectSettings" + resp_body = { + "error": { + "code": 403, + "message": "Cloud SQL Admin API has not been used in project 123456789 before or it is disabled", + } + } + with aioresponses() as mocked: + mocked.get( + get_url, + status=403, + payload=resp_body, + repeat=True, + ) + with pytest.raises(ClientResponseError) as exc_info: + await client._get_metadata("my-project", "my-region", "my-instance") + assert exc_info.value.status == 403 + assert ( + exc_info.value.message + == "Cloud SQL Admin API has not been used in project 123456789 before or it is disabled" + ) + await client.close() + + +async def test_get_metadata_error_parsing_json( + fake_credentials: Credentials, +) -> None: + """ + Test that aiohttp default error messages are raised when _get_metadata gets + a bad JSON response. + """ + # mock Cloud SQL Admin API calls with exceptions + client = CloudSQLClient( + sqladmin_api_endpoint="https://sqladmin.googleapis.com", + quota_project=None, + credentials=fake_credentials, + ) + get_url = "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance/connectSettings" + resp_body = ["error"] # invalid JSON + with aioresponses() as mocked: + mocked.get( + get_url, + status=403, + payload=resp_body, + repeat=True, + ) + with pytest.raises(ClientResponseError) as exc_info: + await client._get_metadata("my-project", "my-region", "my-instance") + assert exc_info.value.status == 403 + assert exc_info.value.message == "Forbidden" + await client.close() + + +async def test_cloud_sql_error_messages_get_ephemeral( + fake_credentials: Credentials, +) -> None: + """ + Test that Cloud SQL Admin API error messages are raised for _get_ephemeral. + """ + # mock Cloud SQL Admin API calls with exceptions + client = CloudSQLClient( + sqladmin_api_endpoint="https://sqladmin.googleapis.com", + quota_project=None, + credentials=fake_credentials, + ) + post_url = "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance:generateEphemeralCert" + resp_body = { + "error": { + "code": 404, + "message": "The Cloud SQL instance does not exist.", + } + } + with aioresponses() as mocked: + mocked.post( + post_url, + status=404, + payload=resp_body, + repeat=True, + ) + with pytest.raises(ClientResponseError) as exc_info: + await client._get_ephemeral("my-project", "my-instance", "my-key") + assert exc_info.value.status == 404 + assert exc_info.value.message == "The Cloud SQL instance does not exist." + await client.close() + + +async def test_get_ephemeral_error_parsing_json( + fake_credentials: Credentials, +) -> None: + """ + Test that aiohttp default error messages are raised when _get_ephemeral gets + a bad JSON response. + """ + # mock Cloud SQL Admin API calls with exceptions + client = CloudSQLClient( + sqladmin_api_endpoint="https://sqladmin.googleapis.com", + quota_project=None, + credentials=fake_credentials, + ) + post_url = "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance:generateEphemeralCert" + resp_body = ["error"] # invalid JSON + with aioresponses() as mocked: + mocked.post( + post_url, + status=404, + payload=resp_body, + repeat=True, + ) + with pytest.raises(ClientResponseError) as exc_info: + await client._get_ephemeral("my-project", "my-instance", "my-key") + assert exc_info.value.status == 404 + assert exc_info.value.message == "Not Found" + await client.close() diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 5b0887aa2..3ce0386b2 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -17,10 +17,6 @@ import asyncio import datetime -from aiohttp import ClientResponseError -from aiohttp import RequestInfo -from aioresponses import aioresponses -from google.auth.credentials import Credentials from mock import patch import mocks import pytest # noqa F401 Needed to run the tests @@ -266,59 +262,6 @@ async def test_get_preferred_ip_CloudSQLIPTypeError(cache: RefreshAheadCache) -> instance_metadata.get_preferred_ip(IPTypes.PSC) -@pytest.mark.asyncio -async def test_ClientResponseError( - fake_credentials: Credentials, -) -> None: - """ - Test that detailed error message is applied to ClientResponseError. - """ - # mock Cloud SQL Admin API calls with exceptions - keys = asyncio.create_task(generate_keys()) - client = CloudSQLClient( - sqladmin_api_endpoint="https://sqladmin.googleapis.com", - quota_project=None, - credentials=fake_credentials, - ) - get_url = "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance/connectSettings" - post_url = "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance:generateEphemeralCert" - with aioresponses() as mocked: - mocked.get( - get_url, - status=403, - exception=ClientResponseError( - RequestInfo(get_url, "GET", headers=[]), history=[], status=403 # type: ignore - ), - repeat=True, - ) - mocked.post( - post_url, - status=403, - exception=ClientResponseError( - RequestInfo(post_url, "POST", headers=[]), history=[], status=403 # type: ignore - ), - repeat=True, - ) - cache = RefreshAheadCache( - "my-project:my-region:my-instance", - client, - keys, - ) - try: - await cache._current - except ClientResponseError as e: - assert e.status == 403 - assert ( - e.message == "Forbidden: Authenticated IAM principal does not " - "seem authorized to make API request. Verify " - "'Cloud SQL Admin API' is enabled within your GCP project and " - "'Cloud SQL Client' role has been granted to IAM principal." - ) - finally: - await cache.close() - await client.close() - - @pytest.mark.asyncio async def test_AutoIAMAuthNotSupportedError(fake_client: CloudSQLClient) -> None: """ From 1a8f2743c661806528c54110fd2a08384398816c Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Tue, 3 Dec 2024 20:06:35 -0500 Subject: [PATCH 07/13] feat: add support for DNS names with Connector (#1204) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Connector may be configured to use a DNS name to look up the instance name instead of configuring the connector with the instance connection name directly. Add a DNS TXT record for the Cloud SQL instance to a private DNS server or a private Google Cloud DNS Zone used by your application. For example: Record type: TXT Name: prod-db.mycompany.example.com – This is the domain name used by the application Value: my-project:my-region:my-instance – This is the instance connection name Configure the Connector to use a DNS name via setting resolver=DnsResolver --- README.md | 63 +++++++++++ google/cloud/sql/connector/__init__.py | 4 + google/cloud/sql/connector/connector.py | 16 ++- google/cloud/sql/connector/exceptions.py | 7 ++ google/cloud/sql/connector/instance.py | 10 +- google/cloud/sql/connector/lazy.py | 10 +- google/cloud/sql/connector/resolver.py | 67 ++++++++++++ requirements.txt | 1 + setup.py | 1 + tests/conftest.py | 3 +- tests/unit/test_connector.py | 17 +-- tests/unit/test_instance.py | 3 +- tests/unit/test_lazy.py | 5 +- tests/unit/test_resolver.py | 128 +++++++++++++++++++++++ 14 files changed, 309 insertions(+), 26 deletions(-) create mode 100644 google/cloud/sql/connector/resolver.py create mode 100644 tests/unit/test_resolver.py diff --git a/README.md b/README.md index 28553f972..1f0e633b9 100644 --- a/README.md +++ b/README.md @@ -365,6 +365,69 @@ conn = connector.connect( ) ``` +### Using DNS domain names to identify instances + +The connector can be configured to use DNS to look up an instance. This would +allow you to configure your application to connect to a database instance, and +centrally configure which instance in your DNS zone. + +#### Configure your DNS Records + +Add a DNS TXT record for the Cloud SQL instance to a **private** DNS server +or a private Google Cloud DNS Zone used by your application. + +> [!NOTE] +> +> You are strongly discouraged from adding DNS records for your +> Cloud SQL instances to a public DNS server. This would allow anyone on the +> internet to discover the Cloud SQL instance name. + +For example: suppose you wanted to use the domain name +`prod-db.mycompany.example.com` to connect to your database instance +`my-project:region:my-instance`. You would create the following DNS record: + +* Record type: `TXT` +* Name: `prod-db.mycompany.example.com` – This is the domain name used by the application +* Value: `my-project:my-region:my-instance` – This is the Cloud SQL instance connection name + +#### Configure the connector + +Configure the connector to resolve DNS names by initializing it with +`resolver=DnsResolver` and replacing the instance connection name with the DNS +name in `connector.connect`: + +```python +from google.cloud.sql.connector import Connector, DnsResolver +import pymysql +import sqlalchemy + +# helper function to return SQLAlchemy connection pool +def init_connection_pool(connector: Connector) -> sqlalchemy.engine.Engine: + # function used to generate database connection + def getconn() -> pymysql.connections.Connection: + conn = connector.connect( + "prod-db.mycompany.example.com", # using DNS name + "pymysql", + user="my-user", + password="my-password", + db="my-db-name" + ) + return conn + + # create connection pool + pool = sqlalchemy.create_engine( + "mysql+pymysql://", + creator=getconn, + ) + return pool + +# initialize Cloud SQL Python Connector with `resolver=DnsResolver` +with Connector(resolver=DnsResolver) as connector: + # initialize connection pool + pool = init_connection_pool(connector) + # ... use SQLAlchemy engine normally +``` + ### Using the Python Connector with Python Web Frameworks The Python Connector can be used alongside popular Python web frameworks such diff --git a/google/cloud/sql/connector/__init__.py b/google/cloud/sql/connector/__init__.py index 5b06fcd7f..99a5097a2 100644 --- a/google/cloud/sql/connector/__init__.py +++ b/google/cloud/sql/connector/__init__.py @@ -18,12 +18,16 @@ from google.cloud.sql.connector.connector import create_async_connector from google.cloud.sql.connector.enums import IPTypes from google.cloud.sql.connector.enums import RefreshStrategy +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver from google.cloud.sql.connector.version import __version__ __all__ = [ "__version__", "create_async_connector", "Connector", + "DefaultResolver", + "DnsResolver", "IPTypes", "RefreshStrategy", ] diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 7a89d7194..1e67373eb 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -37,6 +37,8 @@ import google.cloud.sql.connector.pg8000 as pg8000 import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver from google.cloud.sql.connector.utils import format_database_user from google.cloud.sql.connector.utils import generate_keys @@ -63,6 +65,7 @@ def __init__( user_agent: Optional[str] = None, universe_domain: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, + resolver: Type[DefaultResolver] | Type[DnsResolver] = DefaultResolver, ) -> None: """Initializes a Connector instance. @@ -104,6 +107,13 @@ def __init__( of the following: RefreshStrategy.LAZY ("LAZY") or RefreshStrategy.BACKGROUND ("BACKGROUND"). Default: RefreshStrategy.BACKGROUND + + resolver (DefaultResolver | DnsResolver): The class name of the + resolver to use for resolving the Cloud SQL instance connection + name. To resolve a DNS record to an instance connection name, use + DnsResolver. + Default: DefaultResolver + """ # if refresh_strategy is str, convert to RefreshStrategy enum if isinstance(refresh_strategy, str): @@ -157,6 +167,7 @@ def __init__( self._enable_iam_auth = enable_iam_auth self._quota_project = quota_project self._user_agent = user_agent + self._resolver = resolver() # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes._from_str(ip_type) @@ -269,13 +280,14 @@ async def connect_async( if (instance_connection_string, enable_iam_auth) in self._cache: cache = self._cache[(instance_connection_string, enable_iam_auth)] else: + conn_name = await self._resolver.resolve(instance_connection_string) if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( f"['{instance_connection_string}']: Refresh strategy is set" " to lazy refresh" ) cache = LazyRefreshCache( - instance_connection_string, + conn_name, self._client, self._keys, enable_iam_auth, @@ -286,7 +298,7 @@ async def connect_async( " to backgound refresh" ) cache = RefreshAheadCache( - instance_connection_string, + conn_name, self._client, self._keys, enable_iam_auth, diff --git a/google/cloud/sql/connector/exceptions.py b/google/cloud/sql/connector/exceptions.py index 7bff2300d..92e3e5662 100644 --- a/google/cloud/sql/connector/exceptions.py +++ b/google/cloud/sql/connector/exceptions.py @@ -70,3 +70,10 @@ class IncompatibleDriverError(Exception): Exception to be raised when the database driver given is for the wrong database engine. (i.e. asyncpg for a MySQL database) """ + + +class DnsResolutionError(Exception): + """ + Exception to be raised when an instance connection name can not be resolved + from a DNS record. + """ diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index 9cf9bc787..3b0b9263d 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -24,7 +24,7 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo -from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import RefreshNotValidError from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter from google.cloud.sql.connector.refresh_utils import _is_valid @@ -45,7 +45,7 @@ class RefreshAheadCache: def __init__( self, - instance_connection_string: str, + conn_name: ConnectionName, client: CloudSQLClient, keys: asyncio.Future, enable_iam_auth: bool = False, @@ -53,8 +53,8 @@ def __init__( """Initializes a RefreshAheadCache instance. Args: - instance_connection_string (str): The Cloud SQL Instance's - connection string (also known as an instance connection name). + conn_name (ConnectionName): The Cloud SQL instance's + connection name. client (CloudSQLClient): The Cloud SQL Client instance. keys (asyncio.Future): A future to the client's public-private key pair. @@ -62,8 +62,6 @@ def __init__( (Postgres and MySQL) as the default authentication method for all connections. """ - # validate and parse instance connection name - conn_name = _parse_instance_connection_name(instance_connection_string) self._project, self._region, self._instance = ( conn_name.project, conn_name.region, diff --git a/google/cloud/sql/connector/lazy.py b/google/cloud/sql/connector/lazy.py index 672f989e8..ab73785d1 100644 --- a/google/cloud/sql/connector/lazy.py +++ b/google/cloud/sql/connector/lazy.py @@ -21,7 +21,7 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo -from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.refresh_utils import _refresh_buffer logger = logging.getLogger(name=__name__) @@ -38,7 +38,7 @@ class LazyRefreshCache: def __init__( self, - instance_connection_string: str, + conn_name: ConnectionName, client: CloudSQLClient, keys: asyncio.Future, enable_iam_auth: bool = False, @@ -46,8 +46,8 @@ def __init__( """Initializes a LazyRefreshCache instance. Args: - instance_connection_string (str): The Cloud SQL Instance's - connection string (also known as an instance connection name). + conn_name (ConnectionName): The Cloud SQL instance's + connection name. client (CloudSQLClient): The Cloud SQL Client instance. keys (asyncio.Future): A future to the client's public-private key pair. @@ -55,8 +55,6 @@ def __init__( (Postgres and MySQL) as the default authentication method for all connections. """ - # validate and parse instance connection name - conn_name = _parse_instance_connection_name(instance_connection_string) self._project, self._region, self._instance = ( conn_name.project, conn_name.region, diff --git a/google/cloud/sql/connector/resolver.py b/google/cloud/sql/connector/resolver.py new file mode 100644 index 000000000..15ccd6a21 --- /dev/null +++ b/google/cloud/sql/connector/resolver.py @@ -0,0 +1,67 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dns.asyncresolver + +from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.exceptions import DnsResolutionError + + +class DefaultResolver: + """DefaultResolver simply validates and parses instance connection name.""" + + async def resolve(self, connection_name: str) -> ConnectionName: + return _parse_instance_connection_name(connection_name) + + +class DnsResolver(dns.asyncresolver.Resolver): + """ + DnsResolver resolves domain names into instance connection names using + TXT records in DNS. + """ + + async def resolve(self, dns: str) -> ConnectionName: # type: ignore + try: + conn_name = _parse_instance_connection_name(dns) + except ValueError: + # The connection name was not project:region:instance format. + # Attempt to query a TXT record to get connection name. + conn_name = await self.query_dns(dns) + return conn_name + + async def query_dns(self, dns: str) -> ConnectionName: + try: + # Attempt to query the TXT records. + records = await super().resolve(dns, "TXT", raise_on_no_answer=True) + # Sort the TXT record values alphabetically, strip quotes as record + # values can be returned as raw strings + rdata = [record.to_text().strip('"') for record in records] + rdata.sort() + # Attempt to parse records, returning the first valid record. + for record in rdata: + try: + conn_name = _parse_instance_connection_name(record) + return conn_name + except Exception: + continue + # If all records failed to parse, throw error + raise DnsResolutionError( + f"Unable to parse TXT record for `{dns}` -> `{rdata[0]}`" + ) + # Don't override above DnsResolutionError + except DnsResolutionError: + raise + except Exception as e: + raise DnsResolutionError(f"Unable to resolve TXT record for `{dns}`") from e diff --git a/requirements.txt b/requirements.txt index 466849d56..0ff0e41e7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ aiofiles==24.1.0 aiohttp==3.11.9 cryptography==44.0.0 +dnspython==2.7.0 Requests==2.32.3 google-auth==2.36.0 diff --git a/setup.py b/setup.py index bb70449a5..79c6acf74 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ "aiofiles", "aiohttp", "cryptography>=42.0.0", + "dnspython>=2.0.0", "Requests", "google-auth>=2.28.0", ] diff --git a/tests/conftest.py b/tests/conftest.py index 470fe19f4..3a1a38a27 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,6 +26,7 @@ from unit.mocks import FakeCSQLInstance # type: ignore from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.utils import generate_keys @@ -144,7 +145,7 @@ async def fake_client( async def cache(fake_client: CloudSQLClient) -> AsyncGenerator[RefreshAheadCache, None]: keys = asyncio.create_task(generate_keys()) cache = RefreshAheadCache( - "test-project:test-region:test-instance", + ConnectionName("test-project", "test-region", "test-instance"), client=fake_client, keys=keys, ) diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index fd18f2d5e..d4f53ed51 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -26,6 +26,7 @@ from google.cloud.sql.connector import create_async_connector from google.cloud.sql.connector import IPTypes from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.exceptions import IncompatibleDriverError from google.cloud.sql.connector.instance import RefreshAheadCache @@ -322,18 +323,18 @@ async def test_Connector_remove_cached_bad_instance( async with Connector( credentials=fake_credentials, loop=asyncio.get_running_loop() ) as connector: - conn_name = "bad-project:bad-region:bad-inst" + conn_name = ConnectionName("bad-project", "bad-region", "bad-inst") # populate cache cache = RefreshAheadCache(conn_name, fake_client, connector._keys) - connector._cache[(conn_name, False)] = cache + connector._cache[(str(conn_name), False)] = cache # aiohttp client should throw a 404 ClientResponseError with pytest.raises(ClientResponseError): await connector.connect_async( - conn_name, + str(conn_name), "pg8000", ) # check that cache has been removed from dict - assert (conn_name, False) not in connector._cache + assert (str(conn_name), False) not in connector._cache async def test_Connector_remove_cached_no_ip_type( @@ -348,21 +349,21 @@ async def test_Connector_remove_cached_no_ip_type( async with Connector( credentials=fake_credentials, loop=asyncio.get_running_loop() ) as connector: - conn_name = "test-project:test-region:test-instance" + conn_name = ConnectionName("test-project", "test-region", "test-instance") # populate cache cache = RefreshAheadCache(conn_name, fake_client, connector._keys) - connector._cache[(conn_name, False)] = cache + connector._cache[(str(conn_name), False)] = cache # test instance does not have Private IP, thus should invalidate cache with pytest.raises(CloudSQLIPTypeError): await connector.connect_async( - conn_name, + str(conn_name), "pg8000", user="my-user", password="my-pass", ip_type="private", ) # check that cache has been removed from dict - assert (conn_name, False) not in connector._cache + assert (str(conn_name), False) not in connector._cache def test_default_universe_domain(fake_credentials: Credentials) -> None: diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 3ce0386b2..f80bb1494 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -24,6 +24,7 @@ from google.cloud.sql.connector import IPTypes from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.instance import RefreshAheadCache @@ -271,7 +272,7 @@ async def test_AutoIAMAuthNotSupportedError(fake_client: CloudSQLClient) -> None # generate client key pair keys = asyncio.create_task(generate_keys()) cache = RefreshAheadCache( - "test-project:test-region:sqlserver-instance", + ConnectionName("test-project", "test-region", "sqlserver-instance"), client=fake_client, keys=keys, enable_iam_auth=True, diff --git a/tests/unit/test_lazy.py b/tests/unit/test_lazy.py index 27cd80b4f..344b073e8 100644 --- a/tests/unit/test_lazy.py +++ b/tests/unit/test_lazy.py @@ -16,6 +16,7 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.lazy import LazyRefreshCache from google.cloud.sql.connector.utils import generate_keys @@ -26,7 +27,7 @@ async def test_LazyRefreshCache_connect_info(fake_client: CloudSQLClient) -> Non """ keys = asyncio.create_task(generate_keys()) cache = LazyRefreshCache( - "test-project:test-region:test-instance", + ConnectionName("test-project", "test-region", "test-instance"), client=fake_client, keys=keys, enable_iam_auth=False, @@ -47,7 +48,7 @@ async def test_LazyRefreshCache_force_refresh(fake_client: CloudSQLClient) -> No """ keys = asyncio.create_task(generate_keys()) cache = LazyRefreshCache( - "test-project:test-region:test-instance", + ConnectionName("test-project", "test-region", "test-instance"), client=fake_client, keys=keys, enable_iam_auth=False, diff --git a/tests/unit/test_resolver.py b/tests/unit/test_resolver.py new file mode 100644 index 000000000..d7404890a --- /dev/null +++ b/tests/unit/test_resolver.py @@ -0,0 +1,128 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dns.message +import dns.rdataclass +import dns.rdatatype +import dns.resolver +from mock import patch +import pytest + +from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.exceptions import DnsResolutionError +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver + +conn_str = "my-project:my-region:my-instance" +conn_name = ConnectionName("my-project", "my-region", "my-instance") + + +async def test_DefaultResolver() -> None: + """Test DefaultResolver just parses instance connection string.""" + resolver = DefaultResolver() + result = await resolver.resolve(conn_str) + assert result == conn_name + + +async def test_DnsResolver_with_conn_str() -> None: + """Test DnsResolver with instance connection name just parses connection string.""" + resolver = DnsResolver() + result = await resolver.resolve(conn_str) + assert result == conn_name + + +query_text = """id 1234 +opcode QUERY +rcode NOERROR +flags QR AA RD RA +;QUESTION +db.example.com. IN TXT +;ANSWER +db.example.com. 0 IN TXT "test-project:test-region:test-instance" +db.example.com. 0 IN TXT "my-project:my-region:my-instance" +;AUTHORITY +;ADDITIONAL +""" + + +async def test_DnsResolver_with_dns_name() -> None: + """Test DnsResolver resolves TXT record into proper instance connection name. + + Should sort valid TXT records alphabetically and take first one. + """ + # Patch DNS resolution with valid TXT records + with patch("dns.asyncresolver.Resolver.resolve") as mock_connect: + answer = dns.resolver.Answer( + "db.example.com", + dns.rdatatype.TXT, + dns.rdataclass.IN, + dns.message.from_text(query_text), + ) + mock_connect.return_value = answer + resolver = DnsResolver() + resolver.port = 5053 + # Resolution should return first value sorted alphabetically + result = await resolver.resolve("db.example.com") + assert result == conn_name + + +query_text_malformed = """id 1234 +opcode QUERY +rcode NOERROR +flags QR AA RD RA +;QUESTION +bad.example.com. IN TXT +;ANSWER +bad.example.com. 0 IN TXT "malformed-instance-name" +;AUTHORITY +;ADDITIONAL +""" + + +async def test_DnsResolver_with_malformed_txt() -> None: + """Test DnsResolver with TXT record that holds malformed instance connection name. + + Should throw DnsResolutionError + """ + # patch DNS resolution with malformed TXT record + with patch("dns.asyncresolver.Resolver.resolve") as mock_connect: + answer = dns.resolver.Answer( + "bad.example.com", + dns.rdatatype.TXT, + dns.rdataclass.IN, + dns.message.from_text(query_text_malformed), + ) + mock_connect.return_value = answer + resolver = DnsResolver() + resolver.port = 5053 + with pytest.raises(DnsResolutionError) as exc_info: + await resolver.resolve("bad.example.com") + assert ( + exc_info.value.args[0] + == "Unable to parse TXT record for `bad.example.com` -> `malformed-instance-name`" + ) + + +async def test_DnsResolver_with_bad_dns_name() -> None: + """Test DnsResolver with bad dns name. + + Should throw DnsResolutionError + """ + resolver = DnsResolver() + resolver.port = 5053 + # set lifetime to 1 second for shorter timeout + resolver.lifetime = 1 + with pytest.raises(DnsResolutionError) as exc_info: + await resolver.resolve("bad.dns.com") + assert exc_info.value.args[0] == "Unable to resolve TXT record for `bad.dns.com`" From 720d1ade26a790070b13c930b8b0c08335df7bc1 Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Thu, 5 Dec 2024 13:21:52 -0500 Subject: [PATCH 08/13] refactor: add `domain_name` to `ConnectionName` class (#1209) --- google/cloud/sql/connector/connection_name.py | 12 +++- google/cloud/sql/connector/resolver.py | 8 +-- tests/unit/test_connection_name.py | 60 ++++++++++++++++--- 3 files changed, 67 insertions(+), 13 deletions(-) diff --git a/google/cloud/sql/connector/connection_name.py b/google/cloud/sql/connector/connection_name.py index d240fb565..1bf711ab7 100644 --- a/google/cloud/sql/connector/connection_name.py +++ b/google/cloud/sql/connector/connection_name.py @@ -31,12 +31,21 @@ class ConnectionName: project: str region: str instance_name: str + domain_name: str = "" def __str__(self) -> str: + if self.domain_name: + return f"{self.domain_name} -> {self.project}:{self.region}:{self.instance_name}" return f"{self.project}:{self.region}:{self.instance_name}" -def _parse_instance_connection_name(connection_name: str) -> ConnectionName: +def _parse_connection_name(connection_name: str) -> ConnectionName: + return _parse_connection_name_with_domain_name(connection_name, "") + + +def _parse_connection_name_with_domain_name( + connection_name: str, domain_name: str +) -> ConnectionName: if CONN_NAME_REGEX.fullmatch(connection_name) is None: raise ValueError( "Arg `instance_connection_string` must have " @@ -48,4 +57,5 @@ def _parse_instance_connection_name(connection_name: str) -> ConnectionName: connection_name_split[1], connection_name_split[3], connection_name_split[4], + domain_name, ) diff --git a/google/cloud/sql/connector/resolver.py b/google/cloud/sql/connector/resolver.py index 15ccd6a21..2cdcddbe2 100644 --- a/google/cloud/sql/connector/resolver.py +++ b/google/cloud/sql/connector/resolver.py @@ -14,7 +14,7 @@ import dns.asyncresolver -from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +from google.cloud.sql.connector.connection_name import _parse_connection_name from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import DnsResolutionError @@ -23,7 +23,7 @@ class DefaultResolver: """DefaultResolver simply validates and parses instance connection name.""" async def resolve(self, connection_name: str) -> ConnectionName: - return _parse_instance_connection_name(connection_name) + return _parse_connection_name(connection_name) class DnsResolver(dns.asyncresolver.Resolver): @@ -34,7 +34,7 @@ class DnsResolver(dns.asyncresolver.Resolver): async def resolve(self, dns: str) -> ConnectionName: # type: ignore try: - conn_name = _parse_instance_connection_name(dns) + conn_name = _parse_connection_name(dns) except ValueError: # The connection name was not project:region:instance format. # Attempt to query a TXT record to get connection name. @@ -52,7 +52,7 @@ async def query_dns(self, dns: str) -> ConnectionName: # Attempt to parse records, returning the first valid record. for record in rdata: try: - conn_name = _parse_instance_connection_name(record) + conn_name = _parse_connection_name(record) return conn_name except Exception: continue diff --git a/tests/unit/test_connection_name.py b/tests/unit/test_connection_name.py index 1e3730424..a62f88d5f 100644 --- a/tests/unit/test_connection_name.py +++ b/tests/unit/test_connection_name.py @@ -14,9 +14,14 @@ import pytest # noqa F401 Needed to run the tests -from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +# fmt: off +from google.cloud.sql.connector.connection_name import _parse_connection_name +from google.cloud.sql.connector.connection_name import \ + _parse_connection_name_with_domain_name from google.cloud.sql.connector.connection_name import ConnectionName +# fmt: on + def test_ConnectionName() -> None: conn_name = ConnectionName("project", "region", "instance") @@ -24,10 +29,22 @@ def test_ConnectionName() -> None: assert conn_name.project == "project" assert conn_name.region == "region" assert conn_name.instance_name == "instance" + assert conn_name.domain_name == "" # test ConnectionName str() method prints instance connection name assert str(conn_name) == "project:region:instance" +def test_ConnectionName_with_domain_name() -> None: + conn_name = ConnectionName("project", "region", "instance", "db.example.com") + # test class attributes are set properly + assert conn_name.project == "project" + assert conn_name.region == "region" + assert conn_name.instance_name == "instance" + assert conn_name.domain_name == "db.example.com" + # test ConnectionName str() method prints with domain name + assert str(conn_name) == "db.example.com -> project:region:instance" + + @pytest.mark.parametrize( "connection_name, expected", [ @@ -38,19 +55,46 @@ def test_ConnectionName() -> None: ), ], ) -def test_parse_instance_connection_name( - connection_name: str, expected: ConnectionName -) -> None: +def test_parse_connection_name(connection_name: str, expected: ConnectionName) -> None: """ - Test that _parse_instance_connection_name works correctly on + Test that _parse_connection_name works correctly on normal instance connection names and domain-scoped projects. """ - assert expected == _parse_instance_connection_name(connection_name) + assert expected == _parse_connection_name(connection_name) -def test_parse_instance_connection_name_bad_conn_name() -> None: +def test_parse_connection_name_bad_conn_name() -> None: """ Tests that ValueError is thrown for bad instance connection names. """ with pytest.raises(ValueError): - _parse_instance_connection_name("project:instance") # missing region + _parse_connection_name("project:instance") # missing region + + +@pytest.mark.parametrize( + "connection_name, domain_name, expected", + [ + ( + "project:region:instance", + "db.example.com", + ConnectionName("project", "region", "instance", "db.example.com"), + ), + ( + "domain-prefix:project:region:instance", + "db.example.com", + ConnectionName( + "domain-prefix:project", "region", "instance", "db.example.com" + ), + ), + ], +) +def test_parse_connection_name_with_domain_name( + connection_name: str, domain_name: str, expected: ConnectionName +) -> None: + """ + Test that _parse_connection_name_with_domain_name works correctly on + normal instance connection names and domain-scoped projects. + """ + assert expected == _parse_connection_name_with_domain_name( + connection_name, domain_name + ) From 5af7582d94ea9733a7b9366f8d24c742ff5c6dac Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Mon, 9 Dec 2024 11:18:56 -0500 Subject: [PATCH 09/13] ci: use black isort profile and update to latest (#1211) --- noxfile.py | 12 +++++++++--- tests/unit/test_connection_name.py | 8 +++----- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/noxfile.py b/noxfile.py index 8329b2de8..de1b9b81d 100644 --- a/noxfile.py +++ b/noxfile.py @@ -20,7 +20,7 @@ import nox -BLACK_VERSION = "black==23.12.1" +BLACK_VERSION = "black==24.10.0" ISORT_VERSION = "isort==5.13.2" LINT_PATHS = ["google", "tests", "noxfile.py", "setup.py"] @@ -50,7 +50,10 @@ def lint(session): "--fss", "--check-only", "--diff", - "--profile=google", + "--profile=black", + "--force-single-line-imports", + "--dont-order-by-type", + "--single-line-exclusions=typing", "-w=88", *LINT_PATHS, ) @@ -85,7 +88,10 @@ def format(session): session.run( "isort", "--fss", - "--profile=google", + "--profile=black", + "--force-single-line-imports", + "--dont-order-by-type", + "--single-line-exclusions=typing", "-w=88", *LINT_PATHS, ) diff --git a/tests/unit/test_connection_name.py b/tests/unit/test_connection_name.py index a62f88d5f..783e14fe3 100644 --- a/tests/unit/test_connection_name.py +++ b/tests/unit/test_connection_name.py @@ -14,14 +14,12 @@ import pytest # noqa F401 Needed to run the tests -# fmt: off +from google.cloud.sql.connector.connection_name import ( + _parse_connection_name_with_domain_name, +) from google.cloud.sql.connector.connection_name import _parse_connection_name -from google.cloud.sql.connector.connection_name import \ - _parse_connection_name_with_domain_name from google.cloud.sql.connector.connection_name import ConnectionName -# fmt: on - def test_ConnectionName() -> None: conn_name = ConnectionName("project", "region", "instance") From 7231c57b2437ec84c0ae87a28289ebfd8744393e Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Mon, 9 Dec 2024 11:36:55 -0500 Subject: [PATCH 10/13] refactor: add `ConnectionName` to `ConnectionInfo` class (#1212) Adding ConnectionName to ConnectionInfo class Benefits: - Cleaner interface, passing connection name around instead of individual args (project, region, instance) - Consistent debug logs (give access to ConnectionName throughout) - Allows us to tell if ConnectionInfo is using DNS (if ConnectionName.domain_name is set) --- google/cloud/sql/connector/client.py | 22 +++++++++---------- google/cloud/sql/connector/connection_info.py | 2 ++ google/cloud/sql/connector/connector.py | 15 ++++--------- google/cloud/sql/connector/instance.py | 13 +++-------- google/cloud/sql/connector/lazy.py | 10 +-------- google/cloud/sql/connector/resolver.py | 5 ++++- tests/unit/test_instance.py | 8 +++---- tests/unit/test_resolver.py | 5 ++++- 8 files changed, 32 insertions(+), 48 deletions(-) diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index 61b77d561..8a31eb9a0 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -26,6 +26,7 @@ from google.auth.transport import requests from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported from google.cloud.sql.connector.refresh_utils import _downscope_credentials from google.cloud.sql.connector.refresh_utils import retry_50x @@ -245,9 +246,7 @@ async def _get_ephemeral( async def get_connection_info( self, - project: str, - region: str, - instance: str, + conn_name: ConnectionName, keys: asyncio.Future, enable_iam_auth: bool, ) -> ConnectionInfo: @@ -255,10 +254,8 @@ async def get_connection_info( Admin API. Args: - project (str): The name of the project the Cloud SQL instance is - located in. - region (str): The region the Cloud SQL instance is located in. - instance (str): Name of the Cloud SQL instance. + conn_name (ConnectionName): The Cloud SQL instance's + connection name. keys (asyncio.Future): A future to the client's public-private key pair. enable_iam_auth (bool): Whether an automatic IAM database @@ -278,16 +275,16 @@ async def get_connection_info( metadata_task = asyncio.create_task( self._get_metadata( - project, - region, - instance, + conn_name.project, + conn_name.region, + conn_name.instance_name, ) ) ephemeral_task = asyncio.create_task( self._get_ephemeral( - project, - instance, + conn_name.project, + conn_name.instance_name, pub_key, enable_iam_auth, ) @@ -311,6 +308,7 @@ async def get_connection_info( ephemeral_cert, expiration = await ephemeral_task return ConnectionInfo( + conn_name, ephemeral_cert, metadata["server_ca_cert"], priv_key, diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index b738063c2..82e3a9018 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -21,6 +21,7 @@ from aiofiles.tempfile import TemporaryDirectory +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.exceptions import TLSVersionError from google.cloud.sql.connector.utils import write_to_file @@ -38,6 +39,7 @@ class ConnectionInfo: """Contains all necessary information to connect securely to the server-side Proxy running on a Cloud SQL instance.""" + conn_name: ConnectionName client_cert: str server_ca_cert: str private_key: bytes diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 1e67373eb..7c5fb20de 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -113,7 +113,6 @@ def __init__( name. To resolve a DNS record to an instance connection name, use DnsResolver. Default: DefaultResolver - """ # if refresh_strategy is str, convert to RefreshStrategy enum if isinstance(refresh_strategy, str): @@ -283,8 +282,7 @@ async def connect_async( conn_name = await self._resolver.resolve(instance_connection_string) if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( - f"['{instance_connection_string}']: Refresh strategy is set" - " to lazy refresh" + f"['{conn_name}']: Refresh strategy is set to lazy refresh" ) cache = LazyRefreshCache( conn_name, @@ -294,8 +292,7 @@ async def connect_async( ) else: logger.debug( - f"['{instance_connection_string}']: Refresh strategy is set" - " to backgound refresh" + f"['{conn_name}']: Refresh strategy is set to backgound refresh" ) cache = RefreshAheadCache( conn_name, @@ -303,9 +300,7 @@ async def connect_async( self._keys, enable_iam_auth, ) - logger.debug( - f"['{instance_connection_string}']: Connection info added to cache" - ) + logger.debug(f"['{conn_name}']: Connection info added to cache") self._cache[(instance_connection_string, enable_iam_auth)] = cache connect_func = { @@ -344,9 +339,7 @@ async def connect_async( # the cache and re-raise the error await self._remove_cached(instance_connection_string, enable_iam_auth) raise - logger.debug( - f"['{instance_connection_string}']: Connecting to {ip_address}:3307" - ) + logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") # format `user` param for automatic IAM database authn if enable_iam_auth: formatted_user = format_database_user( diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index 3b0b9263d..5df272fe2 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -62,11 +62,6 @@ def __init__( (Postgres and MySQL) as the default authentication method for all connections. """ - self._project, self._region, self._instance = ( - conn_name.project, - conn_name.region, - conn_name.instance_name, - ) self._conn_name = conn_name self._enable_iam_auth = enable_iam_auth @@ -104,20 +99,18 @@ async def _perform_refresh(self) -> ConnectionInfo: """ self._refresh_in_progress.set() logger.debug( - f"['{self._conn_name}']: Connection info refresh " "operation started" + f"['{self._conn_name}']: Connection info refresh operation started" ) try: await self._refresh_rate_limiter.acquire() connection_info = await self._client.get_connection_info( - self._project, - self._region, - self._instance, + self._conn_name, self._keys, self._enable_iam_auth, ) logger.debug( - f"['{self._conn_name}']: Connection info " "refresh operation complete" + f"['{self._conn_name}']: Connection info refresh operation complete" ) logger.debug( f"['{self._conn_name}']: Current certificate " diff --git a/google/cloud/sql/connector/lazy.py b/google/cloud/sql/connector/lazy.py index ab73785d1..1bc4f90f8 100644 --- a/google/cloud/sql/connector/lazy.py +++ b/google/cloud/sql/connector/lazy.py @@ -55,13 +55,7 @@ def __init__( (Postgres and MySQL) as the default authentication method for all connections. """ - self._project, self._region, self._instance = ( - conn_name.project, - conn_name.region, - conn_name.instance_name, - ) self._conn_name = conn_name - self._enable_iam_auth = enable_iam_auth self._keys = keys self._client = client @@ -101,9 +95,7 @@ async def connect_info(self) -> ConnectionInfo: ) try: conn_info = await self._client.get_connection_info( - self._project, - self._region, - self._instance, + self._conn_name, self._keys, self._enable_iam_auth, ) diff --git a/google/cloud/sql/connector/resolver.py b/google/cloud/sql/connector/resolver.py index 2cdcddbe2..39efd0492 100644 --- a/google/cloud/sql/connector/resolver.py +++ b/google/cloud/sql/connector/resolver.py @@ -14,6 +14,9 @@ import dns.asyncresolver +from google.cloud.sql.connector.connection_name import ( + _parse_connection_name_with_domain_name, +) from google.cloud.sql.connector.connection_name import _parse_connection_name from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import DnsResolutionError @@ -52,7 +55,7 @@ async def query_dns(self, dns: str) -> ConnectionName: # Attempt to parse records, returning the first valid record. for record in rdata: try: - conn_name = _parse_connection_name(record) + conn_name = _parse_connection_name_with_domain_name(record, dns) return conn_name except Exception: continue diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index f80bb1494..aeedf3399 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -47,9 +47,9 @@ async def test_Instance_init( can tell if the connection string that's passed in is formatted correctly. """ assert ( - cache._project == "test-project" - and cache._region == "test-region" - and cache._instance == "test-instance" + cache._conn_name.project == "test-project" + and cache._conn_name.region == "test-region" + and cache._conn_name.instance_name == "test-instance" ) assert cache._enable_iam_auth is False @@ -283,7 +283,7 @@ async def test_AutoIAMAuthNotSupportedError(fake_client: CloudSQLClient) -> None async def test_ConnectionInfo_caches_sslcontext() -> None: info = ConnectionInfo( - "cert", "cert", "key".encode(), {}, "POSTGRES", datetime.datetime.now() + "", "cert", "cert", "key".encode(), {}, "POSTGRES", datetime.datetime.now() ) # context should default to None assert info.context is None diff --git a/tests/unit/test_resolver.py b/tests/unit/test_resolver.py index d7404890a..a9a7f2632 100644 --- a/tests/unit/test_resolver.py +++ b/tests/unit/test_resolver.py @@ -26,6 +26,9 @@ conn_str = "my-project:my-region:my-instance" conn_name = ConnectionName("my-project", "my-region", "my-instance") +conn_name_with_domain = ConnectionName( + "my-project", "my-region", "my-instance", "db.example.com" +) async def test_DefaultResolver() -> None: @@ -74,7 +77,7 @@ async def test_DnsResolver_with_dns_name() -> None: resolver.port = 5053 # Resolution should return first value sorted alphabetically result = await resolver.resolve("db.example.com") - assert result == conn_name + assert result == conn_name_with_domain query_text_malformed = """id 1234 From 10b94fcee2af3fdd50a3cbd39446c16d7ef9214c Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Tue, 10 Dec 2024 17:13:50 +0100 Subject: [PATCH 11/13] chore(deps): update dependency aiohttp to v3.11.10 (#1210) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0ff0e41e7..75b53b946 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ aiofiles==24.1.0 -aiohttp==3.11.9 +aiohttp==3.11.10 cryptography==44.0.0 dnspython==2.7.0 Requests==2.32.3 From f2b8fe67c5da9930e2d6636e9e53f3db3f18e308 Mon Sep 17 00:00:00 2001 From: Mend Renovate Date: Tue, 10 Dec 2024 17:25:17 +0100 Subject: [PATCH 12/13] chore(deps): Update github/codeql-action action to v3.27.7 (#1207) --- .github/workflows/codeql.yml | 6 +++--- .github/workflows/scorecard.yml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 1116c9986..50e6d044e 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -46,16 +46,16 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@f09c1c0a94de965c15400f5634aa42fac8fb8f88 # v3.27.5 + uses: github/codeql-action/init@babb554ede22fd5605947329c4d04d8e7a0b8155 # v3.27.7 with: languages: ${{ matrix.language }} # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). # If this step fails, then you should remove it and run the build manually - name: Autobuild - uses: github/codeql-action/autobuild@f09c1c0a94de965c15400f5634aa42fac8fb8f88 # v3.27.5 + uses: github/codeql-action/autobuild@babb554ede22fd5605947329c4d04d8e7a0b8155 # v3.27.7 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@f09c1c0a94de965c15400f5634aa42fac8fb8f88 # v3.27.5 + uses: github/codeql-action/analyze@babb554ede22fd5605947329c4d04d8e7a0b8155 # v3.27.7 with: category: "/language:${{matrix.language}}" diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index ee3efa120..653a2d6fc 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -65,6 +65,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@f09c1c0a94de965c15400f5634aa42fac8fb8f88 # v3.27.5 + uses: github/codeql-action/upload-sarif@babb554ede22fd5605947329c4d04d8e7a0b8155 # v3.27.7 with: sarif_file: resultsFiltered.sarif From fe43047298529078dc1f5869fc98534b63f28e97 Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:26:29 -0500 Subject: [PATCH 13/13] chore(main): release 1.15.0 (#1208) Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> --- CHANGELOG.md | 8 ++++++++ google/cloud/sql/connector/version.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e97c4830..1467f5436 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## [1.15.0](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/compare/v1.14.0...v1.15.0) (2024-12-10) + + +### Features + +* add support for DNS names with Connector ([#1204](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/issues/1204)) ([1a8f274](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/commit/1a8f2743c661806528c54110fd2a08384398816c)) +* improve aiohttp client error messages ([#1201](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/issues/1201)) ([11f9fe9](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/commit/11f9fe93617cd9a161c69d527d3e7592fcbe56af)) + ## [1.14.0](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/compare/v1.13.0...v1.14.0) (2024-11-20) diff --git a/google/cloud/sql/connector/version.py b/google/cloud/sql/connector/version.py index 7f2e3d950..587a7adbf 100644 --- a/google/cloud/sql/connector/version.py +++ b/google/cloud/sql/connector/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.14.0" +__version__ = "1.15.0"