From 3d63168a784471ddae6a9968f66706ccbc38661f Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 14 Dec 2023 14:32:01 -0800 Subject: [PATCH] Add basic tools support. --- google/generativeai/generative_models.py | 9 +++-- google/generativeai/types/content_types.py | 15 ++++++++ tests/test_content.py | 42 ++++++++++++++++++++++ tests/test_generative_models.py | 28 +++++++++++++++ 4 files changed, 92 insertions(+), 2 deletions(-) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index d2b02b1d8..a783ea9de 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -2,14 +2,14 @@ from __future__ import annotations +from collections.abc import Iterable import dataclasses import textwrap +from typing import Union # pylint: disable=bad-continuation, line-too-long -from collections.abc import Iterable - from google.ai import generativelanguage as glm from google.generativeai import client from google.generativeai import string_utils @@ -70,6 +70,7 @@ generation_config: Overrides for the model's generation config. safety_settings: Overrides for the model's safety settings. stream: If True, yield response chunks as they are generated. + tools: `glm.Tools` more info coming soon. """ _SEND_MESSAGE_ASYNC_DOC = """The async version of `ChatSession.send_message`.""" @@ -158,6 +159,7 @@ def __init__( model_name: str = "gemini-m", safety_settings: safety_types.SafetySettingOptions | None = None, generation_config: generation_types.GenerationConfigType | None = None, + tools: content_types.ToolsType = None, ): if "/" not in model_name: model_name = "models/" + model_name @@ -166,6 +168,8 @@ def __init__( safety_settings, harm_category_set="new" ) self._generation_config = generation_types.to_generation_config_dict(generation_config) + self._tools = content_types.to_tools(tools) + self._client = None self._async_client = None @@ -213,6 +217,7 @@ def _prepare_request( contents=contents, generation_config=merged_gc, safety_settings=merged_ss, + tools=self._tools, **kwargs, ) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index e6503ed63..56a34b1ba 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -40,6 +40,7 @@ "ContentType", "StrictContentType", "ContentsType", + "ToolsType", ] @@ -234,3 +235,17 @@ def to_contents(contents: ContentsType) -> list[glm.Content]: contents = [to_content(contents)] return contents + + +ToolsType = Union[Iterable[glm.Tool], glm.Tool, dict[str, Any], None] + + +def to_tools(tools: ToolsType) -> list[glm.Tool]: + if tools is None: + return [] + elif isinstance(tools, Mapping): + return [glm.Tool(tools)] + elif isinstance(tools, Iterable): + return [glm.Tool(t) for t in tools] + else: + return [glm.Tool(tools)] diff --git a/tests/test_content.py b/tests/test_content.py index 3b82b1982..122bbe224 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -169,6 +169,48 @@ def test_img_to_contents(self, example): self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") + @parameterized.named_parameters( + [ + "OneTool", + glm.Tool( + function_declarations=[ + glm.FunctionDeclaration( + name="datetime", description="Returns the current UTC date and time." + ) + ] + ), + ], + [ + "ToolDict", + dict( + function_declarations=[ + dict(name="datetime", description="Returns the current UTC date and time.") + ] + ), + ], + [ + "ListOfTools", + [ + glm.Tool( + function_declarations=[ + glm.FunctionDeclaration( + name="datetime", + description="Returns the current UTC date and time.", + ) + ] + ) + ], + ], + ) + def test_img_to_contents(self, tools): + tools = content_types.to_tools(tools) + expected = dict( + function_declarations=[ + dict(name="datetime", description="Returns the current UTC date and time.") + ] + ) + self.assertEqual(type(tools[0]).to_dict(tools[0]), expected) + if __name__ == "__main__": absltest.main() diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 01608eb97..142893abe 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -564,6 +564,34 @@ def test_chat_streaming_unexpected_stop(self): chat.rewind() self.assertLen(chat.history, 0) + def test_tools(self): + tools = dict( + function_declarations=[ + dict(name="datetime", description="Returns the current UTC date and time.") + ] + ) + model = generative_models.GenerativeModel("gemini-mm-m", tools=tools) + + self.responses["generate_content"] = [ + simple_response("a"), + simple_response("b"), + ] + + response = model.generate_content("Hello") + + chat = model.start_chat() + response = chat.send_message("Hello") + + expect_tools = dict( + function_declarations=[ + dict(name="datetime", description="Returns the current UTC date and time.") + ] + ) + + for obr in self.observed_requests: + self.assertLen(obr.tools, 1) + self.assertEqual(type(obr.tools[0]).to_dict(obr.tools[0]), tools) + @parameterized.named_parameters( [ "GenerateContentResponse",