diff --git a/.ci/cloudbuild.yaml b/.ci/cloudbuild.yaml new file mode 100644 index 000000000..e31e8bebb --- /dev/null +++ b/.ci/cloudbuild.yaml @@ -0,0 +1,108 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +steps: + - id: run integration tests + name: python:${_VERSION} + entrypoint: bash + env: + - "IP_TYPE=${_IP_TYPE}" + secretEnv: + [ + "MYSQL_CONNECTION_NAME", + "MYSQL_USER", + "MYSQL_IAM_USER", + "MYSQL_PASS", + "MYSQL_DB", + "MYSQL_MCP_CONNECTION_NAME", + "MYSQL_MCP_PASS", + "POSTGRES_CONNECTION_NAME", + "POSTGRES_USER", + "POSTGRES_IAM_USER", + "POSTGRES_PASS", + "POSTGRES_DB", + "POSTGRES_CAS_CONNECTION_NAME", + "POSTGRES_CAS_PASS", + "POSTGRES_CUSTOMER_CAS_CONNECTION_NAME", + "POSTGRES_CUSTOMER_CAS_PASS", + "POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME", + "POSTGRES_MCP_CONNECTION_NAME", + "POSTGRES_MCP_PASS", + "SQLSERVER_CONNECTION_NAME", + "SQLSERVER_USER", + "SQLSERVER_PASS", + "SQLSERVER_DB", + ] + args: + - "-c" + - | + pip install nox + nox -s system-${_VERSION} +availableSecrets: + secretManager: + - versionName: "projects/$PROJECT_ID/secrets/MYSQL_CONNECTION_NAME/versions/latest" + env: "MYSQL_CONNECTION_NAME" + - versionName: "projects/$PROJECT_ID/secrets/MYSQL_USER/versions/latest" + env: "MYSQL_USER" + - versionName: "projects/$PROJECT_ID/secrets/CLOUD_BUILD_MYSQL_IAM_USER/versions/latest" + env: "MYSQL_IAM_USER" + - versionName: "projects/$PROJECT_ID/secrets/MYSQL_PASS/versions/latest" + env: "MYSQL_PASS" + - versionName: "projects/$PROJECT_ID/secrets/MYSQL_DB/versions/latest" + env: "MYSQL_DB" + - versionName: "projects/$PROJECT_ID/secrets/MYSQL_MCP_CONNECTION_NAME/versions/latest" + env: "MYSQL_MCP_CONNECTION_NAME" + - versionName: "projects/$PROJECT_ID/secrets/MYSQL_MCP_PASS/versions/latest" + env: "MYSQL_MCP_PASS" + - versionName: "projects/$PROJECT_ID/secrets/POSTGRES_CONNECTION_NAME/versions/latest" + env: "POSTGRES_CONNECTION_NAME" + - versionName: "projects/$PROJECT_ID/secrets/POSTGRES_USER/versions/latest" + env: "POSTGRES_USER" + - versionName: "projects/$PROJECT_ID/secrets/CLOUD_BUILD_POSTGRES_IAM_USER/versions/latest" + env: "POSTGRES_IAM_USER" + - versionName: "projects/$PROJECT_ID/secrets/POSTGRES_PASS/versions/latest" + env: "POSTGRES_PASS" + - versionName: "projects/$PROJECT_ID/secrets/POSTGRES_DB/versions/latest" + env: "POSTGRES_DB" + - versionName: "projects/$PROJECT_ID/secrets/POSTGRES_CAS_CONNECTION_NAME/versions/latest" + env: "POSTGRES_CAS_CONNECTION_NAME" + - versionName: "projects/$PROJECT_ID/secrets/POSTGRES_CAS_PASS/versions/latest" + env: "POSTGRES_CAS_PASS" + - versionName: "projects/$PROJECT_ID/secrets/POSTGRES_CUSTOMER_CAS_CONNECTION_NAME/versions/latest" + env: "POSTGRES_CUSTOMER_CAS_CONNECTION_NAME" + - versionName: "projects/$PROJECT_ID/secrets/POSTGRES_CUSTOMER_CAS_PASS/versions/latest" + env: "POSTGRES_CUSTOMER_CAS_PASS" + - versionName: "projects/$PROJECT_ID/secrets/POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME/versions/latest" + env: "POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME" + - versionName: "projects/$PROJECT_ID/secrets/POSTGRES_MCP_CONNECTION_NAME/versions/latest" + env: "POSTGRES_MCP_CONNECTION_NAME" + - versionName: "projects/$PROJECT_ID/secrets/POSTGRES_MCP_PASS/versions/latest" + env: "POSTGRES_MCP_PASS" + - versionName: "projects/$PROJECT_ID/secrets/SQLSERVER_CONNECTION_NAME/versions/latest" + env: "SQLSERVER_CONNECTION_NAME" + - versionName: "projects/$PROJECT_ID/secrets/SQLSERVER_USER/versions/latest" + env: "SQLSERVER_USER" + - versionName: "projects/$PROJECT_ID/secrets/SQLSERVER_PASS/versions/latest" + env: "SQLSERVER_PASS" + - versionName: "projects/$PROJECT_ID/secrets/SQLSERVER_DB/versions/latest" + env: "SQLSERVER_DB" +substitutions: + _VERSION: ${_VERSION} + _IP_TYPE: ${_IP_TYPE} + +options: + dynamicSubstitutions: true + pool: + name: ${_POOL_NAME} + logging: CLOUD_LOGGING_ONLY diff --git a/.github/blunderbuss.yml b/.github/blunderbuss.yml index ded14a19b..bc2a8baf0 100644 --- a/.github/blunderbuss.yml +++ b/.github/blunderbuss.yml @@ -14,6 +14,7 @@ assign_issues: - jackwotherspoon + - kgala2 assign_prs: diff --git a/.github/trusted-contribution.yml b/.github/trusted-contribution.yml new file mode 100644 index 000000000..18580d069 --- /dev/null +++ b/.github/trusted-contribution.yml @@ -0,0 +1,28 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Trigger presubmit tests for trusted contributors +# https://github.com/googleapis/repo-automation-bots/tree/main/packages/trusted-contribution +# Install: https://github.com/apps/trusted-contributions-gcf + +trustedContributors: + - "dependabot[bot]" + - "renovate-bot" + - "renovate[bot]" + - "forking-renovate[bot]" + - "release-please[bot]" +annotations: + # Trigger Cloud Build tests + - type: comment + text: "/gcbrun" diff --git a/.github/workflows/cloud_build_failure_reporter.yml b/.github/workflows/cloud_build_failure_reporter.yml new file mode 100644 index 000000000..a07e3a676 --- /dev/null +++ b/.github/workflows/cloud_build_failure_reporter.yml @@ -0,0 +1,180 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Cloud Build Failure Reporter + +on: + workflow_call: + inputs: + trigger_names: + required: true + type: string + workflow_dispatch: + inputs: + trigger_names: + description: 'Cloud Build trigger names separated by comma.' + required: true + default: '' + +jobs: + report: + + permissions: + issues: 'write' + checks: 'read' + contents: 'read' + + runs-on: 'ubuntu-latest' + + steps: + - uses: 'actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea' # v7 + with: + script: |- + // parse test names + const testNameSubstring = '${{ inputs.trigger_names }}'; + const testNameFound = new Map(); //keeps track of whether each test is found + testNameSubstring.split(',').forEach(testName => { + testNameFound.set(testName, false); + }); + + // label for all issues opened by reporter + const periodicLabel = 'periodic-failure'; + + // check if any reporter opened any issues previously + const prevIssues = await github.paginate(github.rest.issues.listForRepo, { + ...context.repo, + state: 'open', + creator: 'github-actions[bot]', + labels: [periodicLabel] + }); + + // createOrCommentIssue creates a new issue or comments on an existing issue. + const createOrCommentIssue = async function (title, txt) { + if (prevIssues.length < 1) { + console.log('no previous issues found, creating one'); + await github.rest.issues.create({ + ...context.repo, + title: title, + body: txt, + labels: [periodicLabel] + }); + return; + } + // only comment on issue related to the current test + for (const prevIssue of prevIssues) { + if (prevIssue.title.includes(title)){ + console.log( + `found previous issue ${prevIssue.html_url}, adding comment` + ); + + await github.rest.issues.createComment({ + ...context.repo, + issue_number: prevIssue.number, + body: txt + }); + return; + } + } + }; + + // updateIssues comments on any existing issues. No-op if no issue exists. + const updateIssues = async function (checkName, txt) { + if (prevIssues.length < 1) { + console.log('no previous issues found.'); + return; + } + // only comment on issue related to the current test + for (const prevIssue of prevIssues) { + if (prevIssue.title.includes(checkName)){ + console.log(`found previous issue ${prevIssue.html_url}, adding comment`); + await github.rest.issues.createComment({ + ...context.repo, + issue_number: prevIssue.number, + body: txt + }); + } + } + }; + + // Find status of check runs. + // We will find check runs for each commit and then filter for the periodic. + // Checks API only allows for ref and if we use main there could be edge cases where + // the check run happened on a SHA that is different from head. + const commits = await github.paginate(github.rest.repos.listCommits, { + ...context.repo + }); + + const relevantChecks = new Map(); + for (const commit of commits) { + console.log( + `checking runs at ${commit.html_url}: ${commit.commit.message}` + ); + const checks = await github.rest.checks.listForRef({ + ...context.repo, + ref: commit.sha + }); + + // Iterate through each check and find matching names + for (const check of checks.data.check_runs) { + console.log(`Handling test name ${check.name}`); + for (const testName of testNameFound.keys()) { + if (testNameFound.get(testName) === true){ + //skip if a check is already found for this name + continue; + } + if (check.name.includes(testName)) { + relevantChecks.set(check, commit); + testNameFound.set(testName, true); + } + } + } + // Break out of the loop early if all tests are found + const allTestsFound = Array.from(testNameFound.values()).every(value => value === true); + if (allTestsFound){ + break; + } + } + + // Handle each relevant check + relevantChecks.forEach((commit, check) => { + if ( + check.status === 'completed' && + check.conclusion === 'success' + ) { + updateIssues( + check.name, + `[Tests are passing](${check.html_url}) for commit [${commit.sha}](${commit.html_url}).` + ); + } else if (check.status === 'in_progress') { + console.log( + `Check is pending ${check.html_url} for ${commit.html_url}. Retry again later.` + ); + } else { + createOrCommentIssue( + `Cloud Build Failure Reporter: ${check.name} failed`, + `Cloud Build Failure Reporter found test failure for [**${check.name}** ](${check.html_url}) at [${commit.sha}](${commit.html_url}). Please fix the error and then close the issue after the **${check.name}** test passes.` + ); + } + }); + + // no periodic checks found across all commits, report it + const noTestFound = Array.from(testNameFound.values()).every(value => value === false); + if (noTestFound){ + createOrCommentIssue( + 'Missing periodic tests: ${{ inputs.trigger_names }}', + `No periodic test is found for triggers: ${{ inputs.trigger_names }}. Last checked from ${ + commits[0].html_url + } to ${commits[commits.length - 1].html_url}.` + ); + } diff --git a/.github/workflows/schedule_reporter.yml b/.github/workflows/schedule_reporter.yml new file mode 100644 index 000000000..bad7e46c2 --- /dev/null +++ b/.github/workflows/schedule_reporter.yml @@ -0,0 +1,29 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Schedule Reporter + +on: + schedule: + - cron: '0 6 * * *' # Runs at 6 AM every morning + +jobs: + run_reporter: + permissions: + issues: 'write' + checks: 'read' + contents: 'read' + uses: ./.github/workflows/cloud_build_failure_reporter.yml + with: + trigger_names: "py-continuous-test-on-merge,py-integration-test-nightly" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b8e6eb58d..8e5d0b002 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -72,6 +72,8 @@ jobs: MYSQL_IAM_USER:${{ vars.GOOGLE_CLOUD_PROJECT }}/MYSQL_USER_IAM_PYTHON MYSQL_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/MYSQL_PASS MYSQL_DB:${{ vars.GOOGLE_CLOUD_PROJECT }}/MYSQL_DB + MYSQL_MCP_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/MYSQL_MCP_CONNECTION_NAME + MYSQL_MCP_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/MYSQL_MCP_PASS POSTGRES_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CONNECTION_NAME POSTGRES_USER:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_USER POSTGRES_IAM_USER:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_USER_IAM_PYTHON @@ -82,6 +84,8 @@ jobs: POSTGRES_CUSTOMER_CAS_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_CONNECTION_NAME POSTGRES_CUSTOMER_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME + POSTGRES_MCP_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_MCP_CONNECTION_NAME + POSTGRES_MCP_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_MCP_PASS SQLSERVER_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_CONNECTION_NAME SQLSERVER_USER:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_USER SQLSERVER_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_PASS @@ -94,6 +98,8 @@ jobs: MYSQL_IAM_USER: "${{ steps.secrets.outputs.MYSQL_IAM_USER }}" MYSQL_PASS: "${{ steps.secrets.outputs.MYSQL_PASS }}" MYSQL_DB: "${{ steps.secrets.outputs.MYSQL_DB }}" + MYSQL_MCP_CONNECTION_NAME: "${{ steps.secrets.outputs.MYSQL_MCP_CONNECTION_NAME }}" + MYSQL_MCP_PASS: "${{ steps.secrets.outputs.MYSQL_MCP_PASS }}" POSTGRES_CONNECTION_NAME: "${{ steps.secrets.outputs.POSTGRES_CONNECTION_NAME }}" POSTGRES_USER: "${{ steps.secrets.outputs.POSTGRES_USER }}" POSTGRES_IAM_USER: "${{ steps.secrets.outputs.POSTGRES_IAM_USER }}" @@ -104,6 +110,8 @@ jobs: POSTGRES_CUSTOMER_CAS_CONNECTION_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_CONNECTION_NAME }}" POSTGRES_CUSTOMER_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS }}" POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME }}" + POSTGRES_MCP_CONNECTION_NAME: "${{ steps.secrets.outputs.POSTGRES_MCP_CONNECTION_NAME }}" + POSTGRES_MCP_PASS: "${{ steps.secrets.outputs.POSTGRES_MCP_PASS }}" SQLSERVER_CONNECTION_NAME: "${{ steps.secrets.outputs.SQLSERVER_CONNECTION_NAME }}" SQLSERVER_USER: "${{ steps.secrets.outputs.SQLSERVER_USER }}" SQLSERVER_PASS: "${{ steps.secrets.outputs.SQLSERVER_PASS }}" diff --git a/CHANGELOG.md b/CHANGELOG.md index 47c853bca..823a77a01 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## [1.18.1](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/compare/v1.18.0...v1.18.1) (2025-04-16) + + +### Bug Fixes + +* bump dependencies to latest ([#1283](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/issues/1283)) ([f29b639](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/commit/f29b6396f8eb4bed9070b3a67711fe6698ed0d51)) + + +### Documentation + +* use lambda over getconn func ([#1251](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/issues/1251)) ([6ecf894](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/commit/6ecf894759bd44c729a5a53f34f3f161923d1e33)) + ## [1.18.0](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/compare/v1.17.0...v1.18.0) (2025-03-21) diff --git a/README.md b/README.md index d79e706d3..1c5489e04 100644 --- a/README.md +++ b/README.md @@ -126,21 +126,16 @@ import sqlalchemy # initialize Connector object connector = Connector() -# function to return the database connection -def getconn() -> pymysql.connections.Connection: - conn: pymysql.connections.Connection = connector.connect( +# initialize SQLAlchemy connection pool with Connector +pool = sqlalchemy.create_engine( + "mysql+pymysql://", + creator=lambda: connector.connect( "project:region:instance", "pymysql", user="my-user", password="my-password", db="my-db-name" - ) - return conn - -# create connection pool -pool = sqlalchemy.create_engine( - "mysql+pymysql://", - creator=getconn, + ), ) ``` @@ -207,33 +202,21 @@ Connector as a context manager: ```python from google.cloud.sql.connector import Connector -import pymysql import sqlalchemy -# helper function to return SQLAlchemy connection pool -def init_connection_pool(connector: Connector) -> sqlalchemy.engine.Engine: - # function used to generate database connection - def getconn() -> pymysql.connections.Connection: - conn = connector.connect( +# initialize Cloud SQL Python Connector as context manager +with Connector() as connector: + # initialize SQLAlchemy connection pool with Connector + pool = sqlalchemy.create_engine( + "mysql+pymysql://", + creator=lambda: connector.connect( "project:region:instance", "pymysql", user="my-user", password="my-password", db="my-db-name" - ) - return conn - - # create connection pool - pool = sqlalchemy.create_engine( - "mysql+pymysql://", - creator=getconn, + ), ) - return pool - -# initialize Cloud SQL Python Connector as context manager -with Connector() as connector: - # initialize connection pool - pool = init_connection_pool(connector) # insert statement insert_stmt = sqlalchemy.text( "INSERT INTO my_table (id, title) VALUES (:id, :title)", @@ -401,30 +384,19 @@ from google.cloud.sql.connector import Connector, DnsResolver import pymysql import sqlalchemy -# helper function to return SQLAlchemy connection pool -def init_connection_pool(connector: Connector) -> sqlalchemy.engine.Engine: - # function used to generate database connection - def getconn() -> pymysql.connections.Connection: - conn = connector.connect( +# initialize Cloud SQL Python Connector with `resolver=DnsResolver` +with Connector(resolver=DnsResolver) as connector: + # initialize SQLAlchemy connection pool with Connector + pool = sqlalchemy.create_engine( + "mysql+pymysql://", + creator=lambda: connector.connect( "prod-db.mycompany.example.com", # using DNS name "pymysql", user="my-user", password="my-password", db="my-db-name" - ) - return conn - - # create connection pool - pool = sqlalchemy.create_engine( - "mysql+pymysql://", - creator=getconn, + ), ) - return pool - -# initialize Cloud SQL Python Connector with `resolver=DnsResolver` -with Connector(resolver=DnsResolver) as connector: - # initialize connection pool - pool = init_connection_pool(connector) # ... use SQLAlchemy engine normally ``` @@ -501,9 +473,12 @@ from google.cloud.sql.connector import Connector # initialize Python Connector object connector = Connector() -# Python Connector database connection function -def getconn(): - conn = connector.connect( +app = Flask(__name__) + +# configure Flask-SQLAlchemy to use Python Connector +app.config['SQLALCHEMY_DATABASE_URI'] = "postgresql+pg8000://" +app.config['SQLALCHEMY_ENGINE_OPTIONS'] = { + "creator": lambda: conn = connector.connect( "project:region:instance-name", # Cloud SQL Instance Connection Name "pg8000", user="my-user", @@ -511,15 +486,6 @@ def getconn(): db="my-database", ip_type="public" # "private" for private IP ) - return conn - - -app = Flask(__name__) - -# configure Flask-SQLAlchemy to use Python Connector -app.config['SQLALCHEMY_DATABASE_URI'] = "postgresql+pg8000://" -app.config['SQLALCHEMY_ENGINE_OPTIONS'] = { - "creator": getconn } # initialize the app with the extension @@ -540,38 +506,27 @@ your web application using [SQLAlchemy ORM](https://docs.sqlalchemy.org/en/14/or through the following: ```python -from sqlalchemy import create_engine -from sqlalchemy.engine import Engine +import sqlalchemy from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from google.cloud.sql.connector import Connector -# helper function to return SQLAlchemy connection pool -def init_connection_pool(connector: Connector) -> Engine: - # Python Connector database connection function - def getconn(): - conn = connector.connect( - "project:region:instance-name", # Cloud SQL Instance Connection Name - "pg8000", - user="my-user", - password="my-password", - db="my-database", - ip_type="public" # "private" for private IP - ) - return conn - - SQLALCHEMY_DATABASE_URL = "postgresql+pg8000://" - - engine = create_engine( - SQLALCHEMY_DATABASE_URL , creator=getconn - ) - return engine # initialize Cloud SQL Python Connector connector = Connector() # create connection pool engine -engine = init_connection_pool(connector) +engine = sqlalchemy.create_engine( + "postgresql+pg8000://", + creator=lambda: connector.connect( + "project:region:instance-name", # Cloud SQL Instance Connection Name + "pg8000", + user="my-user", + password="my-password", + db="my-database", + ip_type="public" # "private" for private IP + ), +) # create SQLAlchemy ORM session SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -640,40 +595,29 @@ async def main(): #### SQLAlchemy Async Engine ```python -import asyncpg - import sqlalchemy from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from google.cloud.sql.connector import Connector, create_async_connector -async def init_connection_pool(connector: Connector) -> AsyncEngine: - # creation function to generate asyncpg connections as 'async_creator' arg - async def getconn() -> asyncpg.Connection: - conn: asyncpg.Connection = await connector.connect_async( + +async def main(): + # initialize Connector object for connections to Cloud SQL + connector = await create_async_connector() + + # The Cloud SQL Python Connector can be used along with SQLAlchemy using the + # 'async_creator' argument to 'create_async_engine' + pool = create_async_engine( + "postgresql+asyncpg://", + async_creator=lambda: connector.connect_async( "project:region:instance", # Cloud SQL instance connection name "asyncpg", user="my-user", password="my-password", db="my-db-name" # ... additional database driver args - ) - return conn - - # The Cloud SQL Python Connector can be used along with SQLAlchemy using the - # 'async_creator' argument to 'create_async_engine' - pool = create_async_engine( - "postgresql+asyncpg://", - async_creator=getconn, + ), ) - return pool - -async def main(): - # initialize Connector object for connections to Cloud SQL - connector = await create_async_connector() - - # initialize connection pool - pool = await init_connection_pool(connector) # example query async with pool.connect() as conn: @@ -744,33 +688,24 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from google.cloud.sql.connector import Connector -async def init_connection_pool(connector: Connector) -> AsyncEngine: - # creation function to generate asyncpg connections as 'async_creator' arg - async def getconn() -> asyncpg.Connection: - conn: asyncpg.Connection = await connector.connect_async( - "project:region:instance", # Cloud SQL instance connection name - "asyncpg", - user="my-user", - password="my-password", - db="my-db-name" - # ... additional database driver args - ) - return conn - - # The Cloud SQL Python Connector can be used along with SQLAlchemy using the - # 'async_creator' argument to 'create_async_engine' - pool = create_async_engine( - "postgresql+asyncpg://", - async_creator=getconn, - ) - return pool async def main(): # initialize Connector object for connections to Cloud SQL loop = asyncio.get_running_loop() async with Connector(loop=loop) as connector: - # initialize connection pool - pool = await init_connection_pool(connector) + # The Cloud SQL Python Connector can be used along with SQLAlchemy using the + # 'async_creator' argument to 'create_async_engine' + pool = create_async_engine( + "postgresql+asyncpg://", + async_creator=lambda: connector.connect_async( + "project:region:instance", # Cloud SQL instance connection name + "asyncpg", + user="my-user", + password="my-password", + db="my-db-name" + # ... additional database driver args + ), + ) # example query async with pool.connect() as conn: diff --git a/google/cloud/sql/connector/asyncpg.py b/google/cloud/sql/connector/asyncpg.py index 47d34895f..2fbc30273 100644 --- a/google/cloud/sql/connector/asyncpg.py +++ b/google/cloud/sql/connector/asyncpg.py @@ -28,21 +28,22 @@ async def connect( ) -> "asyncpg.Connection": """Helper function to create an asyncpg DB-API connection object. - :type ip_address: str - :param ip_address: A string containing an IP address for the Cloud SQL - instance. + Args: + ip_address (str): A string containing an IP address for the Cloud SQL + instance. + ctx (ssl.SSLContext): An SSLContext object created from the Cloud SQL + server CA cert and ephemeral cert. + server CA cert and ephemeral cert. + kwargs: Keyword arguments for establishing asyncpg connection + object to Cloud SQL instance. - :type ctx: ssl.SSLContext - :param ctx: An SSLContext object created from the Cloud SQL server CA - cert and ephemeral cert. - - :type kwargs: Any - :param kwargs: Keyword arguments for establishing asyncpg connection - object to Cloud SQL instance. - - :rtype: asyncpg.Connection - :returns: An asyncpg.Connection object to a Cloud SQL instance. + Returns: + asyncpg.Connection: An asyncpg connection to the Cloud SQL + instance. + Raises: + ImportError: The asyncpg module cannot be imported. """ + try: import asyncpg except ImportError: diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index 556a01bde..11508ce17 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -59,8 +59,7 @@ def __init__( driver: Optional[str] = None, user_agent: Optional[str] = None, ) -> None: - """ - Establish the client to be used for Cloud SQL Admin API requests. + """Establishes the client to be used for Cloud SQL Admin API requests. Args: sqladmin_api_endpoint (str): Base URL to use when calling @@ -100,24 +99,22 @@ async def _get_metadata( region: str, instance: str, ) -> dict[str, Any]: - """Requests metadata from the Cloud SQL Instance - and returns a dictionary containing the IP addresses and certificate - authority of the Cloud SQL Instance. - - :type project: str - :param project: - A string representing the name of the project. + """Requests metadata from the Cloud SQL Instance and returns a dictionary + containing the IP addresses and certificate authority of the Cloud SQL + Instance. - :type region: str - :param region : A string representing the name of the region. + Args: + project (str): A string representing the name of the project. + region (str): A string representing the name of the region. + instance (str): A string representing the name of the instance. - :type instance: str - :param instance: A string representing the name of the instance. + Returns: + A dictionary containing a dictionary of all IP addresses + and their type and a string representing the certificate authority. - :rtype: dict[str: Union[dict, str]] - :returns: Returns a dictionary containing a dictionary of all IP - addresses and their type and a string representing the - certificate authority. + Raises: + ValueError: Provided region does not match the region of the + Cloud SQL instance. """ headers = { @@ -189,23 +186,17 @@ async def _get_ephemeral( ) -> tuple[str, datetime.datetime]: """Asynchronously requests an ephemeral certificate from the Cloud SQL Instance. - :type project: str - :param project : A string representing the name of the project. - - :type instance: str - :param instance: A string representing the name of the instance. - - :type pub_key: - :param str: A string representing PEM-encoded RSA public key. - - :type enable_iam_auth: bool - :param enable_iam_auth - Enables automatic IAM database authentication for Postgres or MySQL - instances. + Args: + project (str): A string representing the name of the project. + instance (str): string representing the name of the instance. + pub_key (str): A string representing PEM-encoded RSA public key. + enable_iam_auth (bool): Enables automatic IAM database + authentication for Postgres or MySQL instances. - :rtype: str - :returns: An ephemeral certificate from the Cloud SQL instance that allows - authorized connections to the instance. + Returns: + A tuple containing an ephemeral certificate from + the Cloud SQL instance as well as a datetime object + representing the expiration time of the certificate. """ headers = { "Authorization": f"Bearer {self._credentials.token}", diff --git a/google/cloud/sql/connector/pg8000.py b/google/cloud/sql/connector/pg8000.py index baaee6615..5a43ad319 100644 --- a/google/cloud/sql/connector/pg8000.py +++ b/google/cloud/sql/connector/pg8000.py @@ -26,17 +26,19 @@ def connect( ) -> "pg8000.dbapi.Connection": """Helper function to create a pg8000 DB-API connection object. - :type ip_address: str - :param ip_address: A string containing an IP address for the Cloud SQL - instance. - - :type sock: ssl.SSLSocket - :param sock: An SSLSocket object created from the Cloud SQL server CA - cert and ephemeral cert. - - - :rtype: pg8000.dbapi.Connection - :returns: A pg8000 Connection object for the Cloud SQL instance. + Args: + ip_address (str): A string containing an IP address for the Cloud SQL + instance. + sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL + server CA cert and ephemeral cert. + kwargs: Additional arguments to pass to the pg8000 connect method. + + Returns: + pg8000.dbapi.Connection: A pg8000 connection to the Cloud SQL + instance. + + Raises: + ImportError: The pg8000 module cannot be imported. """ try: import pg8000 diff --git a/google/cloud/sql/connector/pymysql.py b/google/cloud/sql/connector/pymysql.py index f83f7076c..e01cfed08 100644 --- a/google/cloud/sql/connector/pymysql.py +++ b/google/cloud/sql/connector/pymysql.py @@ -26,16 +26,18 @@ def connect( ) -> "pymysql.connections.Connection": """Helper function to create a pymysql DB-API connection object. - :type ip_address: str - :param ip_address: A string containing an IP address for the Cloud SQL - instance. - - :type sock: ssl.SSLSocket - :param sock: An SSLSocket object created from the Cloud SQL server CA - cert and ephemeral cert. - - :rtype: pymysql.Connection - :returns: A PyMySQL Connection object for the Cloud SQL instance. + Args: + ip_address (str): A string containing an IP address for the Cloud SQL + instance. + sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL + server CA cert and ephemeral cert. + + Returns: + pymysql.connections.Connection: A pymysql connection to the Cloud SQL + instance. + + Raises: + ImportError: The pymysql module cannot be imported. """ try: import pymysql diff --git a/google/cloud/sql/connector/pytds.py b/google/cloud/sql/connector/pytds.py index 3128fdb6a..6cc3c0934 100644 --- a/google/cloud/sql/connector/pytds.py +++ b/google/cloud/sql/connector/pytds.py @@ -27,17 +27,18 @@ def connect(ip_address: str, sock: ssl.SSLSocket, **kwargs: Any) -> "pytds.Connection": """Helper function to create a pytds DB-API connection object. - :type ip_address: str - :param ip_address: A string containing an IP address for the Cloud SQL - instance. + Args: + ip_address (str): A string containing an IP address for the Cloud SQL + instance. + sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL + server CA cert and ephemeral cert. - :type sock: ssl.SSLSocket - :param sock: An SSLSocket object created from the Cloud SQL server CA - cert and ephemeral cert. + Returns: + pytds.Connection: A pytds connection to the Cloud SQL + instance. - - :rtype: pytds.Connection - :returns: A pytds Connection object for the Cloud SQL instance. + Raises: + ImportError: The pytds module cannot be imported. """ try: import pytds diff --git a/google/cloud/sql/connector/rate_limiter.py b/google/cloud/sql/connector/rate_limiter.py index 38e7b94bc..be9c68c64 100644 --- a/google/cloud/sql/connector/rate_limiter.py +++ b/google/cloud/sql/connector/rate_limiter.py @@ -19,24 +19,16 @@ class AsyncRateLimiter(object): - """ - An asyncio-compatible rate limiter which uses the Token Bucket algorithm - (https://en.wikipedia.org/wiki/Token_bucket) to limit the number of function calls over a time interval using an event queue. - - :type max_capacity: int - :param: max_capacity: - The maximum capacity of tokens the bucket will store at any one time. - Default: 1 - - :type rate: float - :param: rate: - The number of tokens that should be added per second. - - :type loop: asyncio.AbstractEventLoop - :param: loop: - The event loop to use. If not provided, the default event loop will be used. - + """An asyncio-compatible rate limiter which uses the Token Bucket algorithm + (https://en.wikipedia.org/wiki/Token_bucket) to limit the number + of function calls over a time interval using an event queue. + Args: + max_capacity (int): The maximum capacity of tokens the bucket + will store at any one time. Default: 1 + rate (float): The number of tokens that should be added per second. + loop (asyncio.AbstractEventLoop): The event loop to use. + If not provided, the default event loop will be used. """ def __init__( diff --git a/google/cloud/sql/connector/refresh_utils.py b/google/cloud/sql/connector/refresh_utils.py index 173f0d2ee..a90d40536 100644 --- a/google/cloud/sql/connector/refresh_utils.py +++ b/google/cloud/sql/connector/refresh_utils.py @@ -44,8 +44,11 @@ def _seconds_until_refresh( Usually the duration will be half of the time until certificate expiration. - :rtype: int - :returns: Time in seconds to wait before performing next refresh. + Args: + expiration (datetime.datetime): The expiration time of the certificate. + + Returns: + int: Time in seconds to wait before performing next refresh. """ duration = int( @@ -81,16 +84,14 @@ def _downscope_credentials( ) -> Credentials: """Generate a down-scoped credential. - :type credentials: google.auth.credentials.Credentials - :param credentials - Credentials object used to generate down-scoped credentials. - - :type scopes: list[str] - :param scopes - List of Google scopes to include in down-scoped credentials object. + Args: + credentials (google.auth.credentials.Credentials): + Credentials object used to generate down-scoped credentials. + scopes (list[str]): List of Google scopes to + include in down-scoped credentials object. - :rtype: google.auth.credentials.Credentials - :returns: Down-scoped credentials object. + Returns: + google.auth.credentials.Credentials: Down-scoped credentials object. """ # credentials sourced from a service account or metadata are children of # Scoped class and are capable of being re-scoped diff --git a/google/cloud/sql/connector/utils.py b/google/cloud/sql/connector/utils.py index 8caa73af6..dd0aec344 100755 --- a/google/cloud/sql/connector/utils.py +++ b/google/cloud/sql/connector/utils.py @@ -79,16 +79,14 @@ async def write_to_file( def format_database_user(database_version: str, user: str) -> str: - """ - Format database `user` param for Cloud SQL automatic IAM authentication. + """Format database `user` param for Cloud SQL automatic IAM authentication. - :type database_version: str - :param database_version - Cloud SQL database version. (i.e. POSTGRES_14, MYSQL8_0, etc.) + Args: + database_version (str): Cloud SQL database version. + user (str): Database username to connect to Cloud SQL database with. - :type user: str - :param user - Database username to connect to Cloud SQL database with. + Returns: + str: Formatted database username. """ # remove suffix for Postgres service accounts if database_version.startswith("POSTGRES"): diff --git a/google/cloud/sql/connector/version.py b/google/cloud/sql/connector/version.py index f89ebde3c..edbff2e69 100644 --- a/google/cloud/sql/connector/version.py +++ b/google/cloud/sql/connector/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.18.0" +__version__ = "1.18.1" diff --git a/pyproject.toml b/pyproject.toml index dec2ff489..8a694369b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,8 @@ build-backend = "setuptools.build_meta" description = "Google Cloud SQL Python Connector library" name = "cloud-sql-python-connector" authors = [{ name = "Google LLC", email = "googleapis-packages@google.com" }] -license = { text = "Apache 2.0" } +license = "Apache-2.0" +license-files = ["LICENSE"] requires-python = ">=3.9" readme = "README.md" classifiers = [ @@ -30,7 +31,6 @@ classifiers = [ # "Development Status :: 5 - Production/Stable" "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.9", diff --git a/tests/conftest.py b/tests/conftest.py index c75de48cb..83d7a78f3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,11 +17,15 @@ import asyncio import os import socket +import ssl from threading import Thread -from typing import Any, AsyncGenerator, Generator +from typing import Any, AsyncGenerator +from aiofiles.tempfile import TemporaryDirectory from aiohttp import web +from cryptography.hazmat.primitives import serialization import pytest # noqa F401 Needed to run the tests +from unit.mocks import create_ssl_context # type: ignore from unit.mocks import FakeCredentials # type: ignore from unit.mocks import FakeCSQLInstance # type: ignore @@ -29,6 +33,7 @@ from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.utils import generate_keys +from google.cloud.sql.connector.utils import write_to_file SCOPES = ["https://www.googleapis.com/auth/sqlservice.admin"] @@ -79,25 +84,60 @@ def fake_credentials() -> FakeCredentials: return FakeCredentials() -def mock_server(server_sock: socket.socket) -> None: - """Create mock server listening on specified ip_address and port.""" +async def start_proxy_server(instance: FakeCSQLInstance) -> None: + """Run local proxy server capable of performing mTLS""" ip_address = "127.0.0.1" port = 3307 - server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_sock.bind((ip_address, port)) - server_sock.listen(0) - server_sock.accept() + # create socket + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + # create SSL/TLS context + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.minimum_version = ssl.TLSVersion.TLSv1_3 + # tmpdir and its contents are automatically deleted after the CA cert + # and cert chain are loaded into the SSLcontext. The values + # need to be written to files in order to be loaded by the SSLContext + server_key_bytes = instance.server_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + async with TemporaryDirectory() as tmpdir: + server_filename, _, key_filename = await write_to_file( + tmpdir, instance.server_cert_pem, "", server_key_bytes + ) + context.load_cert_chain(server_filename, key_filename) + # allow socket to be re-used + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + # bind socket to Cloud SQL proxy server port on localhost + sock.bind((ip_address, port)) + # listen for incoming connections + sock.listen(5) + + with context.wrap_socket(sock, server_side=True) as ssock: + while True: + conn, _ = ssock.accept() + conn.close() + + +@pytest.fixture(scope="session") +def proxy_server(fake_instance: FakeCSQLInstance) -> None: + """Run local proxy server capable of performing mTLS""" + thread = Thread( + target=asyncio.run, + args=( + start_proxy_server( + fake_instance, + ), + ), + daemon=True, + ) + thread.start() + thread.join(1.0) # add a delay to allow the proxy server to start @pytest.fixture -def server() -> Generator: - """Create thread with server listening on proper port""" - server_sock = socket.socket() - thread = Thread(target=mock_server, args=(server_sock,), daemon=True) - thread.start() - yield thread - server_sock.close() - thread.join() +async def context(fake_instance: FakeCSQLInstance) -> ssl.SSLContext: + return await create_ssl_context(fake_instance) @pytest.fixture @@ -107,7 +147,7 @@ def kwargs() -> Any: return kwargs -@pytest.fixture +@pytest.fixture(scope="session") def fake_instance() -> FakeCSQLInstance: return FakeCSQLInstance() diff --git a/tests/system/test_asyncpg_connection.py b/tests/system/test_asyncpg_connection.py index dfcc3941b..1aae1f7e6 100644 --- a/tests/system/test_asyncpg_connection.py +++ b/tests/system/test_asyncpg_connection.py @@ -32,8 +32,10 @@ async def create_sqlalchemy_engine( user: str, password: str, db: str, + ip_type: str = "public", refresh_strategy: str = "background", resolver: Union[type[DefaultResolver], type[DnsResolver]] = DefaultResolver, + **kwargs: Any, ) -> tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, Connector]: """Creates a connection pool for a Cloud SQL instance and returns the pool and the connector. Callers are responsible for closing the pool and the @@ -63,6 +65,9 @@ async def create_sqlalchemy_engine( The database user's password, e.g., secret-password db (str): The name of the database, e.g., mydb + ip_type (str): + The IP type of the Cloud SQL instance to connect to. Can be one + of "public", "private", or "psc". refresh_strategy (Optional[str]): Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" or "background". For serverless environments use "lazy" to avoid @@ -78,21 +83,18 @@ async def create_sqlalchemy_engine( loop=loop, refresh_strategy=refresh_strategy, resolver=resolver ) - async def getconn() -> asyncpg.Connection: - conn: asyncpg.Connection = await connector.connect_async( + # create SQLAlchemy connection pool + engine = sqlalchemy.ext.asyncio.create_async_engine( + "postgresql+asyncpg://", + async_creator=lambda: connector.connect_async( instance_connection_name, "asyncpg", user=user, password=password, db=db, - ip_type="public", # can also be "private" or "psc" - ) - return conn - - # create SQLAlchemy connection pool - engine = sqlalchemy.ext.asyncio.create_async_engine( - "postgresql+asyncpg://", - async_creator=getconn, + ip_type=ip_type, # can be "public", "private" or "psc" + **kwargs, # additional asyncpg connection args + ), execution_options={"isolation_level": "AUTOCOMMIT"}, ) return engine, connector @@ -103,6 +105,7 @@ async def create_asyncpg_pool( user: str, password: str, db: str, + ip_type: str = "public", refresh_strategy: str = "background", ) -> tuple[asyncpg.Pool, Connector]: """Creates a native asyncpg connection pool for a Cloud SQL instance and @@ -132,6 +135,9 @@ async def create_asyncpg_pool( The database user's password, e.g., secret-password db (str): The name of the database, e.g., mydb + ip_type (str): + The IP type of the Cloud SQL instance to connect to. Can be one + of "public", "private", or "psc". refresh_strategy (Optional[str]): Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" or "background". For serverless environments use "lazy" to avoid @@ -149,7 +155,7 @@ async def getconn( user=user, password=password, db=db, - ip_type="public", # can also be "private" or "psc", + ip_type=ip_type, # can be "public", "private" or "psc" **kwargs, ) return conn @@ -165,8 +171,11 @@ async def test_sqlalchemy_connection_with_asyncpg() -> None: user = os.environ["POSTGRES_USER"] password = os.environ["POSTGRES_PASS"] db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") - pool, connector = await create_sqlalchemy_engine(inst_conn_name, user, password, db) + pool, connector = await create_sqlalchemy_engine( + inst_conn_name, user, password, db, ip_type + ) async with pool.connect() as conn: res = (await conn.execute(sqlalchemy.text("SELECT 1"))).fetchone() @@ -181,9 +190,10 @@ async def test_lazy_sqlalchemy_connection_with_asyncpg() -> None: user = os.environ["POSTGRES_USER"] password = os.environ["POSTGRES_PASS"] db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") pool, connector = await create_sqlalchemy_engine( - inst_conn_name, user, password, db, "lazy" + inst_conn_name, user, password, db, ip_type, "lazy" ) async with pool.connect() as conn: @@ -199,9 +209,34 @@ async def test_custom_SAN_with_dns_sqlalchemy_connection_with_asyncpg() -> None: user = os.environ["POSTGRES_USER"] password = os.environ["POSTGRES_CUSTOMER_CAS_PASS"] db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") pool, connector = await create_sqlalchemy_engine( - inst_conn_name, user, password, db, resolver=DnsResolver + inst_conn_name, user, password, db, ip_type, resolver=DnsResolver + ) + + async with pool.connect() as conn: + res = (await conn.execute(sqlalchemy.text("SELECT 1"))).fetchone() + assert res[0] == 1 + + await connector.close_async() + + +async def test_MCP_sqlalchemy_connection_with_asyncpg() -> None: + """Basic test to get time from database using MCP enabled instance.""" + inst_conn_name = os.environ["POSTGRES_MCP_CONNECTION_NAME"] + user = os.environ["POSTGRES_USER"] + password = os.environ["POSTGRES_MCP_PASS"] + db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") + + pool, connector = await create_sqlalchemy_engine( + inst_conn_name, + user, + password, + db, + ip_type, + statement_cache_size=0, ) async with pool.connect() as conn: @@ -217,8 +252,11 @@ async def test_connection_with_asyncpg() -> None: user = os.environ["POSTGRES_USER"] password = os.environ["POSTGRES_PASS"] db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") - pool, connector = await create_asyncpg_pool(inst_conn_name, user, password, db) + pool, connector = await create_asyncpg_pool( + inst_conn_name, user, password, db, ip_type + ) async with pool.acquire() as conn: res = await conn.fetch("SELECT 1") @@ -233,9 +271,10 @@ async def test_lazy_connection_with_asyncpg() -> None: user = os.environ["POSTGRES_USER"] password = os.environ["POSTGRES_PASS"] db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") pool, connector = await create_asyncpg_pool( - inst_conn_name, user, password, db, "lazy" + inst_conn_name, user, password, db, ip_type, "lazy" ) async with pool.acquire() as conn: diff --git a/tests/system/test_asyncpg_iam_auth.py b/tests/system/test_asyncpg_iam_auth.py index 6e96d96bd..0e0c01e83 100644 --- a/tests/system/test_asyncpg_iam_auth.py +++ b/tests/system/test_asyncpg_iam_auth.py @@ -17,7 +17,6 @@ import asyncio import os -import asyncpg import sqlalchemy import sqlalchemy.ext.asyncio @@ -28,6 +27,7 @@ async def create_sqlalchemy_engine( instance_connection_name: str, user: str, db: str, + ip_type: str = "public", refresh_strategy: str = "background", ) -> tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, Connector]: """Creates a connection pool for a Cloud SQL instance and returns the pool @@ -56,6 +56,9 @@ async def create_sqlalchemy_engine( e.g., my-email@test.com, service-account@project-id.iam db (str): The name of the database, e.g., mydb + ip_type (str): + The IP type of the Cloud SQL instance to connect to. Can be one + of "public", "private", or "psc". refresh_strategy (Optional[str]): Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" or "background". For serverless environments use "lazy" to avoid @@ -64,21 +67,17 @@ async def create_sqlalchemy_engine( loop = asyncio.get_running_loop() connector = Connector(loop=loop, refresh_strategy=refresh_strategy) - async def getconn() -> asyncpg.Connection: - conn: asyncpg.Connection = await connector.connect_async( + # create SQLAlchemy connection pool + engine = sqlalchemy.ext.asyncio.create_async_engine( + "postgresql+asyncpg://", + async_creator=lambda: connector.connect_async( instance_connection_name, "asyncpg", user=user, db=db, - ip_type="public", # can also be "private" or "psc" + ip_type=ip_type, # can be "public", "private" or "psc" enable_iam_auth=True, - ) - return conn - - # create SQLAlchemy connection pool - engine = sqlalchemy.ext.asyncio.create_async_engine( - "postgresql+asyncpg://", - async_creator=getconn, + ), execution_options={"isolation_level": "AUTOCOMMIT"}, ) return engine, connector @@ -89,8 +88,9 @@ async def test_iam_authn_connection_with_asyncpg() -> None: inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"] user = os.environ["POSTGRES_IAM_USER"] db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") - pool, connector = await create_sqlalchemy_engine(inst_conn_name, user, db) + pool, connector = await create_sqlalchemy_engine(inst_conn_name, user, db, ip_type) async with pool.connect() as conn: res = (await conn.execute(sqlalchemy.text("SELECT 1"))).fetchone() @@ -104,8 +104,11 @@ async def test_lazy_iam_authn_connection_with_asyncpg() -> None: inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"] user = os.environ["POSTGRES_IAM_USER"] db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") - pool, connector = await create_sqlalchemy_engine(inst_conn_name, user, db, "lazy") + pool, connector = await create_sqlalchemy_engine( + inst_conn_name, user, db, ip_type, "lazy" + ) async with pool.connect() as conn: res = (await conn.execute(sqlalchemy.text("SELECT 1"))).fetchone() diff --git a/tests/system/test_connector_object.py b/tests/system/test_connector_object.py index 258b80aaf..3dd22fda5 100644 --- a/tests/system/test_connector_object.py +++ b/tests/system/test_connector_object.py @@ -20,7 +20,6 @@ import os from threading import Thread -import google.auth import pymysql import pytest import sqlalchemy @@ -39,6 +38,7 @@ def getconn() -> pymysql.connections.Connection: user=os.environ["MYSQL_USER"], password=os.environ["MYSQL_PASS"], db=os.environ["MYSQL_DB"], + ip_type=os.environ.get("IP_TYPE", "public"), ) return conn @@ -50,20 +50,6 @@ def getconn() -> pymysql.connections.Connection: return pool -def test_connector_with_credentials() -> None: - """Test Connector object connection with credentials loaded from file.""" - credentials, _ = google.auth.load_credentials_from_file( - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] - ) - with Connector(credentials=credentials) as connector: - pool = init_connection_engine(connector) - - with pool.connect() as conn: - result = conn.execute(sqlalchemy.text("SELECT 1")).fetchone() - assert isinstance(result[0], int) - assert result[0] == 1 - - def test_multiple_connectors() -> None: """Test that same Cloud SQL instance can connect with two Connector objects.""" first_connector = Connector() diff --git a/tests/system/test_pg8000_connection.py b/tests/system/test_pg8000_connection.py index c47b860c9..c7074b0cf 100644 --- a/tests/system/test_pg8000_connection.py +++ b/tests/system/test_pg8000_connection.py @@ -20,7 +20,6 @@ # [START cloud_sql_connector_postgres_pg8000] from typing import Union -import pg8000 import sqlalchemy from google.cloud.sql.connector import Connector @@ -33,6 +32,7 @@ def create_sqlalchemy_engine( user: str, password: str, db: str, + ip_type: str = "public", refresh_strategy: str = "background", resolver: Union[type[DefaultResolver], type[DnsResolver]] = DefaultResolver, ) -> tuple[sqlalchemy.engine.Engine, Connector]: @@ -65,6 +65,9 @@ def create_sqlalchemy_engine( The database user's password, e.g., secret-password db (str): The name of the database, e.g., mydb + ip_type (str): + The IP type of the Cloud SQL instance to connect to. Can be one + of "public", "private", or "psc". refresh_strategy (Optional[str]): Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" or "background". For serverless environments use "lazy" to avoid @@ -77,21 +80,17 @@ def create_sqlalchemy_engine( """ connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver) - def getconn() -> pg8000.dbapi.Connection: - conn: pg8000.dbapi.Connection = connector.connect( + # create SQLAlchemy connection pool + engine = sqlalchemy.create_engine( + "postgresql+pg8000://", + creator=lambda: connector.connect( instance_connection_name, "pg8000", user=user, password=password, db=db, - ip_type="public", # can also be "private" or "psc" - ) - return conn - - # create SQLAlchemy connection pool - engine = sqlalchemy.create_engine( - "postgresql+pg8000://", - creator=getconn, + ip_type=ip_type, # can be "public", "private" or "psc" + ), ) return engine, connector @@ -105,8 +104,11 @@ def test_pg8000_connection() -> None: user = os.environ["POSTGRES_USER"] password = os.environ["POSTGRES_PASS"] db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") - engine, connector = create_sqlalchemy_engine(inst_conn_name, user, password, db) + engine, connector = create_sqlalchemy_engine( + inst_conn_name, user, password, db, ip_type + ) with engine.connect() as conn: time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() conn.commit() @@ -121,9 +123,10 @@ def test_lazy_pg8000_connection() -> None: user = os.environ["POSTGRES_USER"] password = os.environ["POSTGRES_PASS"] db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") engine, connector = create_sqlalchemy_engine( - inst_conn_name, user, password, db, "lazy" + inst_conn_name, user, password, db, ip_type, "lazy" ) with engine.connect() as conn: time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() @@ -139,8 +142,11 @@ def test_CAS_pg8000_connection() -> None: user = os.environ["POSTGRES_USER"] password = os.environ["POSTGRES_CAS_PASS"] db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") - engine, connector = create_sqlalchemy_engine(inst_conn_name, user, password, db) + engine, connector = create_sqlalchemy_engine( + inst_conn_name, user, password, db, ip_type + ) with engine.connect() as conn: time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() conn.commit() @@ -155,8 +161,11 @@ def test_customer_managed_CAS_pg8000_connection() -> None: user = os.environ["POSTGRES_USER"] password = os.environ["POSTGRES_CUSTOMER_CAS_PASS"] db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") - engine, connector = create_sqlalchemy_engine(inst_conn_name, user, password, db) + engine, connector = create_sqlalchemy_engine( + inst_conn_name, user, password, db, ip_type + ) with engine.connect() as conn: time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() conn.commit() @@ -171,9 +180,29 @@ def test_custom_SAN_with_dns_pg8000_connection() -> None: user = os.environ["POSTGRES_USER"] password = os.environ["POSTGRES_CUSTOMER_CAS_PASS"] db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") + + engine, connector = create_sqlalchemy_engine( + inst_conn_name, user, password, db, ip_type, resolver=DnsResolver + ) + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + assert type(curr_time) is datetime + connector.close() + + +def test_MCP_pg8000_connection() -> None: + """Basic test to get time from database using MCP enabled instance.""" + inst_conn_name = os.environ["POSTGRES_MCP_CONNECTION_NAME"] + user = os.environ["POSTGRES_USER"] + password = os.environ["POSTGRES_MCP_PASS"] + db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") engine, connector = create_sqlalchemy_engine( - inst_conn_name, user, password, db, resolver=DnsResolver + inst_conn_name, user, password, db, ip_type ) with engine.connect() as conn: time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() diff --git a/tests/system/test_pg8000_iam_auth.py b/tests/system/test_pg8000_iam_auth.py index 9a8607bcb..38ee76612 100644 --- a/tests/system/test_pg8000_iam_auth.py +++ b/tests/system/test_pg8000_iam_auth.py @@ -17,7 +17,6 @@ from datetime import datetime import os -import pg8000 import sqlalchemy from google.cloud.sql.connector import Connector @@ -27,6 +26,7 @@ def create_sqlalchemy_engine( instance_connection_name: str, user: str, db: str, + ip_type: str = "public", refresh_strategy: str = "background", ) -> tuple[sqlalchemy.engine.Engine, Connector]: """Creates a connection pool for a Cloud SQL instance and returns the pool @@ -56,6 +56,9 @@ def create_sqlalchemy_engine( e.g., my-email@test.com, service-account@project-id.iam db (str): The name of the database, e.g., mydb + ip_type (str): + The IP type of the Cloud SQL instance to connect to. Can be one + of "public", "private", or "psc". refresh_strategy (Optional[str]): Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" or "background". For serverless environments use "lazy" to avoid @@ -63,21 +66,17 @@ def create_sqlalchemy_engine( """ connector = Connector(refresh_strategy=refresh_strategy) - def getconn() -> pg8000.dbapi.Connection: - conn: pg8000.dbapi.Connection = connector.connect( + # create SQLAlchemy connection pool + engine = sqlalchemy.create_engine( + "postgresql+pg8000://", + creator=lambda: connector.connect( instance_connection_name, "pg8000", user=user, db=db, - ip_type="public", # can also be "private" or "psc" + ip_type=ip_type, # can be "public", "private" or "psc" enable_iam_auth=True, - ) - return conn - - # create SQLAlchemy connection pool - engine = sqlalchemy.create_engine( - "postgresql+pg8000://", - creator=getconn, + ), ) return engine, connector @@ -87,8 +86,9 @@ def test_pg8000_iam_authn_connection() -> None: inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"] user = os.environ["POSTGRES_IAM_USER"] db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") - engine, connector = create_sqlalchemy_engine(inst_conn_name, user, db) + engine, connector = create_sqlalchemy_engine(inst_conn_name, user, db, ip_type) with engine.connect() as conn: time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() conn.commit() @@ -102,8 +102,11 @@ def test_lazy_pg8000_iam_authn_connection() -> None: inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"] user = os.environ["POSTGRES_IAM_USER"] db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") - engine, connector = create_sqlalchemy_engine(inst_conn_name, user, db, "lazy") + engine, connector = create_sqlalchemy_engine( + inst_conn_name, user, db, ip_type, "lazy" + ) with engine.connect() as conn: time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() conn.commit() diff --git a/tests/system/test_pymysql_connection.py b/tests/system/test_pymysql_connection.py index 1e7e26830..f5dd8a1eb 100644 --- a/tests/system/test_pymysql_connection.py +++ b/tests/system/test_pymysql_connection.py @@ -18,7 +18,6 @@ import os # [START cloud_sql_connector_mysql_pymysql] -import pymysql import sqlalchemy from google.cloud.sql.connector import Connector @@ -29,6 +28,7 @@ def create_sqlalchemy_engine( user: str, password: str, db: str, + ip_type: str = "public", refresh_strategy: str = "background", ) -> tuple[sqlalchemy.engine.Engine, Connector]: """Creates a connection pool for a Cloud SQL instance and returns the pool @@ -60,6 +60,9 @@ def create_sqlalchemy_engine( The database user's password, e.g., secret-password db (str): The name of the database, e.g., mydb + ip_type (str): + The IP type of the Cloud SQL instance to connect to. Can be one + of "public", "private", or "psc". refresh_strategy (Optional[str]): Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" or "background". For serverless environments use "lazy" to avoid @@ -67,21 +70,17 @@ def create_sqlalchemy_engine( """ connector = Connector(refresh_strategy=refresh_strategy) - def getconn() -> pymysql.Connection: - conn: pymysql.Connection = connector.connect( + # create SQLAlchemy connection pool + engine = sqlalchemy.create_engine( + "mysql+pymysql://", + creator=lambda: connector.connect( instance_connection_name, "pymysql", user=user, password=password, db=db, - ip_type="public", # can also be "private" or "psc" - ) - return conn - - # create SQLAlchemy connection pool - engine = sqlalchemy.create_engine( - "mysql+pymysql://", - creator=getconn, + ip_type=ip_type, # can be "public", "private" or "psc" + ), ) return engine, connector @@ -95,8 +94,11 @@ def test_pymysql_connection() -> None: user = os.environ["MYSQL_USER"] password = os.environ["MYSQL_PASS"] db = os.environ["MYSQL_DB"] + ip_type = os.environ.get("IP_TYPE", "public") - engine, connector = create_sqlalchemy_engine(inst_conn_name, user, password, db) + engine, connector = create_sqlalchemy_engine( + inst_conn_name, user, password, db, ip_type + ) with engine.connect() as conn: time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() conn.commit() @@ -111,9 +113,29 @@ def test_lazy_pymysql_connection() -> None: user = os.environ["MYSQL_USER"] password = os.environ["MYSQL_PASS"] db = os.environ["MYSQL_DB"] + ip_type = os.environ.get("IP_TYPE", "public") + + engine, connector = create_sqlalchemy_engine( + inst_conn_name, user, password, db, ip_type, "lazy" + ) + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + assert type(curr_time) is datetime + connector.close() + + +def test_MCP_pymysql_connection() -> None: + """Basic test to get time from database using MCP enabled instance.""" + inst_conn_name = os.environ["MYSQL_MCP_CONNECTION_NAME"] + user = os.environ["MYSQL_USER"] + password = os.environ["MYSQL_MCP_PASS"] + db = os.environ["MYSQL_DB"] + ip_type = os.environ.get("IP_TYPE", "public") engine, connector = create_sqlalchemy_engine( - inst_conn_name, user, password, db, "lazy" + inst_conn_name, user, password, db, ip_type ) with engine.connect() as conn: time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() diff --git a/tests/system/test_pymysql_iam_auth.py b/tests/system/test_pymysql_iam_auth.py index 9a617b6f7..da676e66c 100644 --- a/tests/system/test_pymysql_iam_auth.py +++ b/tests/system/test_pymysql_iam_auth.py @@ -17,7 +17,6 @@ from datetime import datetime import os -import pymysql import sqlalchemy from google.cloud.sql.connector import Connector @@ -27,6 +26,7 @@ def create_sqlalchemy_engine( instance_connection_name: str, user: str, db: str, + ip_type: str = "public", refresh_strategy: str = "background", ) -> tuple[sqlalchemy.engine.Engine, Connector]: """Creates a connection pool for a Cloud SQL instance and returns the pool @@ -56,6 +56,9 @@ def create_sqlalchemy_engine( e.g., my-email@test.com -> my-email db (str): The name of the database, e.g., mydb + ip_type (str): + The IP type of the Cloud SQL instance to connect to. Can be one + of "public", "private", or "psc". refresh_strategy (Optional[str]): Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" or "background". For serverless environments use "lazy" to avoid @@ -63,21 +66,17 @@ def create_sqlalchemy_engine( """ connector = Connector(refresh_strategy=refresh_strategy) - def getconn() -> pymysql.Connection: - conn: pymysql.Connection = connector.connect( + # create SQLAlchemy connection pool + engine = sqlalchemy.create_engine( + "mysql+pymysql://", + creator=lambda: connector.connect( instance_connection_name, "pymysql", user=user, db=db, - ip_type="public", # can also be "private" or "psc" + ip_type=ip_type, # can be "public", "private" or "psc" enable_iam_auth=True, - ) - return conn - - # create SQLAlchemy connection pool - engine = sqlalchemy.create_engine( - "mysql+pymysql://", - creator=getconn, + ), ) return engine, connector @@ -87,8 +86,9 @@ def test_pymysql_iam_authn_connection() -> None: inst_conn_name = os.environ["MYSQL_CONNECTION_NAME"] user = os.environ["MYSQL_IAM_USER"] db = os.environ["MYSQL_DB"] + ip_type = os.environ.get("IP_TYPE", "public") - engine, connector = create_sqlalchemy_engine(inst_conn_name, user, db) + engine, connector = create_sqlalchemy_engine(inst_conn_name, user, db, ip_type) with engine.connect() as conn: time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() conn.commit() @@ -102,8 +102,27 @@ def test_lazy_pymysql_iam_authn_connection() -> None: inst_conn_name = os.environ["MYSQL_CONNECTION_NAME"] user = os.environ["MYSQL_IAM_USER"] db = os.environ["MYSQL_DB"] + ip_type = os.environ.get("IP_TYPE", "public") + + engine, connector = create_sqlalchemy_engine( + inst_conn_name, user, db, ip_type, "lazy" + ) + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + assert type(curr_time) is datetime + connector.close() + + +def test_MCP_pymysql_iam_authn_connection() -> None: + """Basic test to get time from database using MCP enabled instance.""" + inst_conn_name = os.environ["MYSQL_MCP_CONNECTION_NAME"] + user = os.environ["MYSQL_IAM_USER"] + db = os.environ["MYSQL_DB"] + ip_type = os.environ.get("IP_TYPE", "public") - engine, connector = create_sqlalchemy_engine(inst_conn_name, user, db, "lazy") + engine, connector = create_sqlalchemy_engine(inst_conn_name, user, db, ip_type) with engine.connect() as conn: time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() conn.commit() diff --git a/tests/system/test_pytds_connection.py b/tests/system/test_pytds_connection.py index fd88d230f..198c3307f 100644 --- a/tests/system/test_pytds_connection.py +++ b/tests/system/test_pytds_connection.py @@ -17,7 +17,6 @@ import os # [START cloud_sql_connector_mysql_pytds] -import pytds import sqlalchemy from google.cloud.sql.connector import Connector @@ -28,6 +27,7 @@ def create_sqlalchemy_engine( user: str, password: str, db: str, + ip_type: str = "public", refresh_strategy: str = "background", ) -> tuple[sqlalchemy.engine.Engine, Connector]: """Creates a connection pool for a Cloud SQL instance and returns the pool @@ -58,6 +58,9 @@ def create_sqlalchemy_engine( The database user's password, e.g., secret-password db (str): The name of the database, e.g., mydb + ip_type (str): + The IP type of the Cloud SQL instance to connect to. Can be one + of "public", "private", or "psc". refresh_strategy (Optional[str]): Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" or "background". For serverless environments use "lazy" to avoid @@ -65,21 +68,17 @@ def create_sqlalchemy_engine( """ connector = Connector(refresh_strategy=refresh_strategy) - def getconn() -> pytds.Connection: - conn: pytds.Connection = connector.connect( + # create SQLAlchemy connection pool + engine = sqlalchemy.create_engine( + "mssql+pytds://", + creator=lambda: connector.connect( instance_connection_name, "pytds", user=user, password=password, db=db, - ip_type="public", # can also be "private" or "psc" - ) - return conn - - # create SQLAlchemy connection pool - engine = sqlalchemy.create_engine( - "mssql+pytds://", - creator=getconn, + ip_type=ip_type, # can be "public", "private" or "psc" + ), ) return engine, connector @@ -93,8 +92,11 @@ def test_pytds_connection() -> None: user = os.environ["SQLSERVER_USER"] password = os.environ["SQLSERVER_PASS"] db = os.environ["SQLSERVER_DB"] + ip_type = os.environ.get("IP_TYPE", "public") - engine, connector = create_sqlalchemy_engine(inst_conn_name, user, password, db) + engine, connector = create_sqlalchemy_engine( + inst_conn_name, user, password, db, ip_type + ) with engine.connect() as conn: res = conn.execute(sqlalchemy.text("SELECT 1")).fetchone() conn.commit() @@ -108,9 +110,10 @@ def test_lazy_pytds_connection() -> None: user = os.environ["SQLSERVER_USER"] password = os.environ["SQLSERVER_PASS"] db = os.environ["SQLSERVER_DB"] + ip_type = os.environ.get("IP_TYPE", "public") engine, connector = create_sqlalchemy_engine( - inst_conn_name, user, password, db, "lazy" + inst_conn_name, user, password, db, ip_type, "lazy" ) with engine.connect() as conn: res = conn.execute(sqlalchemy.text("SELECT 1")).fetchone() diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index cd3299b7f..66bf64a32 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2022 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +16,8 @@ # file containing all mocks used for Cloud SQL Python Connector unit tests +from __future__ import annotations + import datetime import json import ssl @@ -184,28 +186,28 @@ def client_key_signed_cert( .not_valid_after(cert_expiration) # type: ignore ) return ( - cert.sign(priv_key, hashes.SHA256(), default_backend()) + cert.sign(priv_key, hashes.SHA256()) .public_bytes(encoding=serialization.Encoding.PEM) .decode("UTF-8") ) -async def create_ssl_context() -> ssl.SSLContext: +async def create_ssl_context(instance: FakeCSQLInstance) -> ssl.SSLContext: """Helper method to build an ssl.SSLContext for tests""" - # generate keys and certs for test - cert, private_key = generate_cert("my-project", "my-instance") - server_ca_cert = self_signed_cert(cert, private_key) client_private, client_bytes = await generate_keys() client_key: rsa.RSAPublicKey = serialization.load_pem_public_key( - client_bytes.encode("UTF-8"), default_backend() + client_bytes.encode("UTF-8"), ) # type: ignore - ephemeral_cert = client_key_signed_cert(cert, private_key, client_key) - # build default ssl.SSLContext - context = ssl.create_default_context() + ephemeral_cert = client_key_signed_cert( + instance.server_ca, instance.server_key, client_key + ) + # create SSL/TLS context + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = False # load ssl.SSLContext with certs async with TemporaryDirectory() as tmpdir: ca_filename, cert_filename, key_filename = await write_to_file( - tmpdir, server_ca_cert, ephemeral_cert, client_private + tmpdir, instance.server_cert_pem, ephemeral_cert, client_private ) context.load_cert_chain(cert_filename, keyfile=key_filename) context.load_verify_locations(cafile=ca_filename) @@ -279,8 +281,8 @@ async def generate_ephemeral(self, request: Any) -> web.Response: body = await request.json() pub_key = body["public_key"] client_key: rsa.RSAPublicKey = serialization.load_pem_public_key( - pub_key.encode("UTF-8"), default_backend() - ) # type: ignore + pub_key.encode("UTF-8"), + ) ephemeral_cert = client_key_signed_cert( self.server_ca, self.server_key, diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 1a3d60917..3699ddc2d 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -186,17 +186,18 @@ async def test_RefreshAheadCache_close(cache: RefreshAheadCache) -> None: @pytest.mark.asyncio async def test_perform_refresh( cache: RefreshAheadCache, - fake_instance: mocks.FakeCSQLInstance, ) -> None: """ Test that _perform_refresh returns valid ConnectionInfo object. """ instance_metadata = await cache._perform_refresh() - # verify instance metadata object is returned assert isinstance(instance_metadata, ConnectionInfo) # verify instance metadata expiration - assert fake_instance.server_cert.not_valid_after_utc == instance_metadata.expiration + assert ( + cache._client.instance.cert_expiration.replace(microsecond=0) + == instance_metadata.expiration + ) @pytest.mark.asyncio diff --git a/tests/unit/test_monitored_cache.py b/tests/unit/test_monitored_cache.py index 1eea4eb46..1c1f1df86 100644 --- a/tests/unit/test_monitored_cache.py +++ b/tests/unit/test_monitored_cache.py @@ -14,13 +14,13 @@ import asyncio import socket +import ssl import dns.message import dns.rdataclass import dns.rdatatype import dns.resolver from mock import patch -from mocks import create_ssl_context import pytest from google.cloud.sql.connector.client import CloudSQLClient @@ -149,8 +149,10 @@ async def test_MonitoredCache_with_disabled_failover( assert monitored_cache.closed is True -@pytest.mark.usefixtures("server") -async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) -> None: +@pytest.mark.usefixtures("proxy_server") +async def test_MonitoredCache_check_domain_name( + context: ssl.SSLContext, fake_client: CloudSQLClient +) -> None: """ Test that MonitoredCache is closed when _check_domain_name has domain change. """ @@ -177,11 +179,9 @@ async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) -> # configure a local socket ip_addr = "127.0.0.1" - context = await create_ssl_context() sock = context.wrap_socket( socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, - do_handshake_on_connect=False, ) # verify socket is open assert sock.fileno() != -1 @@ -198,8 +198,10 @@ async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) -> assert sock.fileno() == -1 -@pytest.mark.usefixtures("server") -async def test_MonitoredCache_purge_closed_sockets(fake_client: CloudSQLClient) -> None: +@pytest.mark.usefixtures("proxy_server") +async def test_MonitoredCache_purge_closed_sockets( + context: ssl.SSLContext, fake_client: CloudSQLClient +) -> None: """ Test that MonitoredCache._purge_closed_sockets removes closed sockets from cache. @@ -215,11 +217,9 @@ async def test_MonitoredCache_purge_closed_sockets(fake_client: CloudSQLClient) ) # configure a local socket ip_addr = "127.0.0.1" - context = await create_ssl_context() sock = context.wrap_socket( socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, - do_handshake_on_connect=False, ) # set failover to 0 to disable polling diff --git a/tests/unit/test_pg8000.py b/tests/unit/test_pg8000.py index e01a53445..2c003b8a9 100644 --- a/tests/unit/test_pg8000.py +++ b/tests/unit/test_pg8000.py @@ -15,26 +15,22 @@ """ import socket +import ssl from typing import Any from mock import patch -from mocks import create_ssl_context import pytest from google.cloud.sql.connector.pg8000 import connect -@pytest.mark.usefixtures("server") -@pytest.mark.asyncio -async def test_pg8000(kwargs: Any) -> None: +@pytest.mark.usefixtures("proxy_server") +async def test_pg8000(context: ssl.SSLContext, kwargs: Any) -> None: """Test to verify that pg8000 gets to proper connection call.""" ip_addr = "127.0.0.1" - # build ssl.SSLContext - context = await create_ssl_context() sock = context.wrap_socket( socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, - do_handshake_on_connect=False, ) with patch("pg8000.dbapi.connect") as mock_connect: mock_connect.return_value = True diff --git a/tests/unit/test_pymysql.py b/tests/unit/test_pymysql.py index 66b1f22a3..13cd8e98a 100644 --- a/tests/unit/test_pymysql.py +++ b/tests/unit/test_pymysql.py @@ -19,7 +19,6 @@ from typing import Any from mock import patch -from mocks import create_ssl_context import pytest from google.cloud.sql.connector.pymysql import connect as pymysql_connect @@ -33,17 +32,14 @@ def connect(sock: ssl.SSLSocket) -> None: # type: ignore assert isinstance(sock, ssl.SSLSocket) -@pytest.mark.usefixtures("server") +@pytest.mark.usefixtures("proxy_server") @pytest.mark.asyncio -async def test_pymysql(kwargs: Any) -> None: +async def test_pymysql(context: ssl.SSLContext, kwargs: Any) -> None: """Test to verify that pymysql gets to proper connection call.""" ip_addr = "127.0.0.1" - # build ssl.SSLContext - context = await create_ssl_context() sock = context.wrap_socket( socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, - do_handshake_on_connect=False, ) kwargs["timeout"] = 30 with patch("pymysql.Connection") as mock_connect: diff --git a/tests/unit/test_pytds.py b/tests/unit/test_pytds.py index 9efe00ee5..faa20ad8c 100644 --- a/tests/unit/test_pytds.py +++ b/tests/unit/test_pytds.py @@ -16,10 +16,10 @@ import platform import socket +import ssl from typing import Any from mock import patch -from mocks import create_ssl_context import pytest from google.cloud.sql.connector.exceptions import PlatformNotSupportedError @@ -36,17 +36,13 @@ def stub_platform_windows() -> str: return "Windows" -@pytest.mark.usefixtures("server") -@pytest.mark.asyncio -async def test_pytds(kwargs: Any) -> None: +@pytest.mark.usefixtures("proxy_server") +async def test_pytds(context: ssl.SSLContext, kwargs: Any) -> None: """Test to verify that pytds gets to proper connection call.""" ip_addr = "127.0.0.1" - # build ssl.SSLContext - context = await create_ssl_context() sock = context.wrap_socket( socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, - do_handshake_on_connect=False, ) with patch("pytds.connect") as mock_connect: @@ -57,20 +53,16 @@ async def test_pytds(kwargs: Any) -> None: assert mock_connect.assert_called_once -@pytest.mark.usefixtures("server") -@pytest.mark.asyncio -async def test_pytds_platform_error(kwargs: Any) -> None: +@pytest.mark.usefixtures("proxy_server") +async def test_pytds_platform_error(context: ssl.SSLContext, kwargs: Any) -> None: """Test to verify that pytds.connect throws proper PlatformNotSupportedError.""" ip_addr = "127.0.0.1" # stub operating system to Linux setattr(platform, "system", stub_platform_linux) assert platform.system() == "Linux" - # build ssl.SSLContext - context = await create_ssl_context() sock = context.wrap_socket( socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, - do_handshake_on_connect=False, ) # add active_directory_auth to kwargs kwargs["active_directory_auth"] = True @@ -79,9 +71,10 @@ async def test_pytds_platform_error(kwargs: Any) -> None: connect(ip_addr, sock, **kwargs) -@pytest.mark.usefixtures("server") -@pytest.mark.asyncio -async def test_pytds_windows_active_directory_auth(kwargs: Any) -> None: +@pytest.mark.usefixtures("proxy_server") +async def test_pytds_windows_active_directory_auth( + context: ssl.SSLContext, kwargs: Any +) -> None: """ Test to verify that pytds gets to connection call on Windows with active_directory_auth arg set. @@ -90,12 +83,9 @@ async def test_pytds_windows_active_directory_auth(kwargs: Any) -> None: # stub operating system to Windows setattr(platform, "system", stub_platform_windows) assert platform.system() == "Windows" - # build ssl.SSLContext - context = await create_ssl_context() sock = context.wrap_socket( socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, - do_handshake_on_connect=False, ) # add active_directory_auth and server_name to kwargs kwargs["active_directory_auth"] = True