From 1b1d8833d4e25877256a21693d161d32ecd48d2e Mon Sep 17 00:00:00 2001 From: Ethan Wang Date: Fri, 17 May 2024 07:13:54 -0700 Subject: [PATCH 01/17] Restrict Harm category to the sublist only Gemini support (#295) * Restrict Harm category to the sublist only Gemini support * Update text.py * Update safety_types.py * Update safety_types.py * Update safety_types.py * split module Change-Id: Ia94b262d4e27511ca2e4eeb02cb5bd617a772463 * add palm safety Change-Id: Ia1cb199148619ebbc26638d5983b435245904971 * switch imports Change-Id: I2853a88d7acc51a78174c97e30bde8eb24e1d457 --------- Co-authored-by: Mark Daoust --- google/generativeai/answer.py | 4 +- google/generativeai/discuss.py | 4 +- google/generativeai/generative_models.py | 8 +- google/generativeai/text.py | 16 +- google/generativeai/types/discuss_types.py | 4 +- .../generativeai/types/palm_safety_types.py | 286 ++++++++++++++++++ google/generativeai/types/safety_types.py | 179 +++++------ google/generativeai/types/text_types.py | 8 +- tests/test_discuss.py | 16 +- tests/test_text.py | 54 ++-- 10 files changed, 418 insertions(+), 161 deletions(-) create mode 100644 google/generativeai/types/palm_safety_types.py diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index f17a82a17..637002052 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -206,9 +206,7 @@ def _make_generate_answer_request( contents = content_types.to_contents(contents) if safety_settings: - safety_settings = safety_types.normalize_safety_settings( - safety_settings, harm_category_set="new" - ) + safety_settings = safety_types.normalize_safety_settings(safety_settings) if inline_passages is not None and semantic_retriever is not None: raise ValueError( diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index 0cc342096..81e087aa0 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -27,7 +27,7 @@ from google.generativeai import string_utils from google.generativeai.types import discuss_types from google.generativeai.types import model_types -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types def _make_message(content: discuss_types.MessageOptions) -> glm.Message: @@ -521,7 +521,7 @@ def _build_chat_response( response = type(response).to_dict(response) response.pop("messages") - response["filters"] = safety_types.convert_filters_to_enums(response["filters"]) + response["filters"] = palm_safety_types.convert_filters_to_enums(response["filters"]) if response["candidates"]: last = response["candidates"][0] diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index a0e7df1e2..4d71baf48 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -79,9 +79,7 @@ def __init__( if "/" not in model_name: model_name = "models/" + model_name self._model_name = model_name - self._safety_settings = safety_types.to_easy_safety_dict( - safety_settings, harm_category_set="new" - ) + self._safety_settings = safety_types.to_easy_safety_dict(safety_settings) self._generation_config = generation_types.to_generation_config_dict(generation_config) self._tools = content_types.to_function_library(tools) @@ -149,10 +147,10 @@ def _prepare_request( merged_gc = self._generation_config.copy() merged_gc.update(generation_config) - safety_settings = safety_types.to_easy_safety_dict(safety_settings, harm_category_set="new") + safety_settings = safety_types.to_easy_safety_dict(safety_settings) merged_ss = self._safety_settings.copy() merged_ss.update(safety_settings) - merged_ss = safety_types.normalize_safety_settings(merged_ss, harm_category_set="new") + merged_ss = safety_types.normalize_safety_settings(merged_ss) return glm.GenerateContentRequest( model=self._model_name, diff --git a/google/generativeai/text.py b/google/generativeai/text.py index 3a147f945..e51090e1f 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -26,7 +26,7 @@ from google.generativeai.types import text_types from google.generativeai.types import model_types from google.generativeai import models -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types DEFAULT_TEXT_MODEL = "models/text-bison-001" EMBEDDING_MAX_BATCH_SIZE = 100 @@ -81,7 +81,7 @@ def _make_generate_text_request( max_output_tokens: int | None = None, top_p: int | None = None, top_k: int | None = None, - safety_settings: safety_types.SafetySettingOptions | None = None, + safety_settings: palm_safety_types.SafetySettingOptions | None = None, stop_sequences: str | Iterable[str] | None = None, ) -> glm.GenerateTextRequest: """ @@ -108,9 +108,7 @@ def _make_generate_text_request( """ model = model_types.make_model_name(model) prompt = _make_text_prompt(prompt=prompt) - safety_settings = safety_types.normalize_safety_settings( - safety_settings, harm_category_set="old" - ) + safety_settings = palm_safety_types.normalize_safety_settings(safety_settings) if isinstance(stop_sequences, str): stop_sequences = [stop_sequences] if stop_sequences: @@ -138,7 +136,7 @@ def generate_text( max_output_tokens: int | None = None, top_p: float | None = None, top_k: float | None = None, - safety_settings: safety_types.SafetySettingOptions | None = None, + safety_settings: palm_safety_types.SafetySettingOptions | None = None, stop_sequences: str | Iterable[str] | None = None, client: glm.TextServiceClient | None = None, request_options: dict[str, Any] | None = None, @@ -240,11 +238,11 @@ def _generate_response( response = client.generate_text(request, **request_options) response = type(response).to_dict(response) - response["filters"] = safety_types.convert_filters_to_enums(response["filters"]) - response["safety_feedback"] = safety_types.convert_safety_feedback_to_enums( + response["filters"] = palm_safety_types.convert_filters_to_enums(response["filters"]) + response["safety_feedback"] = palm_safety_types.convert_safety_feedback_to_enums( response["safety_feedback"] ) - response["candidates"] = safety_types.convert_candidate_enums(response["candidates"]) + response["candidates"] = palm_safety_types.convert_candidate_enums(response["candidates"]) return Completion(_client=client, **response) diff --git a/google/generativeai/types/discuss_types.py b/google/generativeai/types/discuss_types.py index 0cb393e5c..fa777d1d1 100644 --- a/google/generativeai/types/discuss_types.py +++ b/google/generativeai/types/discuss_types.py @@ -22,7 +22,7 @@ import google.ai.generativelanguage as glm from google.generativeai import string_utils -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types from google.generativeai.types import citation_types @@ -169,7 +169,7 @@ class ChatResponse(abc.ABC): temperature: Optional[float] candidate_count: Optional[int] candidates: List[MessageDict] - filters: List[safety_types.ContentFilterDict] + filters: List[palm_safety_types.ContentFilterDict] top_p: Optional[float] = None top_k: Optional[float] = None diff --git a/google/generativeai/types/palm_safety_types.py b/google/generativeai/types/palm_safety_types.py new file mode 100644 index 000000000..9fb88cd67 --- /dev/null +++ b/google/generativeai/types/palm_safety_types.py @@ -0,0 +1,286 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from collections.abc import Mapping + +import enum +import typing +from typing import Dict, Iterable, List, Union + +from typing_extensions import TypedDict + + +from google.ai import generativelanguage as glm +from google.generativeai import string_utils + + +__all__ = [ + "HarmCategory", + "HarmProbability", + "HarmBlockThreshold", + "BlockedReason", + "ContentFilterDict", + "SafetyRatingDict", + "SafetySettingDict", + "SafetyFeedbackDict", +] + +# These are basic python enums, it's okay to expose them +HarmProbability = glm.SafetyRating.HarmProbability +HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold +BlockedReason = glm.ContentFilter.BlockedReason + + +class HarmCategory: + """ + Harm Categories supported by the palm-family models + """ + + HARM_CATEGORY_UNSPECIFIED = glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value + HARM_CATEGORY_DEROGATORY = glm.HarmCategory.HARM_CATEGORY_DEROGATORY.value + HARM_CATEGORY_TOXICITY = glm.HarmCategory.HARM_CATEGORY_TOXICITY.value + HARM_CATEGORY_VIOLENCE = glm.HarmCategory.HARM_CATEGORY_VIOLENCE.value + HARM_CATEGORY_SEXUAL = glm.HarmCategory.HARM_CATEGORY_SEXUAL.value + HARM_CATEGORY_MEDICAL = glm.HarmCategory.HARM_CATEGORY_MEDICAL.value + HARM_CATEGORY_DANGEROUS = glm.HarmCategory.HARM_CATEGORY_DANGEROUS.value + + +HarmCategoryOptions = Union[str, int, HarmCategory] + +# fmt: off +_HARM_CATEGORIES: Dict[HarmCategoryOptions, glm.HarmCategory] = { + glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + 0: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "harm_category_unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + + glm.HarmCategory.HARM_CATEGORY_DEROGATORY: glm.HarmCategory.HARM_CATEGORY_DEROGATORY, + HarmCategory.HARM_CATEGORY_DEROGATORY: glm.HarmCategory.HARM_CATEGORY_DEROGATORY, + 1: glm.HarmCategory.HARM_CATEGORY_DEROGATORY, + "harm_category_derogatory": glm.HarmCategory.HARM_CATEGORY_DEROGATORY, + "derogatory": glm.HarmCategory.HARM_CATEGORY_DEROGATORY, + + glm.HarmCategory.HARM_CATEGORY_TOXICITY: glm.HarmCategory.HARM_CATEGORY_TOXICITY, + HarmCategory.HARM_CATEGORY_TOXICITY: glm.HarmCategory.HARM_CATEGORY_TOXICITY, + 2: glm.HarmCategory.HARM_CATEGORY_TOXICITY, + "harm_category_toxicity": glm.HarmCategory.HARM_CATEGORY_TOXICITY, + "toxicity": glm.HarmCategory.HARM_CATEGORY_TOXICITY, + "toxic": glm.HarmCategory.HARM_CATEGORY_TOXICITY, + + glm.HarmCategory.HARM_CATEGORY_VIOLENCE: glm.HarmCategory.HARM_CATEGORY_VIOLENCE, + HarmCategory.HARM_CATEGORY_VIOLENCE: glm.HarmCategory.HARM_CATEGORY_VIOLENCE, + 3: glm.HarmCategory.HARM_CATEGORY_VIOLENCE, + "harm_category_violence": glm.HarmCategory.HARM_CATEGORY_VIOLENCE, + "violence": glm.HarmCategory.HARM_CATEGORY_VIOLENCE, + "violent": glm.HarmCategory.HARM_CATEGORY_VIOLENCE, + + glm.HarmCategory.HARM_CATEGORY_SEXUAL: glm.HarmCategory.HARM_CATEGORY_SEXUAL, + HarmCategory.HARM_CATEGORY_SEXUAL: glm.HarmCategory.HARM_CATEGORY_SEXUAL, + 4: glm.HarmCategory.HARM_CATEGORY_SEXUAL, + "harm_category_sexual": glm.HarmCategory.HARM_CATEGORY_SEXUAL, + "sexual": glm.HarmCategory.HARM_CATEGORY_SEXUAL, + "sex": glm.HarmCategory.HARM_CATEGORY_SEXUAL, + + glm.HarmCategory.HARM_CATEGORY_MEDICAL: glm.HarmCategory.HARM_CATEGORY_MEDICAL, + HarmCategory.HARM_CATEGORY_MEDICAL: glm.HarmCategory.HARM_CATEGORY_MEDICAL, + 5: glm.HarmCategory.HARM_CATEGORY_MEDICAL, + "harm_category_medical": glm.HarmCategory.HARM_CATEGORY_MEDICAL, + "medical": glm.HarmCategory.HARM_CATEGORY_MEDICAL, + "med": glm.HarmCategory.HARM_CATEGORY_MEDICAL, + + glm.HarmCategory.HARM_CATEGORY_DANGEROUS: glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + HarmCategory.HARM_CATEGORY_DANGEROUS: glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + 6: glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + "harm_category_dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + "dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + "danger": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, +} +# fmt: on + + +def to_harm_category(x: HarmCategoryOptions) -> glm.HarmCategory: + if isinstance(x, str): + x = x.lower() + return _HARM_CATEGORIES[x] + + +HarmBlockThresholdOptions = Union[str, int, HarmBlockThreshold] + +# fmt: off +_BLOCK_THRESHOLDS: Dict[HarmBlockThresholdOptions, HarmBlockThreshold] = { + HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, + 0: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, + "harm_block_threshold_unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, + "block_threshold_unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, + "unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, + + HarmBlockThreshold.BLOCK_LOW_AND_ABOVE: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + 1: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + "block_low_and_above": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + "low": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + + HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + 2: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + "block_medium_and_above": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + "medium": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + "med": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + + HarmBlockThreshold.BLOCK_ONLY_HIGH: HarmBlockThreshold.BLOCK_ONLY_HIGH, + 3: HarmBlockThreshold.BLOCK_ONLY_HIGH, + "block_only_high": HarmBlockThreshold.BLOCK_ONLY_HIGH, + "high": HarmBlockThreshold.BLOCK_ONLY_HIGH, + + HarmBlockThreshold.BLOCK_NONE: HarmBlockThreshold.BLOCK_NONE, + 4: HarmBlockThreshold.BLOCK_NONE, + "block_none": HarmBlockThreshold.BLOCK_NONE, +} +# fmt: on + + +def to_block_threshold(x: HarmBlockThresholdOptions) -> HarmBlockThreshold: + if isinstance(x, str): + x = x.lower() + return _BLOCK_THRESHOLDS[x] + + +class ContentFilterDict(TypedDict): + reason: BlockedReason + message: str + + __doc__ = string_utils.strip_oneof(glm.ContentFilter.__doc__) + + +def convert_filters_to_enums( + filters: Iterable[dict], +) -> List[ContentFilterDict]: + result = [] + for f in filters: + f = f.copy() + f["reason"] = BlockedReason(f["reason"]) + f = typing.cast(ContentFilterDict, f) + result.append(f) + return result + + +class SafetyRatingDict(TypedDict): + category: glm.HarmCategory + probability: HarmProbability + + __doc__ = string_utils.strip_oneof(glm.SafetyRating.__doc__) + + +def convert_rating_to_enum(rating: dict) -> SafetyRatingDict: + return { + "category": glm.HarmCategory(rating["category"]), + "probability": HarmProbability(rating["probability"]), + } + + +def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]: + result = [] + for r in ratings: + result.append(convert_rating_to_enum(r)) + return result + + +class SafetySettingDict(TypedDict): + category: glm.HarmCategory + threshold: HarmBlockThreshold + + __doc__ = string_utils.strip_oneof(glm.SafetySetting.__doc__) + + +class LooseSafetySettingDict(TypedDict): + category: HarmCategoryOptions + threshold: HarmBlockThresholdOptions + + +EasySafetySetting = Mapping[HarmCategoryOptions, HarmBlockThresholdOptions] +EasySafetySettingDict = dict[HarmCategoryOptions, HarmBlockThresholdOptions] + +SafetySettingOptions = Union[EasySafetySetting, Iterable[LooseSafetySettingDict], None] + + +def to_easy_safety_dict(settings: SafetySettingOptions) -> EasySafetySettingDict: + if settings is None: + return {} + elif isinstance(settings, Mapping): + return {to_harm_category(key): to_block_threshold(value) for key, value in settings.items()} + else: # Iterable + return { + to_harm_category(d["category"]): to_block_threshold(d["threshold"]) for d in settings + } + + +def normalize_safety_settings( + settings: SafetySettingOptions, +) -> list[SafetySettingDict] | None: + if settings is None: + return None + if isinstance(settings, Mapping): + return [ + { + "category": to_harm_category(key), + "threshold": to_block_threshold(value), + } + for key, value in settings.items() + ] + else: + return [ + { + "category": to_harm_category(d["category"]), + "threshold": to_block_threshold(d["threshold"]), + } + for d in settings + ] + + +def convert_setting_to_enum(setting: dict) -> SafetySettingDict: + return { + "category": glm.HarmCategory(setting["category"]), + "threshold": HarmBlockThreshold(setting["threshold"]), + } + + +class SafetyFeedbackDict(TypedDict): + rating: SafetyRatingDict + setting: SafetySettingDict + + __doc__ = string_utils.strip_oneof(glm.SafetyFeedback.__doc__) + + +def convert_safety_feedback_to_enums( + safety_feedback: Iterable[dict], +) -> List[SafetyFeedbackDict]: + result = [] + for sf in safety_feedback: + result.append( + { + "rating": convert_rating_to_enum(sf["rating"]), + "setting": convert_setting_to_enum(sf["setting"]), + } + ) + return result + + +def convert_candidate_enums(candidates): + result = [] + for candidate in candidates: + candidate = candidate.copy() + candidate["safety_ratings"] = convert_ratings_to_enum(candidate["safety_ratings"]) + result.append(candidate) + return result diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index 7d94a5bb0..85e57c8f6 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -1,7 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from __future__ import annotations from collections.abc import Mapping +import enum import typing from typing import Dict, Iterable, List, Union @@ -24,105 +39,72 @@ ] # These are basic python enums, it's okay to expose them -HarmCategory = glm.HarmCategory HarmProbability = glm.SafetyRating.HarmProbability HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold BlockedReason = glm.ContentFilter.BlockedReason +import proto + + +class HarmCategory(proto.Enum): + """ + Harm Categories supported by the gemini-family model + """ + + HARM_CATEGORY_UNSPECIFIED = glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value + HARM_CATEGORY_HARASSMENT = glm.HarmCategory.HARM_CATEGORY_HARASSMENT.value + HARM_CATEGORY_HATE_SPEECH = glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH.value + HARM_CATEGORY_SEXUALLY_EXPLICIT = glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT.value + HARM_CATEGORY_DANGEROUS_CONTENT = glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT.value + + HarmCategoryOptions = Union[str, int, HarmCategory] # fmt: off -_OLD_HARM_CATEGORIES: Dict[HarmCategoryOptions, HarmCategory] = { - HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmCategory.HARM_CATEGORY_UNSPECIFIED, - 0: HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "harm_category_unspecified": HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "unspecified": HarmCategory.HARM_CATEGORY_UNSPECIFIED, - - HarmCategory.HARM_CATEGORY_DEROGATORY: HarmCategory.HARM_CATEGORY_DEROGATORY, - 1: HarmCategory.HARM_CATEGORY_DEROGATORY, - "harm_category_derogatory": HarmCategory.HARM_CATEGORY_DEROGATORY, - "derogatory": HarmCategory.HARM_CATEGORY_DEROGATORY, - - HarmCategory.HARM_CATEGORY_TOXICITY: HarmCategory.HARM_CATEGORY_TOXICITY, - 2: HarmCategory.HARM_CATEGORY_TOXICITY, - "harm_category_toxicity": HarmCategory.HARM_CATEGORY_TOXICITY, - "toxicity": HarmCategory.HARM_CATEGORY_TOXICITY, - "toxic": HarmCategory.HARM_CATEGORY_TOXICITY, - - HarmCategory.HARM_CATEGORY_VIOLENCE: HarmCategory.HARM_CATEGORY_VIOLENCE, - 3: HarmCategory.HARM_CATEGORY_VIOLENCE, - "harm_category_violence": HarmCategory.HARM_CATEGORY_VIOLENCE, - "violence": HarmCategory.HARM_CATEGORY_VIOLENCE, - "violent": HarmCategory.HARM_CATEGORY_VIOLENCE, - - HarmCategory.HARM_CATEGORY_SEXUAL: HarmCategory.HARM_CATEGORY_SEXUAL, - 4: HarmCategory.HARM_CATEGORY_SEXUAL, - "harm_category_sexual": HarmCategory.HARM_CATEGORY_SEXUAL, - "sexual": HarmCategory.HARM_CATEGORY_SEXUAL, - "sex": HarmCategory.HARM_CATEGORY_SEXUAL, - - HarmCategory.HARM_CATEGORY_MEDICAL: HarmCategory.HARM_CATEGORY_MEDICAL, - 5: HarmCategory.HARM_CATEGORY_MEDICAL, - "harm_category_medical": HarmCategory.HARM_CATEGORY_MEDICAL, - "medical": HarmCategory.HARM_CATEGORY_MEDICAL, - "med": HarmCategory.HARM_CATEGORY_MEDICAL, - - HarmCategory.HARM_CATEGORY_DANGEROUS: HarmCategory.HARM_CATEGORY_DANGEROUS, - 6: HarmCategory.HARM_CATEGORY_DANGEROUS, - "harm_category_dangerous": HarmCategory.HARM_CATEGORY_DANGEROUS, - "dangerous": HarmCategory.HARM_CATEGORY_DANGEROUS, - "danger": HarmCategory.HARM_CATEGORY_DANGEROUS, -} - -_NEW_HARM_CATEGORIES = { - 7: HarmCategory.HARM_CATEGORY_HARASSMENT, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmCategory.HARM_CATEGORY_HARASSMENT, - "harm_category_harassment": HarmCategory.HARM_CATEGORY_HARASSMENT, - "harassment": HarmCategory.HARM_CATEGORY_HARASSMENT, - - 8: HarmCategory.HARM_CATEGORY_HATE_SPEECH, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'harm_category_hate_speech': HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'hate_speech': HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'hate': HarmCategory.HARM_CATEGORY_HATE_SPEECH, - - 9: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "harm_category_sexually_explicit": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "harm_category_sexual": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sexually_explicit": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sexual": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sex": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - - 10: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "harm_category_dangerous_content": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "harm_category_dangerous": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "dangerous": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "danger": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, +_HARM_CATEGORIES: Dict[HarmCategoryOptions, glm.HarmCategory] = { + glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + 0: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "harm_category_unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + + 7: glm.HarmCategory.HARM_CATEGORY_HARASSMENT, + glm.HarmCategory.HARM_CATEGORY_HARASSMENT: glm.HarmCategory.HARM_CATEGORY_HARASSMENT, + HarmCategory.HARM_CATEGORY_HARASSMENT: glm.HarmCategory.HARM_CATEGORY_HARASSMENT, + "harm_category_harassment": glm.HarmCategory.HARM_CATEGORY_HARASSMENT, + "harassment": glm.HarmCategory.HARM_CATEGORY_HARASSMENT, + + 8: glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH: glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'harm_category_hate_speech': glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'hate_speech': glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'hate': glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + + 9: glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "harm_category_sexually_explicit": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "harm_category_sexual": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sexually_explicit": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sexual": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sex": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + + 10: glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "harm_category_dangerous_content": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "harm_category_dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "danger": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, } # fmt: on -def to_old_harm_category(x: HarmCategoryOptions) -> HarmCategory: - if isinstance(x, str): - x = x.lower() - return _OLD_HARM_CATEGORIES[x] - - -def to_new_harm_category(x: HarmCategoryOptions) -> HarmCategory: +def to_harm_category(x: HarmCategoryOptions) -> glm.HarmCategory: if isinstance(x, str): x = x.lower() - return _NEW_HARM_CATEGORIES[x] - - -def to_harm_category(x, harm_category_set): - if harm_category_set == "old": - return to_old_harm_category(x) - elif harm_category_set == "new": - return to_new_harm_category(x) - else: - raise ValueError("harm_category_set must be 'new' or 'old'") + return _HARM_CATEGORIES[x] HarmBlockThresholdOptions = Union[str, int, HarmBlockThreshold] @@ -158,7 +140,7 @@ def to_harm_category(x, harm_category_set): # fmt: on -def to_block_threshold(x: HarmBlockThresholdOptions) -> HarmCategory: +def to_block_threshold(x: HarmBlockThresholdOptions) -> HarmBlockThreshold: if isinstance(x, str): x = x.lower() return _BLOCK_THRESHOLDS[x] @@ -184,7 +166,7 @@ def convert_filters_to_enums( class SafetyRatingDict(TypedDict): - category: HarmCategory + category: glm.HarmCategory probability: HarmProbability __doc__ = string_utils.strip_oneof(glm.SafetyRating.__doc__) @@ -192,7 +174,7 @@ class SafetyRatingDict(TypedDict): def convert_rating_to_enum(rating: dict) -> SafetyRatingDict: return { - "category": HarmCategory(rating["category"]), + "category": glm.HarmCategory(rating["category"]), "probability": HarmProbability(rating["probability"]), } @@ -205,7 +187,7 @@ def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]: class SafetySettingDict(TypedDict): - category: HarmCategory + category: glm.HarmCategory threshold: HarmBlockThreshold __doc__ = string_utils.strip_oneof(glm.SafetySetting.__doc__) @@ -222,31 +204,26 @@ class LooseSafetySettingDict(TypedDict): SafetySettingOptions = Union[EasySafetySetting, Iterable[LooseSafetySettingDict], None] -def to_easy_safety_dict(settings: SafetySettingOptions, harm_category_set) -> EasySafetySettingDict: +def to_easy_safety_dict(settings: SafetySettingOptions) -> EasySafetySettingDict: if settings is None: return {} elif isinstance(settings, Mapping): - return { - to_harm_category(key, harm_category_set): to_block_threshold(value) - for key, value in settings.items() - } + return {to_harm_category(key): to_block_threshold(value) for key, value in settings.items()} else: # Iterable return { - to_harm_category(d["category"], harm_category_set): to_block_threshold(d["threshold"]) - for d in settings + to_harm_category(d["category"]): to_block_threshold(d["threshold"]) for d in settings } def normalize_safety_settings( settings: SafetySettingOptions, - harm_category_set, ) -> list[SafetySettingDict] | None: if settings is None: return None if isinstance(settings, Mapping): return [ { - "category": to_harm_category(key, harm_category_set), + "category": to_harm_category(key), "threshold": to_block_threshold(value), } for key, value in settings.items() @@ -254,7 +231,7 @@ def normalize_safety_settings( else: return [ { - "category": to_harm_category(d["category"], harm_category_set), + "category": to_harm_category(d["category"]), "threshold": to_block_threshold(d["threshold"]), } for d in settings @@ -263,7 +240,7 @@ def normalize_safety_settings( def convert_setting_to_enum(setting: dict) -> SafetySettingDict: return { - "category": HarmCategory(setting["category"]), + "category": glm.HarmCategory(setting["category"]), "threshold": HarmBlockThreshold(setting["threshold"]), } diff --git a/google/generativeai/types/text_types.py b/google/generativeai/types/text_types.py index f66c0fb32..61804fcaa 100644 --- a/google/generativeai/types/text_types.py +++ b/google/generativeai/types/text_types.py @@ -21,7 +21,7 @@ from typing_extensions import TypedDict from google.generativeai import string_utils -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types from google.generativeai.types import citation_types @@ -42,7 +42,7 @@ class BatchEmbeddingDict(TypedDict): class TextCompletion(TypedDict, total=False): output: str - safety_ratings: List[safety_types.SafetyRatingDict | None] + safety_ratings: List[palm_safety_types.SafetyRatingDict | None] citation_metadata: citation_types.CitationMetadataDict | None @@ -63,8 +63,8 @@ class Completion(abc.ABC): candidates: List[TextCompletion] result: str | None - filters: List[safety_types.ContentFilterDict | None] - safety_feedback: List[safety_types.SafetyFeedbackDict | None] + filters: List[palm_safety_types.ContentFilterDict | None] + safety_feedback: List[palm_safety_types.SafetyFeedbackDict | None] def to_dict(self) -> Dict[str, Any]: result = { diff --git a/tests/test_discuss.py b/tests/test_discuss.py index 9d628a42c..183ccd0c3 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -22,7 +22,7 @@ from google.generativeai import discuss from google.generativeai import client import google.generativeai as genai -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types from absl.testing import absltest from absl.testing import parameterized @@ -289,32 +289,32 @@ def test_receive_and_reply_with_filters(self): self.mock_response = mock_response = glm.GenerateMessageResponse( candidates=[glm.Message(content="a", author="1")], filters=[ - glm.ContentFilter(reason=safety_types.BlockedReason.SAFETY, message="unsafe"), - glm.ContentFilter(reason=safety_types.BlockedReason.OTHER), + glm.ContentFilter(reason=palm_safety_types.BlockedReason.SAFETY, message="unsafe"), + glm.ContentFilter(reason=palm_safety_types.BlockedReason.OTHER), ], ) response = discuss.chat(messages="do filters work?") filters = response.filters self.assertLen(filters, 2) - self.assertIsInstance(filters[0]["reason"], safety_types.BlockedReason) - self.assertEqual(filters[0]["reason"], safety_types.BlockedReason.SAFETY) + self.assertIsInstance(filters[0]["reason"], palm_safety_types.BlockedReason) + self.assertEqual(filters[0]["reason"], palm_safety_types.BlockedReason.SAFETY) self.assertEqual(filters[0]["message"], "unsafe") self.mock_response = glm.GenerateMessageResponse( candidates=[glm.Message(content="a", author="1")], filters=[ - glm.ContentFilter(reason=safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED) + glm.ContentFilter(reason=palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED) ], ) response = response.reply("Does reply work?") filters = response.filters self.assertLen(filters, 1) - self.assertIsInstance(filters[0]["reason"], safety_types.BlockedReason) + self.assertIsInstance(filters[0]["reason"], palm_safety_types.BlockedReason) self.assertEqual( filters[0]["reason"], - safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED, + palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED, ) def test_chat_citations(self): diff --git a/tests/test_text.py b/tests/test_text.py index 0bc1d4e59..5dcda93b9 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -22,7 +22,7 @@ from google.generativeai import text as text_service from google.generativeai import client -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types from google.generativeai.types import model_types from absl.testing import absltest from absl.testing import parameterized @@ -246,12 +246,12 @@ def test_stop_string(self): testcase_name="basic", safety_settings=[ { - "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "threshold": palm_safety_types.HarmBlockThreshold.BLOCK_NONE, }, { - "category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, - "threshold": safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, + "threshold": palm_safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, }, ], ), @@ -275,8 +275,8 @@ def test_stop_string(self): dict( testcase_name="mixed", safety_settings={ - "medical": safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE: 1, + "medical": palm_safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + palm_safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE: 1, }, ), ] @@ -294,7 +294,7 @@ def test_safety_settings(self, safety_settings): self.assertEqual( self.observed_requests[-1].safety_settings[0].category, - safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, ) def test_filters(self): @@ -302,15 +302,15 @@ def test_filters(self): candidates=[{"output": "hello"}], filters=[ { - "reason": safety_types.BlockedReason.SAFETY, + "reason": palm_safety_types.BlockedReason.SAFETY, "message": "not safe", } ], ) response = text_service.generate_text(prompt="do filters work?") - self.assertIsInstance(response.filters[0]["reason"], safety_types.BlockedReason) - self.assertEqual(response.filters[0]["reason"], safety_types.BlockedReason.SAFETY) + self.assertIsInstance(response.filters[0]["reason"], palm_safety_types.BlockedReason) + self.assertEqual(response.filters[0]["reason"], palm_safety_types.BlockedReason.SAFETY) def test_safety_feedback(self): self.responses["generate_text"] = glm.GenerateTextResponse( @@ -318,12 +318,12 @@ def test_safety_feedback(self): safety_feedback=[ { "rating": { - "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - "probability": safety_types.HarmProbability.HIGH, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "probability": palm_safety_types.HarmProbability.HIGH, }, "setting": { - "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "threshold": palm_safety_types.HarmBlockThreshold.BLOCK_NONE, }, } ], @@ -332,20 +332,20 @@ def test_safety_feedback(self): response = text_service.generate_text(prompt="does safety feedback work?") self.assertIsInstance( response.safety_feedback[0]["rating"]["probability"], - safety_types.HarmProbability, + palm_safety_types.HarmProbability, ) self.assertEqual( response.safety_feedback[0]["rating"]["probability"], - safety_types.HarmProbability.HIGH, + palm_safety_types.HarmProbability.HIGH, ) self.assertIsInstance( response.safety_feedback[0]["setting"]["category"], - safety_types.HarmCategory, + glm.HarmCategory, ) self.assertEqual( response.safety_feedback[0]["setting"]["category"], - safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, ) def test_candidate_safety_feedback(self): @@ -355,12 +355,12 @@ def test_candidate_safety_feedback(self): "output": "hello", "safety_ratings": [ { - "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - "probability": safety_types.HarmProbability.HIGH, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "probability": palm_safety_types.HarmProbability.HIGH, }, { - "category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, - "probability": safety_types.HarmProbability.LOW, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, + "probability": palm_safety_types.HarmProbability.LOW, }, ], } @@ -370,20 +370,20 @@ def test_candidate_safety_feedback(self): result = text_service.generate_text(prompt="Write a story from the ER.") self.assertIsInstance( result.candidates[0]["safety_ratings"][0]["category"], - safety_types.HarmCategory, + glm.HarmCategory, ) self.assertEqual( result.candidates[0]["safety_ratings"][0]["category"], - safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, ) self.assertIsInstance( result.candidates[0]["safety_ratings"][0]["probability"], - safety_types.HarmProbability, + palm_safety_types.HarmProbability, ) self.assertEqual( result.candidates[0]["safety_ratings"][0]["probability"], - safety_types.HarmProbability.HIGH, + palm_safety_types.HarmProbability.HIGH, ) def test_candidate_citations(self): From 51d806d7c9f121696ef10f85b5101b1b9d7a8091 Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Fri, 17 May 2024 21:52:46 +0300 Subject: [PATCH 02/17] Fix bugs, improve code clarity, and enhance overall reliability across several files. (#339) * Fix and improve * Fix `_make_grounding_passages` , `_make_generate_answer_request` * fix get_default_permission_client and get_default_permission_async_client * Add how to test all in CONTRIBUTING.md * fix back support for `tunedModels/` in `get_model` function * Add pytest to CONTRIBUTING.md * Break down test_generate_text for better debugging. * Add pip install nose2 to CONTRIBUTING.md * Format Change-Id: I4e222f3e01cb8d350ae293b35a88fd5f718fe3dc * fix sloppy types in tests Change-Id: I3ad717ca26e5d170e4bbef23076e528badaaaacb * Update CONTRIBUTING.md * Update CONTRIBUTING.md --------- Co-authored-by: Mark Daoust --- CONTRIBUTING.md | 30 +++++++++++++-- google/generativeai/answer.py | 6 +-- google/generativeai/client.py | 4 +- google/generativeai/generative_models.py | 10 +++-- google/generativeai/models.py | 24 ++++++------ google/generativeai/types/permission_types.py | 28 +++++++------- tests/notebook/text_model_test.py | 38 ++++++++++++++++--- 7 files changed, 97 insertions(+), 43 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e51ac7205..9415df2a8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -62,17 +62,41 @@ This "editable" mode lets you edit the source without needing to reinstall the p ### Testing -Use the builtin unittest package: +To ensure the integrity of the codebase, we have a suite of tests located in the `generative-ai-python/tests` directory. +You can run all these tests using Python's built-in `unittest` module or the `pytest` library. + +For `unittest`, open a terminal and navigate to the root directory of the project. Then, execute the following command: + +``` +python -m unittest discover -s tests + +# or more simply +python -m unittest ``` - python -m unittest + +Alternatively, if you prefer using `pytest`, you can install it using pip: + ``` +pip install pytest +``` + +Then, run the tests with the following command: + +``` +pytest tests + +# or more simply +pytest +``` + Or to debug, use: ```commandline +pip install nose2 + nose2 --debugger -``` ### Type checking diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index 637002052..d1af3adf0 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -94,7 +94,7 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP if not isinstance(source, Iterable): raise TypeError( - f"`source` must be a valid `GroundingPassagesOptions` type object got a: `{type(source)}`." + f"The 'source' argument must be an instance of 'GroundingPassagesOptions', but got a '{type(source).__name__}' object instead." ) passages = [] @@ -182,7 +182,7 @@ def _make_generate_answer_request( temperature: float | None = None, ) -> glm.GenerateAnswerRequest: """ - Calls the API to generate a grounded answer from the model. + constructs a glm.GenerateAnswerRequest object by organizing the input parameters for the API call to generate a grounded answer from the model. Args: model: Name of the model used to generate the grounded response. @@ -217,7 +217,7 @@ def _make_generate_answer_request( elif semantic_retriever is not None: semantic_retriever = _make_semantic_retriever_config(semantic_retriever, contents[-1]) else: - TypeError( + raise TypeError( f"The source must be either an `inline_passages` xor `semantic_retriever_config`, but both are `None`" ) diff --git a/google/generativeai/client.py b/google/generativeai/client.py index e8e91ae7e..31160757f 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -328,9 +328,9 @@ def get_default_retriever_async_client() -> glm.RetrieverAsyncClient: return _client_manager.get_default_client("retriever_async") -def get_dafault_permission_client() -> glm.PermissionServiceClient: +def get_default_permission_client() -> glm.PermissionServiceClient: return _client_manager.get_default_client("permission") -def get_dafault_permission_async_client() -> glm.PermissionServiceAsyncClient: +def get_default_permission_async_client() -> glm.PermissionServiceAsyncClient: return _client_manager.get_default_client("permission_async") diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 4d71baf48..258baf0d3 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -387,7 +387,7 @@ def start_chat( >>> response = chat.send_message("Hello?") Arguments: - history: An iterable of `glm.Content` objects, or equvalents to initialize the session. + history: An iterable of `glm.Content` objects, or equivalents to initialize the session. """ if self._generation_config.get("candidate_count", 1) > 1: raise ValueError("Can't chat with `candidate_count > 1`") @@ -401,11 +401,13 @@ def start_chat( class ChatSession: """Contains an ongoing conversation with the model. - >>> model = genai.GenerativeModel(model="gemini-pro") + >>> model = genai.GenerativeModel('models/gemini-pro') >>> chat = model.start_chat() >>> response = chat.send_message("Hello") >>> print(response.text) - >>> response = chat.send_message(...) + >>> response = chat.send_message("Hello again") + >>> print(response.text) + >>> response = chat.send_message(... This `ChatSession` object collects the messages sent and received, in its `ChatSession.history` attribute. @@ -444,7 +446,7 @@ def send_message( Appends the request and response to the conversation history. - >>> model = genai.GenerativeModel(model="gemini-pro") + >>> model = genai.GenerativeModel('models/gemini-pro') >>> chat = model.start_chat() >>> response = chat.send_message("Hello") >>> print(response.text) diff --git a/google/generativeai/models.py b/google/generativeai/models.py index 7c7b8a5cf..16932921a 100644 --- a/google/generativeai/models.py +++ b/google/generativeai/models.py @@ -33,21 +33,21 @@ def get_model( client=None, request_options: dict[str, Any] | None = None, ) -> model_types.Model | model_types.TunedModel: - """Given a model name, fetch the `types.Model` or `types.TunedModel` object. + """Given a model name, fetch the `types.Model` ``` import pprint - model = genai.get_tuned_model(model_name): + model = genai.get_model('models/gemini-pro') pprint.pprint(model) ``` Args: - name: The name of the model to fetch. + name: The name of the model to fetch. Should start with `models/` client: The client to use. request_options: Options for the request. Returns: - A `types.Model` or `types.TunedModel` object. + A `types.Model` """ name = model_types.make_model_name(name) if name.startswith("models/"): @@ -55,7 +55,9 @@ def get_model( elif name.startswith("tunedModels/"): return get_tuned_model(name, client=client, request_options=request_options) else: - raise ValueError("Model names must start with `models/` or `tunedModels/`") + raise ValueError( + f"Model names must start with `models/` or `tunedModels/`. Received: {name}" + ) def get_base_model( @@ -68,12 +70,12 @@ def get_base_model( ``` import pprint - model = genai.get_model('models/chat-bison-001'): + model = genai.get_base_model('models/chat-bison-001') pprint.pprint(model) ``` Args: - name: The name of the model to fetch. + name: The name of the model to fetch. Should start with `models/` client: The client to use. request_options: Options for the request. @@ -88,7 +90,7 @@ def get_base_model( name = model_types.make_model_name(name) if not name.startswith("models/"): - raise ValueError(f"Base model names must start with `models/`, got: {name}") + raise ValueError(f"Base model names must start with `models/`, received: {name}") result = client.get_model(name=name, **request_options) result = type(result).to_dict(result) @@ -105,12 +107,12 @@ def get_tuned_model( ``` import pprint - model = genai.get_tuned_model('tunedModels/my-model-1234'): + model = genai.get_tuned_model('tunedModels/gemini-1.0-pro-001') pprint.pprint(model) ``` Args: - name: The name of the model to fetch. + name: The name of the model to fetch. Should start with `tunedModels/` client: The client to use. request_options: Options for the request. @@ -126,7 +128,7 @@ def get_tuned_model( name = model_types.make_model_name(name) if not name.startswith("tunedModels/"): - raise ValueError("Tuned model names must start with `tunedModels/`") + raise ValueError("Tuned model names must start with `tunedModels/` received: {name}") result = client.get_tuned_model(name=name, **request_options) diff --git a/google/generativeai/types/permission_types.py b/google/generativeai/types/permission_types.py index ef9242999..db1867695 100644 --- a/google/generativeai/types/permission_types.py +++ b/google/generativeai/types/permission_types.py @@ -22,8 +22,8 @@ from google.protobuf import field_mask_pb2 -from google.generativeai.client import get_dafault_permission_client -from google.generativeai.client import get_dafault_permission_async_client +from google.generativeai.client import get_default_permission_client +from google.generativeai.client import get_default_permission_async_client from google.generativeai.utils import flatten_update_paths from google.generativeai import string_utils @@ -107,7 +107,7 @@ def delete( Delete permission (self). """ if client is None: - client = get_dafault_permission_client() + client = get_default_permission_client() delete_request = glm.DeletePermissionRequest(name=self.name) client.delete_permission(request=delete_request) @@ -119,7 +119,7 @@ async def delete_async( This is the async version of `Permission.delete`. """ if client is None: - client = get_dafault_permission_async_client() + client = get_default_permission_async_client() delete_request = glm.DeletePermissionRequest(name=self.name) await client.delete_permission(request=delete_request) @@ -146,7 +146,7 @@ def update( `Permission` object with specified updates. """ if client is None: - client = get_dafault_permission_client() + client = get_default_permission_client() updates = flatten_update_paths(updates) for update_path in updates: @@ -176,7 +176,7 @@ async def update_async( This is the async version of `Permission.update`. """ if client is None: - client = get_dafault_permission_async_client() + client = get_default_permission_async_client() updates = flatten_update_paths(updates) for update_path in updates: @@ -224,7 +224,7 @@ def get( Requested permission as an instance of `Permission`. """ if client is None: - client = get_dafault_permission_client() + client = get_default_permission_client() get_perm_request = glm.GetPermissionRequest(name=name) get_perm_response = client.get_permission(request=get_perm_request) get_perm_response = type(get_perm_response).to_dict(get_perm_response) @@ -240,7 +240,7 @@ async def get_async( This is the async version of `Permission.get`. """ if client is None: - client = get_dafault_permission_async_client() + client = get_default_permission_async_client() get_perm_request = glm.GetPermissionRequest(name=name) get_perm_response = await client.get_permission(request=get_perm_request) get_perm_response = type(get_perm_response).to_dict(get_perm_response) @@ -313,7 +313,7 @@ def create( ValueError: When email_address is not specified and grantee_type is not set to EVERYONE. """ if client is None: - client = get_dafault_permission_client() + client = get_default_permission_client() request = self._make_create_permission_request( role=role, grantee_type=grantee_type, email_address=email_address @@ -333,7 +333,7 @@ async def create_async( This is the async version of `PermissionAdapter.create_permission`. """ if client is None: - client = get_dafault_permission_async_client() + client = get_default_permission_async_client() request = self._make_create_permission_request( role=role, grantee_type=grantee_type, email_address=email_address @@ -358,7 +358,7 @@ def list( Paginated list of `Permission` objects. """ if client is None: - client = get_dafault_permission_client() + client = get_default_permission_client() request = glm.ListPermissionsRequest( parent=self.parent, page_size=page_size # pytype: disable=attribute-error @@ -376,7 +376,7 @@ async def list_async( This is the async version of `PermissionAdapter.list_permissions`. """ if client is None: - client = get_dafault_permission_async_client() + client = get_default_permission_async_client() request = glm.ListPermissionsRequest( parent=self.parent, page_size=page_size # pytype: disable=attribute-error @@ -400,7 +400,7 @@ def transfer_ownership( if self.parent.startswith("corpora"): raise NotImplementedError("Can'/t transfer_ownership for a Corpus") if client is None: - client = get_dafault_permission_client() + client = get_default_permission_client() transfer_request = glm.TransferOwnershipRequest( name=self.parent, email_address=email_address # pytype: disable=attribute-error ) @@ -415,7 +415,7 @@ async def transfer_ownership_async( if self.parent.startswith("corpora"): raise NotImplementedError("Can'/t transfer_ownership for a Corpus") if client is None: - client = get_dafault_permission_async_client() + client = get_default_permission_async_client() transfer_request = glm.TransferOwnershipRequest( name=self.parent, email_address=email_address # pytype: disable=attribute-error ) diff --git a/tests/notebook/text_model_test.py b/tests/notebook/text_model_test.py index 9239ac9c3..428d44b26 100644 --- a/tests/notebook/text_model_test.py +++ b/tests/notebook/text_model_test.py @@ -68,21 +68,47 @@ def _generate_text( class TextModelTestCase(absltest.TestCase): - def test_generate_text(self): + def test_generate_text_without_args(self): model = TestModel() result = model.call_model("prompt goes in") self.assertEqual(result.text_results[0], "prompt goes in_1") - self.assertIsNone(result.text_results[1]) - self.assertIsNone(result.text_results[2]) - self.assertIsNone(result.text_results[3]) + def test_generate_text_without_args_none_results(self): + model = TestModel() + + result = model.call_model("prompt goes in") + self.assertEqual(result.text_results[1], "None") + self.assertEqual(result.text_results[2], "None") + self.assertEqual(result.text_results[3], "None") + + def test_generate_text_with_args_first_result(self): + model = TestModel() args = model_lib.ModelArguments(model="model_name", temperature=0.42, candidate_count=5) + result = model.call_model("prompt goes in", args) self.assertEqual(result.text_results[0], "prompt goes in_1") + + def test_generate_text_with_args_model_name(self): + model = TestModel() + args = model_lib.ModelArguments(model="model_name", temperature=0.42, candidate_count=5) + + result = model.call_model("prompt goes in", args) self.assertEqual(result.text_results[1], "model_name") - self.assertEqual(result.text_results[2], 0.42) - self.assertEqual(result.text_results[3], 5) + + def test_generate_text_with_args_temperature(self): + model = TestModel() + args = model_lib.ModelArguments(model="model_name", temperature=0.42, candidate_count=5) + result = model.call_model("prompt goes in", args) + + self.assertEqual(result.text_results[2], str(0.42)) + + def test_generate_text_with_args_candidate_count(self): + model = TestModel() + args = model_lib.ModelArguments(model="model_name", temperature=0.42, candidate_count=5) + + result = model.call_model("prompt goes in", args) + self.assertEqual(result.text_results[3], str(5)) def test_retry(self): model = TestModel() From 30337c2bb735e6882ea977166d7216382d3f24f1 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 17 May 2024 16:35:38 -0700 Subject: [PATCH 03/17] Improve request_options (#297) * Working on request_options * Add helper_types Change-Id: Idc3e813616413f4ce085c05b771c0127e4dfc886 * format Change-Id: I186e015de97ceece56ee5a97f6edef47ef223d18 * UpdateRequestOptions Change-Id: I9f92466967fb1aa605d442cb143699da4308409b * Add docs Change-Id: I209b2b2ad8d783001b1828cbcac84ca301c11bec * work Change-Id: I00a2e2edb1e9bf3d4f51c0a868a34e044be3c6ff * Fix Py3.9 Change-Id: I8cf0ccac90ba3c4548e7549fec7d0b9b58925e7e * use RequestOptions in tests Change-Id: I92b68bc86330ad874c3765f428a2e64ba220750f * annotations Change-Id: Idbc428075729255d66d2ba8b3bcce0a1d6e8f048 * Update tests/test_discuss.py Co-authored-by: Mark McDonald * tests Change-Id: Ife30e2cc47bd4c52d2dddafdd85a51df0e42e160 --------- Co-authored-by: Mark McDonald --- google/generativeai/answer.py | 8 +- google/generativeai/discuss.py | 13 +-- google/generativeai/embedding.py | 16 ++-- google/generativeai/generative_models.py | 10 +-- google/generativeai/models.py | 23 +++--- google/generativeai/retriever.py | 17 ++-- google/generativeai/text.py | 13 +-- google/generativeai/types/__init__.py | 1 + google/generativeai/types/helper_types.py | 84 ++++++++++++++++++++ google/generativeai/types/retriever_types.py | 66 +++++++-------- tests/test_answer.py | 3 +- tests/test_discuss.py | 1 - tests/test_helpers.py | 83 +++++++++++++++++++ tests/test_models.py | 3 +- 14 files changed, 257 insertions(+), 84 deletions(-) create mode 100644 google/generativeai/types/helper_types.py create mode 100644 tests/test_helpers.py diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index d1af3adf0..1b419be57 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -26,12 +26,10 @@ get_default_generative_client, get_default_generative_async_client, ) -from google.generativeai import string_utils from google.generativeai.types import model_types -from google.generativeai import models +from google.generativeai.types import helper_types from google.generativeai.types import safety_types from google.generativeai.types import content_types -from google.generativeai.types import answer_types from google.generativeai.types import retriever_types from google.generativeai.types.retriever_types import MetadataFilter @@ -245,7 +243,7 @@ def generate_answer( safety_settings: safety_types.SafetySettingOptions | None = None, temperature: float | None = None, client: glm.GenerativeServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Calls the GenerateAnswer API and returns a `types.Answer` containing the response. @@ -318,7 +316,7 @@ async def generate_answer_async( safety_settings: safety_types.SafetySettingOptions | None = None, temperature: float | None = None, client: glm.GenerativeServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Calls the API and returns a `types.Answer` containing the answer. diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index 81e087aa0..35611ae69 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -26,6 +26,7 @@ from google.generativeai.client import get_default_discuss_async_client from google.generativeai import string_utils from google.generativeai.types import discuss_types +from google.generativeai.types import helper_types from google.generativeai.types import model_types from google.generativeai.types import palm_safety_types @@ -316,7 +317,7 @@ def chat( top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, client: glm.DiscussServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: """Calls the API and returns a `types.ChatResponse` containing the response. @@ -416,7 +417,7 @@ async def chat_async( top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, client: glm.DiscussServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: request = _make_generate_message_request( model=model, @@ -469,7 +470,7 @@ def last(self, message: discuss_types.MessageOptions): def reply( self, message: discuss_types.MessageOptions, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: if isinstance(self._client, glm.DiscussServiceAsyncClient): raise TypeError(f"reply can't be called on an async client, use reply_async instead.") @@ -537,7 +538,7 @@ def _build_chat_response( def _generate_response( request: glm.GenerateMessageRequest, client: glm.DiscussServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> ChatResponse: if request_options is None: request_options = {} @@ -553,7 +554,7 @@ def _generate_response( async def _generate_response_async( request: glm.GenerateMessageRequest, client: glm.DiscussServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> ChatResponse: if request_options is None: request_options = {} @@ -574,7 +575,7 @@ def count_message_tokens( messages: discuss_types.MessagesOptions | None = None, model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL, client: glm.DiscussServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.TokenCount: model = model_types.make_model_name(model) prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages) diff --git a/google/generativeai/embedding.py b/google/generativeai/embedding.py index 375d5dcb4..14fff1737 100644 --- a/google/generativeai/embedding.py +++ b/google/generativeai/embedding.py @@ -14,8 +14,6 @@ # limitations under the License. from __future__ import annotations -import dataclasses -from collections.abc import Iterable, Sequence, Mapping import itertools from typing import Any, Iterable, overload, TypeVar, Union, Mapping @@ -24,7 +22,7 @@ from google.generativeai.client import get_default_generative_client from google.generativeai.client import get_default_generative_async_client -from google.generativeai import string_utils +from google.generativeai.types import helper_types from google.generativeai.types import text_types from google.generativeai.types import model_types from google.generativeai.types import content_types @@ -104,7 +102,7 @@ def embed_content( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -116,7 +114,7 @@ def embed_content( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -127,7 +125,7 @@ def embed_content( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """Calls the API to create embeddings for content passed in. @@ -224,7 +222,7 @@ async def embed_content_async( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -236,7 +234,7 @@ async def embed_content_async( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -247,7 +245,7 @@ async def embed_content_async( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceAsyncClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """The async version of `genai.embed_content`.""" model = model_types.make_model_name(model) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 258baf0d3..86b87ee90 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -15,9 +15,9 @@ import google.api_core.exceptions from google.ai import generativelanguage as glm from google.generativeai import client -from google.generativeai import string_utils from google.generativeai.types import content_types from google.generativeai.types import generation_types +from google.generativeai.types import helper_types from google.generativeai.types import safety_types @@ -179,7 +179,7 @@ def generate_content( stream: bool = False, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> generation_types.GenerateContentResponse: """A multipurpose function to generate responses from the model. @@ -279,7 +279,7 @@ async def generate_content_async( stream: bool = False, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> generation_types.AsyncGenerateContentResponse: """The async version of `GenerativeModel.generate_content`.""" request = self._prepare_request( @@ -326,7 +326,7 @@ def count_tokens( safety_settings: safety_types.SafetySettingOptions | None = None, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> glm.CountTokensResponse: if request_options is None: request_options = {} @@ -353,7 +353,7 @@ async def count_tokens_async( safety_settings: safety_types.SafetySettingOptions | None = None, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> glm.CountTokensResponse: if request_options is None: request_options = {} diff --git a/google/generativeai/models.py b/google/generativeai/models.py index 16932921a..46861e7f5 100644 --- a/google/generativeai/models.py +++ b/google/generativeai/models.py @@ -21,6 +21,7 @@ from google.generativeai import operations from google.generativeai.client import get_default_model_client from google.generativeai.types import model_types +from google.generativeai.types import helper_types from google.api_core import operation from google.api_core import protobuf_helpers from google.protobuf import field_mask_pb2 @@ -31,7 +32,7 @@ def get_model( name: model_types.AnyModelNameOptions, *, client=None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.Model | model_types.TunedModel: """Given a model name, fetch the `types.Model` @@ -64,7 +65,7 @@ def get_base_model( name: model_types.BaseModelNameOptions, *, client=None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.Model: """Get the `types.Model` for the given base model name. @@ -101,7 +102,7 @@ def get_tuned_model( name: model_types.TunedModelNameOptions, *, client=None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: """Get the `types.TunedModel` for the given tuned model name. @@ -164,7 +165,7 @@ def list_models( *, page_size: int | None = 50, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.ModelsIterable: """Lists available models. @@ -198,7 +199,7 @@ def list_tuned_models( *, page_size: int | None = 50, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModelsIterable: """Lists available models. @@ -246,7 +247,7 @@ def create_tuned_model( input_key: str = "text_input", output_key: str = "output", client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> operations.CreateTunedModelOperation: """Launches a tuning job to create a TunedModel. @@ -346,6 +347,7 @@ def create_tuned_model( top_k=top_k, tuning_task=tuning_task, ) + operation = client.create_tuned_model( dict(tuned_model_id=id, tuned_model=tuned_model), **request_options ) @@ -359,7 +361,7 @@ def update_tuned_model( updates: None = None, *, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: pass @@ -370,7 +372,7 @@ def update_tuned_model( updates: dict[str, Any], *, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: pass @@ -380,7 +382,7 @@ def update_tuned_model( updates: dict[str, Any] | None = None, *, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: """Push updates to the tuned model. Only certain attributes are updatable.""" if request_options is None: @@ -397,6 +399,7 @@ def update_tuned_model( "`updates` must be a `dict`.\n" f"got: {type(updates)}" ) + tuned_model = client.get_tuned_model(name=name, **request_options) updates = flatten_update_paths(updates) @@ -438,7 +441,7 @@ def _apply_update(thing, path, value): def delete_tuned_model( tuned_model: model_types.TunedModelNameOptions, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> None: if request_options is None: request_options = {} diff --git a/google/generativeai/retriever.py b/google/generativeai/retriever.py index dfd5e9026..190a222a6 100644 --- a/google/generativeai/retriever.py +++ b/google/generativeai/retriever.py @@ -23,6 +23,7 @@ from google.generativeai.client import get_default_retriever_client from google.generativeai.client import get_default_retriever_async_client +from google.generativeai.types import helper_types from google.generativeai.types.model_types import idecode_time from google.generativeai.types import retriever_types @@ -31,7 +32,7 @@ def create_corpus( name: str | None = None, display_name: str | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: """ Create a new `Corpus` in the retriever service, and return it as a `retriever_types.Corpus` instance. @@ -78,7 +79,7 @@ async def create_corpus_async( name: str | None = None, display_name: str | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: """This is the async version of `retriever.create_corpus`.""" if request_options is None: @@ -106,7 +107,7 @@ async def create_corpus_async( def get_corpus( name: str, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: # fmt: skip """ Fetch a specific `Corpus` from the retriever service. @@ -139,7 +140,7 @@ def get_corpus( async def get_corpus_async( name: str, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: # fmt: skip """This is the async version of `retriever.get_corpus`.""" if request_options is None: @@ -164,7 +165,7 @@ def delete_corpus( name: str, force: bool = False, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): # fmt: skip """ Delete a `Corpus` from the service. @@ -191,7 +192,7 @@ async def delete_corpus_async( name: str, force: bool = False, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): # fmt: skip """This is the async version of `retriever.delete_corpus`.""" if request_options is None: @@ -211,7 +212,7 @@ def list_corpora( *, page_size: Optional[int] = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[retriever_types.Corpus]: """ List the Corpuses you own in the service. @@ -242,7 +243,7 @@ async def list_corpora_async( *, page_size: Optional[int] = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[retriever_types.Corpus]: """This is the async version of `retriever.list_corpora`.""" if request_options is None: diff --git a/google/generativeai/text.py b/google/generativeai/text.py index e51090e1f..bb5ec4bdd 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -23,6 +23,7 @@ from google.generativeai.client import get_default_text_client from google.generativeai import string_utils +from google.generativeai.types import helper_types from google.generativeai.types import text_types from google.generativeai.types import model_types from google.generativeai import models @@ -139,7 +140,7 @@ def generate_text( safety_settings: palm_safety_types.SafetySettingOptions | None = None, stop_sequences: str | Iterable[str] | None = None, client: glm.TextServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.Completion: """Calls the API and returns a `types.Completion` containing the response. @@ -215,7 +216,7 @@ def __init__(self, **kwargs): def _generate_response( request: glm.GenerateTextRequest, client: glm.TextServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Completion: """ Generates a response using the provided `glm.GenerateTextRequest` and client. @@ -251,7 +252,7 @@ def count_text_tokens( model: model_types.AnyModelNameOptions, prompt: str, client: glm.TextServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.TokenCount: base_model = models.get_base_model_name(model) @@ -274,7 +275,7 @@ def generate_embeddings( model: model_types.BaseModelNameOptions, text: str, client: glm.TextServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -283,7 +284,7 @@ def generate_embeddings( model: model_types.BaseModelNameOptions, text: Sequence[str], client: glm.TextServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -291,7 +292,7 @@ def generate_embeddings( model: model_types.BaseModelNameOptions, text: str | Sequence[str], client: glm.TextServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """Calls the API to create an embedding for the text passed in. diff --git a/google/generativeai/types/__init__.py b/google/generativeai/types/__init__.py index dc0a76761..21768bbe6 100644 --- a/google/generativeai/types/__init__.py +++ b/google/generativeai/types/__init__.py @@ -19,6 +19,7 @@ from google.generativeai.types.discuss_types import * from google.generativeai.types.file_types import * from google.generativeai.types.generation_types import * +from google.generativeai.types.helper_types import * from google.generativeai.types.model_types import * from google.generativeai.types.safety_types import * from google.generativeai.types.text_types import * diff --git a/google/generativeai/types/helper_types.py b/google/generativeai/types/helper_types.py new file mode 100644 index 000000000..3eba4d3f9 --- /dev/null +++ b/google/generativeai/types/helper_types.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import google.api_core.timeout +import google.api_core.retry + +import collections +import dataclasses + +from typing import Union +from typing_extensions import TypedDict + +__all__ = ["RequestOptions", "RequestOptionsType"] + + +class RequestOptionsDict(TypedDict, total=False): + retry: google.api_core.retry.Retry + timeout: Union[int, float, google.api_core.timeout.TimeToDeadlineTimeout] + + +@dataclasses.dataclass(init=False) +class RequestOptions(collections.abc.Mapping): + """Request options + + >>> import google.generativeai as genai + >>> from google.generativeai.types import RequestOptions + >>> from google.api_core import retry + >>> + >>> model = genai.GenerativeModel() + >>> response = model.generate_content('Hello', + ... request_options=RequestOptions( + ... retry=retry.Retry(initial=10, multiplier=2, maximum=60, timeout=300))) + >>> response = model.generate_content('Hello', + ... request_options=RequestOptions(timeout=600))) + + Args: + retry: Refer to [retry docs](https://googleapis.dev/python/google-api-core/latest/retry.html) for details. + timeout: In seconds (or provide a [TimeToDeadlineTimeout](https://googleapis.dev/python/google-api-core/latest/timeout.html) object). + """ + + retry: google.api_core.retry.Retry | None + timeout: int | float | google.api_core.timeout.TimeToDeadlineTimeout | None + + def __init__( + self, + *, + retry: google.api_core.retry.Retry | None = None, + timeout: int | float | google.api_core.timeout.TimeToDeadlineTimeout | None = None, + ): + self.retry = retry + self.timeout = timeout + + # Inherit from Mapping for **unpacking + def __getitem__(self, item): + if item == "retry": + return self.retry + elif item == "timeout": + return self.timeout + else: + raise KeyError(f'RequestOptions does not have a "{item}" key') + + def __iter__(self): + yield "retry" + yield "timeout" + + def __len__(self): + return 2 + + +RequestOptionsType = Union[RequestOptions, RequestOptionsDict] diff --git a/google/generativeai/types/retriever_types.py b/google/generativeai/types/retriever_types.py index 72859f207..538d3924a 100644 --- a/google/generativeai/types/retriever_types.py +++ b/google/generativeai/types/retriever_types.py @@ -27,6 +27,8 @@ from google.generativeai.client import get_default_retriever_client from google.generativeai.client import get_default_retriever_async_client from google.generativeai import string_utils +from google.generativeai.types import helper_types + from google.generativeai.types import permission_types from google.generativeai.types.model_types import idecode_time from google.generativeai.utils import flatten_update_paths @@ -261,7 +263,7 @@ def create_document( display_name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """ Request to create a `Document`. @@ -312,7 +314,7 @@ async def create_document_async( display_name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """This is the async version of `Corpus.create_document`.""" if request_options is None: @@ -346,7 +348,7 @@ def get_document( self, name: str, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """ Get information about a specific `Document`. @@ -375,7 +377,7 @@ async def get_document_async( self, name: str, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """This is the async version of `Corpus.get_document`.""" if request_options is None: @@ -401,7 +403,7 @@ def update( self, updates: dict[str, Any], client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Update a list of fields for a specified `Corpus`. @@ -439,7 +441,7 @@ async def update_async( self, updates: dict[str, Any], client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Corpus.update`.""" if request_options is None: @@ -470,7 +472,7 @@ def query( metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[RelevantChunk]: """ Query a corpus for information. @@ -524,7 +526,7 @@ async def query_async( metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[RelevantChunk]: """This is the async version of `Corpus.query`.""" if request_options is None: @@ -566,7 +568,7 @@ def delete_document( name: str, force: bool = False, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Delete a document in the corpus. @@ -593,7 +595,7 @@ async def delete_document_async( name: str, force: bool = False, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Corpus.delete_document`.""" if request_options is None: @@ -612,7 +614,7 @@ def list_documents( self, page_size: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[Document]: """ List documents in corpus. @@ -642,7 +644,7 @@ async def list_documents_async( self, page_size: int | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[Document]: """This is the async version of `Corpus.list_documents`.""" if request_options is None: @@ -744,7 +746,7 @@ def create_chunk( name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Chunk: """ Create a `Chunk` object which has textual data. @@ -801,7 +803,7 @@ async def create_chunk_async( name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Chunk: """This is the async version of `Document.create_chunk`.""" if request_options is None: @@ -900,7 +902,7 @@ def batch_create_chunks( self, chunks: BatchCreateChunkOptions, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Create chunks within the given document. @@ -926,7 +928,7 @@ async def batch_create_chunks_async( self, chunks: BatchCreateChunkOptions, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_create_chunk`.""" if request_options is None: @@ -943,7 +945,7 @@ def get_chunk( self, name: str, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Get information about a specific chunk. @@ -972,7 +974,7 @@ async def get_chunk_async( self, name: str, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.get_chunk`.""" if request_options is None: @@ -992,7 +994,7 @@ def list_chunks( self, page_size: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[Chunk]: """ List chunks of a document. @@ -1018,7 +1020,7 @@ async def list_chunks_async( self, page_size: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[Chunk]: """This is the async version of `Document.list_chunks`.""" if request_options is None: @@ -1037,7 +1039,7 @@ def query( metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> list[RelevantChunk]: """ Query a `Document` in the `Corpus` for information. @@ -1090,7 +1092,7 @@ async def query_async( metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> list[RelevantChunk]: """This is the async version of `Document.query`.""" if request_options is None: @@ -1137,7 +1139,7 @@ def update( self, updates: dict[str, Any], client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Update a list of fields for a specified document. @@ -1174,7 +1176,7 @@ async def update_async( self, updates: dict[str, Any], client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.update`.""" if request_options is None: @@ -1202,7 +1204,7 @@ def batch_update_chunks( self, chunks: BatchUpdateChunksOptions, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Update multiple chunks within the same document. @@ -1299,7 +1301,7 @@ async def batch_update_chunks_async( self, chunks: BatchUpdateChunksOptions, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_update_chunks`.""" if request_options is None: @@ -1387,7 +1389,7 @@ def delete_chunk( self, name: str, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, # fmt: {} + request_options: helper_types.RequestOptionsType | None = None, # fmt: {} ): """ Delete a `Chunk`. @@ -1412,7 +1414,7 @@ async def delete_chunk_async( self, name: str, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, # fmt: {} + request_options: helper_types.RequestOptionsType | None = None, # fmt: {} ): """This is the async version of `Document.delete_chunk`.""" if request_options is None: @@ -1431,7 +1433,7 @@ def batch_delete_chunks( self, chunks: BatchDeleteChunkOptions, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Delete multiple `Chunk`s from a document. @@ -1464,7 +1466,7 @@ async def batch_delete_chunks_async( self, chunks: BatchDeleteChunkOptions, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_delete_chunks`.""" if request_options is None: @@ -1570,7 +1572,7 @@ def update( self, updates: dict[str, Any], client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Update a list of fields for a specified `Chunk`. @@ -1619,7 +1621,7 @@ async def update_async( self, updates: dict[str, Any], client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Chunk.update`.""" if request_options is None: diff --git a/tests/test_answer.py b/tests/test_answer.py index 6fa12603c..4128567f4 100644 --- a/tests/test_answer.py +++ b/tests/test_answer.py @@ -21,6 +21,7 @@ import google.ai.generativelanguage as glm from google.generativeai import answer +from google.generativeai import types as genai_types from google.generativeai import client from absl.testing import absltest from absl.testing import parameterized @@ -239,7 +240,7 @@ def test_generate_answer(self): def test_generate_answer_called_with_request_options(self): self.client.generate_answer = mock.MagicMock() request = mock.ANY - request_options = {"timeout": 120} + request_options = genai_types.RequestOptions(timeout=120) answer.generate_answer(contents=[], inline_passages=[], request_options=request_options) diff --git a/tests/test_discuss.py b/tests/test_discuss.py index 183ccd0c3..7db0a63d8 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy -from typing import Any import unittest.mock diff --git a/tests/test_helpers.py b/tests/test_helpers.py new file mode 100644 index 000000000..0c2de7f29 --- /dev/null +++ b/tests/test_helpers.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pathlib +import copy +import collections +from typing import Union + +from absl.testing import parameterized + +import google.ai.generativelanguage as glm + +from google.generativeai import client +from google.generativeai import models +from google.generativeai.types import model_types +from google.generativeai.types import helper_types + +from google.api_core import retry + + +class MockModelClient: + def __init__(self, test): + self.test = test + + def get_model( + self, + request: Union[glm.GetModelRequest, None] = None, + *, + name=None, + timeout=None, + retry=None + ) -> glm.Model: + if request is None: + request = glm.GetModelRequest(name=name) + self.test.assertIsInstance(request, glm.GetModelRequest) + self.test.observed_requests.append(request) + self.test.observed_timeout.append(timeout) + self.test.observed_retry.append(retry) + response = copy.copy(self.test.responses["get_model"]) + return response + + +class HelperTests(parameterized.TestCase): + + def setUp(self): + self.client = MockModelClient(self) + client._client_manager.clients["model"] = self.client + + self.observed_requests = [] + self.observed_retry = [] + self.observed_timeout = [] + self.responses = collections.defaultdict(list) + + @parameterized.named_parameters( + ["None", None, None, None], + ["Empty", {}, None, None], + ["Timeout", {"timeout": 7}, 7, None], + ["Retry", {"retry": retry.Retry(timeout=7)}, None, retry.Retry(timeout=7)], + [ + "RequestOptions", + helper_types.RequestOptions(timeout=7, retry=retry.Retry(multiplier=3)), + 7, + retry.Retry(multiplier=3), + ], + ) + def test_get_model(self, request_options, expected_timeout, expected_retry): + self.responses = {"get_model": glm.Model(name="models/fake-bison-001")} + + _ = models.get_model("models/fake-bison-001", request_options=request_options) + + self.assertEqual(self.observed_timeout[0], expected_timeout) + self.assertEqual(str(self.observed_retry[0]), str(expected_retry)) diff --git a/tests/test_models.py b/tests/test_models.py index e971ef86d..f39ed3a2c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -31,6 +31,7 @@ from google.generativeai import models from google.generativeai import client from google.generativeai.types import model_types +from google.generativeai import types as genai_types import pandas as pd @@ -470,7 +471,7 @@ def test_get_model_called_with_request_options(self): def test_get_tuned_model_called_with_request_options(self): self.client.get_tuned_model = unittest.mock.MagicMock() name = unittest.mock.ANY - request_options = {"timeout": 120} + request_options = genai_types.RequestOptions(timeout=120) try: models.get_model(name="tunedModels/", request_options=request_options) From 3193c3e80bd812427833d56a4337b22f3b6ce3e5 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 17 May 2024 16:38:04 -0700 Subject: [PATCH 04/17] Cleanup file data handling. (#321) * Fix typing Change-Id: I09fb7df098da08bb24337b08cbbc997c4c62af1e * format Change-Id: I7923e2257a6d935bb1b60f4f81dd7910387292ae --- google/generativeai/types/content_types.py | 44 ++++++---------------- google/generativeai/types/file_types.py | 32 ++++++++++++++++ 2 files changed, 43 insertions(+), 33 deletions(-) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index 67c1338bf..ce72dddbc 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -128,7 +128,7 @@ def _convert_dict(d: Mapping) -> glm.Content | glm.Part | glm.Blob: if "inline_data" in part: part["inline_data"] = to_blob(part["inline_data"]) if "file_data" in part: - part["file_data"] = to_file_data(part["file_data"]) + part["file_data"] = file_types.to_file_data(part["file_data"]) return glm.Part(part) elif is_blob_dict(d): blob = d @@ -176,43 +176,21 @@ def to_blob(blob: BlobType) -> glm.Blob: ) -class FileDataDict(TypedDict): - mime_type: str - file_uri: str - - -FileDataType = Union[FileDataDict, glm.FileData, file_types.File] - - -def to_file_data(file_data: FileDataType): - if isinstance(file_data, dict): - if "file_uri" in file_data: - file_data = glm.FileData(file_data) - else: - file_data = glm.File(file_data) - - if isinstance(file_data, file_types.File): - file_data = file_data.to_proto() - - if isinstance(file_data, (glm.File, file_types.File)): - file_data = glm.FileData( - mime_type=file_data.mime_type, - file_uri=file_data.uri, - ) - - if isinstance(file_data, glm.FileData): - return file_data - else: - raise TypeError(f"Could not convert a {type(file_data)} to `FileData`") - - class PartDict(TypedDict): text: str inline_data: BlobType # When you need a `Part` accept a part object, part-dict, blob or string -PartType = Union[glm.Part, PartDict, BlobType, str, glm.FunctionCall, glm.FunctionResponse] +PartType = Union[ + glm.Part, + PartDict, + BlobType, + str, + glm.FunctionCall, + glm.FunctionResponse, + file_types.FileDataType, +] def is_part_dict(d): @@ -236,7 +214,7 @@ def to_part(part: PartType): elif isinstance(part, glm.FileData): return glm.Part(file_data=part) elif isinstance(part, (glm.File, file_types.File)): - return glm.Part(file_data=to_file_data(part)) + return glm.Part(file_data=file_types.to_file_data(part)) elif isinstance(part, glm.FunctionCall): return glm.Part(function_call=part) elif isinstance(part, glm.FunctionResponse): diff --git a/google/generativeai/types/file_types.py b/google/generativeai/types/file_types.py index d18404871..46b0f37b9 100644 --- a/google/generativeai/types/file_types.py +++ b/google/generativeai/types/file_types.py @@ -15,6 +15,8 @@ from __future__ import annotations import datetime +from typing import Union +from typing_extensions import TypedDict from google.generativeai.client import get_default_file_client @@ -73,3 +75,33 @@ def state(self) -> glm.File.State: def delete(self): client = get_default_file_client() client.delete_file(name=self.name) + + +class FileDataDict(TypedDict): + mime_type: str + file_uri: str + + +FileDataType = Union[FileDataDict, glm.FileData, glm.File, File] + + +def to_file_data(file_data: FileDataType): + if isinstance(file_data, dict): + if "file_uri" in file_data: + file_data = glm.FileData(file_data) + else: + file_data = glm.File(file_data) + + if isinstance(file_data, File): + file_data = file_data.to_proto() + + if isinstance(file_data, glm.File): + file_data = glm.FileData( + mime_type=file_data.mime_type, + file_uri=file_data.uri, + ) + + if isinstance(file_data, glm.FileData): + return file_data + else: + raise TypeError(f"Could not convert a {type(file_data)} to `FileData`") From 472a3e34e61be91a72265b1f569328d9235fddfa Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 17 May 2024 16:39:57 -0700 Subject: [PATCH 05/17] Add GenerateContentResponse.to_dict() (#337) * GenerateContentResponse.to_dict() Change-Id: I2042d4387fe216f28b9c48b98eab8a71447fb98f * docstring Change-Id: Icec923e7d782ef5e6c2e36faefe790ea01ed93ad --- google/generativeai/types/generation_types.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/google/generativeai/types/generation_types.py b/google/generativeai/types/generation_types.py index b7a342b37..f0c9de4c7 100644 --- a/google/generativeai/types/generation_types.py +++ b/google/generativeai/types/generation_types.py @@ -356,6 +356,18 @@ def __init__( else: self._error = None + def to_dict(self): + """Returns the result as a JSON-compatible dict. + + Note: This doesn't capture the iterator state when streaming, it only captures the accumulated + `GenerateContentResponse` fields. + + >>> import json + >>> response = model.generate_content('Hello?') + >>> json.dumps(response.to_dict()) + """ + return type(self._result).to_dict(self._result) + @property def candidates(self): """The list of candidate responses. @@ -428,7 +440,7 @@ def __str__(self) -> str: else: _iterator = f"<{self._iterator.__class__.__name__}>" - as_dict = type(self._result).to_dict(self._result) + as_dict = self.to_dict() json_str = json.dumps(as_dict, indent=2) _result = f"glm.GenerateContentResponse({json_str})" From 88f7ab3c0d5e529b2e2ccc4aa049ef3aabc389c5 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 17 May 2024 16:41:57 -0700 Subject: [PATCH 06/17] Fix argument description. (#338) --- google/generativeai/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/generativeai/models.py b/google/generativeai/models.py index 46861e7f5..f25be57c6 100644 --- a/google/generativeai/models.py +++ b/google/generativeai/models.py @@ -276,7 +276,7 @@ def create_tuned_model( * A `glm.Dataset`, or * An `Iterable` of: *`glm.TuningExample`, - * {'text_input': text_input, 'output': output} dicts, or + * `{'text_input': text_input, 'output': output}` dicts * `(text_input, output)` tuples. * A `Mapping` of `Iterable[str]` - use `input_key` and `output_key` to choose which columns to use as the input/output From 05877f721794205c9757325e01874fee6e4c653d Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 21 May 2024 06:38:54 -0700 Subject: [PATCH 07/17] Allow empty contents with count_tokens (#342) Change-Id: Ic20e2f88427d2e4fbc97847cf5c2df1f80a9a5a1 --- google/generativeai/generative_models.py | 9 ++++--- tests/test_generative_models.py | 33 +++++++++++++++++++----- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 86b87ee90..4ef1a3608 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -129,9 +129,6 @@ def _prepare_request( tool_config: content_types.ToolConfigType | None, ) -> glm.GenerateContentRequest: """Creates a `glm.GenerateContentRequest` from raw inputs.""" - if not contents: - raise TypeError("contents must not be empty") - tools_lib = self._get_tools_lib(tools) if tools_lib is not None: tools_lib = tools_lib.to_proto() @@ -235,6 +232,9 @@ def generate_content( tools: `glm.Tools` more info coming soon. request_options: Options for the request. """ + if not contents: + raise TypeError("contents must not be empty") + request = self._prepare_request( contents=contents, generation_config=generation_config, @@ -282,6 +282,9 @@ async def generate_content_async( request_options: helper_types.RequestOptionsType | None = None, ) -> generation_types.AsyncGenerateContentResponse: """The async version of `GenerativeModel.generate_content`.""" + if not contents: + raise TypeError("contents must not be empty") + request = self._prepare_request( contents=contents, generation_config=generation_config, diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 6fabd59e9..3b0c27814 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -21,6 +21,10 @@ TEST_IMAGE_DATA = TEST_IMAGE_PATH.read_bytes() +def noop(x: int): + return x + + def simple_part(text: str) -> glm.Content: return glm.Content({"parts": [{"text": text}]}) @@ -725,18 +729,33 @@ def test_system_instruction(self, instruction, expected_instr): self.assertEqual(req.system_instruction, expected_instr) @parameterized.named_parameters( - ["basic", "Hello"], - ["list", ["Hello"]], + ["basic", {"contents": "Hello"}], + ["list", {"contents": ["Hello"]}], [ "list2", - [{"text": "Hello"}, {"inline_data": {"data": b"PNG!", "mime_type": "image/png"}}], + { + "contents": [ + {"text": "Hello"}, + {"inline_data": {"data": b"PNG!", "mime_type": "image/png"}}, + ] + }, ], - ["contents", [{"role": "user", "parts": ["hello"]}]], + [ + "contents", + {"contents": [{"role": "user", "parts": ["hello"]}]}, + ], + ["empty", {}], + [ + "system_instruction", + {"system_instruction": ["You are a cat"]}, + ], + ["tools", {"tools": [noop]}], ) - def test_count_tokens_smoke(self, contents): + def test_count_tokens_smoke(self, kwargs): + si = kwargs.pop("system_instruction", None) self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)] - model = generative_models.GenerativeModel("gemini-pro-vision") - response = model.count_tokens(contents) + model = generative_models.GenerativeModel("gemini-pro-vision", system_instruction=si) + response = model.count_tokens(**kwargs) self.assertEqual(type(response).to_dict(response), {"total_tokens": 7}) @parameterized.named_parameters( From f987fde53eb95fb25520c8c96f09284680461258 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 21 May 2024 08:53:22 -0700 Subject: [PATCH 08/17] improve the no-credentials error message, fail fast for no-credentials in colab. (#352) * improve the no-credentials error message Change-Id: I294bd094b56287ed923716dce9ea705ef3135f5b * patch colab credentials Change-Id: I5a3cb3168448a565eb3cdc8a0063ae041c41a260 * format Change-Id: I013d506bdcb64092daddedcf3e30f3728a8f3e30 --- google/generativeai/client.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/google/generativeai/client.py b/google/generativeai/client.py index 31160757f..c0d0b3304 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -1,9 +1,9 @@ from __future__ import annotations import os +import contextlib import dataclasses import pathlib -import re import types from typing import Any, cast from collections.abc import Sequence @@ -12,6 +12,8 @@ import google.ai.generativelanguage as glm from google.auth import credentials as ga_credentials +from google.auth import exceptions as ga_exceptions +from google import auth from google.api_core import client_options as client_options_lib from google.api_core import gapic_v1 from google.api_core import operations_v1 @@ -30,6 +32,18 @@ GENAI_API_DISCOVERY_URL = "https://generativelanguage.googleapis.com/$discovery/rest" +@contextlib.contextmanager +def patch_colab_gce_credentials(): + get_gce = auth._default._get_gce_credentials + if "COLAB_RELEASE_TAG" in os.environ: + auth._default._get_gce_credentials = lambda *args, **kwargs: (None, None) + + try: + yield + finally: + auth._default._get_gce_credentials = get_gce + + class FileServiceClient(glm.FileServiceClient): def __init__(self, *args, **kwargs): self._discovery_api = None @@ -183,7 +197,17 @@ def make_client(self, name): if not self.client_config: configure() - client = cls(**self.client_config) + try: + with patch_colab_gce_credentials(): + client = cls(**self.client_config) + except ga_exceptions.DefaultCredentialsError as e: + e.args = ( + "\n No API_KEY or ADC found. Please either:\n" + " - Set the `GOOGLE_API_KEY` environment variable.\n" + " - Manually pass the key with `genai.configure(api_key=my_api_key)`.\n" + " - Or set up Application Default Credentials, see https://ai.google.dev/gemini-api/docs/oauth for more information.", + ) + raise e if not self.default_metadata: return client From f3616428f55a8964e8c13025b4c6e9b05dc1f63f Mon Sep 17 00:00:00 2001 From: Faisal Alsrheed <47912291+Faisal-Alsrheed@users.noreply.github.com> Date: Tue, 21 May 2024 23:53:09 +0300 Subject: [PATCH 09/17] Enhance functions Documentation and Improve Error Messages (#350) * Add Functions docstring to be https://ai.google.dev/api/python/google/generativeai * Fix and improve Error messages * Fix invalid keys and improve error messages * Improve error messages * Refix to_blob --- google/generativeai/answer.py | 17 ++--- google/generativeai/client.py | 14 ++-- google/generativeai/discuss.py | 35 ++++++---- google/generativeai/embedding.py | 20 ++++-- google/generativeai/files.py | 5 +- google/generativeai/generative_models.py | 48 +++++++------- google/generativeai/models.py | 48 ++++++++------ google/generativeai/operations.py | 8 ++- google/generativeai/permission.py | 8 +-- google/generativeai/responder.py | 12 ++-- google/generativeai/retriever.py | 16 ++--- google/generativeai/text.py | 8 ++- google/generativeai/types/content_types.py | 66 +++++++++++-------- google/generativeai/types/file_types.py | 6 +- google/generativeai/types/generation_types.py | 33 ++++------ google/generativeai/types/helper_types.py | 5 +- google/generativeai/types/model_types.py | 18 +++-- google/generativeai/types/permission_types.py | 9 ++- google/generativeai/types/retriever_types.py | 63 +++++++++++------- google/generativeai/utils.py | 2 + tests/test_client.py | 2 +- 21 files changed, 259 insertions(+), 184 deletions(-) diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index 1b419be57..4b9d9f97c 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -92,7 +92,7 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP if not isinstance(source, Iterable): raise TypeError( - f"The 'source' argument must be an instance of 'GroundingPassagesOptions', but got a '{type(source).__name__}' object instead." + f"Invalid input: The 'source' argument must be an instance of 'GroundingPassagesOptions'. Received a '{type(source).__name__}' object instead." ) passages = [] @@ -156,9 +156,9 @@ def _make_semantic_retriever_config( source["source"] = _maybe_get_source_name(source["source"]) else: raise TypeError( - "Could create a `glm.SemanticRetrieverConfig` from:\n" - f" type: {type(source)}\n" - f" value: {source}" + f"Invalid input: Failed to create a 'glm.SemanticRetrieverConfig' from the provided source. " + f"Received type: {type(source).__name__}, " + f"Received value: {source}" ) if source["query"] is None: @@ -208,7 +208,8 @@ def _make_generate_answer_request( if inline_passages is not None and semantic_retriever is not None: raise ValueError( - "Either `inline_passages` or `semantic_retriever_config` must be set, not both." + f"Invalid configuration: Please set either 'inline_passages' or 'semantic_retriever_config', but not both. " + f"Received for inline_passages: {inline_passages}, and for semantic_retriever: {semantic_retriever}." ) elif inline_passages is not None: inline_passages = _make_grounding_passages(inline_passages) @@ -216,7 +217,8 @@ def _make_generate_answer_request( semantic_retriever = _make_semantic_retriever_config(semantic_retriever, contents[-1]) else: raise TypeError( - f"The source must be either an `inline_passages` xor `semantic_retriever_config`, but both are `None`" + f"Invalid configuration: Either 'inline_passages' or 'semantic_retriever_config' must be provided, but currently both are 'None'. " + f"Received for inline_passages: {inline_passages}, and for semantic_retriever: {semantic_retriever}." ) if answer_style: @@ -245,8 +247,7 @@ def generate_answer( client: glm.GenerativeServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): - """ - Calls the GenerateAnswer API and returns a `types.Answer` containing the response. + """Calls the GenerateAnswer API and returns a `types.Answer` containing the response. You can pass a literal list of text chunks: diff --git a/google/generativeai/client.py b/google/generativeai/client.py index c0d0b3304..d969889d0 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -52,7 +52,9 @@ def __init__(self, *args, **kwargs): def _setup_discovery_api(self): api_key = self._client_options.api_key if api_key is None: - raise ValueError("Uploading to the File API requires an API key.") + raise ValueError( + "Invalid operation: Uploading to the File API requires an API key. Please provide a valid API key." + ) request = googleapiclient.http.HttpRequest( http=httplib2.Http(), @@ -95,7 +97,9 @@ def create_file( class FileServiceAsyncClient(glm.FileServiceAsyncClient): async def create_file(self, *args, **kwargs): - raise NotImplementedError("`create_file` is not yet implemented for the async client.") + raise NotImplementedError( + "The `create_file` method is currently not supported for the asynchronous client." + ) @dataclasses.dataclass @@ -123,7 +127,7 @@ def configure( client_info: gapic_v1.client_info.ClientInfo | None = None, default_metadata: Sequence[tuple[str, str]] = (), ) -> None: - """Captures default client configuration. + """Initializes default client configurations using specified parameters or environment variables. If no API key has been provided (either directly, or on `client_options`) and the `GOOGLE_API_KEY` environment variable is set, it will be used as the API key. @@ -149,7 +153,9 @@ def configure( if had_api_key_value: if api_key is not None: - raise ValueError("You can't set both `api_key` and `client_options['api_key']`.") + raise ValueError( + "Invalid configuration: Please set either `api_key` or `client_options['api_key']`, but not both." + ) else: if api_key is None: # If no key is provided explicitly, attempt to load one from the diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index 35611ae69..b084ccad8 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -69,7 +69,9 @@ def _make_messages( elif len(even_authors) == 1: even_author = even_authors.pop() else: - raise discuss_types.AuthorError("Authors are not strictly alternating") + raise discuss_types.AuthorError( + "Invalid sequence: Authors in the discussion must alternate strictly." + ) odd_authors = set(msg.author for msg in messages[1::2] if msg.author) if not odd_authors: @@ -77,7 +79,9 @@ def _make_messages( elif len(odd_authors) == 1: odd_author = odd_authors.pop() else: - raise discuss_types.AuthorError("Authors are not strictly alternating") + raise discuss_types.AuthorError( + "Invalid sequence: Authors in the discussion must alternate strictly." + ) if all(msg.author for msg in messages): return messages @@ -130,8 +134,8 @@ def _make_examples_from_flat( raise ValueError( textwrap.dedent( f"""\ - You must pass `Primer` objects, pairs of messages, or an *even* number of messages, got: - {len(examples)} messages""" + Invalid input: You must pass either `Primer` objects, pairs of messages, or an even number of messages. + Currently, {len(examples)} messages were provided, which is an odd number.""" ) ) result = [] @@ -186,7 +190,7 @@ def _make_examples( else: if not ("input" in first and "output" in first): raise TypeError( - "To create an `Example` from a dict you must supply both `input` and an `output` keys" + "Invalid dictionary format: To create an `Example` instance, the dictionary must contain both `input` and `output` keys." ) else: if isinstance(first, discuss_types.MESSAGE_OPTIONS): @@ -232,8 +236,7 @@ def _make_message_prompt_dict( flat_prompt = (context is not None) or (examples is not None) or (messages is not None) if flat_prompt: raise ValueError( - "You can't set `prompt`, and its fields `(context, examples, messages)`" - " at the same time" + "Invalid configuration: Either `prompt` or its fields `(context, examples, messages)` should be set, but not both simultaneously." ) if isinstance(prompt, glm.MessagePrompt): return prompt @@ -245,7 +248,7 @@ def _make_message_prompt_dict( keys = set(prompt.keys()) if not keys.issubset(discuss_types.MESSAGE_PROMPT_KEYS): raise KeyError( - f"Found extra entries in the prompt dictionary: {keys - discuss_types.MESSAGE_PROMPT_KEYS}" + f"Invalid prompt dictionary: Extra entries found that are not recognized: {keys - discuss_types.MESSAGE_PROMPT_KEYS}. Please check the keys." ) examples = prompt.get("examples", None) @@ -319,7 +322,7 @@ def chat( client: glm.DiscussServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: - """Calls the API and returns a `types.ChatResponse` containing the response. + """Calls the API to initiate a chat with a model using provided parameters Args: model: Which model to call, as a string or a `types.Model`. @@ -419,6 +422,7 @@ async def chat_async( client: glm.DiscussServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: + """Calls the API asynchronously to initiate a chat with a model using provided parameters""" request = _make_generate_message_request( model=model, context=context, @@ -473,12 +477,13 @@ def reply( request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: if isinstance(self._client, glm.DiscussServiceAsyncClient): - raise TypeError(f"reply can't be called on an async client, use reply_async instead.") + raise TypeError( + "Invalid operation: The 'reply' method cannot be called on an asynchronous client. Please use the 'reply_async' method instead." + ) if self.last is None: raise ValueError( - "The last response from the model did not return any candidates.\n" - "Check the `.filters` attribute to see why the responses were filtered:\n" - f"{self.filters}" + f"Invalid operation: No candidates returned from the model's last response. " + f"Please inspect the '.filters' attribute to understand why responses were filtered out. Current filters: {self.filters}" ) request = self.to_dict() @@ -497,7 +502,7 @@ async def reply_async( ) -> discuss_types.ChatResponse: if isinstance(self._client, glm.DiscussServiceClient): raise TypeError( - f"reply_async can't be called on a non-async client, use reply instead." + "Invalid method call: `reply_async` is not supported on a non-async client. Please use the `reply` method instead." ) request = self.to_dict() request.pop("candidates") @@ -577,6 +582,8 @@ def count_message_tokens( client: glm.DiscussServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.TokenCount: + """Calls the API to calculate the number of tokens used in the prompt.""" + model = model_types.make_model_name(model) prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages) diff --git a/google/generativeai/embedding.py b/google/generativeai/embedding.py index 14fff1737..8218ec11d 100644 --- a/google/generativeai/embedding.py +++ b/google/generativeai/embedding.py @@ -82,7 +82,9 @@ def to_task_type(x: EmbeddingTaskTypeOptions) -> EmbeddingTaskType: def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]: if n < 1: - raise ValueError(f"Batch size `n` must be >0, got: {n}") + raise ValueError( + f"Invalid input: The batch size 'n' must be a positive integer. You entered: {n}. Please enter a number greater than 0." + ) batch = [] for item in iterable: batch.append(item) @@ -167,11 +169,13 @@ def embed_content( if title and to_task_type(task_type) is not EmbeddingTaskType.RETRIEVAL_DOCUMENT: raise ValueError( - "If a title is specified, the task must be a retrieval document type task." + f"Invalid task type: When a title is specified, the task must be of a 'retrieval document' type. Received task type: {task_type} and title: {title}." ) if output_dimensionality and output_dimensionality < 0: - raise ValueError("`output_dimensionality` must be a non-negative integer.") + raise ValueError( + f"Invalid value: `output_dimensionality` must be a non-negative integer. Received: {output_dimensionality}." + ) if task_type: task_type = to_task_type(task_type) @@ -247,7 +251,8 @@ async def embed_content_async( client: glm.GenerativeServiceAsyncClient = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: - """The async version of `genai.embed_content`.""" + """Calls the API to create async embeddings for content passed in.""" + model = model_types.make_model_name(model) if request_options is None: @@ -258,11 +263,12 @@ async def embed_content_async( if title and to_task_type(task_type) is not EmbeddingTaskType.RETRIEVAL_DOCUMENT: raise ValueError( - "If a title is specified, the task must be a retrieval document type task." + f"Invalid task type: When a title is specified, the task must be of a 'retrieval document' type. Received task type: {task_type} and title: {title}." ) - if output_dimensionality and output_dimensionality < 0: - raise ValueError("`output_dimensionality` must be a non-negative integer.") + raise ValueError( + f"Invalid value: `output_dimensionality` must be a non-negative integer. Received: {output_dimensionality}." + ) if task_type: task_type = to_task_type(task_type) diff --git a/google/generativeai/files.py b/google/generativeai/files.py index 13535c47f..386592225 100644 --- a/google/generativeai/files.py +++ b/google/generativeai/files.py @@ -37,7 +37,7 @@ def upload_file( display_name: str | None = None, resumable: bool = True, ) -> file_types.File: - """Uploads a file using a supported file service. + """Calls the API to upload a file using a supported file service. Args: path: The path to the file to be uploaded. @@ -73,6 +73,7 @@ def upload_file( def list_files(page_size=100) -> Iterable[file_types.File]: + """Calls the API to list files using a supported file service.""" client = get_default_file_client() response = client.list_files(glm.ListFilesRequest(page_size=page_size)) @@ -81,11 +82,13 @@ def list_files(page_size=100) -> Iterable[file_types.File]: def get_file(name) -> file_types.File: + """Calls the API to retrieve a specified file using a supported file service.""" client = get_default_file_client() return file_types.File(client.get_file(name=name)) def delete_file(name): + """Calls the API to permanently delete a specified file using a supported file service.""" if isinstance(name, (file_types.File, glm.File)): name = name.name request = glm.DeleteFileRequest(name=name) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 4ef1a3608..6fc5554c4 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -265,8 +265,8 @@ def generate_content( except google.api_core.exceptions.InvalidArgument as e: if e.message.startswith("Request payload size exceeds the limit:"): e.message += ( - " Please upload your files with the File API instead." - "`f = genai.upload_file(path); m.generate_content(['tell me about this file:', f])`" + " The file size is too large. Please use the File API to upload your files instead. " + "Example: `f = genai.upload_file(path); m.generate_content(['tell me about this file:', f])`" ) raise @@ -315,8 +315,8 @@ async def generate_content_async( except google.api_core.exceptions.InvalidArgument as e: if e.message.startswith("Request payload size exceeds the limit:"): e.message += ( - " Please upload your files with the File API instead." - "`f = genai.upload_file(path); m.generate_content(['tell me about this file:', f])`" + " The file size is too large. Please use the File API to upload your files instead. " + "Example: `f = genai.upload_file(path); m.generate_content(['tell me about this file:', f])`" ) raise @@ -393,7 +393,9 @@ def start_chat( history: An iterable of `glm.Content` objects, or equivalents to initialize the session. """ if self._generation_config.get("candidate_count", 1) > 1: - raise ValueError("Can't chat with `candidate_count > 1`") + raise ValueError( + "Invalid configuration: The chat functionality does not support `candidate_count` greater than 1." + ) return ChatSession( model=self, history=history, @@ -478,8 +480,7 @@ def send_message( """ if self.enable_automatic_function_calling and stream: raise NotImplementedError( - "The `google.generativeai` SDK does not yet support `stream=True` with " - "`enable_automatic_function_calling=True`" + "Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`." ) tools_lib = self.model._get_tools_lib(tools) @@ -494,7 +495,9 @@ def send_message( generation_config = generation_types.to_generation_config_dict(generation_config) if generation_config.get("candidate_count", 1) > 1: - raise ValueError("Can't chat with `candidate_count > 1`") + raise ValueError( + "Invalid configuration: The chat functionality does not support `candidate_count` greater than 1." + ) response = self.model.generate_content( contents=history, @@ -538,7 +541,7 @@ def _get_function_calls(self, response) -> list[glm.FunctionCall]: candidates = response.candidates if len(candidates) != 1: raise ValueError( - f"Automatic function calling only works with 1 candidate, got: {len(candidates)}" + f"Invalid number of candidates: Automatic function calling only works with 1 candidate, but {len(candidates)} were provided." ) parts = candidates[0].content.parts function_calls = [part.function_call for part in parts if part and "function_call" in part] @@ -557,8 +560,8 @@ def _handle_afc( for fc in function_calls: fr = tools_lib(fc) assert fr is not None, ( - "This should never happen, it should only return None if the declaration" - "is not callable, and that's guarded against above." + "Unexpected state: The function reference (fr) should never be None. It should only return None if the declaration " + "is not callable, which is checked earlier in the code." ) function_response_parts.append(fr) @@ -591,8 +594,7 @@ async def send_message_async( """The async version of `ChatSession.send_message`.""" if self.enable_automatic_function_calling and stream: raise NotImplementedError( - "The `google.generativeai` SDK does not yet support `stream=True` with " - "`enable_automatic_function_calling=True`" + "Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`." ) tools_lib = self.model._get_tools_lib(tools) @@ -607,7 +609,9 @@ async def send_message_async( generation_config = generation_types.to_generation_config_dict(generation_config) if generation_config.get("candidate_count", 1) > 1: - raise ValueError("Can't chat with `candidate_count > 1`") + raise ValueError( + "Invalid configuration: The chat functionality does not support `candidate_count` greater than 1." + ) response = await self.model.generate_content_async( contents=history, @@ -648,8 +652,8 @@ async def _handle_afc_async( for fc in function_calls: fr = tools_lib(fc) assert fr is not None, ( - "This should never happen, it should only return None if the declaration" - "is not callable, and that's guarded against above." + "Unexpected state: The function reference (fr) should never be None. It should only return None if the declaration " + "is not callable, which is checked earlier in the code." ) function_response_parts.append(fr) @@ -709,13 +713,11 @@ def history(self) -> list[glm.Content]: if last._error is not None: raise generation_types.BrokenResponseError( - "Can not build a coherent chat history after a broken " - "streaming response " - "(See the previous Exception fro details). " - "To inspect the last response object, use `chat.last`." - "To remove the last request/response `Content` objects from the chat " - "call `last_send, last_received = chat.rewind()` and continue " - "without it." + "Unable to build a coherent chat history due to a broken streaming response. " + "Refer to the previous exception for details. " + "To inspect the last response object, use `chat.last`. " + "To remove the last request/response `Content` objects from the chat, " + "call `last_send, last_received = chat.rewind()` and continue without it." ) from last._error sent = self._last_sent diff --git a/google/generativeai/models.py b/google/generativeai/models.py index f25be57c6..1f9e836e7 100644 --- a/google/generativeai/models.py +++ b/google/generativeai/models.py @@ -34,7 +34,7 @@ def get_model( client=None, request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.Model | model_types.TunedModel: - """Given a model name, fetch the `types.Model` + """Calls the API to fetch a model by name. ``` import pprint @@ -57,7 +57,7 @@ def get_model( return get_tuned_model(name, client=client, request_options=request_options) else: raise ValueError( - f"Model names must start with `models/` or `tunedModels/`. Received: {name}" + f"Invalid model name: Model names must start with `models/` or `tunedModels/`. Received: {name}" ) @@ -67,7 +67,7 @@ def get_base_model( client=None, request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.Model: - """Get the `types.Model` for the given base model name. + """Calls the API to fetch a base model by name. ``` import pprint @@ -91,7 +91,9 @@ def get_base_model( name = model_types.make_model_name(name) if not name.startswith("models/"): - raise ValueError(f"Base model names must start with `models/`, received: {name}") + raise ValueError( + f"Invalid model name: Base model names must start with `models/`. Received: {name}" + ) result = client.get_model(name=name, **request_options) result = type(result).to_dict(result) @@ -104,7 +106,7 @@ def get_tuned_model( client=None, request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: - """Get the `types.TunedModel` for the given tuned model name. + """Calls the API to fetch a tuned model by name. ``` import pprint @@ -129,7 +131,9 @@ def get_tuned_model( name = model_types.make_model_name(name) if not name.startswith("tunedModels/"): - raise ValueError("Tuned model names must start with `tunedModels/` received: {name}") + raise ValueError( + f"Invalid model name: Tuned model names must start with `tunedModels/`. Received: {name}" + ) result = client.get_tuned_model(name=name, **request_options) @@ -139,6 +143,8 @@ def get_tuned_model( def get_base_model_name( model: model_types.AnyModelNameOptions, client: glm.ModelServiceClient | None = None ): + """Calls the API to fetch the base model name of a model.""" + if isinstance(model, str): if model.startswith("tunedModels/"): model = get_model(model, client=client) @@ -156,7 +162,10 @@ def get_base_model_name( if not base_model: base_model = model.tuned_model_source.base_model else: - raise TypeError(f"Cannot understand model: {model}") + raise TypeError( + f"Invalid model: The provided model '{model}' is not recognized or supported. " + "Supported types are: str, model_types.TunedModel, model_types.Model, glm.Model, and glm.TunedModel." + ) return base_model @@ -167,7 +176,7 @@ def list_models( client: glm.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.ModelsIterable: - """Lists available models. + """Calls the API to list all available models. ``` import pprint @@ -201,7 +210,7 @@ def list_tuned_models( client: glm.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModelsIterable: - """Lists available models. + """Calls the API to list all tuned models. ``` import pprint @@ -249,7 +258,7 @@ def create_tuned_model( client: glm.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> operations.CreateTunedModelOperation: - """Launches a tuning job to create a TunedModel. + """Calls the API to initiate a tuning process that optimizes a model for specific data, returning an operation object to track and manage the tuning progress. Since tuning a model can take significant time, this API doesn't wait for the tuning to complete. Instead, it returns a `google.api_core.operation.Operation` object that lets you check on the @@ -322,7 +331,9 @@ def create_tuned_model( } } else: - ValueError(f"Not understood: `{source_model=}`") + raise ValueError( + f"Invalid model name: The provided model '{source_model}' does not match any known model patterns such as 'models/' or 'tunedModels/'" + ) training_data = model_types.encode_tuning_data( training_data, input_key=input_key, output_key=output_key @@ -384,7 +395,8 @@ def update_tuned_model( client: glm.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: - """Push updates to the tuned model. Only certain attributes are updatable.""" + """Calls the API to puch updates to a specified tuned model where only certain attributes are updatable.""" + if request_options is None: request_options = {} @@ -395,9 +407,7 @@ def update_tuned_model( name = tuned_model if not isinstance(updates, dict): raise TypeError( - "When calling `update_tuned_model(name:str, updates: dict)`,\n" - "`updates` must be a `dict`.\n" - f"got: {type(updates)}" + f"Invalid argument type: In the function `update_tuned_model(name:str, updates: dict)`, the `updates` argument must be of type `dict`. Received type: {type(updates).__name__}." ) tuned_model = client.get_tuned_model(name=name, **request_options) @@ -411,8 +421,7 @@ def update_tuned_model( elif isinstance(tuned_model, glm.TunedModel): if updates is not None: raise ValueError( - "When calling `update_tuned_model(tuned_model:glm.TunedModel, updates=None)`," - "`updates` must not be set." + "Invalid argument: When calling `update_tuned_model(tuned_model:glm.TunedModel, updates=None)`, the `updates` argument must not be set." ) name = tuned_model.name @@ -420,8 +429,7 @@ def update_tuned_model( field_mask = protobuf_helpers.field_mask(was._pb, tuned_model._pb) else: raise TypeError( - "For `update_tuned_model(tuned_model:dict|glm.TunedModel)`," - f"`tuned_model` must be a `dict` or a `glm.TunedModel`. Got a: `{type(tuned_model)}`" + f"Invalid argument type: In the function `update_tuned_model(tuned_model:dict|glm.TunedModel)`, the `tuned_model` argument must be of type `dict` or `glm.TunedModel`. Received type: {type(tuned_model).__name__}." ) result = client.update_tuned_model( @@ -443,6 +451,8 @@ def delete_tuned_model( client: glm.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> None: + """Calls the API to delete a specified tuned model""" + if request_options is None: request_options = {} diff --git a/google/generativeai/operations.py b/google/generativeai/operations.py index d492a9dee..01c0a6b14 100644 --- a/google/generativeai/operations.py +++ b/google/generativeai/operations.py @@ -27,6 +27,8 @@ def list_operations(*, client=None) -> Iterator[CreateTunedModelOperation]: + """Calls the API to list all operations""" + if client is None: client = client_lib.get_default_operations_client() @@ -41,6 +43,7 @@ def list_operations(*, client=None) -> Iterator[CreateTunedModelOperation]: def get_operation(name: str, *, client=None) -> CreateTunedModelOperation: + """Calls the API to get a specific operation""" if client is None: client = client_lib.get_default_operations_client() @@ -49,8 +52,9 @@ def get_operation(name: str, *, client=None) -> CreateTunedModelOperation: def delete_operation(name: str, *, client=None): - """Raises: - google.api_core.exceptions.MethodNotImplemented: Not implemented.""" + """Calls the API to delete a specific operation""" + + # Raises:google.api_core.exceptions.MethodNotImplemented: Not implemented. if client is None: client = client_lib.get_default_operations_client() diff --git a/google/generativeai/permission.py b/google/generativeai/permission.py index b502f9a60..b2b7c15e1 100644 --- a/google/generativeai/permission.py +++ b/google/generativeai/permission.py @@ -90,9 +90,8 @@ def _construct_name( # if name is not provided, then try to construct name via provided resource_name and permission_id. if not (resource_name and permission_id): raise ValueError( - "Either `name` or (`resource_name` and `permission_id`) must be provided." + f"Invalid arguments: Either `name` or both `resource_name` and `permission_id` must be provided. Received name: {name}, resource_name: {resource_name}, permission_id: {permission_id}." ) - if resource_type: resource_type = _to_resource_type(resource_type) else: @@ -100,8 +99,7 @@ def _construct_name( resource_path_components = resource_name.split("/") if len(resource_path_components) != 2: raise ValueError( - f"Invalid `resource_name` format. Expected format: \ - `resource_type/resource_name`. Got: `{resource_name}` instead." + f"Invalid `resource_name` format: Expected format is `resource_type/resource_name` (2 components). Received: `{resource_name}` with {len(resource_path_components)} components." ) resource_type = _to_resource_type(resource_path_components[0]) @@ -128,7 +126,7 @@ def get_permission( permission_id: str | int | None = None, resource_type: str | None = None, ) -> permission_types.Permission: - """Get information about a permission by name. + """Calls the API to retrieve detailed information about a specific permission based on resource type and permission identifiers Args: name: The name of the permission. diff --git a/google/generativeai/responder.py b/google/generativeai/responder.py index 238e7e13a..814bf3581 100644 --- a/google/generativeai/responder.py +++ b/google/generativeai/responder.py @@ -284,7 +284,7 @@ def _make_function_declaration( return CallableFunctionDeclaration.from_function(fun) else: raise TypeError( - "Expected an instance of `genai.FunctionDeclaraionType`. Got a:\n" f" {type(fun)=}\n", + f"Invalid argument type: Expected an instance of `genai.FunctionDeclarationType`. Received type: {type(fun).__name__}.", fun, ) @@ -363,7 +363,7 @@ def _make_tool(tool: ToolType) -> Tool: return Tool(function_declarations=[tool]) except Exception as e: raise TypeError( - "Expected an instance of `genai.ToolType`. Got a:\n" f" {type(tool)=}", + f"Invalid argument type: Expected an instance of `genai.ToolType`. Received type: {type(tool).__name__}.", tool, ) from e @@ -380,8 +380,7 @@ def __init__(self, tools: Iterable[ToolType]): name = declaration.name if name in self._index: raise ValueError( - f"A `FunctionDeclaration` named {name} is already defined. " - "Each `FunctionDeclaration` must be uniquely named." + f"Invalid operation: A `FunctionDeclaration` named '{name}' is already defined. Each `FunctionDeclaration` must have a unique name." ) self._index[declaration.name] = declaration @@ -483,7 +482,7 @@ def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCa obj["mode"] = to_function_calling_mode(mode) else: raise TypeError( - f"Could not convert input to `glm.FunctionCallingConfig`: \n'" f" type: {type(obj)}\n", + f"Invalid argument type: Could not convert input to `glm.FunctionCallingConfig`. Received type: {type(obj).__name__}.", obj, ) @@ -507,5 +506,6 @@ def to_tool_config(obj: ToolConfigType) -> glm.ToolConfig: return glm.ToolConfig(**obj) else: raise TypeError( - f"Could not convert input to `glm.ToolConfig`: \n'" f" type: {type(obj)}\n", obj + f"Invalid argument type: Could not convert input to `glm.ToolConfig`. Received type: {type(obj).__name__}.", + obj, ) diff --git a/google/generativeai/retriever.py b/google/generativeai/retriever.py index 190a222a6..e295bc5b7 100644 --- a/google/generativeai/retriever.py +++ b/google/generativeai/retriever.py @@ -34,10 +34,7 @@ def create_corpus( client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: - """ - Create a new `Corpus` in the retriever service, and return it as a `retriever_types.Corpus` instance. - - Users can specify either a name or display_name. + """Calls the API to create a new `Corpus` by specifying either a corpus resource name as an ID or a display name, and returns the created `Corpus`. Args: name: The corpus resource name (ID). The name must be alphanumeric and fewer @@ -109,8 +106,7 @@ def get_corpus( client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: # fmt: skip - """ - Fetch a specific `Corpus` from the retriever service. + """Calls the API to fetch a `Corpus` by name and returns the `Corpus`. Args: name: The `Corpus` name. @@ -143,6 +139,7 @@ async def get_corpus_async( request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: # fmt: skip """This is the async version of `retriever.get_corpus`.""" + if request_options is None: request_options = {} @@ -167,13 +164,13 @@ def delete_corpus( client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): # fmt: skip - """ - Delete a `Corpus` from the service. + """Calls the API to remove a `Corpus` from the service, optionally deleting associated `Document`s and objects if the `force` parameter is set to true. Args: name: The `Corpus` name. force: If set to true, any `Document`s and objects related to this `Corpus` will also be deleted. request_options: Options for the request. + """ if request_options is None: request_options = {} @@ -214,8 +211,7 @@ def list_corpora( client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[retriever_types.Corpus]: - """ - List the Corpuses you own in the service. + """Calls the API to list all `Corpora` in the service and returns a list of paginated `Corpora`. Args: page_size: Maximum number of `Corpora` to request. diff --git a/google/generativeai/text.py b/google/generativeai/text.py index bb5ec4bdd..b8b814754 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -70,7 +70,9 @@ def _make_text_prompt(prompt: str | dict[str, str]) -> glm.TextPrompt: elif isinstance(prompt, dict): return glm.TextPrompt(prompt) else: - TypeError("Expected string or dictionary for text prompt.") + raise TypeError( + "Invalid argument type: Expected a string or dictionary for the text prompt." + ) def _make_generate_text_request( @@ -142,7 +144,7 @@ def generate_text( client: glm.TextServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.Completion: - """Calls the API and returns a `types.Completion` containing the response. + """Calls the API to generate text based on the provided prompt. Args: model: Which model to call, as a string or a `types.Model`. @@ -254,6 +256,8 @@ def count_text_tokens( client: glm.TextServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.TokenCount: + """Calls the API to count the number of tokens in the text prompt.""" + base_model = models.get_base_model_name(model) if request_options is None: diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index ce72dddbc..169683608 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -93,10 +93,9 @@ def image_to_blob(image) -> glm.Blob: name = image.filename if name is None: raise ValueError( - "Can only convert `IPython.display.Image` if " - "it is constructed from a local file (Image(filename=...))." + "Conversion failed. The `IPython.display.Image` can only be converted if " + "it is constructed from a local file. Please ensure you are using the format: Image(filename='...')." ) - mime_type, _ = mimetypes.guess_type(name) if mime_type is None: mime_type = "image/unknown" @@ -104,10 +103,10 @@ def image_to_blob(image) -> glm.Blob: return glm.Blob(mime_type=mime_type, data=image.data) raise TypeError( - "Could not convert image. expected an `Image` type" - "(`PIL.Image.Image` or `IPython.display.Image`).\n" - f"Got a: {type(image)}\n" - f"Value: {image}" + "Image conversion failed. The input was expected to be of type `Image` " + "(either `PIL.Image.Image` or `IPython.display.Image`).\n" + f"However, received an object of type: {type(image)}.\n" + f"Object Value: {image}" ) @@ -135,11 +134,11 @@ def _convert_dict(d: Mapping) -> glm.Content | glm.Part | glm.Blob: return glm.Blob(blob) else: raise KeyError( - "Could not recognize the intended type of the `dict`. " - "A `Content` should have a 'parts' key. " - "A `Part` should have a 'inline_data' or a 'text' key. " - "A `Blob` should have 'mime_type' and 'data' keys. " - f"Got keys: {list(d.keys())}" + "Unable to determine the intended type of the `dict`. " + "For `Content`, a 'parts' key is expected. " + "For `Part`, either an 'inline_data' or a 'text' key is expected. " + "For `Blob`, both 'mime_type' and 'data' keys are expected. " + f"However, the provided dictionary has the following keys: {list(d.keys())}" ) @@ -244,7 +243,9 @@ def is_content_dict(d): def to_content(content: ContentType): if not content: - raise ValueError("content must not be empty") + raise ValueError( + "Invalid input: 'content' argument must not be empty. Please provide a non-empty value." + ) if isinstance(content, Mapping): content = _convert_dict(content) @@ -266,9 +267,9 @@ def strict_to_content(content: StrictContentType): return content else: raise TypeError( - "Expected a `glm.Content` or a `dict(parts=...)`.\n" - f"Got type: {type(content)}\n" - f"Value: {content}\n" + "Invalid input type. Expected a `glm.Content` or a `dict` with a 'parts' key.\n" + f"However, received an object of type: {type(content)}.\n" + f"Object Value: {content}" ) @@ -455,14 +456,20 @@ def convert_to_nullable(schema): anyof = schema.pop("anyOf", None) if anyof is not None: if len(anyof) != 2: - raise ValueError("Type Unions are not supported (except for Optional)") + raise ValueError( + "Invalid input: Type Unions are not supported, except for `Optional` types. " + "Please provide an `Optional` type or a non-Union type." + ) a, b = anyof if a == {"type": "null"}: schema.update(b) elif b == {"type": "null"}: schema.update(a) else: - raise ValueError("Type Unions are not supported (except for Optional)") + raise ValueError( + "Invalid input: Type Unions are not supported, except for `Optional` types. " + "Please provide an `Optional` type or a non-Union type." + ) schema["nullable"] = True properties = schema.get("properties", None) @@ -600,8 +607,9 @@ def _make_function_declaration( return CallableFunctionDeclaration.from_function(fun) else: raise TypeError( - "Expected an instance of `genai.FunctionDeclaraionType`. Got a:\n" f" {type(fun)=}\n", - fun, + "Invalid input type. Expected an instance of `genai.FunctionDeclarationType`.\n" + f"However, received an object of type: {type(fun)}.\n" + f"Object Value: {fun}" ) @@ -679,8 +687,9 @@ def _make_tool(tool: ToolType) -> Tool: return Tool(function_declarations=[tool]) except Exception as e: raise TypeError( - "Expected an instance of `genai.ToolType`. Got a:\n" f" {type(tool)=}", - tool, + "Invalid input type. Expected an instance of `genai.ToolType`.\n" + f"However, received an object of type: {type(tool)}.\n" + f"Object Value: {tool}" ) from e @@ -696,8 +705,8 @@ def __init__(self, tools: Iterable[ToolType]): name = declaration.name if name in self._index: raise ValueError( - f"A `FunctionDeclaration` named {name} is already defined. " - "Each `FunctionDeclaration` must be uniquely named." + f"Invalid operation: A `FunctionDeclaration` named '{name}' is already defined. " + "Each `FunctionDeclaration` must have a unique name. Please use a different name." ) self._index[declaration.name] = declaration @@ -799,8 +808,9 @@ def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCa obj["mode"] = to_function_calling_mode(mode) else: raise TypeError( - f"Could not convert input to `glm.FunctionCallingConfig`: \n'" f" type: {type(obj)}\n", - obj, + "Invalid input type. Failed to convert input to `glm.FunctionCallingConfig`.\n" + f"Received an object of type: {type(obj)}.\n" + f"Object Value: {obj}" ) return glm.FunctionCallingConfig(obj) @@ -823,5 +833,7 @@ def to_tool_config(obj: ToolConfigType) -> glm.ToolConfig: return glm.ToolConfig(**obj) else: raise TypeError( - f"Could not convert input to `glm.ToolConfig`: \n'" f" type: {type(obj)}\n", obj + "Invalid input type. Failed to convert input to `glm.ToolConfig`.\n" + f"Received an object of type: {type(obj)}.\n" + f"Object Value: {obj}" ) diff --git a/google/generativeai/types/file_types.py b/google/generativeai/types/file_types.py index 46b0f37b9..eb4c0902d 100644 --- a/google/generativeai/types/file_types.py +++ b/google/generativeai/types/file_types.py @@ -104,4 +104,8 @@ def to_file_data(file_data: FileDataType): if isinstance(file_data, glm.FileData): return file_data else: - raise TypeError(f"Could not convert a {type(file_data)} to `FileData`") + raise TypeError( + f"Invalid input type. Failed to convert input to `FileData`.\n" + f"Received an object of type: {type(file_data)}.\n" + f"Object Value: {file_data}" + ) diff --git a/google/generativeai/types/generation_types.py b/google/generativeai/types/generation_types.py index f0c9de4c7..119f97162 100644 --- a/google/generativeai/types/generation_types.py +++ b/google/generativeai/types/generation_types.py @@ -185,8 +185,8 @@ def _normalize_schema(generation_config): elif isinstance(response_schema, types.GenericAlias): if not str(response_schema).startswith("list["): raise ValueError( - f"Could not understand {response_schema}, expected: `int`, `float`, `str`, `bool`, " - "`typing_extensions.TypedDict`, `dataclass`, or `list[...]`" + f"Invalid input: Could not understand the type of '{response_schema}'. " + "Expected one of the following types: `int`, `float`, `str`, `bool`, `typing_extensions.TypedDict`, `dataclass`, or `list[...]`." ) response_schema = content_types._schema_for_class(response_schema) @@ -214,9 +214,9 @@ def to_generation_config_dict(generation_config: GenerationConfigType): return generation_config else: raise TypeError( - "Did not understand `generation_config`, expected a `dict` or" - f" `GenerationConfig`\nGot type: {type(generation_config)}\nValue:" - f" {generation_config}" + "Invalid input type. Expected a `dict` or `GenerationConfig` for `generation_config`.\n" + f"However, received an object of type: {type(generation_config)}.\n" + f"Object Value: {generation_config}" ) @@ -389,14 +389,13 @@ def parts(self): candidates = self.candidates if not candidates: raise ValueError( - "The `response.parts` quick accessor only works for a single candidate, " - "but none were returned. Check the `response.prompt_feedback` to see if the prompt was blocked." + "Invalid operation: The `response.parts` quick accessor requires a single candidate, " + "but none were returned. Please check the `response.prompt_feedback` to determine if the prompt was blocked." ) if len(candidates) > 1: raise ValueError( - "The `response.parts` quick accessor only works with a " - "single candidate. With multiple candidates use " - "result.candidates[index].text" + "Invalid operation: The `response.parts` quick accessor requires a single candidate. " + "For multiple candidates, please use `result.candidates[index].text`." ) parts = candidates[0].content.parts return parts @@ -411,18 +410,14 @@ def text(self): parts = self.parts if not parts: raise ValueError( - "The `response.text` quick accessor only works when the response contains a valid " - "`Part`, but none was returned. Check the `candidate.safety_ratings` to see if the " - "response was blocked." + "Invalid operation: The `response.text` quick accessor requires the response to contain a valid `Part`, " + "but none were returned. Please check the `candidate.safety_ratings` to determine if the response was blocked." ) - if len(parts) != 1 or "text" not in parts[0]: raise ValueError( - "The `response.text` quick accessor only works for " - "simple (single-`Part`) text responses. This response is not simple text. " - "Use the `result.parts` accessor or the full " - "`result.candidates[index].content.parts` lookup " - "instead." + "Invalid operation: The `response.text` quick accessor requires a simple (single-`Part`) text response. " + "This response is not simple text. Please use the `result.parts` accessor or the full " + "`result.candidates[index].content.parts` lookup instead." ) return parts[0].text diff --git a/google/generativeai/types/helper_types.py b/google/generativeai/types/helper_types.py index 3eba4d3f9..fd8c1882b 100644 --- a/google/generativeai/types/helper_types.py +++ b/google/generativeai/types/helper_types.py @@ -71,7 +71,10 @@ def __getitem__(self, item): elif item == "timeout": return self.timeout else: - raise KeyError(f'RequestOptions does not have a "{item}" key') + raise KeyError( + f"Invalid key: 'RequestOptions' does not contain a key named '{item}'. " + "Please use a valid key." + ) def __iter__(self): yield "retry" diff --git a/google/generativeai/types/model_types.py b/google/generativeai/types/model_types.py index 0f85acfe8..32b3bddae 100644 --- a/google/generativeai/types/model_types.py +++ b/google/generativeai/types/model_types.py @@ -287,12 +287,18 @@ def _convert_dict(data, input_key, output_key): try: inputs = data[input_key] except KeyError: - raise KeyError(f'input_key is "{input_key}", but data has keys: {sorted(data.keys())}') + raise KeyError( + f"Invalid key: The input key '{input_key}' does not exist in the data. " + f"Available keys are: {sorted(data.keys())}." + ) try: outputs = data[output_key] except KeyError: - raise KeyError(f'output_key is "{output_key}", but data has keys: {sorted(data.keys())}') + raise KeyError( + f"Invalid key: The output key '{output_key}' does not exist in the data. " + f"Available keys are: {sorted(data.keys())}." + ) for i, o in zip(inputs, outputs): new_data.append(glm.TuningExample({"text_input": str(i), "output": str(o)})) @@ -347,10 +353,14 @@ def make_model_name(name: AnyModelNameOptions): elif isinstance(name, str): name = name else: - raise TypeError("Expected: str, Model, or TunedModel") + raise TypeError( + "Invalid input type. Expected one of the following types: `str`, `Model`, or `TunedModel`." + ) if not (name.startswith("models/") or name.startswith("tunedModels/")): - raise ValueError(f"Model names should start with `models/` or `tunedModels/`, got: {name}") + raise ValueError( + f"Invalid model name: '{name}'. Model names should start with 'models/' or 'tunedModels/'." + ) return name diff --git a/google/generativeai/types/permission_types.py b/google/generativeai/types/permission_types.py index db1867695..fde2ddacc 100644 --- a/google/generativeai/types/permission_types.py +++ b/google/generativeai/types/permission_types.py @@ -152,7 +152,7 @@ def update( for update_path in updates: if update_path != "role": raise ValueError( - f"As of now, only `role` can be updated for `Permission`. Got: `{update_path}` instead." + f"Invalid update path: '{update_path}'. Currently, only the 'role' attribute can be updated for 'Permission'." ) field_mask = field_mask_pb2.FieldMask() @@ -182,7 +182,7 @@ async def update_async( for update_path in updates: if update_path != "role": raise ValueError( - f"As of now, only `role` can be updated for `Permission`. Got: `{update_path}` instead." + f"Invalid update path: '{update_path}'. Currently, only the 'role' attribute can be updated for 'Permission'." ) field_mask = field_mask_pb2.FieldMask() @@ -271,12 +271,11 @@ def _make_create_permission_request( if email_address and grantee_type == GranteeType.EVERYONE: raise ValueError( - f"Cannot limit access for: `{email_address}` when `grantee_type` is set to `EVERYONE`." + f"Invalid operation: Access cannot be limited for a specific email address ('{email_address}') when 'grantee_type' is set to 'EVERYONE'." ) - if not email_address and grantee_type != GranteeType.EVERYONE: raise ValueError( - f"`email_address` must be specified unless `grantee_type` is set to `EVERYONE`." + f"Invalid operation: An 'email_address' must be provided when 'grantee_type' is not set to 'EVERYONE'. Currently, 'grantee_type' is set to '{grantee_type}' and 'email_address' is '{email_address if email_address else 'not provided'}'." ) permission = glm.Permission( diff --git a/google/generativeai/types/retriever_types.py b/google/generativeai/types/retriever_types.py index 538d3924a..294e0b64c 100644 --- a/google/generativeai/types/retriever_types.py +++ b/google/generativeai/types/retriever_types.py @@ -158,8 +158,8 @@ def _to_proto(self): elif isinstance(c.value, (int, float)): kwargs["numeric_value"] = float(c.value) else: - ValueError( - f"The value for the condition must be either a string or an integer/float, but got {c.value}." + raise ValueError( + f"Invalid value type: The value for the condition must be either a string or an integer/float. Received: '{c.value}' of type {type(c.value).__name__}." ) kwargs["operation"] = c.operation @@ -195,10 +195,9 @@ def _to_proto(self): elif isinstance(self.value, (int, float)): kwargs["numeric_value"] = float(self.value) else: - ValueError( - f"The value for a custom_metadata specification must be either a list of string values, a string, or an integer/float, but got {self.value}." + raise ValueError( + f"Invalid value type: The value for a custom_metadata specification must be either a list of string values, a string, or an integer/float. Received: '{self.value}' of type {type(self.value).__name__}." ) - return glm.CustomMetadata(key=self.key, **kwargs) @classmethod @@ -231,7 +230,7 @@ def make_custom_metadata(cm: CustomMetadataOptions) -> CustomMetadata: return CustomMetadata._from_dict(cm) else: raise ValueError( # nofmt - "Could not create a `CustomMetadata` from:\n" f" type: {type(cm)}\n" f" value: {cm}" + f"Invalid input: Could not create a 'CustomMetadata' from the provided input. Received type: '{type(cm).__name__}', value: '{cm}'." ) @@ -425,7 +424,9 @@ def update( # At this time, only `display_name` can be updated for item in updates: if item != "display_name": - raise ValueError("At this time, only `display_name` can be updated for `Corpus`.") + raise ValueError( + "Invalid operation: Currently, only the 'display_name' attribute can be updated for a 'Corpus'." + ) field_mask = field_mask_pb2.FieldMask() for path in updates.keys(): @@ -454,7 +455,9 @@ async def update_async( # At this time, only `display_name` can be updated for item in updates: if item != "display_name": - raise ValueError("At this time, only `display_name` can be updated for `Corpus`.") + raise ValueError( + "Invalid operation: Currently, only the 'display_name' attribute can be updated for a 'Corpus'." + ) field_mask = field_mask_pb2.FieldMask() for path in updates.keys(): @@ -494,7 +497,9 @@ def query( if results_count: if results_count > 100: - raise ValueError("Number of results returned must be between 1 and 100.") + raise ValueError( + "Invalid operation: The number of results returned must be between 1 and 100." + ) m_f_ = [] if metadata_filters: @@ -537,7 +542,9 @@ async def query_async( if results_count: if results_count > 100: - raise ValueError("Number of results returned must be between 1 and 100.") + raise ValueError( + "Invalid operation: The number of results returned must be between 1 and 100." + ) m_f_ = [] if metadata_filters: @@ -869,7 +876,7 @@ def _make_chunk(self, chunk: ChunkOptions) -> glm.Chunk: return glm.Chunk(chunk) else: raise TypeError( - f"Could not convert instance of `{type(chunk)}` chunk:" f"value: {chunk}" + f"Invalid input: Could not convert instance of type '{type(chunk).__name__}' to a chunk. Received value: '{chunk}'." ) def _make_batch_create_chunk_request( @@ -1060,7 +1067,9 @@ def query( if results_count: if results_count < 0 or results_count >= 100: - raise ValueError("Number of results returned must be between 1 and 100.") + raise ValueError( + "Invalid operation: The number of results returned must be between 1 and 100." + ) m_f_ = [] if metadata_filters: @@ -1103,7 +1112,9 @@ async def query_async( if results_count: if results_count < 0 or results_count >= 100: - raise ValueError("Number of results returned must be between 1 and 100.") + raise ValueError( + "Invalid operation: The number of results returned must be between 1 and 100." + ) m_f_ = [] if metadata_filters: @@ -1161,7 +1172,9 @@ def update( # At this time, only `display_name` can be updated for item in updates: if item != "display_name": - raise ValueError("At this time, only `display_name` can be updated for `Document`.") + raise ValueError( + "Invalid operation: Currently, only the 'display_name' attribute can be updated for a 'Document'." + ) field_mask = field_mask_pb2.FieldMask() for path in updates.keys(): field_mask.paths.append(path) @@ -1189,7 +1202,9 @@ async def update_async( # At this time, only `display_name` can be updated for item in updates: if item != "display_name": - raise ValueError("At this time, only `display_name` can be updated for `Document`.") + raise ValueError( + "Invalid operation: Currently, only the 'display_name' attribute can be updated for a 'Document'." + ) field_mask = field_mask_pb2.FieldMask() for path in updates.keys(): field_mask.paths.append(path) @@ -1247,7 +1262,7 @@ def batch_update_chunks( for item in updates: if item != "data.string_value": raise ValueError( - f"At this time, only `data` can be updated for `Chunk`. Got {item}." + f"Invalid operation: Currently, only the 'data' attribute can be updated for a 'Chunk'. Attempted to update '{item}'." ) field_mask = field_mask_pb2.FieldMask() for path in updates.keys(): @@ -1289,8 +1304,7 @@ def batch_update_chunks( ) else: raise TypeError( - "The `chunks` parameter must be a list of glm.UpdateChunkRequests," - "dictionaries, or tuples of dictionaries." + "Invalid input: The 'chunks' parameter must be a list of 'glm.UpdateChunkRequests', dictionaries, or tuples of dictionaries." ) request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = client.batch_update_chunks(request, **request_options) @@ -1335,7 +1349,7 @@ async def batch_update_chunks_async( for item in updates: if item != "data.string_value": raise ValueError( - f"At this time, only `data` can be updated for `Chunk`. Got {item}." + f"Invalid operation: Currently, only the 'data' attribute can be updated for a 'Chunk'. Attempted to update '{item}'." ) field_mask = field_mask_pb2.FieldMask() for path in updates.keys(): @@ -1377,8 +1391,7 @@ async def batch_update_chunks_async( ) else: raise TypeError( - "The `chunks` parameter must be a list of glm.UpdateChunkRequests," - "dictionaries, or tuples of dictionaries." + "Invalid input: The 'chunks' parameter must be a list of 'glm.UpdateChunkRequests', dictionaries, or tuples of dictionaries." ) request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = await client.batch_update_chunks(request, **request_options) @@ -1459,7 +1472,7 @@ def batch_delete_chunks( client.batch_delete_chunks(request, **request_options) else: raise ValueError( - "To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple `glm.DeleteChunkRequest`s." + "Invalid operation: To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple 'glm.DeleteChunkRequest's." ) async def batch_delete_chunks_async( @@ -1486,7 +1499,7 @@ async def batch_delete_chunks_async( await client.batch_delete_chunks(request, **request_options) else: raise ValueError( - "To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple `glm.DeleteChunkRequest`s." + "Invalid operation: To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple 'glm.DeleteChunkRequest's." ) def to_dict(self) -> dict[str, Any]: @@ -1604,7 +1617,7 @@ def update( for item in updates: if item != "data.string_value": raise ValueError( - f"At this time, only `data` can be updated for `Chunk`. Got {item}." + f"Invalid operation: Currently, only the 'data' attribute can be updated for a 'Chunk'. Attempted to update '{item}'." ) field_mask = field_mask_pb2.FieldMask() @@ -1644,7 +1657,7 @@ async def update_async( for item in updates: if item != "data.string_value": raise ValueError( - f"At this time, only `data` can be updated for `Chunk`. Got {item}." + f"Invalid operation: Currently, only the 'data' attribute can be updated for a 'Chunk'. Attempted to update '{item}'." ) field_mask = field_mask_pb2.FieldMask() diff --git a/google/generativeai/utils.py b/google/generativeai/utils.py index 6dc2b6a20..cd2c4cbf7 100644 --- a/google/generativeai/utils.py +++ b/google/generativeai/utils.py @@ -16,6 +16,8 @@ def flatten_update_paths(updates): + """Flattens a nested dictionary into a single level dictionary, with keys representing the original path.""" + new_updates = {} for key, value in updates.items(): if isinstance(value, dict): diff --git a/tests/test_client.py b/tests/test_client.py index 34a0f9fc3..0256edac3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -42,7 +42,7 @@ def test_api_key_from_environment(self): def test_api_key_cannot_be_set_twice(self): client_opts = client_options.ClientOptions(api_key="AIzA_client_opts") - with self.assertRaisesRegex(ValueError, "You can't set both"): + with self.assertRaisesRegex(ValueError, "Invalid configuration: Please set either"): client.configure(api_key="AIzA_client", client_options=client_opts) def test_api_key_and_client_options(self): From 6df10a77034e44b2e2d57ad7528ad4478f84a639 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 22 May 2024 01:18:43 -0700 Subject: [PATCH 10/17] Make printing less verbose. (#340) * Make printing less verbose. Change-Id: Ie37c1b75e57427f7256eb3c323a60a2947dcf6b1 * Use string values for enums when printing. Change-Id: I9b07bd5998445cc4ac59abcae8429fd7dacc824a * format Change-Id: Ibbdfca4be9370acd246721fd42629a10dc4ca612 --- google/generativeai/generative_models.py | 2 - google/generativeai/types/generation_types.py | 4 +- tests/test_generation.py | 18 +-- tests/test_generative_models.py | 113 ++++-------------- 4 files changed, 30 insertions(+), 107 deletions(-) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 6fc5554c4..a92801f2f 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -3,10 +3,8 @@ from __future__ import annotations from collections.abc import Iterable -import dataclasses import textwrap from typing import Any -from typing import Union import reprlib # pylint: disable=bad-continuation, line-too-long diff --git a/google/generativeai/types/generation_types.py b/google/generativeai/types/generation_types.py index 119f97162..8d39f76c7 100644 --- a/google/generativeai/types/generation_types.py +++ b/google/generativeai/types/generation_types.py @@ -435,7 +435,9 @@ def __str__(self) -> str: else: _iterator = f"<{self._iterator.__class__.__name__}>" - as_dict = self.to_dict() + as_dict = type(self._result).to_dict( + self._result, use_integers_for_enums=False, including_default_value_fields=False + ) json_str = json.dumps(as_dict, indent=2) _result = f"glm.GenerateContentResponse({json_str})" diff --git a/tests/test_generation.py b/tests/test_generation.py index 82beac16b..b256a1029 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -531,13 +531,8 @@ def test_repr_for_generate_content_response_from_response(self): { "text": "Hello world!" } - ], - "role": "" - }, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [] + ] + } } ] }), @@ -567,13 +562,8 @@ def test_repr_for_generate_content_response_from_iterator(self): { "text": "a" } - ], - "role": "" - }, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [] + ] + } } ] }), diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 3b0c27814..b7393a388 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -828,13 +828,8 @@ def test_repr_for_unary_non_streamed_response(self): { "text": "world!" } - ], - "role": "" - }, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [] + ] + } } ] }), @@ -866,13 +861,8 @@ def test_repr_for_streaming_start_to_finish(self): { "text": "first" } - ], - "role": "" - }, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [] + ] + } } ] }), @@ -896,28 +886,14 @@ def test_repr_for_streaming_start_to_finish(self): { "text": "first second" } - ], - "role": "" + ] }, "index": 0, - "citation_metadata": { - "citation_sources": [] - }, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [] + "citation_metadata": {} } ], - "prompt_feedback": { - "block_reason": 0, - "safety_ratings": [] - }, - "usage_metadata": { - "prompt_token_count": 0, - "candidates_token_count": 0, - "total_token_count": 0 - } + "prompt_feedback": {}, + "usage_metadata": {} }), )""" ) @@ -939,28 +915,14 @@ def test_repr_for_streaming_start_to_finish(self): { "text": "first second third" } - ], - "role": "" + ] }, "index": 0, - "citation_metadata": { - "citation_sources": [] - }, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [] + "citation_metadata": {} } ], - "prompt_feedback": { - "block_reason": 0, - "safety_ratings": [] - }, - "usage_metadata": { - "prompt_token_count": 0, - "candidates_token_count": 0, - "total_token_count": 0 - } + "prompt_feedback": {}, + "usage_metadata": {} }), )""" ) @@ -989,10 +951,8 @@ def test_repr_error_info_for_stream_prompt_feedback_blocked(self): iterator=, result=glm.GenerateContentResponse({ "prompt_feedback": { - "block_reason": 1, - "safety_ratings": [] - }, - "candidates": [] + "block_reason": "SAFETY" + } }), ), error= prompt_feedback { @@ -1047,28 +1007,14 @@ def no_throw(): { "text": "123" } - ], - "role": "" + ] }, "index": 0, - "citation_metadata": { - "citation_sources": [] - }, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [] + "citation_metadata": {} } ], - "prompt_feedback": { - "block_reason": 0, - "safety_ratings": [] - }, - "usage_metadata": { - "prompt_token_count": 0, - "candidates_token_count": 0, - "total_token_count": 0 - } + "prompt_feedback": {}, + "usage_metadata": {} }), ), error= """ @@ -1120,28 +1066,15 @@ def test_repr_error_info_for_chat_streaming_unexpected_stop(self): { "text": "abc" } - ], - "role": "" + ] }, - "finish_reason": 3, + "finish_reason": "SAFETY", "index": 0, - "citation_metadata": { - "citation_sources": [] - }, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [] + "citation_metadata": {} } ], - "prompt_feedback": { - "block_reason": 0, - "safety_ratings": [] - }, - "usage_metadata": { - "prompt_token_count": 0, - "candidates_token_count": 0, - "total_token_count": 0 - } + "prompt_feedback": {}, + "usage_metadata": {} }), ), error= index: 0 From 75b97dbbbef61d2c9fa4d242b76f30571922bf16 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 22 May 2024 01:27:22 -0700 Subject: [PATCH 11/17] Add video metadata and error to File (#348) * Add video metadata and error to File Change-Id: I721147d4e9abf526c7f0a60346761591d63ebb2f * add tests Change-Id: I41a7af34a3068549cee3c45aead9a042415219ee * fix tests Change-Id: I005e30219f49830f73658488e58588d6ed7ccd88 --- google/generativeai/types/file_types.py | 11 +- tests/test_files.py | 140 ++++++++++++++++++++++++ 2 files changed, 150 insertions(+), 1 deletion(-) create mode 100644 tests/test_files.py diff --git a/google/generativeai/types/file_types.py b/google/generativeai/types/file_types.py index eb4c0902d..0fdf05322 100644 --- a/google/generativeai/types/file_types.py +++ b/google/generativeai/types/file_types.py @@ -18,6 +18,7 @@ from typing import Union from typing_extensions import TypedDict +from google.rpc.status_pb2 import Status from google.generativeai.client import get_default_file_client import google.ai.generativelanguage as glm @@ -29,7 +30,7 @@ def __init__(self, proto: glm.File | File | dict): proto = proto.to_proto() self._proto = glm.File(proto) - def to_proto(self): + def to_proto(self) -> glm.File: return self._proto @property @@ -72,6 +73,14 @@ def uri(self) -> str: def state(self) -> glm.File.State: return self._proto.state + @property + def video_metadata(self) -> glm.VideoMetadata: + return self._proto.video_metadata + + @property + def error(self) -> Status: + return self._proto.error + def delete(self): client = get_default_file_client() client.delete_file(name=self.name) diff --git a/tests/test_files.py b/tests/test_files.py new file mode 100644 index 000000000..333ec1e2a --- /dev/null +++ b/tests/test_files.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.generativeai.types import file_types + +import collections +import datetime +import os +from typing import Iterable, Union +import pathlib + +import google +import google.ai.generativelanguage as glm + +import google.generativeai as genai +from google.generativeai import client as client_lib +from absl.testing import parameterized + + +class FileServiceClient(client_lib.FileServiceClient): + def __init__(self, test): + self.test = test + self.observed_requests = [] + self.responses = collections.defaultdict(list) + + def create_file( + self, + path: Union[str, pathlib.Path, os.PathLike], + *, + mime_type: Union[str, None] = None, + name: Union[str, None] = None, + display_name: Union[str, None] = None, + resumable: bool = True, + ) -> glm.File: + self.observed_requests.append( + dict( + path=path, + mime_type=mime_type, + name=name, + display_name=display_name, + resumable=resumable, + ) + ) + return self.responses["create_file"].pop(0) + + def get_file( + self, + request: glm.GetFileRequest, + **kwargs, + ) -> glm.File: + self.observed_requests.append(request) + return self.responses["get_file"].pop(0) + + def list_files( + self, + request: glm.ListFilesRequest, + **kwargs, + ) -> Iterable[glm.File]: + self.observed_requests.append(request) + for f in self.responses["list_files"].pop(0): + yield f + + def delete_file( + self, + request: glm.DeleteFileRequest, + **kwargs, + ): + self.observed_requests.append(request) + return + + +class UnitTests(parameterized.TestCase): + def setUp(self): + self.client = FileServiceClient(self) + + client_lib._client_manager.clients["file"] = self.client + + @property + def observed_requests(self): + return self.client.observed_requests + + @property + def responses(self): + return self.client.responses + + def test_video_metadata(self): + self.responses["create_file"].append( + glm.File( + uri="https://test", + state="ACTIVE", + video_metadata=dict(video_duration=datetime.timedelta(seconds=30)), + error=dict(code=7, message="ok?"), + ) + ) + + f = genai.upload_file(path="dummy") + self.assertEqual(google.rpc.status_pb2.Status(code=7, message="ok?"), f.error) + self.assertEqual( + glm.VideoMetadata(dict(video_duration=datetime.timedelta(seconds=30))), f.video_metadata + ) + + @parameterized.named_parameters( + [ + dict( + testcase_name="FileDataDict", + file_data=dict(file_uri="https://test_uri"), + ), + dict( + testcase_name="FileDict", + file_data=dict(uri="https://test_uri"), + ), + dict( + testcase_name="FileData", + file_data=glm.FileData(file_uri="https://test_uri"), + ), + dict( + testcase_name="glm.File", + file_data=glm.File(uri="https://test_uri"), + ), + dict( + testcase_name="file_types.File", + file_data=file_types.File(dict(uri="https://test_uri")), + ), + ] + ) + def test_to_file_data(self, file_data): + file_data = file_types.to_file_data(file_data) + self.assertEqual(glm.FileData(file_uri="https://test_uri"), file_data) From 386994a6ba8b798610cc388d157b17ca3417b70f Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 22 May 2024 01:27:35 -0700 Subject: [PATCH 12/17] Quick safety filtering: Allow `safety_settings="block_none"` (#347) * allow safety_settings='off' Change-Id: Ica10b399177301073424a98cb3a8b0736dc216b4 * Fix tests. Change-Id: I06cfd07397e984b9fb757b2831b419eefb8aff98 * license Change-Id: Ifa4843831b9c1479198c2b45c5b5abad8410f448 * format Change-Id: I534837c309121cda9c8947acdd6c126c9c730d62 * add test Change-Id: I9bce66322d64b3d6296d4db7cc0a7b7b9a78763b --- google/generativeai/types/safety_types.py | 37 +++++++++++++-- tests/test_generative_models.py | 36 +++++++------- tests/test_safety.py | 57 +++++++++++++++++++++++ 3 files changed, 109 insertions(+), 21 deletions(-) create mode 100644 tests/test_safety.py diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index 85e57c8f6..c8368da7f 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -201,18 +201,41 @@ class LooseSafetySettingDict(TypedDict): EasySafetySetting = Mapping[HarmCategoryOptions, HarmBlockThresholdOptions] EasySafetySettingDict = dict[HarmCategoryOptions, HarmBlockThresholdOptions] -SafetySettingOptions = Union[EasySafetySetting, Iterable[LooseSafetySettingDict], None] +SafetySettingOptions = Union[ + HarmBlockThresholdOptions, EasySafetySetting, Iterable[LooseSafetySettingDict], None +] + + +def _expand_block_threshold(block_threshold: HarmBlockThresholdOptions): + block_threshold = to_block_threshold(block_threshold) + set(_HARM_CATEGORIES.values()) + return {category: block_threshold for category in set(_HARM_CATEGORIES.values())} def to_easy_safety_dict(settings: SafetySettingOptions) -> EasySafetySettingDict: if settings is None: return {} - elif isinstance(settings, Mapping): + + if isinstance(settings, (int, str, HarmBlockThreshold)): + settings = _expand_block_threshold(settings) + + if isinstance(settings, Mapping): return {to_harm_category(key): to_block_threshold(value) for key, value in settings.items()} + else: # Iterable - return { - to_harm_category(d["category"]): to_block_threshold(d["threshold"]) for d in settings - } + result = {} + for setting in settings: + if isinstance(setting, glm.SafetySetting): + result[to_harm_category(setting.category)] = to_block_threshold(setting.threshold) + elif isinstance(setting, dict): + result[to_harm_category(setting["category"])] = to_block_threshold( + setting["threshold"] + ) + else: + raise ValueError( + f"Could not understand safety setting:\n {type(setting)=}\n {setting=}" + ) + return result def normalize_safety_settings( @@ -220,6 +243,10 @@ def normalize_safety_settings( ) -> list[SafetySettingDict] | None: if settings is None: return None + + if isinstance(settings, (int, str, HarmBlockThreshold)): + settings = _expand_block_threshold(settings) + if isinstance(settings, Mapping): return [ { diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index b7393a388..ff0512031 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -155,11 +155,12 @@ def test_generation_config_overwrite(self, config1, config2): @parameterized.named_parameters( ["dict", {"danger": "low"}, {"danger": "high"}], + ["quick", "low", "high"], [ "list-dict", [ dict( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ), ], @@ -171,21 +172,21 @@ def test_generation_config_overwrite(self, config1, config2): "object", [ glm.SafetySetting( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ), ], [ glm.SafetySetting( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, ), ], ], ) def test_safety_overwrite(self, safe1, safe2): # Safety - model = generative_models.GenerativeModel("gemini-pro", safety_settings={"danger": "low"}) + model = generative_models.GenerativeModel("gemini-pro", safety_settings=safe1) self.responses["generate_content"] = [ simple_response(" world!"), @@ -193,22 +194,25 @@ def test_safety_overwrite(self, safe1, safe2): ] _ = model.generate_content("hello") + + danger = [ + s + for s in self.observed_requests[-1].safety_settings + if s.category == glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT + ] self.assertEqual( - self.observed_requests[-1].safety_settings[0].category, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - ) - self.assertEqual( - self.observed_requests[-1].safety_settings[0].threshold, + danger[0].threshold, glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ) - _ = model.generate_content("hello", safety_settings={"danger": "high"}) - self.assertEqual( - self.observed_requests[-1].safety_settings[0].category, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - ) + _ = model.generate_content("hello", safety_settings=safe2) + danger = [ + s + for s in self.observed_requests[-1].safety_settings + if s.category == glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT + ] self.assertEqual( - self.observed_requests[-1].safety_settings[0].threshold, + danger[0].threshold, glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, ) diff --git a/tests/test_safety.py b/tests/test_safety.py new file mode 100644 index 000000000..f3efc4aca --- /dev/null +++ b/tests/test_safety.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +import google.ai.generativelanguage as glm +from google.generativeai.types import safety_types + + +class SafetyTests(parameterized.TestCase): + """Tests are in order with the design doc.""" + + @parameterized.named_parameters( + ["block_threshold", glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE], + ["block_threshold2", "medium"], + ["block_threshold3", 2], + ["dict", {"danger": "medium"}], + ["dict2", {"danger": 2}], + ["dict3", {"danger": glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}], + [ + "list-dict", + [ + dict( + category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + ], + ], + [ + "list-dict2", + [ + dict(category="danger", threshold="med"), + ], + ], + ) + def test_safety_overwrite(self, setting): + setting = safety_types.to_easy_safety_dict(setting) + self.assertEqual( + setting[glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT], + glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ) + + +if __name__ == "__main__": + absltest.main() From 0dca4ce880ac10ad19adde1fd8a56fdacb92618a Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 22 May 2024 01:29:32 -0700 Subject: [PATCH 13/17] Add request options to chat. (#341) * Add request options to chat Change-Id: I6f7e4c980fd7e2a14fec4c3e2d837ad745c69c9a * fix async Change-Id: Ia224e9e8327443a9920ce5d9a877ebb8c272e583 * fix Change-Id: I7eed70131346c7d7ffe435c8f6909f7eb3f7e9f7 * merge from main Change-Id: I4b92a5bc25aa7bf11bfaf31aa6c029096f3e68bc * add tests Change-Id: I368315f220413ba9508012721e64093372555590 * format Change-Id: I26c7fa1f040e7d1ea16068034d78fb9f6cc13db0 --- google/generativeai/generative_models.py | 34 ++++++- tests/test_generative_models.py | 117 ++++++++++++++--------- 2 files changed, 106 insertions(+), 45 deletions(-) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index a92801f2f..873d2fcb4 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -444,6 +444,7 @@ def send_message( stream: bool = False, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> generation_types.GenerateContentResponse: """Sends the conversation history with the added message and returns the model's response. @@ -476,6 +477,9 @@ def send_message( safety_settings: Overrides for the model's safety settings. stream: If True, yield response chunks as they are generated. """ + if request_options is None: + request_options = {} + if self.enable_automatic_function_calling and stream: raise NotImplementedError( "Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`." @@ -504,6 +508,7 @@ def send_message( stream=stream, tools=tools_lib, tool_config=tool_config, + request_options=request_options, ) self._check_response(response=response, stream=stream) @@ -516,6 +521,7 @@ def send_message( safety_settings=safety_settings, stream=stream, tools_lib=tools_lib, + request_options=request_options, ) self._last_sent = content @@ -546,7 +552,15 @@ def _get_function_calls(self, response) -> list[glm.FunctionCall]: return function_calls def _handle_afc( - self, *, response, history, generation_config, safety_settings, stream, tools_lib + self, + *, + response, + history, + generation_config, + safety_settings, + stream, + tools_lib, + request_options, ) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]: while function_calls := self._get_function_calls(response): @@ -572,6 +586,7 @@ def _handle_afc( safety_settings=safety_settings, stream=stream, tools=tools_lib, + request_options=request_options, ) self._check_response(response=response, stream=stream) @@ -588,8 +603,12 @@ async def send_message_async( stream: bool = False, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> generation_types.AsyncGenerateContentResponse: """The async version of `ChatSession.send_message`.""" + if request_options is None: + request_options = {} + if self.enable_automatic_function_calling and stream: raise NotImplementedError( "Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`." @@ -618,6 +637,7 @@ async def send_message_async( stream=stream, tools=tools_lib, tool_config=tool_config, + request_options=request_options, ) self._check_response(response=response, stream=stream) @@ -630,6 +650,7 @@ async def send_message_async( safety_settings=safety_settings, stream=stream, tools_lib=tools_lib, + request_options=request_options, ) self._last_sent = content @@ -638,7 +659,15 @@ async def send_message_async( return response async def _handle_afc_async( - self, *, response, history, generation_config, safety_settings, stream, tools_lib + self, + *, + response, + history, + generation_config, + safety_settings, + stream, + tools_lib, + request_options, ) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]: while function_calls := self._get_function_calls(response): @@ -664,6 +693,7 @@ async def _handle_afc_async( safety_settings=safety_settings, stream=stream, tools=tools_lib, + request_options=request_options, ) self._check_response(response=response, stream=stream) diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index ff0512031..4a0f86991 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -12,6 +12,8 @@ from google.generativeai import generative_models from google.generativeai.types import content_types from google.generativeai.types import generation_types +from google.generativeai.types import helper_types + import PIL.Image @@ -37,49 +39,63 @@ def simple_response(text: str) -> glm.GenerateContentResponse: return glm.GenerateContentResponse({"candidates": [{"content": simple_part(text)}]}) +class MockGenerativeServiceClient: + def __init__(self, test): + self.test = test + self.observed_requests = [] + self.observed_kwargs = [] + self.responses = collections.defaultdict(list) + + def generate_content( + self, + request: glm.GenerateContentRequest, + **kwargs, + ) -> glm.GenerateContentResponse: + self.test.assertIsInstance(request, glm.GenerateContentRequest) + self.observed_requests.append(request) + self.observed_kwargs.append(kwargs) + response = self.responses["generate_content"].pop(0) + return response + + def stream_generate_content( + self, + request: glm.GetModelRequest, + **kwargs, + ) -> Iterable[glm.GenerateContentResponse]: + self.observed_requests.append(request) + self.observed_kwargs.append(kwargs) + response = self.responses["stream_generate_content"].pop(0) + return response + + def count_tokens( + self, + request: glm.CountTokensRequest, + **kwargs, + ) -> Iterable[glm.GenerateContentResponse]: + self.observed_requests.append(request) + self.observed_kwargs.append(kwargs) + response = self.responses["count_tokens"].pop(0) + return response + + class CUJTests(parameterized.TestCase): """Tests are in order with the design doc.""" - def setUp(self): - self.client = unittest.mock.MagicMock() + @property + def observed_requests(self): + return self.client.observed_requests - client_lib._client_manager.clients["generative"] = self.client - - def add_client_method(f): - name = f.__name__ - setattr(self.client, name, f) - return f + @property + def observed_kwargs(self): + return self.client.observed_kwargs - self.observed_requests = [] - self.responses = collections.defaultdict(list) + @property + def responses(self): + return self.client.responses - @add_client_method - def generate_content( - request: glm.GenerateContentRequest, - **kwargs, - ) -> glm.GenerateContentResponse: - self.assertIsInstance(request, glm.GenerateContentRequest) - self.observed_requests.append(request) - response = self.responses["generate_content"].pop(0) - return response - - @add_client_method - def stream_generate_content( - request: glm.GetModelRequest, - **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: - self.observed_requests.append(request) - response = self.responses["stream_generate_content"].pop(0) - return response - - @add_client_method - def count_tokens( - request: glm.CountTokensRequest, - **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: - self.observed_requests.append(request) - response = self.responses["count_tokens"].pop(0) - return response + def setUp(self): + self.client = MockGenerativeServiceClient(self) + client_lib._client_manager.clients["generative"] = self.client def test_hello(self): # Generate text from text prompt @@ -451,7 +467,7 @@ def test_copy_history(self): chat1 = model.start_chat() chat1.send_message("hello1") - chat2 = copy.deepcopy(chat1) + chat2 = copy.copy(chat1) chat2.send_message("hello2") chat1.send_message("hello3") @@ -810,7 +826,7 @@ def test_async_code_match(self, obj, aobj): ) asource = re.sub(" *?# type: ignore", "", asource) - self.assertEqual(source, asource) + self.assertEqual(source, asource, f"error in {obj=}") def test_repr_for_unary_non_streamed_response(self): model = generative_models.GenerativeModel(model_name="gemini-pro") @@ -1208,15 +1224,30 @@ def test_repr_for_system_instruction(self): self.assertIn("system_instruction='Be excellent.'", result) def test_count_tokens_called_with_request_options(self): - self.client.count_tokens = unittest.mock.MagicMock() - request = unittest.mock.ANY + self.responses["count_tokens"].append(glm.CountTokensResponse()) request_options = {"timeout": 120} - self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)] model = generative_models.GenerativeModel("gemini-pro-vision") model.count_tokens([{"role": "user", "parts": ["hello"]}], request_options=request_options) - self.client.count_tokens.assert_called_once_with(request, **request_options) + self.assertEqual(request_options, self.observed_kwargs[0]) + + def test_chat_with_request_options(self): + self.responses["generate_content"].append( + glm.GenerateContentResponse( + { + "candidates": [{"finish_reason": "STOP"}], + } + ) + ) + request_options = {"timeout": 120} + + model = generative_models.GenerativeModel("gemini-pro") + chat = model.start_chat() + chat.send_message("hello", request_options=helper_types.RequestOptions(**request_options)) + + request_options["retry"] = None + self.assertEqual(request_options, self.observed_kwargs[0]) if __name__ == "__main__": From 2e62faebceaf496d9a511f930f23c579669af5a1 Mon Sep 17 00:00:00 2001 From: Logan Kilpatrick <23kilpatrick23@gmail.com> Date: Sun, 26 May 2024 21:34:06 -0500 Subject: [PATCH 14/17] Update __init__.py to use the latest model (#362) * Update __init__.py * Grammar --------- Co-authored-by: Mark McDonald --- google/generativeai/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/google/generativeai/__init__.py b/google/generativeai/__init__.py index 53383a1b3..57e848298 100644 --- a/google/generativeai/__init__.py +++ b/google/generativeai/__init__.py @@ -30,8 +30,8 @@ genai.configure(api_key=os.environ['API_KEY']) -model = genai.GenerativeModel(name='gemini-pro') -response = model.generate_content('Please summarise this document: ...') +model = genai.GenerativeModel(name='gemini-1.5-flash') +response = model.generate_content('Teach me about how an LLM works') print(response.text) ``` From f08c789741f30e49ecfb822540fd749920d62bcc Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 29 May 2024 20:46:12 -0700 Subject: [PATCH 15/17] Add genai.protos (#354) * Add genai.protos Change-Id: I21cfada033c6ffbed7a20e117e61582fde925f61 * Add genai.protos Change-Id: I9c8473d4ca1a0e92489f145a18ef1abd29af22b3 * test_protos.py Change-Id: I576080fb80cf9dc9345d8bb2178eb4b9ac59ce97 * fix docs + format Change-Id: I5f9aa3f8e3ae780e5cec2078d3eb153157b195fe * fix merge Change-Id: I17014791d966d797b481bca17df69558b23a9a1a * format Change-Id: I51d30f6568640456bcf28db2bd338a58a82346de * Fix client references Change-Id: I4899231706c9624a0f189b22b6f70aeeb4cbea29 * Fix tests Change-Id: I8a636fb634fd079a892cb99170a12c0613887ccf * add import Change-Id: I517171389801ef249cd478f98798181da83bef69 * fix import Change-Id: I8921c0caaa9b902ebde682ead31a2444298c2c9c * Update docstring Change-Id: I1f6b3b9b9521baa8812a908431bf58c623860733 * spelling Change-Id: I0421a35687ed14b1a5ca3b496cafd91514c4de92 * remove unused imports Change-Id: Ifc791796e36668eb473fd0fffea4833b1a062188 * Resolve review coments. Change-Id: Ieb900190f42e883337028ae25da3be819507db4a * Update docstring. Change-Id: I805473f9aaeb04e922a9f66bb5f40716d42fb738 * Fix typo --------- Co-authored-by: Mark McDonald --- docs/build_docs.py | 131 +------------ google/generativeai/__init__.py | 1 + google/generativeai/answer.py | 59 +++--- google/generativeai/client.py | 3 +- google/generativeai/discuss.py | 87 ++++----- google/generativeai/embedding.py | 15 +- google/generativeai/files.py | 8 +- google/generativeai/generative_models.py | 72 +++---- google/generativeai/models.py | 32 ++-- google/generativeai/operations.py | 12 +- google/generativeai/protos.py | 75 ++++++++ google/generativeai/responder.py | 77 ++++---- google/generativeai/retriever.py | 31 ++-- google/generativeai/text.py | 32 ++-- google/generativeai/types/answer_types.py | 4 +- google/generativeai/types/citation_types.py | 6 +- google/generativeai/types/content_types.py | 138 +++++++------- google/generativeai/types/discuss_types.py | 24 +-- google/generativeai/types/file_types.py | 24 +-- google/generativeai/types/generation_types.py | 60 +++--- google/generativeai/types/model_types.py | 42 ++--- .../generativeai/types/palm_safety_types.py | 134 +++++++------- google/generativeai/types/permission_types.py | 35 ++-- google/generativeai/types/retriever_types.py | 175 ++++++++++-------- google/generativeai/types/safety_types.py | 108 +++++------ tests/test_answer.py | 142 ++++++++------ tests/test_client.py | 4 +- tests/test_content.py | 146 +++++++-------- tests/test_discuss.py | 76 ++++---- tests/test_discuss_async.py | 22 +-- tests/test_embedding.py | 21 ++- tests/test_embedding_async.py | 21 ++- tests/test_files.py | 27 +-- tests/test_generation.py | 155 +++++++++------- tests/test_generative_models.py | 124 ++++++------- tests/test_generative_models_async.py | 34 ++-- tests/test_helpers.py | 12 +- tests/test_models.py | 120 ++++++------ tests/test_operations.py | 18 +- tests/test_permission.py | 70 +++---- tests/test_permission_async.py | 70 +++---- tests/test_protos.py | 34 ++++ tests/test_responder.py | 58 +++--- tests/test_retriever.py | 140 +++++++------- tests/test_retriever_async.py | 134 +++++++------- tests/test_safety.py | 14 +- tests/test_text.py | 96 +++++----- 47 files changed, 1499 insertions(+), 1424 deletions(-) create mode 100644 google/generativeai/protos.py create mode 100644 tests/test_protos.py diff --git a/docs/build_docs.py b/docs/build_docs.py index eaa6a1ba4..012cd3441 100644 --- a/docs/build_docs.py +++ b/docs/build_docs.py @@ -44,77 +44,13 @@ # For showing the conditional imports and types in `content_types.py` # grpc must be imported first. typing.TYPE_CHECKING = True -from google import generativeai as palm - +from google import generativeai as genai from tensorflow_docs.api_generator import generate_lib from tensorflow_docs.api_generator import public_api import yaml -glm.__doc__ = """\ -This package, `google.ai.generativelanguage`, is a low-level auto-generated client library for the PaLM API. - -```posix-terminal -pip install google.ai.generativelanguage -``` - -It is built using the same tooling as Google Cloud client libraries, and will be quite familiar if you've used -those before. - -While we encourage Python users to access the PaLM API using the `google.generativeai` package (aka `palm`), -this lower level package is also available. - -Each method in the PaLM API is connected to one of the client classes. Pass your API-key to the class' `client_options` -when initializing a client: - -``` -from google.ai import generativelanguage as glm - -client = glm.DiscussServiceClient( - client_options={'api_key':'YOUR_API_KEY'}) -``` - -To call the api, pass an appropriate request-proto-object. For the `DiscussServiceClient.generate_message` pass -a `generativelanguage.GenerateMessageRequest` instance: - -``` -request = glm.GenerateMessageRequest( - model='models/chat-bison-001', - prompt=glm.MessagePrompt( - messages=[glm.Message(content='Hello!')])) - -client.generate_message(request) -``` -``` -candidates { - author: "1" - content: "Hello! How can I help you today?" -} -... -``` - -For simplicity: - -* The API methods also accept key-word arguments. -* Anywhere you might pass a proto-object, the library will also accept simple python structures. - -So the following is equivalent to the previous example: - -``` -client.generate_message( - model='models/chat-bison-001', - prompt={'messages':[{'content':'Hello!'}]}) -``` -``` -candidates { - author: "1" - content: "Hello! How can I help you today?" -} -... -``` -""" - HERE = pathlib.Path(__file__).parent PROJECT_SHORT_NAME = "genai" @@ -139,43 +75,6 @@ ) -class MyFilter: - def __init__(self, base_dirs): - self.filter_base_dirs = public_api.FilterBaseDirs(base_dirs) - - def drop_staticmethods(self, parent, children): - parent = dict(parent.__dict__) - for name, value in children: - if not isinstance(parent.get(name, None), staticmethod): - yield name, value - - def __call__(self, path, parent, children): - if any("generativelanguage" in part for part in path) or "generativeai" in path: - children = self.filter_base_dirs(path, parent, children) - children = public_api.explicit_package_contents_filter(path, parent, children) - - if any("generativelanguage" in part for part in path): - if "ServiceClient" in path[-1] or "ServiceAsyncClient" in path[-1]: - children = list(self.drop_staticmethods(parent, children)) - - return children - - -class MyDocGenerator(generate_lib.DocGenerator): - def make_default_filters(self): - return [ - # filter the api. - public_api.FailIfNestedTooDeep(10), - public_api.filter_module_all, - public_api.add_proto_fields, - public_api.filter_private_symbols, - MyFilter(self._base_dir), # Replaces: public_api.FilterBaseDirs(self._base_dir), - public_api.FilterPrivateMap(self._private_map), - public_api.filter_doc_controls_skip, - public_api.ignore_typing, - ] - - def gen_api_docs(): """Generates api docs for the generative-ai package.""" for name in dir(google): @@ -188,11 +87,11 @@ def gen_api_docs(): """ ) - doc_generator = MyDocGenerator( + doc_generator = generate_lib.DocGenerator( root_title=PROJECT_FULL_NAME, - py_modules=[("google", google)], + py_modules=[("google.generativeai", genai)], base_dir=( - pathlib.Path(palm.__file__).parent, + pathlib.Path(genai.__file__).parent, pathlib.Path(glm.__file__).parent.parent, ), code_url_prefix=( @@ -201,32 +100,12 @@ def gen_api_docs(): ), search_hints=_SEARCH_HINTS.value, site_path=_SITE_PATH.value, - callbacks=[], + callbacks=[public_api.explicit_package_contents_filter], ) out_path = pathlib.Path(_OUTPUT_DIR.value) doc_generator.build(out_path) - # Fixup the toc file. - toc_path = out_path / "google/_toc.yaml" - toc = yaml.safe_load(toc_path.read_text()) - assert toc["toc"][0]["title"] == "google" - toc["toc"] = toc["toc"][1:] - toc["toc"][0]["title"] = "google.ai.generativelanguage" - toc["toc"][0]["section"] = toc["toc"][0]["section"][1]["section"] - toc["toc"][0], toc["toc"][1] = toc["toc"][1], toc["toc"][0] - toc_path.write_text(yaml.dump(toc)) - - # remove some dummy files and redirect them to `api/` - (out_path / "google.md").unlink() - (out_path / "google/ai.md").unlink() - redirects_path = out_path / "_redirects.yaml" - redirects = {"redirects": []} - redirects["redirects"].insert(0, {"from": "/api/python/google/ai", "to": "/api/"}) - redirects["redirects"].insert(0, {"from": "/api/python/google", "to": "/api/"}) - redirects["redirects"].insert(0, {"from": "/api/python", "to": "/api/"}) - redirects_path.write_text(yaml.dump(redirects)) - # clear `oneof` junk from proto pages for fpath in out_path.rglob("*.md"): old_content = fpath.read_text() diff --git a/google/generativeai/__init__.py b/google/generativeai/__init__.py index 57e848298..2b93fc1ce 100644 --- a/google/generativeai/__init__.py +++ b/google/generativeai/__init__.py @@ -42,6 +42,7 @@ from google.generativeai import version +from google.generativeai import protos from google.generativeai import types from google.generativeai.types import GenerationConfig diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index 4b9d9f97c..4bfabbf23 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -21,6 +21,7 @@ from typing_extensions import TypedDict import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.client import ( get_default_generative_client, @@ -35,7 +36,7 @@ DEFAULT_ANSWER_MODEL = "models/aqa" -AnswerStyle = glm.GenerateAnswerRequest.AnswerStyle +AnswerStyle = protos.GenerateAnswerRequest.AnswerStyle AnswerStyleOptions = Union[int, str, AnswerStyle] @@ -66,28 +67,30 @@ def to_answer_style(x: AnswerStyleOptions) -> AnswerStyle: GroundingPassageOptions = ( - Union[glm.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType], + Union[ + protos.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType + ], ) GroundingPassagesOptions = Union[ - glm.GroundingPassages, + protos.GroundingPassages, Iterable[GroundingPassageOptions], Mapping[str, content_types.ContentType], ] -def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingPassages: +def _make_grounding_passages(source: GroundingPassagesOptions) -> protos.GroundingPassages: """ - Converts the `source` into a `glm.GroundingPassage`. A `GroundingPassages` contains a list of - `glm.GroundingPassage` objects, which each contain a `glm.Contant` and a string `id`. + Converts the `source` into a `protos.GroundingPassage`. A `GroundingPassages` contains a list of + `protos.GroundingPassage` objects, which each contain a `protos.Contant` and a string `id`. Args: - source: `Content` or a `GroundingPassagesOptions` that will be converted to glm.GroundingPassages. + source: `Content` or a `GroundingPassagesOptions` that will be converted to protos.GroundingPassages. Return: - `glm.GroundingPassages` to be passed into `glm.GenerateAnswer`. + `protos.GroundingPassages` to be passed into `protos.GenerateAnswer`. """ - if isinstance(source, glm.GroundingPassages): + if isinstance(source, protos.GroundingPassages): return source if not isinstance(source, Iterable): @@ -100,7 +103,7 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP source = source.items() for n, data in enumerate(source): - if isinstance(data, glm.GroundingPassage): + if isinstance(data, protos.GroundingPassage): passages.append(data) elif isinstance(data, tuple): id, content = data # tuple must have exactly 2 items. @@ -108,11 +111,11 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP else: passages.append({"id": str(n), "content": content_types.to_content(data)}) - return glm.GroundingPassages(passages=passages) + return protos.GroundingPassages(passages=passages) SourceNameType = Union[ - str, retriever_types.Corpus, glm.Corpus, retriever_types.Document, glm.Document + str, retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document ] @@ -127,7 +130,7 @@ class SemanticRetrieverConfigDict(TypedDict): SemanticRetrieverConfigOptions = Union[ SourceNameType, SemanticRetrieverConfigDict, - glm.SemanticRetrieverConfig, + protos.SemanticRetrieverConfig, ] @@ -135,7 +138,7 @@ def _maybe_get_source_name(source) -> str | None: if isinstance(source, str): return source elif isinstance( - source, (retriever_types.Corpus, glm.Corpus, retriever_types.Document, glm.Document) + source, (retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document) ): return source.name else: @@ -145,8 +148,8 @@ def _maybe_get_source_name(source) -> str | None: def _make_semantic_retriever_config( source: SemanticRetrieverConfigOptions, query: content_types.ContentsType, -) -> glm.SemanticRetrieverConfig: - if isinstance(source, glm.SemanticRetrieverConfig): +) -> protos.SemanticRetrieverConfig: + if isinstance(source, protos.SemanticRetrieverConfig): return source name = _maybe_get_source_name(source) @@ -156,7 +159,7 @@ def _make_semantic_retriever_config( source["source"] = _maybe_get_source_name(source["source"]) else: raise TypeError( - f"Invalid input: Failed to create a 'glm.SemanticRetrieverConfig' from the provided source. " + f"Invalid input: Failed to create a 'protos.SemanticRetrieverConfig' from the provided source. " f"Received type: {type(source).__name__}, " f"Received value: {source}" ) @@ -166,7 +169,7 @@ def _make_semantic_retriever_config( elif isinstance(source["query"], str): source["query"] = content_types.to_content(source["query"]) - return glm.SemanticRetrieverConfig(source) + return protos.SemanticRetrieverConfig(source) def _make_generate_answer_request( @@ -178,9 +181,9 @@ def _make_generate_answer_request( answer_style: AnswerStyle | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, temperature: float | None = None, -) -> glm.GenerateAnswerRequest: +) -> protos.GenerateAnswerRequest: """ - constructs a glm.GenerateAnswerRequest object by organizing the input parameters for the API call to generate a grounded answer from the model. + constructs a protos.GenerateAnswerRequest object by organizing the input parameters for the API call to generate a grounded answer from the model. Args: model: Name of the model used to generate the grounded response. @@ -188,16 +191,16 @@ def _make_generate_answer_request( single question to answer. For multi-turn queries, this is a repeated field that contains conversation history and the last `Content` in the list containing the question. inline_passages: Grounding passages (a list of `Content`-like objects or `(id, content)` pairs, - or a `glm.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, + or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, one must be set, but not both. - semantic_retriever: A Corpus, Document, or `glm.SemanticRetrieverConfig` to use for grounding. Exclusive with + semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with `inline_passages`, one must be set, but not both. answer_style: Style for grounded answers. safety_settings: Safety settings for generated output. temperature: The temperature for randomness in the output. Returns: - Call for glm.GenerateAnswerRequest(). + Call for protos.GenerateAnswerRequest(). """ model = model_types.make_model_name(model) @@ -224,7 +227,7 @@ def _make_generate_answer_request( if answer_style: answer_style = to_answer_style(answer_style) - return glm.GenerateAnswerRequest( + return protos.GenerateAnswerRequest( model=model, contents=contents, inline_passages=inline_passages, @@ -273,9 +276,9 @@ def generate_answer( contents: The question to be answered by the model, grounded in the provided source. inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs, - or a `glm.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, + or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, one must be set, but not both. - semantic_retriever: A Corpus, Document, or `glm.SemanticRetrieverConfig` to use for grounding. Exclusive with + semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with `inline_passages`, one must be set, but not both. answer_style: Style in which the grounded answer should be returned. safety_settings: Safety settings for generated output. Defaults to None. @@ -327,9 +330,9 @@ async def generate_answer_async( contents: The question to be answered by the model, grounded in the provided source. inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs, - or a `glm.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, + or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, one must be set, but not both. - semantic_retriever: A Corpus, Document, or `glm.SemanticRetrieverConfig` to use for grounding. Exclusive with + semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with `inline_passages`, one must be set, but not both. answer_style: Style in which the grounded answer should be returned. safety_settings: Safety settings for generated output. Defaults to None. diff --git a/google/generativeai/client.py b/google/generativeai/client.py index d969889d0..40c2bdcaf 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -10,6 +10,7 @@ import httplib2 import google.ai.generativelanguage as glm +import google.generativeai.protos as protos from google.auth import credentials as ga_credentials from google.auth import exceptions as ga_exceptions @@ -76,7 +77,7 @@ def create_file( name: str | None = None, display_name: str | None = None, resumable: bool = True, - ) -> glm.File: + ) -> protos.File: if self._discovery_api is None: self._setup_discovery_api() diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index b084ccad8..448347b41 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -18,37 +18,38 @@ import sys import textwrap -from typing import Any, Iterable, List, Optional, Union +from typing import Iterable, List import google.ai.generativelanguage as glm from google.generativeai.client import get_default_discuss_client from google.generativeai.client import get_default_discuss_async_client from google.generativeai import string_utils +from google.generativeai import protos from google.generativeai.types import discuss_types from google.generativeai.types import helper_types from google.generativeai.types import model_types from google.generativeai.types import palm_safety_types -def _make_message(content: discuss_types.MessageOptions) -> glm.Message: - """Creates a `glm.Message` object from the provided content.""" - if isinstance(content, glm.Message): +def _make_message(content: discuss_types.MessageOptions) -> protos.Message: + """Creates a `protos.Message` object from the provided content.""" + if isinstance(content, protos.Message): return content if isinstance(content, str): - return glm.Message(content=content) + return protos.Message(content=content) else: - return glm.Message(content) + return protos.Message(content) def _make_messages( messages: discuss_types.MessagesOptions, -) -> List[glm.Message]: +) -> List[protos.Message]: """ - Creates a list of `glm.Message` objects from the provided messages. + Creates a list of `protos.Message` objects from the provided messages. This function takes a variety of message content inputs, such as strings, dictionaries, - or `glm.Message` objects, and creates a list of `glm.Message` objects. It ensures that + or `protos.Message` objects, and creates a list of `protos.Message` objects. It ensures that the authors of the messages alternate appropriately. If authors are not provided, default authors are assigned based on their position in the list. @@ -56,9 +57,9 @@ def _make_messages( messages: The messages to convert. Returns: - A list of `glm.Message` objects with alternating authors. + A list of `protos.Message` objects with alternating authors. """ - if isinstance(messages, (str, dict, glm.Message)): + if isinstance(messages, (str, dict, protos.Message)): messages = [_make_message(messages)] else: messages = [_make_message(message) for message in messages] @@ -93,39 +94,39 @@ def _make_messages( return messages -def _make_example(item: discuss_types.ExampleOptions) -> glm.Example: - """Creates a `glm.Example` object from the provided item.""" - if isinstance(item, glm.Example): +def _make_example(item: discuss_types.ExampleOptions) -> protos.Example: + """Creates a `protos.Example` object from the provided item.""" + if isinstance(item, protos.Example): return item if isinstance(item, dict): item = item.copy() item["input"] = _make_message(item["input"]) item["output"] = _make_message(item["output"]) - return glm.Example(item) + return protos.Example(item) if isinstance(item, Iterable): input, output = list(item) - return glm.Example(input=_make_message(input), output=_make_message(output)) + return protos.Example(input=_make_message(input), output=_make_message(output)) # try anyway - return glm.Example(item) + return protos.Example(item) def _make_examples_from_flat( examples: List[discuss_types.MessageOptions], -) -> List[glm.Example]: +) -> List[protos.Example]: """ - Creates a list of `glm.Example` objects from a list of message options. + Creates a list of `protos.Example` objects from a list of message options. This function takes a list of `discuss_types.MessageOptions` and pairs them into - `glm.Example` objects. The input examples must be in pairs to create valid examples. + `protos.Example` objects. The input examples must be in pairs to create valid examples. Args: examples: The list of `discuss_types.MessageOptions`. Returns: - A list of `glm.Example objects` created by pairing up the provided messages. + A list of `protos.Example objects` created by pairing up the provided messages. Raises: ValueError: If the provided list of examples is not of even length. @@ -145,7 +146,7 @@ def _make_examples_from_flat( pair.append(msg) if n % 2 == 0: continue - primer = glm.Example( + primer = protos.Example( input=pair[0], output=pair[1], ) @@ -156,21 +157,21 @@ def _make_examples_from_flat( def _make_examples( examples: discuss_types.ExamplesOptions, -) -> List[glm.Example]: +) -> List[protos.Example]: """ - Creates a list of `glm.Example` objects from the provided examples. + Creates a list of `protos.Example` objects from the provided examples. This function takes various types of example content inputs and creates a list - of `glm.Example` objects. It handles the conversion of different input types and ensures + of `protos.Example` objects. It handles the conversion of different input types and ensures the appropriate structure for creating valid examples. Args: examples: The examples to convert. Returns: - A list of `glm.Example` objects created from the provided examples. + A list of `protos.Example` objects created from the provided examples. """ - if isinstance(examples, glm.Example): + if isinstance(examples, protos.Example): return [examples] if isinstance(examples, dict): @@ -208,11 +209,11 @@ def _make_message_prompt_dict( context: str | None = None, examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, -) -> glm.MessagePrompt: +) -> protos.MessagePrompt: """ - Creates a `glm.MessagePrompt` object from the provided prompt components. + Creates a `protos.MessagePrompt` object from the provided prompt components. - This function constructs a `glm.MessagePrompt` object using the provided `context`, `examples`, + This function constructs a `protos.MessagePrompt` object using the provided `context`, `examples`, or `messages`. It ensures the proper structure and handling of the input components. Either pass a `prompt` or it's component `context`, `examples`, `messages`. @@ -224,7 +225,7 @@ def _make_message_prompt_dict( messages: The messages for the prompt. Returns: - A `glm.MessagePrompt` object created from the provided prompt components. + A `protos.MessagePrompt` object created from the provided prompt components. """ if prompt is None: prompt = dict( @@ -238,7 +239,7 @@ def _make_message_prompt_dict( raise ValueError( "Invalid configuration: Either `prompt` or its fields `(context, examples, messages)` should be set, but not both simultaneously." ) - if isinstance(prompt, glm.MessagePrompt): + if isinstance(prompt, protos.MessagePrompt): return prompt elif isinstance(prompt, dict): # Always check dict before Iterable. pass @@ -268,12 +269,12 @@ def _make_message_prompt( context: str | None = None, examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, -) -> glm.MessagePrompt: - """Creates a `glm.MessagePrompt` object from the provided prompt components.""" +) -> protos.MessagePrompt: + """Creates a `protos.MessagePrompt` object from the provided prompt components.""" prompt = _make_message_prompt_dict( prompt=prompt, context=context, examples=examples, messages=messages ) - return glm.MessagePrompt(prompt) + return protos.MessagePrompt(prompt) def _make_generate_message_request( @@ -287,15 +288,15 @@ def _make_generate_message_request( top_p: float | None = None, top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, -) -> glm.GenerateMessageRequest: - """Creates a `glm.GenerateMessageRequest` object for generating messages.""" +) -> protos.GenerateMessageRequest: + """Creates a `protos.GenerateMessageRequest` object for generating messages.""" model = model_types.make_model_name(model) prompt = _make_message_prompt( prompt=prompt, context=context, examples=examples, messages=messages ) - return glm.GenerateMessageRequest( + return protos.GenerateMessageRequest( model=model, prompt=prompt, temperature=temperature, @@ -514,9 +515,9 @@ async def reply_async( def _build_chat_response( - request: glm.GenerateMessageRequest, - response: glm.GenerateMessageResponse, - client: glm.DiscussServiceClient | glm.DiscussServiceAsyncClient, + request: protos.GenerateMessageRequest, + response: protos.GenerateMessageResponse, + client: glm.DiscussServiceClient | protos.DiscussServiceAsyncClient, ) -> ChatResponse: request = type(request).to_dict(request) prompt = request.pop("prompt") @@ -541,7 +542,7 @@ def _build_chat_response( def _generate_response( - request: glm.GenerateMessageRequest, + request: protos.GenerateMessageRequest, client: glm.DiscussServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> ChatResponse: @@ -557,7 +558,7 @@ def _generate_response( async def _generate_response_async( - request: glm.GenerateMessageRequest, + request: protos.GenerateMessageRequest, client: glm.DiscussServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> ChatResponse: diff --git a/google/generativeai/embedding.py b/google/generativeai/embedding.py index 8218ec11d..616fa07bf 100644 --- a/google/generativeai/embedding.py +++ b/google/generativeai/embedding.py @@ -18,6 +18,7 @@ from typing import Any, Iterable, overload, TypeVar, Union, Mapping import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.client import get_default_generative_client from google.generativeai.client import get_default_generative_async_client @@ -30,7 +31,7 @@ DEFAULT_EMB_MODEL = "models/embedding-001" EMBEDDING_MAX_BATCH_SIZE = 100 -EmbeddingTaskType = glm.TaskType +EmbeddingTaskType = protos.TaskType EmbeddingTaskTypeOptions = Union[int, str, EmbeddingTaskType] @@ -183,7 +184,7 @@ def embed_content( if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)): result = {"embedding": []} requests = ( - glm.EmbedContentRequest( + protos.EmbedContentRequest( model=model, content=content_types.to_content(c), task_type=task_type, @@ -193,7 +194,7 @@ def embed_content( for c in content ) for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE): - embedding_request = glm.BatchEmbedContentsRequest(model=model, requests=batch) + embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch) embedding_response = client.batch_embed_contents( embedding_request, **request_options, @@ -202,7 +203,7 @@ def embed_content( result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"]) return result else: - embedding_request = glm.EmbedContentRequest( + embedding_request = protos.EmbedContentRequest( model=model, content=content_types.to_content(content), task_type=task_type, @@ -276,7 +277,7 @@ async def embed_content_async( if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)): result = {"embedding": []} requests = ( - glm.EmbedContentRequest( + protos.EmbedContentRequest( model=model, content=content_types.to_content(c), task_type=task_type, @@ -286,7 +287,7 @@ async def embed_content_async( for c in content ) for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE): - embedding_request = glm.BatchEmbedContentsRequest(model=model, requests=batch) + embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch) embedding_response = await client.batch_embed_contents( embedding_request, **request_options, @@ -295,7 +296,7 @@ async def embed_content_async( result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"]) return result else: - embedding_request = glm.EmbedContentRequest( + embedding_request = protos.EmbedContentRequest( model=model, content=content_types.to_content(content), task_type=task_type, diff --git a/google/generativeai/files.py b/google/generativeai/files.py index 386592225..4028d37f7 100644 --- a/google/generativeai/files.py +++ b/google/generativeai/files.py @@ -19,7 +19,7 @@ import mimetypes from typing import Iterable import logging -import google.ai.generativelanguage as glm +from google.generativeai import protos from itertools import islice from google.generativeai.types import file_types @@ -76,7 +76,7 @@ def list_files(page_size=100) -> Iterable[file_types.File]: """Calls the API to list files using a supported file service.""" client = get_default_file_client() - response = client.list_files(glm.ListFilesRequest(page_size=page_size)) + response = client.list_files(protos.ListFilesRequest(page_size=page_size)) for proto in response: yield file_types.File(proto) @@ -89,8 +89,8 @@ def get_file(name) -> file_types.File: def delete_file(name): """Calls the API to permanently delete a specified file using a supported file service.""" - if isinstance(name, (file_types.File, glm.File)): + if isinstance(name, (file_types.File, protos.File)): name = name.name - request = glm.DeleteFileRequest(name=name) + request = protos.DeleteFileRequest(name=name) client = get_default_file_client() client.delete_file(request=request) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 873d2fcb4..7d69ae8f9 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -11,7 +11,7 @@ import google.api_core.exceptions -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import client from google.generativeai.types import content_types from google.generativeai.types import generation_types @@ -125,8 +125,8 @@ def _prepare_request( safety_settings: safety_types.SafetySettingOptions | None = None, tools: content_types.FunctionLibraryType | None, tool_config: content_types.ToolConfigType | None, - ) -> glm.GenerateContentRequest: - """Creates a `glm.GenerateContentRequest` from raw inputs.""" + ) -> protos.GenerateContentRequest: + """Creates a `protos.GenerateContentRequest` from raw inputs.""" tools_lib = self._get_tools_lib(tools) if tools_lib is not None: tools_lib = tools_lib.to_proto() @@ -147,7 +147,7 @@ def _prepare_request( merged_ss.update(safety_settings) merged_ss = safety_types.normalize_safety_settings(merged_ss) - return glm.GenerateContentRequest( + return protos.GenerateContentRequest( model=self._model_name, contents=contents, generation_config=merged_gc, @@ -209,25 +209,25 @@ def generate_content( ### Input type flexibility - While the underlying API strictly expects a `list[glm.Content]` objects, this method + While the underlying API strictly expects a `list[protos.Content]` objects, this method will convert the user input into the correct type. The hierarchy of types that can be converted is below. Any of these objects can be passed as an equivalent `dict`. - * `Iterable[glm.Content]` - * `glm.Content` - * `Iterable[glm.Part]` - * `glm.Part` - * `str`, `Image`, or `glm.Blob` + * `Iterable[protos.Content]` + * `protos.Content` + * `Iterable[protos.Part]` + * `protos.Part` + * `str`, `Image`, or `protos.Blob` - In an `Iterable[glm.Content]` each `content` is a separate message. - But note that an `Iterable[glm.Part]` is taken as the parts of a single message. + In an `Iterable[protos.Content]` each `content` is a separate message. + But note that an `Iterable[protos.Part]` is taken as the parts of a single message. Arguments: contents: The contents serving as the model's prompt. generation_config: Overrides for the model's generation config. safety_settings: Overrides for the model's safety settings. stream: If True, yield response chunks as they are generated. - tools: `glm.Tools` more info coming soon. + tools: `protos.Tools` more info coming soon. request_options: Options for the request. """ if not contents: @@ -328,14 +328,14 @@ def count_tokens( tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, request_options: helper_types.RequestOptionsType | None = None, - ) -> glm.CountTokensResponse: + ) -> protos.CountTokensResponse: if request_options is None: request_options = {} if self._client is None: self._client = client.get_default_generative_client() - request = glm.CountTokensRequest( + request = protos.CountTokensRequest( model=self.model_name, generate_content_request=self._prepare_request( contents=contents, @@ -355,14 +355,14 @@ async def count_tokens_async( tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, request_options: helper_types.RequestOptionsType | None = None, - ) -> glm.CountTokensResponse: + ) -> protos.CountTokensResponse: if request_options is None: request_options = {} if self._async_client is None: self._async_client = client.get_default_generative_async_client() - request = glm.CountTokensRequest( + request = protos.CountTokensRequest( model=self.model_name, generate_content_request=self._prepare_request( contents=contents, @@ -388,7 +388,7 @@ def start_chat( >>> response = chat.send_message("Hello?") Arguments: - history: An iterable of `glm.Content` objects, or equivalents to initialize the session. + history: An iterable of `protos.Content` objects, or equivalents to initialize the session. """ if self._generation_config.get("candidate_count", 1) > 1: raise ValueError( @@ -430,8 +430,8 @@ def __init__( enable_automatic_function_calling: bool = False, ): self.model: GenerativeModel = model - self._history: list[glm.Content] = content_types.to_contents(history) - self._last_sent: glm.Content | None = None + self._history: list[protos.Content] = content_types.to_contents(history) + self._last_sent: protos.Content | None = None self._last_received: generation_types.BaseGenerateContentResponse | None = None self.enable_automatic_function_calling = enable_automatic_function_calling @@ -535,13 +535,13 @@ def _check_response(self, *, response, stream): if not stream: if response.candidates[0].finish_reason not in ( - glm.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, - glm.Candidate.FinishReason.STOP, - glm.Candidate.FinishReason.MAX_TOKENS, + protos.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, + protos.Candidate.FinishReason.STOP, + protos.Candidate.FinishReason.MAX_TOKENS, ): raise generation_types.StopCandidateException(response.candidates[0]) - def _get_function_calls(self, response) -> list[glm.FunctionCall]: + def _get_function_calls(self, response) -> list[protos.FunctionCall]: candidates = response.candidates if len(candidates) != 1: raise ValueError( @@ -561,14 +561,14 @@ def _handle_afc( stream, tools_lib, request_options, - ) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]: + ) -> tuple[list[protos.Content], protos.Content, generation_types.BaseGenerateContentResponse]: while function_calls := self._get_function_calls(response): if not all(callable(tools_lib[fc]) for fc in function_calls): break history.append(response.candidates[0].content) - function_response_parts: list[glm.Part] = [] + function_response_parts: list[protos.Part] = [] for fc in function_calls: fr = tools_lib(fc) assert fr is not None, ( @@ -577,7 +577,7 @@ def _handle_afc( ) function_response_parts.append(fr) - send = glm.Content(role=self._USER_ROLE, parts=function_response_parts) + send = protos.Content(role=self._USER_ROLE, parts=function_response_parts) history.append(send) response = self.model.generate_content( @@ -668,14 +668,14 @@ async def _handle_afc_async( stream, tools_lib, request_options, - ) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]: + ) -> tuple[list[protos.Content], protos.Content, generation_types.BaseGenerateContentResponse]: while function_calls := self._get_function_calls(response): if not all(callable(tools_lib[fc]) for fc in function_calls): break history.append(response.candidates[0].content) - function_response_parts: list[glm.Part] = [] + function_response_parts: list[protos.Part] = [] for fc in function_calls: fr = tools_lib(fc) assert fr is not None, ( @@ -684,7 +684,7 @@ async def _handle_afc_async( ) function_response_parts.append(fr) - send = glm.Content(role=self._USER_ROLE, parts=function_response_parts) + send = protos.Content(role=self._USER_ROLE, parts=function_response_parts) history.append(send) response = await self.model.generate_content_async( @@ -708,7 +708,7 @@ def __copy__(self): history=list(self.history), ) - def rewind(self) -> tuple[glm.Content, glm.Content]: + def rewind(self) -> tuple[protos.Content, protos.Content]: """Removes the last request/response pair from the chat history.""" if self._last_received is None: result = self._history.pop(-2), self._history.pop() @@ -725,16 +725,16 @@ def last(self) -> generation_types.BaseGenerateContentResponse | None: return self._last_received @property - def history(self) -> list[glm.Content]: + def history(self) -> list[protos.Content]: """The chat history.""" last = self._last_received if last is None: return self._history if last.candidates[0].finish_reason not in ( - glm.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, - glm.Candidate.FinishReason.STOP, - glm.Candidate.FinishReason.MAX_TOKENS, + protos.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, + protos.Candidate.FinishReason.STOP, + protos.Candidate.FinishReason.MAX_TOKENS, ): error = generation_types.StopCandidateException(last.candidates[0]) last._error = error @@ -770,7 +770,7 @@ def __repr__(self) -> str: _model = str(self.model).replace("\n", "\n" + " " * 4) def content_repr(x): - return f"glm.Content({_dict_repr.repr(type(x).to_dict(x))})" + return f"protos.Content({_dict_repr.repr(type(x).to_dict(x))})" try: history = list(self.history) diff --git a/google/generativeai/models.py b/google/generativeai/models.py index 1f9e836e7..9ba0745c1 100644 --- a/google/generativeai/models.py +++ b/google/generativeai/models.py @@ -18,6 +18,8 @@ from typing import Any, Literal import google.ai.generativelanguage as glm + +from google.generativeai import protos from google.generativeai import operations from google.generativeai.client import get_default_model_client from google.generativeai.types import model_types @@ -155,16 +157,16 @@ def get_base_model_name( base_model = model.base_model elif isinstance(model, model_types.Model): base_model = model.name - elif isinstance(model, glm.Model): + elif isinstance(model, protos.Model): base_model = model.name - elif isinstance(model, glm.TunedModel): + elif isinstance(model, protos.TunedModel): base_model = getattr(model, "base_model", None) if not base_model: base_model = model.tuned_model_source.base_model else: raise TypeError( f"Invalid model: The provided model '{model}' is not recognized or supported. " - "Supported types are: str, model_types.TunedModel, model_types.Model, glm.Model, and glm.TunedModel." + "Supported types are: str, model_types.TunedModel, model_types.Model, protos.Model, and protos.TunedModel." ) return base_model @@ -282,9 +284,9 @@ def create_tuned_model( Args: source_model: The name of the model to tune. training_data: The dataset to tune the model on. This must be either: - * A `glm.Dataset`, or + * A `protos.Dataset`, or * An `Iterable` of: - *`glm.TuningExample`, + *`protos.TuningExample`, * `{'text_input': text_input, 'output': output}` dicts * `(text_input, output)` tuples. * A `Mapping` of `Iterable[str]` - use `input_key` and `output_key` to choose which @@ -339,17 +341,17 @@ def create_tuned_model( training_data, input_key=input_key, output_key=output_key ) - hyperparameters = glm.Hyperparameters( + hyperparameters = protos.Hyperparameters( epoch_count=epoch_count, batch_size=batch_size, learning_rate=learning_rate, ) - tuning_task = glm.TuningTask( + tuning_task = protos.TuningTask( training_data=training_data, hyperparameters=hyperparameters, ) - tuned_model = glm.TunedModel( + tuned_model = protos.TunedModel( **source_model, display_name=display_name, description=description, @@ -368,7 +370,7 @@ def create_tuned_model( @typing.overload def update_tuned_model( - tuned_model: glm.TunedModel, + tuned_model: protos.TunedModel, updates: None = None, *, client: glm.ModelServiceClient | None = None, @@ -389,7 +391,7 @@ def update_tuned_model( def update_tuned_model( - tuned_model: str | glm.TunedModel, + tuned_model: str | protos.TunedModel, updates: dict[str, Any] | None = None, *, client: glm.ModelServiceClient | None = None, @@ -418,10 +420,11 @@ def update_tuned_model( field_mask.paths.append(path) for path, value in updates.items(): _apply_update(tuned_model, path, value) - elif isinstance(tuned_model, glm.TunedModel): + elif isinstance(tuned_model, protos.TunedModel): if updates is not None: raise ValueError( - "Invalid argument: When calling `update_tuned_model(tuned_model:glm.TunedModel, updates=None)`, the `updates` argument must not be set." + "Invalid argument: When calling `update_tuned_model(tuned_model:protos.TunedModel, updates=None)`, " + "the `updates` argument must not be set." ) name = tuned_model.name @@ -429,11 +432,12 @@ def update_tuned_model( field_mask = protobuf_helpers.field_mask(was._pb, tuned_model._pb) else: raise TypeError( - f"Invalid argument type: In the function `update_tuned_model(tuned_model:dict|glm.TunedModel)`, the `tuned_model` argument must be of type `dict` or `glm.TunedModel`. Received type: {type(tuned_model).__name__}." + "Invalid argument type: In the function `update_tuned_model(tuned_model:dict|protos.TunedModel)`, the " + f"`tuned_model` argument must be of type `dict` or `protos.TunedModel`. Received type: {type(tuned_model).__name__}." ) result = client.update_tuned_model( - glm.UpdateTunedModelRequest(tuned_model=tuned_model, update_mask=field_mask), + protos.UpdateTunedModelRequest(tuned_model=tuned_model, update_mask=field_mask), **request_options, ) return model_types.decode_tuned_model(result) diff --git a/google/generativeai/operations.py b/google/generativeai/operations.py index 01c0a6b14..52fd8a1b8 100644 --- a/google/generativeai/operations.py +++ b/google/generativeai/operations.py @@ -17,7 +17,7 @@ import functools from typing import Iterator -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import client as client_lib from google.generativeai.types import model_types @@ -75,8 +75,8 @@ def from_proto(cls, proto, client): cls=CreateTunedModelOperation, operation=proto, operations_client=client, - result_type=glm.TunedModel, - metadata_type=glm.CreateTunedModelMetadata, + result_type=protos.TunedModel, + metadata_type=protos.CreateTunedModelMetadata, ) @classmethod @@ -111,14 +111,14 @@ def update(self): """Refresh the current statuses in metadata/result/error""" self._refresh_and_update() - def wait_bar(self, **kwargs) -> Iterator[glm.CreateTunedModelMetadata]: + def wait_bar(self, **kwargs) -> Iterator[protos.CreateTunedModelMetadata]: """A tqdm wait bar, yields `Operation` statuses until complete. Args: **kwargs: passed through to `tqdm.auto.tqdm(..., **kwargs)` Yields: - Operation statuses as `glm.CreateTunedModelMetadata` objects. + Operation statuses as `protos.CreateTunedModelMetadata` objects. """ bar = tqdm.tqdm(total=self.metadata.total_steps, initial=0, **kwargs) @@ -131,7 +131,7 @@ def wait_bar(self, **kwargs) -> Iterator[glm.CreateTunedModelMetadata]: bar.update(self.metadata.completed_steps - bar.n) return self.result() - def set_result(self, result: glm.TunedModel): + def set_result(self, result: protos.TunedModel): result = model_types.decode_tuned_model(result) super().set_result(result) diff --git a/google/generativeai/protos.py b/google/generativeai/protos.py new file mode 100644 index 000000000..010396c75 --- /dev/null +++ b/google/generativeai/protos.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module provides low level access to the ProtoBuffer "Message" classes used by the API. + +**For typical usage of this SDK you do not need to use any of these classes.** + +ProtoBufers are Google API's serilization format. They are strongly typed and efficient. + +The `genai` SDK tries to be permissive about what objects it will accept from a user, but in the end +the SDK always converts input to an appropriate Proto Message object to send as the request. Each API request +has a `*Request` and `*Response` Message defined here. + +If you have any uncertainty about what the API may accept or return, these classes provide the +complete/unambiguous answer. They come from the `google-ai-generativelanguage` package which is +generated from a snapshot of the API definition. + +>>> from google.generativeai import protos +>>> import inspect +>>> print(inspect.getsource(protos.Part)) + +Proto classes can have "oneof" fields. Use `in` to check which `oneof` field is set. + +>>> p = protos.Part(text='hello') +>>> 'text' in p +True +>>> p.inline_data = {'mime_type':'image/png', 'data': b'PNG'} +>>> type(p.inline_data) is protos.Blob +True +>>> 'inline_data' in p +True +>>> 'text' in p +False + +Instances of all Message classes can be converted into JSON compatible dictionaries with the following construct +(Bytes are base64 encoded): + +>>> p_dict = type(p).to_dict(p) +>>> p_dict +{'inline_data': {'mime_type': 'image/png', 'data': 'UE5H'}} + +A compatible dict can be converted to an instance of a Message class by passing it as the first argument to the +constructor: + +>>> p = protos.Part(p_dict) +inline_data { + mime_type: "image/png" + data: "PNG" +} + +Note when converting that `to_dict` accepts additional arguments: + +- `use_integers_for_enums:bool = True`, Set it to `False` to replace enum int values with their string + names in the output +- ` including_default_value_fields:bool = True`, Set it to `False` to reduce the verbosity of the output. + +Additional arguments are described in the docstring: + +>>> help(proto.Part.to_dict) +""" + +from google.ai.generativelanguage_v1beta.types import * +from google.ai.generativelanguage_v1beta.types import __all__ diff --git a/google/generativeai/responder.py b/google/generativeai/responder.py index 814bf3581..bb85167ad 100644 --- a/google/generativeai/responder.py +++ b/google/generativeai/responder.py @@ -22,9 +22,9 @@ import pydantic -from google.ai import generativelanguage as glm +from google.generativeai import protos -Type = glm.Type +Type = protos.Type TypeOptions = Union[int, str, Type] @@ -186,8 +186,8 @@ def _rename_schema_fields(schema: dict[str, Any]): class FunctionDeclaration: def __init__(self, *, name: str, description: str, parameters: dict[str, Any] | None = None): - """A class wrapping a `glm.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" - self._proto = glm.FunctionDeclaration( + """A class wrapping a `protos.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" + self._proto = protos.FunctionDeclaration( name=name, description=description, parameters=_rename_schema_fields(parameters) ) @@ -200,7 +200,7 @@ def description(self) -> str: return self._proto.description @property - def parameters(self) -> glm.Schema: + def parameters(self) -> protos.Schema: return self._proto.parameters @classmethod @@ -209,7 +209,7 @@ def from_proto(cls, proto) -> FunctionDeclaration: self._proto = proto return self - def to_proto(self) -> glm.FunctionDeclaration: + def to_proto(self) -> protos.FunctionDeclaration: return self._proto @staticmethod @@ -255,16 +255,16 @@ def __init__( super().__init__(name=name, description=description, parameters=parameters) self.function = function - def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse: + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse: result = self.function(**fc.args) if not isinstance(result, dict): result = {"result": result} - return glm.FunctionResponse(name=fc.name, response=result) + return protos.FunctionResponse(name=fc.name, response=result) FunctionDeclarationType = Union[ FunctionDeclaration, - glm.FunctionDeclaration, + protos.FunctionDeclaration, dict[str, Any], Callable[..., Any], ] @@ -272,8 +272,8 @@ def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse: def _make_function_declaration( fun: FunctionDeclarationType, -) -> FunctionDeclaration | glm.FunctionDeclaration: - if isinstance(fun, (FunctionDeclaration, glm.FunctionDeclaration)): +) -> FunctionDeclaration | protos.FunctionDeclaration: + if isinstance(fun, (FunctionDeclaration, protos.FunctionDeclaration)): return fun elif isinstance(fun, dict): if "function" in fun: @@ -289,15 +289,15 @@ def _make_function_declaration( ) -def _encode_fd(fd: FunctionDeclaration | glm.FunctionDeclaration) -> glm.FunctionDeclaration: - if isinstance(fd, glm.FunctionDeclaration): +def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.FunctionDeclaration: + if isinstance(fd, protos.FunctionDeclaration): return fd return fd.to_proto() class Tool: - """A wrapper for `glm.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" + """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): # The main path doesn't use this but is seems useful. @@ -309,23 +309,23 @@ def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): raise ValueError("") self._index[fd.name] = fd - self._proto = glm.Tool( + self._proto = protos.Tool( function_declarations=[_encode_fd(fd) for fd in self._function_declarations] ) @property - def function_declarations(self) -> list[FunctionDeclaration | glm.FunctionDeclaration]: + def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]: return self._function_declarations def __getitem__( - self, name: str | glm.FunctionCall - ) -> FunctionDeclaration | glm.FunctionDeclaration: + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: if not isinstance(name, str): name = name.name return self._index[name] - def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse | None: + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse | None: declaration = self[fc] if not callable(declaration): return None @@ -341,21 +341,21 @@ class ToolDict(TypedDict): ToolType = Union[ - Tool, glm.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType + Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType ] def _make_tool(tool: ToolType) -> Tool: if isinstance(tool, Tool): return tool - elif isinstance(tool, glm.Tool): + elif isinstance(tool, protos.Tool): return Tool(function_declarations=tool.function_declarations) elif isinstance(tool, dict): if "function_declarations" in tool: return Tool(**tool) else: fd = tool - return Tool(function_declarations=[glm.FunctionDeclaration(**fd)]) + return Tool(function_declarations=[protos.FunctionDeclaration(**fd)]) elif isinstance(tool, Iterable): return Tool(function_declarations=tool) else: @@ -385,20 +385,20 @@ def __init__(self, tools: Iterable[ToolType]): self._index[declaration.name] = declaration def __getitem__( - self, name: str | glm.FunctionCall - ) -> FunctionDeclaration | glm.FunctionDeclaration: + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: if not isinstance(name, str): name = name.name return self._index[name] - def __call__(self, fc: glm.FunctionCall) -> glm.Part | None: + def __call__(self, fc: protos.FunctionCall) -> protos.Part | None: declaration = self[fc] if not callable(declaration): return None response = declaration(fc) - return glm.Part(function_response=response) + return protos.Part(function_response=response) def to_proto(self): return [tool.to_proto() for tool in self._tools] @@ -431,7 +431,7 @@ def to_function_library(lib: FunctionLibraryType | None) -> FunctionLibrary | No return FunctionLibrary(tools=lib) -FunctionCallingMode = glm.FunctionCallingConfig.Mode +FunctionCallingMode = protos.FunctionCallingConfig.Mode # fmt: off _FUNCTION_CALLING_MODE = { @@ -467,12 +467,12 @@ class FunctionCallingConfigDict(TypedDict): FunctionCallingConfigType = Union[ - FunctionCallingModeType, FunctionCallingConfigDict, glm.FunctionCallingConfig + FunctionCallingModeType, FunctionCallingConfigDict, protos.FunctionCallingConfig ] -def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCallingConfig: - if isinstance(obj, glm.FunctionCallingConfig): +def to_function_calling_config(obj: FunctionCallingConfigType) -> protos.FunctionCallingConfig: + if isinstance(obj, protos.FunctionCallingConfig): return obj elif isinstance(obj, (FunctionCallingMode, str, int)): obj = {"mode": to_function_calling_mode(obj)} @@ -482,30 +482,31 @@ def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCa obj["mode"] = to_function_calling_mode(mode) else: raise TypeError( - f"Invalid argument type: Could not convert input to `glm.FunctionCallingConfig`. Received type: {type(obj).__name__}.", + "Invalid argument type: Could not convert input to `protos.FunctionCallingConfig`." + f" Received type: {type(obj).__name__}.", obj, ) - return glm.FunctionCallingConfig(obj) + return protos.FunctionCallingConfig(obj) class ToolConfigDict: function_calling_config: FunctionCallingConfigType -ToolConfigType = Union[ToolConfigDict, glm.ToolConfig] +ToolConfigType = Union[ToolConfigDict, protos.ToolConfig] -def to_tool_config(obj: ToolConfigType) -> glm.ToolConfig: - if isinstance(obj, glm.ToolConfig): +def to_tool_config(obj: ToolConfigType) -> protos.ToolConfig: + if isinstance(obj, protos.ToolConfig): return obj elif isinstance(obj, dict): fcc = obj.pop("function_calling_config") fcc = to_function_calling_config(fcc) obj["function_calling_config"] = fcc - return glm.ToolConfig(**obj) + return protos.ToolConfig(**obj) else: raise TypeError( - f"Invalid argument type: Could not convert input to `glm.ToolConfig`. Received type: {type(obj).__name__}.", - obj, + "Invalid argument type: Could not convert input to `protos.ToolConfig`. " + f"Received type: {type(obj).__name__}.", ) diff --git a/google/generativeai/retriever.py b/google/generativeai/retriever.py index e295bc5b7..53c90140a 100644 --- a/google/generativeai/retriever.py +++ b/google/generativeai/retriever.py @@ -14,12 +14,11 @@ # limitations under the License. from __future__ import annotations -import re -import string -import dataclasses -from typing import Any, AsyncIterable, Iterable, Optional + +from typing import AsyncIterable, Iterable, Optional import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.client import get_default_retriever_client from google.generativeai.client import get_default_retriever_async_client @@ -57,13 +56,13 @@ def create_corpus( client = get_default_retriever_client() if name is None: - corpus = glm.Corpus(display_name=display_name) + corpus = protos.Corpus(display_name=display_name) elif retriever_types.valid_name(name): - corpus = glm.Corpus(name=f"corpora/{name}", display_name=display_name) + corpus = protos.Corpus(name=f"corpora/{name}", display_name=display_name) else: raise ValueError(retriever_types.NAME_ERROR_MSG.format(length=len(name), name=name)) - request = glm.CreateCorpusRequest(corpus=corpus) + request = protos.CreateCorpusRequest(corpus=corpus) response = client.create_corpus(request, **request_options) response = type(response).to_dict(response) idecode_time(response, "create_time") @@ -86,13 +85,13 @@ async def create_corpus_async( client = get_default_retriever_async_client() if name is None: - corpus = glm.Corpus(display_name=display_name) + corpus = protos.Corpus(display_name=display_name) elif retriever_types.valid_name(name): - corpus = glm.Corpus(name=f"corpora/{name}", display_name=display_name) + corpus = protos.Corpus(name=f"corpora/{name}", display_name=display_name) else: raise ValueError(retriever_types.NAME_ERROR_MSG.format(length=len(name), name=name)) - request = glm.CreateCorpusRequest(corpus=corpus) + request = protos.CreateCorpusRequest(corpus=corpus) response = await client.create_corpus(request, **request_options) response = type(response).to_dict(response) idecode_time(response, "create_time") @@ -124,7 +123,7 @@ def get_corpus( if "/" not in name: name = "corpora/" + name - request = glm.GetCorpusRequest(name=name) + request = protos.GetCorpusRequest(name=name) response = client.get_corpus(request, **request_options) response = type(response).to_dict(response) idecode_time(response, "create_time") @@ -149,7 +148,7 @@ async def get_corpus_async( if "/" not in name: name = "corpora/" + name - request = glm.GetCorpusRequest(name=name) + request = protos.GetCorpusRequest(name=name) response = await client.get_corpus(request, **request_options) response = type(response).to_dict(response) idecode_time(response, "create_time") @@ -181,7 +180,7 @@ def delete_corpus( if "/" not in name: name = "corpora/" + name - request = glm.DeleteCorpusRequest(name=name, force=force) + request = protos.DeleteCorpusRequest(name=name, force=force) client.delete_corpus(request, **request_options) @@ -201,7 +200,7 @@ async def delete_corpus_async( if "/" not in name: name = "corpora/" + name - request = glm.DeleteCorpusRequest(name=name, force=force) + request = protos.DeleteCorpusRequest(name=name, force=force) await client.delete_corpus(request, **request_options) @@ -227,7 +226,7 @@ def list_corpora( if client is None: client = get_default_retriever_client() - request = glm.ListCorporaRequest(page_size=page_size) + request = protos.ListCorporaRequest(page_size=page_size) for corpus in client.list_corpora(request, **request_options): corpus = type(corpus).to_dict(corpus) idecode_time(corpus, "create_time") @@ -248,7 +247,7 @@ async def list_corpora_async( if client is None: client = get_default_retriever_async_client() - request = glm.ListCorporaRequest(page_size=page_size) + request = protos.ListCorporaRequest(page_size=page_size) async for corpus in await client.list_corpora(request, **request_options): corpus = type(corpus).to_dict(corpus) idecode_time(corpus, "create_time") diff --git a/google/generativeai/text.py b/google/generativeai/text.py index b8b814754..2a6267661 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -21,6 +21,8 @@ import google.ai.generativelanguage as glm +from google.generativeai import protos + from google.generativeai.client import get_default_text_client from google.generativeai import string_utils from google.generativeai.types import helper_types @@ -52,23 +54,23 @@ def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]: yield batch -def _make_text_prompt(prompt: str | dict[str, str]) -> glm.TextPrompt: +def _make_text_prompt(prompt: str | dict[str, str]) -> protos.TextPrompt: """ - Creates a `glm.TextPrompt` object based on the provided prompt input. + Creates a `protos.TextPrompt` object based on the provided prompt input. Args: prompt: The prompt input, either a string or a dictionary. Returns: - glm.TextPrompt: A TextPrompt object containing the prompt text. + protos.TextPrompt: A TextPrompt object containing the prompt text. Raises: TypeError: If the provided prompt is neither a string nor a dictionary. """ if isinstance(prompt, str): - return glm.TextPrompt(text=prompt) + return protos.TextPrompt(text=prompt) elif isinstance(prompt, dict): - return glm.TextPrompt(prompt) + return protos.TextPrompt(prompt) else: raise TypeError( "Invalid argument type: Expected a string or dictionary for the text prompt." @@ -86,11 +88,11 @@ def _make_generate_text_request( top_k: int | None = None, safety_settings: palm_safety_types.SafetySettingOptions | None = None, stop_sequences: str | Iterable[str] | None = None, -) -> glm.GenerateTextRequest: +) -> protos.GenerateTextRequest: """ - Creates a `glm.GenerateTextRequest` object based on the provided parameters. + Creates a `protos.GenerateTextRequest` object based on the provided parameters. - This function generates a `glm.GenerateTextRequest` object with the specified + This function generates a `protos.GenerateTextRequest` object with the specified parameters. It prepares the input parameters and creates a request that can be used for generating text using the chosen model. @@ -107,7 +109,7 @@ def _make_generate_text_request( or iterable of strings. Defaults to None. Returns: - `glm.GenerateTextRequest`: A `GenerateTextRequest` object configured with the specified parameters. + `protos.GenerateTextRequest`: A `GenerateTextRequest` object configured with the specified parameters. """ model = model_types.make_model_name(model) prompt = _make_text_prompt(prompt=prompt) @@ -117,7 +119,7 @@ def _make_generate_text_request( if stop_sequences: stop_sequences = list(stop_sequences) - return glm.GenerateTextRequest( + return protos.GenerateTextRequest( model=model, prompt=prompt, temperature=temperature, @@ -216,12 +218,12 @@ def __init__(self, **kwargs): def _generate_response( - request: glm.GenerateTextRequest, + request: protos.GenerateTextRequest, client: glm.TextServiceClient = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Completion: """ - Generates a response using the provided `glm.GenerateTextRequest` and client. + Generates a response using the provided `protos.GenerateTextRequest` and client. Args: request: The text generation request. @@ -267,7 +269,7 @@ def count_text_tokens( client = get_default_text_client() result = client.count_text_tokens( - glm.CountTextTokensRequest(model=base_model, prompt={"text": prompt}), + protos.CountTextTokensRequest(model=base_model, prompt={"text": prompt}), **request_options, ) @@ -322,7 +324,7 @@ def generate_embeddings( client = get_default_text_client() if isinstance(text, str): - embedding_request = glm.EmbedTextRequest(model=model, text=text) + embedding_request = protos.EmbedTextRequest(model=model, text=text) embedding_response = client.embed_text( embedding_request, **request_options, @@ -333,7 +335,7 @@ def generate_embeddings( result = {"embedding": []} for batch in _batched(text, EMBEDDING_MAX_BATCH_SIZE): # TODO(markdaoust): This could use an option for returning an iterator or wait-bar. - embedding_request = glm.BatchEmbedTextRequest(model=model, texts=batch) + embedding_request = protos.BatchEmbedTextRequest(model=model, texts=batch) embedding_response = client.batch_embed_text( embedding_request, **request_options, diff --git a/google/generativeai/types/answer_types.py b/google/generativeai/types/answer_types.py index 18bd11d62..143a578a4 100644 --- a/google/generativeai/types/answer_types.py +++ b/google/generativeai/types/answer_types.py @@ -16,11 +16,11 @@ from typing import Union -import google.ai.generativelanguage as glm +from google.generativeai import protos __all__ = ["Answer"] -FinishReason = glm.Candidate.FinishReason +FinishReason = protos.Candidate.FinishReason FinishReasonOptions = Union[int, str, FinishReason] diff --git a/google/generativeai/types/citation_types.py b/google/generativeai/types/citation_types.py index ae857c35b..9f169703f 100644 --- a/google/generativeai/types/citation_types.py +++ b/google/generativeai/types/citation_types.py @@ -17,7 +17,7 @@ from typing_extensions import TypedDict -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils @@ -33,10 +33,10 @@ class CitationSourceDict(TypedDict): uri: str | None license: str | None - __doc__ = string_utils.strip_oneof(glm.CitationSource.__doc__) + __doc__ = string_utils.strip_oneof(protos.CitationSource.__doc__) class CitationMetadataDict(TypedDict): citation_sources: List[CitationSourceDict | None] - __doc__ = string_utils.strip_oneof(glm.CitationMetadata.__doc__) + __doc__ = string_utils.strip_oneof(protos.CitationMetadata.__doc__) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index 169683608..b8966b005 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -26,7 +26,7 @@ import pydantic from google.generativeai.types import file_types -from google.ai import generativelanguage as glm +from google.generativeai import protos if typing.TYPE_CHECKING: import PIL.Image @@ -80,10 +80,10 @@ def pil_to_blob(img): mime_type = "image/jpeg" bytesio.seek(0) data = bytesio.read() - return glm.Blob(mime_type=mime_type, data=data) + return protos.Blob(mime_type=mime_type, data=data) -def image_to_blob(image) -> glm.Blob: +def image_to_blob(image) -> protos.Blob: if PIL is not None: if isinstance(image, PIL.Image.Image): return pil_to_blob(image) @@ -100,7 +100,7 @@ def image_to_blob(image) -> glm.Blob: if mime_type is None: mime_type = "image/unknown" - return glm.Blob(mime_type=mime_type, data=image.data) + return protos.Blob(mime_type=mime_type, data=image.data) raise TypeError( "Image conversion failed. The input was expected to be of type `Image` " @@ -115,23 +115,23 @@ class BlobDict(TypedDict): data: bytes -def _convert_dict(d: Mapping) -> glm.Content | glm.Part | glm.Blob: +def _convert_dict(d: Mapping) -> protos.Content | protos.Part | protos.Blob: if is_content_dict(d): content = dict(d) if isinstance(parts := content["parts"], str): content["parts"] = [parts] content["parts"] = [to_part(part) for part in content["parts"]] - return glm.Content(content) + return protos.Content(content) elif is_part_dict(d): part = dict(d) if "inline_data" in part: part["inline_data"] = to_blob(part["inline_data"]) if "file_data" in part: part["file_data"] = file_types.to_file_data(part["file_data"]) - return glm.Part(part) + return protos.Part(part) elif is_blob_dict(d): blob = d - return glm.Blob(blob) + return protos.Blob(blob) else: raise KeyError( "Unable to determine the intended type of the `dict`. " @@ -148,17 +148,17 @@ def is_blob_dict(d): if typing.TYPE_CHECKING: BlobType = Union[ - glm.Blob, BlobDict, PIL.Image.Image, IPython.display.Image + protos.Blob, BlobDict, PIL.Image.Image, IPython.display.Image ] # Any for the images else: - BlobType = Union[glm.Blob, BlobDict, Any] + BlobType = Union[protos.Blob, BlobDict, Any] -def to_blob(blob: BlobType) -> glm.Blob: +def to_blob(blob: BlobType) -> protos.Blob: if isinstance(blob, Mapping): blob = _convert_dict(blob) - if isinstance(blob, glm.Blob): + if isinstance(blob, protos.Blob): return blob elif isinstance(blob, IMAGE_TYPES): return image_to_blob(blob) @@ -182,12 +182,12 @@ class PartDict(TypedDict): # When you need a `Part` accept a part object, part-dict, blob or string PartType = Union[ - glm.Part, + protos.Part, PartDict, BlobType, str, - glm.FunctionCall, - glm.FunctionResponse, + protos.FunctionCall, + protos.FunctionResponse, file_types.FileDataType, ] @@ -206,22 +206,22 @@ def to_part(part: PartType): if isinstance(part, Mapping): part = _convert_dict(part) - if isinstance(part, glm.Part): + if isinstance(part, protos.Part): return part elif isinstance(part, str): - return glm.Part(text=part) - elif isinstance(part, glm.FileData): - return glm.Part(file_data=part) - elif isinstance(part, (glm.File, file_types.File)): - return glm.Part(file_data=file_types.to_file_data(part)) - elif isinstance(part, glm.FunctionCall): - return glm.Part(function_call=part) - elif isinstance(part, glm.FunctionResponse): - return glm.Part(function_response=part) + return protos.Part(text=part) + elif isinstance(part, protos.FileData): + return protos.Part(file_data=part) + elif isinstance(part, (protos.File, file_types.File)): + return protos.Part(file_data=file_types.to_file_data(part)) + elif isinstance(part, protos.FunctionCall): + return protos.Part(function_call=part) + elif isinstance(part, protos.FunctionResponse): + return protos.Part(function_response=part) else: # Maybe it can be turned into a blob? - return glm.Part(inline_data=to_blob(part)) + return protos.Part(inline_data=to_blob(part)) class ContentDict(TypedDict): @@ -235,10 +235,10 @@ def is_content_dict(d): # When you need a message accept a `Content` object or dict, a list of parts, # or a single part -ContentType = Union[glm.Content, ContentDict, Iterable[PartType], PartType] +ContentType = Union[protos.Content, ContentDict, Iterable[PartType], PartType] # For generate_content, we're not guessing roles for [[parts],[parts],[parts]] yet. -StrictContentType = Union[glm.Content, ContentDict] +StrictContentType = Union[protos.Content, ContentDict] def to_content(content: ContentType): @@ -250,24 +250,24 @@ def to_content(content: ContentType): if isinstance(content, Mapping): content = _convert_dict(content) - if isinstance(content, glm.Content): + if isinstance(content, protos.Content): return content elif isinstance(content, Iterable) and not isinstance(content, str): - return glm.Content(parts=[to_part(part) for part in content]) + return protos.Content(parts=[to_part(part) for part in content]) else: # Maybe this is a Part? - return glm.Content(parts=[to_part(content)]) + return protos.Content(parts=[to_part(content)]) def strict_to_content(content: StrictContentType): if isinstance(content, Mapping): content = _convert_dict(content) - if isinstance(content, glm.Content): + if isinstance(content, protos.Content): return content else: raise TypeError( - "Invalid input type. Expected a `glm.Content` or a `dict` with a 'parts' key.\n" + "Invalid input type. Expected a `protos.Content` or a `dict` with a 'parts' key.\n" f"However, received an object of type: {type(content)}.\n" f"Object Value: {content}" ) @@ -276,7 +276,7 @@ def strict_to_content(content: StrictContentType): ContentsType = Union[ContentType, Iterable[StrictContentType], None] -def to_contents(contents: ContentsType) -> list[glm.Content]: +def to_contents(contents: ContentsType) -> list[protos.Content]: if contents is None: return [] @@ -509,8 +509,8 @@ def _rename_schema_fields(schema): class FunctionDeclaration: def __init__(self, *, name: str, description: str, parameters: dict[str, Any] | None = None): - """A class wrapping a `glm.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" - self._proto = glm.FunctionDeclaration( + """A class wrapping a `protos.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" + self._proto = protos.FunctionDeclaration( name=name, description=description, parameters=_rename_schema_fields(parameters) ) @@ -523,7 +523,7 @@ def description(self) -> str: return self._proto.description @property - def parameters(self) -> glm.Schema: + def parameters(self) -> protos.Schema: return self._proto.parameters @classmethod @@ -532,7 +532,7 @@ def from_proto(cls, proto) -> FunctionDeclaration: self._proto = proto return self - def to_proto(self) -> glm.FunctionDeclaration: + def to_proto(self) -> protos.FunctionDeclaration: return self._proto @staticmethod @@ -578,16 +578,16 @@ def __init__( super().__init__(name=name, description=description, parameters=parameters) self.function = function - def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse: + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse: result = self.function(**fc.args) if not isinstance(result, dict): result = {"result": result} - return glm.FunctionResponse(name=fc.name, response=result) + return protos.FunctionResponse(name=fc.name, response=result) FunctionDeclarationType = Union[ FunctionDeclaration, - glm.FunctionDeclaration, + protos.FunctionDeclaration, dict[str, Any], Callable[..., Any], ] @@ -595,8 +595,8 @@ def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse: def _make_function_declaration( fun: FunctionDeclarationType, -) -> FunctionDeclaration | glm.FunctionDeclaration: - if isinstance(fun, (FunctionDeclaration, glm.FunctionDeclaration)): +) -> FunctionDeclaration | protos.FunctionDeclaration: + if isinstance(fun, (FunctionDeclaration, protos.FunctionDeclaration)): return fun elif isinstance(fun, dict): if "function" in fun: @@ -613,15 +613,15 @@ def _make_function_declaration( ) -def _encode_fd(fd: FunctionDeclaration | glm.FunctionDeclaration) -> glm.FunctionDeclaration: - if isinstance(fd, glm.FunctionDeclaration): +def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.FunctionDeclaration: + if isinstance(fd, protos.FunctionDeclaration): return fd return fd.to_proto() class Tool: - """A wrapper for `glm.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" + """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): # The main path doesn't use this but is seems useful. @@ -633,23 +633,23 @@ def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): raise ValueError("") self._index[fd.name] = fd - self._proto = glm.Tool( + self._proto = protos.Tool( function_declarations=[_encode_fd(fd) for fd in self._function_declarations] ) @property - def function_declarations(self) -> list[FunctionDeclaration | glm.FunctionDeclaration]: + def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]: return self._function_declarations def __getitem__( - self, name: str | glm.FunctionCall - ) -> FunctionDeclaration | glm.FunctionDeclaration: + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: if not isinstance(name, str): name = name.name return self._index[name] - def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse | None: + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse | None: declaration = self[fc] if not callable(declaration): return None @@ -665,21 +665,21 @@ class ToolDict(TypedDict): ToolType = Union[ - Tool, glm.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType + Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType ] def _make_tool(tool: ToolType) -> Tool: if isinstance(tool, Tool): return tool - elif isinstance(tool, glm.Tool): + elif isinstance(tool, protos.Tool): return Tool(function_declarations=tool.function_declarations) elif isinstance(tool, dict): if "function_declarations" in tool: return Tool(**tool) else: fd = tool - return Tool(function_declarations=[glm.FunctionDeclaration(**fd)]) + return Tool(function_declarations=[protos.FunctionDeclaration(**fd)]) elif isinstance(tool, Iterable): return Tool(function_declarations=tool) else: @@ -711,20 +711,20 @@ def __init__(self, tools: Iterable[ToolType]): self._index[declaration.name] = declaration def __getitem__( - self, name: str | glm.FunctionCall - ) -> FunctionDeclaration | glm.FunctionDeclaration: + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: if not isinstance(name, str): name = name.name return self._index[name] - def __call__(self, fc: glm.FunctionCall) -> glm.Part | None: + def __call__(self, fc: protos.FunctionCall) -> protos.Part | None: declaration = self[fc] if not callable(declaration): return None response = declaration(fc) - return glm.Part(function_response=response) + return protos.Part(function_response=response) def to_proto(self): return [tool.to_proto() for tool in self._tools] @@ -757,7 +757,7 @@ def to_function_library(lib: FunctionLibraryType | None) -> FunctionLibrary | No return FunctionLibrary(tools=lib) -FunctionCallingMode = glm.FunctionCallingConfig.Mode +FunctionCallingMode = protos.FunctionCallingConfig.Mode # fmt: off _FUNCTION_CALLING_MODE = { @@ -793,12 +793,12 @@ class FunctionCallingConfigDict(TypedDict): FunctionCallingConfigType = Union[ - FunctionCallingModeType, FunctionCallingConfigDict, glm.FunctionCallingConfig + FunctionCallingModeType, FunctionCallingConfigDict, protos.FunctionCallingConfig ] -def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCallingConfig: - if isinstance(obj, glm.FunctionCallingConfig): +def to_function_calling_config(obj: FunctionCallingConfigType) -> protos.FunctionCallingConfig: + if isinstance(obj, protos.FunctionCallingConfig): return obj elif isinstance(obj, (FunctionCallingMode, str, int)): obj = {"mode": to_function_calling_mode(obj)} @@ -808,32 +808,32 @@ def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCa obj["mode"] = to_function_calling_mode(mode) else: raise TypeError( - "Invalid input type. Failed to convert input to `glm.FunctionCallingConfig`.\n" + "Invalid input type. Failed to convert input to `protos.FunctionCallingConfig`.\n" f"Received an object of type: {type(obj)}.\n" f"Object Value: {obj}" ) - return glm.FunctionCallingConfig(obj) + return protos.FunctionCallingConfig(obj) class ToolConfigDict: function_calling_config: FunctionCallingConfigType -ToolConfigType = Union[ToolConfigDict, glm.ToolConfig] +ToolConfigType = Union[ToolConfigDict, protos.ToolConfig] -def to_tool_config(obj: ToolConfigType) -> glm.ToolConfig: - if isinstance(obj, glm.ToolConfig): +def to_tool_config(obj: ToolConfigType) -> protos.ToolConfig: + if isinstance(obj, protos.ToolConfig): return obj elif isinstance(obj, dict): fcc = obj.pop("function_calling_config") fcc = to_function_calling_config(fcc) obj["function_calling_config"] = fcc - return glm.ToolConfig(**obj) + return protos.ToolConfig(**obj) else: raise TypeError( - "Invalid input type. Failed to convert input to `glm.ToolConfig`.\n" + "Invalid input type. Failed to convert input to `protos.ToolConfig`.\n" f"Received an object of type: {type(obj)}.\n" f"Object Value: {obj}" ) diff --git a/google/generativeai/types/discuss_types.py b/google/generativeai/types/discuss_types.py index fa777d1d1..a538da65c 100644 --- a/google/generativeai/types/discuss_types.py +++ b/google/generativeai/types/discuss_types.py @@ -19,7 +19,7 @@ from typing import Any, Dict, Union, Iterable, Optional, Tuple, List from typing_extensions import TypedDict -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils from google.generativeai.types import palm_safety_types @@ -46,15 +46,15 @@ class TokenCount(TypedDict): class MessageDict(TypedDict): - """A dict representation of a `glm.Message`.""" + """A dict representation of a `protos.Message`.""" author: str content: str citation_metadata: Optional[citation_types.CitationMetadataDict] -MessageOptions = Union[str, MessageDict, glm.Message] -MESSAGE_OPTIONS = (str, dict, glm.Message) +MessageOptions = Union[str, MessageDict, protos.Message] +MESSAGE_OPTIONS = (str, dict, protos.Message) MessagesOptions = Union[ MessageOptions, @@ -64,7 +64,7 @@ class MessageDict(TypedDict): class ExampleDict(TypedDict): - """A dict representation of a `glm.Example`.""" + """A dict representation of a `protos.Example`.""" input: MessageOptions output: MessageOptions @@ -74,14 +74,14 @@ class ExampleDict(TypedDict): Tuple[MessageOptions, MessageOptions], Iterable[MessageOptions], ExampleDict, - glm.Example, + protos.Example, ] -EXAMPLE_OPTIONS = (glm.Example, dict, Iterable) +EXAMPLE_OPTIONS = (protos.Example, dict, Iterable) ExamplesOptions = Union[ExampleOptions, Iterable[ExampleOptions]] class MessagePromptDict(TypedDict, total=False): - """A dict representation of a `glm.MessagePrompt`.""" + """A dict representation of a `protos.MessagePrompt`.""" context: str examples: ExamplesOptions @@ -90,16 +90,16 @@ class MessagePromptDict(TypedDict, total=False): MessagePromptOptions = Union[ str, - glm.Message, - Iterable[Union[str, glm.Message]], + protos.Message, + Iterable[Union[str, protos.Message]], MessagePromptDict, - glm.MessagePrompt, + protos.MessagePrompt, ] MESSAGE_PROMPT_KEYS = {"context", "examples", "messages"} class ResponseDict(TypedDict): - """A dict representation of a `glm.GenerateMessageResponse`.""" + """A dict representation of a `protos.GenerateMessageResponse`.""" messages: List[MessageDict] candidates: List[MessageDict] diff --git a/google/generativeai/types/file_types.py b/google/generativeai/types/file_types.py index 0fdf05322..ef251e296 100644 --- a/google/generativeai/types/file_types.py +++ b/google/generativeai/types/file_types.py @@ -21,16 +21,16 @@ from google.rpc.status_pb2 import Status from google.generativeai.client import get_default_file_client -import google.ai.generativelanguage as glm +from google.generativeai import protos class File: - def __init__(self, proto: glm.File | File | dict): + def __init__(self, proto: protos.File | File | dict): if isinstance(proto, File): proto = proto.to_proto() - self._proto = glm.File(proto) + self._proto = protos.File(proto) - def to_proto(self) -> glm.File: + def to_proto(self) -> protos.File: return self._proto @property @@ -70,11 +70,11 @@ def uri(self) -> str: return self._proto.uri @property - def state(self) -> glm.File.State: + def state(self) -> protos.File.State: return self._proto.state @property - def video_metadata(self) -> glm.VideoMetadata: + def video_metadata(self) -> protos.VideoMetadata: return self._proto.video_metadata @property @@ -91,26 +91,26 @@ class FileDataDict(TypedDict): file_uri: str -FileDataType = Union[FileDataDict, glm.FileData, glm.File, File] +FileDataType = Union[FileDataDict, protos.FileData, protos.File, File] def to_file_data(file_data: FileDataType): if isinstance(file_data, dict): if "file_uri" in file_data: - file_data = glm.FileData(file_data) + file_data = protos.FileData(file_data) else: - file_data = glm.File(file_data) + file_data = protos.File(file_data) if isinstance(file_data, File): file_data = file_data.to_proto() - if isinstance(file_data, glm.File): - file_data = glm.FileData( + if isinstance(file_data, protos.File): + file_data = protos.FileData( mime_type=file_data.mime_type, file_uri=file_data.uri, ) - if isinstance(file_data, glm.FileData): + if isinstance(file_data, protos.FileData): return file_data else: raise TypeError( diff --git a/google/generativeai/types/generation_types.py b/google/generativeai/types/generation_types.py index 8d39f76c7..20686a156 100644 --- a/google/generativeai/types/generation_types.py +++ b/google/generativeai/types/generation_types.py @@ -30,7 +30,7 @@ import google.protobuf.json_format import google.api_core.exceptions -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils from google.generativeai.types import content_types from google.generativeai.responder import _rename_schema_fields @@ -85,7 +85,7 @@ class GenerationConfigDict(TypedDict, total=False): max_output_tokens: int temperature: float response_mime_type: str - response_schema: glm.Schema | Mapping[str, Any] # fmt: off + response_schema: protos.Schema | Mapping[str, Any] # fmt: off @dataclasses.dataclass @@ -165,19 +165,19 @@ class GenerationConfig: top_p: float | None = None top_k: int | None = None response_mime_type: str | None = None - response_schema: glm.Schema | Mapping[str, Any] | None = None + response_schema: protos.Schema | Mapping[str, Any] | None = None -GenerationConfigType = Union[glm.GenerationConfig, GenerationConfigDict, GenerationConfig] +GenerationConfigType = Union[protos.GenerationConfig, GenerationConfigDict, GenerationConfig] def _normalize_schema(generation_config): - # Convert response_schema to glm.Schema for request + # Convert response_schema to protos.Schema for request response_schema = generation_config.get("response_schema", None) if response_schema is None: return - if isinstance(response_schema, glm.Schema): + if isinstance(response_schema, protos.Schema): return if isinstance(response_schema, type): @@ -191,13 +191,13 @@ def _normalize_schema(generation_config): response_schema = content_types._schema_for_class(response_schema) response_schema = _rename_schema_fields(response_schema) - generation_config["response_schema"] = glm.Schema(response_schema) + generation_config["response_schema"] = protos.Schema(response_schema) def to_generation_config_dict(generation_config: GenerationConfigType): if generation_config is None: return {} - elif isinstance(generation_config, glm.GenerationConfig): + elif isinstance(generation_config, protos.GenerationConfig): schema = generation_config.response_schema generation_config = type(generation_config).to_dict( generation_config @@ -221,14 +221,14 @@ def to_generation_config_dict(generation_config: GenerationConfigType): def _join_citation_metadatas( - citation_metadatas: Iterable[glm.CitationMetadata], + citation_metadatas: Iterable[protos.CitationMetadata], ): citation_metadatas = list(citation_metadatas) return citation_metadatas[-1] def _join_safety_ratings_lists( - safety_ratings_lists: Iterable[list[glm.SafetyRating]], + safety_ratings_lists: Iterable[list[protos.SafetyRating]], ): ratings = {} blocked = collections.defaultdict(list) @@ -243,13 +243,13 @@ def _join_safety_ratings_lists( safety_list = [] for (category, probability), blocked in zip(ratings.items(), blocked.values()): safety_list.append( - glm.SafetyRating(category=category, probability=probability, blocked=blocked) + protos.SafetyRating(category=category, probability=probability, blocked=blocked) ) return safety_list -def _join_contents(contents: Iterable[glm.Content]): +def _join_contents(contents: Iterable[protos.Content]): contents = tuple(contents) roles = [c.role for c in contents if c.role] if roles: @@ -271,22 +271,22 @@ def _join_contents(contents: Iterable[glm.Content]): merged_parts.append(part) continue - merged_part = glm.Part(merged_parts[-1]) + merged_part = protos.Part(merged_parts[-1]) merged_part.text += part.text merged_parts[-1] = merged_part - return glm.Content( + return protos.Content( role=role, parts=merged_parts, ) -def _join_candidates(candidates: Iterable[glm.Candidate]): +def _join_candidates(candidates: Iterable[protos.Candidate]): candidates = tuple(candidates) index = candidates[0].index # These should all be the same. - return glm.Candidate( + return protos.Candidate( index=index, content=_join_contents([c.content for c in candidates]), finish_reason=candidates[-1].finish_reason, @@ -296,7 +296,7 @@ def _join_candidates(candidates: Iterable[glm.Candidate]): ) -def _join_candidate_lists(candidate_lists: Iterable[list[glm.Candidate]]): +def _join_candidate_lists(candidate_lists: Iterable[list[protos.Candidate]]): # Assuming that is a candidate ends, it is no longer returned in the list of # candidates and that's why candidates have an index candidates = collections.defaultdict(list) @@ -312,15 +312,15 @@ def _join_candidate_lists(candidate_lists: Iterable[list[glm.Candidate]]): def _join_prompt_feedbacks( - prompt_feedbacks: Iterable[glm.GenerateContentResponse.PromptFeedback], + prompt_feedbacks: Iterable[protos.GenerateContentResponse.PromptFeedback], ): # Always return the first prompt feedback. return next(iter(prompt_feedbacks)) -def _join_chunks(chunks: Iterable[glm.GenerateContentResponse]): +def _join_chunks(chunks: Iterable[protos.GenerateContentResponse]): chunks = tuple(chunks) - return glm.GenerateContentResponse( + return protos.GenerateContentResponse( candidates=_join_candidate_lists(c.candidates for c in chunks), prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks), usage_metadata=chunks[-1].usage_metadata, @@ -338,11 +338,11 @@ def __init__( done: bool, iterator: ( None - | Iterable[glm.GenerateContentResponse] - | AsyncIterable[glm.GenerateContentResponse] + | Iterable[protos.GenerateContentResponse] + | AsyncIterable[protos.GenerateContentResponse] ), - result: glm.GenerateContentResponse, - chunks: Iterable[glm.GenerateContentResponse] | None = None, + result: protos.GenerateContentResponse, + chunks: Iterable[protos.GenerateContentResponse] | None = None, ): self._done = done self._iterator = iterator @@ -440,7 +440,7 @@ def __str__(self) -> str: ) json_str = json.dumps(as_dict, indent=2) - _result = f"glm.GenerateContentResponse({json_str})" + _result = f"protos.GenerateContentResponse({json_str})" _result = _result.replace("\n", "\n ") if self._error: @@ -478,7 +478,7 @@ def rewrite_stream_error(): GENERATE_CONTENT_RESPONSE_DOC = """Instances of this class manage the response of the `generate_content` method. These are returned by `GenerativeModel.generate_content` and `ChatSession.send_message`. - This object is based on the low level `glm.GenerateContentResponse` class which just has `prompt_feedback` + This object is based on the low level `protos.GenerateContentResponse` class which just has `prompt_feedback` and `candidates` attributes. This class adds several quick accessors for common use cases. The same object type is returned for both `stream=True/False`. @@ -507,7 +507,7 @@ def rewrite_stream_error(): @string_utils.set_doc(GENERATE_CONTENT_RESPONSE_DOC) class GenerateContentResponse(BaseGenerateContentResponse): @classmethod - def from_iterator(cls, iterator: Iterable[glm.GenerateContentResponse]): + def from_iterator(cls, iterator: Iterable[protos.GenerateContentResponse]): iterator = iter(iterator) with rewrite_stream_error(): response = next(iterator) @@ -519,7 +519,7 @@ def from_iterator(cls, iterator: Iterable[glm.GenerateContentResponse]): ) @classmethod - def from_response(cls, response: glm.GenerateContentResponse): + def from_response(cls, response: protos.GenerateContentResponse): return cls( done=True, iterator=None, @@ -574,7 +574,7 @@ def resolve(self): @string_utils.set_doc(ASYNC_GENERATE_CONTENT_RESPONSE_DOC) class AsyncGenerateContentResponse(BaseGenerateContentResponse): @classmethod - async def from_aiterator(cls, iterator: AsyncIterable[glm.GenerateContentResponse]): + async def from_aiterator(cls, iterator: AsyncIterable[protos.GenerateContentResponse]): iterator = aiter(iterator) # type: ignore with rewrite_stream_error(): response = await anext(iterator) # type: ignore @@ -586,7 +586,7 @@ async def from_aiterator(cls, iterator: AsyncIterable[glm.GenerateContentRespons ) @classmethod - def from_response(cls, response: glm.GenerateContentResponse): + def from_response(cls, response: protos.GenerateContentResponse): return cls( done=True, iterator=None, diff --git a/google/generativeai/types/model_types.py b/google/generativeai/types/model_types.py index 32b3bddae..81a545b30 100644 --- a/google/generativeai/types/model_types.py +++ b/google/generativeai/types/model_types.py @@ -28,7 +28,7 @@ import urllib.request from typing_extensions import TypedDict -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.types import permission_types from google.generativeai import string_utils @@ -44,7 +44,7 @@ "TunedModelState", ] -TunedModelState = glm.TunedModel.State +TunedModelState = protos.TunedModel.State TunedModelStateOptions = Union[None, str, int, TunedModelState] @@ -91,7 +91,7 @@ def to_tuned_model_state(x: TunedModelStateOptions) -> TunedModelState: @string_utils.prettyprint @dataclasses.dataclass class Model: - """A dataclass representation of a `glm.Model`. + """A dataclass representation of a `protos.Model`. Attributes: name: The resource name of the `Model`. Format: `models/{model}` with a `{model}` naming @@ -140,8 +140,8 @@ def idecode_time(parent: dict["str", Any], name: str): parent[name] = dt -def decode_tuned_model(tuned_model: glm.TunedModel | dict["str", Any]) -> TunedModel: - if isinstance(tuned_model, glm.TunedModel): +def decode_tuned_model(tuned_model: protos.TunedModel | dict["str", Any]) -> TunedModel: + if isinstance(tuned_model, protos.TunedModel): tuned_model = type(tuned_model).to_dict(tuned_model) # pytype: disable=attribute-error tuned_model["state"] = to_tuned_model_state(tuned_model.pop("state", None)) @@ -180,7 +180,7 @@ def decode_tuned_model(tuned_model: glm.TunedModel | dict["str", Any]) -> TunedM @string_utils.prettyprint @dataclasses.dataclass class TunedModel: - """A dataclass representation of a `glm.TunedModel`.""" + """A dataclass representation of a `protos.TunedModel`.""" name: str | None = None source_model: str | None = None @@ -214,13 +214,13 @@ class TuningExampleDict(TypedDict): output: str -TuningExampleOptions = Union[TuningExampleDict, glm.TuningExample, tuple[str, str], list[str]] +TuningExampleOptions = Union[TuningExampleDict, protos.TuningExample, tuple[str, str], list[str]] # TODO(markdaoust): gs:// URLS? File-type argument for files without extension? TuningDataOptions = Union[ pathlib.Path, str, - glm.Dataset, + protos.Dataset, Mapping[str, Iterable[str]], Iterable[TuningExampleOptions], ] @@ -228,8 +228,8 @@ class TuningExampleDict(TypedDict): def encode_tuning_data( data: TuningDataOptions, input_key="text_input", output_key="output" -) -> glm.Dataset: - if isinstance(data, glm.Dataset): +) -> protos.Dataset: + if isinstance(data, protos.Dataset): return data if isinstance(data, str): @@ -301,8 +301,8 @@ def _convert_dict(data, input_key, output_key): ) for i, o in zip(inputs, outputs): - new_data.append(glm.TuningExample({"text_input": str(i), "output": str(o)})) - return glm.Dataset(examples=glm.TuningExamples(examples=new_data)) + new_data.append(protos.TuningExample({"text_input": str(i), "output": str(o)})) + return protos.Dataset(examples=protos.TuningExamples(examples=new_data)) def _convert_iterable(data, input_key, output_key): @@ -310,17 +310,17 @@ def _convert_iterable(data, input_key, output_key): for example in data: example = encode_tuning_example(example, input_key, output_key) new_data.append(example) - return glm.Dataset(examples=glm.TuningExamples(examples=new_data)) + return protos.Dataset(examples=protos.TuningExamples(examples=new_data)) def encode_tuning_example(example: TuningExampleOptions, input_key, output_key): - if isinstance(example, glm.TuningExample): + if isinstance(example, protos.TuningExample): return example elif isinstance(example, (tuple, list)): a, b = example - example = glm.TuningExample(text_input=a, output=b) + example = protos.TuningExample(text_input=a, output=b) else: # dict - example = glm.TuningExample(text_input=example[input_key], output=example[output_key]) + example = protos.TuningExample(text_input=example[input_key], output=example[output_key]) return example @@ -341,14 +341,14 @@ class Hyperparameters: learning_rate: float = 0.0 -BaseModelNameOptions = Union[str, Model, glm.Model] -TunedModelNameOptions = Union[str, TunedModel, glm.TunedModel] -AnyModelNameOptions = Union[str, Model, glm.Model, TunedModel, glm.TunedModel] +BaseModelNameOptions = Union[str, Model, protos.Model] +TunedModelNameOptions = Union[str, TunedModel, protos.TunedModel] +AnyModelNameOptions = Union[str, Model, protos.Model, TunedModel, protos.TunedModel] ModelNameOptions = AnyModelNameOptions def make_model_name(name: AnyModelNameOptions): - if isinstance(name, (Model, glm.Model, TunedModel, glm.TunedModel)): + if isinstance(name, (Model, protos.Model, TunedModel, protos.TunedModel)): name = name.name # pytype: disable=attribute-error elif isinstance(name, str): name = name @@ -372,7 +372,7 @@ def make_model_name(name: AnyModelNameOptions): @string_utils.prettyprint @dataclasses.dataclass class TokenCount: - """A dataclass representation of a `glm.TokenCountResponse`. + """A dataclass representation of a `protos.TokenCountResponse`. Attributes: token_count: The number of tokens returned by the model's tokenizer for the `input_text`. diff --git a/google/generativeai/types/palm_safety_types.py b/google/generativeai/types/palm_safety_types.py index 9fb88cd67..0ab85e1b2 100644 --- a/google/generativeai/types/palm_safety_types.py +++ b/google/generativeai/types/palm_safety_types.py @@ -23,7 +23,7 @@ from typing_extensions import TypedDict -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils @@ -39,9 +39,9 @@ ] # These are basic python enums, it's okay to expose them -HarmProbability = glm.SafetyRating.HarmProbability -HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold -BlockedReason = glm.ContentFilter.BlockedReason +HarmProbability = protos.SafetyRating.HarmProbability +HarmBlockThreshold = protos.SafetySetting.HarmBlockThreshold +BlockedReason = protos.ContentFilter.BlockedReason class HarmCategory: @@ -49,70 +49,70 @@ class HarmCategory: Harm Categories supported by the palm-family models """ - HARM_CATEGORY_UNSPECIFIED = glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value - HARM_CATEGORY_DEROGATORY = glm.HarmCategory.HARM_CATEGORY_DEROGATORY.value - HARM_CATEGORY_TOXICITY = glm.HarmCategory.HARM_CATEGORY_TOXICITY.value - HARM_CATEGORY_VIOLENCE = glm.HarmCategory.HARM_CATEGORY_VIOLENCE.value - HARM_CATEGORY_SEXUAL = glm.HarmCategory.HARM_CATEGORY_SEXUAL.value - HARM_CATEGORY_MEDICAL = glm.HarmCategory.HARM_CATEGORY_MEDICAL.value - HARM_CATEGORY_DANGEROUS = glm.HarmCategory.HARM_CATEGORY_DANGEROUS.value + HARM_CATEGORY_UNSPECIFIED = protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value + HARM_CATEGORY_DEROGATORY = protos.HarmCategory.HARM_CATEGORY_DEROGATORY.value + HARM_CATEGORY_TOXICITY = protos.HarmCategory.HARM_CATEGORY_TOXICITY.value + HARM_CATEGORY_VIOLENCE = protos.HarmCategory.HARM_CATEGORY_VIOLENCE.value + HARM_CATEGORY_SEXUAL = protos.HarmCategory.HARM_CATEGORY_SEXUAL.value + HARM_CATEGORY_MEDICAL = protos.HarmCategory.HARM_CATEGORY_MEDICAL.value + HARM_CATEGORY_DANGEROUS = protos.HarmCategory.HARM_CATEGORY_DANGEROUS.value HarmCategoryOptions = Union[str, int, HarmCategory] # fmt: off -_HARM_CATEGORIES: Dict[HarmCategoryOptions, glm.HarmCategory] = { - glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - 0: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "harm_category_unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - - glm.HarmCategory.HARM_CATEGORY_DEROGATORY: glm.HarmCategory.HARM_CATEGORY_DEROGATORY, - HarmCategory.HARM_CATEGORY_DEROGATORY: glm.HarmCategory.HARM_CATEGORY_DEROGATORY, - 1: glm.HarmCategory.HARM_CATEGORY_DEROGATORY, - "harm_category_derogatory": glm.HarmCategory.HARM_CATEGORY_DEROGATORY, - "derogatory": glm.HarmCategory.HARM_CATEGORY_DEROGATORY, - - glm.HarmCategory.HARM_CATEGORY_TOXICITY: glm.HarmCategory.HARM_CATEGORY_TOXICITY, - HarmCategory.HARM_CATEGORY_TOXICITY: glm.HarmCategory.HARM_CATEGORY_TOXICITY, - 2: glm.HarmCategory.HARM_CATEGORY_TOXICITY, - "harm_category_toxicity": glm.HarmCategory.HARM_CATEGORY_TOXICITY, - "toxicity": glm.HarmCategory.HARM_CATEGORY_TOXICITY, - "toxic": glm.HarmCategory.HARM_CATEGORY_TOXICITY, - - glm.HarmCategory.HARM_CATEGORY_VIOLENCE: glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - HarmCategory.HARM_CATEGORY_VIOLENCE: glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - 3: glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - "harm_category_violence": glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - "violence": glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - "violent": glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - - glm.HarmCategory.HARM_CATEGORY_SEXUAL: glm.HarmCategory.HARM_CATEGORY_SEXUAL, - HarmCategory.HARM_CATEGORY_SEXUAL: glm.HarmCategory.HARM_CATEGORY_SEXUAL, - 4: glm.HarmCategory.HARM_CATEGORY_SEXUAL, - "harm_category_sexual": glm.HarmCategory.HARM_CATEGORY_SEXUAL, - "sexual": glm.HarmCategory.HARM_CATEGORY_SEXUAL, - "sex": glm.HarmCategory.HARM_CATEGORY_SEXUAL, - - glm.HarmCategory.HARM_CATEGORY_MEDICAL: glm.HarmCategory.HARM_CATEGORY_MEDICAL, - HarmCategory.HARM_CATEGORY_MEDICAL: glm.HarmCategory.HARM_CATEGORY_MEDICAL, - 5: glm.HarmCategory.HARM_CATEGORY_MEDICAL, - "harm_category_medical": glm.HarmCategory.HARM_CATEGORY_MEDICAL, - "medical": glm.HarmCategory.HARM_CATEGORY_MEDICAL, - "med": glm.HarmCategory.HARM_CATEGORY_MEDICAL, - - glm.HarmCategory.HARM_CATEGORY_DANGEROUS: glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - HarmCategory.HARM_CATEGORY_DANGEROUS: glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - 6: glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - "harm_category_dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - "dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - "danger": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, +_HARM_CATEGORIES: Dict[HarmCategoryOptions, protos.HarmCategory] = { + protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + 0: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "harm_category_unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + + protos.HarmCategory.HARM_CATEGORY_DEROGATORY: protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + HarmCategory.HARM_CATEGORY_DEROGATORY: protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + 1: protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + "harm_category_derogatory": protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + "derogatory": protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + + protos.HarmCategory.HARM_CATEGORY_TOXICITY: protos.HarmCategory.HARM_CATEGORY_TOXICITY, + HarmCategory.HARM_CATEGORY_TOXICITY: protos.HarmCategory.HARM_CATEGORY_TOXICITY, + 2: protos.HarmCategory.HARM_CATEGORY_TOXICITY, + "harm_category_toxicity": protos.HarmCategory.HARM_CATEGORY_TOXICITY, + "toxicity": protos.HarmCategory.HARM_CATEGORY_TOXICITY, + "toxic": protos.HarmCategory.HARM_CATEGORY_TOXICITY, + + protos.HarmCategory.HARM_CATEGORY_VIOLENCE: protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + HarmCategory.HARM_CATEGORY_VIOLENCE: protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + 3: protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + "harm_category_violence": protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + "violence": protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + "violent": protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + + protos.HarmCategory.HARM_CATEGORY_SEXUAL: protos.HarmCategory.HARM_CATEGORY_SEXUAL, + HarmCategory.HARM_CATEGORY_SEXUAL: protos.HarmCategory.HARM_CATEGORY_SEXUAL, + 4: protos.HarmCategory.HARM_CATEGORY_SEXUAL, + "harm_category_sexual": protos.HarmCategory.HARM_CATEGORY_SEXUAL, + "sexual": protos.HarmCategory.HARM_CATEGORY_SEXUAL, + "sex": protos.HarmCategory.HARM_CATEGORY_SEXUAL, + + protos.HarmCategory.HARM_CATEGORY_MEDICAL: protos.HarmCategory.HARM_CATEGORY_MEDICAL, + HarmCategory.HARM_CATEGORY_MEDICAL: protos.HarmCategory.HARM_CATEGORY_MEDICAL, + 5: protos.HarmCategory.HARM_CATEGORY_MEDICAL, + "harm_category_medical": protos.HarmCategory.HARM_CATEGORY_MEDICAL, + "medical": protos.HarmCategory.HARM_CATEGORY_MEDICAL, + "med": protos.HarmCategory.HARM_CATEGORY_MEDICAL, + + protos.HarmCategory.HARM_CATEGORY_DANGEROUS: protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + HarmCategory.HARM_CATEGORY_DANGEROUS: protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + 6: protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + "harm_category_dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + "dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + "danger": protos.HarmCategory.HARM_CATEGORY_DANGEROUS, } # fmt: on -def to_harm_category(x: HarmCategoryOptions) -> glm.HarmCategory: +def to_harm_category(x: HarmCategoryOptions) -> protos.HarmCategory: if isinstance(x, str): x = x.lower() return _HARM_CATEGORIES[x] @@ -161,7 +161,7 @@ class ContentFilterDict(TypedDict): reason: BlockedReason message: str - __doc__ = string_utils.strip_oneof(glm.ContentFilter.__doc__) + __doc__ = string_utils.strip_oneof(protos.ContentFilter.__doc__) def convert_filters_to_enums( @@ -177,15 +177,15 @@ def convert_filters_to_enums( class SafetyRatingDict(TypedDict): - category: glm.HarmCategory + category: protos.HarmCategory probability: HarmProbability - __doc__ = string_utils.strip_oneof(glm.SafetyRating.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetyRating.__doc__) def convert_rating_to_enum(rating: dict) -> SafetyRatingDict: return { - "category": glm.HarmCategory(rating["category"]), + "category": protos.HarmCategory(rating["category"]), "probability": HarmProbability(rating["probability"]), } @@ -198,10 +198,10 @@ def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]: class SafetySettingDict(TypedDict): - category: glm.HarmCategory + category: protos.HarmCategory threshold: HarmBlockThreshold - __doc__ = string_utils.strip_oneof(glm.SafetySetting.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetySetting.__doc__) class LooseSafetySettingDict(TypedDict): @@ -251,7 +251,7 @@ def normalize_safety_settings( def convert_setting_to_enum(setting: dict) -> SafetySettingDict: return { - "category": glm.HarmCategory(setting["category"]), + "category": protos.HarmCategory(setting["category"]), "threshold": HarmBlockThreshold(setting["threshold"]), } @@ -260,7 +260,7 @@ class SafetyFeedbackDict(TypedDict): rating: SafetyRatingDict setting: SafetySettingDict - __doc__ = string_utils.strip_oneof(glm.SafetyFeedback.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetyFeedback.__doc__) def convert_safety_feedback_to_enums( diff --git a/google/generativeai/types/permission_types.py b/google/generativeai/types/permission_types.py index fde2ddacc..1df831db0 100644 --- a/google/generativeai/types/permission_types.py +++ b/google/generativeai/types/permission_types.py @@ -19,6 +19,7 @@ import re import google.ai.generativelanguage as glm +from google.generativeai import protos from google.protobuf import field_mask_pb2 @@ -28,8 +29,8 @@ from google.generativeai import string_utils -GranteeType = glm.Permission.GranteeType -Role = glm.Permission.Role +GranteeType = protos.Permission.GranteeType +Role = protos.Permission.Role GranteeTypeOptions = Union[str, int, GranteeType] RoleOptions = Union[str, int, Role] @@ -108,7 +109,7 @@ def delete( """ if client is None: client = get_default_permission_client() - delete_request = glm.DeletePermissionRequest(name=self.name) + delete_request = protos.DeletePermissionRequest(name=self.name) client.delete_permission(request=delete_request) async def delete_async( @@ -120,7 +121,7 @@ async def delete_async( """ if client is None: client = get_default_permission_async_client() - delete_request = glm.DeletePermissionRequest(name=self.name) + delete_request = protos.DeletePermissionRequest(name=self.name) await client.delete_permission(request=delete_request) # TODO (magashe): Add a method to validate update value. As of now only `role` is supported as a mask path @@ -161,7 +162,7 @@ def update( for path, value in updates.items(): self._apply_update(path, value) - update_request = glm.UpdatePermissionRequest( + update_request = protos.UpdatePermissionRequest( permission=self._to_proto(), update_mask=field_mask ) client.update_permission(request=update_request) @@ -191,14 +192,14 @@ async def update_async( for path, value in updates.items(): self._apply_update(path, value) - update_request = glm.UpdatePermissionRequest( + update_request = protos.UpdatePermissionRequest( permission=self._to_proto(), update_mask=field_mask ) await client.update_permission(request=update_request) return self - def _to_proto(self) -> glm.Permission: - return glm.Permission( + def _to_proto(self) -> protos.Permission: + return protos.Permission( name=self.name, role=self.role, grantee_type=self.grantee_type, @@ -225,7 +226,7 @@ def get( """ if client is None: client = get_default_permission_client() - get_perm_request = glm.GetPermissionRequest(name=name) + get_perm_request = protos.GetPermissionRequest(name=name) get_perm_response = client.get_permission(request=get_perm_request) get_perm_response = type(get_perm_response).to_dict(get_perm_response) return cls(**get_perm_response) @@ -241,7 +242,7 @@ async def get_async( """ if client is None: client = get_default_permission_async_client() - get_perm_request = glm.GetPermissionRequest(name=name) + get_perm_request = protos.GetPermissionRequest(name=name) get_perm_response = await client.get_permission(request=get_perm_request) get_perm_response = type(get_perm_response).to_dict(get_perm_response) return cls(**get_perm_response) @@ -263,7 +264,7 @@ def _make_create_permission_request( role: RoleOptions, grantee_type: Optional[GranteeTypeOptions] = None, email_address: Optional[str] = None, - ) -> glm.CreatePermissionRequest: + ) -> protos.CreatePermissionRequest: role = to_role(role) if grantee_type: @@ -278,12 +279,12 @@ def _make_create_permission_request( f"Invalid operation: An 'email_address' must be provided when 'grantee_type' is not set to 'EVERYONE'. Currently, 'grantee_type' is set to '{grantee_type}' and 'email_address' is '{email_address if email_address else 'not provided'}'." ) - permission = glm.Permission( + permission = protos.Permission( role=role, grantee_type=grantee_type, email_address=email_address, ) - return glm.CreatePermissionRequest( + return protos.CreatePermissionRequest( parent=self.parent, permission=permission, ) @@ -359,7 +360,7 @@ def list( if client is None: client = get_default_permission_client() - request = glm.ListPermissionsRequest( + request = protos.ListPermissionsRequest( parent=self.parent, page_size=page_size # pytype: disable=attribute-error ) for permission in client.list_permissions(request): @@ -377,7 +378,7 @@ async def list_async( if client is None: client = get_default_permission_async_client() - request = glm.ListPermissionsRequest( + request = protos.ListPermissionsRequest( parent=self.parent, page_size=page_size # pytype: disable=attribute-error ) async for permission in await client.list_permissions(request): @@ -400,7 +401,7 @@ def transfer_ownership( raise NotImplementedError("Can'/t transfer_ownership for a Corpus") if client is None: client = get_default_permission_client() - transfer_request = glm.TransferOwnershipRequest( + transfer_request = protos.TransferOwnershipRequest( name=self.parent, email_address=email_address # pytype: disable=attribute-error ) return client.transfer_ownership(request=transfer_request) @@ -415,7 +416,7 @@ async def transfer_ownership_async( raise NotImplementedError("Can'/t transfer_ownership for a Corpus") if client is None: client = get_default_permission_async_client() - transfer_request = glm.TransferOwnershipRequest( + transfer_request = protos.TransferOwnershipRequest( name=self.parent, email_address=email_address # pytype: disable=attribute-error ) return await client.transfer_ownership(request=transfer_request) diff --git a/google/generativeai/types/retriever_types.py b/google/generativeai/types/retriever_types.py index 294e0b64c..9931ee58d 100644 --- a/google/generativeai/types/retriever_types.py +++ b/google/generativeai/types/retriever_types.py @@ -22,6 +22,7 @@ from typing_extensions import deprecated # type: ignore import google.ai.generativelanguage as glm +from google.generativeai import protos from google.protobuf import field_mask_pb2 from google.generativeai.client import get_default_retriever_client @@ -44,14 +45,14 @@ def valid_name(name): return re.match(_VALID_NAME, name) and len(name) < 40 -Operator = glm.Condition.Operator -State = glm.Chunk.State +Operator = protos.Condition.Operator +State = protos.Chunk.State OperatorOptions = Union[str, int, Operator] StateOptions = Union[str, int, State] ChunkOptions = Union[ - glm.Chunk, + protos.Chunk, str, tuple[str, str], tuple[str, str, Any], @@ -59,17 +60,17 @@ def valid_name(name): ] # fmt: no BatchCreateChunkOptions = Union[ - glm.BatchCreateChunksRequest, + protos.BatchCreateChunksRequest, Mapping[str, str], Mapping[str, tuple[str, str]], Iterable[ChunkOptions], ] # fmt: no -UpdateChunkOptions = Union[glm.UpdateChunkRequest, Mapping[str, Any], tuple[str, Any]] +UpdateChunkOptions = Union[protos.UpdateChunkRequest, Mapping[str, Any], tuple[str, Any]] -BatchUpdateChunksOptions = Union[glm.BatchUpdateChunksRequest, Iterable[UpdateChunkOptions]] +BatchUpdateChunksOptions = Union[protos.BatchUpdateChunksRequest, Iterable[UpdateChunkOptions]] -BatchDeleteChunkOptions = Union[list[glm.DeleteChunkRequest], Iterable[str]] +BatchDeleteChunkOptions = Union[list[protos.DeleteChunkRequest], Iterable[str]] _OPERATOR: dict[OperatorOptions, Operator] = { Operator.OPERATOR_UNSPECIFIED: Operator.OPERATOR_UNSPECIFIED, @@ -163,10 +164,10 @@ def _to_proto(self): ) kwargs["operation"] = c.operation - condition = glm.Condition(**kwargs) + condition = protos.Condition(**kwargs) conditions.append(condition) - return glm.MetadataFilter(key=self.key, conditions=conditions) + return protos.MetadataFilter(key=self.key, conditions=conditions) @string_utils.prettyprint @@ -188,17 +189,17 @@ def _to_proto(self): kwargs["string_value"] = self.value elif isinstance(self.value, Iterable): if isinstance(self.value, Mapping): - # If already converted to a glm.StringList, get the values + # If already converted to a protos.StringList, get the values kwargs["string_list_value"] = self.value else: - kwargs["string_list_value"] = glm.StringList(values=self.value) + kwargs["string_list_value"] = protos.StringList(values=self.value) elif isinstance(self.value, (int, float)): kwargs["numeric_value"] = float(self.value) else: raise ValueError( f"Invalid value type: The value for a custom_metadata specification must be either a list of string values, a string, or an integer/float. Received: '{self.value}' of type {type(self.value).__name__}." ) - return glm.CustomMetadata(key=self.key, **kwargs) + return protos.CustomMetadata(key=self.key, **kwargs) @classmethod def _from_dict(cls, cm): @@ -216,14 +217,14 @@ def _to_dict(self): return type(proto).to_dict(proto) -CustomMetadataOptions = Union[CustomMetadata, glm.CustomMetadata, dict] +CustomMetadataOptions = Union[CustomMetadata, protos.CustomMetadata, dict] def make_custom_metadata(cm: CustomMetadataOptions) -> CustomMetadata: if isinstance(cm, CustomMetadata): return cm - if isinstance(cm, glm.CustomMetadata): + if isinstance(cm, protos.CustomMetadata): cm = type(cm).to_dict(cm) if isinstance(cm, dict): @@ -293,9 +294,9 @@ def create_document( c_data.append(cm._to_proto()) if name is None: - document = glm.Document(display_name=display_name, custom_metadata=c_data) + document = protos.Document(display_name=display_name, custom_metadata=c_data) elif valid_name(name): - document = glm.Document( + document = protos.Document( name=f"{self.name}/documents/{name}", display_name=display_name, custom_metadata=c_data, @@ -303,7 +304,7 @@ def create_document( else: raise ValueError(NAME_ERROR_MSG.format(length=len(name), name=name)) - request = glm.CreateDocumentRequest(parent=self.name, document=document) + request = protos.CreateDocumentRequest(parent=self.name, document=document) response = client.create_document(request, **request_options) return decode_document(response) @@ -329,9 +330,9 @@ async def create_document_async( c_data.append(cm._to_proto()) if name is None: - document = glm.Document(display_name=display_name, custom_metadata=c_data) + document = protos.Document(display_name=display_name, custom_metadata=c_data) elif valid_name(name): - document = glm.Document( + document = protos.Document( name=f"{self.name}/documents/{name}", display_name=display_name, custom_metadata=c_data, @@ -339,7 +340,7 @@ async def create_document_async( else: raise ValueError(NAME_ERROR_MSG.format(length=len(name), name=name)) - request = glm.CreateDocumentRequest(parent=self.name, document=document) + request = protos.CreateDocumentRequest(parent=self.name, document=document) response = await client.create_document(request, **request_options) return decode_document(response) @@ -368,7 +369,7 @@ def get_document( if "/" not in name: name = f"{self.name}/documents/{name}" - request = glm.GetDocumentRequest(name=name) + request = protos.GetDocumentRequest(name=name) response = client.get_document(request, **request_options) return decode_document(response) @@ -388,7 +389,7 @@ async def get_document_async( if "/" not in name: name = f"{self.name}/documents/{name}" - request = glm.GetDocumentRequest(name=name) + request = protos.GetDocumentRequest(name=name) response = await client.get_document(request, **request_options) return decode_document(response) @@ -434,7 +435,7 @@ def update( for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) + request = protos.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) client.update_corpus(request, **request_options) return self @@ -465,7 +466,7 @@ async def update_async( for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) + request = protos.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) await client.update_corpus(request, **request_options) return self @@ -506,7 +507,7 @@ def query( for mf in metadata_filters: m_f_.append(mf._to_proto()) - request = glm.QueryCorpusRequest( + request = protos.QueryCorpusRequest( name=self.name, query=query, metadata_filters=m_f_, @@ -551,7 +552,7 @@ async def query_async( for mf in metadata_filters: m_f_.append(mf._to_proto()) - request = glm.QueryCorpusRequest( + request = protos.QueryCorpusRequest( name=self.name, query=query, metadata_filters=m_f_, @@ -594,7 +595,7 @@ def delete_document( if "/" not in name: name = f"{self.name}/documents/{name}" - request = glm.DeleteDocumentRequest(name=name, force=bool(force)) + request = protos.DeleteDocumentRequest(name=name, force=bool(force)) client.delete_document(request, **request_options) async def delete_document_async( @@ -614,7 +615,7 @@ async def delete_document_async( if "/" not in name: name = f"{self.name}/documents/{name}" - request = glm.DeleteDocumentRequest(name=name, force=bool(force)) + request = protos.DeleteDocumentRequest(name=name, force=bool(force)) await client.delete_document(request, **request_options) def list_documents( @@ -640,7 +641,7 @@ def list_documents( if client is None: client = get_default_retriever_client() - request = glm.ListDocumentsRequest( + request = protos.ListDocumentsRequest( parent=self.name, page_size=page_size, ) @@ -660,7 +661,7 @@ async def list_documents_async( if client is None: client = get_default_retriever_async_client() - request = glm.ListDocumentsRequest( + request = protos.ListDocumentsRequest( parent=self.name, page_size=page_size, ) @@ -792,15 +793,17 @@ def create_chunk( chunk_name = name if isinstance(data, str): - chunk = glm.Chunk(name=chunk_name, data={"string_value": data}, custom_metadata=c_data) + chunk = protos.Chunk( + name=chunk_name, data={"string_value": data}, custom_metadata=c_data + ) else: - chunk = glm.Chunk( + chunk = protos.Chunk( name=chunk_name, data={"string_value": data.string_value}, custom_metadata=c_data, ) - request = glm.CreateChunkRequest(parent=self.name, chunk=chunk) + request = protos.CreateChunkRequest(parent=self.name, chunk=chunk) response = client.create_chunk(request, **request_options) return decode_chunk(response) @@ -834,24 +837,26 @@ async def create_chunk_async( chunk_name = name if isinstance(data, str): - chunk = glm.Chunk(name=chunk_name, data={"string_value": data}, custom_metadata=c_data) + chunk = protos.Chunk( + name=chunk_name, data={"string_value": data}, custom_metadata=c_data + ) else: - chunk = glm.Chunk( + chunk = protos.Chunk( name=chunk_name, data={"string_value": data.string_value}, custom_metadata=c_data, ) - request = glm.CreateChunkRequest(parent=self.name, chunk=chunk) + request = protos.CreateChunkRequest(parent=self.name, chunk=chunk) response = await client.create_chunk(request, **request_options) return decode_chunk(response) - def _make_chunk(self, chunk: ChunkOptions) -> glm.Chunk: + def _make_chunk(self, chunk: ChunkOptions) -> protos.Chunk: # del self - if isinstance(chunk, glm.Chunk): - return glm.Chunk(chunk) + if isinstance(chunk, protos.Chunk): + return protos.Chunk(chunk) elif isinstance(chunk, str): - return glm.Chunk(data={"string_value": chunk}) + return protos.Chunk(data={"string_value": chunk}) elif isinstance(chunk, tuple): if len(chunk) == 2: name, data = chunk # pytype: disable=bad-unpacking @@ -864,7 +869,7 @@ def _make_chunk(self, chunk: ChunkOptions) -> glm.Chunk: f"value: {chunk}" ) - return glm.Chunk( + return protos.Chunk( name=name, data={"string_value": data}, custom_metadata=custom_metadata, @@ -873,7 +878,7 @@ def _make_chunk(self, chunk: ChunkOptions) -> glm.Chunk: if isinstance(chunk["data"], str): chunk = dict(chunk) chunk["data"] = {"string_value": chunk["data"]} - return glm.Chunk(chunk) + return protos.Chunk(chunk) else: raise TypeError( f"Invalid input: Could not convert instance of type '{type(chunk).__name__}' to a chunk. Received value: '{chunk}'." @@ -881,8 +886,8 @@ def _make_chunk(self, chunk: ChunkOptions) -> glm.Chunk: def _make_batch_create_chunk_request( self, chunks: BatchCreateChunkOptions - ) -> glm.BatchCreateChunksRequest: - if isinstance(chunks, glm.BatchCreateChunksRequest): + ) -> protos.BatchCreateChunksRequest: + if isinstance(chunks, protos.BatchCreateChunksRequest): return chunks if isinstance(chunks, Mapping): @@ -901,9 +906,9 @@ def _make_batch_create_chunk_request( chunk.name = f"{self.name}/chunks/{chunk.name}" - requests.append(glm.CreateChunkRequest(parent=self.name, chunk=chunk)) + requests.append(protos.CreateChunkRequest(parent=self.name, chunk=chunk)) - return glm.BatchCreateChunksRequest(parent=self.name, requests=requests) + return protos.BatchCreateChunksRequest(parent=self.name, requests=requests) def batch_create_chunks( self, @@ -973,7 +978,7 @@ def get_chunk( if "/" not in name: name = f"{self.name}/chunks/{name}" - request = glm.GetChunkRequest(name=name) + request = protos.GetChunkRequest(name=name) response = client.get_chunk(request, **request_options) return decode_chunk(response) @@ -993,7 +998,7 @@ async def get_chunk_async( if "/" not in name: name = f"{self.name}/chunks/{name}" - request = glm.GetChunkRequest(name=name) + request = protos.GetChunkRequest(name=name) response = await client.get_chunk(request, **request_options) return decode_chunk(response) @@ -1019,7 +1024,7 @@ def list_chunks( if client is None: client = get_default_retriever_client() - request = glm.ListChunksRequest(parent=self.name, page_size=page_size) + request = protos.ListChunksRequest(parent=self.name, page_size=page_size) for chunk in client.list_chunks(request, **request_options): yield decode_chunk(chunk) @@ -1036,7 +1041,7 @@ async def list_chunks_async( if client is None: client = get_default_retriever_async_client() - request = glm.ListChunksRequest(parent=self.name, page_size=page_size) + request = protos.ListChunksRequest(parent=self.name, page_size=page_size) async for chunk in await client.list_chunks(request, **request_options): yield decode_chunk(chunk) @@ -1076,7 +1081,7 @@ def query( for mf in metadata_filters: m_f_.append(mf._to_proto()) - request = glm.QueryDocumentRequest( + request = protos.QueryDocumentRequest( name=self.name, query=query, metadata_filters=m_f_, @@ -1121,7 +1126,7 @@ async def query_async( for mf in metadata_filters: m_f_.append(mf._to_proto()) - request = glm.QueryDocumentRequest( + request = protos.QueryDocumentRequest( name=self.name, query=query, metadata_filters=m_f_, @@ -1181,7 +1186,7 @@ def update( for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) + request = protos.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) client.update_document(request, **request_options) return self @@ -1211,7 +1216,7 @@ async def update_async( for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) + request = protos.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) await client.update_document(request, **request_options) return self @@ -1237,7 +1242,7 @@ def batch_update_chunks( if client is None: client = get_default_retriever_client() - if isinstance(chunks, glm.BatchUpdateChunksRequest): + if isinstance(chunks, protos.BatchUpdateChunksRequest): response = client.batch_update_chunks(chunks) response = type(response).to_dict(response) return response @@ -1270,15 +1275,17 @@ def batch_update_chunks( for path, value in updates.items(): chunk_to_update._apply_update(path, value) _requests.append( - glm.UpdateChunkRequest(chunk=chunk_to_update.to_dict(), update_mask=field_mask) + protos.UpdateChunkRequest( + chunk=chunk_to_update.to_dict(), update_mask=field_mask + ) ) - request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = client.batch_update_chunks(request, **request_options) response = type(response).to_dict(response) return response if isinstance(chunks, Iterable) and not isinstance(chunks, Mapping): for chunk in chunks: - if isinstance(chunk, glm.UpdateChunkRequest): + if isinstance(chunk, protos.UpdateChunkRequest): _requests.append(chunk) elif isinstance(chunk, tuple): # First element is name of chunk, second element contains updates @@ -1304,9 +1311,10 @@ def batch_update_chunks( ) else: raise TypeError( - "Invalid input: The 'chunks' parameter must be a list of 'glm.UpdateChunkRequests', dictionaries, or tuples of dictionaries." + "Invalid input: The 'chunks' parameter must be a list of 'protos.UpdateChunkRequests'," + " dictionaries, or tuples of dictionaries." ) - request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = client.batch_update_chunks(request, **request_options) response = type(response).to_dict(response) return response @@ -1324,7 +1332,7 @@ async def batch_update_chunks_async( if client is None: client = get_default_retriever_async_client() - if isinstance(chunks, glm.BatchUpdateChunksRequest): + if isinstance(chunks, protos.BatchUpdateChunksRequest): response = client.batch_update_chunks(chunks) response = type(response).to_dict(response) return response @@ -1357,15 +1365,17 @@ async def batch_update_chunks_async( for path, value in updates.items(): chunk_to_update._apply_update(path, value) _requests.append( - glm.UpdateChunkRequest(chunk=chunk_to_update.to_dict(), update_mask=field_mask) + protos.UpdateChunkRequest( + chunk=chunk_to_update.to_dict(), update_mask=field_mask + ) ) - request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = await client.batch_update_chunks(request, **request_options) response = type(response).to_dict(response) return response if isinstance(chunks, Iterable) and not isinstance(chunks, Mapping): for chunk in chunks: - if isinstance(chunk, glm.UpdateChunkRequest): + if isinstance(chunk, protos.UpdateChunkRequest): _requests.append(chunk) elif isinstance(chunk, tuple): # First element is name of chunk, second element contains updates @@ -1391,9 +1401,10 @@ async def batch_update_chunks_async( ) else: raise TypeError( - "Invalid input: The 'chunks' parameter must be a list of 'glm.UpdateChunkRequests', dictionaries, or tuples of dictionaries." + "Invalid input: The 'chunks' parameter must be a list of 'protos.UpdateChunkRequests', " + "dictionaries, or tuples of dictionaries." ) - request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = await client.batch_update_chunks(request, **request_options) response = type(response).to_dict(response) return response @@ -1420,7 +1431,7 @@ def delete_chunk( if "/" not in name: name = f"{self.name}/chunks/{name}" - request = glm.DeleteChunkRequest(name=name) + request = protos.DeleteChunkRequest(name=name) client.delete_chunk(request, **request_options) async def delete_chunk_async( @@ -1439,7 +1450,7 @@ async def delete_chunk_async( if "/" not in name: name = f"{self.name}/chunks/{name}" - request = glm.DeleteChunkRequest(name=name) + request = protos.DeleteChunkRequest(name=name) await client.delete_chunk(request, **request_options) def batch_delete_chunks( @@ -1461,18 +1472,19 @@ def batch_delete_chunks( if client is None: client = get_default_retriever_client() - if all(isinstance(x, glm.DeleteChunkRequest) for x in chunks): - request = glm.BatchDeleteChunksRequest(parent=self.name, requests=chunks) + if all(isinstance(x, protos.DeleteChunkRequest) for x in chunks): + request = protos.BatchDeleteChunksRequest(parent=self.name, requests=chunks) client.batch_delete_chunks(request, **request_options) elif isinstance(chunks, Iterable): _request_list = [] for chunk_name in chunks: - _request_list.append(glm.DeleteChunkRequest(name=chunk_name)) - request = glm.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) + _request_list.append(protos.DeleteChunkRequest(name=chunk_name)) + request = protos.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) client.batch_delete_chunks(request, **request_options) else: raise ValueError( - "Invalid operation: To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple 'glm.DeleteChunkRequest's." + "Invalid operation: To delete chunks, you must pass in either the names of the chunks as an iterable, " + "or multiple 'protos.DeleteChunkRequest's." ) async def batch_delete_chunks_async( @@ -1488,18 +1500,19 @@ async def batch_delete_chunks_async( if client is None: client = get_default_retriever_async_client() - if all(isinstance(x, glm.DeleteChunkRequest) for x in chunks): - request = glm.BatchDeleteChunksRequest(parent=self.name, requests=chunks) + if all(isinstance(x, protos.DeleteChunkRequest) for x in chunks): + request = protos.BatchDeleteChunksRequest(parent=self.name, requests=chunks) await client.batch_delete_chunks(request, **request_options) elif isinstance(chunks, Iterable): _request_list = [] for chunk_name in chunks: - _request_list.append(glm.DeleteChunkRequest(name=chunk_name)) - request = glm.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) + _request_list.append(protos.DeleteChunkRequest(name=chunk_name)) + request = protos.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) await client.batch_delete_chunks(request, **request_options) else: raise ValueError( - "Invalid operation: To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple 'glm.DeleteChunkRequest's." + "Invalid operation: To delete chunks, you must pass in either the names of the chunks as an iterable, " + "or multiple 'protos.DeleteChunkRequest's." ) def to_dict(self) -> dict[str, Any]: @@ -1511,7 +1524,7 @@ def to_dict(self) -> dict[str, Any]: return result -def decode_chunk(chunk: glm.Chunk) -> Chunk: +def decode_chunk(chunk: protos.Chunk) -> Chunk: chunk = type(chunk).to_dict(chunk) idecode_time(chunk, "create_time") idecode_time(chunk, "update_time") @@ -1625,7 +1638,7 @@ def update( field_mask.paths.append(path) for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) + request = protos.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) client.update_chunk(request, **request_options) return self @@ -1665,7 +1678,7 @@ async def update_async( field_mask.paths.append(path) for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) + request = protos.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) await client.update_chunk(request, **request_options) return self diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index c8368da7f..74da06e45 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -23,7 +23,7 @@ from typing_extensions import TypedDict -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils @@ -39,9 +39,9 @@ ] # These are basic python enums, it's okay to expose them -HarmProbability = glm.SafetyRating.HarmProbability -HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold -BlockedReason = glm.ContentFilter.BlockedReason +HarmProbability = protos.SafetyRating.HarmProbability +HarmBlockThreshold = protos.SafetySetting.HarmBlockThreshold +BlockedReason = protos.ContentFilter.BlockedReason import proto @@ -51,57 +51,57 @@ class HarmCategory(proto.Enum): Harm Categories supported by the gemini-family model """ - HARM_CATEGORY_UNSPECIFIED = glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value - HARM_CATEGORY_HARASSMENT = glm.HarmCategory.HARM_CATEGORY_HARASSMENT.value - HARM_CATEGORY_HATE_SPEECH = glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH.value - HARM_CATEGORY_SEXUALLY_EXPLICIT = glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT.value - HARM_CATEGORY_DANGEROUS_CONTENT = glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT.value + HARM_CATEGORY_UNSPECIFIED = protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value + HARM_CATEGORY_HARASSMENT = protos.HarmCategory.HARM_CATEGORY_HARASSMENT.value + HARM_CATEGORY_HATE_SPEECH = protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH.value + HARM_CATEGORY_SEXUALLY_EXPLICIT = protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT.value + HARM_CATEGORY_DANGEROUS_CONTENT = protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT.value HarmCategoryOptions = Union[str, int, HarmCategory] # fmt: off -_HARM_CATEGORIES: Dict[HarmCategoryOptions, glm.HarmCategory] = { - glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - 0: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "harm_category_unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, +_HARM_CATEGORIES: Dict[HarmCategoryOptions, protos.HarmCategory] = { + protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + 0: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "harm_category_unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - 7: glm.HarmCategory.HARM_CATEGORY_HARASSMENT, - glm.HarmCategory.HARM_CATEGORY_HARASSMENT: glm.HarmCategory.HARM_CATEGORY_HARASSMENT, - HarmCategory.HARM_CATEGORY_HARASSMENT: glm.HarmCategory.HARM_CATEGORY_HARASSMENT, - "harm_category_harassment": glm.HarmCategory.HARM_CATEGORY_HARASSMENT, - "harassment": glm.HarmCategory.HARM_CATEGORY_HARASSMENT, - - 8: glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH: glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'harm_category_hate_speech': glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'hate_speech': glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'hate': glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - - 9: glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "harm_category_sexually_explicit": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "harm_category_sexual": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sexually_explicit": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sexual": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sex": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - - 10: glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "harm_category_dangerous_content": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "harm_category_dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "danger": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + 7: protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + protos.HarmCategory.HARM_CATEGORY_HARASSMENT: protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + HarmCategory.HARM_CATEGORY_HARASSMENT: protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + "harm_category_harassment": protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + "harassment": protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + + 8: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'harm_category_hate_speech': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'hate_speech': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'hate': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + + 9: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "harm_category_sexually_explicit": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "harm_category_sexual": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sexually_explicit": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sexual": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sex": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + + 10: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "harm_category_dangerous_content": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "harm_category_dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "danger": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, } # fmt: on -def to_harm_category(x: HarmCategoryOptions) -> glm.HarmCategory: +def to_harm_category(x: HarmCategoryOptions) -> protos.HarmCategory: if isinstance(x, str): x = x.lower() return _HARM_CATEGORIES[x] @@ -150,7 +150,7 @@ class ContentFilterDict(TypedDict): reason: BlockedReason message: str - __doc__ = string_utils.strip_oneof(glm.ContentFilter.__doc__) + __doc__ = string_utils.strip_oneof(protos.ContentFilter.__doc__) def convert_filters_to_enums( @@ -166,15 +166,15 @@ def convert_filters_to_enums( class SafetyRatingDict(TypedDict): - category: glm.HarmCategory + category: protos.HarmCategory probability: HarmProbability - __doc__ = string_utils.strip_oneof(glm.SafetyRating.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetyRating.__doc__) def convert_rating_to_enum(rating: dict) -> SafetyRatingDict: return { - "category": glm.HarmCategory(rating["category"]), + "category": protos.HarmCategory(rating["category"]), "probability": HarmProbability(rating["probability"]), } @@ -187,10 +187,10 @@ def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]: class SafetySettingDict(TypedDict): - category: glm.HarmCategory + category: protos.HarmCategory threshold: HarmBlockThreshold - __doc__ = string_utils.strip_oneof(glm.SafetySetting.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetySetting.__doc__) class LooseSafetySettingDict(TypedDict): @@ -225,7 +225,7 @@ def to_easy_safety_dict(settings: SafetySettingOptions) -> EasySafetySettingDict else: # Iterable result = {} for setting in settings: - if isinstance(setting, glm.SafetySetting): + if isinstance(setting, protos.SafetySetting): result[to_harm_category(setting.category)] = to_block_threshold(setting.threshold) elif isinstance(setting, dict): result[to_harm_category(setting["category"])] = to_block_threshold( @@ -267,7 +267,7 @@ def normalize_safety_settings( def convert_setting_to_enum(setting: dict) -> SafetySettingDict: return { - "category": glm.HarmCategory(setting["category"]), + "category": protos.HarmCategory(setting["category"]), "threshold": HarmBlockThreshold(setting["threshold"]), } @@ -276,7 +276,7 @@ class SafetyFeedbackDict(TypedDict): rating: SafetyRatingDict setting: SafetySettingDict - __doc__ = string_utils.strip_oneof(glm.SafetyFeedback.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetyFeedback.__doc__) def convert_safety_feedback_to_enums( diff --git a/tests/test_answer.py b/tests/test_answer.py index 4128567f4..2669b207c 100644 --- a/tests/test_answer.py +++ b/tests/test_answer.py @@ -18,7 +18,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import answer from google.generativeai import types as genai_types @@ -47,14 +47,14 @@ def add_client_method(f): @add_client_method def generate_answer( - request: glm.GenerateAnswerRequest, + request: protos.GenerateAnswerRequest, **kwargs, - ) -> glm.GenerateAnswerResponse: + ) -> protos.GenerateAnswerResponse: self.observed_requests.append(request) - return glm.GenerateAnswerResponse( - answer=glm.Candidate( + return protos.GenerateAnswerResponse( + answer=protos.Candidate( index=1, - content=(glm.Content(parts=[glm.Part(text="Demo answer.")])), + content=(protos.Content(parts=[protos.Part(text="Demo answer.")])), ), answerable_probability=0.500, ) @@ -62,17 +62,23 @@ def generate_answer( def test_make_grounding_passages_mixed_types(self): inline_passages = [ "I am a chicken", - glm.Content(parts=[glm.Part(text="I am a bird.")]), - glm.Content(parts=[glm.Part(text="I can fly!")]), + protos.Content(parts=[protos.Part(text="I am a bird.")]), + protos.Content(parts=[protos.Part(text="I can fly!")]), ] x = answer._make_grounding_passages(inline_passages) - self.assertIsInstance(x, glm.GroundingPassages) + self.assertIsInstance(x, protos.GroundingPassages) self.assertEqual( - glm.GroundingPassages( + protos.GroundingPassages( passages=[ - {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + { + "id": "0", + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "1", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, + {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ), x, @@ -82,23 +88,29 @@ def test_make_grounding_passages_mixed_types(self): [ dict( testcase_name="grounding_passage", - inline_passages=glm.GroundingPassages( + inline_passages=protos.GroundingPassages( passages=[ { "id": "0", - "content": glm.Content(parts=[glm.Part(text="I am a chicken")]), + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "1", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, + { + "id": "2", + "content": protos.Content(parts=[protos.Part(text="I can fly!")]), }, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, ] ), ), dict( testcase_name="content_object", inline_passages=[ - glm.Content(parts=[glm.Part(text="I am a chicken")]), - glm.Content(parts=[glm.Part(text="I am a bird.")]), - glm.Content(parts=[glm.Part(text="I can fly!")]), + protos.Content(parts=[protos.Part(text="I am a chicken")]), + protos.Content(parts=[protos.Part(text="I am a bird.")]), + protos.Content(parts=[protos.Part(text="I can fly!")]), ], ), dict( @@ -109,13 +121,19 @@ def test_make_grounding_passages_mixed_types(self): ) def test_make_grounding_passages(self, inline_passages): x = answer._make_grounding_passages(inline_passages) - self.assertIsInstance(x, glm.GroundingPassages) + self.assertIsInstance(x, protos.GroundingPassages) self.assertEqual( - glm.GroundingPassages( + protos.GroundingPassages( passages=[ - {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + { + "id": "0", + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "1", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, + {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ), x, @@ -133,27 +151,33 @@ def test_make_grounding_passages(self, inline_passages): dict( testcase_name="list_of_grounding_passages", inline_passages=[ - glm.GroundingPassage( - id="4", content=glm.Content(parts=[glm.Part(text="I am a chicken")]) + protos.GroundingPassage( + id="4", content=protos.Content(parts=[protos.Part(text="I am a chicken")]) ), - glm.GroundingPassage( - id="5", content=glm.Content(parts=[glm.Part(text="I am a bird.")]) + protos.GroundingPassage( + id="5", content=protos.Content(parts=[protos.Part(text="I am a bird.")]) ), - glm.GroundingPassage( - id="6", content=glm.Content(parts=[glm.Part(text="I can fly!")]) + protos.GroundingPassage( + id="6", content=protos.Content(parts=[protos.Part(text="I can fly!")]) ), ], ), ) def test_make_grounding_passages_different_id(self, inline_passages): x = answer._make_grounding_passages(inline_passages) - self.assertIsInstance(x, glm.GroundingPassages) + self.assertIsInstance(x, protos.GroundingPassages) self.assertEqual( - glm.GroundingPassages( + protos.GroundingPassages( passages=[ - {"id": "4", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "5", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "6", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + { + "id": "4", + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "5", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, + {"id": "6", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ), x, @@ -167,16 +191,22 @@ def test_make_grounding_passages_key_strings(self): } x = answer._make_grounding_passages(inline_passages) - self.assertIsInstance(x, glm.GroundingPassages) + self.assertIsInstance(x, protos.GroundingPassages) self.assertEqual( - glm.GroundingPassages( + protos.GroundingPassages( passages=[ { "id": "first", - "content": glm.Content(parts=[glm.Part(text="I am a chicken")]), + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "second", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, + { + "id": "third", + "content": protos.Content(parts=[protos.Part(text="I can fly!")]), }, - {"id": "second", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "third", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, ] ), x, @@ -184,14 +214,14 @@ def test_make_grounding_passages_key_strings(self): def test_generate_answer_request(self): # Should be a list of contents to use to_contents() function. - contents = [glm.Content(parts=[glm.Part(text="I have wings.")])] + contents = [protos.Content(parts=[protos.Part(text="I have wings.")])] inline_passages = ["I am a chicken", "I am a bird.", "I can fly!"] - grounding_passages = glm.GroundingPassages( + grounding_passages = protos.GroundingPassages( passages=[ - {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + {"id": "0", "content": protos.Content(parts=[protos.Part(text="I am a chicken")])}, + {"id": "1", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, + {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ) @@ -200,7 +230,7 @@ def test_generate_answer_request(self): ) self.assertEqual( - glm.GenerateAnswerRequest( + protos.GenerateAnswerRequest( model=DEFAULT_ANSWER_MODEL, contents=contents, inline_passages=grounding_passages ), x, @@ -208,13 +238,13 @@ def test_generate_answer_request(self): def test_generate_answer(self): # Test handling return value of generate_answer(). - contents = [glm.Content(parts=[glm.Part(text="I have wings.")])] + contents = [protos.Content(parts=[protos.Part(text="I have wings.")])] - grounding_passages = glm.GroundingPassages( + grounding_passages = protos.GroundingPassages( passages=[ - {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + {"id": "0", "content": protos.Content(parts=[protos.Part(text="I am a chicken")])}, + {"id": "1", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, + {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ) @@ -225,13 +255,13 @@ def test_generate_answer(self): answer_style="ABSTRACTIVE", ) - self.assertIsInstance(a, glm.GenerateAnswerResponse) + self.assertIsInstance(a, protos.GenerateAnswerResponse) self.assertEqual( a, - glm.GenerateAnswerResponse( - answer=glm.Candidate( + protos.GenerateAnswerResponse( + answer=protos.Candidate( index=1, - content=(glm.Content(parts=[glm.Part(text="Demo answer.")])), + content=(protos.Content(parts=[protos.Part(text="Demo answer.")])), ), answerable_probability=0.500, ), diff --git a/tests/test_client.py b/tests/test_client.py index 0256edac3..0cc3e05eb 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,8 +4,10 @@ from absl.testing import absltest from absl.testing import parameterized -from google.api_core import client_options import google.ai.generativelanguage as glm + +from google.api_core import client_options +from google.generativeai import protos from google.generativeai import client diff --git a/tests/test_content.py b/tests/test_content.py index 5f22b93a1..3829ebc86 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -19,7 +19,7 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.types import content_types import IPython.display import PIL.Image @@ -71,7 +71,7 @@ class UnitTests(parameterized.TestCase): ) def test_png_to_blob(self, image): blob = content_types.image_to_blob(image) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @@ -81,29 +81,29 @@ def test_png_to_blob(self, image): ) def test_jpg_to_blob(self, image): blob = content_types.image_to_blob(image) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/jpeg") self.assertStartsWith(blob.data, b"\xff\xd8\xff\xe0\x00\x10JFIF") @parameterized.named_parameters( ["BlobDict", {"mime_type": "image/png", "data": TEST_PNG_DATA}], - ["glm.Blob", glm.Blob(mime_type="image/png", data=TEST_PNG_DATA)], + ["protos.Blob", protos.Blob(mime_type="image/png", data=TEST_PNG_DATA)], ["Image", IPython.display.Image(filename=TEST_PNG_PATH)], ) def test_to_blob(self, example): blob = content_types.to_blob(example) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @parameterized.named_parameters( ["dict", {"text": "Hello world!"}], - ["glm.Part", glm.Part(text="Hello world!")], + ["protos.Part", protos.Part(text="Hello world!")], ["str", "Hello world!"], ) def test_to_part(self, example): part = content_types.to_part(example) - self.assertIsInstance(part, glm.Part) + self.assertIsInstance(part, protos.Part) self.assertEqual(part.text, "Hello world!") @parameterized.named_parameters( @@ -116,12 +116,12 @@ def test_to_part(self, example): ) def test_img_to_part(self, example): blob = content_types.to_part(example).inline_data - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @parameterized.named_parameters( - ["glm.Content", glm.Content(parts=[{"text": "Hello world!"}])], + ["protos.Content", protos.Content(parts=[{"text": "Hello world!"}])], ["ContentDict", {"parts": [{"text": "Hello world!"}]}], ["ContentDict-str", {"parts": ["Hello world!"]}], ["list[parts]", [{"text": "Hello world!"}]], @@ -135,7 +135,7 @@ def test_to_content(self, example): part = content.parts[0] self.assertLen(content.parts, 1) - self.assertIsInstance(part, glm.Part) + self.assertIsInstance(part, protos.Part) self.assertEqual(part.text, "Hello world!") @parameterized.named_parameters( @@ -147,12 +147,12 @@ def test_img_to_content(self, example): content = content_types.to_content(example) blob = content.parts[0].inline_data self.assertLen(content.parts, 1) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @parameterized.named_parameters( - ["glm.Content", glm.Content(parts=[{"text": "Hello world!"}])], + ["protos.Content", protos.Content(parts=[{"text": "Hello world!"}])], ["ContentDict", {"parts": [{"text": "Hello world!"}]}], ["ContentDict-str", {"parts": ["Hello world!"]}], ) @@ -161,7 +161,7 @@ def test_strict_to_content(self, example): part = content.parts[0] self.assertLen(content.parts, 1) - self.assertIsInstance(part, glm.Part) + self.assertIsInstance(part, protos.Part) self.assertEqual(part.text, "Hello world!") @parameterized.named_parameters( @@ -176,7 +176,7 @@ def test_strict_to_contents_fails(self, examples): content_types.strict_to_content(examples) @parameterized.named_parameters( - ["glm.Content", [glm.Content(parts=[{"text": "Hello world!"}])]], + ["protos.Content", [protos.Content(parts=[{"text": "Hello world!"}])]], ["ContentDict", [{"parts": [{"text": "Hello world!"}]}]], ["ContentDict-unwraped", [{"parts": ["Hello world!"]}]], ["ContentDict+str-part", [{"parts": "Hello world!"}]], @@ -188,7 +188,7 @@ def test_to_contents(self, example): self.assertLen(contents, 1) self.assertLen(contents[0].parts, 1) - self.assertIsInstance(part, glm.Part) + self.assertIsInstance(part, protos.Part) self.assertEqual(part.text, "Hello world!") def test_dict_to_content_fails(self): @@ -209,7 +209,7 @@ def test_img_to_contents(self, example): self.assertLen(contents, 1) self.assertLen(contents[0].parts, 1) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @@ -217,9 +217,9 @@ def test_img_to_contents(self, example): [ "FunctionLibrary", content_types.FunctionLibrary( - tools=glm.Tool( + tools=protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -231,7 +231,7 @@ def test_img_to_contents(self, example): [ content_types.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -239,11 +239,11 @@ def test_img_to_contents(self, example): ], ], [ - "IterableTool-glm.Tool", + "IterableTool-protos.Tool", [ - glm.Tool( + protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -268,7 +268,7 @@ def test_img_to_contents(self, example): "IterableTool-IterableFD", [ [ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -278,7 +278,7 @@ def test_img_to_contents(self, example): [ "IterableTool-FD", [ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -288,17 +288,17 @@ def test_img_to_contents(self, example): "Tool", content_types.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] ), ], [ - "glm.Tool", - glm.Tool( + "protos.Tool", + protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -350,8 +350,8 @@ def test_img_to_contents(self, example): ), ], [ - "glm.FD", - glm.FunctionDeclaration( + "protos.FD", + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ), ], @@ -391,83 +391,83 @@ def b(): self.assertLen(tools[0].function_declarations, 2) @parameterized.named_parameters( - ["int", int, glm.Schema(type=glm.Type.INTEGER)], - ["float", float, glm.Schema(type=glm.Type.NUMBER)], - ["str", str, glm.Schema(type=glm.Type.STRING)], - ["nullable_str", Union[str, None], glm.Schema(type=glm.Type.STRING, nullable=True)], + ["int", int, protos.Schema(type=protos.Type.INTEGER)], + ["float", float, protos.Schema(type=protos.Type.NUMBER)], + ["str", str, protos.Schema(type=protos.Type.STRING)], + ["nullable_str", Union[str, None], protos.Schema(type=protos.Type.STRING, nullable=True)], [ "list", list[str], - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema(type=glm.Type.STRING), + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema(type=protos.Type.STRING), ), ], [ "list-list-int", list[list[int]], - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema( - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema(type=glm.Type.INTEGER), + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema( + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema(type=protos.Type.INTEGER), ), ), ), ], - ["dict", dict, glm.Schema(type=glm.Type.OBJECT)], - ["dict-str-any", dict[str, Any], glm.Schema(type=glm.Type.OBJECT)], + ["dict", dict, protos.Schema(type=protos.Type.OBJECT)], + ["dict-str-any", dict[str, Any], protos.Schema(type=protos.Type.OBJECT)], [ "dataclass", ADataClass, - glm.Schema( - type=glm.Type.OBJECT, - properties={"a": {"type_": glm.Type.INTEGER}}, + protos.Schema( + type=protos.Type.OBJECT, + properties={"a": {"type_": protos.Type.INTEGER}}, ), ], [ "nullable_dataclass", Union[ADataClass, None], - glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + type=protos.Type.OBJECT, nullable=True, - properties={"a": {"type_": glm.Type.INTEGER}}, + properties={"a": {"type_": protos.Type.INTEGER}}, ), ], [ "list_of_dataclass", list[ADataClass], - glm.Schema( + protos.Schema( type="ARRAY", - items=glm.Schema( - type=glm.Type.OBJECT, - properties={"a": {"type_": glm.Type.INTEGER}}, + items=protos.Schema( + type=protos.Type.OBJECT, + properties={"a": {"type_": protos.Type.INTEGER}}, ), ), ], [ "dataclass_with_nullable", ADataClassWithNullable, - glm.Schema( - type=glm.Type.OBJECT, - properties={"a": {"type_": glm.Type.INTEGER, "nullable": True}}, + protos.Schema( + type=protos.Type.OBJECT, + properties={"a": {"type_": protos.Type.INTEGER, "nullable": True}}, ), ], [ "dataclass_with_list", ADataClassWithList, - glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + type=protos.Type.OBJECT, properties={"a": {"type_": "ARRAY", "items": {"type_": "INTEGER"}}}, ), ], [ "list_of_dataclass_with_list", list[ADataClassWithList], - glm.Schema( - items=glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + items=protos.Schema( + type=protos.Type.OBJECT, properties={"a": {"type_": "ARRAY", "items": {"type_": "INTEGER"}}}, ), type="ARRAY", @@ -476,31 +476,31 @@ def b(): [ "list_of_nullable", list[Union[int, None]], - glm.Schema( + protos.Schema( type="ARRAY", - items={"type_": glm.Type.INTEGER, "nullable": True}, + items={"type_": protos.Type.INTEGER, "nullable": True}, ), ], [ "TypedDict", ATypedDict, - glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + type=protos.Type.OBJECT, properties={ - "a": {"type_": glm.Type.INTEGER}, + "a": {"type_": protos.Type.INTEGER}, }, ), ], [ "nested", Nested, - glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + type=protos.Type.OBJECT, properties={ - "x": glm.Schema( - type=glm.Type.OBJECT, + "x": protos.Schema( + type=protos.Type.OBJECT, properties={ - "a": {"type_": glm.Type.INTEGER}, + "a": {"type_": protos.Type.INTEGER}, }, ), }, diff --git a/tests/test_discuss.py b/tests/test_discuss.py index 7db0a63d8..4e54cf754 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -16,7 +16,7 @@ import unittest.mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import discuss from google.generativeai import client @@ -37,18 +37,18 @@ def setUp(self): self.observed_request = None - self.mock_response = glm.GenerateMessageResponse( + self.mock_response = protos.GenerateMessageResponse( candidates=[ - glm.Message(content="a", author="1"), - glm.Message(content="b", author="1"), - glm.Message(content="c", author="1"), + protos.Message(content="a", author="1"), + protos.Message(content="b", author="1"), + protos.Message(content="c", author="1"), ], ) def fake_generate_message( - request: glm.GenerateMessageRequest, + request: protos.GenerateMessageRequest, **kwargs, - ) -> glm.GenerateMessageResponse: + ) -> protos.GenerateMessageResponse: self.observed_request = request response = copy.copy(self.mock_response) response.messages = request.prompt.messages @@ -60,22 +60,22 @@ def fake_generate_message( ["string", "Hello", ""], ["dict", {"content": "Hello"}, ""], ["dict_author", {"content": "Hello", "author": "me"}, "me"], - ["proto", glm.Message(content="Hello"), ""], - ["proto_author", glm.Message(content="Hello", author="me"), "me"], + ["proto", protos.Message(content="Hello"), ""], + ["proto_author", protos.Message(content="Hello", author="me"), "me"], ) def test_make_message(self, message, author): x = discuss._make_message(message) - self.assertIsInstance(x, glm.Message) + self.assertIsInstance(x, protos.Message) self.assertEqual("Hello", x.content) self.assertEqual(author, x.author) @parameterized.named_parameters( ["string", "Hello", ["Hello"]], ["dict", {"content": "Hello"}, ["Hello"]], - ["proto", glm.Message(content="Hello"), ["Hello"]], + ["proto", protos.Message(content="Hello"), ["Hello"]], [ "list", - ["hello0", {"content": "hello1"}, glm.Message(content="hello2")], + ["hello0", {"content": "hello1"}, protos.Message(content="hello2")], ["hello0", "hello1", "hello2"], ], ) @@ -90,15 +90,15 @@ def test_make_messages(self, messages, expected_contents): ["dict", {"input": "hello", "output": "goodbye"}], [ "proto", - glm.Example( - input=glm.Message(content="hello"), - output=glm.Message(content="goodbye"), + protos.Example( + input=protos.Message(content="hello"), + output=protos.Message(content="goodbye"), ), ], ) def test_make_example(self, example): x = discuss._make_example(example) - self.assertIsInstance(x, glm.Example) + self.assertIsInstance(x, protos.Example) self.assertEqual("hello", x.input.content) self.assertEqual("goodbye", x.output.content) return @@ -110,7 +110,7 @@ def test_make_example(self, example): "Hi", {"content": "Hello!"}, "what's your name?", - glm.Message(content="Dave, what's yours"), + protos.Message(content="Dave, what's yours"), ], ], [ @@ -145,15 +145,15 @@ def test_make_examples_from_example(self): @parameterized.named_parameters( ["str", "hello"], - ["message", glm.Message(content="hello")], + ["message", protos.Message(content="hello")], ["messages", ["hello"]], ["dict", {"messages": "hello"}], ["dict2", {"messages": ["hello"]}], - ["proto", glm.MessagePrompt(messages=[glm.Message(content="hello")])], + ["proto", protos.MessagePrompt(messages=[protos.Message(content="hello")])], ) def test_make_message_prompt_from_messages(self, prompt): x = discuss._make_message_prompt(prompt) - self.assertIsInstance(x, glm.MessagePrompt) + self.assertIsInstance(x, protos.MessagePrompt) self.assertEqual(x.messages[0].content, "hello") return @@ -181,15 +181,15 @@ def test_make_message_prompt_from_messages(self, prompt): [ "proto", [ - glm.MessagePrompt( + protos.MessagePrompt( context="you are a cat", examples=[ - glm.Example( - input=glm.Message(content="are you hungry?"), - output=glm.Message(content="meow!"), + protos.Example( + input=protos.Message(content="are you hungry?"), + output=protos.Message(content="meow!"), ) ], - messages=[glm.Message(content="hello")], + messages=[protos.Message(content="hello")], ) ], {}, @@ -197,7 +197,7 @@ def test_make_message_prompt_from_messages(self, prompt): ) def test_make_message_prompt_from_prompt(self, args, kwargs): x = discuss._make_message_prompt(*args, **kwargs) - self.assertIsInstance(x, glm.MessagePrompt) + self.assertIsInstance(x, protos.MessagePrompt) self.assertEqual(x.context, "you are a cat") self.assertEqual(x.examples[0].input.content, "are you hungry?") self.assertEqual(x.examples[0].output.content, "meow!") @@ -229,8 +229,8 @@ def test_make_generate_message_request_nested( } ) - self.assertIsInstance(request0, glm.GenerateMessageRequest) - self.assertIsInstance(request1, glm.GenerateMessageRequest) + self.assertIsInstance(request0, protos.GenerateMessageRequest) + self.assertIsInstance(request1, protos.GenerateMessageRequest) self.assertEqual(request0, request1) @parameterized.parameters( @@ -285,11 +285,13 @@ def test_reply(self, kwargs): response = response.reply("again") def test_receive_and_reply_with_filters(self): - self.mock_response = mock_response = glm.GenerateMessageResponse( - candidates=[glm.Message(content="a", author="1")], + self.mock_response = mock_response = protos.GenerateMessageResponse( + candidates=[protos.Message(content="a", author="1")], filters=[ - glm.ContentFilter(reason=palm_safety_types.BlockedReason.SAFETY, message="unsafe"), - glm.ContentFilter(reason=palm_safety_types.BlockedReason.OTHER), + protos.ContentFilter( + reason=palm_safety_types.BlockedReason.SAFETY, message="unsafe" + ), + protos.ContentFilter(reason=palm_safety_types.BlockedReason.OTHER), ], ) response = discuss.chat(messages="do filters work?") @@ -300,10 +302,12 @@ def test_receive_and_reply_with_filters(self): self.assertEqual(filters[0]["reason"], palm_safety_types.BlockedReason.SAFETY) self.assertEqual(filters[0]["message"], "unsafe") - self.mock_response = glm.GenerateMessageResponse( - candidates=[glm.Message(content="a", author="1")], + self.mock_response = protos.GenerateMessageResponse( + candidates=[protos.Message(content="a", author="1")], filters=[ - glm.ContentFilter(reason=palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED) + protos.ContentFilter( + reason=palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED + ) ], ) @@ -317,7 +321,7 @@ def test_receive_and_reply_with_filters(self): ) def test_chat_citations(self): - self.mock_response = mock_response = glm.GenerateMessageResponse( + self.mock_response = mock_response = protos.GenerateMessageResponse( candidates=[ { "content": "Hello google!", diff --git a/tests/test_discuss_async.py b/tests/test_discuss_async.py index 7e1f7947c..d35d03525 100644 --- a/tests/test_discuss_async.py +++ b/tests/test_discuss_async.py @@ -17,7 +17,7 @@ from typing import Any import unittest -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import discuss from absl.testing import absltest @@ -31,14 +31,14 @@ async def test_chat_async(self): observed_request = None async def fake_generate_message( - request: glm.GenerateMessageRequest, + request: protos.GenerateMessageRequest, **kwargs, - ) -> glm.GenerateMessageResponse: + ) -> protos.GenerateMessageResponse: nonlocal observed_request observed_request = request - return glm.GenerateMessageResponse( + return protos.GenerateMessageResponse( candidates=[ - glm.Message( + protos.Message( author="1", content="Why did the chicken cross the road?", ) @@ -59,17 +59,17 @@ async def fake_generate_message( self.assertEqual( observed_request, - glm.GenerateMessageRequest( + protos.GenerateMessageRequest( model="models/bard", - prompt=glm.MessagePrompt( + prompt=protos.MessagePrompt( context="Example Prompt", examples=[ - glm.Example( - input=glm.Message(content="Example from human"), - output=glm.Message(content="Example response from AI"), + protos.Example( + input=protos.Message(content="Example from human"), + output=protos.Message(content="Example response from AI"), ) ], - messages=[glm.Message(author="0", content="Tell me a joke")], + messages=[protos.Message(author="0", content="Tell me a joke")], ), temperature=0.75, candidate_count=1, diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 5f6aa8d89..a208a4743 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -18,7 +18,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import embedding @@ -45,20 +45,20 @@ def add_client_method(f): @add_client_method def embed_content( - request: glm.EmbedContentRequest, + request: protos.EmbedContentRequest, **kwargs, - ) -> glm.EmbedContentResponse: + ) -> protos.EmbedContentResponse: self.observed_requests.append(request) - return glm.EmbedContentResponse(embedding=glm.ContentEmbedding(values=[1, 2, 3])) + return protos.EmbedContentResponse(embedding=protos.ContentEmbedding(values=[1, 2, 3])) @add_client_method def batch_embed_contents( - request: glm.BatchEmbedContentsRequest, + request: protos.BatchEmbedContentsRequest, **kwargs, - ) -> glm.BatchEmbedContentsResponse: + ) -> protos.BatchEmbedContentsResponse: self.observed_requests.append(request) - return glm.BatchEmbedContentsResponse( - embeddings=[glm.ContentEmbedding(values=[1, 2, 3])] * len(request.requests) + return protos.BatchEmbedContentsResponse( + embeddings=[protos.ContentEmbedding(values=[1, 2, 3])] * len(request.requests) ) def test_embed_content(self): @@ -68,8 +68,9 @@ def test_embed_content(self): self.assertIsInstance(emb, dict) self.assertEqual( self.observed_requests[-1], - glm.EmbedContentRequest( - model=DEFAULT_EMB_MODEL, content=glm.Content(parts=[glm.Part(text="What are you?")]) + protos.EmbedContentRequest( + model=DEFAULT_EMB_MODEL, + content=protos.Content(parts=[protos.Part(text="What are you?")]), ), ) self.assertIsInstance(emb["embedding"][0], float) diff --git a/tests/test_embedding_async.py b/tests/test_embedding_async.py index d4ca16c08..367cf7ded 100644 --- a/tests/test_embedding_async.py +++ b/tests/test_embedding_async.py @@ -18,7 +18,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import embedding @@ -44,20 +44,20 @@ def add_client_method(f): @add_client_method async def embed_content( - request: glm.EmbedContentRequest, + request: protos.EmbedContentRequest, **kwargs, - ) -> glm.EmbedContentResponse: + ) -> protos.EmbedContentResponse: self.observed_requests.append(request) - return glm.EmbedContentResponse(embedding=glm.ContentEmbedding(values=[1, 2, 3])) + return protos.EmbedContentResponse(embedding=protos.ContentEmbedding(values=[1, 2, 3])) @add_client_method async def batch_embed_contents( - request: glm.BatchEmbedContentsRequest, + request: protos.BatchEmbedContentsRequest, **kwargs, - ) -> glm.BatchEmbedContentsResponse: + ) -> protos.BatchEmbedContentsResponse: self.observed_requests.append(request) - return glm.BatchEmbedContentsResponse( - embeddings=[glm.ContentEmbedding(values=[1, 2, 3])] * len(request.requests) + return protos.BatchEmbedContentsResponse( + embeddings=[protos.ContentEmbedding(values=[1, 2, 3])] * len(request.requests) ) async def test_embed_content_async(self): @@ -67,8 +67,9 @@ async def test_embed_content_async(self): self.assertIsInstance(emb, dict) self.assertEqual( self.observed_requests[-1], - glm.EmbedContentRequest( - model=DEFAULT_EMB_MODEL, content=glm.Content(parts=[glm.Part(text="What are you?")]) + protos.EmbedContentRequest( + model=DEFAULT_EMB_MODEL, + content=protos.Content(parts=[protos.Part(text="What are you?")]), ), ) self.assertIsInstance(emb["embedding"][0], float) diff --git a/tests/test_files.py b/tests/test_files.py index 333ec1e2a..7d9139450 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -22,10 +22,10 @@ import pathlib import google -import google.ai.generativelanguage as glm import google.generativeai as genai from google.generativeai import client as client_lib +from google.generativeai import protos from absl.testing import parameterized @@ -43,7 +43,7 @@ def create_file( name: Union[str, None] = None, display_name: Union[str, None] = None, resumable: bool = True, - ) -> glm.File: + ) -> protos.File: self.observed_requests.append( dict( path=path, @@ -57,24 +57,24 @@ def create_file( def get_file( self, - request: glm.GetFileRequest, + request: protos.GetFileRequest, **kwargs, - ) -> glm.File: + ) -> protos.File: self.observed_requests.append(request) return self.responses["get_file"].pop(0) def list_files( self, - request: glm.ListFilesRequest, + request: protos.ListFilesRequest, **kwargs, - ) -> Iterable[glm.File]: + ) -> Iterable[protos.File]: self.observed_requests.append(request) for f in self.responses["list_files"].pop(0): yield f def delete_file( self, - request: glm.DeleteFileRequest, + request: protos.DeleteFileRequest, **kwargs, ): self.observed_requests.append(request) @@ -97,7 +97,7 @@ def responses(self): def test_video_metadata(self): self.responses["create_file"].append( - glm.File( + protos.File( uri="https://test", state="ACTIVE", video_metadata=dict(video_duration=datetime.timedelta(seconds=30)), @@ -108,7 +108,8 @@ def test_video_metadata(self): f = genai.upload_file(path="dummy") self.assertEqual(google.rpc.status_pb2.Status(code=7, message="ok?"), f.error) self.assertEqual( - glm.VideoMetadata(dict(video_duration=datetime.timedelta(seconds=30))), f.video_metadata + protos.VideoMetadata(dict(video_duration=datetime.timedelta(seconds=30))), + f.video_metadata, ) @parameterized.named_parameters( @@ -123,11 +124,11 @@ def test_video_metadata(self): ), dict( testcase_name="FileData", - file_data=glm.FileData(file_uri="https://test_uri"), + file_data=protos.FileData(file_uri="https://test_uri"), ), dict( - testcase_name="glm.File", - file_data=glm.File(uri="https://test_uri"), + testcase_name="protos.File", + file_data=protos.File(uri="https://test_uri"), ), dict( testcase_name="file_types.File", @@ -137,4 +138,4 @@ def test_video_metadata(self): ) def test_to_file_data(self, file_data): file_data = file_types.to_file_data(file_data) - self.assertEqual(glm.FileData(file_uri="https://test_uri"), file_data) + self.assertEqual(protos.FileData(file_uri="https://test_uri"), file_data) diff --git a/tests/test_generation.py b/tests/test_generation.py index b256a1029..828577d21 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -5,7 +5,7 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.types import generation_types @@ -24,9 +24,11 @@ class Person(TypedDict): class UnitTests(parameterized.TestCase): @parameterized.named_parameters( [ - "glm.GenerationConfig", - glm.GenerationConfig( - temperature=0.1, stop_sequences=["end"], response_schema=glm.Schema(type="STRING") + "protos.GenerationConfig", + protos.GenerationConfig( + temperature=0.1, + stop_sequences=["end"], + response_schema=protos.Schema(type="STRING"), ), ], [ @@ -48,15 +50,15 @@ def test_to_generation_config(self, config): def test_join_citation_metadatas(self): citations = [ - glm.CitationMetadata( + protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=3, end_index=21, uri="https://google.com"), + protos.CitationSource(start_index=3, end_index=21, uri="https://google.com"), ] ), - glm.CitationMetadata( + protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=3, end_index=33, uri="https://google.com"), - glm.CitationSource(start_index=55, end_index=92, uri="https://google.com"), + protos.CitationSource(start_index=3, end_index=33, uri="https://google.com"), + protos.CitationSource(start_index=55, end_index=92, uri="https://google.com"), ] ), ] @@ -74,14 +76,14 @@ def test_join_citation_metadatas(self): def test_join_safety_ratings_list(self): ratings = [ [ - glm.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), - glm.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="LOW"), - glm.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="MEDIUM"), + protos.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="MEDIUM"), ], [ - glm.SafetyRating(category="HARM_CATEGORY_DEROGATORY", probability="LOW"), - glm.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="LOW"), - glm.SafetyRating( + protos.SafetyRating(category="HARM_CATEGORY_DEROGATORY", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="LOW"), + protos.SafetyRating( category="HARM_CATEGORY_DANGEROUS", probability="HIGH", blocked=True, @@ -101,14 +103,14 @@ def test_join_safety_ratings_list(self): def test_join_contents(self): contents = [ - glm.Content(role="assistant", parts=[glm.Part(text="Tell me a story about a ")]), - glm.Content( + protos.Content(role="assistant", parts=[protos.Part(text="Tell me a story about a ")]), + protos.Content( role="assistant", - parts=[glm.Part(text="magic backpack that looks like this: ")], + parts=[protos.Part(text="magic backpack that looks like this: ")], ), - glm.Content( + protos.Content( role="assistant", - parts=[glm.Part(inline_data=glm.Blob(mime_type="image/png", data=b"DATA!"))], + parts=[protos.Part(inline_data=protos.Blob(mime_type="image/png", data=b"DATA!"))], ), ] result = generation_types._join_contents(contents) @@ -126,7 +128,8 @@ def test_many_join_contents(self): import string contents = [ - glm.Content(role="assistant", parts=[glm.Part(text=a)]) for a in string.ascii_lowercase + protos.Content(role="assistant", parts=[protos.Part(text=a)]) + for a in string.ascii_lowercase ] result = generation_types._join_contents(contents) @@ -139,41 +142,53 @@ def test_many_join_contents(self): def test_join_candidates(self): candidates = [ - glm.Candidate( + protos.Candidate( index=0, - content=glm.Content( + content=protos.Content( role="assistant", - parts=[glm.Part(text="Tell me a story about a ")], + parts=[protos.Part(text="Tell me a story about a ")], ), - citation_metadata=glm.CitationMetadata( + citation_metadata=protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=55, end_index=85, uri="https://google.com"), + protos.CitationSource( + start_index=55, end_index=85, uri="https://google.com" + ), ] ), ), - glm.Candidate( + protos.Candidate( index=0, - content=glm.Content( + content=protos.Content( role="assistant", - parts=[glm.Part(text="magic backpack that looks like this: ")], + parts=[protos.Part(text="magic backpack that looks like this: ")], ), - citation_metadata=glm.CitationMetadata( + citation_metadata=protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=55, end_index=92, uri="https://google.com"), - glm.CitationSource(start_index=3, end_index=21, uri="https://google.com"), + protos.CitationSource( + start_index=55, end_index=92, uri="https://google.com" + ), + protos.CitationSource( + start_index=3, end_index=21, uri="https://google.com" + ), ] ), ), - glm.Candidate( + protos.Candidate( index=0, - content=glm.Content( + content=protos.Content( role="assistant", - parts=[glm.Part(inline_data=glm.Blob(mime_type="image/png", data=b"DATA!"))], + parts=[ + protos.Part(inline_data=protos.Blob(mime_type="image/png", data=b"DATA!")) + ], ), - citation_metadata=glm.CitationMetadata( + citation_metadata=protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=55, end_index=92, uri="https://google.com"), - glm.CitationSource(start_index=3, end_index=21, uri="https://google.com"), + protos.CitationSource( + start_index=55, end_index=92, uri="https://google.com" + ), + protos.CitationSource( + start_index=3, end_index=21, uri="https://google.com" + ), ] ), finish_reason="STOP", @@ -213,17 +228,17 @@ def test_join_candidates(self): def test_join_prompt_feedbacks(self): feedbacks = [ - glm.GenerateContentResponse.PromptFeedback( + protos.GenerateContentResponse.PromptFeedback( block_reason="SAFETY", safety_ratings=[ - glm.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), ], ), - glm.GenerateContentResponse.PromptFeedback(), - glm.GenerateContentResponse.PromptFeedback(), - glm.GenerateContentResponse.PromptFeedback( + protos.GenerateContentResponse.PromptFeedback(), + protos.GenerateContentResponse.PromptFeedback(), + protos.GenerateContentResponse.PromptFeedback( safety_ratings=[ - glm.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="HIGH"), + protos.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="HIGH"), ] ), ] @@ -396,23 +411,23 @@ def test_join_prompt_feedbacks(self): ] def test_join_candidates(self): - candidate_lists = [[glm.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS] + candidate_lists = [[protos.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS] result = generation_types._join_candidate_lists(candidate_lists) self.assertEqual(self.MERGED_CANDIDATES, [type(r).to_dict(r) for r in result]) def test_join_chunks(self): - chunks = [glm.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] + chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] - chunks[0].prompt_feedback = glm.GenerateContentResponse.PromptFeedback( + chunks[0].prompt_feedback = protos.GenerateContentResponse.PromptFeedback( block_reason="SAFETY", safety_ratings=[ - glm.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), ], ) result = generation_types._join_chunks(chunks) - expected = glm.GenerateContentResponse( + expected = protos.GenerateContentResponse( { "candidates": self.MERGED_CANDIDATES, "prompt_feedback": { @@ -431,7 +446,7 @@ def test_join_chunks(self): self.assertEqual(type(expected).to_dict(expected), type(result).to_dict(expected)) def test_generate_content_response_iterator_end_to_end(self): - chunks = [glm.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] + chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] merged = generation_types._join_chunks(chunks) response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) @@ -453,7 +468,7 @@ def test_generate_content_response_iterator_end_to_end(self): def test_generate_content_response_multiple_iterators(self): chunks = [ - glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) + protos.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) for a in string.ascii_lowercase ] response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) @@ -483,7 +498,7 @@ def test_generate_content_response_multiple_iterators(self): def test_generate_content_response_resolve(self): chunks = [ - glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) + protos.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) for a in "abcd" ] response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) @@ -497,7 +512,7 @@ def test_generate_content_response_resolve(self): self.assertEqual(response.candidates[0].content.parts[0].text, "abcd") def test_generate_content_response_from_response(self): - raw_response = glm.GenerateContentResponse( + raw_response = protos.GenerateContentResponse( {"candidates": [{"content": {"parts": [{"text": "Hello world!"}]}}]} ) response = generation_types.GenerateContentResponse.from_response(raw_response) @@ -511,7 +526,7 @@ def test_generate_content_response_from_response(self): ) def test_repr_for_generate_content_response_from_response(self): - raw_response = glm.GenerateContentResponse( + raw_response = protos.GenerateContentResponse( {"candidates": [{"content": {"parts": [{"text": "Hello world!"}]}}]} ) response = generation_types.GenerateContentResponse.from_response(raw_response) @@ -523,7 +538,7 @@ def test_repr_for_generate_content_response_from_response(self): GenerateContentResponse( done=True, iterator=None, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -542,7 +557,7 @@ def test_repr_for_generate_content_response_from_response(self): def test_repr_for_generate_content_response_from_iterator(self): chunks = [ - glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) + protos.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) for a in "abcd" ] response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) @@ -554,7 +569,7 @@ def test_repr_for_generate_content_response_from_iterator(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -573,35 +588,35 @@ def test_repr_for_generate_content_response_from_iterator(self): @parameterized.named_parameters( [ - "glm.Schema", - glm.Schema(type="STRING"), - glm.Schema(type="STRING"), + "protos.Schema", + protos.Schema(type="STRING"), + protos.Schema(type="STRING"), ], [ "SchemaDict", {"type": "STRING"}, - glm.Schema(type="STRING"), + protos.Schema(type="STRING"), ], [ "str", str, - glm.Schema(type="STRING"), + protos.Schema(type="STRING"), ], - ["list_of_str", list[str], glm.Schema(type="ARRAY", items=glm.Schema(type="STRING"))], + ["list_of_str", list[str], protos.Schema(type="ARRAY", items=protos.Schema(type="STRING"))], [ "fancy", Person, - glm.Schema( + protos.Schema( type="OBJECT", properties=dict( - name=glm.Schema(type="STRING"), - favorite_color=glm.Schema(type="STRING"), - birthday=glm.Schema( + name=protos.Schema(type="STRING"), + favorite_color=protos.Schema(type="STRING"), + birthday=protos.Schema( type="OBJECT", properties=dict( - day=glm.Schema(type="INTEGER"), - month=glm.Schema(type="INTEGER"), - year=glm.Schema(type="INTEGER"), + day=protos.Schema(type="INTEGER"), + month=protos.Schema(type="INTEGER"), + year=protos.Schema(type="INTEGER"), ), ), ), diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 4a0f86991..0ece77e94 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -7,7 +7,7 @@ import unittest.mock from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import client as client_lib from google.generativeai import generative_models from google.generativeai.types import content_types @@ -23,20 +23,20 @@ TEST_IMAGE_DATA = TEST_IMAGE_PATH.read_bytes() -def noop(x: int): - return x +def simple_part(text: str) -> protos.Content: + return protos.Content({"parts": [{"text": text}]}) -def simple_part(text: str) -> glm.Content: - return glm.Content({"parts": [{"text": text}]}) +def noop(x: int): + return x -def iter_part(texts: Iterable[str]) -> glm.Content: - return glm.Content({"parts": [{"text": t} for t in texts]}) +def iter_part(texts: Iterable[str]) -> protos.Content: + return protos.Content({"parts": [{"text": t} for t in texts]}) -def simple_response(text: str) -> glm.GenerateContentResponse: - return glm.GenerateContentResponse({"candidates": [{"content": simple_part(text)}]}) +def simple_response(text: str) -> protos.GenerateContentResponse: + return protos.GenerateContentResponse({"candidates": [{"content": simple_part(text)}]}) class MockGenerativeServiceClient: @@ -48,10 +48,10 @@ def __init__(self, test): def generate_content( self, - request: glm.GenerateContentRequest, + request: protos.GenerateContentRequest, **kwargs, - ) -> glm.GenerateContentResponse: - self.test.assertIsInstance(request, glm.GenerateContentRequest) + ) -> protos.GenerateContentResponse: + self.test.assertIsInstance(request, protos.GenerateContentRequest) self.observed_requests.append(request) self.observed_kwargs.append(kwargs) response = self.responses["generate_content"].pop(0) @@ -59,9 +59,9 @@ def generate_content( def stream_generate_content( self, - request: glm.GetModelRequest, + request: protos.GetModelRequest, **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: + ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) self.observed_kwargs.append(kwargs) response = self.responses["stream_generate_content"].pop(0) @@ -69,9 +69,9 @@ def stream_generate_content( def count_tokens( self, - request: glm.CountTokensRequest, + request: protos.CountTokensRequest, **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: + ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) self.observed_kwargs.append(kwargs) response = self.responses["count_tokens"].pop(0) @@ -149,9 +149,9 @@ def test_image(self, content): generation_types.GenerationConfig(temperature=0.5), ], [ - "glm", - glm.GenerationConfig(temperature=0.0), - glm.GenerationConfig(temperature=0.5), + "protos", + protos.GenerationConfig(temperature=0.0), + protos.GenerationConfig(temperature=0.5), ], ) def test_generation_config_overwrite(self, config1, config2): @@ -176,8 +176,8 @@ def test_generation_config_overwrite(self, config1, config2): "list-dict", [ dict( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ), ], [ @@ -187,15 +187,15 @@ def test_generation_config_overwrite(self, config1, config2): [ "object", [ - glm.SafetySetting( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + protos.SafetySetting( + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ), ], [ - glm.SafetySetting( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, + protos.SafetySetting( + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, ), ], ], @@ -214,22 +214,22 @@ def test_safety_overwrite(self, safe1, safe2): danger = [ s for s in self.observed_requests[-1].safety_settings - if s.category == glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT + if s.category == protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT ] self.assertEqual( danger[0].threshold, - glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ) _ = model.generate_content("hello", safety_settings=safe2) danger = [ s for s in self.observed_requests[-1].safety_settings - if s.category == glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT + if s.category == protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT ] self.assertEqual( danger[0].threshold, - glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, + protos.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, ) def test_stream_basic(self): @@ -263,7 +263,7 @@ def test_stream_lookahead(self): def test_stream_prompt_feedback_blocked(self): chunks = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "prompt_feedback": {"block_reason": "SAFETY"}, } @@ -276,7 +276,7 @@ def test_stream_prompt_feedback_blocked(self): self.assertEqual( response.prompt_feedback.block_reason, - glm.GenerateContentResponse.PromptFeedback.BlockReason.SAFETY, + protos.GenerateContentResponse.PromptFeedback.BlockReason.SAFETY, ) with self.assertRaises(generation_types.BlockedPromptException): @@ -285,20 +285,20 @@ def test_stream_prompt_feedback_blocked(self): def test_stream_prompt_feedback_not_blocked(self): chunks = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "prompt_feedback": { "safety_ratings": [ { - "category": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - "probability": glm.SafetyRating.HarmProbability.NEGLIGIBLE, + "category": protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + "probability": protos.SafetyRating.HarmProbability.NEGLIGIBLE, } ] }, "candidates": [{"content": {"parts": [{"text": "first"}]}}], } ), - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"content": {"parts": [{"text": " second"}]}}], } @@ -311,7 +311,7 @@ def test_stream_prompt_feedback_not_blocked(self): self.assertEqual( response.prompt_feedback.safety_ratings[0].category, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + protos.HarmCategory.HARM_CATEGORY_DANGEROUS, ) text = "".join(chunk.text for chunk in response) @@ -544,7 +544,7 @@ def no_throw(): def test_chat_prompt_blocked(self): self.responses["generate_content"] = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "prompt_feedback": {"block_reason": "SAFETY"}, } @@ -562,7 +562,7 @@ def test_chat_prompt_blocked(self): def test_chat_candidate_blocked(self): # I feel like chat needs a .last so you can look at the partial results. self.responses["generate_content"] = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"finish_reason": "SAFETY"}], } @@ -582,7 +582,7 @@ def test_chat_streaming_unexpected_stop(self): simple_response("a"), simple_response("b"), simple_response("c"), - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"finish_reason": "SAFETY"}], } @@ -669,9 +669,9 @@ def test_tools(self): }, ), dict( - testcase_name="test_glm_FunctionCallingConfig", + testcase_name="test_protos.FunctionCallingConfig", tool_config={ - "function_calling_config": glm.FunctionCallingConfig( + "function_calling_config": protos.FunctionCallingConfig( mode=content_types.FunctionCallingMode.AUTO ) }, @@ -698,9 +698,9 @@ def test_tools(self): }, ), dict( - testcase_name="test_glm_ToolConfig", - tool_config=glm.ToolConfig( - function_calling_config=glm.FunctionCallingConfig( + testcase_name="test_protos.ToolConfig", + tool_config=protos.ToolConfig( + function_calling_config=protos.FunctionCallingConfig( mode=content_types.FunctionCallingMode.NONE ) ), @@ -773,7 +773,7 @@ def test_system_instruction(self, instruction, expected_instr): ) def test_count_tokens_smoke(self, kwargs): si = kwargs.pop("system_instruction", None) - self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)] + self.responses["count_tokens"] = [protos.CountTokensResponse(total_tokens=7)] model = generative_models.GenerativeModel("gemini-pro-vision", system_instruction=si) response = model.count_tokens(**kwargs) self.assertEqual(type(response).to_dict(response), {"total_tokens": 7}) @@ -840,7 +840,7 @@ def test_repr_for_unary_non_streamed_response(self): GenerateContentResponse( done=True, iterator=None, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -873,7 +873,7 @@ def test_repr_for_streaming_start_to_finish(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -898,7 +898,7 @@ def test_repr_for_streaming_start_to_finish(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -927,7 +927,7 @@ def test_repr_for_streaming_start_to_finish(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -951,7 +951,7 @@ def test_repr_for_streaming_start_to_finish(self): def test_repr_error_info_for_stream_prompt_feedback_blocked(self): # response._error => BlockedPromptException chunks = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "prompt_feedback": {"block_reason": "SAFETY"}, } @@ -969,7 +969,7 @@ def test_repr_error_info_for_stream_prompt_feedback_blocked(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "prompt_feedback": { "block_reason": "SAFETY" } @@ -1019,7 +1019,7 @@ def no_throw(): GenerateContentResponse( done=True, iterator=None, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -1049,7 +1049,7 @@ def test_repr_error_info_for_chat_streaming_unexpected_stop(self): simple_response("a"), simple_response("b"), simple_response("c"), - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"finish_reason": "SAFETY"}], } @@ -1078,7 +1078,7 @@ def test_repr_error_info_for_chat_streaming_unexpected_stop(self): GenerateContentResponse( done=True, iterator=None, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -1141,7 +1141,7 @@ def test_repr_for_multi_turn_chat(self): tools=None, system_instruction=None, ), - history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), glm.Content({'parts': [{'text': 'first'}], 'role': 'model'}), glm.Content({'parts': [{'text': 'I also like this image.'}, {'inline_data': {'data': 'iVBORw0KGgoA...AAElFTkSuQmCC', 'mime_type': 'image/png'}}], 'role': 'user'}), glm.Content({'parts': [{'text': 'second'}], 'role': 'model'}), glm.Content({'parts': [{'text': 'What things do I like?.'}], 'role': 'user'}), glm.Content({'parts': [{'text': 'third'}], 'role': 'model'})] + history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), protos.Content({'parts': [{'text': 'first'}], 'role': 'model'}), protos.Content({'parts': [{'text': 'I also like this image.'}, {'inline_data': {'data': 'iVBORw0KGgoA...AAElFTkSuQmCC', 'mime_type': 'image/png'}}], 'role': 'user'}), protos.Content({'parts': [{'text': 'second'}], 'role': 'model'}), protos.Content({'parts': [{'text': 'What things do I like?.'}], 'role': 'user'}), protos.Content({'parts': [{'text': 'third'}], 'role': 'model'})] )""" ) self.assertEqual(expected, result) @@ -1169,7 +1169,7 @@ def test_repr_for_incomplete_streaming_chat(self): tools=None, system_instruction=None, ), - history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] + history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )""" ) self.assertEqual(expected, result) @@ -1185,7 +1185,7 @@ def test_repr_for_broken_streaming_chat(self): for chunk in [ simple_response("first"), # FinishReason.SAFETY = 3 - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [ {"finish_reason": 3, "content": {"parts": [{"text": "second"}]}} @@ -1213,7 +1213,7 @@ def test_repr_for_broken_streaming_chat(self): tools=None, system_instruction=None, ), - history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] + history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )""" ) self.assertEqual(expected, result) @@ -1224,7 +1224,7 @@ def test_repr_for_system_instruction(self): self.assertIn("system_instruction='Be excellent.'", result) def test_count_tokens_called_with_request_options(self): - self.responses["count_tokens"].append(glm.CountTokensResponse()) + self.responses["count_tokens"].append(protos.CountTokensResponse(total_tokens=7)) request_options = {"timeout": 120} model = generative_models.GenerativeModel("gemini-pro-vision") @@ -1234,7 +1234,7 @@ def test_count_tokens_called_with_request_options(self): def test_chat_with_request_options(self): self.responses["generate_content"].append( - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"finish_reason": "STOP"}], } diff --git a/tests/test_generative_models_async.py b/tests/test_generative_models_async.py index 2c465d1d3..03055ffb3 100644 --- a/tests/test_generative_models_async.py +++ b/tests/test_generative_models_async.py @@ -24,14 +24,16 @@ from google.generativeai import client as client_lib from google.generativeai import generative_models from google.generativeai.types import content_types -import google.ai.generativelanguage as glm +from google.generativeai import protos from absl.testing import absltest from absl.testing import parameterized -def simple_response(text: str) -> glm.GenerateContentResponse: - return glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": text}]}}]}) +def simple_response(text: str) -> protos.GenerateContentResponse: + return protos.GenerateContentResponse( + {"candidates": [{"content": {"parts": [{"text": text}]}}]} + ) class AsyncTests(parameterized.TestCase, unittest.IsolatedAsyncioTestCase): @@ -50,28 +52,28 @@ def add_client_method(f): @add_client_method async def generate_content( - request: glm.GenerateContentRequest, + request: protos.GenerateContentRequest, **kwargs, - ) -> glm.GenerateContentResponse: - self.assertIsInstance(request, glm.GenerateContentRequest) + ) -> protos.GenerateContentResponse: + self.assertIsInstance(request, protos.GenerateContentRequest) self.observed_requests.append(request) response = self.responses["generate_content"].pop(0) return response @add_client_method async def stream_generate_content( - request: glm.GetModelRequest, + request: protos.GetModelRequest, **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: + ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) response = self.responses["stream_generate_content"].pop(0) return response @add_client_method async def count_tokens( - request: glm.CountTokensRequest, + request: protos.CountTokensRequest, **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: + ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) response = self.responses["count_tokens"].pop(0) return response @@ -140,9 +142,9 @@ async def responses(): }, ), dict( - testcase_name="test_glm_FunctionCallingConfig", + testcase_name="test_protos.FunctionCallingConfig", tool_config={ - "function_calling_config": glm.FunctionCallingConfig( + "function_calling_config": protos.FunctionCallingConfig( mode=content_types.FunctionCallingMode.AUTO ) }, @@ -169,9 +171,9 @@ async def responses(): }, ), dict( - testcase_name="test_glm_ToolConfig", - tool_config=glm.ToolConfig( - function_calling_config=glm.FunctionCallingConfig( + testcase_name="test_protos.ToolConfig", + tool_config=protos.ToolConfig( + function_calling_config=protos.FunctionCallingConfig( mode=content_types.FunctionCallingMode.NONE ) ), @@ -211,7 +213,7 @@ async def test_tool_config(self, tool_config, expected_tool_config): ["contents", [{"role": "user", "parts": ["hello"]}]], ) async def test_count_tokens_smoke(self, contents): - self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)] + self.responses["count_tokens"] = [protos.CountTokensResponse(total_tokens=7)] model = generative_models.GenerativeModel("gemini-pro-vision") response = await model.count_tokens_async(contents) self.assertEqual(type(response).to_dict(response), {"total_tokens": 7}) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 0c2de7f29..f060caf88 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -19,7 +19,7 @@ from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import client from google.generativeai import models @@ -35,15 +35,15 @@ def __init__(self, test): def get_model( self, - request: Union[glm.GetModelRequest, None] = None, + request: Union[protos.GetModelRequest, None] = None, *, name=None, timeout=None, retry=None - ) -> glm.Model: + ) -> protos.Model: if request is None: - request = glm.GetModelRequest(name=name) - self.test.assertIsInstance(request, glm.GetModelRequest) + request = protos.GetModelRequest(name=name) + self.test.assertIsInstance(request, protos.GetModelRequest) self.test.observed_requests.append(request) self.test.observed_timeout.append(timeout) self.test.observed_retry.append(retry) @@ -75,7 +75,7 @@ def setUp(self): ], ) def test_get_model(self, request_options, expected_timeout, expected_retry): - self.responses = {"get_model": glm.Model(name="models/fake-bison-001")} + self.responses = {"get_model": protos.Model(name="models/fake-bison-001")} _ = models.get_model("models/fake-bison-001", request_options=request_options) diff --git a/tests/test_models.py b/tests/test_models.py index f39ed3a2c..23f80913a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -25,7 +25,7 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.api_core import operation from google.generativeai import models @@ -45,7 +45,7 @@ def setUp(self): client._client_manager.clients["model"] = self.client # TODO(markdaoust): Check if typechecking works better if wee define this as a - # subclass of `glm.ModelServiceClient`, would pyi files for `glm` help? + # subclass of `glm.ModelServiceClient`, would pyi files for `glm`. help? def add_client_method(f): name = f.__name__ setattr(self.client, name, f) @@ -55,63 +55,65 @@ def add_client_method(f): self.responses = {} @add_client_method - def get_model(request: Union[glm.GetModelRequest, None] = None, *, name=None) -> glm.Model: + def get_model( + request: Union[protos.GetModelRequest, None] = None, *, name=None + ) -> protos.Model: if request is None: - request = glm.GetModelRequest(name=name) - self.assertIsInstance(request, glm.GetModelRequest) + request = protos.GetModelRequest(name=name) + self.assertIsInstance(request, protos.GetModelRequest) self.observed_requests.append(request) response = copy.copy(self.responses["get_model"]) return response @add_client_method def get_tuned_model( - request: Union[glm.GetTunedModelRequest, None] = None, + request: Union[protos.GetTunedModelRequest, None] = None, *, name=None, **kwargs, - ) -> glm.TunedModel: + ) -> protos.TunedModel: if request is None: - request = glm.GetTunedModelRequest(name=name) - self.assertIsInstance(request, glm.GetTunedModelRequest) + request = protos.GetTunedModelRequest(name=name) + self.assertIsInstance(request, protos.GetTunedModelRequest) self.observed_requests.append(request) response = copy.copy(self.responses["get_tuned_model"]) return response @add_client_method def list_models( - request: Union[glm.ListModelsRequest, None] = None, + request: Union[protos.ListModelsRequest, None] = None, *, page_size=None, page_token=None, **kwargs, - ) -> glm.ListModelsResponse: + ) -> protos.ListModelsResponse: if request is None: - request = glm.ListModelsRequest(page_size=page_size, page_token=page_token) - self.assertIsInstance(request, glm.ListModelsRequest) + request = protos.ListModelsRequest(page_size=page_size, page_token=page_token) + self.assertIsInstance(request, protos.ListModelsRequest) self.observed_requests.append(request) response = self.responses["list_models"] return (item for item in response) @add_client_method def list_tuned_models( - request: glm.ListTunedModelsRequest = None, + request: protos.ListTunedModelsRequest = None, *, page_size=None, page_token=None, **kwargs, - ) -> Iterable[glm.TunedModel]: + ) -> Iterable[protos.TunedModel]: if request is None: - request = glm.ListTunedModelsRequest(page_size=page_size, page_token=page_token) - self.assertIsInstance(request, glm.ListTunedModelsRequest) + request = protos.ListTunedModelsRequest(page_size=page_size, page_token=page_token) + self.assertIsInstance(request, protos.ListTunedModelsRequest) self.observed_requests.append(request) response = self.responses["list_tuned_models"] return (item for item in response) @add_client_method def update_tuned_model( - request: glm.UpdateTunedModelRequest, + request: protos.UpdateTunedModelRequest, **kwargs, - ) -> glm.TunedModel: + ) -> protos.TunedModel: self.observed_requests.append(request) response = self.responses.get("update_tuned_model", None) if response is None: @@ -120,7 +122,7 @@ def update_tuned_model( @add_client_method def delete_tuned_model(name): - request = glm.DeleteTunedModelRequest(name=name) + request = protos.DeleteTunedModelRequest(name=name) self.observed_requests.append(request) response = True return response @@ -130,26 +132,26 @@ def create_tuned_model( request, **kwargs, ): - request = glm.CreateTunedModelRequest(request) + request = protos.CreateTunedModelRequest(request) self.observed_requests.append(request) return self.responses["create_tuned_model"] def test_decode_tuned_model_time_round_trip(self): example_dt = datetime.datetime(2000, 1, 2, 3, 4, 5, 600_000, pytz.UTC) - tuned_model = glm.TunedModel(name="tunedModels/house-mouse-001", create_time=example_dt) + tuned_model = protos.TunedModel(name="tunedModels/house-mouse-001", create_time=example_dt) tuned_model = model_types.decode_tuned_model(tuned_model) self.assertEqual(tuned_model.create_time, example_dt) @parameterized.named_parameters( ["simple", "models/fake-bison-001"], ["simple-tuned", "tunedModels/my-pig-001"], - ["model-instance", glm.Model(name="models/fake-bison-001")], - ["tuned-model-instance", glm.TunedModel(name="tunedModels/my-pig-001")], + ["model-instance", protos.Model(name="models/fake-bison-001")], + ["tuned-model-instance", protos.TunedModel(name="tunedModels/my-pig-001")], ) def test_get_model(self, name): self.responses = { - "get_model": glm.Model(name="models/fake-bison-001"), - "get_tuned_model": glm.TunedModel(name="tunedModels/my-pig-001"), + "get_model": protos.Model(name="models/fake-bison-001"), + "get_tuned_model": protos.TunedModel(name="tunedModels/my-pig-001"), } model = models.get_model(name) @@ -160,7 +162,7 @@ def test_get_model(self, name): @parameterized.named_parameters( ["simple", "mystery-bison-001"], - ["model-instance", glm.Model(name="how?-bison-001")], + ["model-instance", protos.Model(name="how?-bison-001")], ) def test_fail_with_unscoped_model_name(self, name): with self.assertRaises(ValueError): @@ -170,9 +172,9 @@ def test_list_models(self): # The low level lib wraps the response in an iterable, so this is a fair test. self.responses = { "list_models": [ - glm.Model(name="models/fake-bison-001"), - glm.Model(name="models/fake-bison-002"), - glm.Model(name="models/fake-bison-003"), + protos.Model(name="models/fake-bison-001"), + protos.Model(name="models/fake-bison-002"), + protos.Model(name="models/fake-bison-003"), ] } @@ -185,9 +187,9 @@ def test_list_tuned_models(self): self.responses = { # The low level lib wraps the response in an iterable, so this is a fair test. "list_tuned_models": [ - glm.TunedModel(name="tunedModels/my-pig-001"), - glm.TunedModel(name="tunedModels/my-pig-002"), - glm.TunedModel(name="tunedModels/my-pig-003"), + protos.TunedModel(name="tunedModels/my-pig-001"), + protos.TunedModel(name="tunedModels/my-pig-002"), + protos.TunedModel(name="tunedModels/my-pig-003"), ] } found_models = list(models.list_tuned_models()) @@ -197,8 +199,8 @@ def test_list_tuned_models(self): @parameterized.named_parameters( [ - "edited-glm-model", - glm.TunedModel( + "edited-protos.model", + protos.TunedModel( name="tunedModels/my-pig-001", description="Trained on my data", ), @@ -211,7 +213,7 @@ def test_list_tuned_models(self): ], ) def test_update_tuned_model_basics(self, tuned_model, updates): - self.responses["get_tuned_model"] = glm.TunedModel(name="tunedModels/my-pig-001") + self.responses["get_tuned_model"] = protos.TunedModel(name="tunedModels/my-pig-001") # No self.responses['update_tuned_model'] the mock just returns the input. updated_model = models.update_tuned_model(tuned_model, updates) updated_model.description = "Trained on my data" @@ -227,7 +229,7 @@ def test_update_tuned_model_basics(self, tuned_model, updates): ], ) def test_update_tuned_model_nested_fields(self, updates): - self.responses["get_tuned_model"] = glm.TunedModel( + self.responses["get_tuned_model"] = protos.TunedModel( name="tunedModels/my-pig-001", base_model="models/dance-monkey-007" ) @@ -250,8 +252,8 @@ def test_update_tuned_model_nested_fields(self, updates): @parameterized.named_parameters( ["name", "tunedModels/bipedal-pangolin-223"], [ - "glm.TunedModel", - glm.TunedModel(name="tunedModels/bipedal-pangolin-223"), + "protos.TunedModel", + protos.TunedModel(name="tunedModels/bipedal-pangolin-223"), ], [ "models.TunedModel", @@ -275,23 +277,23 @@ def test_decode_micros(self, time_str, micros): self.assertEqual(time["time"].microsecond, micros) def test_decode_tuned_model(self): - out_fields = glm.TunedModel( - state=glm.TunedModel.State.CREATING, + out_fields = protos.TunedModel( + state=protos.TunedModel.State.CREATING, create_time="2000-01-01T01:01:01.0Z", update_time="2001-01-01T01:01:01.0Z", - tuning_task=glm.TuningTask( - hyperparameters=glm.Hyperparameters( + tuning_task=protos.TuningTask( + hyperparameters=protos.Hyperparameters( batch_size=72, epoch_count=1, learning_rate=0.1 ), start_time="2002-01-01T01:01:01.0Z", complete_time="2003-01-01T01:01:01.0Z", snapshots=[ - glm.TuningSnapshot( + protos.TuningSnapshot( step=1, epoch=1, compute_time="2004-01-01T01:01:01.0Z", ), - glm.TuningSnapshot( + protos.TuningSnapshot( step=2, epoch=1, compute_time="2005-01-01T01:01:01.0Z", @@ -301,7 +303,7 @@ def test_decode_tuned_model(self): ) decoded = model_types.decode_tuned_model(out_fields) - self.assertEqual(decoded.state, glm.TunedModel.State.CREATING) + self.assertEqual(decoded.state, protos.TunedModel.State.CREATING) self.assertEqual(decoded.create_time.year, 2000) self.assertEqual(decoded.update_time.year, 2001) self.assertIsInstance(decoded.tuning_task.hyperparameters, model_types.Hyperparameters) @@ -314,10 +316,10 @@ def test_decode_tuned_model(self): self.assertEqual(decoded.tuning_task.snapshots[1]["compute_time"].year, 2005) @parameterized.named_parameters( - ["simple", glm.TunedModel(base_model="models/swim-fish-000")], + ["simple", protos.TunedModel(base_model="models/swim-fish-000")], [ "nested", - glm.TunedModel( + protos.TunedModel( tuned_model_source={ "tuned_model": "tunedModels/hidden-fish-55", "base_model": "models/swim-fish-000", @@ -341,7 +343,7 @@ def test_smoke_create_tuned_model(self): training_data=[ ("in", "out"), {"text_input": "in", "output": "out"}, - glm.TuningExample(text_input="in", output="out"), + protos.TuningExample(text_input="in", output="out"), ], ) req = self.observed_requests[-1] @@ -351,10 +353,10 @@ def test_smoke_create_tuned_model(self): self.assertLen(req.tuned_model.tuning_task.training_data.examples.examples, 3) @parameterized.named_parameters( - ["simple", glm.TunedModel(base_model="models/swim-fish-000")], + ["simple", protos.TunedModel(base_model="models/swim-fish-000")], [ "nested", - glm.TunedModel( + protos.TunedModel( tuned_model_source={ "tuned_model": "tunedModels/hidden-fish-55", "base_model": "models/swim-fish-000", @@ -380,9 +382,9 @@ def test_create_tuned_model_on_tuned_model(self, tuned_source): @parameterized.named_parameters( [ - "glm", - glm.Dataset( - examples=glm.TuningExamples( + "protos", + protos.Dataset( + examples=protos.TuningExamples( examples=[ {"text_input": "a", "output": "1"}, {"text_input": "b", "output": "2"}, @@ -396,7 +398,7 @@ def test_create_tuned_model_on_tuned_model(self, tuned_source): [ ("a", "1"), {"text_input": "b", "output": "2"}, - glm.TuningExample({"text_input": "c", "output": "3"}), + protos.TuningExample({"text_input": "c", "output": "3"}), ], ], ["dict", {"text_input": ["a", "b", "c"], "output": ["1", "2", "3"]}], @@ -445,8 +447,8 @@ def test_create_tuned_model_on_tuned_model(self, tuned_source): def test_create_dataset(self, data, ik="text_input", ok="output"): ds = model_types.encode_tuning_data(data, input_key=ik, output_key=ok) - expect = glm.Dataset( - examples=glm.TuningExamples( + expect = protos.Dataset( + examples=protos.TuningExamples( examples=[ {"text_input": "a", "output": "1"}, {"text_input": "b", "output": "2"}, @@ -502,7 +504,7 @@ def test_update_tuned_model_called_with_request_options(self): self.client.update_tuned_model = unittest.mock.MagicMock() request = unittest.mock.ANY request_options = {"timeout": 120} - self.responses["get_tuned_model"] = glm.TunedModel(name="tunedModels/") + self.responses["get_tuned_model"] = protos.TunedModel(name="tunedModels/") try: models.update_tuned_model( @@ -534,7 +536,7 @@ def test_create_tuned_model_called_with_request_options(self): training_data=[ ("in", "out"), {"text_input": "in", "output": "out"}, - glm.TuningExample(text_input="in", output="out"), + protos.TuningExample(text_input="in", output="out"), ], request_options=request_options, ) diff --git a/tests/test_operations.py b/tests/test_operations.py index 80262db88..6529b77e5 100644 --- a/tests/test_operations.py +++ b/tests/test_operations.py @@ -16,7 +16,7 @@ from contextlib import redirect_stderr import io -import google.ai.generativelanguage as glm +from google.generativeai import protos import google.protobuf.any_pb2 import google.generativeai.operations as genai_operation @@ -41,7 +41,7 @@ def test_end_to_end(self): # `Any` takes a type name and a serialized proto. metadata = google.protobuf.any_pb2.Any( type_url=self.metadata_type, - value=glm.CreateTunedModelMetadata(tuned_model=name)._pb.SerializeToString(), + value=protos.CreateTunedModelMetadata(tuned_model=name)._pb.SerializeToString(), ) # Initially the `Operation` is not `done`, so it only gives a metadata. @@ -58,7 +58,7 @@ def test_end_to_end(self): metadata=metadata, response=google.protobuf.any_pb2.Any( type_url=self.result_type, - value=glm.TunedModel(name=name)._pb.SerializeToString(), + value=protos.TunedModel(name=name)._pb.SerializeToString(), ), ) @@ -72,8 +72,8 @@ def refresh(*_, **__): operation=initial_pb, refresh=refresh, cancel=lambda: print(f"cancel!"), - result_type=glm.TunedModel, - metadata_type=glm.CreateTunedModelMetadata, + result_type=protos.TunedModel, + metadata_type=protos.CreateTunedModelMetadata, ) # Use our wrapper instead. @@ -99,7 +99,7 @@ def gen_operations(): def make_metadata(completed_steps): return google.protobuf.any_pb2.Any( type_url=self.metadata_type, - value=glm.CreateTunedModelMetadata( + value=protos.CreateTunedModelMetadata( tuned_model=name, total_steps=total_steps, completed_steps=completed_steps, @@ -122,7 +122,7 @@ def make_metadata(completed_steps): metadata=make_metadata(total_steps), response=google.protobuf.any_pb2.Any( type_url=self.result_type, - value=glm.TunedModel(name=name)._pb.SerializeToString(), + value=protos.TunedModel(name=name)._pb.SerializeToString(), ), ) @@ -142,8 +142,8 @@ def refresh(*_, **__): operation=initial_pb, refresh=refresh, cancel=None, - result_type=glm.TunedModel, - metadata_type=glm.CreateTunedModelMetadata, + result_type=protos.TunedModel, + metadata_type=protos.CreateTunedModelMetadata, ) # Use our wrapper instead. diff --git a/tests/test_permission.py b/tests/test_permission.py index 55ad7a2f0..66b396977 100644 --- a/tests/test_permission.py +++ b/tests/test_permission.py @@ -17,7 +17,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import retriever from google.generativeai import permission @@ -50,11 +50,11 @@ def add_client_method(f): @add_client_method def create_corpus( - request: glm.CreateCorpusRequest, + request: protos.CreateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -63,24 +63,24 @@ def create_corpus( @add_client_method def get_tuned_model( - request: Optional[glm.GetTunedModelRequest] = None, + request: Optional[protos.GetTunedModelRequest] = None, *, name=None, **kwargs, - ) -> glm.TunedModel: + ) -> protos.TunedModel: if request is None: - request = glm.GetTunedModelRequest(name=name) - self.assertIsInstance(request, glm.GetTunedModelRequest) + request = protos.GetTunedModelRequest(name=name) + self.assertIsInstance(request, protos.GetTunedModelRequest) self.observed_requests.append(request) response = copy.copy(self.responses["get_tuned_model"]) return response @add_client_method def create_permission( - request: glm.CreatePermissionRequest, - ) -> glm.Permission: + request: protos.CreatePermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -88,17 +88,17 @@ def create_permission( @add_client_method def delete_permission( - request: glm.DeletePermissionRequest, + request: protos.DeletePermissionRequest, ) -> None: self.observed_requests.append(request) return None @add_client_method def get_permission( - request: glm.GetPermissionRequest, - ) -> glm.Permission: + request: protos.GetPermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -106,16 +106,16 @@ def get_permission( @add_client_method def list_permissions( - request: glm.ListPermissionsRequest, - ) -> glm.ListPermissionsResponse: + request: protos.ListPermissionsRequest, + ) -> protos.ListPermissionsResponse: self.observed_requests.append(request) return [ - glm.Permission( + protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), ), - glm.Permission( + protos.Permission( name="corpora/demo-corpus/permissions/987654321", role=permission_services.to_role("reader"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -125,10 +125,10 @@ def list_permissions( @add_client_method def update_permission( - request: glm.UpdatePermissionRequest, - ) -> glm.Permission: + request: protos.UpdatePermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("reader"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -136,16 +136,16 @@ def update_permission( @add_client_method def transfer_ownership( - request: glm.TransferOwnershipRequest, - ) -> glm.TransferOwnershipResponse: + request: protos.TransferOwnershipRequest, + ) -> protos.TransferOwnershipResponse: self.observed_requests.append(request) - return glm.TransferOwnershipResponse() + return protos.TransferOwnershipResponse() def test_create_permission_success(self): x = retriever.create_corpus("demo-corpus") perm = x.permissions.create(role="writer", grantee_type="everyone", email_address=None) self.assertIsInstance(perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.CreatePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.CreatePermissionRequest) def test_create_permission_failure_email_set_when_grantee_type_is_everyone(self): x = retriever.create_corpus("demo-corpus") @@ -161,14 +161,14 @@ def test_delete_permission(self): x = retriever.create_corpus("demo-corpus") perm = x.permissions.create("writer", "everyone") perm.delete() - self.assertIsInstance(self.observed_requests[-1], glm.DeletePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeletePermissionRequest) def test_get_permission_with_full_name(self): x = retriever.create_corpus("demo-corpus") perm = x.permissions.create("writer", "everyone") fetch_perm = permission.get_permission(name=perm.name) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) self.assertEqual(fetch_perm, perm) def test_get_permission_with_resource_name_and_id_1(self): @@ -178,7 +178,7 @@ def test_get_permission_with_resource_name_and_id_1(self): resource_name="corpora/demo-corpus", permission_id=123456789 ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) self.assertEqual(fetch_perm, perm) def test_get_permission_with_resource_name_name_and_id_2(self): @@ -186,14 +186,14 @@ def test_get_permission_with_resource_name_name_and_id_2(self): resource_name="tunedModels/demo-corpus", permission_id=123456789 ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) def test_get_permission_with_resource_type(self): fetch_perm = permission.get_permission( resource_name="demo-model", permission_id=123456789, resource_type="tunedModels" ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) @parameterized.named_parameters( dict( @@ -257,14 +257,14 @@ def test_list_permission(self): self.assertEqual(perms[1].email_address, "_") for perm in perms: self.assertIsInstance(perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.ListPermissionsRequest) + self.assertIsInstance(self.observed_requests[-1], protos.ListPermissionsRequest) def test_update_permission_success(self): x = retriever.create_corpus("demo-corpus") perm = x.permissions.create("writer", "everyone") updated_perm = perm.update({"role": permission_services.to_role("reader")}) self.assertIsInstance(updated_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.UpdatePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.UpdatePermissionRequest) def test_update_permission_failure_restricted_update_path(self): x = retriever.create_corpus("demo-corpus") @@ -275,12 +275,12 @@ def test_update_permission_failure_restricted_update_path(self): ) def test_transfer_ownership(self): - self.responses["get_tuned_model"] = glm.TunedModel( + self.responses["get_tuned_model"] = protos.TunedModel( name="tunedModels/fake-pig-001", base_model="models/dance-monkey-007" ) x = models.get_tuned_model("tunedModels/fake-pig-001") response = x.permissions.transfer_ownership(email_address="_") - self.assertIsInstance(self.observed_requests[-1], glm.TransferOwnershipRequest) + self.assertIsInstance(self.observed_requests[-1], protos.TransferOwnershipRequest) def test_transfer_ownership_on_corpora(self): x = retriever.create_corpus("demo-corpus") diff --git a/tests/test_permission_async.py b/tests/test_permission_async.py index 165039122..ddc9c22a2 100644 --- a/tests/test_permission_async.py +++ b/tests/test_permission_async.py @@ -17,7 +17,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import retriever from google.generativeai import permission @@ -49,11 +49,11 @@ def add_client_method(f): @add_client_method async def create_corpus( - request: glm.CreateCorpusRequest, + request: protos.CreateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -62,24 +62,24 @@ async def create_corpus( @add_client_method def get_tuned_model( - request: Optional[glm.GetTunedModelRequest] = None, + request: Optional[protos.GetTunedModelRequest] = None, *, name=None, **kwargs, - ) -> glm.TunedModel: + ) -> protos.TunedModel: if request is None: - request = glm.GetTunedModelRequest(name=name) - self.assertIsInstance(request, glm.GetTunedModelRequest) + request = protos.GetTunedModelRequest(name=name) + self.assertIsInstance(request, protos.GetTunedModelRequest) self.observed_requests.append(request) response = copy.copy(self.responses["get_tuned_model"]) return response @add_client_method async def create_permission( - request: glm.CreatePermissionRequest, - ) -> glm.Permission: + request: protos.CreatePermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -87,17 +87,17 @@ async def create_permission( @add_client_method async def delete_permission( - request: glm.DeletePermissionRequest, + request: protos.DeletePermissionRequest, ) -> None: self.observed_requests.append(request) return None @add_client_method async def get_permission( - request: glm.GetPermissionRequest, - ) -> glm.Permission: + request: protos.GetPermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -105,17 +105,17 @@ async def get_permission( @add_client_method async def list_permissions( - request: glm.ListPermissionsRequest, - ) -> glm.ListPermissionsResponse: + request: protos.ListPermissionsRequest, + ) -> protos.ListPermissionsResponse: self.observed_requests.append(request) async def results(): - yield glm.Permission( + yield protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), ) - yield glm.Permission( + yield protos.Permission( name="corpora/demo-corpus/permissions/987654321", role=permission_services.to_role("reader"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -126,10 +126,10 @@ async def results(): @add_client_method async def update_permission( - request: glm.UpdatePermissionRequest, - ) -> glm.Permission: + request: protos.UpdatePermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("reader"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -137,10 +137,10 @@ async def update_permission( @add_client_method async def transfer_ownership( - request: glm.TransferOwnershipRequest, - ) -> glm.TransferOwnershipResponse: + request: protos.TransferOwnershipRequest, + ) -> protos.TransferOwnershipResponse: self.observed_requests.append(request) - return glm.TransferOwnershipResponse() + return protos.TransferOwnershipResponse() async def test_create_permission_success(self): x = await retriever.create_corpus_async("demo-corpus") @@ -148,7 +148,7 @@ async def test_create_permission_success(self): role="writer", grantee_type="everyone", email_address=None ) self.assertIsInstance(perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.CreatePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.CreatePermissionRequest) async def test_create_permission_failure_email_set_when_grantee_type_is_everyone(self): x = await retriever.create_corpus_async("demo-corpus") @@ -168,14 +168,14 @@ async def test_delete_permission(self): x = await retriever.create_corpus_async("demo-corpus") perm = await x.permissions.create_async("writer", "everyone") await perm.delete_async() - self.assertIsInstance(self.observed_requests[-1], glm.DeletePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeletePermissionRequest) async def test_get_permission_with_full_name(self): x = await retriever.create_corpus_async("demo-corpus") perm = await x.permissions.create_async("writer", "everyone") fetch_perm = await permission.get_permission_async(name=perm.name) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) self.assertEqual(fetch_perm, perm) async def test_get_permission_with_resource_name_and_id_1(self): @@ -185,7 +185,7 @@ async def test_get_permission_with_resource_name_and_id_1(self): resource_name="corpora/demo-corpus", permission_id=123456789 ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) self.assertEqual(fetch_perm, perm) async def test_get_permission_with_resource_name_name_and_id_2(self): @@ -193,14 +193,14 @@ async def test_get_permission_with_resource_name_name_and_id_2(self): resource_name="tunedModels/demo-corpus", permission_id=123456789 ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) async def test_get_permission_with_resource_type(self): fetch_perm = await permission.get_permission_async( resource_name="demo-model", permission_id=123456789, resource_type="tunedModels" ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) @parameterized.named_parameters( dict( @@ -264,14 +264,14 @@ async def test_list_permission(self): self.assertEqual(perms[1].email_address, "_") for perm in perms: self.assertIsInstance(perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.ListPermissionsRequest) + self.assertIsInstance(self.observed_requests[-1], protos.ListPermissionsRequest) async def test_update_permission_success(self): x = await retriever.create_corpus_async("demo-corpus") perm = await x.permissions.create_async("writer", "everyone") updated_perm = await perm.update_async({"role": permission_services.to_role("reader")}) self.assertIsInstance(updated_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.UpdatePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.UpdatePermissionRequest) async def test_update_permission_failure_restricted_update_path(self): x = await retriever.create_corpus_async("demo-corpus") @@ -282,12 +282,12 @@ async def test_update_permission_failure_restricted_update_path(self): ) async def test_transfer_ownership(self): - self.responses["get_tuned_model"] = glm.TunedModel( + self.responses["get_tuned_model"] = protos.TunedModel( name="tunedModels/fake-pig-001", base_model="models/dance-monkey-007" ) x = models.get_tuned_model("tunedModels/fake-pig-001") response = await x.permissions.transfer_ownership_async(email_address="_") - self.assertIsInstance(self.observed_requests[-1], glm.TransferOwnershipRequest) + self.assertIsInstance(self.observed_requests[-1], protos.TransferOwnershipRequest) async def test_transfer_ownership_on_corpora(self): x = await retriever.create_corpus_async("demo-corpus") diff --git a/tests/test_protos.py b/tests/test_protos.py new file mode 100644 index 000000000..1b59b0c6e --- /dev/null +++ b/tests/test_protos.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pathlib +import re + +from absl.testing import parameterized + +ROOT = pathlib.Path(__file__).parent.parent + + +class UnitTests(parameterized.TestCase): + def test_check_glm_imports(self): + for fpath in ROOT.rglob("*.py"): + if fpath.name == "build_docs.py": + continue + content = fpath.read_text() + for match in re.findall("glm\.\w+", content): + self.assertIn( + "Client", + match, + msg=f"Bad `glm.` usage, use `genai.protos` instead,\n in {fpath}", + ) diff --git a/tests/test_responder.py b/tests/test_responder.py index 4eb310815..c075fc65a 100644 --- a/tests/test_responder.py +++ b/tests/test_responder.py @@ -17,7 +17,7 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import responder import IPython.display import PIL.Image @@ -42,9 +42,9 @@ class UnitTests(parameterized.TestCase): [ "FunctionLibrary", responder.FunctionLibrary( - tools=glm.Tool( + tools=protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -56,7 +56,7 @@ class UnitTests(parameterized.TestCase): [ responder.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -64,11 +64,11 @@ class UnitTests(parameterized.TestCase): ], ], [ - "IterableTool-glm.Tool", + "IterableTool-protos.Tool", [ - glm.Tool( + protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -93,7 +93,7 @@ class UnitTests(parameterized.TestCase): "IterableTool-IterableFD", [ [ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -103,7 +103,7 @@ class UnitTests(parameterized.TestCase): [ "IterableTool-FD", [ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -113,17 +113,17 @@ class UnitTests(parameterized.TestCase): "Tool", responder.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] ), ], [ - "glm.Tool", - glm.Tool( + "protos.Tool", + protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -175,8 +175,8 @@ class UnitTests(parameterized.TestCase): ), ], [ - "glm.FD", - glm.FunctionDeclaration( + "protos.FD", + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ), ], @@ -216,32 +216,32 @@ def b(): self.assertLen(tools[0].function_declarations, 2) @parameterized.named_parameters( - ["int", int, glm.Schema(type=glm.Type.INTEGER)], - ["float", float, glm.Schema(type=glm.Type.NUMBER)], - ["str", str, glm.Schema(type=glm.Type.STRING)], + ["int", int, protos.Schema(type=protos.Type.INTEGER)], + ["float", float, protos.Schema(type=protos.Type.NUMBER)], + ["str", str, protos.Schema(type=protos.Type.STRING)], [ "list", list[str], - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema(type=glm.Type.STRING), + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema(type=protos.Type.STRING), ), ], [ "list-list-int", list[list[int]], - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema( - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema(type=glm.Type.INTEGER), + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema( + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema(type=protos.Type.INTEGER), ), ), ), ], - ["dict", dict, glm.Schema(type=glm.Type.OBJECT)], - ["dict-str-any", dict[str, Any], glm.Schema(type=glm.Type.OBJECT)], + ["dict", dict, protos.Schema(type=protos.Type.OBJECT)], + ["dict-str-any", dict[str, Any], protos.Schema(type=protos.Type.OBJECT)], ) def test_auto_schema(self, annotation, expected): def fun(a: annotation): diff --git a/tests/test_retriever.py b/tests/test_retriever.py index 910183789..bce9a402b 100644 --- a/tests/test_retriever.py +++ b/tests/test_retriever.py @@ -16,7 +16,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import retriever from google.generativeai import client @@ -42,11 +42,11 @@ def add_client_method(f): @add_client_method def create_corpus( - request: glm.CreateCorpusRequest, + request: protos.CreateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo_corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -55,11 +55,11 @@ def create_corpus( @add_client_method def get_corpus( - request: glm.GetCorpusRequest, + request: protos.GetCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo_corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -68,11 +68,11 @@ def get_corpus( @add_client_method def update_corpus( - request: glm.UpdateCorpusRequest, + request: protos.UpdateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus-1", create_time="2000-01-01T01:01:01.123456Z", @@ -81,18 +81,18 @@ def update_corpus( @add_client_method def list_corpora( - request: glm.ListCorporaRequest, + request: protos.ListCorporaRequest, **kwargs, - ) -> glm.ListCorporaResponse: + ) -> protos.ListCorporaResponse: self.observed_requests.append(request) return [ - glm.Corpus( + protos.Corpus( name="corpora/demo_corpus-1", display_name="demo-corpus-1", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Corpus( + protos.Corpus( name="corpora/demo-corpus-2", display_name="demo-corpus-2", create_time="2000-01-01T01:01:01.123456Z", @@ -102,15 +102,15 @@ def list_corpora( @add_client_method def query_corpus( - request: glm.QueryCorpusRequest, + request: protos.QueryCorpusRequest, **kwargs, - ) -> glm.QueryCorpusResponse: + ) -> protos.QueryCorpusResponse: self.observed_requests.append(request) - return glm.QueryCorpusResponse( + return protos.QueryCorpusResponse( relevant_chunks=[ - glm.RelevantChunk( + protos.RelevantChunk( chunk_relevance_score=0.08, - chunk=glm.Chunk( + chunk=protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, custom_metadata=[], @@ -124,18 +124,18 @@ def query_corpus( @add_client_method def delete_corpus( - request: glm.DeleteCorpusRequest, + request: protos.DeleteCorpusRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method def create_document( - request: glm.CreateDocumentRequest, + request: protos.CreateDocumentRequest, **kwargs, ) -> retriever_service.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo-doc", display_name="demo-doc", create_time="2000-01-01T01:01:01.123456Z", @@ -144,11 +144,11 @@ def create_document( @add_client_method def get_document( - request: glm.GetDocumentRequest, + request: protos.GetDocumentRequest, **kwargs, ) -> retriever_service.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo_doc", display_name="demo-doc", create_time="2000-01-01T01:01:01.123456Z", @@ -157,11 +157,11 @@ def get_document( @add_client_method def update_document( - request: glm.UpdateDocumentRequest, + request: protos.UpdateDocumentRequest, **kwargs, - ) -> glm.Document: + ) -> protos.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo_doc", display_name="demo-doc-1", create_time="2000-01-01T01:01:01.123456Z", @@ -170,18 +170,18 @@ def update_document( @add_client_method def list_documents( - request: glm.ListDocumentsRequest, + request: protos.ListDocumentsRequest, **kwargs, - ) -> glm.ListDocumentsResponse: + ) -> protos.ListDocumentsResponse: self.observed_requests.append(request) return [ - glm.Document( + protos.Document( name="corpora/demo-corpus/documents/demo_doc_1", display_name="demo-doc-1", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Document( + protos.Document( name="corpora/demo-corpus/documents/demo_doc_2", display_name="demo-doc-2", create_time="2000-01-01T01:01:01.123456Z", @@ -191,22 +191,22 @@ def list_documents( @add_client_method def delete_document( - request: glm.DeleteDocumentRequest, + request: protos.DeleteDocumentRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method def query_document( - request: glm.QueryDocumentRequest, + request: protos.QueryDocumentRequest, **kwargs, - ) -> glm.QueryDocumentResponse: + ) -> protos.QueryDocumentResponse: self.observed_requests.append(request) - return glm.QueryDocumentResponse( + return protos.QueryDocumentResponse( relevant_chunks=[ - glm.RelevantChunk( + protos.RelevantChunk( chunk_relevance_score=0.08, - chunk=glm.Chunk( + chunk=protos.Chunk( name="demo-chunk", data={"string_value": "This is a demo chunk."}, custom_metadata=[], @@ -220,11 +220,11 @@ def query_document( @add_client_method def create_chunk( - request: glm.CreateChunkRequest, + request: protos.CreateChunkRequest, **kwargs, ) -> retriever_service.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -233,19 +233,19 @@ def create_chunk( @add_client_method def batch_create_chunks( - request: glm.BatchCreateChunksRequest, + request: protos.BatchCreateChunksRequest, **kwargs, - ) -> glm.BatchCreateChunksResponse: + ) -> protos.BatchCreateChunksResponse: self.observed_requests.append(request) - return glm.BatchCreateChunksResponse( + return protos.BatchCreateChunksResponse( chunks=[ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/dc", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/dc1", data={"string_value": "This is another demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -256,11 +256,11 @@ def batch_create_chunks( @add_client_method def get_chunk( - request: glm.GetChunkRequest, + request: protos.GetChunkRequest, **kwargs, ) -> retriever_service.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -269,18 +269,18 @@ def get_chunk( @add_client_method def list_chunks( - request: glm.ListChunksRequest, + request: protos.ListChunksRequest, **kwargs, - ) -> glm.ListChunksResponse: + ) -> protos.ListChunksResponse: self.observed_requests.append(request) return [ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/demo-chunk-1", data={"string_value": "This is another demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -290,17 +290,17 @@ def list_chunks( @add_client_method def update_chunk( - request: glm.UpdateChunkRequest, + request: protos.UpdateChunkRequest, **kwargs, - ) -> glm.Chunk: + ) -> protos.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is an updated demo chunk."}, custom_metadata=[ - glm.CustomMetadata( + protos.CustomMetadata( key="tags", - string_list_value=glm.StringList( + string_list_value=protos.StringList( values=["Google For Developers", "Project IDX", "Blog", "Announcement"] ), ) @@ -311,19 +311,19 @@ def update_chunk( @add_client_method def batch_update_chunks( - request: glm.BatchUpdateChunksRequest, + request: protos.BatchUpdateChunksRequest, **kwargs, - ) -> glm.BatchUpdateChunksResponse: + ) -> protos.BatchUpdateChunksResponse: self.observed_requests.append(request) - return glm.BatchUpdateChunksResponse( + return protos.BatchUpdateChunksResponse( chunks=[ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is an updated chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/demo-chunk-1", data={"string_value": "This is another updated chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -334,14 +334,14 @@ def batch_update_chunks( @add_client_method def delete_chunk( - request: glm.DeleteChunkRequest, + request: protos.DeleteChunkRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method def batch_delete_chunks( - request: glm.BatchDeleteChunksRequest, + request: protos.BatchDeleteChunksRequest, **kwargs, ) -> None: self.observed_requests.append(request) @@ -366,7 +366,7 @@ def test_get_corpus(self, name="demo-corpus"): def test_update_corpus(self): demo_corpus = retriever.create_corpus(name="demo-corpus") update_request = demo_corpus.update(updates={"display_name": "demo-corpus_1"}) - self.assertIsInstance(self.observed_requests[-1], glm.UpdateCorpusRequest) + self.assertIsInstance(self.observed_requests[-1], protos.UpdateCorpusRequest) self.assertEqual("demo-corpus_1", demo_corpus.display_name) def test_list_corpora(self): @@ -402,7 +402,7 @@ def test_delete_corpus(self): demo_corpus = retriever.create_corpus(name="demo-corpus") demo_document = demo_corpus.create_document(name="demo-doc") delete_request = retriever.delete_corpus(name="corpora/demo_corpus", force=True) - self.assertIsInstance(self.observed_requests[-1], glm.DeleteCorpusRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCorpusRequest) def test_create_document(self, display_name="demo-doc"): demo_corpus = retriever.create_corpus(name="demo-corpus") @@ -433,7 +433,7 @@ def test_delete_document(self): demo_document = demo_corpus.create_document(name="demo-doc") demo_doc2 = demo_corpus.create_document(name="demo-doc-2") delete_request = demo_corpus.delete_document(name="corpora/demo-corpus/documents/demo_doc") - self.assertIsInstance(self.observed_requests[-1], glm.DeleteDocumentRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteDocumentRequest) def test_list_documents(self): demo_corpus = retriever.create_corpus(name="demo-corpus") @@ -521,7 +521,7 @@ def test_batch_create_chunks(self, chunks): demo_corpus = retriever.create_corpus(name="demo-corpus") demo_document = demo_corpus.create_document(name="demo-doc") chunks = demo_document.batch_create_chunks(chunks=chunks) - self.assertIsInstance(self.observed_requests[-1], glm.BatchCreateChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchCreateChunksRequest) self.assertEqual("This is a demo chunk.", chunks[0].data.string_value) self.assertEqual("This is another demo chunk.", chunks[1].data.string_value) @@ -548,7 +548,7 @@ def test_list_chunks(self): ) list_req = list(demo_document.list_chunks()) - self.assertIsInstance(self.observed_requests[-1], glm.ListChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.ListChunksRequest) self.assertLen(list_req, 2) def test_update_chunk(self): @@ -615,7 +615,7 @@ def test_batch_update_chunks_data_structures(self, updates): data="This is another demo chunk.", ) update_request = demo_document.batch_update_chunks(chunks=updates) - self.assertIsInstance(self.observed_requests[-1], glm.BatchUpdateChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchUpdateChunksRequest) self.assertEqual( "This is an updated chunk.", update_request["chunks"][0]["data"]["string_value"] ) @@ -631,7 +631,7 @@ def test_delete_chunk(self): data="This is a demo chunk.", ) delete_request = demo_document.delete_chunk(name="demo-chunk") - self.assertIsInstance(self.observed_requests[-1], glm.DeleteChunkRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteChunkRequest) def test_batch_delete_chunks(self): demo_corpus = retriever.create_corpus(name="demo-corpus") @@ -645,7 +645,7 @@ def test_batch_delete_chunks(self): data="This is another demo chunk.", ) delete_request = demo_document.batch_delete_chunks(chunks=[x.name, y.name]) - self.assertIsInstance(self.observed_requests[-1], glm.BatchDeleteChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchDeleteChunksRequest) @parameterized.parameters( {"method": "create_corpus"}, diff --git a/tests/test_retriever_async.py b/tests/test_retriever_async.py index b764c23b2..bb0c862d1 100644 --- a/tests/test_retriever_async.py +++ b/tests/test_retriever_async.py @@ -19,7 +19,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import retriever from google.generativeai import client as client_lib @@ -44,11 +44,11 @@ def add_client_method(f): @add_client_method async def create_corpus( - request: glm.CreateCorpusRequest, + request: protos.CreateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -57,11 +57,11 @@ async def create_corpus( @add_client_method async def get_corpus( - request: glm.GetCorpusRequest, + request: protos.GetCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -70,11 +70,11 @@ async def get_corpus( @add_client_method async def update_corpus( - request: glm.UpdateCorpusRequest, + request: protos.UpdateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus-1", create_time="2000-01-01T01:01:01.123456Z", @@ -83,19 +83,19 @@ async def update_corpus( @add_client_method async def list_corpora( - request: glm.ListCorporaRequest, + request: protos.ListCorporaRequest, **kwargs, - ) -> glm.ListCorporaResponse: + ) -> protos.ListCorporaResponse: self.observed_requests.append(request) async def results(): - yield glm.Corpus( + yield protos.Corpus( name="corpora/demo-corpus-1", display_name="demo-corpus-1", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ) - yield glm.Corpus( + yield protos.Corpus( name="corpora/demo-corpus_2", display_name="demo-corpus-2", create_time="2000-01-01T01:01:01.123456Z", @@ -106,15 +106,15 @@ async def results(): @add_client_method async def query_corpus( - request: glm.QueryCorpusRequest, + request: protos.QueryCorpusRequest, **kwargs, - ) -> glm.QueryCorpusResponse: + ) -> protos.QueryCorpusResponse: self.observed_requests.append(request) - return glm.QueryCorpusResponse( + return protos.QueryCorpusResponse( relevant_chunks=[ - glm.RelevantChunk( + protos.RelevantChunk( chunk_relevance_score=0.08, - chunk=glm.Chunk( + chunk=protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, custom_metadata=[], @@ -128,18 +128,18 @@ async def query_corpus( @add_client_method async def delete_corpus( - request: glm.DeleteCorpusRequest, + request: protos.DeleteCorpusRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method async def create_document( - request: glm.CreateDocumentRequest, + request: protos.CreateDocumentRequest, **kwargs, ) -> retriever_service.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo-doc", display_name="demo-doc", create_time="2000-01-01T01:01:01.123456Z", @@ -148,11 +148,11 @@ async def create_document( @add_client_method async def get_document( - request: glm.GetDocumentRequest, + request: protos.GetDocumentRequest, **kwargs, ) -> retriever_service.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo-doc", display_name="demo-doc", create_time="2000-01-01T01:01:01.123456Z", @@ -161,11 +161,11 @@ async def get_document( @add_client_method async def update_document( - request: glm.UpdateDocumentRequest, + request: protos.UpdateDocumentRequest, **kwargs, - ) -> glm.Document: + ) -> protos.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo-doc", display_name="demo-doc-1", create_time="2000-01-01T01:01:01.123456Z", @@ -174,19 +174,19 @@ async def update_document( @add_client_method async def list_documents( - request: glm.ListDocumentsRequest, + request: protos.ListDocumentsRequest, **kwargs, - ) -> glm.ListDocumentsResponse: + ) -> protos.ListDocumentsResponse: self.observed_requests.append(request) async def results(): - yield glm.Document( + yield protos.Document( name="corpora/demo-corpus/documents/dem-doc_1", display_name="demo-doc-1", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ) - yield glm.Document( + yield protos.Document( name="corpora/demo-corpus/documents/dem-doc_2", display_name="demo-doc_2", create_time="2000-01-01T01:01:01.123456Z", @@ -197,22 +197,22 @@ async def results(): @add_client_method async def delete_document( - request: glm.DeleteDocumentRequest, + request: protos.DeleteDocumentRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method async def query_document( - request: glm.QueryDocumentRequest, + request: protos.QueryDocumentRequest, **kwargs, - ) -> glm.QueryDocumentResponse: + ) -> protos.QueryDocumentResponse: self.observed_requests.append(request) - return glm.QueryDocumentResponse( + return protos.QueryDocumentResponse( relevant_chunks=[ - glm.RelevantChunk( + protos.RelevantChunk( chunk_relevance_score=0.08, - chunk=glm.Chunk( + chunk=protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, custom_metadata=[], @@ -226,11 +226,11 @@ async def query_document( @add_client_method async def create_chunk( - request: glm.CreateChunkRequest, + request: protos.CreateChunkRequest, **kwargs, ) -> retriever_service.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -239,19 +239,19 @@ async def create_chunk( @add_client_method async def batch_create_chunks( - request: glm.BatchCreateChunksRequest, + request: protos.BatchCreateChunksRequest, **kwargs, - ) -> glm.BatchCreateChunksResponse: + ) -> protos.BatchCreateChunksResponse: self.observed_requests.append(request) - return glm.BatchCreateChunksResponse( + return protos.BatchCreateChunksResponse( chunks=[ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/dc", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/dc1", data={"string_value": "This is another demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -262,11 +262,11 @@ async def batch_create_chunks( @add_client_method async def get_chunk( - request: glm.GetChunkRequest, + request: protos.GetChunkRequest, **kwargs, ) -> retriever_service.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -275,19 +275,19 @@ async def get_chunk( @add_client_method async def list_chunks( - request: glm.ListChunksRequest, + request: protos.ListChunksRequest, **kwargs, - ) -> glm.ListChunksResponse: + ) -> protos.ListChunksResponse: self.observed_requests.append(request) async def results(): - yield glm.Chunk( + yield protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ) - yield glm.Chunk( + yield protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk-1", data={"string_value": "This is another demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -298,11 +298,11 @@ async def results(): @add_client_method async def update_chunk( - request: glm.UpdateChunkRequest, + request: protos.UpdateChunkRequest, **kwargs, - ) -> glm.Chunk: + ) -> protos.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is an updated demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -311,19 +311,19 @@ async def update_chunk( @add_client_method async def batch_update_chunks( - request: glm.BatchUpdateChunksRequest, + request: protos.BatchUpdateChunksRequest, **kwargs, - ) -> glm.BatchUpdateChunksResponse: + ) -> protos.BatchUpdateChunksResponse: self.observed_requests.append(request) - return glm.BatchUpdateChunksResponse( + return protos.BatchUpdateChunksResponse( chunks=[ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is an updated chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk-1", data={"string_value": "This is another updated chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -334,14 +334,14 @@ async def batch_update_chunks( @add_client_method async def delete_chunk( - request: glm.DeleteChunkRequest, + request: protos.DeleteChunkRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method async def batch_delete_chunks( - request: glm.BatchDeleteChunksRequest, + request: protos.BatchDeleteChunksRequest, **kwargs, ) -> None: self.observed_requests.append(request) @@ -398,7 +398,7 @@ async def test_delete_corpus(self): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") demo_document = await demo_corpus.create_document_async(name="demo-doc") delete_request = await retriever.delete_corpus_async(name="corpora/demo-corpus", force=True) - self.assertIsInstance(self.observed_requests[-1], glm.DeleteCorpusRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCorpusRequest) async def test_create_document(self, display_name="demo-doc"): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") @@ -425,7 +425,7 @@ async def test_delete_document(self): delete_request = await demo_corpus.delete_document_async( name="corpora/demo-corpus/documents/demo-doc" ) - self.assertIsInstance(self.observed_requests[-1], glm.DeleteDocumentRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteDocumentRequest) async def test_list_documents(self): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") @@ -513,7 +513,7 @@ async def test_batch_create_chunks(self, chunks): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") demo_document = await demo_corpus.create_document_async(name="demo-doc") chunks = await demo_document.batch_create_chunks_async(chunks=chunks) - self.assertIsInstance(self.observed_requests[-1], glm.BatchCreateChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchCreateChunksRequest) self.assertEqual("This is a demo chunk.", chunks[0].data.string_value) self.assertEqual("This is another demo chunk.", chunks[1].data.string_value) @@ -541,7 +541,7 @@ async def test_list_chunks(self): chunks = [] async for chunk in demo_document.list_chunks_async(): chunks.append(chunk) - self.assertIsInstance(self.observed_requests[-1], glm.ListChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.ListChunksRequest) self.assertLen(chunks, 2) async def test_update_chunk(self): @@ -597,7 +597,7 @@ async def test_batch_update_chunks_data_structures(self, updates): data="This is another demo chunk.", ) update_request = await demo_document.batch_update_chunks_async(chunks=updates) - self.assertIsInstance(self.observed_requests[-1], glm.BatchUpdateChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchUpdateChunksRequest) self.assertEqual( "This is an updated chunk.", update_request["chunks"][0]["data"]["string_value"] ) @@ -615,7 +615,7 @@ async def test_delete_chunk(self): delete_request = await demo_document.delete_chunk_async( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk" ) - self.assertIsInstance(self.observed_requests[-1], glm.DeleteChunkRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteChunkRequest) async def test_batch_delete_chunks(self): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") @@ -629,7 +629,7 @@ async def test_batch_delete_chunks(self): data="This is another demo chunk.", ) delete_request = await demo_document.batch_delete_chunks_async(chunks=[x.name, y.name]) - self.assertIsInstance(self.observed_requests[-1], glm.BatchDeleteChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchDeleteChunksRequest) async def test_get_corpus_called_with_request_options(self): self.client.get_corpus = unittest.mock.AsyncMock() diff --git a/tests/test_safety.py b/tests/test_safety.py index f3efc4aca..2ac8aca46 100644 --- a/tests/test_safety.py +++ b/tests/test_safety.py @@ -15,26 +15,26 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm from google.generativeai.types import safety_types +from google.generativeai import protos class SafetyTests(parameterized.TestCase): """Tests are in order with the design doc.""" @parameterized.named_parameters( - ["block_threshold", glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE], + ["block_threshold", protos.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE], ["block_threshold2", "medium"], ["block_threshold3", 2], ["dict", {"danger": "medium"}], ["dict2", {"danger": 2}], - ["dict3", {"danger": glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}], + ["dict3", {"danger": protos.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}], [ "list-dict", [ dict( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, ), ], ], @@ -48,8 +48,8 @@ class SafetyTests(parameterized.TestCase): def test_safety_overwrite(self, setting): setting = safety_types.to_easy_safety_dict(setting) self.assertEqual( - setting[glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT], - glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + setting[protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT], + protos.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, ) diff --git a/tests/test_text.py b/tests/test_text.py index 5dcda93b9..795c3dfcd 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -18,7 +18,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import text as text_service from google.generativeai import client @@ -46,42 +46,42 @@ def add_client_method(f): @add_client_method def generate_text( - request: glm.GenerateTextRequest, + request: protos.GenerateTextRequest, **kwargs, - ) -> glm.GenerateTextResponse: + ) -> protos.GenerateTextResponse: self.observed_requests.append(request) return self.responses["generate_text"] @add_client_method def embed_text( - request: glm.EmbedTextRequest, + request: protos.EmbedTextRequest, **kwargs, - ) -> glm.EmbedTextResponse: + ) -> protos.EmbedTextResponse: self.observed_requests.append(request) return self.responses["embed_text"] @add_client_method def batch_embed_text( - request: glm.EmbedTextRequest, + request: protos.EmbedTextRequest, **kwargs, - ) -> glm.EmbedTextResponse: + ) -> protos.EmbedTextResponse: self.observed_requests.append(request) - return glm.BatchEmbedTextResponse( - embeddings=[glm.Embedding(value=[1, 2, 3])] * len(request.texts) + return protos.BatchEmbedTextResponse( + embeddings=[protos.Embedding(value=[1, 2, 3])] * len(request.texts) ) @add_client_method def count_text_tokens( - request: glm.CountTextTokensRequest, + request: protos.CountTextTokensRequest, **kwargs, - ) -> glm.CountTextTokensResponse: + ) -> protos.CountTextTokensResponse: self.observed_requests.append(request) return self.responses["count_text_tokens"] @add_client_method - def get_tuned_model(name) -> glm.TunedModel: - request = glm.GetTunedModelRequest(name=name) + def get_tuned_model(name) -> protos.TunedModel: + request = protos.GetTunedModelRequest(name=name) self.observed_requests.append(request) response = copy.copy(self.responses["get_tuned_model"]) return response @@ -93,7 +93,7 @@ def get_tuned_model(name) -> glm.TunedModel: ) def test_make_prompt(self, prompt): x = text_service._make_text_prompt(prompt) - self.assertIsInstance(x, glm.TextPrompt) + self.assertIsInstance(x, protos.TextPrompt) self.assertEqual("Hello how are", x.text) @parameterized.named_parameters( @@ -104,7 +104,7 @@ def test_make_prompt(self, prompt): def test_make_generate_text_request(self, prompt): x = text_service._make_generate_text_request(model="models/chat-bison-001", prompt=prompt) self.assertEqual("models/chat-bison-001", x.model) - self.assertIsInstance(x, glm.GenerateTextRequest) + self.assertIsInstance(x, protos.GenerateTextRequest) @parameterized.named_parameters( [ @@ -116,14 +116,16 @@ def test_make_generate_text_request(self, prompt): ] ) def test_generate_embeddings(self, model, text): - self.responses["embed_text"] = glm.EmbedTextResponse( - embedding=glm.Embedding(value=[1, 2, 3]) + self.responses["embed_text"] = protos.EmbedTextResponse( + embedding=protos.Embedding(value=[1, 2, 3]) ) emb = text_service.generate_embeddings(model=model, text=text) self.assertIsInstance(emb, dict) - self.assertEqual(self.observed_requests[-1], glm.EmbedTextRequest(model=model, text=text)) + self.assertEqual( + self.observed_requests[-1], protos.EmbedTextRequest(model=model, text=text) + ) self.assertIsInstance(emb["embedding"][0], float) @parameterized.named_parameters( @@ -191,11 +193,11 @@ def test_generate_embeddings_batch(self, model, text): ] ) def test_generate_response(self, *, prompt, **kwargs): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ - glm.TextCompletion(output=" road?"), - glm.TextCompletion(output=" bridge?"), - glm.TextCompletion(output=" river?"), + protos.TextCompletion(output=" road?"), + protos.TextCompletion(output=" bridge?"), + protos.TextCompletion(output=" river?"), ] ) @@ -203,8 +205,8 @@ def test_generate_response(self, *, prompt, **kwargs): self.assertEqual( self.observed_requests[-1], - glm.GenerateTextRequest( - model="models/text-bison-001", prompt=glm.TextPrompt(text=prompt), **kwargs + protos.GenerateTextRequest( + model="models/text-bison-001", prompt=protos.TextPrompt(text=prompt), **kwargs ), ) @@ -220,20 +222,20 @@ def test_generate_response(self, *, prompt, **kwargs): ) def test_stop_string(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ - glm.TextCompletion(output="Hello world?"), - glm.TextCompletion(output="Hell!"), - glm.TextCompletion(output="I'm going to stop"), + protos.TextCompletion(output="Hello world?"), + protos.TextCompletion(output="Hell!"), + protos.TextCompletion(output="I'm going to stop"), ] ) complete = text_service.generate_text(prompt="Hello", stop_sequences="stop") self.assertEqual( self.observed_requests[-1], - glm.GenerateTextRequest( + protos.GenerateTextRequest( model="models/text-bison-001", - prompt=glm.TextPrompt(text="Hello"), + prompt=protos.TextPrompt(text="Hello"), stop_sequences=["stop"], ), ) @@ -282,9 +284,9 @@ def test_stop_string(self): ] ) def test_safety_settings(self, safety_settings): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ - glm.TextCompletion(output="No"), + protos.TextCompletion(output="No"), ] ) # This test really just checks that the safety_settings get converted to a proto. @@ -298,7 +300,7 @@ def test_safety_settings(self, safety_settings): ) def test_filters(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[{"output": "hello"}], filters=[ { @@ -313,7 +315,7 @@ def test_filters(self): self.assertEqual(response.filters[0]["reason"], palm_safety_types.BlockedReason.SAFETY) def test_safety_feedback(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[{"output": "hello"}], safety_feedback=[ { @@ -341,7 +343,7 @@ def test_safety_feedback(self): self.assertIsInstance( response.safety_feedback[0]["setting"]["category"], - glm.HarmCategory, + protos.HarmCategory, ) self.assertEqual( response.safety_feedback[0]["setting"]["category"], @@ -349,7 +351,7 @@ def test_safety_feedback(self): ) def test_candidate_safety_feedback(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ { "output": "hello", @@ -370,7 +372,7 @@ def test_candidate_safety_feedback(self): result = text_service.generate_text(prompt="Write a story from the ER.") self.assertIsInstance( result.candidates[0]["safety_ratings"][0]["category"], - glm.HarmCategory, + protos.HarmCategory, ) self.assertEqual( result.candidates[0]["safety_ratings"][0]["category"], @@ -387,7 +389,7 @@ def test_candidate_safety_feedback(self): ) def test_candidate_citations(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ { "output": "Hello Google!", @@ -434,21 +436,21 @@ def test_candidate_citations(self): ), ), dict( - testcase_name="glm_model", - model=glm.Model( + testcase_name="protos.model", + model=protos.Model( name="models/text-bison-001", ), ), dict( - testcase_name="glm_tuned_model", - model=glm.TunedModel( + testcase_name="protos.tuned_model", + model=protos.TunedModel( name="tunedModels/bipedal-pangolin-001", base_model="models/text-bison-001", ), ), dict( - testcase_name="glm_tuned_model_nested", - model=glm.TunedModel( + testcase_name="protos.tuned_model_nested", + model=protos.TunedModel( name="tunedModels/bipedal-pangolin-002", tuned_model_source={ "tuned_model": "tunedModels/bipedal-pangolin-002", @@ -459,10 +461,10 @@ def test_candidate_citations(self): ] ) def test_count_message_tokens(self, model): - self.responses["get_tuned_model"] = glm.TunedModel( + self.responses["get_tuned_model"] = protos.TunedModel( name="tunedModels/bipedal-pangolin-001", base_model="models/text-bison-001" ) - self.responses["count_text_tokens"] = glm.CountTextTokensResponse(token_count=7) + self.responses["count_text_tokens"] = protos.CountTextTokensResponse(token_count=7) response = text_service.count_text_tokens(model, "Tell me a story about a magic backpack.") self.assertEqual({"token_count": 7}, response) @@ -472,7 +474,7 @@ def test_count_message_tokens(self, model): self.assertLen(self.observed_requests, 2) self.assertEqual( self.observed_requests[0], - glm.GetTunedModelRequest(name="tunedModels/bipedal-pangolin-001"), + protos.GetTunedModelRequest(name="tunedModels/bipedal-pangolin-001"), ) def test_count_text_tokens_called_with_request_options(self): From e2263cb7b786c110d2e3704f86d27883d8ab535e Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Mon, 3 Jun 2024 09:18:46 -0700 Subject: [PATCH 16/17] Handle image mode (#374) * Handle image mode Change-Id: Idbd0d65f6359557adbf812190048514080d48e6f * format Change-Id: Ia8a386ef959907650dd4fa5b5ab7401e75b9484a --- google/generativeai/types/content_types.py | 2 +- tests/test_content.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index b8966b005..7e343a5c0 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -72,7 +72,7 @@ def pil_to_blob(img): bytesio = io.BytesIO() - if isinstance(img, PIL.PngImagePlugin.PngImageFile): + if isinstance(img, PIL.PngImagePlugin.PngImageFile) or img.mode == "RGBA": img.save(bytesio, format="PNG") mime_type = "image/png" else: diff --git a/tests/test_content.py b/tests/test_content.py index 3829ebc86..6df5faad4 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -24,6 +24,7 @@ import IPython.display import PIL.Image +import numpy as np HERE = pathlib.Path(__file__).parent TEST_PNG_PATH = HERE / "test_img.png" @@ -67,6 +68,7 @@ class ADataClassWithList: class UnitTests(parameterized.TestCase): @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_PNG_PATH)], + ["RGBA", PIL.Image.fromarray(np.zeros([6, 6, 4], dtype=np.uint8))], ["IPython", IPython.display.Image(filename=TEST_PNG_PATH)], ) def test_png_to_blob(self, image): @@ -77,6 +79,7 @@ def test_png_to_blob(self, image): @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_JPG_PATH)], + ["RGB", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8))], ["IPython", IPython.display.Image(filename=TEST_JPG_PATH)], ) def test_jpg_to_blob(self, image): From 7b9758f54180b525d355393302100b0dea082a12 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Mon, 3 Jun 2024 09:32:07 -0700 Subject: [PATCH 17/17] Update version.py (#375) --- google/generativeai/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/generativeai/version.py b/google/generativeai/version.py index e1ce17d66..8018b67ac 100644 --- a/google/generativeai/version.py +++ b/google/generativeai/version.py @@ -14,4 +14,4 @@ # limitations under the License. from __future__ import annotations -__version__ = "0.5.4" +__version__ = "0.6.0"