diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 08553673f..613fc896c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -65,7 +65,7 @@ This "editable" mode lets you edit the source without needing to reinstall the p Use the builtin unittest package: ``` -python -m unittest + python -m unittest discover --pattern '*test*.py' ``` Or to debug, use: diff --git a/google/generativeai/client.py b/google/generativeai/client.py index af239d1c7..830b229d9 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -45,7 +45,7 @@ def configure( # but that seems rare. Users that need it can just switch to the low level API. transport: Union[str, None] = None, client_options: Union[client_options_lib.ClientOptions, dict, None] = None, - client_info: Optional[gapic_v1.client_info.ClientInfo] = None + client_info: Optional[gapic_v1.client_info.ClientInfo] = None, ): """Captures default client configuration. @@ -86,13 +86,13 @@ def configure( user_agent = f"{USER_AGENT}/{version.__version__}" if client_info: - # Be respectful of any existing agent setting. - if client_info.user_agent: - client_info.user_agent += f" {user_agent}" - else: - client_info.user_agent = user_agent + # Be respectful of any existing agent setting. + if client_info.user_agent: + client_info.user_agent += f" {user_agent}" + else: + client_info.user_agent = user_agent else: - client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent) + client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent) new_default_client_config = { "credentials": credentials, diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index f463e2abd..e18da84d7 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -25,6 +25,7 @@ from google.generativeai.client import get_default_discuss_async_client from google.generativeai.types import discuss_types from google.generativeai.types import model_types +from google.generativeai.types import safety_types def _make_message(content: discuss_types.MessageOptions) -> glm.Message: @@ -389,8 +390,11 @@ def __init__(self, **kwargs): @property @set_doc(discuss_types.ChatResponse.last.__doc__) - def last(self) -> str: - return self.messages[-1]["content"] + def last(self) -> Optional[str]: + if self.messages[-1]: + return self.messages[-1]["content"] + else: + return None @last.setter def last(self, message: discuss_types.MessageOptions): @@ -405,8 +409,16 @@ def reply( raise TypeError( f"reply can't be called on an async client, use reply_async instead." ) + if self.last is None: + raise ValueError( + "The last response from the model did not return any candidates.\n" + "Check the `.filters` attribute to see why the responses were filtered:\n" + f"{self.filters}" + ) + request = self.to_dict() request.pop("candidates") + request.pop("filters", None) request["messages"] = list(request["messages"]) request["messages"].append(_make_message(message)) request = _make_generate_message_request(**request) @@ -422,6 +434,7 @@ async def reply_async( ) request = self.to_dict() request.pop("candidates") + request.pop("filters") request["messages"] = list(request["messages"]) request["messages"].append(_make_message(message)) request = _make_generate_message_request(**request) @@ -440,12 +453,20 @@ def _build_chat_response( request["messages"] = prompt["messages"] response = type(response).to_dict(response) - request["messages"].append(response["candidates"][0]) + response.pop("messages") + + response["filters"] = safety_types.convert_filters_to_enums(response["filters"]) + + if response["candidates"]: + last = response["candidates"][0] + else: + last = None + request["messages"].append(last) request.setdefault("temperature", None) request.setdefault("candidate_count", None) return ChatResponse( - _client=client, candidates=response["candidates"], **request + _client=client, **response, **request ) # pytype: disable=missing-parameter diff --git a/google/generativeai/docstring_utils.py b/google/generativeai/docstring_utils.py new file mode 100644 index 000000000..f403316c6 --- /dev/null +++ b/google/generativeai/docstring_utils.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def strip_oneof(docstring): + lines = docstring.splitlines() + lines = [line for line in lines if ".. _oneof:" not in line] + lines = [line for line in lines if "This field is a member of `oneof`_" not in line] + return "\n".join(lines) diff --git a/google/generativeai/text.py b/google/generativeai/text.py index ae17e422d..171597725 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -24,6 +24,7 @@ from google.generativeai.client import get_default_text_client from google.generativeai.types import text_types from google.generativeai.types import model_types +from google.generativeai.types import safety_types def _make_text_prompt(prompt: Union[str, dict[str, str]]) -> glm.TextPrompt: @@ -44,6 +45,7 @@ def _make_generate_text_request( max_output_tokens: Optional[int] = None, top_p: Optional[int] = None, top_k: Optional[int] = None, + safety_settings: Optional[List[safety_types.SafetySettingDict]] = None, stop_sequences: Union[str, Iterable[str]] = None, ) -> glm.GenerateTextRequest: model = model_types.make_model_name(model) @@ -61,6 +63,7 @@ def _make_generate_text_request( max_output_tokens=max_output_tokens, top_p=top_p, top_k=top_k, + safety_settings=safety_settings, stop_sequences=stop_sequences, ) @@ -74,6 +77,7 @@ def generate_text( max_output_tokens: Optional[int] = None, top_p: Optional[float] = None, top_k: Optional[float] = None, + safety_settings: Optional[Iterable[safety.SafetySettingDict]] = None, stop_sequences: Union[str, Iterable[str]] = None, client: Optional[glm.TextServiceClient] = None, ) -> text_types.Completion: @@ -103,6 +107,15 @@ def generate_text( For example, if the sorted probabilities are `[0.5, 0.2, 0.1, 0.1, 0.05, 0.05]` a `top_p` of `0.8` will sample as `[0.625, 0.25, 0.125, 0, 0, 0]. + safety_settings: A list of unique `types.SafetySetting` instances for blocking unsafe content. + These will be enforced on the `prompt` and + `candidates`. There should not be more than one + setting for each `types.SafetyCategory` type. The API will block any prompts and + responses that fail to meet the thresholds set by these settings. This list + overrides the default settings for each `SafetyCategory` specified in the + safety_settings. If there is no `types.SafetySetting` for a given + `SafetyCategory` provided in the list, the API will use the default safety + setting for that category. stop_sequences: A set of up to 5 character sequences that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response. @@ -119,6 +132,7 @@ def generate_text( max_output_tokens=max_output_tokens, top_p=top_p, top_k=top_k, + safety_settings=safety_settings, stop_sequences=stop_sequences, ) @@ -145,6 +159,14 @@ def _generate_response( response = client.generate_text(request) response = type(response).to_dict(response) + response["filters"] = safety_types.convert_filters_to_enums(response["filters"]) + response["safety_feedback"] = safety_types.convert_safety_feedback_to_enums( + response["safety_feedback"] + ) + response["candidates"] = safety_types.convert_candidate_enums( + response["candidates"] + ) + return Completion(_client=client, **response) diff --git a/google/generativeai/types/__init__.py b/google/generativeai/types/__init__.py index 463d9bdb3..0bdf3a713 100644 --- a/google/generativeai/types/__init__.py +++ b/google/generativeai/types/__init__.py @@ -17,6 +17,11 @@ from google.generativeai.types.discuss_types import * from google.generativeai.types.model_types import * from google.generativeai.types.text_types import * +from google.generativeai.types.citation_types import * +from google.generativeai.types.safety_types import * del discuss_types del model_types +del text_types +del citation_types +del safety_types diff --git a/google/generativeai/types/citation_types.py b/google/generativeai/types/citation_types.py new file mode 100644 index 000000000..c79bde621 --- /dev/null +++ b/google/generativeai/types/citation_types.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, List + +from google.ai import generativelanguage as glm +from google.generativeai import docstring_utils +from typing import TypedDict + +__all__ = [ + "CitationMetadataDict", + "CitationSourceDict", +] + + +class CitationSourceDict(TypedDict): + start_index: Optional[int] + end_index: Optional[int] + uri: Optional[str] + license: Optional[str] + + __doc__ = docstring_utils.strip_oneof(glm.CitationSource.__doc__) + + +class CitationMetadataDict(TypedDict): + citation_sources = Optional[List[CitationSourceDict]] + + __doc__ = docstring_utils.strip_oneof(glm.CitationMetadata.__doc__) diff --git a/google/generativeai/types/discuss_types.py b/google/generativeai/types/discuss_types.py index 88c6ebeb5..ae6cea84d 100644 --- a/google/generativeai/types/discuss_types.py +++ b/google/generativeai/types/discuss_types.py @@ -19,6 +19,8 @@ from typing import Any, Dict, TypedDict, Union, Iterable, Optional, Tuple, List import google.ai.generativelanguage as glm +from google.generativeai.types import safety_types +from google.generativeai.types import citation_types __all__ = [ "MessageDict", @@ -35,11 +37,12 @@ ] -class MessageDict(TypedDict, total=False): +class MessageDict(TypedDict): """A dict representation of a `glm.Message`.""" author: str content: str + citation_metadata: Optional[citation_types.CitationMetadataDict] MessageOptions = Union[str, MessageDict, glm.Message] @@ -129,7 +132,14 @@ class ChatResponse(abc.ABC): Note: The `temperature` field affects the variability of the responses. Low temperatures will return few candidates. Setting `temperature=0` is deterministic, so it will only ever return one candidate. - + filters: This indicates which `types.SafetyCategory`(s) blocked a + candidate from this response, the lowest `types.HarmProbability` + that triggered a block, and the `types.HarmThreshold` setting for that category. + This indicates the smallest change to the `types.SafetySettings` that would be + necessary to unblock at least 1 response. + + The blocking is configured by the `types.SafetySettings` in the request (or the + default `types.SafetySettings` of the API). messages: Contains all the `messages` that were passed when the model was called, plus the top `candidate` message. model: The model name. @@ -140,21 +150,23 @@ class ChatResponse(abc.ABC): candidate_count: The **maximum** number of generated response messages to return. top_k: The maximum number of tokens to consider when sampling. top_p: The maximum cumulative probability of tokens to consider when sampling. + """ model: str context: str examples: List[ExampleDict] - messages: List[MessageDict] + messages: List[Optional[MessageDict]] temperature: Optional[float] candidate_count: Optional[int] candidates: List[MessageDict] top_p: Optional[float] = None top_k: Optional[float] = None + filters: List[safety_types.ContentFilterDict] @property @abc.abstractmethod - def last(self) -> str: + def last(self) -> Optional[str]: """A settable property that provides simple access to the last response string A shortcut for `response.messages[0]['content']`. diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py new file mode 100644 index 000000000..7dbd76dca --- /dev/null +++ b/google/generativeai/types/safety_types.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +from google.ai import generativelanguage as glm +from google.generativeai import docstring_utils +from typing import Iterable, List, TypedDict + +__all__ = [ + "HarmCategory", + "HarmProbability", + "HarmBlockThreshold", + "BlockedReason", + "ContentFilterDict", + "SafetyRatingDict", + "SafetySettingDict", + "SafetyFeedbackDict", +] + +# These are basic python enums, it's okay to expose them +HarmCategory = glm.HarmCategory +HarmProbability = glm.SafetyRating.HarmProbability +HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold +BlockedReason = glm.ContentFilter.BlockedReason + + +class ContentFilterDict(TypedDict): + reason: BlockedReason + message: str + + __doc__ = docstring_utils.strip_oneof(glm.ContentFilter.__doc__) + + +def convert_filters_to_enums(filters: Iterable[dict]) -> List[ContentFilterDict]: + result = [] + for f in filters: + f = f.copy() + f["reason"] = BlockedReason(f["reason"]) + result.append(f) + return result + + +class SafetyRatingDict(TypedDict): + category: HarmCategory + probability: HarmProbability + + __doc__ = docstring_utils.strip_oneof(glm.SafetyRating.__doc__) + + +def convert_rating_to_enum(rating: dict) -> SafetyRatingDict: + return { + "category": HarmCategory(rating["category"]), + "probability": HarmProbability(rating["probability"]), + } + + +def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]: + result = [] + for r in ratings: + result.append(convert_rating_to_enum(r)) + return result + + +class SafetySettingDict(TypedDict): + category: HarmCategory + threshold: HarmBlockThreshold + + __doc__ = docstring_utils.strip_oneof(glm.SafetySetting.__doc__) + + +def convert_setting_to_enum(setting: dict) -> SafetySettingDict: + return { + "category": HarmCategory(setting["category"]), + "threshold": HarmBlockThreshold(setting["threshold"]), + } + + +class SafetyFeedbackDict(TypedDict): + rating: SafetyRatingDict + setting: SafetySettingDict + + __doc__ = docstring_utils.strip_oneof(glm.SafetyFeedback.__doc__) + + +def convert_safety_feedback_to_enums( + safety_feedback: Iterable[dict], +) -> List[SafetyFeedbackDict]: + result = [] + for sf in safety_feedback: + result.append( + { + "rating": convert_rating_to_enum(sf["rating"]), + "setting": convert_setting_to_enum(sf["setting"]), + } + ) + return result + + +def convert_candidate_enums(candidates): + result = [] + for candidate in candidates: + candidate = candidate.copy() + candidate["safety_ratings"] = convert_ratings_to_enum( + candidate["safety_ratings"] + ) + result.append(candidate) + return result diff --git a/google/generativeai/types/text_types.py b/google/generativeai/types/text_types.py index 0abca273c..ba9e26381 100644 --- a/google/generativeai/types/text_types.py +++ b/google/generativeai/types/text_types.py @@ -17,28 +17,42 @@ import dataclasses from typing import Any, Dict, Optional, List, Iterator, TypedDict +from google.generativeai.types import safety_types +from google.generativeai.types import citation_types + + __all__ = ["Completion"] -class TextCandidate(TypedDict, total=False): +class TextCompletion(TypedDict, total=False): output: str + safety_ratings: Optional[List[safety_types.SafetyRatingDict]] + citation_metadata: Optional[citation_types.CitationMetadataDict] @dataclasses.dataclass(init=False) class Completion(abc.ABC): - """A text completion given a prompt from the model. + """The result returned by `generativeai.generate_text`. - * Use `completion.candidates` to access all of the text completion options generated by the model. + Use `GenerateTextResponse.candidates` to access all the completions generated by the model. Attributes: candidates: A list of candidate text completions generated by the model. + result: The output of the first candidate, + filters: Indicates the reasons why content may have been blocked + Either Unspecified, Safety, or Other. See `types.ContentFilter`. + safety_feedback: Indicates which safety settings blocked content in this result. """ - candidates: List[TextCandidate] + candidates: List[TextCompletion] result: Optional[str] + filters: Optional[list[safety_types.ContentFilterDict]] + safety_feedback: Optional[list[safety_types.SafetyFeedbackDict]] def to_dict(self) -> Dict[str, Any]: result = { "candidates": self.candidates, + "filters": self.filters, + "safety_feedback": self.safety_feedback, } return result diff --git a/google/generativeai/version.py b/google/generativeai/version.py index fa14c226f..3ca712222 100644 --- a/google/generativeai/version.py +++ b/google/generativeai/version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.1.0rc1" +__version__ = "0.1.0rc2" diff --git a/setup.py b/setup.py index a2dbe719a..5a420ddaf 100644 --- a/setup.py +++ b/setup.py @@ -34,9 +34,7 @@ else: release_status = "Development Status :: 5 - Production/Stable" -dependencies = [ - "google-ai-generativelanguage==0.1.0" -] +dependencies = ["google-ai-generativelanguage==0.2.0"] extras_require = { "dev": [ diff --git a/tests/test_discuss.py b/tests/test_discuss.py index 32938c884..c2dff55e5 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -12,23 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import unittest.mock +import copy -import asynctest -from asynctest import mock as async_mock +import unittest.mock import google.ai.generativelanguage as glm from google.generativeai import discuss from google.generativeai import client import google.generativeai as genai +from google.generativeai.types import safety_types + from absl.testing import absltest from absl.testing import parameterized # TODO: replace returns with 'assert' statements -class UnitTests(parameterized.TestCase, asynctest.TestCase): +class UnitTests(parameterized.TestCase): def setUp(self): self.client = unittest.mock.MagicMock() @@ -36,18 +37,21 @@ def setUp(self): self.observed_request = None + self.mock_response = glm.GenerateMessageResponse( + candidates=[ + glm.Message(content="a", author="1"), + glm.Message(content="b", author="1"), + glm.Message(content="c", author="1"), + ], + ) + def fake_generate_message( request: glm.GenerateMessageRequest, ) -> glm.GenerateMessageResponse: self.observed_request = request - return glm.GenerateMessageResponse( - messages=request.prompt.messages, - candidates=[ - glm.Message(content="a", author="1"), - glm.Message(content="b", author="1"), - glm.Message(content="c", author="1"), - ], - ) + response = copy.copy(self.mock_response) + response.messages = request.prompt.messages + return response self.client.generate_message = fake_generate_message @@ -271,60 +275,74 @@ def test_reply(self, kwargs): response = response.reply("again") + def test_receive_and_reply_with_filters(self): + self.mock_response = mock_response = glm.GenerateMessageResponse( + candidates=[glm.Message(content="a", author="1")], + filters=[ + glm.ContentFilter( + reason=safety_types.BlockedReason.SAFETY, message="unsafe" + ), + glm.ContentFilter(reason=safety_types.BlockedReason.OTHER), + ], + ) + response = discuss.chat(messages="do filters work?") + + filters = response.filters + self.assertLen(filters, 2) + self.assertIsInstance(filters[0]["reason"], safety_types.BlockedReason) + self.assertEqual(filters[0]["reason"], safety_types.BlockedReason.SAFETY) + self.assertEqual(filters[0]["message"], "unsafe") + + self.mock_response = glm.GenerateMessageResponse( + candidates=[glm.Message(content="a", author="1")], + filters=[ + glm.ContentFilter( + reason=safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED + ) + ], + ) -class AsyncTests(parameterized.TestCase, asynctest.TestCase): - async def test_chat_async(self): - client = async_mock.MagicMock() - - observed_request = None - - async def fake_generate_message( - request: glm.GenerateMessageRequest, - ) -> glm.GenerateMessageResponse: - nonlocal observed_request - observed_request = request - return glm.GenerateMessageResponse( - candidates=[ - glm.Message( - author="1", content="Why did the chicken cross the road?" - ) - ] - ) + response = response.reply("Does reply work?") + filters = response.filters + self.assertLen(filters, 1) + self.assertIsInstance(filters[0]["reason"], safety_types.BlockedReason) + self.assertEqual( + filters[0]["reason"], safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED + ) - client.generate_message = fake_generate_message - observed_response = await discuss.chat_async( - model="models/bard", - context="Example Prompt", - examples=[["Example from human", "Example response from AI"]], - messages=["Tell me a joke"], - temperature=0.75, - candidate_count=1, - client=client, + def test_chat_citations(self): + self.mock_response = mock_response = glm.GenerateMessageResponse( + candidates=[{'content':"Hello google!", 'author':"1", "citation_metadata": { + "citation_sources": [ + { + "start_index": 6, + "end_index": 12, + "uri": "https://google.com", + } + ] + }, + }], ) + response = discuss.chat(messages="Do citations work?") + self.assertEqual( - observed_request, - glm.GenerateMessageRequest( - model="models/bard", - prompt=glm.MessagePrompt( - context="Example Prompt", - examples=[ - glm.Example( - input=glm.Message(content="Example from human"), - output=glm.Message(content="Example response from AI"), - ) - ], - messages=[glm.Message(author="0", content="Tell me a joke")], - ), - temperature=0.75, - candidate_count=1, - ), + response.candidates[0]["citation_metadata"]["citation_sources"][0][ + "start_index" + ], + 6, ) + + response = response.reply("What about a second time?") + self.assertEqual( - observed_response.candidates, - [{"author": "1", "content": "Why did the chicken cross the road?"}], + response.candidates[0]["citation_metadata"]["citation_sources"][0][ + "start_index" + ], + 6, ) + self.assertLen(response.messages, 4) if __name__ == "__main__": diff --git a/tests/test_discuss_async.py b/tests/test_discuss_async.py new file mode 100644 index 000000000..ac364f34c --- /dev/null +++ b/tests/test_discuss_async.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +if sys.version_info < (3, 11): + import asynctest + from asynctest import mock as async_mock + +import google.ai.generativelanguage as glm + +from google.generativeai import discuss +from absl.testing import absltest +from absl.testing import parameterized + +bases = (parameterized.TestCase,) + +if sys.version_info < (3, 11): + bases = bases + (asynctest.TestCase,) + +unittest.skipIf( + sys.version_info >= (3, 11), "asynctest is not suported on python 3.11+" +) + + +class AsyncTests(*bases): + if sys.version_info < (3, 11): + + async def test_chat_async(self): + client = async_mock.MagicMock() + + observed_request = None + + async def fake_generate_message( + request: glm.GenerateMessageRequest, + ) -> glm.GenerateMessageResponse: + nonlocal observed_request + observed_request = request + return glm.GenerateMessageResponse( + candidates=[ + glm.Message( + author="1", content="Why did the chicken cross the road?" + ) + ] + ) + + client.generate_message = fake_generate_message + + observed_response = await discuss.chat_async( + model="models/bard", + context="Example Prompt", + examples=[["Example from human", "Example response from AI"]], + messages=["Tell me a joke"], + temperature=0.75, + candidate_count=1, + client=client, + ) + + self.assertEqual( + observed_request, + glm.GenerateMessageRequest( + model="models/bard", + prompt=glm.MessagePrompt( + context="Example Prompt", + examples=[ + glm.Example( + input=glm.Message(content="Example from human"), + output=glm.Message(content="Example response from AI"), + ) + ], + messages=[glm.Message(author="0", content="Tell me a joke")], + ), + temperature=0.75, + candidate_count=1, + ), + ) + self.assertEqual( + observed_response.candidates, + [{"author": "1", "content": "Why did the chicken cross the road?"}], + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/test_text.py b/tests/test_text.py index cf41d128b..63eccb8c6 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -17,13 +17,11 @@ import unittest import unittest.mock as mock -import asynctest -from asynctest import mock as async_mock - import google.ai.generativelanguage as glm from google.generativeai import text as text_service from google.generativeai import client +from google.generativeai.types import safety_types from absl.testing import absltest from absl.testing import parameterized @@ -36,17 +34,19 @@ def setUp(self): self.observed_request = None + self.mock_response = glm.GenerateTextResponse( + candidates=[ + glm.TextCompletion(output=" road?"), + glm.TextCompletion(output=" bridge?"), + glm.TextCompletion(output=" river?"), + ] + ) + def fake_generate_completion( request: glm.GenerateTextRequest, ) -> glm.GenerateTextResponse: self.observed_request = request - return glm.GenerateTextResponse( - candidates=[ - glm.TextCompletion(output=" road?"), - glm.TextCompletion(output=" bridge?"), - glm.TextCompletion(output=" river?"), - ] - ) + return self.mock_response self.client.generate_text = fake_generate_completion @@ -78,7 +78,6 @@ def test_make_generate_text_request(self, prompt): self.assertEqual("models/chat-lamda-001", x.model) self.assertIsInstance(x, glm.GenerateTextRequest) - # @unittest.skipUnless(os.getenv('API_KEY'), "No API key set") @parameterized.named_parameters( [ dict( @@ -133,9 +132,9 @@ def test_generate_response(self, *, prompt, **kwargs): self.assertEqual( complete.candidates, [ - {"output": " road?"}, - {"output": " bridge?"}, - {"output": " river?"}, + {"output": " road?", "safety_ratings": []}, + {"output": " bridge?", "safety_ratings": []}, + {"output": " river?", "safety_ratings": []}, ], ) @@ -150,6 +149,141 @@ def test_stop_string(self): stop_sequences=["stop"], ), ) + # Just make sure it made it into the request object. + self.assertEqual(self.observed_request.stop_sequences, ["stop"]) + + def test_safety_settings(self): + result = text_service.generate_text( + prompt="Say something wicked.", + safety_settings=[ + { + "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE, + }, + { + "category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, + "threshold": safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + }, + ], + ) + + self.assertEqual( + self.observed_request.safety_settings[0].category, + safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + ) + + def test_filters(self): + self.mock_response = glm.GenerateTextResponse( + candidates=[{"output": "hello"}], + filters=[ + {"reason": safety_types.BlockedReason.SAFETY, "message": "not safe"} + ], + ) + + response = text_service.generate_text(prompt="do filters work?") + self.assertIsInstance(response.filters[0]["reason"], safety_types.BlockedReason) + self.assertEqual( + response.filters[0]["reason"], safety_types.BlockedReason.SAFETY + ) + + def test_safety_feedback(self): + self.mock_response = glm.GenerateTextResponse( + candidates=[{"output": "hello"}], + safety_feedback=[ + { + "rating": { + "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "probability": safety_types.HarmProbability.HIGH, + }, + "setting": { + "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE, + }, + } + ], + ) + + response = text_service.generate_text(prompt="does safety feedback work?") + self.assertIsInstance( + response.safety_feedback[0]["rating"]["probability"], + safety_types.HarmProbability, + ) + self.assertEqual( + response.safety_feedback[0]["rating"]["probability"], + safety_types.HarmProbability.HIGH, + ) + + self.assertIsInstance( + response.safety_feedback[0]["setting"]["category"], + safety_types.HarmCategory, + ) + self.assertEqual( + response.safety_feedback[0]["setting"]["category"], + safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + ) + + def test_candidate_safety_feedback(self): + self.mock_response = glm.GenerateTextResponse( + candidates=[ + { + "output": "hello", + "safety_ratings": [ + { + "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "probability": safety_types.HarmProbability.HIGH, + }, + { + "category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, + "probability": safety_types.HarmProbability.LOW, + }, + ], + } + ] + ) + + result = text_service.generate_text(prompt="Write a story from the ER.") + self.assertIsInstance( + result.candidates[0]["safety_ratings"][0]["category"], + safety_types.HarmCategory, + ) + self.assertEqual( + result.candidates[0]["safety_ratings"][0]["category"], + safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + ) + + self.assertIsInstance( + result.candidates[0]["safety_ratings"][0]["probability"], + safety_types.HarmProbability, + ) + self.assertEqual( + result.candidates[0]["safety_ratings"][0]["probability"], + safety_types.HarmProbability.HIGH, + ) + + def test_candidate_citations(self): + self.mock_response = glm.GenerateTextResponse( + candidates=[ + { + "output": "Hello Google!", + "citation_metadata": { + "citation_sources": [ + { + "start_index": 6, + "end_index": 12, + "uri": "https://google.com", + } + ] + }, + } + ] + ) + result = text_service.generate_text(prompt="Hi my name is Google") + self.assertEqual( + result.candidates[0]["citation_metadata"]["citation_sources"][0][ + "start_index" + ], + 6, + ) if __name__ == "__main__":