diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index cd35f0928..e49ace3d6 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -24,6 +24,7 @@ from google.generativeai.client import get_default_discuss_client from google.generativeai.client import get_default_discuss_async_client +from google.generativeai import string_utils from google.generativeai.types import discuss_types from google.generativeai.types import model_types from google.generativeai.types import safety_types @@ -445,6 +446,7 @@ async def chat_async( DATACLASS_KWARGS = {} +@string_utils.prettyprint @set_doc(discuss_types.ChatResponse.__doc__) @dataclasses.dataclass(**DATACLASS_KWARGS, init=False) class ChatResponse(discuss_types.ChatResponse): diff --git a/google/generativeai/docstring_utils.py b/google/generativeai/docstring_utils.py deleted file mode 100644 index 66b1da5a5..000000000 --- a/google/generativeai/docstring_utils.py +++ /dev/null @@ -1,22 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - - -def strip_oneof(docstring): - lines = docstring.splitlines() - lines = [line for line in lines if ".. _oneof:" not in line] - lines = [line for line in lines if "This field is a member of `oneof`_" not in line] - return "\n".join(lines) diff --git a/google/generativeai/operations.py b/google/generativeai/operations.py index ffa0f237f..d492a9dee 100644 --- a/google/generativeai/operations.py +++ b/google/generativeai/operations.py @@ -18,6 +18,7 @@ from typing import Iterator from google.ai import generativelanguage as glm + from google.generativeai import client as client_lib from google.generativeai.types import model_types from google.api_core import operation as operation_lib diff --git a/google/generativeai/string_utils.py b/google/generativeai/string_utils.py new file mode 100644 index 000000000..6cda23635 --- /dev/null +++ b/google/generativeai/string_utils.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +import pprint +import re +import reprlib +import textwrap + + +def strip_oneof(docstring): + lines = docstring.splitlines() + lines = [line for line in lines if ".. _oneof:" not in line] + lines = [line for line in lines if "This field is a member of `oneof`_" not in line] + return "\n".join(lines) + + +def prettyprint(cls): + cls.__str__ = _prettyprint + cls.__repr__ = _prettyprint + return cls + + +repr = reprlib.Repr() + + +@reprlib.recursive_repr() +def _prettyprint(self): + """A dataclass prettyprint function you can use in __str__or __repr__. + + Note: You can't set `__str__ = pprint.pformat` because it causes a recursion error. + + Mostly identical to pprint but: + + * This will contract long lists and dicts (> 10lines) to [...] and {...}. + * This will contract long object reprs to ClassName(...). + """ + fields = [] + for f in dataclasses.fields(self): + s = pprint.pformat(getattr(self, f.name)) + class_re = r"^(\w+)\(.*\)$" + if s.count("\n") >= 10: + if s.startswith("["): + s = "[...]" + elif s.startswith("{"): + s = "{...}" + elif re.match(class_re, s, flags=re.DOTALL): + s = re.sub(class_re, r"\1(...)", s, flags=re.DOTALL) + else: + s = "..." + else: + width = len(f.name) + 1 + s = textwrap.indent(s, " " * width).lstrip(" ") + fields.append(f"{f.name}={s}") + attrs = ",\n".join(fields) + + name = self.__class__.__name__ + width = len(name) + 1 + + attrs = textwrap.indent(attrs, " " * width).lstrip(" ") + return f"{name}({attrs})" diff --git a/google/generativeai/text.py b/google/generativeai/text.py index 7a8bf90b2..d02b7b7e4 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -21,6 +21,7 @@ import google.ai.generativelanguage as glm from google.generativeai.client import get_default_text_client +from google.generativeai import string_utils from google.generativeai.types import text_types from google.generativeai.types import model_types from google.generativeai.types import safety_types @@ -175,6 +176,7 @@ def generate_text( return _generate_response(client=client, request=request) +@string_utils.prettyprint @dataclasses.dataclass(init=False) class Completion(text_types.Completion): def __init__(self, **kwargs): diff --git a/google/generativeai/types/citation_types.py b/google/generativeai/types/citation_types.py index 757220ce1..e720fcd88 100644 --- a/google/generativeai/types/citation_types.py +++ b/google/generativeai/types/citation_types.py @@ -17,7 +17,8 @@ from typing import Optional, List from google.ai import generativelanguage as glm -from google.generativeai import docstring_utils +from google.generativeai import string_utils + from typing import TypedDict __all__ = [ @@ -32,10 +33,10 @@ class CitationSourceDict(TypedDict): uri: str | None license: str | None - __doc__ = docstring_utils.strip_oneof(glm.CitationSource.__doc__) + __doc__ = string_utils.strip_oneof(glm.CitationSource.__doc__) class CitationMetadataDict(TypedDict): citation_sources: List[CitationSourceDict | None] - __doc__ = docstring_utils.strip_oneof(glm.CitationMetadata.__doc__) + __doc__ = string_utils.strip_oneof(glm.CitationMetadata.__doc__) diff --git a/google/generativeai/types/discuss_types.py b/google/generativeai/types/discuss_types.py index 7add0e59d..7790bd5e6 100644 --- a/google/generativeai/types/discuss_types.py +++ b/google/generativeai/types/discuss_types.py @@ -19,6 +19,8 @@ from typing import Any, Dict, TypedDict, Union, Iterable, Optional, Tuple, List import google.ai.generativelanguage as glm +from google.generativeai import string_utils + from google.generativeai.types import safety_types from google.generativeai.types import citation_types @@ -97,6 +99,7 @@ class ResponseDict(TypedDict): candidates: List[MessageDict] +@string_utils.prettyprint @dataclasses.dataclass(init=False) class ChatResponse(abc.ABC): """A chat response from the model. diff --git a/google/generativeai/types/model_types.py b/google/generativeai/types/model_types.py index 463ff4651..039488cd1 100644 --- a/google/generativeai/types/model_types.py +++ b/google/generativeai/types/model_types.py @@ -21,6 +21,7 @@ from typing import Any, Iterable, TypedDict, Union import google.ai.generativelanguage as glm +from google.generativeai import string_utils __all__ = [ "Model", @@ -65,6 +66,7 @@ def to_tuned_model_state(x: TunedModelStateOptions) -> TunedModelState: return _TUNED_MODEL_STATES[x] +@string_utils.prettyprint @dataclasses.dataclass class Model: """A dataclass representation of a `glm.Model`. @@ -152,6 +154,7 @@ def decode_tuned_model(tuned_model: glm.TunedModel | dict["str", Any]) -> TunedM return TunedModel(**tuned_model) +@string_utils.prettyprint @dataclasses.dataclass class TunedModel: """A dataclass representation of a `glm.TunedModel`.""" @@ -170,6 +173,7 @@ class TunedModel: tuning_task: TuningTask | None = None +@string_utils.prettyprint @dataclasses.dataclass class TuningTask: start_time: datetime.datetime | None = None @@ -208,6 +212,7 @@ def encode_tuning_example(example: TuningExampleOptions): return example +@string_utils.prettyprint @dataclasses.dataclass class TuningSnapshot: step: int @@ -216,6 +221,7 @@ class TuningSnapshot: compute_time: datetime.datetime +@string_utils.prettyprint @dataclasses.dataclass class Hyperparameters: epoch_count: int = 0 @@ -246,6 +252,7 @@ def make_model_name(name: AnyModelNameOptions): TunedModelsIterable = Iterable[TunedModel] +@string_utils.prettyprint @dataclasses.dataclass class TokenCount: """A dataclass representation of a `glm.TokenCountResponse`. diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index ddd2172ee..bedd65317 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -17,7 +17,8 @@ from collections.abc import Mapping from google.ai import generativelanguage as glm -from google.generativeai import docstring_utils +from google.generativeai import string_utils + import typing from typing import Iterable, Dict, Iterable, List, TypedDict, Union @@ -134,7 +135,7 @@ class ContentFilterDict(TypedDict): reason: BlockedReason message: str - __doc__ = docstring_utils.strip_oneof(glm.ContentFilter.__doc__) + __doc__ = string_utils.strip_oneof(glm.ContentFilter.__doc__) def convert_filters_to_enums( @@ -153,7 +154,7 @@ class SafetyRatingDict(TypedDict): category: HarmCategory probability: HarmProbability - __doc__ = docstring_utils.strip_oneof(glm.SafetyRating.__doc__) + __doc__ = string_utils.strip_oneof(glm.SafetyRating.__doc__) def convert_rating_to_enum(rating: dict) -> SafetyRatingDict: @@ -174,7 +175,7 @@ class SafetySettingDict(TypedDict): category: HarmCategory threshold: HarmBlockThreshold - __doc__ = docstring_utils.strip_oneof(glm.SafetySetting.__doc__) + __doc__ = string_utils.strip_oneof(glm.SafetySetting.__doc__) class LooseSafetySettingDict(TypedDict): @@ -220,7 +221,7 @@ class SafetyFeedbackDict(TypedDict): rating: SafetyRatingDict setting: SafetySettingDict - __doc__ = docstring_utils.strip_oneof(glm.SafetyFeedback.__doc__) + __doc__ = string_utils.strip_oneof(glm.SafetyFeedback.__doc__) def convert_safety_feedback_to_enums( diff --git a/google/generativeai/types/text_types.py b/google/generativeai/types/text_types.py index d729db4d6..3e14c5e73 100644 --- a/google/generativeai/types/text_types.py +++ b/google/generativeai/types/text_types.py @@ -18,6 +18,7 @@ import dataclasses from typing import Any, Dict, List, TypedDict +from google.generativeai import string_utils from google.generativeai.types import safety_types from google.generativeai.types import citation_types @@ -39,6 +40,7 @@ class TextCompletion(TypedDict, total=False): citation_metadata: citation_types.CitationMetadataDict | None +@string_utils.prettyprint @dataclasses.dataclass(init=False) class Completion(abc.ABC): """The result returned by `generativeai.generate_text`. diff --git a/tests/test_string_utils.py b/tests/test_string_utils.py new file mode 100644 index 000000000..48d5003c3 --- /dev/null +++ b/tests/test_string_utils.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import pprint +import textwrap +from typing import Any +import unittest + + +from google.generativeai import string_utils + +from absl.testing import parameterized + + +@string_utils.prettyprint +@dataclasses.dataclass +class MyClass: + a: int + b: float + c: list[int] + d: Any + + +class OperationsTests(parameterized.TestCase): + def test_simple(self): + m = MyClass(a=1, b=1 / 3, c=[0, 1, 2, 3, 4, 5], d={"a": 1, "b": 2}) + + result = str(m) + expected = textwrap.dedent( + """ + MyClass(a=1, + b=0.3333333333333333, + c=[0, 1, 2, 3, 4, 5], + d={'a': 1, 'b': 2})""" + )[1:] + self.assertEqual(expected, result) + self.assertEqual(pprint.pformat(m), result) + self.assertEqual(repr(m), result) + + def test_long_list(self): + m = MyClass(a=1, b=1 / 3, c=[1, 2, 3] * 10, d={"a": 1, "b": 2}) + expected = textwrap.dedent( + """ + MyClass(a=1, + b=0.3333333333333333, + c=[...], + d={'a': 1, 'b': 2})""" + )[1:] + self.assertEqual(expected, str(m)) + + def test_nested(self): + m1 = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d={"a": 1, "b": 2}) + m2 = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d=m1) + + expected = textwrap.dedent( + """ + MyClass(a=1, + b=0.3333333333333333, + c=[1, 2, 3], + d=MyClass(a=1, + b=0.3333333333333333, + c=[1, 2, 3], + d={'a': 1, 'b': 2}))""" + )[1:] + result = str(m2) + self.assertEqual(expected, result) + self.assertEqual(pprint.pformat(m2), result) + self.assertEqual(repr(m2), result) + + def test_long_obj(self): + m = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d=None) + m = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d=m) + m = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d=m) + m = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d=m) + m = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d=m) + + expected = textwrap.dedent( + """ + MyClass(a=1, + b=0.3333333333333333, + c=[1, 2, 3], + d=MyClass(...))""" + )[1:] + self.assertEqual(expected, str(m)) + + def test_recursive(self): + m = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d=None) + m.d = m + + expected = textwrap.dedent( + """ + MyClass(a=1, + b=0.3333333333333333, + c=[1, 2, 3], + d=...)""" + )[1:] + result = str(m) + self.assertEqual(expected, result) + self.assertEqual(pprint.pformat(m), result) + self.assertEqual(repr(m), result)