From 3d082e408a06c88b08a14cac37ce27fdab777088 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Fri, 21 Aug 2020 12:02:59 +0300 Subject: [PATCH 01/12] feat: refactor connect() function, cover it with unit tests --- spanner_dbapi/__init__.py | 149 ++++++++++++++++++---------- tests/spanner_dbapi/test_connect.py | 116 ++++++++++++++++++++++ 2 files changed, 210 insertions(+), 55 deletions(-) create mode 100644 tests/spanner_dbapi/test_connect.py diff --git a/spanner_dbapi/__init__.py b/spanner_dbapi/__init__.py index f5d349a655..cf88da598f 100644 --- a/spanner_dbapi/__init__.py +++ b/spanner_dbapi/__init__.py @@ -4,83 +4,122 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -from google.cloud import spanner_v1 as spanner +"""Connection-based DB API for Cloud Spanner.""" + +from google.cloud import spanner_v1 from .connection import Connection -# These need to be included in the top-level package for PEP-0249 DB API v2. from .exceptions import ( - DatabaseError, DataError, Error, IntegrityError, InterfaceError, - InternalError, NotSupportedError, OperationalError, ProgrammingError, + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + NotSupportedError, + OperationalError, + ProgrammingError, Warning, ) from .parse_utils import get_param_types from .types import ( - BINARY, DATETIME, NUMBER, ROWID, STRING, Binary, Date, DateFromTicks, Time, - TimeFromTicks, Timestamp, TimestampFromTicks, + BINARY, + DATETIME, + NUMBER, + ROWID, + STRING, + Binary, + Date, + DateFromTicks, + Time, + TimeFromTicks, + Timestamp, + TimestampFromTicks, ) from .version import google_client_info -# Globals that MUST be defined ### -apilevel = "2.0" # Implements the Python Database API specification 2.0 version. -# We accept arguments in the format '%s' aka ANSI C print codes. -# as per https://www.python.org/dev/peps/pep-0249/#paramstyle -paramstyle = 'format' -# Threads may share the module but not connections. This is a paranoid threadsafety level, -# but it is necessary for starters to use when debugging failures. Eventually once transactions -# are working properly, we'll update the threadsafety level. +apilevel = "2.0" # supports DP-API 2.0 level. +paramstyle = "format" # ANSI C printf format codes, e.g. ...WHERE name=%s. + +# Threads may share the module, but not connections. This is a paranoid threadsafety +# level, but it is necessary for starters to use when debugging failures. +# Eventually once transactions are working properly, we'll update the +# threadsafety level. threadsafety = 1 -def connect(project=None, instance=None, database=None, credentials_uri=None, user_agent=None): +def connect(instance_id, database_id, project=None, credentials=None, user_agent=None): """ - Connect to Cloud Spanner. + Create a connection to Cloud Spanner database. - Args: - project: The id of a project that already exists. - instance: The id of an instance that already exists. - database: The name of a database that already exists. - credentials_uri: An optional string specifying where to retrieve the service - account JSON for the credentials to connect to Cloud Spanner. + :type instance_id: :class:`str` + :param instance_id: ID of the instance to connect to. - Returns: - The Connection object associated to the Cloud Spanner instance. + :type database_id: :class:`str` + :param database_id: The name of the database to connect to. - Raises: - Error if it encounters any unexpected inputs. - """ - if not project: - raise Error("'project' is required.") - if not instance: - raise Error("'instance' is required.") - if not database: - raise Error("'database' is required.") + :type project: :class:`str` + :param project: (Optional) The ID of the project which owns the + instances, tables and data. If not provided, will + attempt to determine from the environment. - client_kwargs = { - 'project': project, - 'client_info': google_client_info(user_agent), - } - if credentials_uri: - client = spanner.Client.from_service_account_json(credentials_uri, **client_kwargs) - else: - client = spanner.Client(**client_kwargs) + :type credentials: :class:`google.auth.credentials.Credentials` + :param credentials: (Optional) The authorization credentials to attach to requests. + These credentials identify this application to the service. + If none are specified, the client will attempt to ascertain + the credentials from the environment. + + :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` + :returns: Connection object associated with the given Cloud Spanner resource. + + :raises: :class:`ProgrammingError` in case of given instance/database + doesn't exist. + """ + client = spanner_v1.Client( + project=project, + credentials=credentials, + client_info=google_client_info(user_agent), + ) - client_instance = client.instance(instance) - if not client_instance.exists(): - raise ProgrammingError("instance '%s' does not exist." % instance) + instance = client.instance(instance_id) + if not instance.exists(): + raise ProgrammingError("instance '%s' does not exist." % instance_id) - db = client_instance.database(database, pool=spanner.pool.BurstyPool()) - if not db.exists(): - raise ProgrammingError("database '%s' does not exist." % database) + database = instance.database(database_id, pool=spanner_v1.pool.BurstyPool()) + if not database.exists(): + raise ProgrammingError("database '%s' does not exist." % database_id) - return Connection(db) + return Connection(database) __all__ = [ - 'DatabaseError', 'DataError', 'Error', 'IntegrityError', 'InterfaceError', - 'InternalError', 'NotSupportedError', 'OperationalError', 'ProgrammingError', - 'Warning', 'DEFAULT_USER_AGENT', 'apilevel', 'connect', 'paramstyle', 'threadsafety', - 'get_param_types', - 'Binary', 'Date', 'DateFromTicks', 'Time', 'TimeFromTicks', 'Timestamp', - 'TimestampFromTicks', - 'BINARY', 'STRING', 'NUMBER', 'DATETIME', 'ROWID', 'TimestampStr', + "DatabaseError", + "DataError", + "Error", + "IntegrityError", + "InterfaceError", + "InternalError", + "NotSupportedError", + "OperationalError", + "ProgrammingError", + "Warning", + "DEFAULT_USER_AGENT", + "apilevel", + "connect", + "paramstyle", + "threadsafety", + "get_param_types", + "Binary", + "Date", + "DateFromTicks", + "Time", + "TimeFromTicks", + "Timestamp", + "TimestampFromTicks", + "BINARY", + "STRING", + "NUMBER", + "DATETIME", + "ROWID", + "TimestampStr", ] diff --git a/tests/spanner_dbapi/test_connect.py b/tests/spanner_dbapi/test_connect.py new file mode 100644 index 0000000000..faa3c84fe1 --- /dev/null +++ b/tests/spanner_dbapi/test_connect.py @@ -0,0 +1,116 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""connect() module function unit tests.""" + +import mock +import unittest + + +def _make_credentials(): + import google.auth.credentials + + class _CredentialsWithScopes( + google.auth.credentials.Credentials, google.auth.credentials.Scoped + ): + pass + + return mock.Mock(spec=_CredentialsWithScopes) + + +class Testconnect(unittest.TestCase): + def _callFUT(self, *args, **kw): + from google.cloud.spanner_dbapi import connect + + return connect(*args, **kw) + + def test_connect(self): + from google.api_core.gapic_v1.client_info import ClientInfo + from google.cloud.spanner_dbapi.connection import Connection + + PROJECT = "test-project" + USER_AGENT = "user-agent" + CREDENTIALS = _make_credentials() + CLIENT_INFO = ClientInfo(user_agent=USER_AGENT) + + with mock.patch("google.cloud.spanner_dbapi.spanner_v1.Client") as client_mock: + with mock.patch( + "google.cloud.spanner_dbapi.google_client_info", + return_value=CLIENT_INFO, + ) as client_info_mock: + + connection = self._callFUT( + "test-instance", "test-database", PROJECT, CREDENTIALS, USER_AGENT + ) + + self.assertIsInstance(connection, Connection) + client_info_mock.assert_called_once_with(USER_AGENT) + + client_mock.assert_called_once_with( + project=PROJECT, credentials=CREDENTIALS, client_info=CLIENT_INFO + ) + + def test_instance_not_found(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=False + ) as exists_mock: + + with self.assertRaises(ProgrammingError): + self._callFUT("test-instance", "test-database") + + exists_mock.assert_called_once() + + def test_database_not_found(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=False + ) as exists_mock: + + with self.assertRaises(ProgrammingError): + self._callFUT("test-instance", "test-database") + + exists_mock.assert_called_once() + + def test_connect_instance_id(self): + from google.cloud.spanner_dbapi.connection import Connection + + INSTANCE = "test-instance" + + with mock.patch( + "google.cloud.spanner_v1.client.Client.instance" + ) as instance_mock: + connection = self._callFUT(INSTANCE, "test-database") + + instance_mock.assert_called_once_with(INSTANCE) + + self.assertIsInstance(connection, Connection) + + def test_connect_database_id(self): + from google.cloud.spanner_dbapi.connection import Connection + + DATABASE = "test-database" + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.database" + ) as database_mock: + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + connection = self._callFUT("test-instance", DATABASE) + + database_mock.assert_called_once_with(DATABASE, pool=mock.ANY) + + self.assertIsInstance(connection, Connection) + + +if __name__ == "__main__": + unittest.main() From 13d672b156ab7d11bb213c23815931e0e359e5cb Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Fri, 21 Aug 2020 12:19:22 +0300 Subject: [PATCH 02/12] fix mock import --- tests/spanner_dbapi/test_connect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/spanner_dbapi/test_connect.py b/tests/spanner_dbapi/test_connect.py index faa3c84fe1..5d7f193b58 100644 --- a/tests/spanner_dbapi/test_connect.py +++ b/tests/spanner_dbapi/test_connect.py @@ -6,8 +6,8 @@ """connect() module function unit tests.""" -import mock import unittest +from unittest import mock def _make_credentials(): From 6eaa1f0848573e793e8b2385deab02ec493f04ec Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Fri, 21 Aug 2020 12:27:00 +0300 Subject: [PATCH 03/12] change imports to the db_api package instead of google.cloud --- tests/spanner_dbapi/test_connect.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/spanner_dbapi/test_connect.py b/tests/spanner_dbapi/test_connect.py index 5d7f193b58..a57d3ca0ff 100644 --- a/tests/spanner_dbapi/test_connect.py +++ b/tests/spanner_dbapi/test_connect.py @@ -23,23 +23,22 @@ class _CredentialsWithScopes( class Testconnect(unittest.TestCase): def _callFUT(self, *args, **kw): - from google.cloud.spanner_dbapi import connect + from spanner_dbapi import connect return connect(*args, **kw) def test_connect(self): from google.api_core.gapic_v1.client_info import ClientInfo - from google.cloud.spanner_dbapi.connection import Connection + from spanner_dbapi.connection import Connection PROJECT = "test-project" USER_AGENT = "user-agent" CREDENTIALS = _make_credentials() CLIENT_INFO = ClientInfo(user_agent=USER_AGENT) - with mock.patch("google.cloud.spanner_dbapi.spanner_v1.Client") as client_mock: + with mock.patch("spanner_dbapi.spanner_v1.Client") as client_mock: with mock.patch( - "google.cloud.spanner_dbapi.google_client_info", - return_value=CLIENT_INFO, + "spanner_dbapi.google_client_info", return_value=CLIENT_INFO ) as client_info_mock: connection = self._callFUT( @@ -54,7 +53,7 @@ def test_connect(self): ) def test_instance_not_found(self): - from google.cloud.spanner_dbapi.exceptions import ProgrammingError + from spanner_dbapi.exceptions import ProgrammingError with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=False @@ -66,7 +65,7 @@ def test_instance_not_found(self): exists_mock.assert_called_once() def test_database_not_found(self): - from google.cloud.spanner_dbapi.exceptions import ProgrammingError + from spanner_dbapi.exceptions import ProgrammingError with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=True @@ -81,7 +80,7 @@ def test_database_not_found(self): exists_mock.assert_called_once() def test_connect_instance_id(self): - from google.cloud.spanner_dbapi.connection import Connection + from spanner_dbapi.connection import Connection INSTANCE = "test-instance" @@ -95,7 +94,7 @@ def test_connect_instance_id(self): self.assertIsInstance(connection, Connection) def test_connect_database_id(self): - from google.cloud.spanner_dbapi.connection import Connection + from spanner_dbapi.connection import Connection DATABASE = "test-database" From 7607d9e2aae297a0c8c4723ecc8856382de194f3 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Mon, 24 Aug 2020 12:47:57 +0300 Subject: [PATCH 04/12] feat: cursor must detect if the parent connection is closed --- spanner_dbapi/connection.py | 54 ++++++----- spanner_dbapi/cursor.py | 138 ++++++++++++++++++----------- tests/spanner_dbapi/test_cursor.py | 54 +++++++++++ 3 files changed, 172 insertions(+), 74 deletions(-) create mode 100644 tests/spanner_dbapi/test_cursor.py diff --git a/spanner_dbapi/connection.py b/spanner_dbapi/connection.py index 20b707adb0..0ae8c84d27 100644 --- a/spanner_dbapi/connection.py +++ b/spanner_dbapi/connection.py @@ -11,26 +11,31 @@ from .cursor import Cursor from .exceptions import InterfaceError -ColumnDetails = namedtuple('column_details', ['null_ok', 'spanner_type']) +ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) class Connection: def __init__(self, db_handle): self._dbhandle = db_handle - self._closed = False self._ddl_statements = [] + self.is_closed = False + def cursor(self): - self.__raise_if_already_closed() + self._raise_if_already_closed() return Cursor(self) - def __raise_if_already_closed(self): - """ - Raise an exception if attempting to use an already closed connection. + def _raise_if_already_closed(self): + """Raise an exception if this connection is closed. + + Helper to check the connection state before + running a SQL/DDL/DML query. + + :raises: :class:`InterfaceError` if this connection is closed. """ - if self._closed: - raise InterfaceError('connection already closed') + if self.is_closed: + raise InterfaceError("connection is already closed") def __handle_update_ddl(self, ddl_statements): """ @@ -41,24 +46,24 @@ def __handle_update_ddl(self, ddl_statements): Returns: google.api_core.operation.Operation.result() """ - self.__raise_if_already_closed() + self._raise_if_already_closed() # Synchronously wait on the operation's completion. return self._dbhandle.update_ddl(ddl_statements).result() def read_snapshot(self): - self.__raise_if_already_closed() + self._raise_if_already_closed() return self._dbhandle.snapshot() def in_transaction(self, fn, *args, **kwargs): - self.__raise_if_already_closed() + self._raise_if_already_closed() return self._dbhandle.run_in_transaction(fn, *args, **kwargs) def append_ddl_statement(self, ddl_statement): - self.__raise_if_already_closed() + self._raise_if_already_closed() self._ddl_statements.append(ddl_statement) def run_prior_DDL_statements(self): - self.__raise_if_already_closed() + self._raise_if_already_closed() if not self._ddl_statements: return @@ -69,14 +74,16 @@ def run_prior_DDL_statements(self): return self.__handle_update_ddl(ddl_statements) def list_tables(self): - return self.run_sql_in_snapshot(""" + return self.run_sql_in_snapshot( + """ SELECT t.table_name FROM information_schema.tables AS t WHERE t.table_catalog = '' and t.table_schema = '' - """) + """ + ) def run_sql_in_snapshot(self, sql, params=None, param_types=None): # Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions @@ -89,38 +96,37 @@ def run_sql_in_snapshot(self, sql, params=None, param_types=None): def get_table_column_schema(self, table_name): rows = self.run_sql_in_snapshot( - '''SELECT + """SELECT COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = '' AND - TABLE_NAME = @table_name''', - params={'table_name': table_name}, - param_types={'table_name': spanner.param_types.STRING}, + TABLE_NAME = @table_name""", + params={"table_name": table_name}, + param_types={"table_name": spanner.param_types.STRING}, ) column_details = {} for column_name, is_nullable, spanner_type in rows: column_details[column_name] = ColumnDetails( - null_ok=is_nullable == 'YES', - spanner_type=spanner_type, + null_ok=is_nullable == "YES", spanner_type=spanner_type ) return column_details def close(self): self.rollback() self.__dbhandle = None - self._closed = True + self.is_closed = True def commit(self): - self.__raise_if_already_closed() + self._raise_if_already_closed() self.run_prior_DDL_statements() def rollback(self): - self.__raise_if_already_closed() + self._raise_if_already_closed() # TODO: to be added. diff --git a/spanner_dbapi/cursor.py b/spanner_dbapi/cursor.py index d5f08c4e93..ebf0cb66f0 100644 --- a/spanner_dbapi/cursor.py +++ b/spanner_dbapi/cursor.py @@ -8,11 +8,19 @@ from google.cloud.spanner_v1 import param_types from .exceptions import ( - IntegrityError, InterfaceError, OperationalError, ProgrammingError, + IntegrityError, + InterfaceError, + OperationalError, + ProgrammingError, ) from .parse_utils import ( - STMT_DDL, STMT_INSERT, STMT_NON_UPDATING, classify_stmt, - ensure_where_clause, get_param_types, parse_insert, + STMT_DDL, + STMT_INSERT, + STMT_NON_UPDATING, + classify_stmt, + ensure_where_clause, + get_param_types, + parse_insert, sql_pyformat_args_to_spanner, ) from .utils import PeekIterator @@ -44,12 +52,9 @@ def __init__(self, connection): self._res = None self._row_count = _UNSET_COUNT self._connection = connection - self._closed = False + self._is_closed = False - # arraysize is a readable and writable property mandated - # by PEP-0249 https://www.python.org/dev/peps/pep-0249/#arraysize - # It determines the results of .fetchmany - self.arraysize = 1 + self.arraysize = 1 # the number of rows to fetch at a time with fetchmany() def execute(self, sql, args=None): """ @@ -64,7 +69,7 @@ def execute(self, sql, args=None): self._raise_if_already_closed() if not self._connection: - raise ProgrammingError('Cursor is not connected to the database') + raise ProgrammingError("Cursor is not connected to the database") self._res = None @@ -86,23 +91,22 @@ def execute(self, sql, args=None): else: self.__handle_update(sql, args or None) except (grpc_exceptions.AlreadyExists, grpc_exceptions.FailedPrecondition) as e: - raise IntegrityError(e.details if hasattr(e, 'details') else e) + raise IntegrityError(e.details if hasattr(e, "details") else e) except grpc_exceptions.InvalidArgument as e: - raise ProgrammingError(e.details if hasattr(e, 'details') else e) + raise ProgrammingError(e.details if hasattr(e, "details") else e) except grpc_exceptions.InternalServerError as e: - raise OperationalError(e.details if hasattr(e, 'details') else e) + raise OperationalError(e.details if hasattr(e, "details") else e) def __handle_update(self, sql, params): - self._connection.in_transaction( - self.__do_execute_update, - sql, params, - ) + self._connection.in_transaction(self.__do_execute_update, sql, params) def __do_execute_update(self, transaction, sql, params, param_types=None): sql = ensure_where_clause(sql) sql, params = sql_pyformat_args_to_spanner(sql, params) - res = transaction.execute_update(sql, params=params, param_types=get_param_types(params)) + res = transaction.execute_update( + sql, params=params, param_types=get_param_types(params) + ) self._itr = None if type(res) == int: self._row_count = res @@ -125,20 +129,18 @@ def __handle_insert(self, sql, params): # transaction.execute_sql(sql, params, param_types) # which invokes more RPCs and is more costly. - if parts.get('homogenous'): + if parts.get("homogenous"): # The common case of multiple values being passed in # non-complex pyformat args and need to be uploaded in one RPC. return self._connection.in_transaction( - self.__do_execute_insert_homogenous, - parts, + self.__do_execute_insert_homogenous, parts ) else: # All the other cases that are esoteric and need # transaction.execute_sql - sql_params_list = parts.get('sql_params_list') + sql_params_list = parts.get("sql_params_list") return self._connection.in_transaction( - self.__do_execute_insert_heterogenous, - sql_params_list, + self.__do_execute_insert_heterogenous, sql_params_list ) def __do_execute_insert_heterogenous(self, transaction, sql_params_list): @@ -152,9 +154,9 @@ def __do_execute_insert_heterogenous(self, transaction, sql_params_list): def __do_execute_insert_homogenous(self, transaction, parts): # Perform an insert in one shot. - table = parts.get('table') - columns = parts.get('columns') - values = parts.get('values') + table = parts.get("table") + columns = parts.get("columns") + values = parts.get("values") return transaction.insert(table, columns, values) def __handle_DQL(self, sql, params): @@ -162,7 +164,9 @@ def __handle_DQL(self, sql, params): # Reference # https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql sql, params = sql_pyformat_args_to_spanner(sql, params) - res = snapshot.execute_sql(sql, params=params, param_types=get_param_types(params)) + res = snapshot.execute_sql( + sql, params=params, param_types=get_param_types(params) + ) if type(res) == int: self._row_count = res self._itr = None @@ -216,32 +220,48 @@ def description(self): def rowcount(self): return self._row_count - def _raise_if_already_closed(self): + @property + def is_closed(self): + """The cursor close indicator. + + Returns: + bool: + True if this cursor or it's parent connection + is closed, False otherwise. """ - Raise an exception if attempting to use an already closed connection. + return self._is_closed or self._connection.is_closed + + def _raise_if_already_closed(self): + """Raise an exception if this cursor is closed. + + Helper to check this cursor's state before running a + SQL/DDL/DML query. If the parent connection is + already closed it also raises an error. + + :raises: :class:`InterfaceError` if this cursor is closed. """ - if self._closed: - raise InterfaceError('cursor already closed') + if self.is_closed: + raise InterfaceError("cursor is already closed") def close(self): self.__clear() - self._closed = True + self._is_closed = True def executemany(self, operation, seq_of_params): if not self._connection: - raise ProgrammingError('Cursor is not connected to the database') + raise ProgrammingError("Cursor is not connected to the database") for params in seq_of_params: self.execute(operation, params) def __next__(self): if self._itr is None: - raise ProgrammingError('no results to return') + raise ProgrammingError("no results to return") return next(self._itr) def __iter__(self): if self._itr is None: - raise ProgrammingError('no results to return') + raise ProgrammingError("no results to return") return self._itr def fetchone(self): @@ -289,10 +309,10 @@ def lastrowid(self): return None def setinputsizes(sizes): - raise ProgrammingError('Unimplemented') + raise ProgrammingError("Unimplemented") def setoutputsize(size, column=None): - raise ProgrammingError('Unimplemented') + raise ProgrammingError("Unimplemented") def _run_prior_DDL_statements(self): return self._connection.run_prior_DDL_statements() @@ -308,8 +328,16 @@ def get_table_column_schema(self, table_name): class Column: - def __init__(self, name, type_code, display_size=None, internal_size=None, - precision=None, scale=None, null_ok=False): + def __init__( + self, + name, + type_code, + display_size=None, + internal_size=None, + precision=None, + scale=None, + null_ok=False, + ): self.name = name self.type_code = type_code self.display_size = display_size @@ -338,14 +366,24 @@ def __getitem__(self, index): return self.null_ok def __str__(self): - rstr = ', '.join([field for field in [ - "name='%s'" % self.name, - "type_code=%d" % self.type_code, - None if not self.display_size else "display_size=%d" % self.display_size, - None if not self.internal_size else "internal_size=%d" % self.internal_size, - None if not self.precision else "precision='%s'" % self.precision, - None if not self.scale else "scale='%s'" % self.scale, - None if not self.null_ok else "null_ok='%s'" % self.null_ok, - ] if field]) - - return 'Column(%s)' % rstr + rstr = ", ".join( + [ + field + for field in [ + "name='%s'" % self.name, + "type_code=%d" % self.type_code, + None + if not self.display_size + else "display_size=%d" % self.display_size, + None + if not self.internal_size + else "internal_size=%d" % self.internal_size, + None if not self.precision else "precision='%s'" % self.precision, + None if not self.scale else "scale='%s'" % self.scale, + None if not self.null_ok else "null_ok='%s'" % self.null_ok, + ] + if field + ] + ) + + return "Column(%s)" % rstr diff --git a/tests/spanner_dbapi/test_cursor.py b/tests/spanner_dbapi/test_cursor.py new file mode 100644 index 0000000000..99014ea254 --- /dev/null +++ b/tests/spanner_dbapi/test_cursor.py @@ -0,0 +1,54 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Cursor() class unit tests.""" + +import unittest +from unittest import mock + + +class TestCursor(unittest.TestCase): + def test_close(self): + from spanner_dbapi import connect + from spanner_dbapi.exceptions import InterfaceError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor.close() + + self.assertTrue(cursor.is_closed) + with self.assertRaises(InterfaceError): + cursor.execute("SELECT * FROM database") + + def test_connection_closed(self): + from spanner_dbapi import connect + from spanner_dbapi.exceptions import InterfaceError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + connection.close() + + self.assertTrue(cursor.is_closed) + with self.assertRaises(InterfaceError): + cursor.execute("SELECT * FROM database") + + +if __name__ == "__main__": + unittest.main() From 0f53b1d855f56a1fee01e82f619a70ead6754f8f Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Mon, 24 Aug 2020 14:45:54 +0300 Subject: [PATCH 05/12] feat: avoid adding a dummy WHERE clause into UPDATE/DELETE queries --- spanner_dbapi/cursor.py | 22 +- spanner_dbapi/parse_utils.py | 336 +++++++++-------- tests/spanner_dbapi/test_parse_utils.py | 473 +++++++++++++----------- 3 files changed, 441 insertions(+), 390 deletions(-) diff --git a/spanner_dbapi/cursor.py b/spanner_dbapi/cursor.py index ebf0cb66f0..a43dabb096 100644 --- a/spanner_dbapi/cursor.py +++ b/spanner_dbapi/cursor.py @@ -63,13 +63,11 @@ def execute(self, sql, args=None): sql: A SQL statement *args: variadic argument list **kwargs: key worded arguments - Returns: - None """ self._raise_if_already_closed() if not self._connection: - raise ProgrammingError("Cursor is not connected to the database") + raise ProgrammingError("Cursor is not connected to a database") self._res = None @@ -82,14 +80,14 @@ def execute(self, sql, args=None): # For every other operation, we've got to ensure that # any prior DDL statements were run. - self._run_prior_DDL_statements() + self._run_prior_ddl_statements() if classification == STMT_NON_UPDATING: - self.__handle_DQL(sql, args or None) + self._handle_dql(sql, args) elif classification == STMT_INSERT: - self.__handle_insert(sql, args or None) + self._handle_insert(sql, args) else: - self.__handle_update(sql, args or None) + self._handle_update(sql, args) except (grpc_exceptions.AlreadyExists, grpc_exceptions.FailedPrecondition) as e: raise IntegrityError(e.details if hasattr(e, "details") else e) except grpc_exceptions.InvalidArgument as e: @@ -97,11 +95,11 @@ def execute(self, sql, args=None): except grpc_exceptions.InternalServerError as e: raise OperationalError(e.details if hasattr(e, "details") else e) - def __handle_update(self, sql, params): + def _handle_update(self, sql, params): self._connection.in_transaction(self.__do_execute_update, sql, params) def __do_execute_update(self, transaction, sql, params, param_types=None): - sql = ensure_where_clause(sql) + ensure_where_clause(sql) sql, params = sql_pyformat_args_to_spanner(sql, params) res = transaction.execute_update( @@ -113,7 +111,7 @@ def __do_execute_update(self, transaction, sql, params, param_types=None): return res - def __handle_insert(self, sql, params): + def _handle_insert(self, sql, params): parts = parse_insert(sql, params) # The split between the two styles exists because: @@ -159,7 +157,7 @@ def __do_execute_insert_homogenous(self, transaction, parts): values = parts.get("values") return transaction.insert(table, columns, values) - def __handle_DQL(self, sql, params): + def _handle_dql(self, sql, params): with self._connection.read_snapshot() as snapshot: # Reference # https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql @@ -314,7 +312,7 @@ def setinputsizes(sizes): def setoutputsize(size, column=None): raise ProgrammingError("Unimplemented") - def _run_prior_DDL_statements(self): + def _run_prior_ddl_statements(self): return self._connection.run_prior_DDL_statements() def list_tables(self): diff --git a/spanner_dbapi/parse_utils.py b/spanner_dbapi/parse_utils.py index c6299a3c87..e01678758a 100644 --- a/spanner_dbapi/parse_utils.py +++ b/spanner_dbapi/parse_utils.py @@ -17,20 +17,20 @@ from .types import DateStr, TimestampStr from .utils import sanitize_literals_for_upload -STMT_DDL = 'DDL' -STMT_NON_UPDATING = 'NON_UPDATING' -STMT_UPDATING = 'UPDATING' -STMT_INSERT = 'INSERT' +STMT_DDL = "DDL" +STMT_NON_UPDATING = "NON_UPDATING" +STMT_UPDATING = "UPDATING" +STMT_INSERT = "INSERT" # Heuristic for identifying statements that don't need to be run as updates. -re_NON_UPDATE = re.compile(r'^\s*(SELECT)', re.IGNORECASE) +re_NON_UPDATE = re.compile(r"^\s*(SELECT)", re.IGNORECASE) -re_WITH = re.compile(r'^\s*(WITH)', re.IGNORECASE) +re_WITH = re.compile(r"^\s*(WITH)", re.IGNORECASE) # DDL statements follow https://cloud.google.com/spanner/docs/data-definition-language -re_DDL = re.compile(r'^\s*(CREATE|ALTER|DROP)', re.IGNORECASE | re.DOTALL) +re_DDL = re.compile(r"^\s*(CREATE|ALTER|DROP)", re.IGNORECASE | re.DOTALL) -re_IS_INSERT = re.compile(r'^\s*(INSERT)', re.IGNORECASE | re.DOTALL) +re_IS_INSERT = re.compile(r"^\s*(INSERT)", re.IGNORECASE | re.DOTALL) def classify_stmt(sql): @@ -53,18 +53,15 @@ def classify_stmt(sql): # Only match the `INSERT INTO (columns...) # otherwise the rest of the statement could be a complex # operation. - r'^\s*INSERT INTO (?P[^\s\(\)]+)\s*\((?P[^\(\)]+)\)', + r"^\s*INSERT INTO (?P[^\s\(\)]+)\s*\((?P[^\(\)]+)\)", re.IGNORECASE | re.DOTALL, ) -re_VALUES_TILL_END = re.compile( - r'VALUES\s*\(.+$', - re.IGNORECASE | re.DOTALL, -) +re_VALUES_TILL_END = re.compile(r"VALUES\s*\(.+$", re.IGNORECASE | re.DOTALL) re_VALUES_PYFORMAT = re.compile( # To match: (%s, %s,....%s) - r'(\(\s*%s[^\(\)]+\))', + r"(\(\s*%s[^\(\)]+\))", re.DOTALL, ) @@ -74,7 +71,7 @@ def strip_backticks(name): Strip backticks off of quoted names For example, '`no`' (a Spanner reserved word) becomes 'no'. """ - has_quotes = name.startswith('`') and name.endswith('`') + has_quotes = name.startswith("`") and name.endswith("`") return name[1:-1] if has_quotes else name @@ -139,30 +136,30 @@ def parse_insert(insert_sql, params): match = re_INSERT.search(insert_sql) if not match: - raise ProgrammingError('Could not parse an INSERT statement from %s' % insert_sql) + raise ProgrammingError( + "Could not parse an INSERT statement from %s" % insert_sql + ) after_values_sql = re_VALUES_TILL_END.findall(insert_sql) if not after_values_sql: # Case b) insert_sql = sanitize_literals_for_upload(insert_sql) - return { - 'sql_params_list': [(insert_sql, None,)], - } + return {"sql_params_list": [(insert_sql, None)]} if not params: # Case a) perhaps? # Check if any %s exists. - pyformat_str_count = after_values_sql.count('%s') + pyformat_str_count = after_values_sql.count("%s") if pyformat_str_count > 0: - raise ProgrammingError('no params yet there are %d "%s" tokens' % pyformat_str_count) + raise ProgrammingError( + 'no params yet there are %d "%s" tokens' % pyformat_str_count + ) insert_sql = sanitize_literals_for_upload(insert_sql) # Confirmed case of: # SQL: INSERT INTO T (a1, a2) VALUES (1, 2) # Params: None - return { - 'sql_params_list': [(insert_sql, None,)], - } + return {"sql_params_list": [(insert_sql, None)]} values_str = after_values_sql[0] _, values = parse_values(values_str) @@ -171,22 +168,21 @@ def parse_insert(insert_sql, params): # Case c) columns = [ - strip_backticks(mi.strip()) - for mi in match.group('columns').split(',') + strip_backticks(mi.strip()) for mi in match.group("columns").split(",") ] sql_params_list = [] - insert_sql_preamble = 'INSERT INTO %s (%s) VALUES %s' % ( - match.group('table_name'), match.group('columns'), values.argv[0], + insert_sql_preamble = "INSERT INTO %s (%s) VALUES %s" % ( + match.group("table_name"), + match.group("columns"), + values.argv[0], ) values_pyformat = [str(arg) for arg in values.argv] rows_list = rows_for_insert_or_update(columns, params, values_pyformat) insert_sql_preamble = sanitize_literals_for_upload(insert_sql_preamble) for row in rows_list: - sql_params_list.append((insert_sql_preamble, row,)) + sql_params_list.append((insert_sql_preamble, row)) - return { - 'sql_params_list': sql_params_list, - } + return {"sql_params_list": sql_params_list} # Case d) # insert_sql is of the form: @@ -194,10 +190,11 @@ def parse_insert(insert_sql, params): # Sanity check: # length(all_args) == len(params) - args_len = reduce(lambda a, b: a+b, [len(arg) for arg in values.argv]) + args_len = reduce(lambda a, b: a + b, [len(arg) for arg in values.argv]) if args_len != len(params): - raise ProgrammingError('Invalid length: VALUES(...) len: %d != len(params): %d' % ( - args_len, len(params)), + raise ProgrammingError( + "Invalid length: VALUES(...) len: %d != len(params): %d" + % (args_len, len(params)) ) trim_index = insert_sql.find(values_str) @@ -205,14 +202,12 @@ def parse_insert(insert_sql, params): sql_param_tuples = [] for token_arg in values.argv: - row_sql = before_values_sql + ' VALUES%s' % token_arg + row_sql = before_values_sql + " VALUES%s" % token_arg row_sql = sanitize_literals_for_upload(row_sql) - row_params, params = tuple(params[0:len(token_arg)]), params[len(token_arg):] - sql_param_tuples.append((row_sql, row_params,)) + row_params, params = tuple(params[0 : len(token_arg)]), params[len(token_arg) :] + sql_param_tuples.append((row_sql, row_params)) - return { - 'sql_params_list': sql_param_tuples, - } + return {"sql_params_list": sql_param_tuples} def rows_for_insert_or_update(columns, params, pyformat_args=None): @@ -247,8 +242,10 @@ def rows_for_insert_or_update(columns, params, pyformat_args=None): columns_len = len(columns) for param in params: if columns_len != len(param): - raise Error('\nlen(`%s`)=%d\n!=\ncolum_len(`%s`)=%d' % ( - param, len(param), columns, columns_len)) + raise Error( + "\nlen(`%s`)=%d\n!=\ncolum_len(`%s`)=%d" + % (param, len(param), columns, columns_len) + ) return params else: # The case with Params B: [1, 2, 3] @@ -267,12 +264,14 @@ def rows_for_insert_or_update(columns, params, pyformat_args=None): # Sanity check 1: all the pyformat_values should have the exact same length. first, rest = pyformat_args[0], pyformat_args[1:] - n_stride = first.count('%s') + n_stride = first.count("%s") for pyfmt_value in rest: - n = pyfmt_value.count('%s') + n = pyfmt_value.count("%s") if n_stride != n: - raise Error('\nlen(`%s`)=%d\n!=\nlen(`%s`)=%d' % ( - first, n_stride, pyfmt_value, n)) + raise Error( + "\nlen(`%s`)=%d\n!=\nlen(`%s`)=%d" + % (first, n_stride, pyfmt_value, n) + ) # Sanity check 2: len(params) MUST be a multiple of n_stride aka len(count of %s). # so that we can properly group for example: @@ -283,20 +282,21 @@ def rows_for_insert_or_update(columns, params, pyformat_args=None): # into # [(1, 2, 3), (4, 5, 6), (7, 8, 9)] if (len(params) % n_stride) != 0: - raise ProgrammingError('Invalid length: len(params)=%d MUST be a multiple of len(pyformat_args)=%d' % ( - len(params), n_stride), + raise ProgrammingError( + "Invalid length: len(params)=%d MUST be a multiple of len(pyformat_args)=%d" + % (len(params), n_stride) ) # Now chop up the strides. strides = [] for step in range(0, len(params), n_stride): - stride = tuple(params[step:step+n_stride:]) + stride = tuple(params[step : step + n_stride :]) strides.append(stride) return strides -re_PYFORMAT = re.compile(r'(%s|%\([^\(\)]+\)s)+', re.DOTALL) +re_PYFORMAT = re.compile(r"(%s|%\([^\(\)]+\)s)+", re.DOTALL) def sql_pyformat_args_to_spanner(sql, params): @@ -332,8 +332,10 @@ def sql_pyformat_args_to_spanner(sql, params): n_matches = len(found_pyformat_placeholders) if n_matches != n_params: raise Error( - 'pyformat_args mismatch\ngot %d args from %s\n' - 'want %d args in %s' % (n_matches, found_pyformat_placeholders, n_params, params)) + "pyformat_args mismatch\ngot %d args from %s\n" + "want %d args in %s" + % (n_matches, found_pyformat_placeholders, n_params, params) + ) if len(params) == 0: return sanitize_literals_for_upload(sql), params @@ -345,8 +347,8 @@ def sql_pyformat_args_to_spanner(sql, params): # Params: ('a', 23, '888***') # Case b) Params is a dict and the matches are %(value)s' for i, pyfmt in enumerate(found_pyformat_placeholders): - key = 'a%d' % i - sql = sql.replace(pyfmt, '@'+key, 1) + key = "a%d" % i + sql = sql.replace(pyfmt, "@" + key, 1) if params_is_dict: # The '%(key)s' case, so interpolate it. resolved_value = pyfmt % params @@ -379,9 +381,9 @@ def get_param_types(params): param_types[key] = spanner.param_types.FLOAT64 elif isinstance(value, int): param_types[key] = spanner.param_types.INT64 - elif isinstance(value, (TimestampStr, datetime.datetime,)): + elif isinstance(value, (TimestampStr, datetime.datetime)): param_types[key] = spanner.param_types.TIMESTAMP - elif isinstance(value, (DateStr, datetime.date,)): + elif isinstance(value, (DateStr, datetime.date)): param_types[key] = spanner.param_types.DATE elif isinstance(value, str): param_types[key] = spanner.param_types.STRING @@ -391,112 +393,124 @@ def get_param_types(params): def ensure_where_clause(sql): + """Check if the given sql query includes WHERE clause. + + Cloud Spanner requires a WHERE clause with UPDATE + and DELETE statements to avoid accidental delete + or update of all the table rows. + + :type sql: :class:`str` + :param sql: SQL query to check for WHERE clause presence. + + :raises: :class:`ProgrammingError` if there is + no WHERE clause in the given query. """ - Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements. - Add a dummy WHERE clause if necessary. - """ - if any(isinstance(token, sqlparse.sql.Where) for token in sqlparse.parse(sql)[0]): - return sql - return sql + ' WHERE 1=1' + for token in sqlparse.parse(sql)[0]: + if isinstance(token, sqlparse.sql.Where): + return + + raise ProgrammingError( + "Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query" + ) SPANNER_RESERVED_KEYWORDS = { - 'ALL', - 'AND', - 'ANY', - 'ARRAY', - 'AS', - 'ASC', - 'ASSERT_ROWS_MODIFIED', - 'AT', - 'BETWEEN', - 'BY', - 'CASE', - 'CAST', - 'COLLATE', - 'CONTAINS', - 'CREATE', - 'CROSS', - 'CUBE', - 'CURRENT', - 'DEFAULT', - 'DEFINE', - 'DESC', - 'DISTINCT', - 'DROP', - 'ELSE', - 'END', - 'ENUM', - 'ESCAPE', - 'EXCEPT', - 'EXCLUDE', - 'EXISTS', - 'EXTRACT', - 'FALSE', - 'FETCH', - 'FOLLOWING', - 'FOR', - 'FROM', - 'FULL', - 'GROUP', - 'GROUPING', - 'GROUPS', - 'HASH', - 'HAVING', - 'IF', - 'IGNORE', - 'IN', - 'INNER', - 'INTERSECT', - 'INTERVAL', - 'INTO', - 'IS', - 'JOIN', - 'LATERAL', - 'LEFT', - 'LIKE', - 'LIMIT', - 'LOOKUP', - 'MERGE', - 'NATURAL', - 'NEW', - 'NO', - 'NOT', - 'NULL', - 'NULLS', - 'OF', - 'ON', - 'OR', - 'ORDER', - 'OUTER', - 'OVER', - 'PARTITION', - 'PRECEDING', - 'PROTO', - 'RANGE', - 'RECURSIVE', - 'RESPECT', - 'RIGHT', - 'ROLLUP', - 'ROWS', - 'SELECT', - 'SET', - 'SOME', - 'STRUCT', - 'TABLESAMPLE', - 'THEN', - 'TO', - 'TREAT', - 'TRUE', - 'UNBOUNDED', - 'UNION', - 'UNNEST', - 'USING', - 'WHEN', - 'WHERE', - 'WINDOW', - 'WITH', - 'WITHIN', + "ALL", + "AND", + "ANY", + "ARRAY", + "AS", + "ASC", + "ASSERT_ROWS_MODIFIED", + "AT", + "BETWEEN", + "BY", + "CASE", + "CAST", + "COLLATE", + "CONTAINS", + "CREATE", + "CROSS", + "CUBE", + "CURRENT", + "DEFAULT", + "DEFINE", + "DESC", + "DISTINCT", + "DROP", + "ELSE", + "END", + "ENUM", + "ESCAPE", + "EXCEPT", + "EXCLUDE", + "EXISTS", + "EXTRACT", + "FALSE", + "FETCH", + "FOLLOWING", + "FOR", + "FROM", + "FULL", + "GROUP", + "GROUPING", + "GROUPS", + "HASH", + "HAVING", + "IF", + "IGNORE", + "IN", + "INNER", + "INTERSECT", + "INTERVAL", + "INTO", + "IS", + "JOIN", + "LATERAL", + "LEFT", + "LIKE", + "LIMIT", + "LOOKUP", + "MERGE", + "NATURAL", + "NEW", + "NO", + "NOT", + "NULL", + "NULLS", + "OF", + "ON", + "OR", + "ORDER", + "OUTER", + "OVER", + "PARTITION", + "PRECEDING", + "PROTO", + "RANGE", + "RECURSIVE", + "RESPECT", + "RIGHT", + "ROLLUP", + "ROWS", + "SELECT", + "SET", + "SOME", + "STRUCT", + "TABLESAMPLE", + "THEN", + "TO", + "TREAT", + "TRUE", + "UNBOUNDED", + "UNION", + "UNNEST", + "USING", + "WHEN", + "WHERE", + "WINDOW", + "WITH", + "WITHIN", } @@ -505,6 +519,6 @@ def escape_name(name): Escape name by applying backticks to value that either contain '-' or are any of Cloud Spanner's reserved keywords. """ - if '-' in name or ' ' in name or name.upper() in SPANNER_RESERVED_KEYWORDS: - return '`' + name + '`' + if "-" in name or " " in name or name.upper() in SPANNER_RESERVED_KEYWORDS: + return "`" + name + "`" return name diff --git a/tests/spanner_dbapi/test_parse_utils.py b/tests/spanner_dbapi/test_parse_utils.py index f0da0145c7..8445174210 100644 --- a/tests/spanner_dbapi/test_parse_utils.py +++ b/tests/spanner_dbapi/test_parse_utils.py @@ -11,9 +11,18 @@ from google.cloud.spanner_v1 import param_types from spanner_dbapi.exceptions import Error, ProgrammingError from spanner_dbapi.parse_utils import ( - STMT_DDL, STMT_NON_UPDATING, DateStr, TimestampStr, classify_stmt, - ensure_where_clause, escape_name, get_param_types, parse_insert, - rows_for_insert_or_update, sql_pyformat_args_to_spanner, strip_backticks, + STMT_DDL, + STMT_NON_UPDATING, + DateStr, + TimestampStr, + classify_stmt, + ensure_where_clause, + escape_name, + get_param_types, + parse_insert, + rows_for_insert_or_update, + sql_pyformat_args_to_spanner, + strip_backticks, ) from spanner_dbapi.utils import backtick_unicode @@ -21,139 +30,126 @@ class ParseUtilsTests(TestCase): def test_classify_stmt(self): cases = [ - ('SELECT 1', STMT_NON_UPDATING,), - ('SELECT s.SongName FROM Songs AS s', STMT_NON_UPDATING,), - ('WITH sq AS (SELECT SchoolID FROM Roster) SELECT * from sq', STMT_NON_UPDATING,), - ( - 'CREATE TABLE django_content_type (id STRING(64) NOT NULL, name STRING(100) ' - 'NOT NULL, app_label STRING(100) NOT NULL, model STRING(100) NOT NULL) PRIMARY KEY(id)', - STMT_DDL, - ), - ( - 'CREATE INDEX SongsBySingerAlbumSongNameDesc ON ' - 'Songs(SingerId, AlbumId, SongName DESC), INTERLEAVE IN Albums', - STMT_DDL, - ), - ('CREATE INDEX SongsBySongName ON Songs(SongName)', STMT_DDL,), - ('CREATE INDEX AlbumsByAlbumTitle2 ON Albums(AlbumTitle) STORING (MarketingBudget)', STMT_DDL,), + ("SELECT 1", STMT_NON_UPDATING), + ("SELECT s.SongName FROM Songs AS s", STMT_NON_UPDATING), + ( + "WITH sq AS (SELECT SchoolID FROM Roster) SELECT * from sq", + STMT_NON_UPDATING, + ), + ( + "CREATE TABLE django_content_type (id STRING(64) NOT NULL, name STRING(100) " + "NOT NULL, app_label STRING(100) NOT NULL, model STRING(100) NOT NULL) PRIMARY KEY(id)", + STMT_DDL, + ), + ( + "CREATE INDEX SongsBySingerAlbumSongNameDesc ON " + "Songs(SingerId, AlbumId, SongName DESC), INTERLEAVE IN Albums", + STMT_DDL, + ), + ("CREATE INDEX SongsBySongName ON Songs(SongName)", STMT_DDL), + ( + "CREATE INDEX AlbumsByAlbumTitle2 ON Albums(AlbumTitle) STORING (MarketingBudget)", + STMT_DDL, + ), ] for tt in cases: sql, want_classification = tt got_classification = classify_stmt(sql) - self.assertEqual(got_classification, want_classification, 'Classification mismatch') + self.assertEqual( + got_classification, want_classification, "Classification mismatch" + ) def test_parse_insert(self): cases = [ ( - 'INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)', + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", [1, 2, 3, 4, 5, 6], { - 'sql_params_list': [ + "sql_params_list": [ ( - 'INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)', - (1, 2, 3,), + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", + (1, 2, 3), ), ( - 'INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)', - (4, 5, 6,), + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", + (4, 5, 6), ), - ], + ] }, ), ( - 'INSERT INTO django_migrations(app, name, applied) VALUES (%s, %s, %s)', + "INSERT INTO django_migrations(app, name, applied) VALUES (%s, %s, %s)", [1, 2, 3, 4, 5, 6], { - 'sql_params_list': [ + "sql_params_list": [ ( - 'INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)', - (1, 2, 3,), + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", + (1, 2, 3), ), ( - 'INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)', - (4, 5, 6,), + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", + (4, 5, 6), ), - ], + ] }, ), ( - 'INSERT INTO sales.addresses (street, city, state, zip_code) ' - 'SELECT street, city, state, zip_code FROM sales.customers' - 'ORDER BY first_name, last_name', + "INSERT INTO sales.addresses (street, city, state, zip_code) " + "SELECT street, city, state, zip_code FROM sales.customers" + "ORDER BY first_name, last_name", None, { - 'sql_params_list': [( - 'INSERT INTO sales.addresses (street, city, state, zip_code) ' - 'SELECT street, city, state, zip_code FROM sales.customers' - 'ORDER BY first_name, last_name', - None, - )], - } + "sql_params_list": [ + ( + "INSERT INTO sales.addresses (street, city, state, zip_code) " + "SELECT street, city, state, zip_code FROM sales.customers" + "ORDER BY first_name, last_name", + None, + ) + ] + }, ), ( - - 'INSERT INTO ap (n, ct, cn) ' - 'VALUES (%s, %s, %s), (%s, %s, %s), (%s, %s, %s),(%s, %s, %s)', + "INSERT INTO ap (n, ct, cn) " + "VALUES (%s, %s, %s), (%s, %s, %s), (%s, %s, %s),(%s, %s, %s)", (1, 2, 3, 4, 5, 6, 7, 8, 9), { - 'sql_params_list': [ - ( - 'INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)', - (1, 2, 3,), - ), - ( - 'INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)', - (4, 5, 6,), - ), - ( - 'INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)', - (7, 8, 9,), - ), - ], + "sql_params_list": [ + ("INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", (1, 2, 3)), + ("INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", (4, 5, 6)), + ("INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", (7, 8, 9)), + ] }, ), ( - 'INSERT INTO `no` (`yes`) VALUES (%s)', + "INSERT INTO `no` (`yes`) VALUES (%s)", (1, 4, 5), { - 'sql_params_list': [ - ( - 'INSERT INTO `no` (`yes`) VALUES (%s)', - (1,), - ), - ( - 'INSERT INTO `no` (`yes`) VALUES (%s)', - (4,), - ), - ( - 'INSERT INTO `no` (`yes`) VALUES (%s)', - (5,), - ), - ], + "sql_params_list": [ + ("INSERT INTO `no` (`yes`) VALUES (%s)", (1,)), + ("INSERT INTO `no` (`yes`) VALUES (%s)", (4,)), + ("INSERT INTO `no` (`yes`) VALUES (%s)", (5,)), + ] }, ), ( - 'INSERT INTO T (f1, f2) VALUES (1, 2)', + "INSERT INTO T (f1, f2) VALUES (1, 2)", None, - { - 'sql_params_list': [ - ( - 'INSERT INTO T (f1, f2) VALUES (1, 2)', - None, - ), - ], - }, + {"sql_params_list": [("INSERT INTO T (f1, f2) VALUES (1, 2)", None)]}, ), ( - 'INSERT INTO `no` (`yes`, tiff) VALUES (%s, LOWER(%s)), (%s, %s), (%s, %s)', - (1, 'FOO', 5, 10, 11, 29), + "INSERT INTO `no` (`yes`, tiff) VALUES (%s, LOWER(%s)), (%s, %s), (%s, %s)", + (1, "FOO", 5, 10, 11, 29), { - 'sql_params_list': [ - ('INSERT INTO `no` (`yes`, tiff) VALUES(%s, LOWER(%s))', (1, 'FOO',)), - ('INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)', (5, 10)), - ('INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)', (11, 29)), - ], + "sql_params_list": [ + ( + "INSERT INTO `no` (`yes`, tiff) VALUES(%s, LOWER(%s))", + (1, "FOO"), + ), + ("INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)", (5, 10)), + ("INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)", (11, 29)), + ] }, ), ] @@ -161,71 +157,68 @@ def test_parse_insert(self): for sql, params, want in cases: with self.subTest(sql=sql): got = parse_insert(sql, params) - self.assertEqual(got, want, 'Mismatch with parse_insert of `%s`' % sql) + self.assertEqual(got, want, "Mismatch with parse_insert of `%s`" % sql) def test_parse_insert_invalid(self): cases = [ ( - 'INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, %s)', + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, %s)", [1, 2, 3, 4, 5, 6, 7], - 'len\\(params\\)=7 MUST be a multiple of len\\(pyformat_args\\)=3', + "len\\(params\\)=7 MUST be a multiple of len\\(pyformat_args\\)=3", ), ( - 'INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, LOWER(%s))', + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, LOWER(%s))", [1, 2, 3, 4, 5, 6, 7], - 'Invalid length: VALUES\\(...\\) len: 6 != len\\(params\\): 7', + "Invalid length: VALUES\\(...\\) len: 6 != len\\(params\\): 7", ), ( - 'INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, LOWER(%s)))', + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, LOWER(%s)))", [1, 2, 3, 4, 5, 6], - 'VALUES: expected `,` got \\) in \\)', + "VALUES: expected `,` got \\) in \\)", ), ] for sql, params, wantException in cases: with self.subTest(sql=sql): self.assertRaisesRegex( - ProgrammingError, - wantException, - lambda: parse_insert(sql, params), + ProgrammingError, wantException, lambda: parse_insert(sql, params) ) def test_rows_for_insert_or_update(self): cases = [ ( - ['id', 'app', 'name'], - [(5, 'ap', 'n',), (6, 'bp', 'm',)], + ["id", "app", "name"], + [(5, "ap", "n"), (6, "bp", "m")], None, - [(5, 'ap', 'n',), (6, 'bp', 'm',)], + [(5, "ap", "n"), (6, "bp", "m")], ), ( - ['app', 'name'], - [('ap', 'n',), ('bp', 'm',)], + ["app", "name"], + [("ap", "n"), ("bp", "m")], None, - [('ap', 'n'), ('bp', 'm',)] + [("ap", "n"), ("bp", "m")], ), ( - ['app', 'name', 'fn'], - ['ap', 'n', 'f1', 'bp', 'm', 'f2', 'cp', 'o', 'f3'], - ['(%s, %s, %s)', '(%s, %s, %s)', '(%s, %s, %s)'], - [('ap', 'n', 'f1',), ('bp', 'm', 'f2',), ('cp', 'o', 'f3',)] + ["app", "name", "fn"], + ["ap", "n", "f1", "bp", "m", "f2", "cp", "o", "f3"], + ["(%s, %s, %s)", "(%s, %s, %s)", "(%s, %s, %s)"], + [("ap", "n", "f1"), ("bp", "m", "f2"), ("cp", "o", "f3")], ), ( - ['app', 'name', 'fn', 'ln'], - [('ap', 'n', (45, 'nested',), 'll',), ('bp', 'm', 'f2', 'mt',), ('fp', 'cp', 'o', 'f3',)], - None, + ["app", "name", "fn", "ln"], [ - ('ap', 'n', (45, 'nested',), 'll',), - ('bp', 'm', 'f2', 'mt',), - ('fp', 'cp', 'o', 'f3',), + ("ap", "n", (45, "nested"), "ll"), + ("bp", "m", "f2", "mt"), + ("fp", "cp", "o", "f3"), ], - ), - ( - ['app', 'name', 'fn'], - ['ap', 'n', 'f1'], None, - [('ap', 'n', 'f1',)] + [ + ("ap", "n", (45, "nested"), "ll"), + ("bp", "m", "f2", "mt"), + ("fp", "cp", "o", "f3"), + ], ), + (["app", "name", "fn"], ["ap", "n", "f1"], None, [("ap", "n", "f1")]), ] for i, (columns, params, pyformat_args, want) in enumerate(cases): @@ -236,118 +229,181 @@ def test_rows_for_insert_or_update(self): def test_sql_pyformat_args_to_spanner(self): cases = [ ( - ('SELECT * from t WHERE f1=%s, f2 = %s, f3=%s', (10, 'abc', 'y**$22l3f',)), - ('SELECT * from t WHERE f1=@a0, f2 = @a1, f3=@a2', {'a0': 10, 'a1': 'abc', 'a2': 'y**$22l3f'}), + ( + "SELECT * from t WHERE f1=%s, f2 = %s, f3=%s", + (10, "abc", "y**$22l3f"), + ), + ( + "SELECT * from t WHERE f1=@a0, f2 = @a1, f3=@a2", + {"a0": 10, "a1": "abc", "a2": "y**$22l3f"}, + ), ), ( - ('INSERT INTO t (f1, f2, f2) VALUES (%s, %s, %s)', ('app', 'name', 'applied',)), - ('INSERT INTO t (f1, f2, f2) VALUES (@a0, @a1, @a2)', {'a0': 'app', 'a1': 'name', 'a2': 'applied'}), + ( + "INSERT INTO t (f1, f2, f2) VALUES (%s, %s, %s)", + ("app", "name", "applied"), + ), + ( + "INSERT INTO t (f1, f2, f2) VALUES (@a0, @a1, @a2)", + {"a0": "app", "a1": "name", "a2": "applied"}, + ), ), ( ( - 'INSERT INTO t (f1, f2, f2) VALUES (%(f1)s, %(f2)s, %(f3)s)', - {'f1': 'app', 'f2': 'name', 'f3': 'applied'}, + "INSERT INTO t (f1, f2, f2) VALUES (%(f1)s, %(f2)s, %(f3)s)", + {"f1": "app", "f2": "name", "f3": "applied"}, ), ( - 'INSERT INTO t (f1, f2, f2) VALUES (@a0, @a1, @a2)', - {'a0': 'app', 'a1': 'name', 'a2': 'applied'}, + "INSERT INTO t (f1, f2, f2) VALUES (@a0, @a1, @a2)", + {"a0": "app", "a1": "name", "a2": "applied"}, ), ), ( # Intentionally using a dict with more keys than will be resolved. - ('SELECT * from t WHERE f1=%(f1)s', {'f1': 'app', 'f2': 'name'}), - ('SELECT * from t WHERE f1=@a0', {'a0': 'app'}), + ("SELECT * from t WHERE f1=%(f1)s", {"f1": "app", "f2": "name"}), + ("SELECT * from t WHERE f1=@a0", {"a0": "app"}), ), ( # No args to replace, we MUST return the original params dict # since it might be useful to pass to the next user. - ('SELECT * from t WHERE id=10', {'f1': 'app', 'f2': 'name'}), - ('SELECT * from t WHERE id=10', {'f1': 'app', 'f2': 'name'}), + ("SELECT * from t WHERE id=10", {"f1": "app", "f2": "name"}), + ("SELECT * from t WHERE id=10", {"f1": "app", "f2": "name"}), ), ( - ('SELECT (an.p + %s) AS np FROM an WHERE (an.p + %s) = %s', (1, 1.0, decimal.Decimal('31'),)), - ('SELECT (an.p + @a0) AS np FROM an WHERE (an.p + @a1) = @a2', {'a0': 1, 'a1': 1.0, 'a2': 31.0}), + ( + "SELECT (an.p + %s) AS np FROM an WHERE (an.p + %s) = %s", + (1, 1.0, decimal.Decimal("31")), + ), + ( + "SELECT (an.p + @a0) AS np FROM an WHERE (an.p + @a1) = @a2", + {"a0": 1, "a1": 1.0, "a2": 31.0}, + ), ), ] for ((sql_in, params), sql_want) in cases: with self.subTest(sql=sql_in): got_sql, got_named_args = sql_pyformat_args_to_spanner(sql_in, params) want_sql, want_named_args = sql_want - self.assertEqual(got_sql, want_sql, 'SQL does not match') - self.assertEqual(got_named_args, want_named_args, 'Named args do not match') + self.assertEqual(got_sql, want_sql, "SQL does not match") + self.assertEqual( + got_named_args, want_named_args, "Named args do not match" + ) def test_sql_pyformat_args_to_spanner_invalid(self): cases = [ - ('SELECT * from t WHERE f1=%s, f2 = %s, f3=%s, extra=%s', (10, 'abc', 'y**$22l3f',)), + ( + "SELECT * from t WHERE f1=%s, f2 = %s, f3=%s, extra=%s", + (10, "abc", "y**$22l3f"), + ) ] for sql, params in cases: with self.subTest(sql=sql): - self.assertRaisesRegex(Error, 'pyformat_args mismatch', - lambda: sql_pyformat_args_to_spanner(sql, params), - ) + self.assertRaisesRegex( + Error, + "pyformat_args mismatch", + lambda: sql_pyformat_args_to_spanner(sql, params), + ) def test_get_param_types(self): cases = [ ( - {'a1': 10, 'b1': '2005-08-30T01:01:01.000001Z', 'c1': '2019-12-05', 'd1': 10.39}, { - 'a1': param_types.INT64, - 'b1': param_types.STRING, - 'c1': param_types.STRING, - 'd1': param_types.FLOAT64, + "a1": 10, + "b1": "2005-08-30T01:01:01.000001Z", + "c1": "2019-12-05", + "d1": 10.39, + }, + { + "a1": param_types.INT64, + "b1": param_types.STRING, + "c1": param_types.STRING, + "d1": param_types.FLOAT64, }, ), ( - {'a1': 10, 'b1': TimestampStr('2005-08-30T01:01:01.000001Z'), 'c1': '2019-12-05'}, - {'a1': param_types.INT64, 'b1': param_types.TIMESTAMP, 'c1': param_types.STRING}, + { + "a1": 10, + "b1": TimestampStr("2005-08-30T01:01:01.000001Z"), + "c1": "2019-12-05", + }, + { + "a1": param_types.INT64, + "b1": param_types.TIMESTAMP, + "c1": param_types.STRING, + }, ), ( - {'a1': 10, 'b1': '2005-08-30T01:01:01.000001Z', 'c1': DateStr('2019-12-05')}, - {'a1': param_types.INT64, 'b1': param_types.STRING, 'c1': param_types.DATE}, + { + "a1": 10, + "b1": "2005-08-30T01:01:01.000001Z", + "c1": DateStr("2019-12-05"), + }, + { + "a1": param_types.INT64, + "b1": param_types.STRING, + "c1": param_types.DATE, + }, ), ( - {'a1': 10, 'b1': '2005-08-30T01:01:01.000001Z'}, - {'a1': param_types.INT64, 'b1': param_types.STRING}, + {"a1": 10, "b1": "2005-08-30T01:01:01.000001Z"}, + {"a1": param_types.INT64, "b1": param_types.STRING}, ), ( - {'a1': 10, 'b1': TimestampStr('2005-08-30T01:01:01.000001Z'), 'c1': DateStr('2005-08-30')}, - {'a1': param_types.INT64, 'b1': param_types.TIMESTAMP, 'c1': param_types.DATE}, + { + "a1": 10, + "b1": TimestampStr("2005-08-30T01:01:01.000001Z"), + "c1": DateStr("2005-08-30"), + }, + { + "a1": param_types.INT64, + "b1": param_types.TIMESTAMP, + "c1": param_types.DATE, + }, ), ( - {'a1': 10, 'b1': 'aaaaa08-30T01:01:01.000001Z'}, - {'a1': param_types.INT64, 'b1': param_types.STRING}, + {"a1": 10, "b1": "aaaaa08-30T01:01:01.000001Z"}, + {"a1": param_types.INT64, "b1": param_types.STRING}, ), ( - {'a1': 10, 'b1': '2005-08-30T01:01:01.000001', 't1': True, 't2': False, 'f1': 99e9}, { - 'a1': param_types.INT64, - 'b1': param_types.STRING, - 't1': param_types.BOOL, - 't2': param_types.BOOL, - 'f1': param_types.FLOAT64, + "a1": 10, + "b1": "2005-08-30T01:01:01.000001", + "t1": True, + "t2": False, + "f1": 99e9, + }, + { + "a1": param_types.INT64, + "b1": param_types.STRING, + "t1": param_types.BOOL, + "t2": param_types.BOOL, + "f1": param_types.FLOAT64, }, ), ( - {'a1': 10, 'b1': '2019-11-26T02:55:41.000000Z'}, - {'a1': param_types.INT64, 'b1': param_types.STRING}, + {"a1": 10, "b1": "2019-11-26T02:55:41.000000Z"}, + {"a1": param_types.INT64, "b1": param_types.STRING}, ), ( { - 'a1': 10, 'b1': TimestampStr('2019-11-26T02:55:41.000000Z'), - 'dt1': datetime.datetime(2011, 9, 1, 13, 20, 30), - 'd1': datetime.date(2011, 9, 1), + "a1": 10, + "b1": TimestampStr("2019-11-26T02:55:41.000000Z"), + "dt1": datetime.datetime(2011, 9, 1, 13, 20, 30), + "d1": datetime.date(2011, 9, 1), }, { - 'a1': param_types.INT64, 'b1': param_types.TIMESTAMP, - 'dt1': param_types.TIMESTAMP, 'd1': param_types.DATE, + "a1": param_types.INT64, + "b1": param_types.TIMESTAMP, + "dt1": param_types.TIMESTAMP, + "d1": param_types.DATE, }, ), ( - {'a1': 10, 'b1': TimestampStr('2019-11-26T02:55:41.000000Z')}, - {'a1': param_types.INT64, 'b1': param_types.TIMESTAMP}, + {"a1": 10, "b1": TimestampStr("2019-11-26T02:55:41.000000Z")}, + {"a1": param_types.INT64, "b1": param_types.TIMESTAMP}, ), - ({'a1': b'bytes'}, {'a1': param_types.BYTES}), - ({'a1': 10, 'b1': None}, {'a1': param_types.INT64}), + ({"a1": b"bytes"}, {"a1": param_types.BYTES}), + ({"a1": 10, "b1": None}, {"a1": param_types.INT64}), (None, None), ] @@ -357,45 +413,31 @@ def test_get_param_types(self): self.assertEqual(got_param_types, want_param_types) def test_ensure_where_clause(self): - cases = [ - ( - 'UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1', - 'UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1', - ), - ( - 'UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5', - 'UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1', - ), - ( - 'UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2', - 'UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2', - ), - ( - 'UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)', - 'UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)', - ), - ( - 'UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)', - 'UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)', - ), - ( - 'DELETE * FROM TABLE', - 'DELETE * FROM TABLE WHERE 1=1', - ), - ] + cases = ( + "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", + "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", + "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", + ) + err_cases = ( + "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5", + "DELETE * FROM TABLE", + ) + for sql in cases: + with self.subTest(sql=sql): + ensure_where_clause(sql) - for sql, want in cases: + for sql in err_cases: with self.subTest(sql=sql): - got = ensure_where_clause(sql) - self.assertEqual(got, want) + with self.assertRaises(ProgrammingError): + ensure_where_clause(sql) def test_escape_name(self): cases = [ - ('SELECT', '`SELECT`'), - ('id', 'id'), - ('', ''), - ('dashed-value', '`dashed-value`'), - ('with space', '`with space`'), + ("SELECT", "`SELECT`"), + ("id", "id"), + ("", ""), + ("dashed-value", "`dashed-value`"), + ("with space", "`with space`"), ] for name, want in cases: @@ -404,10 +446,7 @@ def test_escape_name(self): self.assertEqual(got, want) def test_strip_backticks(self): - cases = [ - ('foo', 'foo'), - ('`foo`', 'foo'), - ] + cases = [("foo", "foo"), ("`foo`", "foo")] for name, want in cases: with self.subTest(name=name): got = strip_backticks(name) @@ -415,11 +454,11 @@ def test_strip_backticks(self): def test_backtick_unicode(self): cases = [ - ('SELECT (1) as foo WHERE 1=1', 'SELECT (1) as foo WHERE 1=1'), - ('SELECT (1) as föö', 'SELECT (1) as `föö`'), - ('SELECT (1) as `föö`', 'SELECT (1) as `föö`'), - ('SELECT (1) as `föö` `umläut', 'SELECT (1) as `föö` `umläut'), - ('SELECT (1) as `föö', 'SELECT (1) as `föö'), + ("SELECT (1) as foo WHERE 1=1", "SELECT (1) as foo WHERE 1=1"), + ("SELECT (1) as föö", "SELECT (1) as `föö`"), + ("SELECT (1) as `föö`", "SELECT (1) as `föö`"), + ("SELECT (1) as `föö` `umläut", "SELECT (1) as `föö` `umläut"), + ("SELECT (1) as `föö", "SELECT (1) as `föö"), ] for sql, want in cases: with self.subTest(sql=sql): From 0ca360da0dbace356728dc3a684632a07d14058d Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Wed, 26 Aug 2020 12:22:44 +0300 Subject: [PATCH 06/12] move WHERE clause check upper in function steck to avoid starting a transaction --- spanner_dbapi/cursor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/spanner_dbapi/cursor.py b/spanner_dbapi/cursor.py index a43dabb096..776e050594 100644 --- a/spanner_dbapi/cursor.py +++ b/spanner_dbapi/cursor.py @@ -50,7 +50,7 @@ class Cursor: def __init__(self, connection): self._itr = None self._res = None - self._row_count = _UNSET_COUNT + self._rowcount = _UNSET_COUNT self._connection = connection self._is_closed = False @@ -96,10 +96,10 @@ def execute(self, sql, args=None): raise OperationalError(e.details if hasattr(e, "details") else e) def _handle_update(self, sql, params): + ensure_where_clause(sql) self._connection.in_transaction(self.__do_execute_update, sql, params) def __do_execute_update(self, transaction, sql, params, param_types=None): - ensure_where_clause(sql) sql, params = sql_pyformat_args_to_spanner(sql, params) res = transaction.execute_update( @@ -107,7 +107,7 @@ def __do_execute_update(self, transaction, sql, params, param_types=None): ) self._itr = None if type(res) == int: - self._row_count = res + self._rowcount = res return res @@ -166,7 +166,7 @@ def _handle_dql(self, sql, params): sql, params=params, param_types=get_param_types(params) ) if type(res) == int: - self._row_count = res + self._rowcount = res self._itr = None else: # Immediately using: @@ -183,7 +183,7 @@ def _handle_dql(self, sql, params): self._itr = PeekIterator(self._res) # Unfortunately, Spanner doesn't seem to send back # information about the number of rows available. - self._row_count = _UNSET_COUNT + self._rowcount = _UNSET_COUNT def __enter__(self): return self @@ -216,7 +216,7 @@ def description(self): @property def rowcount(self): - return self._row_count + return self._rowcount @property def is_closed(self): From 3235e0cc93f93dd9cdd1b776edd94da764be2119 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Thu, 3 Sep 2020 16:56:19 +0300 Subject: [PATCH 07/12] merge conflicts resolve --- tests/spanner_dbapi/test_parse_utils.py | 47 +++++++++++-------------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/tests/spanner_dbapi/test_parse_utils.py b/tests/spanner_dbapi/test_parse_utils.py index 202f762a55..6cd7834805 100644 --- a/tests/spanner_dbapi/test_parse_utils.py +++ b/tests/spanner_dbapi/test_parse_utils.py @@ -57,7 +57,9 @@ def test_classify_stmt(self): sql, want_classification = tt got_classification = classify_stmt(sql) self.assertEqual( - got_classification, want_classification, "Classification mismatch" + got_classification, + want_classification, + "Classification mismatch", ) def test_parse_insert(self): @@ -202,7 +204,9 @@ def test_parse_insert_invalid(self): for sql, params, wantException in cases: with self.subTest(sql=sql): self.assertRaisesRegex( - ProgrammingError, wantException, lambda: parse_insert(sql, params) + ProgrammingError, + wantException, + lambda: parse_insert(sql, params), ) def test_rows_for_insert_or_update(self): @@ -234,6 +238,9 @@ def test_rows_for_insert_or_update(self): ], None, [ + ("ap", "n", (45, "nested"), "ll"), + ("bp", "m", "f2", "mt"), + ("fp", "cp", "o", "f3"), ], ), ( @@ -242,7 +249,6 @@ def test_rows_for_insert_or_update(self): None, [("ap", "n", "f1")], ), - (["app", "name", "fn"], ["ap", "n", "f1"], None, [("ap", "n", "f1")]), ] for i, (columns, params, pyformat_args, want) in enumerate(cases): @@ -442,29 +448,18 @@ def test_get_param_types(self): self.assertEqual(got_param_types, want_param_types) def test_ensure_where_clause(self): - cases = [ - ( - "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", - "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", - ), - ( - "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5", - "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1", - ), - ( - "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", - "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", - ), - ( - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - ), - ( - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - ), - ("DELETE * FROM TABLE", "DELETE * FROM TABLE WHERE 1=1"), - ] + cases = ( + "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", + "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", + "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", + ) + err_cases = ( + "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5", + "DELETE * FROM TABLE", + ) + for sql in cases: + with self.subTest(sql=sql): + ensure_where_clause(sql) for sql in err_cases: with self.subTest(sql=sql): From 1edf8a50cbc42c173c81b882cc4776737a918387 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Thu, 3 Sep 2020 17:03:03 +0300 Subject: [PATCH 08/12] lint fix --- spanner_dbapi/cursor.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/spanner_dbapi/cursor.py b/spanner_dbapi/cursor.py index 04e314a394..f976d7eb82 100644 --- a/spanner_dbapi/cursor.py +++ b/spanner_dbapi/cursor.py @@ -103,8 +103,11 @@ def execute(self, sql, args=None): self._handle_insert(sql, args) else: self._handle_update(sql, args) - except (grpc_exceptions.AlreadyExists, grpc_exceptions.FailedPrecondition) as e: - self.__handle_update(sql, args or None) + except ( + grpc_exceptions.AlreadyExists, + grpc_exceptions.FailedPrecondition, + ) as e: + self.__handle_update(sql, args or None) except (AlreadyExists, FailedPrecondition) as e: raise IntegrityError(e.details if hasattr(e, "details") else e) except InvalidArgument as e: From 3cb6f530574d9885d34942eed6f486ea93faf745 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Thu, 3 Sep 2020 17:15:19 +0300 Subject: [PATCH 09/12] fix test merge errors --- spanner_dbapi/cursor.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/spanner_dbapi/cursor.py b/spanner_dbapi/cursor.py index f976d7eb82..1c3cbfada8 100644 --- a/spanner_dbapi/cursor.py +++ b/spanner_dbapi/cursor.py @@ -103,11 +103,8 @@ def execute(self, sql, args=None): self._handle_insert(sql, args) else: self._handle_update(sql, args) - except ( - grpc_exceptions.AlreadyExists, - grpc_exceptions.FailedPrecondition, - ) as e: - self.__handle_update(sql, args or None) + except (AlreadyExists, FailedPrecondition) as e: + raise IntegrityError(e.details if hasattr(e, "details") else e) except (AlreadyExists, FailedPrecondition) as e: raise IntegrityError(e.details if hasattr(e, "details") else e) except InvalidArgument as e: @@ -116,6 +113,7 @@ def execute(self, sql, args=None): raise OperationalError(e.details if hasattr(e, "details") else e) def __handle_update(self, sql, params): + ensure_where_clause(sql) self._connection.in_transaction(self.__do_execute_update, sql, params) def __do_execute_update(self, transaction, sql, params, param_types=None): From 316f6378010d8b494b8dac0e1ba5c0ab5fc9b894 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Thu, 3 Sep 2020 20:47:26 +0300 Subject: [PATCH 10/12] args condition --- spanner_dbapi/parse_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spanner_dbapi/parse_utils.py b/spanner_dbapi/parse_utils.py index 5b79116348..897f1c6756 100644 --- a/spanner_dbapi/parse_utils.py +++ b/spanner_dbapi/parse_utils.py @@ -380,7 +380,7 @@ def get_param_types(params): """ Return a dictionary of spanner.param_types for a dictionary of parameters. """ - if params is None: + if not params: return None param_types = {} for key, value in params.items(): From e855e933e308ad2a195fa20d48650d78e630ffaa Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Thu, 3 Sep 2020 21:21:58 +0300 Subject: [PATCH 11/12] return back or None statements --- spanner_dbapi/cursor.py | 6 +++--- spanner_dbapi/parse_utils.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/spanner_dbapi/cursor.py b/spanner_dbapi/cursor.py index 1c3cbfada8..e8cf165356 100644 --- a/spanner_dbapi/cursor.py +++ b/spanner_dbapi/cursor.py @@ -98,11 +98,11 @@ def execute(self, sql, args=None): self._run_prior_ddl_statements() if classification == STMT_NON_UPDATING: - self._handle_dql(sql, args) + self._handle_dql(sql, args or None) elif classification == STMT_INSERT: - self._handle_insert(sql, args) + self._handle_insert(sql, args or None) else: - self._handle_update(sql, args) + self._handle_update(sql, args or None) except (AlreadyExists, FailedPrecondition) as e: raise IntegrityError(e.details if hasattr(e, "details") else e) except (AlreadyExists, FailedPrecondition) as e: diff --git a/spanner_dbapi/parse_utils.py b/spanner_dbapi/parse_utils.py index 897f1c6756..5b79116348 100644 --- a/spanner_dbapi/parse_utils.py +++ b/spanner_dbapi/parse_utils.py @@ -380,7 +380,7 @@ def get_param_types(params): """ Return a dictionary of spanner.param_types for a dictionary of parameters. """ - if not params: + if params is None: return None param_types = {} for key, value in params.items(): From 52ff0263445fb20ff159835a689ed1b07e06494f Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Thu, 3 Sep 2020 21:45:25 +0300 Subject: [PATCH 12/12] fix method name --- spanner_dbapi/cursor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spanner_dbapi/cursor.py b/spanner_dbapi/cursor.py index e8cf165356..4a0a923236 100644 --- a/spanner_dbapi/cursor.py +++ b/spanner_dbapi/cursor.py @@ -112,7 +112,7 @@ def execute(self, sql, args=None): except InternalServerError as e: raise OperationalError(e.details if hasattr(e, "details") else e) - def __handle_update(self, sql, params): + def _handle_update(self, sql, params): ensure_where_clause(sql) self._connection.in_transaction(self.__do_execute_update, sql, params)