diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 6997752a42..070205867a 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -27,6 +27,7 @@ from google.cloud.spanner_dbapi import parse_utils from google.cloud.spanner_dbapi.parse_utils import get_param_types from google.cloud.spanner_dbapi.utils import PeekIterator +from google.cloud.spanner_dbapi.utils import StreamedManyResultSets _UNSET_COUNT = -1 @@ -198,8 +199,19 @@ def executemany(self, operation, seq_of_params): """ self._raise_if_closed() + classification = parse_utils.classify_stmt(operation) + if classification == parse_utils.STMT_DDL: + raise ProgrammingError( + "Executing DDL statements with executemany() method is not allowed." + ) + + many_result_set = StreamedManyResultSets() for params in seq_of_params: self.execute(operation, params) + many_result_set.add_iter(self._itr) + + self._result_set = many_result_set + self._itr = many_result_set def fetchone(self): """Fetch the next row of a query result set, returning a single diff --git a/google/cloud/spanner_dbapi/utils.py b/google/cloud/spanner_dbapi/utils.py index f4769e80a4..78684f3eab 100644 --- a/google/cloud/spanner_dbapi/utils.py +++ b/google/cloud/spanner_dbapi/utils.py @@ -47,6 +47,49 @@ def __iter__(self): return self +class StreamedManyResultSets: + """Iterator to walk through several `StreamedResultsSet` iterators. + + This type of iterator is used by `Cursor.executemany()` + method to iterate through several `StreamedResultsSet` + iterators like they all are merged into single iterator. + """ + + def __init__(self): + self._iterators = [] + self._index = 0 + + def add_iter(self, iterator): + """Add new iterator into this one. + + :type iterator: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet` + :param iterator: Iterator to merge into this one. + """ + self._iterators.append(iterator) + + def __next__(self): + """Return the next value from the currently streamed iterator. + + If the current iterator is streamed to the end, + start to stream the next one. + + :rtype: list + :returns: The next result row. + """ + try: + res = next(self._iterators[self._index]) + except StopIteration: + self._index += 1 + res = self.__next__() + except IndexError: + raise StopIteration + + return res + + def __iter__(self): + return self + + re_UNICODE_POINTS = re.compile(r"([^\s]*[\u0080-\uFFFF]+[^\s]*)") diff --git a/tests/system/test_system.py b/tests/system/test_system.py index f3ee345e15..ddf2fc139c 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -289,6 +289,46 @@ def test_rollback_on_connection_closing(self): cursor.close() conn.close() + def test_execute_many(self): + # connect to the test database + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (1, 'first-name', 'last-name', 'test.email@example.com'), + (2, 'first-name2', 'last-name2', 'test.email2@example.com') + """ + ) + conn.commit() + + cursor.executemany( + """ +SELECT * FROM contacts WHERE contact_id = @a1 +""", + ({"a1": 1}, {"a1": 2}), + ) + res = cursor.fetchall() + conn.commit() + + self.assertEqual(len(res), 2) + self.assertEqual(res[0][0], 1) + self.assertEqual(res[1][0], 2) + + # checking that execute() and executemany() + # results are not mixed together + cursor.execute( + """ +SELECT * FROM contacts WHERE contact_id = 1 +""", + ) + res = cursor.fetchone() + conn.commit() + + self.assertEqual(res[0], 1) + conn.close() + def clear_table(transaction): """Clear the test table.""" diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 09288df94e..6576b567cb 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -259,6 +259,24 @@ def test_executemany_on_closed_cursor(self): """SELECT * FROM table1 WHERE "col1" = @a1""", () ) + def test_executemany_DLL(self): + from google.cloud.spanner_dbapi import connect, ProgrammingError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + + with self.assertRaises(ProgrammingError): + cursor.executemany("""DROP DATABASE database_name""", ()) + def test_executemany(self): from google.cloud.spanner_dbapi import connect @@ -276,8 +294,11 @@ def test_executemany(self): connection = connect("test-instance", "test-database") cursor = connection.cursor() + + cursor._result_set = [1, 2, 3] + cursor._itr = iter([1, 2, 3]) with mock.patch( - "google.cloud.spanner_dbapi.cursor.Cursor.execute" + "google.cloud.spanner_dbapi.cursor.Cursor.execute", ) as execute_mock: cursor.executemany(operation, params_seq)