From d247e58f6eb4f103b6d76a62f51699d174afcc9a Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Mon, 2 Oct 2023 15:44:55 -0700 Subject: [PATCH 1/4] Add dataclass prettyprinting. --- google/generativeai/discuss.py | 2 + google/generativeai/operations.py | 1 + .../{docstring_utils.py => string_utils.py} | 33 +++++++ google/generativeai/text.py | 2 + google/generativeai/types/citation_types.py | 7 +- google/generativeai/types/discuss_types.py | 3 + google/generativeai/types/model_types.py | 7 ++ google/generativeai/types/safety_types.py | 11 ++- google/generativeai/types/text_types.py | 2 + tests/test_string_utils.py | 98 +++++++++++++++++++ 10 files changed, 158 insertions(+), 8 deletions(-) rename google/generativeai/{docstring_utils.py => string_utils.py} (51%) create mode 100644 tests/test_string_utils.py diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index cd35f0928..e49ace3d6 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -24,6 +24,7 @@ from google.generativeai.client import get_default_discuss_client from google.generativeai.client import get_default_discuss_async_client +from google.generativeai import string_utils from google.generativeai.types import discuss_types from google.generativeai.types import model_types from google.generativeai.types import safety_types @@ -445,6 +446,7 @@ async def chat_async( DATACLASS_KWARGS = {} +@string_utils.prettyprint @set_doc(discuss_types.ChatResponse.__doc__) @dataclasses.dataclass(**DATACLASS_KWARGS, init=False) class ChatResponse(discuss_types.ChatResponse): diff --git a/google/generativeai/operations.py b/google/generativeai/operations.py index ffa0f237f..d492a9dee 100644 --- a/google/generativeai/operations.py +++ b/google/generativeai/operations.py @@ -18,6 +18,7 @@ from typing import Iterator from google.ai import generativelanguage as glm + from google.generativeai import client as client_lib from google.generativeai.types import model_types from google.api_core import operation as operation_lib diff --git a/google/generativeai/docstring_utils.py b/google/generativeai/string_utils.py similarity index 51% rename from google/generativeai/docstring_utils.py rename to google/generativeai/string_utils.py index 66b1da5a5..f176fa321 100644 --- a/google/generativeai/docstring_utils.py +++ b/google/generativeai/string_utils.py @@ -14,9 +14,42 @@ # limitations under the License. from __future__ import annotations +import pprint +import textwrap +import dataclasses + def strip_oneof(docstring): lines = docstring.splitlines() lines = [line for line in lines if ".. _oneof:" not in line] lines = [line for line in lines if "This field is a member of `oneof`_" not in line] return "\n".join(lines) + + +def prettyprint(cls): + cls.__str__ = _prettyprint + cls.__repr__ = _prettyprint + return cls + + +def _prettyprint(self): + """You can't use `__str__ = pprint.pformat`. That causes a recursion error. + + This works, but it doesn't handle objects that reference themselves. + """ + fields = [] + for f in dataclasses.fields(self): + s = pprint.pformat(getattr(self, f.name)) + if s.count("\n") >= 10: + s = "..." + else: + width = len(f.name) + 1 + s = textwrap.indent(s, " " * width).lstrip(" ") + fields.append(f"{f.name}={s}") + attrs = ",\n".join(fields) + + name = self.__class__.__name__ + width = len(name) + 1 + + attrs = textwrap.indent(attrs, " " * width).lstrip(" ") + return f"{name}({attrs})" diff --git a/google/generativeai/text.py b/google/generativeai/text.py index 7a8bf90b2..d02b7b7e4 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -21,6 +21,7 @@ import google.ai.generativelanguage as glm from google.generativeai.client import get_default_text_client +from google.generativeai import string_utils from google.generativeai.types import text_types from google.generativeai.types import model_types from google.generativeai.types import safety_types @@ -175,6 +176,7 @@ def generate_text( return _generate_response(client=client, request=request) +@string_utils.prettyprint @dataclasses.dataclass(init=False) class Completion(text_types.Completion): def __init__(self, **kwargs): diff --git a/google/generativeai/types/citation_types.py b/google/generativeai/types/citation_types.py index 757220ce1..e720fcd88 100644 --- a/google/generativeai/types/citation_types.py +++ b/google/generativeai/types/citation_types.py @@ -17,7 +17,8 @@ from typing import Optional, List from google.ai import generativelanguage as glm -from google.generativeai import docstring_utils +from google.generativeai import string_utils + from typing import TypedDict __all__ = [ @@ -32,10 +33,10 @@ class CitationSourceDict(TypedDict): uri: str | None license: str | None - __doc__ = docstring_utils.strip_oneof(glm.CitationSource.__doc__) + __doc__ = string_utils.strip_oneof(glm.CitationSource.__doc__) class CitationMetadataDict(TypedDict): citation_sources: List[CitationSourceDict | None] - __doc__ = docstring_utils.strip_oneof(glm.CitationMetadata.__doc__) + __doc__ = string_utils.strip_oneof(glm.CitationMetadata.__doc__) diff --git a/google/generativeai/types/discuss_types.py b/google/generativeai/types/discuss_types.py index 7add0e59d..7790bd5e6 100644 --- a/google/generativeai/types/discuss_types.py +++ b/google/generativeai/types/discuss_types.py @@ -19,6 +19,8 @@ from typing import Any, Dict, TypedDict, Union, Iterable, Optional, Tuple, List import google.ai.generativelanguage as glm +from google.generativeai import string_utils + from google.generativeai.types import safety_types from google.generativeai.types import citation_types @@ -97,6 +99,7 @@ class ResponseDict(TypedDict): candidates: List[MessageDict] +@string_utils.prettyprint @dataclasses.dataclass(init=False) class ChatResponse(abc.ABC): """A chat response from the model. diff --git a/google/generativeai/types/model_types.py b/google/generativeai/types/model_types.py index 463ff4651..039488cd1 100644 --- a/google/generativeai/types/model_types.py +++ b/google/generativeai/types/model_types.py @@ -21,6 +21,7 @@ from typing import Any, Iterable, TypedDict, Union import google.ai.generativelanguage as glm +from google.generativeai import string_utils __all__ = [ "Model", @@ -65,6 +66,7 @@ def to_tuned_model_state(x: TunedModelStateOptions) -> TunedModelState: return _TUNED_MODEL_STATES[x] +@string_utils.prettyprint @dataclasses.dataclass class Model: """A dataclass representation of a `glm.Model`. @@ -152,6 +154,7 @@ def decode_tuned_model(tuned_model: glm.TunedModel | dict["str", Any]) -> TunedM return TunedModel(**tuned_model) +@string_utils.prettyprint @dataclasses.dataclass class TunedModel: """A dataclass representation of a `glm.TunedModel`.""" @@ -170,6 +173,7 @@ class TunedModel: tuning_task: TuningTask | None = None +@string_utils.prettyprint @dataclasses.dataclass class TuningTask: start_time: datetime.datetime | None = None @@ -208,6 +212,7 @@ def encode_tuning_example(example: TuningExampleOptions): return example +@string_utils.prettyprint @dataclasses.dataclass class TuningSnapshot: step: int @@ -216,6 +221,7 @@ class TuningSnapshot: compute_time: datetime.datetime +@string_utils.prettyprint @dataclasses.dataclass class Hyperparameters: epoch_count: int = 0 @@ -246,6 +252,7 @@ def make_model_name(name: AnyModelNameOptions): TunedModelsIterable = Iterable[TunedModel] +@string_utils.prettyprint @dataclasses.dataclass class TokenCount: """A dataclass representation of a `glm.TokenCountResponse`. diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index ddd2172ee..bedd65317 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -17,7 +17,8 @@ from collections.abc import Mapping from google.ai import generativelanguage as glm -from google.generativeai import docstring_utils +from google.generativeai import string_utils + import typing from typing import Iterable, Dict, Iterable, List, TypedDict, Union @@ -134,7 +135,7 @@ class ContentFilterDict(TypedDict): reason: BlockedReason message: str - __doc__ = docstring_utils.strip_oneof(glm.ContentFilter.__doc__) + __doc__ = string_utils.strip_oneof(glm.ContentFilter.__doc__) def convert_filters_to_enums( @@ -153,7 +154,7 @@ class SafetyRatingDict(TypedDict): category: HarmCategory probability: HarmProbability - __doc__ = docstring_utils.strip_oneof(glm.SafetyRating.__doc__) + __doc__ = string_utils.strip_oneof(glm.SafetyRating.__doc__) def convert_rating_to_enum(rating: dict) -> SafetyRatingDict: @@ -174,7 +175,7 @@ class SafetySettingDict(TypedDict): category: HarmCategory threshold: HarmBlockThreshold - __doc__ = docstring_utils.strip_oneof(glm.SafetySetting.__doc__) + __doc__ = string_utils.strip_oneof(glm.SafetySetting.__doc__) class LooseSafetySettingDict(TypedDict): @@ -220,7 +221,7 @@ class SafetyFeedbackDict(TypedDict): rating: SafetyRatingDict setting: SafetySettingDict - __doc__ = docstring_utils.strip_oneof(glm.SafetyFeedback.__doc__) + __doc__ = string_utils.strip_oneof(glm.SafetyFeedback.__doc__) def convert_safety_feedback_to_enums( diff --git a/google/generativeai/types/text_types.py b/google/generativeai/types/text_types.py index d729db4d6..3e14c5e73 100644 --- a/google/generativeai/types/text_types.py +++ b/google/generativeai/types/text_types.py @@ -18,6 +18,7 @@ import dataclasses from typing import Any, Dict, List, TypedDict +from google.generativeai import string_utils from google.generativeai.types import safety_types from google.generativeai.types import citation_types @@ -39,6 +40,7 @@ class TextCompletion(TypedDict, total=False): citation_metadata: citation_types.CitationMetadataDict | None +@string_utils.prettyprint @dataclasses.dataclass(init=False) class Completion(abc.ABC): """The result returned by `generativeai.generate_text`. diff --git a/tests/test_string_utils.py b/tests/test_string_utils.py new file mode 100644 index 000000000..a50cb1bb4 --- /dev/null +++ b/tests/test_string_utils.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import pprint +import textwrap +from typing import Any +import unittest + + +from google.generativeai import string_utils + +from absl.testing import parameterized + + +@string_utils.prettyprint +@dataclasses.dataclass +class MyClass: + a: int + b: float + c: list[int] + d: Any + + +class OperationsTests(parameterized.TestCase): + def test_simple(self): + m = MyClass(a=1, b=1 / 3, c=[0, 1, 2, 3, 4, 5], d={"a": 1, "b": 2}) + + result = str(m) + expected = textwrap.dedent( + """ + MyClass(a=1, + b=0.3333333333333333, + c=[0, 1, 2, 3, 4, 5], + d={'a': 1, 'b': 2})""" + )[1:] + self.assertEqual(expected, result) + self.assertEqual(pprint.pformat(m), result) + self.assertEqual(repr(m), result) + + def test_long(self): + m = MyClass(a=1, b=1 / 3, c=[1, 2, 3] * 10, d={"a": 1, "b": 2}) + expected = textwrap.dedent( + """ + MyClass(a=1, + b=0.3333333333333333, + c=..., + d={'a': 1, 'b': 2})""" + )[1:] + self.assertEqual(expected, str(m)) + + def test_nested(self): + m1 = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d={"a": 1, "b": 2}) + m2 = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d=m1) + + expected = textwrap.dedent( + """ + MyClass(a=1, + b=0.3333333333333333, + c=[1, 2, 3], + d=MyClass(a=1, + b=0.3333333333333333, + c=[1, 2, 3], + d={'a': 1, 'b': 2}))""" + )[1:] + result = str(m2) + self.assertEqual(expected, result) + self.assertEqual(pprint.pformat(m2), result) + self.assertEqual(repr(m2), result) + + @unittest.skip("I don't have a solution for this.") + def test_recursive(self): + m = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d=None) + m.d = m + + expected = textwrap.dedent( + """ + MyClass(a=1, + b=0.3333333333333333, + c=[1, 2, 3], + d=...)""" + )[1:] + result = str(m) + self.assertEqual(expected, result) + self.assertEqual(pprint.pformat(m), result) + self.assertEqual(repr(m), result) From 90c5a70111d61c2f21b21ee3ef4af50aef2a4f5f Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 5 Oct 2023 11:17:30 -0700 Subject: [PATCH 2/4] use reprlib.recursive_repr --- google/generativeai/string_utils.py | 5 ++++- tests/test_string_utils.py | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/google/generativeai/string_utils.py b/google/generativeai/string_utils.py index f176fa321..8f5c1cd46 100644 --- a/google/generativeai/string_utils.py +++ b/google/generativeai/string_utils.py @@ -14,9 +14,10 @@ # limitations under the License. from __future__ import annotations +import dataclasses import pprint +import reprlib import textwrap -import dataclasses def strip_oneof(docstring): @@ -32,6 +33,8 @@ def prettyprint(cls): return cls + +@reprlib.recursive_repr() def _prettyprint(self): """You can't use `__str__ = pprint.pformat`. That causes a recursion error. diff --git a/tests/test_string_utils.py b/tests/test_string_utils.py index a50cb1bb4..ad4ced46e 100644 --- a/tests/test_string_utils.py +++ b/tests/test_string_utils.py @@ -80,7 +80,6 @@ def test_nested(self): self.assertEqual(pprint.pformat(m2), result) self.assertEqual(repr(m2), result) - @unittest.skip("I don't have a solution for this.") def test_recursive(self): m = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d=None) m.d = m From 839a854740bb61512fff84988121678abcfa78ae Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 5 Oct 2023 11:18:25 -0700 Subject: [PATCH 3/4] format --- google/generativeai/string_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/google/generativeai/string_utils.py b/google/generativeai/string_utils.py index 8f5c1cd46..1e5b454d8 100644 --- a/google/generativeai/string_utils.py +++ b/google/generativeai/string_utils.py @@ -33,7 +33,6 @@ def prettyprint(cls): return cls - @reprlib.recursive_repr() def _prettyprint(self): """You can't use `__str__ = pprint.pformat`. That causes a recursion error. From 43a457683120b2406f9b3b06b032f27dcd91ffd5 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 5 Oct 2023 12:43:05 -0700 Subject: [PATCH 4/4] Improve contractions. --- google/generativeai/string_utils.py | 23 ++++++++++++++++++++--- tests/test_string_utils.py | 20 ++++++++++++++++++-- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/google/generativeai/string_utils.py b/google/generativeai/string_utils.py index 1e5b454d8..6cda23635 100644 --- a/google/generativeai/string_utils.py +++ b/google/generativeai/string_utils.py @@ -16,6 +16,7 @@ import dataclasses import pprint +import re import reprlib import textwrap @@ -33,17 +34,33 @@ def prettyprint(cls): return cls +repr = reprlib.Repr() + + @reprlib.recursive_repr() def _prettyprint(self): - """You can't use `__str__ = pprint.pformat`. That causes a recursion error. + """A dataclass prettyprint function you can use in __str__or __repr__. + + Note: You can't set `__str__ = pprint.pformat` because it causes a recursion error. + + Mostly identical to pprint but: - This works, but it doesn't handle objects that reference themselves. + * This will contract long lists and dicts (> 10lines) to [...] and {...}. + * This will contract long object reprs to ClassName(...). """ fields = [] for f in dataclasses.fields(self): s = pprint.pformat(getattr(self, f.name)) + class_re = r"^(\w+)\(.*\)$" if s.count("\n") >= 10: - s = "..." + if s.startswith("["): + s = "[...]" + elif s.startswith("{"): + s = "{...}" + elif re.match(class_re, s, flags=re.DOTALL): + s = re.sub(class_re, r"\1(...)", s, flags=re.DOTALL) + else: + s = "..." else: width = len(f.name) + 1 s = textwrap.indent(s, " " * width).lstrip(" ") diff --git a/tests/test_string_utils.py b/tests/test_string_utils.py index ad4ced46e..48d5003c3 100644 --- a/tests/test_string_utils.py +++ b/tests/test_string_utils.py @@ -50,13 +50,13 @@ def test_simple(self): self.assertEqual(pprint.pformat(m), result) self.assertEqual(repr(m), result) - def test_long(self): + def test_long_list(self): m = MyClass(a=1, b=1 / 3, c=[1, 2, 3] * 10, d={"a": 1, "b": 2}) expected = textwrap.dedent( """ MyClass(a=1, b=0.3333333333333333, - c=..., + c=[...], d={'a': 1, 'b': 2})""" )[1:] self.assertEqual(expected, str(m)) @@ -80,6 +80,22 @@ def test_nested(self): self.assertEqual(pprint.pformat(m2), result) self.assertEqual(repr(m2), result) + def test_long_obj(self): + m = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d=None) + m = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d=m) + m = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d=m) + m = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d=m) + m = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d=m) + + expected = textwrap.dedent( + """ + MyClass(a=1, + b=0.3333333333333333, + c=[1, 2, 3], + d=MyClass(...))""" + )[1:] + self.assertEqual(expected, str(m)) + def test_recursive(self): m = MyClass(a=1, b=1 / 3, c=[1, 2, 3], d=None) m.d = m