diff --git a/google/cloud/spanner_dbapi/batch_dml_executor.py b/google/cloud/spanner_dbapi/batch_executor.py similarity index 72% rename from google/cloud/spanner_dbapi/batch_dml_executor.py rename to google/cloud/spanner_dbapi/batch_executor.py index 7c4272a0ca..d9b350a104 100644 --- a/google/cloud/spanner_dbapi/batch_dml_executor.py +++ b/google/cloud/spanner_dbapi/batch_executor.py @@ -16,6 +16,8 @@ from enum import Enum from typing import TYPE_CHECKING, List + +from google.cloud.spanner_dbapi import parse_utils from google.cloud.spanner_dbapi.parsed_statement import ( ParsedStatement, StatementType, @@ -28,6 +30,46 @@ if TYPE_CHECKING: from google.cloud.spanner_dbapi.cursor import Cursor + from google.cloud.spanner_dbapi.connection import Connection + + +class BatchDdlExecutor: + """Executor that is used when a DDL batch is started. These batches only + accept DDL statements. All DDL statements are buffered locally and sent to + Spanner when runBatch() is called. + + :type "Connection": :class:`~google.cloud.spanner_dbapi.connection.Connection` + :param connection: + """ + + def __init__(self, connection: "Connection"): + self._connection = connection + self._statements: List[str] = [] + + def execute_statement(self, parsed_statement: ParsedStatement): + """Executes the statement when ddl batch is active by buffering the + statement in-memory. + + This method is internal and not for public use + + :type parsed_statement: ParsedStatement + :param parsed_statement: parsed statement containing sql query + """ + from google.cloud.spanner_dbapi import ProgrammingError + + if parsed_statement.statement_type != StatementType.DDL: + raise ProgrammingError("Only DDL statements are allowed in batch DDL mode.") + self._statements.extend( + parse_utils.parse_and_get_ddl_statements(parsed_statement.statement.sql) + ) + + def run_batch(self): + """Executes all the buffered statements on the active ddl batch by + making a call to Spanner. + + This method is internal and not for public use + """ + return self._connection.database.update_ddl(self._statements).result() class BatchDmlExecutor: @@ -48,10 +90,13 @@ def execute_statement(self, parsed_statement: ParsedStatement): """Executes the statement when dml batch is active by buffering the statement in-memory. + This method is internal and not for public use + :type parsed_statement: ParsedStatement :param parsed_statement: parsed statement containing sql query and query params """ + from google.cloud.spanner_dbapi import ProgrammingError if ( @@ -61,9 +106,11 @@ def execute_statement(self, parsed_statement: ParsedStatement): raise ProgrammingError("Only DML statements are allowed in batch DML mode.") self._statements.append(parsed_statement.statement) - def run_batch_dml(self): + def run_batch(self): """Executes all the buffered statements on the active dml batch by making a call to Spanner. + + This method is internal and not for public use """ return run_batch_dml(self._cursor, self._statements) @@ -71,6 +118,8 @@ def run_batch_dml(self): def run_batch_dml(cursor: "Cursor", statements: List[Statement]): """Executes all the dml statements by making a batch call to Spanner. + This method is internal and not for public use + :type cursor: Cursor :param cursor: Database Cursor object diff --git a/google/cloud/spanner_dbapi/client_side_statement_executor.py b/google/cloud/spanner_dbapi/client_side_statement_executor.py index dfbf33c1e8..735452a49b 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_executor.py +++ b/google/cloud/spanner_dbapi/client_side_statement_executor.py @@ -85,6 +85,9 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement): TypeCode.TIMESTAMP, column_values, ) + if statement_type == ClientSideStatementType.START_BATCH_DDL: + connection.start_batch_ddl() + return None if statement_type == ClientSideStatementType.START_BATCH_DML: connection.start_batch_dml(cursor) return None diff --git a/google/cloud/spanner_dbapi/client_side_statement_parser.py b/google/cloud/spanner_dbapi/client_side_statement_parser.py index 63188a032a..2dcacc00f8 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_parser.py +++ b/google/cloud/spanner_dbapi/client_side_statement_parser.py @@ -30,6 +30,7 @@ RE_SHOW_READ_TIMESTAMP = re.compile( r"^\s*(SHOW)\s+(VARIABLE)\s+(READ_TIMESTAMP)", re.IGNORECASE ) +RE_START_BATCH_DDL = re.compile(r"^\s*(START)\s+(BATCH)\s+(DDL)", re.IGNORECASE) RE_START_BATCH_DML = re.compile(r"^\s*(START)\s+(BATCH)\s+(DML)", re.IGNORECASE) RE_RUN_BATCH = re.compile(r"^\s*(RUN)\s+(BATCH)", re.IGNORECASE) RE_ABORT_BATCH = re.compile(r"^\s*(ABORT)\s+(BATCH)", re.IGNORECASE) @@ -62,6 +63,8 @@ def parse_stmt(query): client_side_statement_type = ClientSideStatementType.SHOW_COMMIT_TIMESTAMP elif RE_SHOW_READ_TIMESTAMP.match(query): client_side_statement_type = ClientSideStatementType.SHOW_READ_TIMESTAMP + elif RE_START_BATCH_DDL.match(query): + client_side_statement_type = ClientSideStatementType.START_BATCH_DDL elif RE_START_BATCH_DML.match(query): client_side_statement_type = ClientSideStatementType.START_BATCH_DML elif RE_BEGIN.match(query): diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 27983b8bd5..f894bdeb67 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -19,7 +19,11 @@ from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud import spanner_v1 as spanner from google.cloud.spanner_dbapi import partition_helper -from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor +from google.cloud.spanner_dbapi.batch_executor import ( + BatchMode, + BatchDmlExecutor, + BatchDdlExecutor, +) from google.cloud.spanner_dbapi.parse_utils import _get_statement_type from google.cloud.spanner_dbapi.parsed_statement import ( StatementType, @@ -91,7 +95,9 @@ class Connection: should end a that a new one should be started when the next statement is executed. """ - def __init__(self, instance, database=None, read_only=False): + def __init__( + self, instance, database=None, read_only=False, buffer_ddl_statements=False + ): self._instance = instance self._database = database self._ddl_statements = [] @@ -114,8 +120,10 @@ def __init__(self, instance, database=None, read_only=False): # made atleast one call to Spanner. self._spanner_transaction_started = False self._batch_mode = BatchMode.NONE + self._batch_ddl_executor: BatchDdlExecutor = None self._batch_dml_executor: BatchDmlExecutor = None self._transaction_helper = TransactionRetryHelper(self) + self._buffer_ddl_statements = buffer_ddl_statements @property def autocommit(self): @@ -126,6 +134,30 @@ def autocommit(self): """ return self._autocommit + @property + def buffer_ddl_statements(self): + """Whether to buffer ddl statements for this connection. + This flag determines how DDL statements are handled when auto_commit=False: + + 1. buffer_ddl_statements=True: DDL statements are buffered in the client until the + next non-DDL statement, or until the transaction is committed. Executing a + non-DDL statement causes the connection to send all buffered DDL statements + to Spanner, and then to execute the non-DDL statement. Note that although the + DDL statements are sent as one batch to Spanner, they are not guaranteed to be + atomic. See https://cloud.google.com/spanner/docs/reference/rpc/google.spanner.admin.database.v1#google.spanner.admin.database.v1.UpdateDatabaseDdlRequest + for more details on DDL batches. + 2. buffer_ddl_statements=False: Executing DDL statements is not allowed and an + error will be raised when the connection is in auto_commit=False mode, a + transaction is active, and DDL statements cannot be executed in a transaction. + + This flag is ignored when auto_commit=True. + + + :rtype: bool + :returns: _buffer_ddl_statements flag value. + """ + return self._buffer_ddl_statements + @autocommit.setter def autocommit(self, value): """Change this connection autocommit mode. Setting this value to True @@ -134,6 +166,10 @@ def autocommit(self, value): :type value: bool :param value: New autocommit mode state. """ + if self._batch_mode is not BatchMode.NONE: + raise ProgrammingError( + "Can't change the autocommit mode as the batch is already active." + ) if value and not self._autocommit and self._spanner_transaction_started: self.commit() @@ -365,7 +401,8 @@ def commit(self): ) return - self.run_prior_DDL_statements() + if self.buffer_ddl_statements: + self.run_prior_DDL_statements() try: if self._spanner_transaction_started and not self._read_only: self._transaction.commit() @@ -463,8 +500,46 @@ def validate(self): "Expected: [[1]]" % result ) + @check_not_closed + def start_batch_ddl(self): + """ + This method is internal and not for public use + """ + if self._batch_mode is not BatchMode.NONE: + raise ProgrammingError( + "Cannot start a DDL batch when a batch is already active" + ) + if self.read_only: + raise ProgrammingError( + "Cannot start a DDL batch when the connection is in read-only mode" + ) + if self.buffer_ddl_statements: + raise ProgrammingError( + "Cannot start a DDL batch when _buffer_ddl_statements flag is True" + ) + if self._client_transaction_started: + raise ProgrammingError( + "Cannot start a DDL batch when transaction is already active." + ) + self._batch_mode = BatchMode.DDL + self._batch_ddl_executor = BatchDdlExecutor(self) + + @check_not_closed + def execute_batch_ddl_statement(self, parsed_statement: ParsedStatement): + """ + This method is internal and not for public use + """ + if self._batch_mode is not BatchMode.DDL: + raise ProgrammingError( + "Cannot execute statement when the BatchMode is not DDL" + ) + self._batch_ddl_executor.execute_statement(parsed_statement) + @check_not_closed def start_batch_dml(self, cursor): + """ + This method is internal and not for public use + """ if self._batch_mode is not BatchMode.NONE: raise ProgrammingError( "Cannot start a DML batch when a batch is already active" @@ -478,6 +553,9 @@ def start_batch_dml(self, cursor): @check_not_closed def execute_batch_dml_statement(self, parsed_statement: ParsedStatement): + """ + This method is internal and not for public use + """ if self._batch_mode is not BatchMode.DML: raise ProgrammingError( "Cannot execute statement when the BatchMode is not DML" @@ -486,22 +564,34 @@ def execute_batch_dml_statement(self, parsed_statement: ParsedStatement): @check_not_closed def run_batch(self): + """ + This method is internal and not for public use + """ + result_set = None if self._batch_mode is BatchMode.NONE: raise ProgrammingError("Cannot run a batch when the BatchMode is not set") try: if self._batch_mode is BatchMode.DML: - many_result_set = self._batch_dml_executor.run_batch_dml() + result_set = self._batch_dml_executor.run_batch() + elif self._batch_mode is BatchMode.DDL: + self._batch_ddl_executor.run_batch() finally: self._batch_mode = BatchMode.NONE self._batch_dml_executor = None - return many_result_set + self._batch_ddl_executor = None + return result_set @check_not_closed def abort_batch(self): + """ + This method is internal and not for public use + """ if self._batch_mode is BatchMode.NONE: raise ProgrammingError("Cannot abort a batch when the BatchMode is not set") if self._batch_mode is BatchMode.DML: self._batch_dml_executor = None + if self._batch_mode is BatchMode.DDL: + self._batch_ddl_executor = None self._batch_mode = BatchMode.NONE @check_not_closed @@ -510,6 +600,9 @@ def partition_query( parsed_statement: ParsedStatement, query_options=None, ): + """ + This method is internal and not for public use + """ statement = parsed_statement.statement partitioned_query = parsed_statement.client_side_statement_params[0] self._partitioned_query_validation(partitioned_query, statement) @@ -534,6 +627,9 @@ def partition_query( @check_not_closed def run_partition(self, encoded_partition_id): + """ + This method is internal and not for public use + """ partition_id: PartitionId = partition_helper.decode_from_string( encoded_partition_id ) @@ -550,6 +646,9 @@ def run_partitioned_query( self, parsed_statement: ParsedStatement, ): + """ + This method is internal and not for public use + """ statement = parsed_statement.statement partitioned_query = parsed_statement.client_side_statement_params[0] self._partitioned_query_validation(partitioned_query, statement) @@ -584,10 +683,23 @@ def connect( pool=None, user_agent=None, client=None, + buffer_ddl_statements=False, route_to_leader_enabled=True, ): """Creates a connection to a Google Cloud Spanner database. + :type buffer_ddl_statements: bool + :param buffer_ddl_statements: Whether to buffer ddl statements at client + side. If the connection is in auto commit mode then this flag doesn't + have any significance as ddl statements would be executed as they come. + + For connection not in auto commit mode: + If enabled ddl statements would be buffered at client and not executed + at cloud spanner. When a non ddl statement comes or a transaction is + committed then all the existing buffered ddl statements would be executed. + + If disabled then its an error to execute ddl statement in autocommit mode. + :type instance_id: str :param instance_id: The ID of the instance to connect to. @@ -658,7 +770,9 @@ def connect( instance = client.instance(instance_id) conn = Connection( - instance, instance.database(database_id, pool=pool) if database_id else None + instance, + instance.database(database_id, pool=pool) if database_id else None, + buffer_ddl_statements=buffer_ddl_statements, ) if pool is not None: conn._own_pool = False diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index c8cb450394..aa552070a0 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -15,8 +15,6 @@ """Database cursor for Google Cloud Spanner DB API.""" from collections import namedtuple -import sqlparse - from google.api_core.exceptions import Aborted from google.api_core.exceptions import AlreadyExists from google.api_core.exceptions import FailedPrecondition @@ -25,7 +23,7 @@ from google.api_core.exceptions import OutOfRange from google.cloud import spanner_v1 as spanner -from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode +from google.cloud.spanner_dbapi.batch_executor import BatchMode from google.cloud.spanner_dbapi.exceptions import IntegrityError from google.cloud.spanner_dbapi.exceptions import InterfaceError from google.cloud.spanner_dbapi.exceptions import OperationalError @@ -34,7 +32,7 @@ from google.cloud.spanner_dbapi import ( _helpers, client_side_statement_executor, - batch_dml_executor, + batch_executor, ) from google.cloud.spanner_dbapi._helpers import ColumnInfo from google.cloud.spanner_dbapi._helpers import CODE_TO_DISPLAY_SIZE @@ -210,18 +208,8 @@ def _batch_DDLs(self, sql): :raises: :class:`ValueError` in case not a DDL statement present in the operation. """ - statements = [] - for ddl in sqlparse.split(sql): - if ddl: - ddl = ddl.rstrip(";") - if ( - parse_utils.classify_statement(ddl).statement_type - != StatementType.DDL - ): - raise ValueError("Only DDL statements may be batched.") - - statements.append(ddl) + statements = parse_utils.parse_and_get_ddl_statements(sql) # Only queue DDL statements if they are all correctly classified. self.connection._ddl_statements.extend(statements) @@ -261,6 +249,8 @@ def _execute(self, sql, args=None, call_from_execute_many=False): self._itr = self._result_set else: self._itr = PeekIterator(self._result_set) + elif self.connection._batch_mode == BatchMode.DDL: + self.connection.execute_batch_ddl_statement(self._parsed_statement) elif self.connection._batch_mode == BatchMode.DML: self.connection.execute_batch_dml_statement(self._parsed_statement) elif self.connection.read_only or ( @@ -269,9 +259,18 @@ def _execute(self, sql, args=None, call_from_execute_many=False): ): self._handle_DQL(sql, args or None) elif self._parsed_statement.statement_type == StatementType.DDL: - self._batch_DDLs(sql) - if not self.connection._client_transaction_started: - self.connection.run_prior_DDL_statements() + if not self.connection.buffer_ddl_statements: + if not self.connection._client_transaction_started: + self._batch_DDLs(sql) + self.connection.run_prior_DDL_statements() + else: + raise ProgrammingError( + "Cannot execute DDL statement when a transaction is already active" + ) + else: + self._batch_DDLs(sql) + if not self.connection._client_transaction_started: + self.connection.run_prior_DDL_statements() else: self._execute_in_rw_transaction() @@ -296,9 +295,8 @@ def _execute(self, sql, args=None, call_from_execute_many=False): self.connection._spanner_transaction_started = False def _execute_in_rw_transaction(self): - # For every other operation, we've got to ensure that - # any prior DDL statements were run. - self.connection.run_prior_DDL_statements() + if self.connection.buffer_ddl_statements: + self.connection.run_prior_DDL_statements() statement = self._parsed_statement.statement if self.connection._client_transaction_started: while True: @@ -347,9 +345,8 @@ def executemany(self, operation, seq_of_params): + ", with executemany() method is not allowed." ) - # For every operation, we've got to ensure that any prior DDL - # statements were run. - self.connection.run_prior_DDL_statements() + if self.connection.buffer_ddl_statements: + self.connection.run_prior_DDL_statements() if self._parsed_statement.statement_type in ( StatementType.INSERT, StatementType.UPDATE, @@ -360,7 +357,7 @@ def executemany(self, operation, seq_of_params): operation, params ) statements.append(Statement(sql, params, get_param_types(params))) - many_result_set = batch_dml_executor.run_batch_dml(self, statements) + many_result_set = batch_executor.run_batch_dml(self, statements) else: many_result_set = StreamedManyResultSets() for params in seq_of_params: @@ -523,7 +520,8 @@ def run_sql_in_snapshot(self, sql, params=None, param_types=None): # hence this method exists to circumvent that limit. if self.connection.database is None: raise ValueError("Database needs to be passed for this operation") - self.connection.run_prior_DDL_statements() + if self.connection.buffer_ddl_statements: + self.connection.run_prior_DDL_statements() with self.connection.database.snapshot() as snapshot: return list(snapshot.execute_sql(sql, params, param_types)) diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index b642daf084..929f08a206 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -205,6 +205,17 @@ def classify_stmt(query): return STMT_UPDATING +def parse_and_get_ddl_statements(sql): + statements = [] + for ddl in sqlparse.split(sql): + if ddl: + ddl = ddl.rstrip(";") + if classify_statement(ddl).statement_type != StatementType.DDL: + raise ValueError("Only DDL statements may be batched.") + statements.append(ddl) + return statements + + def classify_statement(query, args=None): """Determine SQL query type. diff --git a/google/cloud/spanner_dbapi/parsed_statement.py b/google/cloud/spanner_dbapi/parsed_statement.py index 1bb0ed25f4..ac4390502e 100644 --- a/google/cloud/spanner_dbapi/parsed_statement.py +++ b/google/cloud/spanner_dbapi/parsed_statement.py @@ -36,6 +36,7 @@ class ClientSideStatementType(Enum): PARTITION_QUERY = 9 RUN_PARTITION = 10 RUN_PARTITIONED_QUERY = 11 + START_BATCH_DDL = 12 @dataclass diff --git a/google/cloud/spanner_dbapi/transaction_helper.py b/google/cloud/spanner_dbapi/transaction_helper.py index bc896009c7..f4b0f1a81e 100644 --- a/google/cloud/spanner_dbapi/transaction_helper.py +++ b/google/cloud/spanner_dbapi/transaction_helper.py @@ -18,7 +18,7 @@ import time -from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode +from google.cloud.spanner_dbapi.batch_executor import BatchMode from google.cloud.spanner_dbapi.exceptions import RetryAborted from google.cloud.spanner_v1.session import _get_retry_delay diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index 52a80d5714..9e3a2e6eb8 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -86,7 +86,9 @@ def dbapi_database(self, raw_database): @pytest.fixture(autouse=True) def init_connection(self, request, shared_instance, dbapi_database): if "noautofixt" not in request.keywords: - self._conn = Connection(shared_instance, dbapi_database) + self._conn = Connection( + shared_instance, dbapi_database, buffer_ddl_statements=True + ) self._cursor = self._conn.cursor() yield if "noautofixt" not in request.keywords: @@ -152,7 +154,10 @@ def test_commit_exception(self): try: self._conn.commit() except Exception: - pass + # The new transaction might try to checkout the same session which was deleted earlier, + # which might result in Session not found exception. So we are removing the session from + # the session pool at client by getting the first session from the queue. + self._conn.database._pool._sessions.get() # Testing that the connection and Cursor are in proper state post commit # and a new transaction is started @@ -177,7 +182,10 @@ def test_rollback_exception(self): try: self._conn.rollback() except Exception: - pass + # The new transaction might try to checkout the same session which was deleted earlier, + # which might result in Session not found exception. So we are removing the session from + # the session pool at client by getting the first session from the queue. + self._conn.database._pool._sessions.get() # Testing that the connection and Cursor are in proper state post # exception in rollback and a new transaction is started @@ -454,6 +462,52 @@ def test_read_timestamp_client_side_autocommit(self): read_timestamp_query_result_2 = self._cursor.fetchall() assert read_timestamp_query_result_1 != read_timestamp_query_result_2 + @pytest.mark.parametrize("auto_commit", [False, True]) + def test_batch_ddl(self, auto_commit, dbapi_database): + """Test batch ddl.""" + + self._conn._buffer_ddl_statements = False + self._conn.autocommit = auto_commit + + exception_raised = False + try: + self._cursor.execute("start batch ddl") + self._cursor.execute( + """ + CREATE TABLE Table_1 ( + SingerId INT64 NOT NULL, + Name STRING(1024), + ) PRIMARY KEY (SingerId) + """ + ) + self._cursor.execute( + """ + CREATE TABLE Table_2 ( + SingerId INT64 NOT NULL, + Name STRING(1024), + ) PRIMARY KEY (SingerId) + """ + ) + self._cursor.execute("run batch") + except ProgrammingError: + exception_raised = True + + table_1 = dbapi_database.table("Table_1") + table_2 = dbapi_database.table("Table_2") + if auto_commit: + assert exception_raised is False + assert table_1.exists() is True + assert table_2.exists() is True + + self._cursor.execute("start batch ddl") + self._cursor.execute("DROP TABLE Table_1") + self._cursor.execute("DROP TABLE Table_2") + self._cursor.execute("run batch") + else: + assert exception_raised is True + assert table_1.exists() is False + assert table_2.exists() is False + @pytest.mark.parametrize("auto_commit", [False, True]) def test_batch_dml(self, auto_commit, dbapi_database): """Test batch dml.""" @@ -1113,44 +1167,14 @@ def test_execute_many(self): assert res[0] == 1 - @pytest.mark.noautofixt - def test_DDL_autocommit(self, shared_instance, dbapi_database): - """Check that DDLs in autocommit mode are immediately executed.""" - - try: - conn = Connection(shared_instance, dbapi_database) - conn.autocommit = True - - cur = conn.cursor() - cur.execute( - """ - CREATE TABLE Singers ( - SingerId INT64 NOT NULL, - Name STRING(1024), - ) PRIMARY KEY (SingerId) - """ - ) - conn.close() - - # if previous DDL wasn't committed, the next DROP TABLE - # statement will fail with a ProgrammingError - conn = Connection(shared_instance, dbapi_database) - cur = conn.cursor() - - cur.execute("DROP TABLE Singers") - conn.commit() - finally: - # Delete table - table = dbapi_database.table("Singers") - if table.exists(): - op = dbapi_database.update_ddl(["DROP TABLE Singers"]) - op.result() - - def test_ddl_execute_autocommit_true(self, dbapi_database): - """Check that DDL statement in autocommit mode results in successful - DDL statement execution for execute method.""" + @pytest.mark.parametrize("autocommit", [True, False]) + def test_ddl_execute(self, autocommit, dbapi_database): + """Check that DDL statement results in successful execution for execute + method in autocommit mode while it's a noop in non-autocommit mode when + buffer_ddl_statements flag is enabled.""" - self._conn.autocommit = True + if autocommit: + self._conn.autocommit = True self._cursor.execute( """ CREATE TABLE DdlExecuteAutocommit ( @@ -1160,29 +1184,52 @@ def test_ddl_execute_autocommit_true(self, dbapi_database): """ ) table = dbapi_database.table("DdlExecuteAutocommit") - assert table.exists() is True - - def test_ddl_executemany_autocommit_true(self, dbapi_database): - """Check that DDL statement in autocommit mode results in exception for - executemany method .""" - - self._conn.autocommit = True - with pytest.raises(ProgrammingError): - self._cursor.executemany( + if autocommit: + assert table.exists() is True + self._cursor.execute("DROP TABLE DdlExecuteAutocommit") + else: + assert table.exists() is False + assert len(self._conn._ddl_statements) == 1 + + @pytest.mark.parametrize("autocommit", [True, False]) + def test_ddl_execute_without_buffer_ddl_enabled(self, autocommit, dbapi_database): + """Check that DDL statement results in successful execution for execute + method in autocommit mode while it results in error in non-autocommit + mode when buffer_ddl_statements flag is disabled.""" + + self._conn._buffer_ddl_statements = False + exception = False + if autocommit: + self._conn.autocommit = True + try: + self._cursor.execute( """ - CREATE TABLE DdlExecuteManyAutocommit ( + CREATE TABLE DdlExecuteAutocommit ( SingerId INT64 NOT NULL, Name STRING(1024), ) PRIMARY KEY (SingerId) - """, - [], + """ ) - table = dbapi_database.table("DdlExecuteManyAutocommit") - assert table.exists() is False - - def test_ddl_executemany_autocommit_false(self, dbapi_database): - """Check that DDL statement in non-autocommit mode results in exception for - executemany method .""" + except ProgrammingError: + exception = True + table = dbapi_database.table("DdlExecuteAutocommit") + if autocommit: + assert table.exists() is True + self._cursor.execute("DROP TABLE DdlExecuteAutocommit") + else: + assert table.exists() is False + assert exception is True + + @pytest.mark.parametrize("autocommit", [True, False]) + @pytest.mark.parametrize("buffer_ddl", [True, False]) + def test_ddl_executemany(self, buffer_ddl, autocommit, dbapi_database): + """Check that DDL statement always results in exception for execution of + executemany method.""" + + if not buffer_ddl: + self._conn._buffer_ddl_statements = False + if autocommit: + self._conn.autocommit = True with pytest.raises(ProgrammingError): self._cursor.executemany( """ @@ -1196,9 +1243,10 @@ def test_ddl_executemany_autocommit_false(self, dbapi_database): table = dbapi_database.table("DdlExecuteManyAutocommit") assert table.exists() is False - def test_ddl_execute(self, dbapi_database): + def test_ddl_then_non_ddl_execute(self, dbapi_database): """Check that DDL statement followed by non-DDL execute statement in - non autocommit mode results in successful DDL statement execution.""" + non autocommit mode results in successful DDL statement execution + when buffer_ddl_statements flag is enabled.""" want_row = ( 1, @@ -1230,7 +1278,7 @@ def test_ddl_execute(self, dbapi_database): assert got_rows == [want_row] - def test_ddl_executemany(self, dbapi_database): + def test_ddl_then_non_ddl_executemany(self, dbapi_database): """Check that DDL statement followed by non-DDL executemany statement in non autocommit mode results in successful DDL statement execution.""" @@ -1339,14 +1387,11 @@ def test_json_array(self, dbapi_database): op = dbapi_database.update_ddl(["DROP TABLE JsonDetails"]) op.result() - @pytest.mark.noautofixt - def test_DDL_commit(self, shared_instance, dbapi_database): - """Check that DDLs in commit mode are executed on calling `commit()`.""" + def test_DDL_commit(self, dbapi_database): + """Check that DDLs in commit mode are executed on calling `commit()` + when buffer_ddl_statements flag is enabled.""" try: - conn = Connection(shared_instance, dbapi_database) - cur = conn.cursor() - - cur.execute( + self._cursor.execute( """ CREATE TABLE Singers ( SingerId INT64 NOT NULL, @@ -1354,16 +1399,12 @@ def test_DDL_commit(self, shared_instance, dbapi_database): ) PRIMARY KEY (SingerId) """ ) - conn.commit() - conn.close() + self._conn.commit() # if previous DDL wasn't committed, the next DROP TABLE # statement will fail with a ProgrammingError - conn = Connection(shared_instance, dbapi_database) - cur = conn.cursor() - - cur.execute("DROP TABLE Singers") - conn.commit() + self._cursor.execute("DROP TABLE Singers") + self._conn.commit() finally: # Delete table table = dbapi_database.table("Singers") @@ -1406,7 +1447,7 @@ def test_read_only_dml(self): """ self._conn.read_only = True - with pytest.raises(ProgrammingError): + with pytest.raises(Exception): self._cursor.execute( """ UPDATE contacts diff --git a/tests/unit/spanner_dbapi/test_batch_dml_executor.py b/tests/unit/spanner_dbapi/test_batch_executor.py similarity index 64% rename from tests/unit/spanner_dbapi/test_batch_dml_executor.py rename to tests/unit/spanner_dbapi/test_batch_executor.py index 3dc387bcb6..58cec5555e 100644 --- a/tests/unit/spanner_dbapi/test_batch_dml_executor.py +++ b/tests/unit/spanner_dbapi/test_batch_executor.py @@ -16,7 +16,7 @@ from unittest import mock from google.cloud.spanner_dbapi import ProgrammingError -from google.cloud.spanner_dbapi.batch_dml_executor import BatchDmlExecutor +from google.cloud.spanner_dbapi.batch_executor import BatchDmlExecutor, BatchDdlExecutor from google.cloud.spanner_dbapi.parsed_statement import ( ParsedStatement, Statement, @@ -52,3 +52,28 @@ def test_execute_statement_update_statement_type(self): ) self.assertEqual(self._under_test._statements, [statement]) + + +class TestBatchDdlExecutor(unittest.TestCase): + @mock.patch("google.cloud.spanner_dbapi.connection.Connection") + def setUp(self, mock_connection): + self._under_test = BatchDdlExecutor(mock_connection) + + def test_execute_statement_non_ddl_statement_type(self): + parsed_statement = ParsedStatement(StatementType.QUERY, Statement("sql")) + + with self.assertRaises(ProgrammingError): + self._under_test.execute_statement(parsed_statement) + + def test_execute_statement_ddl_statement_type(self): + sql = """CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + Name STRING(1024), + ) PRIMARY KEY (SingerId)""" + statement = Statement(sql) + + self._under_test.execute_statement( + ParsedStatement(StatementType.DDL, statement) + ) + + self.assertEqual(self._under_test._statements, [sql]) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index dec32285d4..a5658831f9 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -20,7 +20,7 @@ import warnings import pytest -from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode +from google.cloud.spanner_dbapi.batch_executor import BatchMode from google.cloud.spanner_dbapi.exceptions import ( InterfaceError, OperationalError, @@ -326,6 +326,52 @@ def test_rollback_in_autocommit_mode(self, mock_warn): CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2 ) + def test_start_batch_ddl_batch_mode_active(self): + self._under_test._batch_mode = BatchMode.DDL + + with self.assertRaises(ProgrammingError): + self._under_test.start_batch_ddl() + + def test_start_batch_ddl_connection_read_only(self): + self._under_test.read_only = True + + with self.assertRaises(ProgrammingError): + self._under_test.start_batch_ddl() + + def test_start_batch_ddl_buffer_ddl_active(self): + self._under_test._buffer_ddl_statements = True + + with self.assertRaises(ProgrammingError): + self._under_test.start_batch_ddl() + + def test_start_batch_ddl(self): + self._under_test.autocommit = True + self._under_test.start_batch_ddl() + + self.assertEqual(self._under_test._batch_mode, BatchMode.DDL) + + def test_execute_batch_ddl_batch_mode_inactive(self): + self._under_test._batch_mode = BatchMode.NONE + + with self.assertRaises(ProgrammingError): + self._under_test.execute_batch_ddl_statement( + ParsedStatement(StatementType.DDL, Statement("sql")) + ) + + @mock.patch( + "google.cloud.spanner_dbapi.batch_executor.BatchDdlExecutor", autospec=True + ) + def test_execute_batch_ddl(self, mock_batch_ddl_executor): + self._under_test._batch_mode = BatchMode.DDL + self._under_test._batch_ddl_executor = mock_batch_ddl_executor + + parsed_statement = ParsedStatement(StatementType.DDL, Statement("sql")) + self._under_test.execute_batch_ddl_statement(parsed_statement) + + mock_batch_ddl_executor.execute_statement.assert_called_once_with( + parsed_statement + ) + def test_start_batch_dml_batch_mode_active(self): self._under_test._batch_mode = BatchMode.DML cursor = self._under_test.cursor() @@ -356,7 +402,7 @@ def test_execute_batch_dml_batch_mode_inactive(self): ) @mock.patch( - "google.cloud.spanner_dbapi.batch_dml_executor.BatchDmlExecutor", autospec=True + "google.cloud.spanner_dbapi.batch_executor.BatchDmlExecutor", autospec=True ) def test_execute_batch_dml(self, mock_batch_dml_executor): self._under_test._batch_mode = BatchMode.DML @@ -370,7 +416,7 @@ def test_execute_batch_dml(self, mock_batch_dml_executor): ) @mock.patch( - "google.cloud.spanner_dbapi.batch_dml_executor.BatchDmlExecutor", autospec=True + "google.cloud.spanner_dbapi.batch_executor.BatchDmlExecutor", autospec=True ) def test_run_batch_batch_mode_inactive(self, mock_batch_dml_executor): self._under_test._batch_mode = BatchMode.NONE @@ -380,20 +426,33 @@ def test_run_batch_batch_mode_inactive(self, mock_batch_dml_executor): self._under_test.run_batch() @mock.patch( - "google.cloud.spanner_dbapi.batch_dml_executor.BatchDmlExecutor", autospec=True + "google.cloud.spanner_dbapi.batch_executor.BatchDmlExecutor", autospec=True ) - def test_run_batch(self, mock_batch_dml_executor): + def test_run_dml_batch(self, mock_batch_dml_executor): self._under_test._batch_mode = BatchMode.DML self._under_test._batch_dml_executor = mock_batch_dml_executor self._under_test.run_batch() - mock_batch_dml_executor.run_batch_dml.assert_called_once_with() + mock_batch_dml_executor.run_batch.assert_called_once_with() self.assertEqual(self._under_test._batch_mode, BatchMode.NONE) self.assertEqual(self._under_test._batch_dml_executor, None) @mock.patch( - "google.cloud.spanner_dbapi.batch_dml_executor.BatchDmlExecutor", autospec=True + "google.cloud.spanner_dbapi.batch_executor.BatchDdlExecutor", autospec=True + ) + def test_run_ddl_batch(self, mock_batch_ddl_executor): + self._under_test._batch_mode = BatchMode.DDL + self._under_test._batch_ddl_executor = mock_batch_ddl_executor + + self._under_test.run_batch() + + mock_batch_ddl_executor.run_batch.assert_called_once_with() + self.assertEqual(self._under_test._batch_mode, BatchMode.NONE) + self.assertEqual(self._under_test._batch_ddl_executor, None) + + @mock.patch( + "google.cloud.spanner_dbapi.batch_executor.BatchDmlExecutor", autospec=True ) def test_abort_batch_batch_mode_inactive(self, mock_batch_dml_executor): self._under_test._batch_mode = BatchMode.NONE @@ -403,7 +462,7 @@ def test_abort_batch_batch_mode_inactive(self, mock_batch_dml_executor): self._under_test.abort_batch() @mock.patch( - "google.cloud.spanner_dbapi.batch_dml_executor.BatchDmlExecutor", autospec=True + "google.cloud.spanner_dbapi.batch_executor.BatchDmlExecutor", autospec=True ) def test_abort_dml_batch(self, mock_batch_dml_executor): self._under_test._batch_mode = BatchMode.DML diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 9735185a5c..7c88da484f 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -203,7 +203,9 @@ def test_execute_insert_statement_autocommit_off(self): self.assertIsInstance(cursor._result_set, mock.MagicMock) def test_execute_statement(self): - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + connection = self._make_connection( + self.INSTANCE, mock.MagicMock(), buffer_ddl_statements=True + ) cursor = self._make_one(connection) sql = "sql" @@ -1163,7 +1165,9 @@ def test_ddls_with_semicolon(self, mock_client): "DROP TABLE table_name", ] - connection = connect("test-instance", "test-database") + connection = connect( + "test-instance", "test-database", buffer_ddl_statements=True + ) cursor = connection.cursor() cursor.execute(