diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index a8eef614..34d29b7c 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - option: ["", "_grpc_gcp", "_wo_grpc"] + option: ["", "_grpc_gcp", "_wo_grpc", "_with_prerelease_deps"] python: - "3.7" - "3.8" diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e64d357..e48409f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,14 @@ [1]: https://pypi.org/project/google-api-core/#history +## [2.19.1](https://github.com/googleapis/python-api-core/compare/v2.19.0...v2.19.1) (2024-06-19) + + +### Bug Fixes + +* Add support for protobuf 5.x ([#644](https://github.com/googleapis/python-api-core/issues/644)) ([fda0ca6](https://github.com/googleapis/python-api-core/commit/fda0ca6f0664ac5044671591ed62618175a7393f)) +* Ignore unknown fields in rest streaming. ([#651](https://github.com/googleapis/python-api-core/issues/651)) ([1203fb9](https://github.com/googleapis/python-api-core/commit/1203fb97d2685535f89113e944c4764c1deb595e)) + ## [2.19.0](https://github.com/googleapis/python-api-core/compare/v2.18.0...v2.19.0) (2024-04-29) diff --git a/google/api_core/grpc_helpers_async.py b/google/api_core/grpc_helpers_async.py index 9423d2b6..6feb2229 100644 --- a/google/api_core/grpc_helpers_async.py +++ b/google/api_core/grpc_helpers_async.py @@ -159,7 +159,6 @@ class _WrappedStreamStreamCall( def _wrap_unary_errors(callable_): """Map errors for Unary-Unary async callables.""" - grpc_helpers._patch_callable_name(callable_) @functools.wraps(callable_) def error_remapped_callable(*args, **kwargs): @@ -169,23 +168,13 @@ def error_remapped_callable(*args, **kwargs): return error_remapped_callable -def _wrap_stream_errors(callable_): +def _wrap_stream_errors(callable_, wrapper_type): """Map errors for streaming RPC async callables.""" - grpc_helpers._patch_callable_name(callable_) @functools.wraps(callable_) async def error_remapped_callable(*args, **kwargs): call = callable_(*args, **kwargs) - - if isinstance(call, aio.UnaryStreamCall): - call = _WrappedUnaryStreamCall().with_call(call) - elif isinstance(call, aio.StreamUnaryCall): - call = _WrappedStreamUnaryCall().with_call(call) - elif isinstance(call, aio.StreamStreamCall): - call = _WrappedStreamStreamCall().with_call(call) - else: - raise TypeError("Unexpected type of call %s" % type(call)) - + call = wrapper_type().with_call(call) await call.wait_for_connection() return call @@ -207,10 +196,16 @@ def wrap_errors(callable_): Returns: Callable: The wrapped gRPC callable. """ - if isinstance(callable_, aio.UnaryUnaryMultiCallable): - return _wrap_unary_errors(callable_) + grpc_helpers._patch_callable_name(callable_) + + if isinstance(callable_, aio.UnaryStreamMultiCallable): + return _wrap_stream_errors(callable_, _WrappedUnaryStreamCall) + elif isinstance(callable_, aio.StreamUnaryMultiCallable): + return _wrap_stream_errors(callable_, _WrappedStreamUnaryCall) + elif isinstance(callable_, aio.StreamStreamMultiCallable): + return _wrap_stream_errors(callable_, _WrappedStreamStreamCall) else: - return _wrap_stream_errors(callable_) + return _wrap_unary_errors(callable_) def create_channel( diff --git a/google/api_core/operations_v1/__init__.py b/google/api_core/operations_v1/__init__.py index 61186451..8b75426b 100644 --- a/google/api_core/operations_v1/__init__.py +++ b/google/api_core/operations_v1/__init__.py @@ -14,7 +14,9 @@ """Package for interacting with the google.longrunning.operations meta-API.""" -from google.api_core.operations_v1.abstract_operations_client import AbstractOperationsClient +from google.api_core.operations_v1.abstract_operations_client import ( + AbstractOperationsClient, +) from google.api_core.operations_v1.operations_async_client import OperationsAsyncClient from google.api_core.operations_v1.operations_client import OperationsClient from google.api_core.operations_v1.transports.rest import OperationsRestTransport @@ -23,5 +25,5 @@ "AbstractOperationsClient", "OperationsAsyncClient", "OperationsClient", - "OperationsRestTransport" + "OperationsRestTransport", ] diff --git a/google/api_core/operations_v1/transports/base.py b/google/api_core/operations_v1/transports/base.py index 98cf7896..fb1d4fc9 100644 --- a/google/api_core/operations_v1/transports/base.py +++ b/google/api_core/operations_v1/transports/base.py @@ -45,7 +45,7 @@ def __init__( self, *, host: str = DEFAULT_HOST, - credentials: ga_credentials.Credentials = None, + credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, diff --git a/google/api_core/operations_v1/transports/rest.py b/google/api_core/operations_v1/transports/rest.py index 49f99d21..f37bb344 100644 --- a/google/api_core/operations_v1/transports/rest.py +++ b/google/api_core/operations_v1/transports/rest.py @@ -29,9 +29,13 @@ from google.longrunning import operations_pb2 # type: ignore from google.protobuf import empty_pb2 # type: ignore from google.protobuf import json_format # type: ignore +import google.protobuf + import grpc from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO, OperationsTransport +PROTOBUF_VERSION = google.protobuf.__version__ + OptionalRetry = Union[retries.Retry, object] DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( @@ -66,7 +70,7 @@ def __init__( self, *, host: str = "longrunning.googleapis.com", - credentials: ga_credentials.Credentials = None, + credentials: Optional[ga_credentials.Credentials] = None, credentials_file: Optional[str] = None, scopes: Optional[Sequence[str]] = None, client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, @@ -184,11 +188,7 @@ def _list_operations( "google.longrunning.Operations.ListOperations" ] - request_kwargs = json_format.MessageToDict( - request, - preserving_proto_field_name=True, - including_default_value_fields=True, - ) + request_kwargs = self._convert_protobuf_message_to_dict(request) transcoded_request = path_template.transcode(http_options, **request_kwargs) uri = transcoded_request["uri"] @@ -199,7 +199,6 @@ def _list_operations( json_format.ParseDict(transcoded_request["query_params"], query_params_request) query_params = json_format.MessageToDict( query_params_request, - including_default_value_fields=False, preserving_proto_field_name=False, use_integers_for_enums=False, ) @@ -265,11 +264,7 @@ def _get_operation( "google.longrunning.Operations.GetOperation" ] - request_kwargs = json_format.MessageToDict( - request, - preserving_proto_field_name=True, - including_default_value_fields=True, - ) + request_kwargs = self._convert_protobuf_message_to_dict(request) transcoded_request = path_template.transcode(http_options, **request_kwargs) uri = transcoded_request["uri"] @@ -280,7 +275,6 @@ def _get_operation( json_format.ParseDict(transcoded_request["query_params"], query_params_request) query_params = json_format.MessageToDict( query_params_request, - including_default_value_fields=False, preserving_proto_field_name=False, use_integers_for_enums=False, ) @@ -339,11 +333,7 @@ def _delete_operation( "google.longrunning.Operations.DeleteOperation" ] - request_kwargs = json_format.MessageToDict( - request, - preserving_proto_field_name=True, - including_default_value_fields=True, - ) + request_kwargs = self._convert_protobuf_message_to_dict(request) transcoded_request = path_template.transcode(http_options, **request_kwargs) uri = transcoded_request["uri"] @@ -354,7 +344,6 @@ def _delete_operation( json_format.ParseDict(transcoded_request["query_params"], query_params_request) query_params = json_format.MessageToDict( query_params_request, - including_default_value_fields=False, preserving_proto_field_name=False, use_integers_for_enums=False, ) @@ -411,11 +400,7 @@ def _cancel_operation( "google.longrunning.Operations.CancelOperation" ] - request_kwargs = json_format.MessageToDict( - request, - preserving_proto_field_name=True, - including_default_value_fields=True, - ) + request_kwargs = self._convert_protobuf_message_to_dict(request) transcoded_request = path_template.transcode(http_options, **request_kwargs) # Jsonify the request body @@ -423,7 +408,6 @@ def _cancel_operation( json_format.ParseDict(transcoded_request["body"], body_request) body = json_format.MessageToDict( body_request, - including_default_value_fields=False, preserving_proto_field_name=False, use_integers_for_enums=False, ) @@ -435,7 +419,6 @@ def _cancel_operation( json_format.ParseDict(transcoded_request["query_params"], query_params_request) query_params = json_format.MessageToDict( query_params_request, - including_default_value_fields=False, preserving_proto_field_name=False, use_integers_for_enums=False, ) @@ -458,6 +441,38 @@ def _cancel_operation( return empty_pb2.Empty() + def _convert_protobuf_message_to_dict( + self, message: google.protobuf.message.Message + ): + r"""Converts protobuf message to a dictionary. + + When the dictionary is encoded to JSON, it conforms to proto3 JSON spec. + + Args: + message(google.protobuf.message.Message): The protocol buffers message + instance to serialize. + + Returns: + A dict representation of the protocol buffer message. + """ + # For backwards compatibility with protobuf 3.x 4.x + # Remove once support for protobuf 3.x and 4.x is dropped + # https://github.com/googleapis/python-api-core/issues/643 + if PROTOBUF_VERSION[0:2] in ["3.", "4."]: + result = json_format.MessageToDict( + message, + preserving_proto_field_name=True, + including_default_value_fields=True, # type: ignore # backward compatibility + ) + else: + result = json_format.MessageToDict( + message, + preserving_proto_field_name=True, + always_print_fields_with_no_presence=True, + ) + + return result + @property def list_operations( self, diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py index 3f5b6b03..88bcb31b 100644 --- a/google/api_core/rest_streaming.py +++ b/google/api_core/rest_streaming.py @@ -118,7 +118,9 @@ def __next__(self): def _grab(self): # Add extra quotes to make json.loads happy. if issubclass(self._response_message_cls, proto.Message): - return self._response_message_cls.from_json(self._ready_objs.popleft()) + return self._response_message_cls.from_json( + self._ready_objs.popleft(), ignore_unknown_fields=True + ) elif issubclass(self._response_message_cls, google.protobuf.message.Message): return Parse(self._ready_objs.popleft(), self._response_message_cls()) else: diff --git a/google/api_core/version.py b/google/api_core/version.py index 2605c08a..25b6f2f8 100644 --- a/google/api_core/version.py +++ b/google/api_core/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.19.0" +__version__ = "2.19.1" diff --git a/noxfile.py b/noxfile.py index 8fbcaec0..593bfc85 100644 --- a/noxfile.py +++ b/noxfile.py @@ -15,7 +15,9 @@ from __future__ import absolute_import import os import pathlib +import re import shutil +import unittest # https://github.com/google/importlab/issues/25 import nox # pytype: disable=import-error @@ -26,6 +28,8 @@ # Black and flake8 clash on the syntax for ignoring flake8's F401 in this file. BLACK_EXCLUDES = ["--exclude", "^/google/api_core/operations_v1/__init__.py"] +PYTHON_VERSIONS = ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] + DEFAULT_PYTHON_VERSION = "3.10" CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() @@ -72,7 +76,37 @@ def blacken(session): session.run("black", *BLACK_EXCLUDES, *BLACK_PATHS) -def default(session, install_grpc=True): +def install_prerelease_dependencies(session, constraints_path): + with open(constraints_path, encoding="utf-8") as constraints_file: + constraints_text = constraints_file.read() + # Ignore leading whitespace and comment lines. + constraints_deps = [ + match.group(1) + for match in re.finditer( + r"^\s*(\S+)(?===\S+)", constraints_text, flags=re.MULTILINE + ) + ] + session.install(*constraints_deps) + prerel_deps = [ + "google-auth", + "googleapis-common-protos", + "grpcio", + "grpcio-status", + "proto-plus", + "protobuf", + ] + + for dep in prerel_deps: + session.install("--pre", "--no-deps", "--upgrade", dep) + + # Remaining dependencies + other_deps = [ + "requests", + ] + session.install(*other_deps) + + +def default(session, install_grpc=True, prerelease=False): """Default unit test session. This is intended to be run **without** an interpreter set, so @@ -80,9 +114,8 @@ def default(session, install_grpc=True): Python corresponding to the ``nox`` binary the ``PATH`` can run the tests. """ - constraints_path = str( - CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" - ) + if prerelease and not install_grpc: + unittest.skip("The pre-release session cannot be run without grpc") session.install( "dataclasses", @@ -92,10 +125,36 @@ def default(session, install_grpc=True): "pytest-xdist", ) - if install_grpc: - session.install("-e", ".[grpc]", "-c", constraints_path) + constraints_dir = str(CURRENT_DIRECTORY / "testing") + + if prerelease: + install_prerelease_dependencies( + session, f"{constraints_dir}/constraints-{PYTHON_VERSIONS[0]}.txt" + ) + # This *must* be the last install command to get the package from source. + session.install("-e", ".", "--no-deps") else: - session.install("-e", ".", "-c", constraints_path) + session.install( + "-e", + ".[grpc]" if install_grpc else ".", + "-c", + f"{constraints_dir}/constraints-{session.python}.txt", + ) + + # Print out package versions of dependencies + session.run( + "python", "-c", "import google.protobuf; print(google.protobuf.__version__)" + ) + # Support for proto.version was added in v1.23.0 + # https://github.com/googleapis/proto-plus-python/releases/tag/v1.23.0 + session.run( + "python", + "-c", + """import proto; hasattr(proto, "version") and print(proto.version.__version__)""", + ) + if install_grpc: + session.run("python", "-c", "import grpc; print(grpc.__version__)") + session.run("python", "-c", "import google.auth; print(google.auth.__version__)") pytest_args = [ "python", @@ -130,15 +189,26 @@ def default(session, install_grpc=True): session.run(*pytest_args) -@nox.session(python=["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]) +@nox.session(python=PYTHON_VERSIONS) def unit(session): """Run the unit test suite.""" default(session) -@nox.session(python=["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]) +@nox.session(python=PYTHON_VERSIONS) +def unit_with_prerelease_deps(session): + """Run the unit test suite.""" + default(session, prerelease=True) + + +@nox.session(python=PYTHON_VERSIONS) def unit_grpc_gcp(session): - """Run the unit test suite with grpcio-gcp installed.""" + """ + Run the unit test suite with grpcio-gcp installed. + `grpcio-gcp` doesn't support protobuf 4+. + Remove extra `grpcgcp` when protobuf 3.x is dropped. + https://github.com/googleapis/python-api-core/issues/594 + """ constraints_path = str( CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" ) @@ -150,7 +220,7 @@ def unit_grpc_gcp(session): default(session) -@nox.session(python=["3.8", "3.10", "3.11", "3.12"]) +@nox.session(python=PYTHON_VERSIONS) def unit_wo_grpc(session): """Run the unit test suite w/o grpcio installed""" default(session, install_grpc=False) @@ -164,10 +234,10 @@ def lint_setup_py(session): session.run("python", "setup.py", "check", "--restructuredtext", "--strict") -@nox.session(python="3.8") +@nox.session(python=DEFAULT_PYTHON_VERSION) def pytype(session): """Run type-checking.""" - session.install(".[grpc]", "pytype >= 2019.3.21") + session.install(".[grpc]", "pytype") session.run("pytype") @@ -178,9 +248,7 @@ def mypy(session): session.install( "types-setuptools", "types-requests", - # TODO(https://github.com/googleapis/python-api-core/issues/642): - # Use the latest version of types-protobuf. - "types-protobuf<5", + "types-protobuf", "types-mock", "types-dataclasses", ) diff --git a/owlbot.py b/owlbot.py index 5a83032e..c8c76542 100644 --- a/owlbot.py +++ b/owlbot.py @@ -29,8 +29,8 @@ ".flake8", # flake8-import-order, layout ".coveragerc", # layout "CONTRIBUTING.rst", # no systests - ".github/workflows/unittest.yml", # exclude unittest gh action - ".github/workflows/lint.yml", # exclude lint gh action + ".github/workflows/unittest.yml", # exclude unittest gh action + ".github/workflows/lint.yml", # exclude lint gh action "README.rst", ] templated_files = common.py_library(microgenerator=True, cov_level=100) diff --git a/pytest.ini b/pytest.ini index 13d5bf4d..696548cf 100644 --- a/pytest.ini +++ b/pytest.ini @@ -19,3 +19,5 @@ filterwarnings = ignore:.*pkg_resources is deprecated as an API:DeprecationWarning # Remove once https://github.com/grpc/grpc/issues/35086 is fixed (and version newer than 1.60.0 is published) ignore:There is no current event loop:DeprecationWarning + # Remove after support for Python 3.7 is dropped + ignore:After January 1, 2024, new releases of this library will drop support for Python 3.7:DeprecationWarning diff --git a/setup.py b/setup.py index 47e0b454..a9e01f49 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ release_status = "Development Status :: 5 - Production/Stable" dependencies = [ "googleapis-common-protos >= 1.56.2, < 2.0.dev0", - "protobuf>=3.19.5,<5.0.0.dev0,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", + "protobuf>=3.19.5,<6.0.0.dev0,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", "proto-plus >= 1.22.3, <2.0.0dev", "google-auth >= 2.14.1, < 3.0.dev0", "requests >= 2.18.0, < 3.0.0.dev0", diff --git a/tests/asyncio/test_grpc_helpers_async.py b/tests/asyncio/test_grpc_helpers_async.py index 6bde59ca..1a408ccd 100644 --- a/tests/asyncio/test_grpc_helpers_async.py +++ b/tests/asyncio/test_grpc_helpers_async.py @@ -97,12 +97,40 @@ async def test_common_methods_in_wrapped_call(): assert mock_call.wait_for_connection.call_count == 1 +@pytest.mark.asyncio +@pytest.mark.parametrize( + "callable_type,expected_wrapper_type", + [ + (grpc.aio.UnaryStreamMultiCallable, grpc_helpers_async._WrappedUnaryStreamCall), + (grpc.aio.StreamUnaryMultiCallable, grpc_helpers_async._WrappedStreamUnaryCall), + ( + grpc.aio.StreamStreamMultiCallable, + grpc_helpers_async._WrappedStreamStreamCall, + ), + ], +) +async def test_wrap_errors_w_stream_type(callable_type, expected_wrapper_type): + class ConcreteMulticallable(callable_type): + def __call__(self, *args, **kwargs): + raise NotImplementedError("Should not be called") + + with mock.patch.object( + grpc_helpers_async, "_wrap_stream_errors" + ) as wrap_stream_errors: + callable_ = ConcreteMulticallable() + grpc_helpers_async.wrap_errors(callable_) + assert wrap_stream_errors.call_count == 1 + wrap_stream_errors.assert_called_once_with(callable_, expected_wrapper_type) + + @pytest.mark.asyncio async def test_wrap_stream_errors_unary_stream(): mock_call = mock.Mock(aio.UnaryStreamCall, autospec=True) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedUnaryStreamCall + ) await wrapped_callable(1, 2, three="four") multicallable.assert_called_once_with(1, 2, three="four") @@ -114,7 +142,9 @@ async def test_wrap_stream_errors_stream_unary(): mock_call = mock.Mock(aio.StreamUnaryCall, autospec=True) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamUnaryCall + ) await wrapped_callable(1, 2, three="four") multicallable.assert_called_once_with(1, 2, three="four") @@ -126,24 +156,15 @@ async def test_wrap_stream_errors_stream_stream(): mock_call = mock.Mock(aio.StreamStreamCall, autospec=True) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) await wrapped_callable(1, 2, three="four") multicallable.assert_called_once_with(1, 2, three="four") assert mock_call.wait_for_connection.call_count == 1 -@pytest.mark.asyncio -async def test_wrap_stream_errors_type_error(): - mock_call = mock.Mock() - multicallable = mock.Mock(return_value=mock_call) - - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) - - with pytest.raises(TypeError): - await wrapped_callable() - - @pytest.mark.asyncio async def test_wrap_stream_errors_raised(): grpc_error = RpcErrorImpl(grpc.StatusCode.INVALID_ARGUMENT) @@ -151,7 +172,9 @@ async def test_wrap_stream_errors_raised(): mock_call.wait_for_connection = mock.AsyncMock(side_effect=[grpc_error]) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) with pytest.raises(exceptions.InvalidArgument): await wrapped_callable() @@ -166,7 +189,9 @@ async def test_wrap_stream_errors_read(): mock_call.read = mock.AsyncMock(side_effect=grpc_error) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) wrapped_call = await wrapped_callable(1, 2, three="four") multicallable.assert_called_once_with(1, 2, three="four") @@ -189,7 +214,9 @@ async def test_wrap_stream_errors_aiter(): mock_call.__aiter__ = mock.Mock(return_value=mocked_aiter) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) wrapped_call = await wrapped_callable() with pytest.raises(exceptions.InvalidArgument) as exc_info: @@ -210,7 +237,9 @@ async def test_wrap_stream_errors_aiter_non_rpc_error(): mock_call.__aiter__ = mock.Mock(return_value=mocked_aiter) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) wrapped_call = await wrapped_callable() with pytest.raises(TypeError) as exc_info: @@ -224,7 +253,9 @@ async def test_wrap_stream_errors_aiter_called_multiple_times(): mock_call = mock.Mock(aio.StreamStreamCall, autospec=True) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) wrapped_call = await wrapped_callable() assert wrapped_call.__aiter__() == wrapped_call.__aiter__() @@ -239,7 +270,9 @@ async def test_wrap_stream_errors_write(): mock_call.done_writing = mock.AsyncMock(side_effect=[None, grpc_error]) multicallable = mock.Mock(return_value=mock_call) - wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable) + wrapped_callable = grpc_helpers_async._wrap_stream_errors( + multicallable, grpc_helpers_async._WrappedStreamStreamCall + ) wrapped_call = await wrapped_callable() @@ -295,7 +328,9 @@ def test_wrap_errors_streaming(wrap_stream_errors): result = grpc_helpers_async.wrap_errors(callable_) assert result == wrap_stream_errors.return_value - wrap_stream_errors.assert_called_once_with(callable_) + wrap_stream_errors.assert_called_once_with( + callable_, grpc_helpers_async._WrappedUnaryStreamCall + ) @pytest.mark.parametrize( diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py index b532eb1d..0f2b3b32 100644 --- a/tests/unit/test_rest_streaming.py +++ b/tests/unit/test_rest_streaming.py @@ -101,9 +101,11 @@ def _parse_responses(self, responses: List[proto.Message]) -> bytes: # json.dumps returns a string surrounded with quotes that need to be stripped # in order to be an actual JSON. json_responses = [ - self._response_message_cls.to_json(r).strip('"') - if issubclass(self._response_message_cls, proto.Message) - else MessageToJson(r).strip('"') + ( + self._response_message_cls.to_json(r).strip('"') + if issubclass(self._response_message_cls, proto.Message) + else MessageToJson(r).strip('"') + ) for r in responses ] logging.info(f"Sending JSON stream: {json_responses}")