From 57fd446d4a83277b4ddeab9fb02b2a95b9797f27 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Mon, 2 Oct 2023 17:48:12 -0700 Subject: [PATCH 1/3] Add metadata handling. --- google/generativeai/client.py | 273 ++++++++++++++++++++++------------ 1 file changed, 174 insertions(+), 99 deletions(-) diff --git a/google/generativeai/client.py b/google/generativeai/client.py index 1d0e3c16a..e6f8b4477 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -15,7 +15,9 @@ from __future__ import annotations import os -from typing import cast, Optional, Union +import dataclasses +import types +from typing import Any, cast, Sequence import google.ai.generativelanguage as glm @@ -26,15 +28,160 @@ from google.generativeai import version - USER_AGENT = "genai-py" -default_client_config = {} -default_discuss_client = None -default_discuss_async_client = None -default_model_client = None -default_text_client = None -default_operations_client = None + +@dataclasses.dataclass +class _ClientManager: + client_config: dict[str, Any] = dataclasses.field(default_factory=dict) + metadata: Sequence[tuple[str, str]] = () + discuss_client: glm.DiscussServiceClient | None = None + discuss_async_client: glm.DiscussServiceAsyncClient | None = None + model_client: glm.ModelServiceClient | None = None + text_client: glm.TextServiceClient | None = None + operations_client = None + + def configure( + self, + *, + api_key: str | None = None, + credentials: ga_credentials.Credentials | dict | None = None, + # The user can pass a string to choose `rest` or `grpc` or 'grpc_asyncio'. + # See `_transport_registry` in `DiscussServiceClientMeta`. + # Since the transport classes align with the client classes it wouldn't make + # sense to accept a `Transport` object here even though the client classes can. + # We could accept a dict since all the `Transport` classes take the same args, + # but that seems rare. Users that need it can just switch to the low level API. + transport: str | None = None, + client_options: client_options_lib.ClientOptions | dict | None = None, + client_info: gapic_v1.client_info.ClientInfo | None = None, + default_metadata: Sequence[tuple[str, str]] = (), + ): + """Captures default client configuration. + + If no API key has been provided (either directly, or on `client_options`) and the + `GOOGLE_API_KEY` environment variable is set, it will be used as the API key. + + Args: + Refer to `glm.DiscussServiceClient`, and `glm.ModelsServiceClient` for details on additional arguments. + transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`]. + api_key: The API-Key to use when creating the default clients (each service uses + a separate client). This is a shortcut for `client_options={"api_key": api_key}`. + If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be + used. + default_metadata: Default (key, value) metadata pairs to send with every request. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + client_options = cast(client_options_lib.ClientOptions, client_options) + had_api_key_value = getattr(client_options, "api_key", None) + + if had_api_key_value: + if api_key is not None: + raise ValueError("You can't set both `api_key` and `client_options['api_key']`.") + else: + if api_key is None: + # If no key is provided explicitly, attempt to load one from the + # environment. + api_key = os.getenv("GOOGLE_API_KEY") + + client_options.api_key = api_key + + user_agent = f"{USER_AGENT}/{version.__version__}" + if client_info: + # Be respectful of any existing agent setting. + if client_info.user_agent: + client_info.user_agent += f" {user_agent}" + else: + client_info.user_agent = user_agent + else: + client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent) + + client_config = { + "credentials": credentials, + "transport": transport, + "client_options": client_options, + "client_info": client_info, + } + + client_config = {key: value for key, value in client_config.items() if value is not None} + + self.client_config = client_config + self.default_metadata = default_metadata + self.discuss_client = None + self.text_client = None + self.model_client = None + self.operations_client = None + + def make_client(self, cls): + # Attempt to configure using defaults. + if self.client_config is None: + configure() + + client = cls(**self.client_config) + + if not self.default_metadata: + return client + + def keep(name, f): + if name.startswith("_"): + return False + if not isinstance(f, types.FunctionType): + return False + if isinstance(f, classmethod): + return False + if isinstance(f, staticmethod): + False + + return True + + def add_default_metadata_wrapper(f): + def call(*args, metadata=(), **kwargs): + metadata = list(metadata) + list(self.default_metadata) + return f(*args, **kwargs, metadata=metadata) + + return call + + for name, value in cls.__dict__.items(): + if not keep(name, value): + continue + f = getattr(client, name) + f = add_default_metadata_wrapper(f) + setattr(client, name, f) + + return client + + def get_default_discuss_client(self) -> glm.DiscussServiceClient: + if self.discuss_client is None: + self.discuss_client = self.make_client(glm.DiscussServiceClient) + return self.discuss_client + + def get_default_text_client(self) -> glm.TextServiceClient: + if self.text_client is None: + self.text_client = self.make_client(glm.TextServiceClient) + return self.text_client + + def get_default_discuss_async_client(self) -> glm.DiscussServiceAsyncClient: + if self.discuss_async_client is None: + self.discuss_async_client = self.make_client(glm.DiscussServiceAsyncClient) + return self.discuss_async_client + + def get_default_model_client(self) -> glm.ModelServiceClient: + if self.model_client is None: + self.model_client = self.make_client(glm.ModelServiceClient) + return self.model_client + + def get_default_operations_client(self) -> operations_v1.OperationsClient: + if self.operations_client is None: + self.model_client = get_default_model_client() + self.operations_client = model_client._transport.operations_client + + return self.operations_client + + +_client_manager = _ClientManager() def configure( @@ -50,6 +197,7 @@ def configure( transport: str | None = None, client_options: client_options_lib.ClientOptions | dict | None = None, client_info: gapic_v1.client_info.ClientInfo | None = None, + default_metadata: Sequence[tuple[str, str]] = (), ): """Captures default client configuration. @@ -58,111 +206,38 @@ def configure( Args: Refer to `glm.DiscussServiceClient`, and `glm.ModelsServiceClient` for details on additional arguments. + transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`]. api_key: The API-Key to use when creating the default clients (each service uses a separate client). This is a shortcut for `client_options={"api_key": api_key}`. If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be used. + default_metadata: Default `(key, value)` metadata pairs to send with every request. """ - global default_client_config - global default_discuss_client - global default_model_client - global default_text_client - global default_operations_client - - if isinstance(client_options, dict): - client_options = client_options_lib.from_dict(client_options) - if client_options is None: - client_options = client_options_lib.ClientOptions() - client_options = cast(client_options_lib.ClientOptions, client_options) - had_api_key_value = getattr(client_options, "api_key", None) - - if had_api_key_value: - if api_key is not None: - raise ValueError("You can't set both `api_key` and `client_options['api_key']`.") - else: - if api_key is None: - # If no key is provided explicitly, attempt to load one from the - # environment. - api_key = os.getenv("GOOGLE_API_KEY") - - client_options.api_key = api_key - - user_agent = f"{USER_AGENT}/{version.__version__}" - if client_info: - # Be respectful of any existing agent setting. - if client_info.user_agent: - client_info.user_agent += f" {user_agent}" - else: - client_info.user_agent = user_agent - else: - client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent) - - new_default_client_config = { - "credentials": credentials, - "transport": transport, - "client_options": client_options, - "client_info": client_info, - } - - new_default_client_config = { - key: value for key, value in new_default_client_config.items() if value is not None - } - - default_client_config = new_default_client_config - default_discuss_client = None - default_text_client = None - default_model_client = None - default_operations_client = None + return _client_manager.configure( + api_key=api_key, + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + default_metadata=default_metadata, + ) def get_default_discuss_client() -> glm.DiscussServiceClient: - global default_discuss_client - if default_discuss_client is None: - # Attempt to configure using defaults. - if not default_client_config: - configure() - default_discuss_client = glm.DiscussServiceClient(**default_client_config) - - return default_discuss_client + return _client_manager.get_default_discuss_client() def get_default_text_client() -> glm.TextServiceClient: - global default_text_client - if default_text_client is None: - # Attempt to configure using defaults. - if not default_client_config: - configure() - default_text_client = glm.TextServiceClient(**default_client_config) - - return default_text_client + return _client_manager.get_default_discuss_client() -def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient: - global default_discuss_async_client - if default_discuss_async_client is None: - # Attempt to configure using defaults. - if not default_client_config: - configure() - default_discuss_async_client = glm.DiscussServiceAsyncClient(**default_client_config) - - return default_discuss_async_client - - -def get_default_model_client() -> glm.ModelServiceClient: - global default_model_client - if default_model_client is None: - # Attempt to configure using defaults. - if not default_client_config: - configure() - default_model_client = glm.ModelServiceClient(**default_client_config) +def get_default_operations_client() -> operations_v1.OperationsClient: + return _client_manager.get_default_operations_client() - return default_model_client +def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient: + return _client_manager.get_default_discuss_async_client() -def get_default_operations_client() -> operations_v1.OperationsClient: - global default_operations_client - if default_operations_client is None: - model_client = get_default_model_client() - default_operations_client = model_client._transport.operations_client - return default_operations_client +def get_default_model_client() -> glm.ModelServiceAsyncClient: + return _client_manager.get_default_model_client() From ff232cfbdb8642507a246002b1fd625679c9662e Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 5 Oct 2023 14:53:35 -0700 Subject: [PATCH 2/3] Add and fix tests. --- google/generativeai/client.py | 8 ++++---- tests/test_client.py | 33 ++++++++++++++++++++++++++------- tests/test_discuss.py | 2 +- tests/test_models.py | 2 +- tests/test_text.py | 2 +- 5 files changed, 33 insertions(+), 14 deletions(-) diff --git a/google/generativeai/client.py b/google/generativeai/client.py index e6f8b4477..eb5bb024a 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -34,7 +34,7 @@ @dataclasses.dataclass class _ClientManager: client_config: dict[str, Any] = dataclasses.field(default_factory=dict) - metadata: Sequence[tuple[str, str]] = () + default_metadata: Sequence[tuple[str, str]] = () discuss_client: glm.DiscussServiceClient | None = None discuss_async_client: glm.DiscussServiceAsyncClient | None = None model_client: glm.ModelServiceClient | None = None @@ -117,7 +117,7 @@ def configure( def make_client(self, cls): # Attempt to configure using defaults. - if self.client_config is None: + if not self.client_config: configure() client = cls(**self.client_config) @@ -176,7 +176,7 @@ def get_default_model_client(self) -> glm.ModelServiceClient: def get_default_operations_client(self) -> operations_v1.OperationsClient: if self.operations_client is None: self.model_client = get_default_model_client() - self.operations_client = model_client._transport.operations_client + self.operations_client = self.model_client._transport.operations_client return self.operations_client @@ -228,7 +228,7 @@ def get_default_discuss_client() -> glm.DiscussServiceClient: def get_default_text_client() -> glm.TextServiceClient: - return _client_manager.get_default_discuss_client() + return _client_manager.get_default_text_client() def get_default_operations_client() -> operations_v1.OperationsClient: diff --git a/tests/test_client.py b/tests/test_client.py index a5b5a5e78..b512a949f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -20,37 +20,38 @@ from absl.testing import parameterized from google.api_core import client_options +import google.ai.generativelanguage as glm from google.generativeai import client class ClientTests(parameterized.TestCase): def setUp(self): super().setUp() - client.default_client_config = {} + client._client_manager = client._ClientManager() def test_api_key_passed_directly(self): client.configure(api_key="AIzA_direct") - client_opts = client.default_client_config["client_options"] + client_opts = client._client_manager.client_config["client_options"] self.assertEqual(client_opts.api_key, "AIzA_direct") def test_api_key_passed_via_client_options(self): client_opts = client_options.ClientOptions(api_key="AIzA_client_opts") client.configure(client_options=client_opts) - client_opts = client.default_client_config["client_options"] + client_opts = client._client_manager.client_config["client_options"] self.assertEqual(client_opts.api_key, "AIzA_client_opts") @mock.patch.dict(os.environ, {"GOOGLE_API_KEY": "AIzA_env"}) def test_api_key_from_environment(self): # Default to API key loaded from environment. client.configure() - client_opts = client.default_client_config["client_options"] + client_opts = client._client_manager.client_config["client_options"] self.assertEqual(client_opts.api_key, "AIzA_env") # But not when a key is provided explicitly. client.configure(api_key="AIzA_client") - client_opts = client.default_client_config["client_options"] + client_opts = client._client_manager.client_config["client_options"] self.assertEqual(client_opts.api_key, "AIzA_client") def test_api_key_cannot_be_set_twice(self): @@ -65,7 +66,7 @@ def test_api_key_and_client_options(self): client_opts = client_options.ClientOptions(api_endpoint="web.site") client.configure(api_key="AIzA_client", client_options=client_opts) - actual_client_opts = client.default_client_config["client_options"] + actual_client_opts = client._client_manager.client_config["client_options"] self.assertEqual(actual_client_opts.api_key, "AIzA_client") self.assertEqual(actual_client_opts.api_endpoint, "web.site") @@ -74,15 +75,33 @@ def test_api_key_and_client_options(self): client.get_default_text_client, client.get_default_discuss_async_client, client.get_default_model_client, + client.get_default_operations_client, ) @mock.patch.dict(os.environ, {"GOOGLE_API_KEY": "AIzA_env"}) def test_configureless_client_with_key(self, factory_fn): _ = factory_fn() # And ensure that it has set the default options. - actual_client_opts = client.default_client_config["client_options"] + actual_client_opts = client._client_manager.client_config["client_options"] self.assertEqual(actual_client_opts.api_key, "AIzA_env") + class DummyClient: + def __init__(self, *args, **kwargs): + pass + + def generate_text(self, metadata=None): + self.metadata = metadata + + @mock.patch.object(glm, "TextServiceClient", DummyClient) + def test_default_metadata(self): + metadata = [("hello", "world")] + client.configure(default_metadata=metadata) + + text_client = client.get_default_text_client() + text_client.generate_text() + + self.assertEqual(metadata, text_client.metadata) + if __name__ == "__main__": absltest.main() diff --git a/tests/test_discuss.py b/tests/test_discuss.py index cb0455663..8021d972a 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -33,7 +33,7 @@ class UnitTests(parameterized.TestCase): def setUp(self): self.client = unittest.mock.MagicMock() - client.default_discuss_client = self.client + client._client_manager.discuss_client = self.client self.observed_request = None diff --git a/tests/test_models.py b/tests/test_models.py index 646e31969..c06daa9cc 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -36,7 +36,7 @@ class UnitTests(parameterized.TestCase): def setUp(self): self.client = unittest.mock.MagicMock() - client.default_model_client = self.client + client._client_manager.model_client = self.client def add_client_method(f): name = f.__name__ diff --git a/tests/test_text.py b/tests/test_text.py index 0e5381677..36b822e67 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -30,7 +30,7 @@ class UnitTests(parameterized.TestCase): def setUp(self): self.client = unittest.mock.MagicMock() - client.default_text_client = self.client + client._client_manager.text_client = self.client self.observed_request = None From 6eb1f26034269e50264b12a3308ba073cd929c01 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 5 Oct 2023 16:28:20 -0700 Subject: [PATCH 3/3] Resolve comments --- google/generativeai/client.py | 29 ++++++++++++++++++----------- tests/test_client.py | 25 +++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/google/generativeai/client.py b/google/generativeai/client.py index eb5bb024a..dead136c9 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -17,7 +17,8 @@ import os import dataclasses import types -from typing import Any, cast, Sequence +from typing import Any, cast +from collections.abc import Sequence import google.ai.generativelanguage as glm @@ -56,20 +57,23 @@ def configure( client_options: client_options_lib.ClientOptions | dict | None = None, client_info: gapic_v1.client_info.ClientInfo | None = None, default_metadata: Sequence[tuple[str, str]] = (), - ): + ) -> None: """Captures default client configuration. If no API key has been provided (either directly, or on `client_options`) and the `GOOGLE_API_KEY` environment variable is set, it will be used as the API key. + Note: Not all arguments are detailed below. Refer to the `*ServiceClient` classes in + `google.ai.generativelanguage` for details on the other arguments. + Args: - Refer to `glm.DiscussServiceClient`, and `glm.ModelsServiceClient` for details on additional arguments. transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`]. api_key: The API-Key to use when creating the default clients (each service uses a separate client). This is a shortcut for `client_options={"api_key": api_key}`. If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be used. default_metadata: Default (key, value) metadata pairs to send with every request. + when using `transport="rest"` these are sent as HTTP headers. """ if isinstance(client_options, dict): client_options = client_options_lib.from_dict(client_options) @@ -128,14 +132,14 @@ def make_client(self, cls): def keep(name, f): if name.startswith("_"): return False - if not isinstance(f, types.FunctionType): + elif not isinstance(f, types.FunctionType): return False - if isinstance(f, classmethod): + elif isinstance(f, classmethod): return False - if isinstance(f, staticmethod): - False - - return True + elif isinstance(f, staticmethod): + return False + else: + return True def add_default_metadata_wrapper(f): def call(*args, metadata=(), **kwargs): @@ -204,14 +208,17 @@ def configure( If no API key has been provided (either directly, or on `client_options`) and the `GOOGLE_API_KEY` environment variable is set, it will be used as the API key. + Note: Not all arguments are detailed below. Refer to the `*ServiceClient` classes in + `google.ai.generativelanguage` for details on the other arguments. + Args: - Refer to `glm.DiscussServiceClient`, and `glm.ModelsServiceClient` for details on additional arguments. transport: A string, one of: [`rest`, `grpc`, `grpc_asyncio`]. api_key: The API-Key to use when creating the default clients (each service uses a separate client). This is a shortcut for `client_options={"api_key": api_key}`. If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be used. - default_metadata: Default `(key, value)` metadata pairs to send with every request. + default_metadata: Default (key, value) metadata pairs to send with every request. + when using `transport="rest"` these are sent as HTTP headers. """ return _client_manager.configure( api_key=api_key, diff --git a/tests/test_client.py b/tests/test_client.py index b512a949f..29c14ea51 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -92,8 +92,22 @@ def __init__(self, *args, **kwargs): def generate_text(self, metadata=None): self.metadata = metadata + not_a_function = 7 + + def _hidden(self): + self.called_hidden = True + + @staticmethod + def static(): + pass + + @classmethod + def classm(cls): + cls.called_classm = True + @mock.patch.object(glm, "TextServiceClient", DummyClient) def test_default_metadata(self): + # The metadata wrapper injects this argument. metadata = [("hello", "world")] client.configure(default_metadata=metadata) @@ -102,6 +116,17 @@ def test_default_metadata(self): self.assertEqual(metadata, text_client.metadata) + self.assertEqual(text_client.not_a_function, ClientTests.DummyClient.not_a_function) + + # Since these don't have a metadata arg, they'll fail if the wrapper is applied. + text_client._hidden() + self.assertTrue(text_client.called_hidden) + + text_client.static() + + text_client.classm() + self.assertTrue(ClientTests.DummyClient.called_classm) + if __name__ == "__main__": absltest.main()