diff --git a/docs/spanner/snapshot-usage.rst b/docs/spanner/snapshot-usage.rst index ba31425a54b4..90c6d0322b4d 100644 --- a/docs/spanner/snapshot-usage.rst +++ b/docs/spanner/snapshot-usage.rst @@ -62,26 +62,6 @@ fails if the result set is too large, manually, perform all iteration within the context of the ``with database.snapshot()`` block. -.. note:: - - If streaming a chunk raises an exception, the application can - retry the ``read``, passing the ``resume_token`` from ``StreamingResultSet`` - which raised the error. E.g.: - - .. code:: python - - result = snapshot.read(table, columns, keys) - while True: - try: - for row in result.rows: - print row - except Exception: - result = snapshot.read( - table, columns, keys, resume_token=result.resume_token) - continue - else: - break - Execute a SQL Select Statement @@ -112,26 +92,6 @@ fails if the result set is too large, manually, perform all iteration within the context of the ``with database.snapshot()`` block. -.. note:: - - If streaming a chunk raises an exception, the application can - retry the query, passing the ``resume_token`` from ``StreamingResultSet`` - which raised the error. E.g.: - - .. code:: python - - result = snapshot.execute_sql(QUERY) - while True: - try: - for row in result.rows: - print row - except Exception: - result = snapshot.execute_sql( - QUERY, resume_token=result.resume_token) - continue - else: - break - Next Step --------- diff --git a/docs/spanner/transaction-usage.rst b/docs/spanner/transaction-usage.rst index 0577bc2093b8..5c2e4a9bb5a2 100644 --- a/docs/spanner/transaction-usage.rst +++ b/docs/spanner/transaction-usage.rst @@ -32,12 +32,6 @@ fails if the result set is too large, for row in result.rows: print(row) -.. note:: - - If streaming a chunk fails due to a "resumable" error, - :meth:`Session.read` retries the ``StreamingRead`` API reqeust, - passing the ``resume_token`` from the last partial result streamed. - Execute a SQL Select Statement ------------------------------ diff --git a/spanner/google/cloud/spanner/session.py b/spanner/google/cloud/spanner/session.py index d513889053a7..94fd0f092366 100644 --- a/spanner/google/cloud/spanner/session.py +++ b/spanner/google/cloud/spanner/session.py @@ -165,8 +165,7 @@ def snapshot(self, **kw): return Snapshot(self, **kw) - def read(self, table, columns, keyset, index='', limit=0, - resume_token=b''): + def read(self, table, columns, keyset, index='', limit=0): """Perform a ``StreamingRead`` API request for rows in a table. :type table: str @@ -185,17 +184,12 @@ def read(self, table, columns, keyset, index='', limit=0, :type limit: int :param limit: (Optional) maxiumn number of rows to return - :type resume_token: bytes - :param resume_token: token for resuming previously-interrupted read - :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ - return self.snapshot().read( - table, columns, keyset, index, limit, resume_token) + return self.snapshot().read(table, columns, keyset, index, limit) - def execute_sql(self, sql, params=None, param_types=None, query_mode=None, - resume_token=b''): + def execute_sql(self, sql, params=None, param_types=None, query_mode=None): """Perform an ``ExecuteStreamingSql`` API request. :type sql: str @@ -216,14 +210,11 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None, :param query_mode: Mode governing return of results / query plan. See https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 - :type resume_token: bytes - :param resume_token: token for resuming previously-interrupted query - :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ return self.snapshot().execute_sql( - sql, params, param_types, query_mode, resume_token) + sql, params, param_types, query_mode) def batch(self): """Factory to create a batch for this session. diff --git a/spanner/google/cloud/spanner/snapshot.py b/spanner/google/cloud/spanner/snapshot.py index 89bd840000dc..9636bd9762e3 100644 --- a/spanner/google/cloud/spanner/snapshot.py +++ b/spanner/google/cloud/spanner/snapshot.py @@ -14,6 +14,8 @@ """Model a set of read-only queries to a database as a snapshot.""" +import functools + from google.protobuf.struct_pb2 import Struct from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionOptions from google.cloud.proto.spanner.v1.transaction_pb2 import TransactionSelector @@ -49,8 +51,7 @@ def _make_txn_selector(self): # pylint: disable=redundant-returns-doc """ raise NotImplementedError - def read(self, table, columns, keyset, index='', limit=0, - resume_token=b''): + def read(self, table, columns, keyset, index='', limit=0): """Perform a ``StreamingRead`` API request for rows in a table. :type table: str @@ -69,9 +70,6 @@ def read(self, table, columns, keyset, index='', limit=0, :type limit: int :param limit: (Optional) maxiumn number of rows to return - :type resume_token: bytes - :param resume_token: token for resuming previously-interrupted read - :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. :raises ValueError: @@ -92,17 +90,20 @@ def read(self, table, columns, keyset, index='', limit=0, iterator = api.streaming_read( self._session.name, table, columns, keyset.to_pb(), transaction=transaction, index=index, limit=limit, - resume_token=resume_token, options=options) + options=options) self._read_request_count += 1 + restart = functools.partial( + api.streaming_read, self._session.name, table, columns, keyset, + index=index, limit=limit) + if self._multi_use: - return StreamedResultSet(iterator, source=self) + return StreamedResultSet(iterator, restart, source=self) else: - return StreamedResultSet(iterator) + return StreamedResultSet(iterator, restart) - def execute_sql(self, sql, params=None, param_types=None, query_mode=None, - resume_token=b''): + def execute_sql(self, sql, params=None, param_types=None, query_mode=None): """Perform an ``ExecuteStreamingSql`` API request for rows in a table. :type sql: str @@ -122,9 +123,6 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None, :param query_mode: Mode governing return of results / query plan. See https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 - :type resume_token: bytes - :param resume_token: token for resuming previously-interrupted query - :rtype: :class:`~google.cloud.spanner.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. :raises ValueError: @@ -153,14 +151,18 @@ def execute_sql(self, sql, params=None, param_types=None, query_mode=None, iterator = api.execute_streaming_sql( self._session.name, sql, transaction=transaction, params=params_pb, param_types=param_types, - query_mode=query_mode, resume_token=resume_token, options=options) + query_mode=query_mode, options=options) self._read_request_count += 1 + restart = functools.partial( + api.execute_streaming_sql, self._session.name, sql, + params=params, param_types=param_types, query_mode=query_mode) + if self._multi_use: - return StreamedResultSet(iterator, source=self) + return StreamedResultSet(iterator, restart, source=self) else: - return StreamedResultSet(iterator) + return StreamedResultSet(iterator, restart) class Snapshot(_SnapshotBase): diff --git a/spanner/google/cloud/spanner/streamed.py b/spanner/google/cloud/spanner/streamed.py index c7d950d766d7..1c0dbe7dc89a 100644 --- a/spanner/google/cloud/spanner/streamed.py +++ b/spanner/google/cloud/spanner/streamed.py @@ -14,9 +14,10 @@ """Wrapper for streaming results.""" +from google.api.core import exceptions +from google.api.core import retry from google.protobuf.struct_pb2 import ListValue from google.protobuf.struct_pb2 import Value -from google.cloud import exceptions from google.cloud.proto.spanner.v1 import type_pb2 import six @@ -24,6 +25,28 @@ from google.cloud.spanner._helpers import _parse_value_pb # pylint: enable=ungrouped-imports +_RESTART_DEADLINE = 30.0 # seconds + + +# pylint: disable=invalid-name +# Pylint sees this as a constant, but it is also an alias that should be +# considered a function. +if_unavailable_error = retry.if_exception_type(( + exceptions.ServiceUnavailable, +)) +"""A predicate that checks if an exception is a transient API error. + +For streaming the result of ``read`` / ``execute_sql`` requests, only +the following server errors are considered transient: + +- :class:`google.api.core.exceptions.ServiceUnavailable` - HTTP 503, gRPC + ``UNAVAILABLE``. +""" + +retry_unavailable = retry.Retry(predicate=if_unavailable_error) +"""Used by `StreamedResultSet.consume_next`.""" +# pylint: enable=invalid-name + class StreamedResultSet(object): """Process a sequence of partial result sets into a single set of row data. @@ -34,11 +57,18 @@ class StreamedResultSet(object): :class:`google.cloud.proto.spanner.v1.result_set_pb2.PartialResultSet` instances. + :type restart: callable + :param restart: + Function (typically curried via :func:`functools.partial`) used to + restart the initial request if a retriable error is raised during + streaming. + :type source: :class:`~google.cloud.spanner.snapshot.Snapshot` :param source: Snapshot from which the result set was fetched. """ - def __init__(self, response_iterator, source=None): + def __init__(self, response_iterator, restart, source=None): self._response_iterator = response_iterator + self._restart = restart self._rows = [] # Fully-processed rows self._counter = 0 # Counter for processed responses self._metadata = None # Until set from first PRS @@ -125,12 +155,29 @@ def _merge_values(self, values): self._rows.append(self._current_row) self._current_row = [] + def _restart_iterator(self, _exc_ignored): + """Helper for :meth:`consume_next`.""" + if self._resume_token in (None, b''): + raise + + self._response_iterator = self._restart( + resume_token=self._resume_token) + + def _bump_iterator(self): + """Helper for :meth:`consume_next`.""" + return six.next(self._response_iterator) + def consume_next(self): """Consume the next partial result set from the stream. Parse the result set into new/existing rows in :attr:`_rows` + + :raises ValueError: + if the sleep generator somehow does not yield values. """ - response = six.next(self._response_iterator) + response = retry_unavailable( + self._bump_iterator, on_error=self._restart_iterator)() + self._counter += 1 self._resume_token = response.resume_token diff --git a/spanner/setup.py b/spanner/setup.py index 7498d54abfd6..d101de5fc238 100644 --- a/spanner/setup.py +++ b/spanner/setup.py @@ -51,7 +51,7 @@ REQUIREMENTS = [ - 'google-cloud-core >= 0.27.0, < 0.28dev', + 'google-cloud-core >= 0.27.1, < 0.28dev', 'grpcio >= 1.2.0, < 2.0dev', 'gapic-google-cloud-spanner-v1 >= 0.15.0, < 0.16dev', 'gapic-google-cloud-spanner-admin-database-v1 >= 0.15.0, < 0.16dev', diff --git a/spanner/tests/unit/test_session.py b/spanner/tests/unit/test_session.py index 826369079d29..3c9d9e74af47 100644 --- a/spanner/tests/unit/test_session.py +++ b/spanner/tests/unit/test_session.py @@ -265,7 +265,6 @@ def test_read(self): KEYSET = KeySet(keys=KEYS) INDEX = 'email-address-index' LIMIT = 20 - TOKEN = b'DEADBEEF' database = _Database(self.DATABASE_NAME) session = self._make_one(database) session._session_id = 'DEADBEEF' @@ -279,28 +278,26 @@ def __init__(self, session, **kwargs): self._session = session self._kwargs = kwargs.copy() - def read(self, table, columns, keyset, index='', limit=0, - resume_token=b''): + def read(self, table, columns, keyset, index='', limit=0): _read_with.append( - (table, columns, keyset, index, limit, resume_token)) + (table, columns, keyset, index, limit)) return expected with _Monkey(MUT, Snapshot=_Snapshot): found = session.read( TABLE_NAME, COLUMNS, KEYSET, - index=INDEX, limit=LIMIT, resume_token=TOKEN) + index=INDEX, limit=LIMIT) self.assertIs(found, expected) self.assertEqual(len(_read_with), 1) - (table, columns, key_set, index, limit, resume_token) = _read_with[0] + (table, columns, key_set, index, limit) = _read_with[0] self.assertEqual(table, TABLE_NAME) self.assertEqual(columns, COLUMNS) self.assertEqual(key_set, KEYSET) self.assertEqual(index, INDEX) self.assertEqual(limit, LIMIT) - self.assertEqual(resume_token, TOKEN) def test_execute_sql_not_created(self): SQL = 'SELECT first_name, age FROM citizens' @@ -315,7 +312,6 @@ def test_execute_sql_defaults(self): from google.cloud._testing import _Monkey SQL = 'SELECT first_name, age FROM citizens' - TOKEN = b'DEADBEEF' database = _Database(self.DATABASE_NAME) session = self._make_one(database) session._session_id = 'DEADBEEF' @@ -330,25 +326,23 @@ def __init__(self, session, **kwargs): self._kwargs = kwargs.copy() def execute_sql( - self, sql, params=None, param_types=None, query_mode=None, - resume_token=None): + self, sql, params=None, param_types=None, query_mode=None): _executed_sql_with.append( - (sql, params, param_types, query_mode, resume_token)) + (sql, params, param_types, query_mode)) return expected with _Monkey(MUT, Snapshot=_Snapshot): - found = session.execute_sql(SQL, resume_token=TOKEN) + found = session.execute_sql(SQL) self.assertIs(found, expected) self.assertEqual(len(_executed_sql_with), 1) - sql, params, param_types, query_mode, token = _executed_sql_with[0] + sql, params, param_types, query_mode = _executed_sql_with[0] self.assertEqual(sql, SQL) self.assertEqual(params, None) self.assertEqual(param_types, None) self.assertEqual(query_mode, None) - self.assertEqual(token, TOKEN) def test_batch_not_created(self): database = _Database(self.DATABASE_NAME) diff --git a/spanner/tests/unit/test_snapshot.py b/spanner/tests/unit/test_snapshot.py index 4717a14c2f24..dee8fdecefd7 100644 --- a/spanner/tests/unit/test_snapshot.py +++ b/spanner/tests/unit/test_snapshot.py @@ -15,6 +15,8 @@ import unittest +import mock + from google.cloud._testing import _GAXBaseAPI @@ -149,7 +151,6 @@ def _read_helper(self, multi_use, first=True, count=0): KEYSET = KeySet(keys=KEYS) INDEX = 'email-address-index' LIMIT = 20 - TOKEN = b'DEADBEEF' database = _Database() api = database.spanner_api = _FauxSpannerAPI( _streaming_read_response=_MockCancellableIterator(*result_sets)) @@ -160,9 +161,17 @@ def _read_helper(self, multi_use, first=True, count=0): if not first: derived._transaction_id = TXN_ID - result_set = derived.read( - TABLE_NAME, COLUMNS, KEYSET, - index=INDEX, limit=LIMIT, resume_token=TOKEN) + partial_patch = mock.patch('functools.partial') + + with partial_patch as patch: + result_set = derived.read( + TABLE_NAME, COLUMNS, KEYSET, + index=INDEX, limit=LIMIT) + + self.assertIs(result_set._restart, patch.return_value) + patch.assert_called_once_with( + api.streaming_read, session.name, TABLE_NAME, COLUMNS, KEYSET, + index=INDEX, limit=LIMIT) self.assertEqual(derived._read_request_count, count + 1) @@ -193,7 +202,7 @@ def _read_helper(self, multi_use, first=True, count=0): self.assertTrue(transaction.single_use.read_only.strong) self.assertEqual(index, INDEX) self.assertEqual(limit, LIMIT) - self.assertEqual(resume_token, TOKEN) + self.assertEqual(resume_token, b'') self.assertEqual(options.kwargs['metadata'], [('google-cloud-resource-prefix', database.name)]) @@ -273,7 +282,6 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): for row in VALUES ] MODE = 2 # PROFILE - TOKEN = b'DEADBEEF' struct_type_pb = StructType(fields=[ StructType.Field(name='first_name', type=Type(code=STRING)), StructType.Field(name='last_name', type=Type(code=STRING)), @@ -299,9 +307,17 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): if not first: derived._transaction_id = TXN_ID - result_set = derived.execute_sql( - SQL_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, - query_mode=MODE, resume_token=TOKEN) + partial_patch = mock.patch('functools.partial') + + with partial_patch as patch: + result_set = derived.execute_sql( + SQL_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, + query_mode=MODE) + + self.assertIs(result_set._restart, patch.return_value) + patch.assert_called_once_with( + api.execute_streaming_sql, session.name, SQL_QUERY_WITH_PARAM, + params=PARAMS, param_types=PARAM_TYPES, query_mode=MODE) self.assertEqual(derived._read_request_count, count + 1) @@ -333,7 +349,7 @@ def _execute_sql_helper(self, multi_use, first=True, count=0): self.assertEqual(params, expected_params) self.assertEqual(param_types, PARAM_TYPES) self.assertEqual(query_mode, MODE) - self.assertEqual(resume_token, TOKEN) + self.assertEqual(resume_token, b'') self.assertEqual(options.kwargs['metadata'], [('google-cloud-resource-prefix', database.name)]) @@ -360,8 +376,6 @@ def test_execute_sql_w_multi_use_w_first_w_count_gt_0(self): class _MockCancellableIterator(object): - cancel_calls = 0 - def __init__(self, *values): self.iter_values = iter(values) @@ -725,7 +739,7 @@ def begin_transaction(self, session, options_, options=None): # pylint: disable=too-many-arguments def streaming_read(self, session, table, columns, key_set, transaction=None, index='', limit=0, - resume_token='', options=None): + resume_token=b'', options=None): from google.gax.errors import GaxError self._streaming_read_with = ( @@ -738,7 +752,7 @@ def streaming_read(self, session, table, columns, key_set, def execute_streaming_sql(self, session, sql, transaction=None, params=None, param_types=None, - resume_token='', query_mode=None, options=None): + resume_token=b'', query_mode=None, options=None): from google.gax.errors import GaxError self._executed_streaming_sql_with = ( diff --git a/spanner/tests/unit/test_streamed.py b/spanner/tests/unit/test_streamed.py index 0e0bcb7aff6b..ebb293dc5a0a 100644 --- a/spanner/tests/unit/test_streamed.py +++ b/spanner/tests/unit/test_streamed.py @@ -25,13 +25,16 @@ def _getTargetClass(self): return StreamedResultSet - def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) + def _make_one(self, response_iterator, restart=object(), source=None): + return self._getTargetClass()( + response_iterator, restart, source=source) def test_ctor_defaults(self): iterator = _MockCancellableIterator() - streamed = self._make_one(iterator) + restart = object() + streamed = self._make_one(iterator, restart) self.assertIs(streamed._response_iterator, iterator) + self.assertIs(streamed._restart, restart) self.assertIsNone(streamed._source) self.assertEqual(streamed.rows, []) self.assertIsNone(streamed.metadata) @@ -40,9 +43,11 @@ def test_ctor_defaults(self): def test_ctor_w_source(self): iterator = _MockCancellableIterator() + restart = object() source = object() - streamed = self._make_one(iterator, source=source) + streamed = self._make_one(iterator, restart, source=source) self.assertIs(streamed._response_iterator, iterator) + self.assertIs(streamed._restart, restart) self.assertIs(streamed._source, source) self.assertEqual(streamed.rows, []) self.assertIsNone(streamed.metadata) @@ -650,6 +655,18 @@ def test_consume_next_empty(self): with self.assertRaises(StopIteration): streamed.consume_next() + def test_consume_next_w_retryable_exception_wo_token(self): + from google.api.core.exceptions import ServiceUnavailable + + failing_iterator = _FailingIterator() + restart = mock.Mock() + streamed = self._make_one(failing_iterator, restart) + + with self.assertRaises(ServiceUnavailable): + streamed.consume_next() + + restart.assert_not_called() + def test_consume_next_first_set_partial(self): TXN_ID = b'DEADBEEF' FIELDS = [ @@ -739,7 +756,8 @@ def test_consume_next_w_pending_chunk(self): self.assertIsNone(streamed._pending_chunk) self.assertEqual(streamed.resume_token, result_set.resume_token) - def test_consume_next_last_set(self): + def test_consume_next_last_set_w_restart(self): + TOKEN = b'BECADEAF' FIELDS = [ self._make_scalar_field('full_name', 'STRING'), self._make_scalar_field('age', 'INT64'), @@ -754,15 +772,22 @@ def test_consume_next_last_set(self): BARE = [u'Phred Phlyntstone', 42, True] VALUES = [self._make_value(bare) for bare in BARE] result_set = self._make_partial_result_set(VALUES, stats=stats) + failing_iterator = _FailingIterator() iterator = _MockCancellableIterator(result_set) - streamed = self._make_one(iterator) + restart = mock.Mock(return_value=iterator) + streamed = self._make_one(failing_iterator, restart) streamed._metadata = metadata + streamed._resume_token = TOKEN + streamed.consume_next() + self.assertEqual(streamed.rows, [BARE]) self.assertEqual(streamed._current_row, []) self.assertEqual(streamed._stats, stats) self.assertEqual(streamed.resume_token, result_set.resume_token) + restart.assert_called_once_with(resume_token=TOKEN) + def test_consume_all_empty(self): iterator = _MockCancellableIterator() streamed = self._make_one(iterator) @@ -900,18 +925,28 @@ def test___iter___w_existing_rows_read(self): class _MockCancellableIterator(object): - cancel_calls = 0 - def __init__(self, *values): self.iter_values = iter(values) def next(self): + return next(self.iter_values) def __next__(self): # pragma: NO COVER Py3k return self.next() +class _FailingIterator(object): + + def next(self): + from google.api.core.exceptions import ServiceUnavailable + + raise ServiceUnavailable('testing') + + def __next__(self): # pragma: NO COVER Py3k + return self.next() + + class TestStreamedResultSet_JSON_acceptance_tests(unittest.TestCase): _json_tests = None @@ -921,8 +956,9 @@ def _getTargetClass(self): return StreamedResultSet - def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) + def _make_one(self, response_iterator, restart=object(), source=None): + return self._getTargetClass()( + response_iterator, restart, source=source) def _load_json_test(self, test_name): import os