From 11aa0365959da9ab4369510c778e35dedacb9d98 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 5 May 2023 17:05:18 -0700 Subject: [PATCH 01/15] Update to google.ai.generativelanguage 0.2.0 --- google/generativeai/version.py | 2 +- setup.py | 2 +- tests/test_text.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/google/generativeai/version.py b/google/generativeai/version.py index fa14c226f..3ca712222 100644 --- a/google/generativeai/version.py +++ b/google/generativeai/version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.1.0rc1" +__version__ = "0.1.0rc2" diff --git a/setup.py b/setup.py index a2dbe719a..d937de24f 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ release_status = "Development Status :: 5 - Production/Stable" dependencies = [ - "google-ai-generativelanguage==0.1.0" + "google-ai-generativelanguage==0.2.0" ] extras_require = { diff --git a/tests/test_text.py b/tests/test_text.py index cf41d128b..508437d19 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -133,9 +133,9 @@ def test_generate_response(self, *, prompt, **kwargs): self.assertEqual( complete.candidates, [ - {"output": " road?"}, - {"output": " bridge?"}, - {"output": " river?"}, + {"output": " road?", 'safety_ratings': []}, + {"output": " bridge?", 'safety_ratings': []}, + {"output": " river?", 'safety_ratings': []}, ], ) From b717d4fd6c944974b66d82e1a734c49663973597 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Sun, 7 May 2023 21:11:06 -0700 Subject: [PATCH 02/15] Add citation and safety types --- google/generativeai/docstring_utils.py | 20 +++++++ google/generativeai/types/citation_types.py | 38 +++++++++++++ google/generativeai/types/safety_types.py | 63 +++++++++++++++++++++ 3 files changed, 121 insertions(+) create mode 100644 google/generativeai/docstring_utils.py create mode 100644 google/generativeai/types/citation_types.py create mode 100644 google/generativeai/types/safety_types.py diff --git a/google/generativeai/docstring_utils.py b/google/generativeai/docstring_utils.py new file mode 100644 index 000000000..0356563ca --- /dev/null +++ b/google/generativeai/docstring_utils.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def strip_oneof(docstring): + lines = docstring.splitlines() + lines = [line for line in lines if ".. _oneof:" not in line] + lines = [line for line in lines if "This field is a member of `oneof`_" not in line] + return "\n".join(lines) \ No newline at end of file diff --git a/google/generativeai/types/citation_types.py b/google/generativeai/types/citation_types.py new file mode 100644 index 000000000..fa8d5f296 --- /dev/null +++ b/google/generativeai/types/citation_types.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.ai import generativelanguage as glm +from google.generativeai import docstring_utils +from typing import TypedDict + +__all__ = [ + "CitationMetadataDict", + "CitationSourceDict", +] + + +class CitationSourceDict(TypedDict): + start_index: Optional[int] + end_index: Optional[int] + uri: Optional[str] + license: Optional[str] + + __doc__ = docstring_utils.strip_oneof(glm.CitationSource.__doc__) + + +class CitationMetadataDict(TypedDict): + citation_sources = Optional[List[CitationSourceDict]] + + __doc__ = docstring_utils.strip_oneof(glm.CitationMetadata.__doc__) diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py new file mode 100644 index 000000000..aa159cd12 --- /dev/null +++ b/google/generativeai/types/safety_types.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +from google.ai import generativelanguage as glm +from typing import TypedDict + +__all__ = [ + "HarmCategory", + "HarmProbability", + "HarmBlockThreshold", + "BlockReason", + "ContentFilter", + "SafetyRatingDict", + "SafetySetting", + "SafetyFeedbackDict", +] + +# These are basic python enums, it's okay to expose them +HarmCategory = glm.HarmCategory +HarmProbability = glm.SafetyRating.HarmProbability +HarmBlockThreshold = glm.SafetySetting +BlockReason = glm.ContentFilter.BlockedReason + + +class ContentFilter(TypedDict): + reason: BlockReason + message: str + + __doc__ = docstring_utils.strip_oneof(glm.ContentFilter.__doc__) + + +class SafetyRatingDict(TypedDict): + category: HarmCategory + probability: HarmProbability + + __doc__ = docstring_utils.strip_oneof(glm.SafetyRating.__doc__) + + +class SafetySetting(TypedDict): + category: HarmCategory + threshold: HarmBlockThreshold + + __doc__ = docstring_utils.strip_oneof(glm.SafetySetting.__doc__) + + +class SafetyFeedbackDict(TypedDict): + rating: SafetyRatingDict + setting: SafetySetting + + __doc__ = docstring_utils.strip_oneof(glm.SafetyFeedback.__doc__) From ca312f96be70f81102e4675b99a4a19f0c25a4a4 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Sun, 7 May 2023 21:16:20 -0700 Subject: [PATCH 03/15] Update discuss with citation and safety --- google/generativeai/discuss.py | 6 +++++- google/generativeai/types/discuss_types.py | 15 +++++++++++++-- tests/test_discuss.py | 10 +++++----- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index f463e2abd..e63d9ff6a 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -25,6 +25,7 @@ from google.generativeai.client import get_default_discuss_async_client from google.generativeai.types import discuss_types from google.generativeai.types import model_types +from google.generativeai.types import safety_types def _make_message(content: discuss_types.MessageOptions) -> glm.Message: @@ -407,6 +408,7 @@ def reply( ) request = self.to_dict() request.pop("candidates") + request.pop("filters") request["messages"] = list(request["messages"]) request["messages"].append(_make_message(message)) request = _make_generate_message_request(**request) @@ -440,12 +442,14 @@ def _build_chat_response( request["messages"] = prompt["messages"] response = type(response).to_dict(response) + response.pop("messages") + request["messages"].append(response["candidates"][0]) request.setdefault("temperature", None) request.setdefault("candidate_count", None) return ChatResponse( - _client=client, candidates=response["candidates"], **request + _client=client, **response, **request ) # pytype: disable=missing-parameter diff --git a/google/generativeai/types/discuss_types.py b/google/generativeai/types/discuss_types.py index 88c6ebeb5..071f6788a 100644 --- a/google/generativeai/types/discuss_types.py +++ b/google/generativeai/types/discuss_types.py @@ -19,6 +19,7 @@ from typing import Any, Dict, TypedDict, Union, Iterable, Optional, Tuple, List import google.ai.generativelanguage as glm +from google.generativeai.types import safety_types __all__ = [ "MessageDict", @@ -35,11 +36,12 @@ ] -class MessageDict(TypedDict, total=False): +class MessageDict(TypedDict): """A dict representation of a `glm.Message`.""" author: str content: str + citation_metadata: Optional[citation_types.CitationMetadataDict] MessageOptions = Union[str, MessageDict, glm.Message] @@ -129,7 +131,14 @@ class ChatResponse(abc.ABC): Note: The `temperature` field affects the variability of the responses. Low temperatures will return few candidates. Setting `temperature=0` is deterministic, so it will only ever return one candidate. - + filters: This indicates which `types.SafetyCategory`(s) blocked a + candidate from this response, the lowest `types.HarmProbability` + that triggered a block, and the `types.HarmThreshold` setting for that category. + This indicates the smallest change to the `types.SafetySettings` that would be + necessary to unblock at least 1 response. + + The blocking is configured by the `types.SafetySettings` in the request (or the + default `types.SafetySettings` of the API). messages: Contains all the `messages` that were passed when the model was called, plus the top `candidate` message. model: The model name. @@ -140,6 +149,7 @@ class ChatResponse(abc.ABC): candidate_count: The **maximum** number of generated response messages to return. top_k: The maximum number of tokens to consider when sampling. top_p: The maximum cumulative probability of tokens to consider when sampling. + """ model: str @@ -151,6 +161,7 @@ class ChatResponse(abc.ABC): candidates: List[MessageDict] top_p: Optional[float] = None top_k: Optional[float] = None + filters: List[safety.ContentFilter] @property @abc.abstractmethod diff --git a/tests/test_discuss.py b/tests/test_discuss.py index 32938c884..dd871df15 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -14,8 +14,8 @@ # limitations under the License. import unittest.mock -import asynctest -from asynctest import mock as async_mock +#import asynctest +#from asynctest import mock as async_mock import google.ai.generativelanguage as glm @@ -28,7 +28,7 @@ # TODO: replace returns with 'assert' statements -class UnitTests(parameterized.TestCase, asynctest.TestCase): +class UnitTests(parameterized.TestCase): def setUp(self): self.client = unittest.mock.MagicMock() @@ -271,7 +271,7 @@ def test_reply(self, kwargs): response = response.reply("again") - +''' class AsyncTests(parameterized.TestCase, asynctest.TestCase): async def test_chat_async(self): client = async_mock.MagicMock() @@ -325,7 +325,7 @@ async def fake_generate_message( observed_response.candidates, [{"author": "1", "content": "Why did the chicken cross the road?"}], ) - +''' if __name__ == "__main__": absltest.main() From 908b86eb7084d817b529b1abd54553303c21abcb Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Sun, 7 May 2023 21:18:26 -0700 Subject: [PATCH 04/15] split test_discuss for async --- tests/test_discuss_async.py | 331 ++++++++++++++++++++++++++++++++++++ 1 file changed, 331 insertions(+) create mode 100644 tests/test_discuss_async.py diff --git a/tests/test_discuss_async.py b/tests/test_discuss_async.py new file mode 100644 index 000000000..dd871df15 --- /dev/null +++ b/tests/test_discuss_async.py @@ -0,0 +1,331 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest.mock + +#import asynctest +#from asynctest import mock as async_mock + +import google.ai.generativelanguage as glm + +from google.generativeai import discuss +from google.generativeai import client +import google.generativeai as genai +from absl.testing import absltest +from absl.testing import parameterized + +# TODO: replace returns with 'assert' statements + + +class UnitTests(parameterized.TestCase): + def setUp(self): + self.client = unittest.mock.MagicMock() + + client.default_discuss_client = self.client + + self.observed_request = None + + def fake_generate_message( + request: glm.GenerateMessageRequest, + ) -> glm.GenerateMessageResponse: + self.observed_request = request + return glm.GenerateMessageResponse( + messages=request.prompt.messages, + candidates=[ + glm.Message(content="a", author="1"), + glm.Message(content="b", author="1"), + glm.Message(content="c", author="1"), + ], + ) + + self.client.generate_message = fake_generate_message + + @parameterized.named_parameters( + ["string", "Hello", ""], + ["dict", {"content": "Hello"}, ""], + ["dict_author", {"content": "Hello", "author": "me"}, "me"], + ["proto", glm.Message(content="Hello"), ""], + ["proto_author", glm.Message(content="Hello", author="me"), "me"], + ) + def test_make_message(self, message, author): + x = discuss._make_message(message) + self.assertIsInstance(x, glm.Message) + self.assertEqual("Hello", x.content) + self.assertEqual(author, x.author) + + @parameterized.named_parameters( + ["string", "Hello", ["Hello"]], + ["dict", {"content": "Hello"}, ["Hello"]], + ["proto", glm.Message(content="Hello"), ["Hello"]], + [ + "list", + ["hello0", {"content": "hello1"}, glm.Message(content="hello2")], + ["hello0", "hello1", "hello2"], + ], + ) + def test_make_messages(self, messages, expected_contents): + messages = discuss._make_messages(messages) + for expected, message in zip(expected_contents, messages): + self.assertEqual(expected, message.content) + + @parameterized.named_parameters( + ["tuple", ("hello", {"content": "goodbye"})], + ["iterable", iter(["hello", "goodbye"])], + ["dict", {"input": "hello", "output": "goodbye"}], + [ + "proto", + glm.Example( + input=glm.Message(content="hello"), + output=glm.Message(content="goodbye"), + ), + ], + ) + def test_make_example(self, example): + x = discuss._make_example(example) + self.assertIsInstance(x, glm.Example) + self.assertEqual("hello", x.input.content) + self.assertEqual("goodbye", x.output.content) + return + + @parameterized.named_parameters( + [ + "messages", + [ + "Hi", + {"content": "Hello!"}, + "what's your name?", + glm.Message(content="Dave, what's yours"), + ], + ], + [ + "examples", + [ + ("Hi", "Hello!"), + { + "input": "what's your name?", + "output": {"content": "Dave, what's yours"}, + }, + ], + ], + ) + def test_make_examples(self, examples): + examples = discuss._make_examples(examples) + self.assertLen(examples, 2) + self.assertEqual(examples[0].input.content, "Hi") + self.assertEqual(examples[0].output.content, "Hello!") + self.assertEqual(examples[1].input.content, "what's your name?") + self.assertEqual(examples[1].output.content, "Dave, what's yours") + + return + + def test_make_examples_from_example(self): + ex_dict = {"input": "hello", "output": "meow!"} + example = discuss._make_example(ex_dict) + examples1 = discuss._make_examples(ex_dict) + examples2 = discuss._make_examples(discuss._make_example(ex_dict)) + + self.assertEqual(example, examples1[0]) + self.assertEqual(example, examples2[0]) + + @parameterized.named_parameters( + ["str", "hello"], + ["message", glm.Message(content="hello")], + ["messages", ["hello"]], + ["dict", {"messages": "hello"}], + ["dict2", {"messages": ["hello"]}], + ["proto", glm.MessagePrompt(messages=[glm.Message(content="hello")])], + ) + def test_make_message_prompt_from_messages(self, prompt): + x = discuss._make_message_prompt(prompt) + self.assertIsInstance(x, glm.MessagePrompt) + self.assertEqual(x.messages[0].content, "hello") + return + + @parameterized.named_parameters( + [ + "dict", + [ + { + "context": "you are a cat", + "examples": ["are you hungry?", "meow!"], + "messages": "hello", + } + ], + {}, + ], + [ + "kwargs", + [], + { + "context": "you are a cat", + "examples": ["are you hungry?", "meow!"], + "messages": "hello", + }, + ], + [ + "proto", + [ + glm.MessagePrompt( + context="you are a cat", + examples=[ + glm.Example( + input=glm.Message(content="are you hungry?"), + output=glm.Message(content="meow!"), + ) + ], + messages=[glm.Message(content="hello")], + ) + ], + {}, + ], + ) + def test_make_message_prompt_from_prompt(self, args, kwargs): + x = discuss._make_message_prompt(*args, **kwargs) + self.assertIsInstance(x, glm.MessagePrompt) + self.assertEqual(x.context, "you are a cat") + self.assertEqual(x.examples[0].input.content, "are you hungry?") + self.assertEqual(x.examples[0].output.content, "meow!") + self.assertEqual(x.messages[0].content, "hello") + + def test_make_generate_message_request_nested( + self, + ): + request0 = discuss._make_generate_message_request( + **{ + "model": "Dave", + "context": "you are a cat", + "examples": ["hello", "meow", "are you hungry?", "meow!"], + "messages": "Please catch that mouse.", + "temperature": 0.2, + "candidate_count": 7, + } + ) + request1 = discuss._make_generate_message_request( + **{ + "model": "Dave", + "prompt": { + "context": "you are a cat", + "examples": ["hello", "meow", "are you hungry?", "meow!"], + "messages": "Please catch that mouse.", + }, + "temperature": 0.2, + "candidate_count": 7, + } + ) + + self.assertIsInstance(request0, glm.GenerateMessageRequest) + self.assertIsInstance(request1, glm.GenerateMessageRequest) + self.assertEqual(request0, request1) + + @parameterized.parameters( + {"prompt": {}, "context": "You are a cat."}, + {"prompt": {"context": "You are a cat."}, "examples": ["hello", "meow"]}, + {"prompt": {"examples": ["hello", "meow"]}, "messages": "hello"}, + ) + def test_make_generate_message_request_flat_prompt_conflict( + self, + context=None, + examples=None, + messages=None, + prompt=None, + ): + with self.assertRaises(ValueError): + x = discuss._make_generate_message_request( + model="test", + context=context, + examples=examples, + messages=messages, + prompt=prompt, + ) + + @parameterized.parameters( + {"kwargs": {"context": "You are a cat."}}, + {"kwargs": {"messages": "hello"}}, + {"kwargs": {"examples": [["a", "b"], ["c", "d"]]}}, + {"kwargs": {"messages": ["hello"], "examples": [["a", "b"], ["c", "d"]]}}, + ) + def test_reply(self, kwargs): + response = genai.chat(**kwargs) + first_messages = response.messages + + self.assertEqual("a", response.last) + self.assertEqual( + [ + {"author": "1", "content": "a"}, + {"author": "1", "content": "b"}, + {"author": "1", "content": "c"}, + ], + response.candidates, + ) + + response = response.reply("again") + +''' +class AsyncTests(parameterized.TestCase, asynctest.TestCase): + async def test_chat_async(self): + client = async_mock.MagicMock() + + observed_request = None + + async def fake_generate_message( + request: glm.GenerateMessageRequest, + ) -> glm.GenerateMessageResponse: + nonlocal observed_request + observed_request = request + return glm.GenerateMessageResponse( + candidates=[ + glm.Message( + author="1", content="Why did the chicken cross the road?" + ) + ] + ) + + client.generate_message = fake_generate_message + + observed_response = await discuss.chat_async( + model="models/bard", + context="Example Prompt", + examples=[["Example from human", "Example response from AI"]], + messages=["Tell me a joke"], + temperature=0.75, + candidate_count=1, + client=client, + ) + + self.assertEqual( + observed_request, + glm.GenerateMessageRequest( + model="models/bard", + prompt=glm.MessagePrompt( + context="Example Prompt", + examples=[ + glm.Example( + input=glm.Message(content="Example from human"), + output=glm.Message(content="Example response from AI"), + ) + ], + messages=[glm.Message(author="0", content="Tell me a joke")], + ), + temperature=0.75, + candidate_count=1, + ), + ) + self.assertEqual( + observed_response.candidates, + [{"author": "1", "content": "Why did the chicken cross the road?"}], + ) +''' + +if __name__ == "__main__": + absltest.main() From b020ca78d986b79f1f1308069ae1b97e7d5bec43 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Sun, 7 May 2023 21:20:12 -0700 Subject: [PATCH 05/15] Split async tests for discuss. --- tests/test_discuss.py | 59 --------- tests/test_discuss_async.py | 250 +----------------------------------- 2 files changed, 3 insertions(+), 306 deletions(-) diff --git a/tests/test_discuss.py b/tests/test_discuss.py index dd871df15..c6947c266 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -14,9 +14,6 @@ # limitations under the License. import unittest.mock -#import asynctest -#from asynctest import mock as async_mock - import google.ai.generativelanguage as glm from google.generativeai import discuss @@ -271,61 +268,5 @@ def test_reply(self, kwargs): response = response.reply("again") -''' -class AsyncTests(parameterized.TestCase, asynctest.TestCase): - async def test_chat_async(self): - client = async_mock.MagicMock() - - observed_request = None - - async def fake_generate_message( - request: glm.GenerateMessageRequest, - ) -> glm.GenerateMessageResponse: - nonlocal observed_request - observed_request = request - return glm.GenerateMessageResponse( - candidates=[ - glm.Message( - author="1", content="Why did the chicken cross the road?" - ) - ] - ) - - client.generate_message = fake_generate_message - - observed_response = await discuss.chat_async( - model="models/bard", - context="Example Prompt", - examples=[["Example from human", "Example response from AI"]], - messages=["Tell me a joke"], - temperature=0.75, - candidate_count=1, - client=client, - ) - - self.assertEqual( - observed_request, - glm.GenerateMessageRequest( - model="models/bard", - prompt=glm.MessagePrompt( - context="Example Prompt", - examples=[ - glm.Example( - input=glm.Message(content="Example from human"), - output=glm.Message(content="Example response from AI"), - ) - ], - messages=[glm.Message(author="0", content="Tell me a joke")], - ), - temperature=0.75, - candidate_count=1, - ), - ) - self.assertEqual( - observed_response.candidates, - [{"author": "1", "content": "Why did the chicken cross the road?"}], - ) -''' - if __name__ == "__main__": absltest.main() diff --git a/tests/test_discuss_async.py b/tests/test_discuss_async.py index dd871df15..5b5423fc8 100644 --- a/tests/test_discuss_async.py +++ b/tests/test_discuss_async.py @@ -14,8 +14,8 @@ # limitations under the License. import unittest.mock -#import asynctest -#from asynctest import mock as async_mock +import asynctest +from asynctest import mock as async_mock import google.ai.generativelanguage as glm @@ -28,250 +28,6 @@ # TODO: replace returns with 'assert' statements -class UnitTests(parameterized.TestCase): - def setUp(self): - self.client = unittest.mock.MagicMock() - - client.default_discuss_client = self.client - - self.observed_request = None - - def fake_generate_message( - request: glm.GenerateMessageRequest, - ) -> glm.GenerateMessageResponse: - self.observed_request = request - return glm.GenerateMessageResponse( - messages=request.prompt.messages, - candidates=[ - glm.Message(content="a", author="1"), - glm.Message(content="b", author="1"), - glm.Message(content="c", author="1"), - ], - ) - - self.client.generate_message = fake_generate_message - - @parameterized.named_parameters( - ["string", "Hello", ""], - ["dict", {"content": "Hello"}, ""], - ["dict_author", {"content": "Hello", "author": "me"}, "me"], - ["proto", glm.Message(content="Hello"), ""], - ["proto_author", glm.Message(content="Hello", author="me"), "me"], - ) - def test_make_message(self, message, author): - x = discuss._make_message(message) - self.assertIsInstance(x, glm.Message) - self.assertEqual("Hello", x.content) - self.assertEqual(author, x.author) - - @parameterized.named_parameters( - ["string", "Hello", ["Hello"]], - ["dict", {"content": "Hello"}, ["Hello"]], - ["proto", glm.Message(content="Hello"), ["Hello"]], - [ - "list", - ["hello0", {"content": "hello1"}, glm.Message(content="hello2")], - ["hello0", "hello1", "hello2"], - ], - ) - def test_make_messages(self, messages, expected_contents): - messages = discuss._make_messages(messages) - for expected, message in zip(expected_contents, messages): - self.assertEqual(expected, message.content) - - @parameterized.named_parameters( - ["tuple", ("hello", {"content": "goodbye"})], - ["iterable", iter(["hello", "goodbye"])], - ["dict", {"input": "hello", "output": "goodbye"}], - [ - "proto", - glm.Example( - input=glm.Message(content="hello"), - output=glm.Message(content="goodbye"), - ), - ], - ) - def test_make_example(self, example): - x = discuss._make_example(example) - self.assertIsInstance(x, glm.Example) - self.assertEqual("hello", x.input.content) - self.assertEqual("goodbye", x.output.content) - return - - @parameterized.named_parameters( - [ - "messages", - [ - "Hi", - {"content": "Hello!"}, - "what's your name?", - glm.Message(content="Dave, what's yours"), - ], - ], - [ - "examples", - [ - ("Hi", "Hello!"), - { - "input": "what's your name?", - "output": {"content": "Dave, what's yours"}, - }, - ], - ], - ) - def test_make_examples(self, examples): - examples = discuss._make_examples(examples) - self.assertLen(examples, 2) - self.assertEqual(examples[0].input.content, "Hi") - self.assertEqual(examples[0].output.content, "Hello!") - self.assertEqual(examples[1].input.content, "what's your name?") - self.assertEqual(examples[1].output.content, "Dave, what's yours") - - return - - def test_make_examples_from_example(self): - ex_dict = {"input": "hello", "output": "meow!"} - example = discuss._make_example(ex_dict) - examples1 = discuss._make_examples(ex_dict) - examples2 = discuss._make_examples(discuss._make_example(ex_dict)) - - self.assertEqual(example, examples1[0]) - self.assertEqual(example, examples2[0]) - - @parameterized.named_parameters( - ["str", "hello"], - ["message", glm.Message(content="hello")], - ["messages", ["hello"]], - ["dict", {"messages": "hello"}], - ["dict2", {"messages": ["hello"]}], - ["proto", glm.MessagePrompt(messages=[glm.Message(content="hello")])], - ) - def test_make_message_prompt_from_messages(self, prompt): - x = discuss._make_message_prompt(prompt) - self.assertIsInstance(x, glm.MessagePrompt) - self.assertEqual(x.messages[0].content, "hello") - return - - @parameterized.named_parameters( - [ - "dict", - [ - { - "context": "you are a cat", - "examples": ["are you hungry?", "meow!"], - "messages": "hello", - } - ], - {}, - ], - [ - "kwargs", - [], - { - "context": "you are a cat", - "examples": ["are you hungry?", "meow!"], - "messages": "hello", - }, - ], - [ - "proto", - [ - glm.MessagePrompt( - context="you are a cat", - examples=[ - glm.Example( - input=glm.Message(content="are you hungry?"), - output=glm.Message(content="meow!"), - ) - ], - messages=[glm.Message(content="hello")], - ) - ], - {}, - ], - ) - def test_make_message_prompt_from_prompt(self, args, kwargs): - x = discuss._make_message_prompt(*args, **kwargs) - self.assertIsInstance(x, glm.MessagePrompt) - self.assertEqual(x.context, "you are a cat") - self.assertEqual(x.examples[0].input.content, "are you hungry?") - self.assertEqual(x.examples[0].output.content, "meow!") - self.assertEqual(x.messages[0].content, "hello") - - def test_make_generate_message_request_nested( - self, - ): - request0 = discuss._make_generate_message_request( - **{ - "model": "Dave", - "context": "you are a cat", - "examples": ["hello", "meow", "are you hungry?", "meow!"], - "messages": "Please catch that mouse.", - "temperature": 0.2, - "candidate_count": 7, - } - ) - request1 = discuss._make_generate_message_request( - **{ - "model": "Dave", - "prompt": { - "context": "you are a cat", - "examples": ["hello", "meow", "are you hungry?", "meow!"], - "messages": "Please catch that mouse.", - }, - "temperature": 0.2, - "candidate_count": 7, - } - ) - - self.assertIsInstance(request0, glm.GenerateMessageRequest) - self.assertIsInstance(request1, glm.GenerateMessageRequest) - self.assertEqual(request0, request1) - - @parameterized.parameters( - {"prompt": {}, "context": "You are a cat."}, - {"prompt": {"context": "You are a cat."}, "examples": ["hello", "meow"]}, - {"prompt": {"examples": ["hello", "meow"]}, "messages": "hello"}, - ) - def test_make_generate_message_request_flat_prompt_conflict( - self, - context=None, - examples=None, - messages=None, - prompt=None, - ): - with self.assertRaises(ValueError): - x = discuss._make_generate_message_request( - model="test", - context=context, - examples=examples, - messages=messages, - prompt=prompt, - ) - - @parameterized.parameters( - {"kwargs": {"context": "You are a cat."}}, - {"kwargs": {"messages": "hello"}}, - {"kwargs": {"examples": [["a", "b"], ["c", "d"]]}}, - {"kwargs": {"messages": ["hello"], "examples": [["a", "b"], ["c", "d"]]}}, - ) - def test_reply(self, kwargs): - response = genai.chat(**kwargs) - first_messages = response.messages - - self.assertEqual("a", response.last) - self.assertEqual( - [ - {"author": "1", "content": "a"}, - {"author": "1", "content": "b"}, - {"author": "1", "content": "c"}, - ], - response.candidates, - ) - - response = response.reply("again") - -''' class AsyncTests(parameterized.TestCase, asynctest.TestCase): async def test_chat_async(self): client = async_mock.MagicMock() @@ -325,7 +81,7 @@ async def fake_generate_message( observed_response.candidates, [{"author": "1", "content": "Why did the chicken cross the road?"}], ) -''' + if __name__ == "__main__": absltest.main() From afbfae45ae988f225f6965c89daf31c50c03fb45 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Sun, 7 May 2023 22:51:41 -0700 Subject: [PATCH 06/15] Discuss: handle case where all messages are filtered --- google/generativeai/discuss.py | 21 +++++++++++++++++---- tests/test_discuss.py | 1 + 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index e63d9ff6a..8ce9210b7 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -390,8 +390,11 @@ def __init__(self, **kwargs): @property @set_doc(discuss_types.ChatResponse.last.__doc__) - def last(self) -> str: - return self.messages[-1]["content"] + def last(self) -> Optional[str]: + if self.messages[-1]: + return self.messages[-1]["content"] + else: + return None @last.setter def last(self, message: discuss_types.MessageOptions): @@ -406,9 +409,14 @@ def reply( raise TypeError( f"reply can't be called on an async client, use reply_async instead." ) + if self.last is None: + raise ValueError('The last response from the model did not return any candidates.\n' + 'Check the `.filters` attribute to see why the responses were filtered:\n' + f'{self.filters}') + request = self.to_dict() request.pop("candidates") - request.pop("filters") + request.pop("filters", None) request["messages"] = list(request["messages"]) request["messages"].append(_make_message(message)) request = _make_generate_message_request(**request) @@ -424,6 +432,7 @@ async def reply_async( ) request = self.to_dict() request.pop("candidates") + request.pop("filters") request["messages"] = list(request["messages"]) request["messages"].append(_make_message(message)) request = _make_generate_message_request(**request) @@ -444,7 +453,11 @@ def _build_chat_response( response = type(response).to_dict(response) response.pop("messages") - request["messages"].append(response["candidates"][0]) + if response["candidates"]: + last = response["candidates"][0] + else: + last = None + request["messages"].append(last) request.setdefault("temperature", None) request.setdefault("candidate_count", None) diff --git a/tests/test_discuss.py b/tests/test_discuss.py index c6947c266..c244792d2 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -268,5 +268,6 @@ def test_reply(self, kwargs): response = response.reply("again") + if __name__ == "__main__": absltest.main() From aa1ffd955933345835b51510cab352bf50204ec8 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Sun, 7 May 2023 22:55:26 -0700 Subject: [PATCH 07/15] fix imports --- google/generativeai/types/__init__.py | 5 +++++ google/generativeai/types/citation_types.py | 1 + google/generativeai/types/discuss_types.py | 7 ++++--- google/generativeai/types/safety_types.py | 3 ++- google/generativeai/types/text_types.py | 20 +++++++++++++++----- 5 files changed, 27 insertions(+), 9 deletions(-) diff --git a/google/generativeai/types/__init__.py b/google/generativeai/types/__init__.py index 463d9bdb3..818d85be8 100644 --- a/google/generativeai/types/__init__.py +++ b/google/generativeai/types/__init__.py @@ -17,6 +17,11 @@ from google.generativeai.types.discuss_types import * from google.generativeai.types.model_types import * from google.generativeai.types.text_types import * +from google.generativeai.types.citation_types import * +from google.generativeai.types.safety_types import * del discuss_types del model_types +del text_types +del citation_types +del safety_types \ No newline at end of file diff --git a/google/generativeai/types/citation_types.py b/google/generativeai/types/citation_types.py index fa8d5f296..c79bde621 100644 --- a/google/generativeai/types/citation_types.py +++ b/google/generativeai/types/citation_types.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, List from google.ai import generativelanguage as glm from google.generativeai import docstring_utils diff --git a/google/generativeai/types/discuss_types.py b/google/generativeai/types/discuss_types.py index 071f6788a..ceda56613 100644 --- a/google/generativeai/types/discuss_types.py +++ b/google/generativeai/types/discuss_types.py @@ -20,6 +20,7 @@ import google.ai.generativelanguage as glm from google.generativeai.types import safety_types +from google.generativeai.types import citation_types __all__ = [ "MessageDict", @@ -155,17 +156,17 @@ class ChatResponse(abc.ABC): model: str context: str examples: List[ExampleDict] - messages: List[MessageDict] + messages: List[Optional[MessageDict]] temperature: Optional[float] candidate_count: Optional[int] candidates: List[MessageDict] top_p: Optional[float] = None top_k: Optional[float] = None - filters: List[safety.ContentFilter] + filters: List[safety_types.ContentFilter] @property @abc.abstractmethod - def last(self) -> str: + def last(self) -> Optional[str]: """A settable property that provides simple access to the last response string A shortcut for `response.messages[0]['content']`. diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index aa159cd12..9e0e8f3d9 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -15,6 +15,7 @@ import enum from google.ai import generativelanguage as glm +from google.generativeai import docstring_utils from typing import TypedDict __all__ = [ @@ -31,7 +32,7 @@ # These are basic python enums, it's okay to expose them HarmCategory = glm.HarmCategory HarmProbability = glm.SafetyRating.HarmProbability -HarmBlockThreshold = glm.SafetySetting +HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold BlockReason = glm.ContentFilter.BlockedReason diff --git a/google/generativeai/types/text_types.py b/google/generativeai/types/text_types.py index 0abca273c..a42af928d 100644 --- a/google/generativeai/types/text_types.py +++ b/google/generativeai/types/text_types.py @@ -17,28 +17,38 @@ import dataclasses from typing import Any, Dict, Optional, List, Iterator, TypedDict -__all__ = ["Completion"] +from google.generativeai.types import safety_types +from google.generativeai.types import citation_types -class TextCandidate(TypedDict, total=False): +__all__ = ["TextResponse"] + + +class TextCompletion(TypedDict, total=False): output: str + safety_ratings: Optional[List[safety_types.SafetyRatingDict]] + citation_metadata: Optional[citation_types.CitationMetadataDict] @dataclasses.dataclass(init=False) -class Completion(abc.ABC): +class TextResponse(abc.ABC): """A text completion given a prompt from the model. - * Use `completion.candidates` to access all of the text completion options generated by the model. + Use `GenerateTextResponse.candidates` to access all the completions generated by the model. Attributes: candidates: A list of candidate text completions generated by the model. """ - candidates: List[TextCandidate] + candidates: List[TextCompletion] result: Optional[str] + filters: Optional[list[safety_types.ContentFilter]] + safety_feedback: Optional[list[safety_types.SafetyFeedbackDict]] def to_dict(self) -> Dict[str, Any]: result = { "candidates": self.candidates, + "filters": self.filters, + "safety_feedback": self.safety_feedback, } return result From 528b746e945047614c016ba18a23214823c0e028 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Sun, 7 May 2023 23:08:50 -0700 Subject: [PATCH 08/15] Add filters and safety settings to text --- google/generativeai/text.py | 14 ++++++++++++++ google/generativeai/types/text_types.py | 11 +++++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/google/generativeai/text.py b/google/generativeai/text.py index ae17e422d..6373555e2 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -24,6 +24,7 @@ from google.generativeai.client import get_default_text_client from google.generativeai.types import text_types from google.generativeai.types import model_types +from google.generativeai.types import safety_types def _make_text_prompt(prompt: Union[str, dict[str, str]]) -> glm.TextPrompt: @@ -44,6 +45,7 @@ def _make_generate_text_request( max_output_tokens: Optional[int] = None, top_p: Optional[int] = None, top_k: Optional[int] = None, + safety_settings: Optional[List[safety_types.SafetySetting]] = None, stop_sequences: Union[str, Iterable[str]] = None, ) -> glm.GenerateTextRequest: model = model_types.make_model_name(model) @@ -61,6 +63,7 @@ def _make_generate_text_request( max_output_tokens=max_output_tokens, top_p=top_p, top_k=top_k, + safety_settings=safety_settings, stop_sequences=stop_sequences, ) @@ -74,6 +77,7 @@ def generate_text( max_output_tokens: Optional[int] = None, top_p: Optional[float] = None, top_k: Optional[float] = None, + safety_settings: Optional[Iterable[safety.SafetySetting]] = None, stop_sequences: Union[str, Iterable[str]] = None, client: Optional[glm.TextServiceClient] = None, ) -> text_types.Completion: @@ -103,6 +107,15 @@ def generate_text( For example, if the sorted probabilities are `[0.5, 0.2, 0.1, 0.1, 0.05, 0.05]` a `top_p` of `0.8` will sample as `[0.625, 0.25, 0.125, 0, 0, 0]. + safety_settings: A list of unique `types.SafetySetting` instances for blocking unsafe content. + These will be enforced on the `prompt` and + `candidates`. There should not be more than one + setting for each `types.SafetyCategory` type. The API will block any prompts and + responses that fail to meet the thresholds set by these settings. This list + overrides the default settings for each `SafetyCategory` specified in the + safety_settings. If there is no `types.SafetySetting` for a given + `SafetyCategory` provided in the list, the API will use the default safety + setting for that category. stop_sequences: A set of up to 5 character sequences that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response. @@ -119,6 +132,7 @@ def generate_text( max_output_tokens=max_output_tokens, top_p=top_p, top_k=top_k, + safety_settings=safety_settings, stop_sequences=stop_sequences, ) diff --git a/google/generativeai/types/text_types.py b/google/generativeai/types/text_types.py index a42af928d..ebd5a6f25 100644 --- a/google/generativeai/types/text_types.py +++ b/google/generativeai/types/text_types.py @@ -21,7 +21,7 @@ from google.generativeai.types import citation_types -__all__ = ["TextResponse"] +__all__ = ["Completion"] class TextCompletion(TypedDict, total=False): @@ -31,15 +31,18 @@ class TextCompletion(TypedDict, total=False): @dataclasses.dataclass(init=False) -class TextResponse(abc.ABC): - """A text completion given a prompt from the model. +class Completion(abc.ABC): + """The result of the `1 given a prompt from the model. Use `GenerateTextResponse.candidates` to access all the completions generated by the model. Attributes: candidates: A list of candidate text completions generated by the model. + result: The output of the first candidate, + filters: Indicates the reasons why content may have been blocked + Either Unspecified, Safety, or Other. See `types.ContentFilter`. + safety_feedback: Indicates which safety settings blocked content in this result. """ - candidates: List[TextCompletion] result: Optional[str] filters: Optional[list[safety_types.ContentFilter]] From d301de6d422a52cf3e02b40a7028b570f096c86d Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Mon, 8 May 2023 08:55:40 -0700 Subject: [PATCH 09/15] Cleanup async tests --- tests/test_discuss_async.py | 69 ++++++++++++++++++++----------------- tests/test_text.py | 3 -- 2 files changed, 37 insertions(+), 35 deletions(-) diff --git a/tests/test_discuss_async.py b/tests/test_discuss_async.py index 5b5423fc8..2185f73c5 100644 --- a/tests/test_discuss_async.py +++ b/tests/test_discuss_async.py @@ -12,52 +12,57 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import unittest.mock -import asynctest -from asynctest import mock as async_mock +import sys +import unittest + +if sys.version_info < (3, 11): + import asynctest + from asynctest import mock as async_mock import google.ai.generativelanguage as glm from google.generativeai import discuss -from google.generativeai import client -import google.generativeai as genai from absl.testing import absltest from absl.testing import parameterized -# TODO: replace returns with 'assert' statements +bases = (parameterized.TestCase,) +if sys.version_info < (3, 11): + bases = bases + (asynctest.TestCase,) -class AsyncTests(parameterized.TestCase, asynctest.TestCase): - async def test_chat_async(self): - client = async_mock.MagicMock() +unittest.skipIf(sys.version_info >= (3,11), "asynctest is not suported on python 3.11+") +class AsyncTests(*bases): + if sys.version_info < (3, 11): + async def test_chat_async(self): + client = async_mock.MagicMock() - observed_request = None + observed_request = None - async def fake_generate_message( - request: glm.GenerateMessageRequest, - ) -> glm.GenerateMessageResponse: - nonlocal observed_request - observed_request = request - return glm.GenerateMessageResponse( - candidates=[ - glm.Message( - author="1", content="Why did the chicken cross the road?" - ) - ] - ) + async def fake_generate_message( + request: glm.GenerateMessageRequest, + ) -> glm.GenerateMessageResponse: + nonlocal observed_request + observed_request = request + return glm.GenerateMessageResponse( + candidates=[ + glm.Message( + author="1", content="Why did the chicken cross the road?" + ) + ] + ) - client.generate_message = fake_generate_message + client.generate_message = fake_generate_message - observed_response = await discuss.chat_async( - model="models/bard", - context="Example Prompt", - examples=[["Example from human", "Example response from AI"]], - messages=["Tell me a joke"], - temperature=0.75, - candidate_count=1, - client=client, - ) + observed_response = await discuss.chat_async( + model="models/bard", + context="Example Prompt", + examples=[["Example from human", "Example response from AI"]], + messages=["Tell me a joke"], + temperature=0.75, + candidate_count=1, + client=client, + ) self.assertEqual( observed_request, diff --git a/tests/test_text.py b/tests/test_text.py index 508437d19..e4521a67c 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -17,9 +17,6 @@ import unittest import unittest.mock as mock -import asynctest -from asynctest import mock as async_mock - import google.ai.generativelanguage as glm from google.generativeai import text as text_service From 0b55411fe6485cdc3fb31b699269ae2c860ee5cf Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Mon, 8 May 2023 09:08:47 -0700 Subject: [PATCH 10/15] Update test instructions. --- CONTRIBUTING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 08553673f..613fc896c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -65,7 +65,7 @@ This "editable" mode lets you edit the source without needing to reinstall the p Use the builtin unittest package: ``` -python -m unittest + python -m unittest discover --pattern '*test*.py' ``` Or to debug, use: From 862e7ecc3dd22ebca9fc8cd90858410f52f0fc86 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Mon, 8 May 2023 13:18:36 -0700 Subject: [PATCH 11/15] Test filters in chat --- google/generativeai/discuss.py | 5 +++ google/generativeai/types/safety_types.py | 6 +-- tests/test_discuss.py | 45 ++++++++++++++++++++--- tests/test_discuss_async.py | 42 ++++++++++----------- 4 files changed, 68 insertions(+), 30 deletions(-) diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index 8ce9210b7..98a8ae5cd 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -438,6 +438,9 @@ async def reply_async( request = _make_generate_message_request(**request) return await _generate_response_async(request=request, client=self._client) +def _convert_filters_to_enums(filters): + for f in filters: + f['reason'] = safety_types.BlockedReason(f['reason']) def _build_chat_response( request: glm.GenerateMessageRequest, @@ -453,6 +456,8 @@ def _build_chat_response( response = type(response).to_dict(response) response.pop("messages") + _convert_filters_to_enums(response['filters']) + if response["candidates"]: last = response["candidates"][0] else: diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index 9e0e8f3d9..a9e569cf7 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -22,7 +22,7 @@ "HarmCategory", "HarmProbability", "HarmBlockThreshold", - "BlockReason", + "BlockedReason", "ContentFilter", "SafetyRatingDict", "SafetySetting", @@ -33,11 +33,11 @@ HarmCategory = glm.HarmCategory HarmProbability = glm.SafetyRating.HarmProbability HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold -BlockReason = glm.ContentFilter.BlockedReason +BlockedReason = glm.ContentFilter.BlockedReason class ContentFilter(TypedDict): - reason: BlockReason + reason: BlockedReason message: str __doc__ = docstring_utils.strip_oneof(glm.ContentFilter.__doc__) diff --git a/tests/test_discuss.py b/tests/test_discuss.py index c244792d2..e01fb47f2 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy + import unittest.mock import google.ai.generativelanguage as glm @@ -33,12 +35,7 @@ def setUp(self): self.observed_request = None - def fake_generate_message( - request: glm.GenerateMessageRequest, - ) -> glm.GenerateMessageResponse: - self.observed_request = request - return glm.GenerateMessageResponse( - messages=request.prompt.messages, + self.mock_response = glm.GenerateMessageResponse( candidates=[ glm.Message(content="a", author="1"), glm.Message(content="b", author="1"), @@ -46,6 +43,14 @@ def fake_generate_message( ], ) + def fake_generate_message( + request: glm.GenerateMessageRequest, + ) -> glm.GenerateMessageResponse: + self.observed_request = request + response = copy.copy(self.mock_response) + response.messages = request.prompt.messages + return response + self.client.generate_message = fake_generate_message @parameterized.named_parameters( @@ -268,6 +273,34 @@ def test_reply(self, kwargs): response = response.reply("again") + def test_receive_and_reply_with_filters(self): + + self.mock_response = mock_response = glm.GenerateMessageResponse( + candidates=[glm.Message(content="a", author="1")], + filters=[ + glm.ContentFilter(reason=glm.ContentFilter.BlockedReason.SAFETY, message='unsafe'), + glm.ContentFilter(reason=glm.ContentFilter.BlockedReason.OTHER),] + ) + response = discuss.chat(messages="do filters work?") + + filters = response.filters + self.assertLen(filters, 2) + self.assertIsInstance(filters[0]['reason'], glm.ContentFilter.BlockedReason) + self.assertEquals(filters[0]['reason'], glm.ContentFilter.BlockedReason.SAFETY) + self.assertEquals(filters[0]['message'], 'unsafe') + + self.mock_response = glm.GenerateMessageResponse( + candidates=[glm.Message(content="a", author="1")], + filters=[ + glm.ContentFilter(reason=glm.ContentFilter.BlockedReason.BLOCKED_REASON_UNSPECIFIED)] + ) + + response = response.reply('Does reply work?') + filters = response.filters + self.assertLen(filters, 1) + self.assertIsInstance(filters[0]['reason'], glm.ContentFilter.BlockedReason) + self.assertEquals(filters[0]['reason'], glm.ContentFilter.BlockedReason.BLOCKED_REASON_UNSPECIFIED) + if __name__ == "__main__": absltest.main() diff --git a/tests/test_discuss_async.py b/tests/test_discuss_async.py index 2185f73c5..48e9a7c52 100644 --- a/tests/test_discuss_async.py +++ b/tests/test_discuss_async.py @@ -64,28 +64,28 @@ async def fake_generate_message( client=client, ) - self.assertEqual( - observed_request, - glm.GenerateMessageRequest( - model="models/bard", - prompt=glm.MessagePrompt( - context="Example Prompt", - examples=[ - glm.Example( - input=glm.Message(content="Example from human"), - output=glm.Message(content="Example response from AI"), - ) - ], - messages=[glm.Message(author="0", content="Tell me a joke")], + self.assertEqual( + observed_request, + glm.GenerateMessageRequest( + model="models/bard", + prompt=glm.MessagePrompt( + context="Example Prompt", + examples=[ + glm.Example( + input=glm.Message(content="Example from human"), + output=glm.Message(content="Example response from AI"), + ) + ], + messages=[glm.Message(author="0", content="Tell me a joke")], + ), + temperature=0.75, + candidate_count=1, ), - temperature=0.75, - candidate_count=1, - ), - ) - self.assertEqual( - observed_response.candidates, - [{"author": "1", "content": "Why did the chicken cross the road?"}], - ) + ) + self.assertEqual( + observed_response.candidates, + [{"author": "1", "content": "Why did the chicken cross the road?"}], + ) if __name__ == "__main__": From 049f0d7fc79e753641c632cf9826f843a30acfcd Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Mon, 8 May 2023 14:12:31 -0700 Subject: [PATCH 12/15] test generate_text safety settings, filters and feedback. --- google/generativeai/discuss.py | 17 ++-- google/generativeai/docstring_utils.py | 3 +- google/generativeai/text.py | 7 +- google/generativeai/types/__init__.py | 2 +- google/generativeai/types/discuss_types.py | 2 +- google/generativeai/types/safety_types.py | 31 ++++++- google/generativeai/types/text_types.py | 5 +- setup.py | 4 +- tests/test_discuss.py | 41 +++++---- tests/test_discuss_async.py | 13 ++- tests/test_text.py | 101 ++++++++++++++++++--- 11 files changed, 171 insertions(+), 55 deletions(-) diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index 98a8ae5cd..e094dde65 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -392,9 +392,9 @@ def __init__(self, **kwargs): @set_doc(discuss_types.ChatResponse.last.__doc__) def last(self) -> Optional[str]: if self.messages[-1]: - return self.messages[-1]["content"] + return self.messages[-1]["content"] else: - return None + return None @last.setter def last(self, message: discuss_types.MessageOptions): @@ -410,9 +410,11 @@ def reply( f"reply can't be called on an async client, use reply_async instead." ) if self.last is None: - raise ValueError('The last response from the model did not return any candidates.\n' - 'Check the `.filters` attribute to see why the responses were filtered:\n' - f'{self.filters}') + raise ValueError( + "The last response from the model did not return any candidates.\n" + "Check the `.filters` attribute to see why the responses were filtered:\n" + f"{self.filters}" + ) request = self.to_dict() request.pop("candidates") @@ -438,9 +440,6 @@ async def reply_async( request = _make_generate_message_request(**request) return await _generate_response_async(request=request, client=self._client) -def _convert_filters_to_enums(filters): - for f in filters: - f['reason'] = safety_types.BlockedReason(f['reason']) def _build_chat_response( request: glm.GenerateMessageRequest, @@ -456,7 +455,7 @@ def _build_chat_response( response = type(response).to_dict(response) response.pop("messages") - _convert_filters_to_enums(response['filters']) + safety_types.convert_filters_to_enums(response["filters"]) if response["candidates"]: last = response["candidates"][0] diff --git a/google/generativeai/docstring_utils.py b/google/generativeai/docstring_utils.py index 0356563ca..f403316c6 100644 --- a/google/generativeai/docstring_utils.py +++ b/google/generativeai/docstring_utils.py @@ -13,8 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. + def strip_oneof(docstring): lines = docstring.splitlines() lines = [line for line in lines if ".. _oneof:" not in line] lines = [line for line in lines if "This field is a member of `oneof`_" not in line] - return "\n".join(lines) \ No newline at end of file + return "\n".join(lines) diff --git a/google/generativeai/text.py b/google/generativeai/text.py index 6373555e2..b709b3e96 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -45,7 +45,7 @@ def _make_generate_text_request( max_output_tokens: Optional[int] = None, top_p: Optional[int] = None, top_k: Optional[int] = None, - safety_settings: Optional[List[safety_types.SafetySetting]] = None, + safety_settings: Optional[List[safety_types.SafetySettingDict]] = None, stop_sequences: Union[str, Iterable[str]] = None, ) -> glm.GenerateTextRequest: model = model_types.make_model_name(model) @@ -77,7 +77,7 @@ def generate_text( max_output_tokens: Optional[int] = None, top_p: Optional[float] = None, top_k: Optional[float] = None, - safety_settings: Optional[Iterable[safety.SafetySetting]] = None, + safety_settings: Optional[Iterable[safety.SafetySettingDict]] = None, stop_sequences: Union[str, Iterable[str]] = None, client: Optional[glm.TextServiceClient] = None, ) -> text_types.Completion: @@ -159,6 +159,9 @@ def _generate_response( response = client.generate_text(request) response = type(response).to_dict(response) + safety_types.convert_filters_to_enums(response["filters"]) + safety_types.convert_safety_feedback_to_enums(response["safety_feedback"]) + return Completion(_client=client, **response) diff --git a/google/generativeai/types/__init__.py b/google/generativeai/types/__init__.py index 818d85be8..0bdf3a713 100644 --- a/google/generativeai/types/__init__.py +++ b/google/generativeai/types/__init__.py @@ -24,4 +24,4 @@ del model_types del text_types del citation_types -del safety_types \ No newline at end of file +del safety_types diff --git a/google/generativeai/types/discuss_types.py b/google/generativeai/types/discuss_types.py index ceda56613..ae6cea84d 100644 --- a/google/generativeai/types/discuss_types.py +++ b/google/generativeai/types/discuss_types.py @@ -162,7 +162,7 @@ class ChatResponse(abc.ABC): candidates: List[MessageDict] top_p: Optional[float] = None top_k: Optional[float] = None - filters: List[safety_types.ContentFilter] + filters: List[safety_types.ContentFilterDict] @property @abc.abstractmethod diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index a9e569cf7..d4e6e2615 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -23,9 +23,9 @@ "HarmProbability", "HarmBlockThreshold", "BlockedReason", - "ContentFilter", + "ContentFilterDict", "SafetyRatingDict", - "SafetySetting", + "SafetySettingDict", "SafetyFeedbackDict", ] @@ -36,13 +36,18 @@ BlockedReason = glm.ContentFilter.BlockedReason -class ContentFilter(TypedDict): +class ContentFilterDict(TypedDict): reason: BlockedReason message: str __doc__ = docstring_utils.strip_oneof(glm.ContentFilter.__doc__) +def convert_filters_to_enums(filters): + for f in filters: + f["reason"] = BlockedReason(f["reason"]) + + class SafetyRatingDict(TypedDict): category: HarmCategory probability: HarmProbability @@ -50,15 +55,31 @@ class SafetyRatingDict(TypedDict): __doc__ = docstring_utils.strip_oneof(glm.SafetyRating.__doc__) -class SafetySetting(TypedDict): +def convert_rating_to_enum(setting): + setting["category"] = HarmCategory(setting["category"]) + setting["probability"] = HarmProbability(setting["probability"]) + + +class SafetySettingDict(TypedDict): category: HarmCategory threshold: HarmBlockThreshold __doc__ = docstring_utils.strip_oneof(glm.SafetySetting.__doc__) +def convert_setting_to_enum(setting): + setting["category"] = HarmCategory(setting["category"]) + setting["threshold"] = HarmBlockThreshold(setting["threshold"]) + + class SafetyFeedbackDict(TypedDict): rating: SafetyRatingDict - setting: SafetySetting + setting: SafetySettingDict __doc__ = docstring_utils.strip_oneof(glm.SafetyFeedback.__doc__) + + +def convert_safety_feedback_to_enums(safety_feedback): + for sf in safety_feedback: + convert_rating_to_enum(sf["rating"]) + convert_setting_to_enum(sf["setting"]) diff --git a/google/generativeai/types/text_types.py b/google/generativeai/types/text_types.py index ebd5a6f25..ba9e26381 100644 --- a/google/generativeai/types/text_types.py +++ b/google/generativeai/types/text_types.py @@ -32,7 +32,7 @@ class TextCompletion(TypedDict, total=False): @dataclasses.dataclass(init=False) class Completion(abc.ABC): - """The result of the `1 given a prompt from the model. + """The result returned by `generativeai.generate_text`. Use `GenerateTextResponse.candidates` to access all the completions generated by the model. @@ -43,9 +43,10 @@ class Completion(abc.ABC): Either Unspecified, Safety, or Other. See `types.ContentFilter`. safety_feedback: Indicates which safety settings blocked content in this result. """ + candidates: List[TextCompletion] result: Optional[str] - filters: Optional[list[safety_types.ContentFilter]] + filters: Optional[list[safety_types.ContentFilterDict]] safety_feedback: Optional[list[safety_types.SafetyFeedbackDict]] def to_dict(self) -> Dict[str, Any]: diff --git a/setup.py b/setup.py index d937de24f..5a420ddaf 100644 --- a/setup.py +++ b/setup.py @@ -34,9 +34,7 @@ else: release_status = "Development Status :: 5 - Production/Stable" -dependencies = [ - "google-ai-generativelanguage==0.2.0" -] +dependencies = ["google-ai-generativelanguage==0.2.0"] extras_require = { "dev": [ diff --git a/tests/test_discuss.py b/tests/test_discuss.py index e01fb47f2..e3205aa28 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -21,6 +21,8 @@ from google.generativeai import discuss from google.generativeai import client import google.generativeai as genai +from google.generativeai.types import safety_types + from absl.testing import absltest from absl.testing import parameterized @@ -36,12 +38,12 @@ def setUp(self): self.observed_request = None self.mock_response = glm.GenerateMessageResponse( - candidates=[ - glm.Message(content="a", author="1"), - glm.Message(content="b", author="1"), - glm.Message(content="c", author="1"), - ], - ) + candidates=[ + glm.Message(content="a", author="1"), + glm.Message(content="b", author="1"), + glm.Message(content="c", author="1"), + ], + ) def fake_generate_message( request: glm.GenerateMessageRequest, @@ -274,32 +276,39 @@ def test_reply(self, kwargs): response = response.reply("again") def test_receive_and_reply_with_filters(self): - self.mock_response = mock_response = glm.GenerateMessageResponse( candidates=[glm.Message(content="a", author="1")], filters=[ - glm.ContentFilter(reason=glm.ContentFilter.BlockedReason.SAFETY, message='unsafe'), - glm.ContentFilter(reason=glm.ContentFilter.BlockedReason.OTHER),] + glm.ContentFilter( + reason=safety_types.BlockedReason.SAFETY, message="unsafe" + ), + glm.ContentFilter(reason=safety_types.BlockedReason.OTHER), + ], ) response = discuss.chat(messages="do filters work?") filters = response.filters self.assertLen(filters, 2) - self.assertIsInstance(filters[0]['reason'], glm.ContentFilter.BlockedReason) - self.assertEquals(filters[0]['reason'], glm.ContentFilter.BlockedReason.SAFETY) - self.assertEquals(filters[0]['message'], 'unsafe') + self.assertIsInstance(filters[0]["reason"], safety_types.BlockedReason) + self.assertEqual(filters[0]["reason"], safety_types.BlockedReason.SAFETY) + self.assertEqual(filters[0]["message"], "unsafe") self.mock_response = glm.GenerateMessageResponse( candidates=[glm.Message(content="a", author="1")], filters=[ - glm.ContentFilter(reason=glm.ContentFilter.BlockedReason.BLOCKED_REASON_UNSPECIFIED)] + glm.ContentFilter( + reason=safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED + ) + ], ) - response = response.reply('Does reply work?') + response = response.reply("Does reply work?") filters = response.filters self.assertLen(filters, 1) - self.assertIsInstance(filters[0]['reason'], glm.ContentFilter.BlockedReason) - self.assertEquals(filters[0]['reason'], glm.ContentFilter.BlockedReason.BLOCKED_REASON_UNSPECIFIED) + self.assertIsInstance(filters[0]["reason"], safety_types.BlockedReason) + self.assertEqual( + filters[0]["reason"], safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED + ) if __name__ == "__main__": diff --git a/tests/test_discuss_async.py b/tests/test_discuss_async.py index 48e9a7c52..ac364f34c 100644 --- a/tests/test_discuss_async.py +++ b/tests/test_discuss_async.py @@ -17,8 +17,8 @@ import unittest if sys.version_info < (3, 11): - import asynctest - from asynctest import mock as async_mock + import asynctest + from asynctest import mock as async_mock import google.ai.generativelanguage as glm @@ -29,11 +29,16 @@ bases = (parameterized.TestCase,) if sys.version_info < (3, 11): - bases = bases + (asynctest.TestCase,) + bases = bases + (asynctest.TestCase,) + +unittest.skipIf( + sys.version_info >= (3, 11), "asynctest is not suported on python 3.11+" +) + -unittest.skipIf(sys.version_info >= (3,11), "asynctest is not suported on python 3.11+") class AsyncTests(*bases): if sys.version_info < (3, 11): + async def test_chat_async(self): client = async_mock.MagicMock() diff --git a/tests/test_text.py b/tests/test_text.py index e4521a67c..6ece720f6 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -21,6 +21,7 @@ from google.generativeai import text as text_service from google.generativeai import client +from google.generativeai.types import safety_types from absl.testing import absltest from absl.testing import parameterized @@ -33,17 +34,19 @@ def setUp(self): self.observed_request = None + self.mock_response = glm.GenerateTextResponse( + candidates=[ + glm.TextCompletion(output=" road?"), + glm.TextCompletion(output=" bridge?"), + glm.TextCompletion(output=" river?"), + ] + ) + def fake_generate_completion( request: glm.GenerateTextRequest, ) -> glm.GenerateTextResponse: self.observed_request = request - return glm.GenerateTextResponse( - candidates=[ - glm.TextCompletion(output=" road?"), - glm.TextCompletion(output=" bridge?"), - glm.TextCompletion(output=" river?"), - ] - ) + return self.mock_response self.client.generate_text = fake_generate_completion @@ -75,7 +78,6 @@ def test_make_generate_text_request(self, prompt): self.assertEqual("models/chat-lamda-001", x.model) self.assertIsInstance(x, glm.GenerateTextRequest) - # @unittest.skipUnless(os.getenv('API_KEY'), "No API key set") @parameterized.named_parameters( [ dict( @@ -130,9 +132,9 @@ def test_generate_response(self, *, prompt, **kwargs): self.assertEqual( complete.candidates, [ - {"output": " road?", 'safety_ratings': []}, - {"output": " bridge?", 'safety_ratings': []}, - {"output": " river?", 'safety_ratings': []}, + {"output": " road?", "safety_ratings": []}, + {"output": " bridge?", "safety_ratings": []}, + {"output": " river?", "safety_ratings": []}, ], ) @@ -147,6 +149,83 @@ def test_stop_string(self): stop_sequences=["stop"], ), ) + # Just make sure it made it into the request object. + self.assertEqual(self.observed_request.stop_sequences, ["stop"]) + + def test_safety_settings(self): + result = text_service.generate_text( + prompt="Say something wicked.", + safety_settings=[ + { + "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE, + }, + { + "category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, + "threshold": safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + }, + ], + ) + + # Just make sure it made it into the request object. + self.assertEqual( + self.observed_request.safety_settings[0].category, + safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + ) + + def test_filters(self): + self.mock_response = glm.GenerateTextResponse( + candidates=[{"output": "hello"}], + filters=[ + {"reason": safety_types.BlockedReason.SAFETY, "message": "not safe"} + ], + ) + + response = text_service.generate_text(prompt="do filters work?") + self.assertIsInstance(response.filters[0]["reason"], safety_types.BlockedReason) + self.assertEqual( + response.filters[0]["reason"], safety_types.BlockedReason.SAFETY + ) + + def test_safety_feedback(self): + self.mock_response = glm.GenerateTextResponse( + candidates=[{"output": "hello"}], + safety_feedback=[ + { + "rating": { + "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "probability": safety_types.HarmProbability.HIGH, + }, + "setting": { + "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE, + }, + } + ], + ) + + response = text_service.generate_text(prompt="does safety feedback work?") + self.assertIsInstance( + response.safety_feedback[0]["rating"]["probability"], + safety_types.HarmProbability, + ) + self.assertEqual( + response.safety_feedback[0]["rating"]["probability"], + safety_types.HarmProbability.HIGH, + ) + + self.assertIsInstance( + response.safety_feedback[0]["setting"]["category"], + safety_types.HarmCategory, + ) + self.assertEqual( + response.safety_feedback[0]["setting"]["category"], + safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + ) + + #def test_candidate_safety_feedback(self): + + #def test_candidate_citations(self): if __name__ == "__main__": From a4f57ea97d7eace309920711bcf2352472b65567 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Mon, 8 May 2023 14:46:35 -0700 Subject: [PATCH 13/15] Test candidates['safety_ratings']. --- google/generativeai/client.py | 14 +++--- google/generativeai/discuss.py | 2 +- google/generativeai/text.py | 7 ++- google/generativeai/types/safety_types.py | 54 ++++++++++++++++++----- tests/test_text.py | 28 +++++++++++- 5 files changed, 82 insertions(+), 23 deletions(-) diff --git a/google/generativeai/client.py b/google/generativeai/client.py index af239d1c7..830b229d9 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -45,7 +45,7 @@ def configure( # but that seems rare. Users that need it can just switch to the low level API. transport: Union[str, None] = None, client_options: Union[client_options_lib.ClientOptions, dict, None] = None, - client_info: Optional[gapic_v1.client_info.ClientInfo] = None + client_info: Optional[gapic_v1.client_info.ClientInfo] = None, ): """Captures default client configuration. @@ -86,13 +86,13 @@ def configure( user_agent = f"{USER_AGENT}/{version.__version__}" if client_info: - # Be respectful of any existing agent setting. - if client_info.user_agent: - client_info.user_agent += f" {user_agent}" - else: - client_info.user_agent = user_agent + # Be respectful of any existing agent setting. + if client_info.user_agent: + client_info.user_agent += f" {user_agent}" + else: + client_info.user_agent = user_agent else: - client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent) + client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent) new_default_client_config = { "credentials": credentials, diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index e094dde65..e18da84d7 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -455,7 +455,7 @@ def _build_chat_response( response = type(response).to_dict(response) response.pop("messages") - safety_types.convert_filters_to_enums(response["filters"]) + response["filters"] = safety_types.convert_filters_to_enums(response["filters"]) if response["candidates"]: last = response["candidates"][0] diff --git a/google/generativeai/text.py b/google/generativeai/text.py index b709b3e96..aa13a24a9 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -159,8 +159,11 @@ def _generate_response( response = client.generate_text(request) response = type(response).to_dict(response) - safety_types.convert_filters_to_enums(response["filters"]) - safety_types.convert_safety_feedback_to_enums(response["safety_feedback"]) + response["filters"] = safety_types.convert_filters_to_enums(response["filters"]) + response["safety_feedback"] = safety_types.convert_safety_feedback_to_enums( + response["safety_feedback"] + ) + response['candidates'] = safety_types.convert_candidate_enums(response['candidates']) return Completion(_client=client, **response) diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index d4e6e2615..5edce558a 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -16,7 +16,7 @@ import enum from google.ai import generativelanguage as glm from google.generativeai import docstring_utils -from typing import TypedDict +from typing import Iterable, List, TypedDict __all__ = [ "HarmCategory", @@ -43,9 +43,13 @@ class ContentFilterDict(TypedDict): __doc__ = docstring_utils.strip_oneof(glm.ContentFilter.__doc__) -def convert_filters_to_enums(filters): +def convert_filters_to_enums(filters: Iterable[dict]) -> List[ContentFilterDict]: + result = [] for f in filters: + f = f.copy() f["reason"] = BlockedReason(f["reason"]) + result.append(f) + return result class SafetyRatingDict(TypedDict): @@ -55,9 +59,18 @@ class SafetyRatingDict(TypedDict): __doc__ = docstring_utils.strip_oneof(glm.SafetyRating.__doc__) -def convert_rating_to_enum(setting): - setting["category"] = HarmCategory(setting["category"]) - setting["probability"] = HarmProbability(setting["probability"]) +def convert_rating_to_enum(rating: dict) -> SafetyRatingDict: + return { + "category": HarmCategory(rating["category"]), + "probability": HarmProbability(rating["probability"]), + } + + +def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]: + result = [] + for r in ratings: + result.append(convert_rating_to_enum(r)) + return result class SafetySettingDict(TypedDict): @@ -67,9 +80,11 @@ class SafetySettingDict(TypedDict): __doc__ = docstring_utils.strip_oneof(glm.SafetySetting.__doc__) -def convert_setting_to_enum(setting): - setting["category"] = HarmCategory(setting["category"]) - setting["threshold"] = HarmBlockThreshold(setting["threshold"]) +def convert_setting_to_enum(setting: dict) -> SafetySettingDict: + return { + "category": HarmCategory(setting["category"]), + "threshold": HarmBlockThreshold(setting["threshold"]), + } class SafetyFeedbackDict(TypedDict): @@ -79,7 +94,24 @@ class SafetyFeedbackDict(TypedDict): __doc__ = docstring_utils.strip_oneof(glm.SafetyFeedback.__doc__) -def convert_safety_feedback_to_enums(safety_feedback): +def convert_safety_feedback_to_enums( + safety_feedback: Iterable[dict], +) -> List[SafetyFeedbackDict]: + result = [] for sf in safety_feedback: - convert_rating_to_enum(sf["rating"]) - convert_setting_to_enum(sf["setting"]) + result.append( + { + "rating": convert_rating_to_enum(sf["rating"]), + "setting": convert_setting_to_enum(sf["setting"]), + } + ) + return result + + +def convert_candidate_enums(candidates): + result = [] + for candidate in candidates: + candidate = candidate.copy() + candidate['safety_ratings'] = convert_ratings_to_enum(candidate['safety_ratings']) + result.append(candidate) + return result \ No newline at end of file diff --git a/tests/test_text.py b/tests/test_text.py index 6ece720f6..3a4a5c91a 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -223,9 +223,33 @@ def test_safety_feedback(self): safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, ) - #def test_candidate_safety_feedback(self): + def test_candidate_safety_feedback(self): + self.mock_response = glm.GenerateTextResponse( + candidates=[ + { + "output": "hello", + "safety_ratings": [ + { + "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "probability": safety_types.HarmProbability.HIGH, + }, + { + "category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, + "probability": safety_types.HarmProbability.LOW, + }, + ], + } + ] + ) + + result = text_service.generate_text(prompt="Write a story from the ER.") + self.assertIsInstance(result.candidates[0]['safety_ratings'][0]['category'], safety_types.HarmCategory) + self.assertEqual(result.candidates[0]['safety_ratings'][0]['category'], safety_types.HarmCategory.HARM_CATEGORY_MEDICAL) + + self.assertIsInstance(result.candidates[0]['safety_ratings'][0]['probability'], safety_types.HarmProbability) + self.assertEqual(result.candidates[0]['safety_ratings'][0]['probability'], safety_types.HarmProbability.HIGH) - #def test_candidate_citations(self): + # def test_candidate_citations(self): if __name__ == "__main__": From 2c8ce6feb8a06721a482794b8d59e3a21bc8f10e Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Mon, 8 May 2023 15:02:38 -0700 Subject: [PATCH 14/15] Test text candidate citations and safety ratings. --- google/generativeai/text.py | 4 +- google/generativeai/types/safety_types.py | 6 ++- tests/test_text.py | 46 ++++++++++++++++++++--- 3 files changed, 47 insertions(+), 9 deletions(-) diff --git a/google/generativeai/text.py b/google/generativeai/text.py index aa13a24a9..171597725 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -163,7 +163,9 @@ def _generate_response( response["safety_feedback"] = safety_types.convert_safety_feedback_to_enums( response["safety_feedback"] ) - response['candidates'] = safety_types.convert_candidate_enums(response['candidates']) + response["candidates"] = safety_types.convert_candidate_enums( + response["candidates"] + ) return Completion(_client=client, **response) diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index 5edce558a..7dbd76dca 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -112,6 +112,8 @@ def convert_candidate_enums(candidates): result = [] for candidate in candidates: candidate = candidate.copy() - candidate['safety_ratings'] = convert_ratings_to_enum(candidate['safety_ratings']) + candidate["safety_ratings"] = convert_ratings_to_enum( + candidate["safety_ratings"] + ) result.append(candidate) - return result \ No newline at end of file + return result diff --git a/tests/test_text.py b/tests/test_text.py index 3a4a5c91a..63eccb8c6 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -167,7 +167,6 @@ def test_safety_settings(self): ], ) - # Just make sure it made it into the request object. self.assertEqual( self.observed_request.safety_settings[0].category, safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, @@ -243,13 +242,48 @@ def test_candidate_safety_feedback(self): ) result = text_service.generate_text(prompt="Write a story from the ER.") - self.assertIsInstance(result.candidates[0]['safety_ratings'][0]['category'], safety_types.HarmCategory) - self.assertEqual(result.candidates[0]['safety_ratings'][0]['category'], safety_types.HarmCategory.HARM_CATEGORY_MEDICAL) + self.assertIsInstance( + result.candidates[0]["safety_ratings"][0]["category"], + safety_types.HarmCategory, + ) + self.assertEqual( + result.candidates[0]["safety_ratings"][0]["category"], + safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + ) - self.assertIsInstance(result.candidates[0]['safety_ratings'][0]['probability'], safety_types.HarmProbability) - self.assertEqual(result.candidates[0]['safety_ratings'][0]['probability'], safety_types.HarmProbability.HIGH) + self.assertIsInstance( + result.candidates[0]["safety_ratings"][0]["probability"], + safety_types.HarmProbability, + ) + self.assertEqual( + result.candidates[0]["safety_ratings"][0]["probability"], + safety_types.HarmProbability.HIGH, + ) - # def test_candidate_citations(self): + def test_candidate_citations(self): + self.mock_response = glm.GenerateTextResponse( + candidates=[ + { + "output": "Hello Google!", + "citation_metadata": { + "citation_sources": [ + { + "start_index": 6, + "end_index": 12, + "uri": "https://google.com", + } + ] + }, + } + ] + ) + result = text_service.generate_text(prompt="Hi my name is Google") + self.assertEqual( + result.candidates[0]["citation_metadata"]["citation_sources"][0][ + "start_index" + ], + 6, + ) if __name__ == "__main__": From bbd3ae0a5be9622e1027f5845cc8fc1bb2d9ad07 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Mon, 8 May 2023 15:10:07 -0700 Subject: [PATCH 15/15] Test chat citations. --- tests/test_discuss.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/test_discuss.py b/tests/test_discuss.py index e3205aa28..c2dff55e5 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -311,5 +311,39 @@ def test_receive_and_reply_with_filters(self): ) + def test_chat_citations(self): + self.mock_response = mock_response = glm.GenerateMessageResponse( + candidates=[{'content':"Hello google!", 'author':"1", "citation_metadata": { + "citation_sources": [ + { + "start_index": 6, + "end_index": 12, + "uri": "https://google.com", + } + ] + }, + }], + ) + + response = discuss.chat(messages="Do citations work?") + + self.assertEqual( + response.candidates[0]["citation_metadata"]["citation_sources"][0][ + "start_index" + ], + 6, + ) + + response = response.reply("What about a second time?") + + self.assertEqual( + response.candidates[0]["citation_metadata"]["citation_sources"][0][ + "start_index" + ], + 6, + ) + self.assertLen(response.messages, 4) + + if __name__ == "__main__": absltest.main()