From 7c73eb4b07ff902e77da664257d880d75f1caf1b Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 15 Dec 2023 11:07:41 -0800 Subject: [PATCH 1/2] Add smkoe tests for count_tokens. --- google/generativeai/generative_models.py | 8 ++++++-- tests/test_generative_models.py | 23 +++++++++++++++++++++++ tests/test_generative_models_async.py | 23 +++++++++++++++++++++++ 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index f1725815b..9d9520137 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -274,14 +274,18 @@ async def generate_content_async( def count_tokens( self, contents: content_types.ContentsType ) -> glm.CountTokensResponse: + if self._client is None: + self._client = client.get_default_generative_client() contents = content_types.to_contents(contents) - return self._client.count_tokens(model=self.model_name, contents=contents) + return self._client.count_tokens(glm.CountTokensRequest(model=self.model_name, contents=contents)) async def count_tokens_async( self, contents: content_types.ContentsType ) -> glm.CountTokensResponse: + if self._async_client is None: + self._async_client = client.get_default_generative_async_client() contents = content_types.to_contents(contents) - return await self._client.count_tokens(model=self.model_name, contents=contents) + return await self._async_client.count_tokens(glm.CountTokensRequest(model=self.model_name, contents=contents)) # fmt: on def start_chat( diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 01608eb97..07527b166 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -56,6 +56,14 @@ def stream_generate_content( response = self.responses["stream_generate_content"].pop(0) return response + @add_client_method + def count_tokens( + request: glm.CountTokensRequest, + ) -> Iterable[glm.GenerateContentResponse]: + self.observed_requests.append(request) + response = self.responses["count_tokens"].pop(0) + return response + def test_hello(self): # Generate text from text prompt model = generative_models.GenerativeModel(model_name="gemini-m") @@ -564,6 +572,21 @@ def test_chat_streaming_unexpected_stop(self): chat.rewind() self.assertLen(chat.history, 0) + @parameterized.named_parameters( + ["basic", "Hello"], + ["list", ["Hello"]], + [ + "list2", + [{"text": "Hello"}, {"inline_data": {"data": b"PNG!", "mime_type": "image/png"}}], + ], + ["contents", [{"role": "user", "parts": ["hello"]}]], + ) + def test_count_tokens_smoke(self, contents): + self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)] + model = generative_models.GenerativeModel("gemini-mm-m") + response = model.count_tokens(contents) + self.assertEqual(type(response).to_dict(response), {"total_tokens": 7}) + @parameterized.named_parameters( [ "GenerateContentResponse", diff --git a/tests/test_generative_models_async.py b/tests/test_generative_models_async.py index 1c48f3476..313cc0493 100644 --- a/tests/test_generative_models_async.py +++ b/tests/test_generative_models_async.py @@ -62,6 +62,14 @@ async def stream_generate_content( response = self.responses["stream_generate_content"].pop(0) return response + @add_client_method + async def count_tokens( + request: glm.CountTokensRequest, + ) -> Iterable[glm.GenerateContentResponse]: + self.observed_requests.append(request) + response = self.responses["count_tokens"].pop(0) + return response + async def test_basic(self): # Generate text from text prompt model = generative_models.GenerativeModel(model_name="gemini-m") @@ -98,6 +106,21 @@ async def responses(): self.assertEqual(response.text, "world!") + @parameterized.named_parameters( + ["basic", "Hello"], + ["list", ["Hello"]], + [ + "list2", + [{"text": "Hello"}, {"inline_data": {"data": b"PNG!", "mime_type": "image/png"}}], + ], + ["contents", [{"role": "user", "parts": ["hello"]}]], + ) + async def test_count_tokens_smoke(self, contents): + self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)] + model = generative_models.GenerativeModel("gemini-mm-m") + response = await model.count_tokens_async(contents) + self.assertEqual(type(response).to_dict(response), {"total_tokens": 7}) + if __name__ == "__main__": absltest.main() From f14eaec71fdf2c4557b6d2c2c345af563f9a7cb1 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 15 Dec 2023 11:29:55 -0800 Subject: [PATCH 2/2] fix docstring --- google/generativeai/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/generativeai/__init__.py b/google/generativeai/__init__.py index 6a41ef975..1709bb822 100644 --- a/google/generativeai/__init__.py +++ b/google/generativeai/__init__.py @@ -30,7 +30,7 @@ genai.configure(api_key=os.environ['API_KEY']) -model = genai.Model(name='gemini-pro') +model = genai.GenerativeModel(name='gemini-pro') response = model.generate_content('Please summarise this document: ...') print(response.text)