From 3f602b87d06bed9fc1b7daaaaec9846c6cb0c58f Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Tue, 3 Nov 2020 13:04:35 +0300 Subject: [PATCH 1/6] fix: executemany() should return every executed query result --- google/cloud/spanner_dbapi/cursor.py | 13 +++++++++ google/cloud/spanner_dbapi/utils.py | 43 ++++++++++++++++++++++++++++ tests/system/test_system.py | 40 ++++++++++++++++++++++++++ 3 files changed, 96 insertions(+) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 6997752a42..32cdec30dc 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -27,6 +27,7 @@ from google.cloud.spanner_dbapi import parse_utils from google.cloud.spanner_dbapi.parse_utils import get_param_types from google.cloud.spanner_dbapi.utils import PeekIterator +from google.cloud.spanner_dbapi.utils import StreamedManyResultSets _UNSET_COUNT = -1 @@ -46,6 +47,7 @@ def __init__(self, connection): self._row_count = _UNSET_COUNT self.connection = connection self._is_closed = False + self._executed_many = False # the number of rows to fetch at a time with fetchmany() self.arraysize = 1 @@ -144,6 +146,7 @@ def execute(self, sql, args=None): self._raise_if_closed() self._result_set = None + self._executed_many = False # Classify whether this is a read-only SQL statement. try: @@ -198,8 +201,12 @@ def executemany(self, operation, seq_of_params): """ self._raise_if_closed() + self._many_result_set = StreamedManyResultSets() for params in seq_of_params: self.execute(operation, params) + self._many_result_set.add_iter(self._itr) + + self._executed_many = True def fetchone(self): """Fetch the next row of a query result set, returning a single @@ -293,11 +300,17 @@ def __exit__(self, etype, value, traceback): def __next__(self): if self._itr is None: raise ProgrammingError("no results to return") + if self._executed_many: + return next(self._many_result_set) + return next(self._itr) def __iter__(self): if self._itr is None: raise ProgrammingError("no results to return") + if self._executed_many: + return self._many_result_set + return self._itr def list_tables(self): diff --git a/google/cloud/spanner_dbapi/utils.py b/google/cloud/spanner_dbapi/utils.py index f4769e80a4..78684f3eab 100644 --- a/google/cloud/spanner_dbapi/utils.py +++ b/google/cloud/spanner_dbapi/utils.py @@ -47,6 +47,49 @@ def __iter__(self): return self +class StreamedManyResultSets: + """Iterator to walk through several `StreamedResultsSet` iterators. + + This type of iterator is used by `Cursor.executemany()` + method to iterate through several `StreamedResultsSet` + iterators like they all are merged into single iterator. + """ + + def __init__(self): + self._iterators = [] + self._index = 0 + + def add_iter(self, iterator): + """Add new iterator into this one. + + :type iterator: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet` + :param iterator: Iterator to merge into this one. + """ + self._iterators.append(iterator) + + def __next__(self): + """Return the next value from the currently streamed iterator. + + If the current iterator is streamed to the end, + start to stream the next one. + + :rtype: list + :returns: The next result row. + """ + try: + res = next(self._iterators[self._index]) + except StopIteration: + self._index += 1 + res = self.__next__() + except IndexError: + raise StopIteration + + return res + + def __iter__(self): + return self + + re_UNICODE_POINTS = re.compile(r"([^\s]*[\u0080-\uFFFF]+[^\s]*)") diff --git a/tests/system/test_system.py b/tests/system/test_system.py index f3ee345e15..8b21bb1fdf 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -289,6 +289,46 @@ def test_rollback_on_connection_closing(self): cursor.close() conn.close() + def test_execute_many(self): + # connect to the test database + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru'), + (2, 'first-name2', 'last-name2', 'test.email2@domen.ru') + """ + ) + conn.commit() + + cursor.executemany( + """ +SELECT * FROM contacts WHERE contact_id = @a1 +""", + ({"a1": 1}, {"a1": 2}), + ) + res = cursor.fetchall() + conn.commit() + + self.assertEqual(len(res), 2) + self.assertEqual(res[0][0], 1) + self.assertEqual(res[1][0], 2) + + # checking that execute() and executemany() + # results are not mixed together + cursor.execute( + """ +SELECT * FROM contacts WHERE contact_id = 1 +""", + ) + res = cursor.fetchone() + conn.commit() + + self.assertEqual(res[0], 1) + conn.close() + def clear_table(transaction): """Clear the test table.""" From 658ef93ec859556c728aad0e1b1995d59f9801bf Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Tue, 3 Nov 2020 13:06:42 +0300 Subject: [PATCH 2/6] add default value --- google/cloud/spanner_dbapi/cursor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 32cdec30dc..6376291631 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -44,6 +44,7 @@ class Cursor(object): def __init__(self, connection): self._itr = None self._result_set = None + self._many_result_set = None self._row_count = _UNSET_COUNT self.connection = connection self._is_closed = False From a4828b9ab177e365422c02cb65f47106029b6d6e Mon Sep 17 00:00:00 2001 From: Ilya Gurov Date: Thu, 5 Nov 2020 11:41:08 +0300 Subject: [PATCH 3/6] Update tests/system/test_system.py Co-authored-by: Chris Kleinknecht --- tests/system/test_system.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 8b21bb1fdf..ddf2fc139c 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -297,8 +297,8 @@ def test_execute_many(self): cursor.execute( """ INSERT INTO contacts (contact_id, first_name, last_name, email) -VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru'), - (2, 'first-name2', 'last-name2', 'test.email2@domen.ru') +VALUES (1, 'first-name', 'last-name', 'test.email@example.com'), + (2, 'first-name2', 'last-name2', 'test.email2@example.com') """ ) conn.commit() From 657bf43e3b00502aface57f3a902c62ea407bf8b Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Thu, 5 Nov 2020 12:19:39 +0300 Subject: [PATCH 4/6] don't use new properties --- google/cloud/spanner_dbapi/cursor.py | 16 ++++------------ tests/unit/spanner_dbapi/test_cursor.py | 5 ++++- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 6376291631..fcb0e16296 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -44,11 +44,9 @@ class Cursor(object): def __init__(self, connection): self._itr = None self._result_set = None - self._many_result_set = None self._row_count = _UNSET_COUNT self.connection = connection self._is_closed = False - self._executed_many = False # the number of rows to fetch at a time with fetchmany() self.arraysize = 1 @@ -147,7 +145,6 @@ def execute(self, sql, args=None): self._raise_if_closed() self._result_set = None - self._executed_many = False # Classify whether this is a read-only SQL statement. try: @@ -202,12 +199,13 @@ def executemany(self, operation, seq_of_params): """ self._raise_if_closed() - self._many_result_set = StreamedManyResultSets() + many_result_set = StreamedManyResultSets() for params in seq_of_params: self.execute(operation, params) - self._many_result_set.add_iter(self._itr) + many_result_set.add_iter(self._itr) - self._executed_many = True + self._result_set = many_result_set + self._itr = many_result_set def fetchone(self): """Fetch the next row of a query result set, returning a single @@ -301,17 +299,11 @@ def __exit__(self, etype, value, traceback): def __next__(self): if self._itr is None: raise ProgrammingError("no results to return") - if self._executed_many: - return next(self._many_result_set) - return next(self._itr) def __iter__(self): if self._itr is None: raise ProgrammingError("no results to return") - if self._executed_many: - return self._many_result_set - return self._itr def list_tables(self): diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 09288df94e..889bed20da 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -276,8 +276,11 @@ def test_executemany(self): connection = connect("test-instance", "test-database") cursor = connection.cursor() + + cursor._result_set = [1, 2, 3] + cursor._itr = iter([1, 2, 3]) with mock.patch( - "google.cloud.spanner_dbapi.cursor.Cursor.execute" + "google.cloud.spanner_dbapi.cursor.Cursor.execute", ) as execute_mock: cursor.executemany(operation, params_seq) From 286ae76454ef5eea170dc2066b09523b25ac2327 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Mon, 16 Nov 2020 12:29:24 +0300 Subject: [PATCH 5/6] disallow executing DDLs with executemany() --- google/cloud/spanner_dbapi/cursor.py | 6 ++++++ tests/unit/spanner_dbapi/test_cursor.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index fcb0e16296..124664d4f5 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -199,6 +199,12 @@ def executemany(self, operation, seq_of_params): """ self._raise_if_closed() + classification = parse_utils.classify_stmt(operation) + if classification == parse_utils.STMT_DDL: + raise ProgrammingError( + "Executing DDL statements with executemany() method is now allowed." + ) + many_result_set = StreamedManyResultSets() for params in seq_of_params: self.execute(operation, params) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 889bed20da..6576b567cb 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -259,6 +259,24 @@ def test_executemany_on_closed_cursor(self): """SELECT * FROM table1 WHERE "col1" = @a1""", () ) + def test_executemany_DLL(self): + from google.cloud.spanner_dbapi import connect, ProgrammingError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + + with self.assertRaises(ProgrammingError): + cursor.executemany("""DROP DATABASE database_name""", ()) + def test_executemany(self): from google.cloud.spanner_dbapi import connect From be5cc2272ec2f1a9e834e8c31c3ef84ee86d0526 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Mon, 16 Nov 2020 12:32:11 +0300 Subject: [PATCH 6/6] fix a typo --- google/cloud/spanner_dbapi/cursor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 124664d4f5..070205867a 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -202,7 +202,7 @@ def executemany(self, operation, seq_of_params): classification = parse_utils.classify_stmt(operation) if classification == parse_utils.STMT_DDL: raise ProgrammingError( - "Executing DDL statements with executemany() method is now allowed." + "Executing DDL statements with executemany() method is not allowed." ) many_result_set = StreamedManyResultSets()