From a27288c43389d1187ce8e3fa69214f94bdbbd900 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 17 Oct 2024 14:11:36 -0700 Subject: [PATCH 01/12] Allow the use of generated images as inputs Change-Id: I0956fb78272a8a8af2c5219d80a26dec944040a8 # Conflicts: # google/generativeai/vision_models/_vision_models.py --- google/generativeai/client.py | 97 +++++- google/generativeai/types/content_types.py | 98 +----- google/generativeai/vision_models/__init__.py | 6 +- .../vision_models/_vision_models.py | 300 +----------------- tests/test_content.py | 13 +- 5 files changed, 108 insertions(+), 406 deletions(-) diff --git a/google/generativeai/client.py b/google/generativeai/client.py index a75643f1a..69492c1cd 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -6,7 +6,7 @@ import dataclasses import pathlib from typing import Any, cast -from collections.abc import Sequence +from collections.abc import Sequence, Mapping import httplib2 from io import IOBase @@ -23,6 +23,11 @@ import googleapiclient.http import googleapiclient.discovery +from google.protobuf import struct_pb2 + +from proto.marshal.collections import maps +from proto.marshal.collections import repeated + try: from google.generativeai import version @@ -130,6 +135,73 @@ async def create_file(self, *args, **kwargs): ) +# This is to get around https://github.com/googleapis/proto-plus-python/issues/488 +def to_value(value) -> struct_pb2.Value: + """Return a protobuf Value object representing this value.""" + if isinstance(value, struct_pb2.Value): + return value + if value is None: + return struct_pb2.Value(null_value=0) + if isinstance(value, bool): + return struct_pb2.Value(bool_value=value) + if isinstance(value, (int, float)): + return struct_pb2.Value(number_value=float(value)) + if isinstance(value, str): + return struct_pb2.Value(string_value=value) + if isinstance(value, collections.abc.Sequence): + return struct_pb2.Value(list_value=to_list_value(value)) + if isinstance(value, collections.abc.Mapping): + return struct_pb2.Value(struct_value=to_mapping_value(value)) + raise ValueError("Unable to coerce value: %r" % value) + + +def to_list_value(value) -> struct_pb2.ListValue: + # We got a proto, or else something we sent originally. + # Preserve the instance we have. + if isinstance(value, struct_pb2.ListValue): + return value + if isinstance(value, repeated.RepeatedComposite): + return struct_pb2.ListValue(values=[v for v in value.pb]) + + # We got a list (or something list-like); convert it. + return struct_pb2.ListValue(values=[to_value(v) for v in value]) + + +def to_mapping_value(value) -> struct_pb2.Struct: + # We got a proto, or else something we sent originally. + # Preserve the instance we have. + if isinstance(value, struct_pb2.Struct): + return value + if isinstance(value, maps.MapComposite): + return struct_pb2.Struct( + fields={k: v for k, v in value.pb.items()}, + ) + + # We got a dict (or something dict-like); convert it. + return struct_pb2.Struct(fields={k: to_value(v) for k, v in value.items()}) + + +# This is to get around https://github.com/googleapis/proto-plus-python/issues/488 + + +class PredictionServiceClient(glm.PredictionServiceClient): + def predict(self, model=None, instances=None, parameters=None): + pr = protos.PredictRequest.pb() + request = pr( + model=model, instances=[to_value(i) for i in instances], parameters=to_value(parameters) + ) + return super().predict(request) + + +class PredictionServiceAsyncClient(glm.PredictionServiceAsyncClient): + async def predict(self, model=None, instances=None, parameters=None): + pr = protos.PredictRequest.pb() + request = pr( + model=model, instances=[to_value(i) for i in instances], parameters=to_value(parameters) + ) + return await super().predict(request) + + @dataclasses.dataclass class _ClientManager: client_config: dict[str, Any] = dataclasses.field(default_factory=dict) @@ -220,15 +292,20 @@ def configure( self.clients = {} def make_client(self, name): - if name == "file": - cls = FileServiceClient - elif name == "file_async": - cls = FileServiceAsyncClient - elif name.endswith("_async"): - name = name.split("_")[0] - cls = getattr(glm, name.title() + "ServiceAsyncClient") - else: - cls = getattr(glm, name.title() + "ServiceClient") + local_clients = { + "file": FileServiceClient, + "file_async": FileServiceAsyncClient, + "prediction": PredictionServiceClient, + "prediction_async": PredictionServiceAsyncClient, + } + cls = local_clients.get("name", None) + + if cls is None: + if name.endswith("_async"): + name = name.split("_")[0] + cls = getattr(glm, name.title() + "ServiceAsyncClient") + else: + cls = getattr(glm, name.title() + "ServiceClient") # Attempt to configure using defaults. if not self.client_config: diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index 23241a536..3eeababbb 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -16,45 +16,16 @@ from __future__ import annotations from collections.abc import Iterable, Mapping, Sequence -import io import inspect -import mimetypes -import pathlib -import typing from typing import Any, Callable, Union from typing_extensions import TypedDict import pydantic from google.generativeai.types import file_types +from google.generativeai.types.image_types import _image_types from google.generativeai import protos -if typing.TYPE_CHECKING: - import PIL.Image - import PIL.ImageFile - import IPython.display - - IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image) - ImageType = PIL.Image.Image | IPython.display.Image -else: - IMAGE_TYPES = () - try: - import PIL.Image - import PIL.ImageFile - - IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,) - except ImportError: - PIL = None - - try: - import IPython.display - - IMAGE_TYPES = IMAGE_TYPES + (IPython.display.Image,) - except ImportError: - IPython = None - - ImageType = Union["PIL.Image.Image", "IPython.display.Image"] - __all__ = [ "BlobDict", @@ -97,62 +68,6 @@ def to_mode(x: ModeOptions) -> Mode: return _MODE[x] -def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob: - # If the image is a local file, return a file-based blob without any modification. - # Otherwise, return a lossless WebP blob (same quality with optimized size). - def file_blob(image: PIL.Image.Image) -> protos.Blob | None: - if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None: - return None - filename = str(image.filename) - if not pathlib.Path(filename).is_file(): - return None - - mime_type = image.get_format_mimetype() - image_bytes = pathlib.Path(filename).read_bytes() - - return protos.Blob(mime_type=mime_type, data=image_bytes) - - def webp_blob(image: PIL.Image.Image) -> protos.Blob: - # Reference: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp - image_io = io.BytesIO() - image.save(image_io, format="webp", lossless=True) - image_io.seek(0) - - mime_type = "image/webp" - image_bytes = image_io.read() - - return protos.Blob(mime_type=mime_type, data=image_bytes) - - return file_blob(image) or webp_blob(image) - - -def image_to_blob(image: ImageType) -> protos.Blob: - if PIL is not None: - if isinstance(image, PIL.Image.Image): - return _pil_to_blob(image) - - if IPython is not None: - if isinstance(image, IPython.display.Image): - name = image.filename - if name is None: - raise ValueError( - "Conversion failed. The `IPython.display.Image` can only be converted if " - "it is constructed from a local file. Please ensure you are using the format: Image(filename='...')." - ) - mime_type, _ = mimetypes.guess_type(name) - if mime_type is None: - mime_type = "image/unknown" - - return protos.Blob(mime_type=mime_type, data=image.data) - - raise TypeError( - "Image conversion failed. The input was expected to be of type `Image` " - "(either `PIL.Image.Image` or `IPython.display.Image`).\n" - f"However, received an object of type: {type(image)}.\n" - f"Object Value: {image}" - ) - - class BlobDict(TypedDict): mime_type: str data: bytes @@ -189,12 +104,7 @@ def is_blob_dict(d): return "mime_type" in d and "data" in d -if typing.TYPE_CHECKING: - BlobType = Union[ - protos.Blob, BlobDict, PIL.Image.Image, IPython.display.Image - ] # Any for the images -else: - BlobType = Union[protos.Blob, BlobDict, Any] +BlobType = Union[protos.Blob, BlobDict, _image_types.ImageType] # Any for the images def to_blob(blob: BlobType) -> protos.Blob: @@ -203,8 +113,8 @@ def to_blob(blob: BlobType) -> protos.Blob: if isinstance(blob, protos.Blob): return blob - elif isinstance(blob, IMAGE_TYPES): - return image_to_blob(blob) + elif isinstance(blob, _image_types.IMAGE_TYPES): + return _image_types.image_to_blob(blob) else: if isinstance(blob, Mapping): raise KeyError( diff --git a/google/generativeai/vision_models/__init__.py b/google/generativeai/vision_models/__init__.py index 2a4a27e32..e1b62d39b 100644 --- a/google/generativeai/vision_models/__init__.py +++ b/google/generativeai/vision_models/__init__.py @@ -14,15 +14,15 @@ # """Classes for working with vision models.""" +from google.generativeai.types.image_types import check_watermark, Image, GeneratedImage + from google.generativeai.vision_models._vision_models import ( - check_watermark, - Image, - GeneratedImage, ImageGenerationModel, ImageGenerationResponse, ) __all__ = [ + "check_watermark", "Image", "GeneratedImage", "ImageGenerationModel", diff --git a/google/generativeai/vision_models/_vision_models.py b/google/generativeai/vision_models/_vision_models.py index 52ec689a9..69e44ab5f 100644 --- a/google/generativeai/vision_models/_vision_models.py +++ b/google/generativeai/vision_models/_vision_models.py @@ -26,16 +26,12 @@ from typing import Any, Dict, List, Literal, Optional, Union from google.generativeai import client +from google.generativeai.types import image_types from google.generativeai import protos from google.generativeai.types import content_types -from google.protobuf import struct_pb2 -from proto.marshal.collections import maps -from proto.marshal.collections import repeated - - -# pylint: disable=g-import-not-at-top\ +# pylint: disable=g-import-not-at-top if typing.TYPE_CHECKING: from IPython import display as IPython_display else: @@ -53,52 +49,6 @@ PIL_Image = None -# This is to get around https://github.com/googleapis/proto-plus-python/issues/488 -def to_value(value) -> struct_pb2.Value: - """Return a protobuf Value object representing this value.""" - if isinstance(value, struct_pb2.Value): - return value - if value is None: - return struct_pb2.Value(null_value=0) - if isinstance(value, bool): - return struct_pb2.Value(bool_value=value) - if isinstance(value, (int, float)): - return struct_pb2.Value(number_value=float(value)) - if isinstance(value, str): - return struct_pb2.Value(string_value=value) - if isinstance(value, collections.abc.Sequence): - return struct_pb2.Value(list_value=to_list_value(value)) - if isinstance(value, collections.abc.Mapping): - return struct_pb2.Value(struct_value=to_mapping_value(value)) - raise ValueError("Unable to coerce value: %r" % value) - - -def to_list_value(value) -> struct_pb2.ListValue: - # We got a proto, or else something we sent originally. - # Preserve the instance we have. - if isinstance(value, struct_pb2.ListValue): - return value - if isinstance(value, repeated.RepeatedComposite): - return struct_pb2.ListValue(values=[v for v in value.pb]) - - # We got a list (or something list-like); convert it. - return struct_pb2.ListValue(values=[to_value(v) for v in value]) - - -def to_mapping_value(value) -> struct_pb2.Struct: - # We got a proto, or else something we sent originally. - # Preserve the instance we have. - if isinstance(value, struct_pb2.Struct): - return value - if isinstance(value, maps.MapComposite): - return struct_pb2.Struct( - fields={k: v for k, v in value.pb.items()}, - ) - - # We got a dict (or something dict-like); convert it. - return struct_pb2.Struct(fields={k: to_value(v) for k, v in value.items()}) - - AspectRatio = Literal["1:1", "9:16", "16:9", "4:3", "3:4"] ASPECT_RATIOS = AspectRatio.__args__ # type: ignore @@ -111,171 +61,6 @@ def to_mapping_value(value) -> struct_pb2.Struct: PersonGeneration = Literal["dont_allow", "allow_adult"] PERSON_GENERATIONS = PersonGeneration.__args__ # type: ignore -ImageLikeType = Union["Image", pathlib.Path, content_types.ImageType] - - -def check_watermark( - img: ImageLikeType, model_id: str = "models/image-verification-001" -) -> "CheckWatermarkResult": - """Checks if an image has a Google-AI watermark. - - Args: - img: can be a `pathlib.Path` or a `PIL.Image.Image`, `IPython.display.Image`, or `google.generativeai.Image`. - model_id: Which version of the image-verification model to send the image to. - - Returns: - - """ - if isinstance(img, Image): - pass - elif isinstance(img, pathlib.Path): - img = Image.load_from_file(img) - elif IPython_display is not None and isinstance(img, IPython_display.Image): - img = Image(image_bytes=img.data) - elif PIL_Image is not None and isinstance(img, PIL_Image.Image): - blob = content_types._pil_to_blob(img) - img = Image(image_bytes=blob.data) - elif isinstance(img, protos.Blob): - img = Image(image_bytes=img.data) - else: - raise TypeError( - f"Not implemented: Could not convert a {type(img)} into `Image`\n {img=}" - ) - - prediction_client = client.get_default_prediction_client() - if not model_id.startswith("models/"): - model_id = f"models/{model_id}" - - instance = {"image": {"bytesBase64Encoded": base64.b64encode(img._loaded_bytes).decode()}} - parameters = {"watermarkVerification": True} - - # This is to get around https://github.com/googleapis/proto-plus-python/issues/488 - pr = protos.PredictRequest.pb() - request = pr(model=model_id, instances=[to_value(instance)], parameters=to_value(parameters)) - - response = prediction_client.predict(request) - - return CheckWatermarkResult(response.predictions) - - -class Image: - """Image.""" - - __module__ = "vertexai.vision_models" - - _loaded_bytes: Optional[bytes] = None - _loaded_image: Optional["PIL_Image.Image"] = None - - def __init__( - self, - image_bytes: Optional[bytes], - ): - """Creates an `Image` object. - - Args: - image_bytes: Image file bytes. Image can be in PNG or JPEG format. - """ - self._image_bytes = image_bytes - - @staticmethod - def load_from_file(location: os.PathLike) -> "Image": - """Loads image from local file or Google Cloud Storage. - - Args: - location: Local path or Google Cloud Storage uri from where to load - the image. - - Returns: - Loaded image as an `Image` object. - """ - # Load image from local path - image_bytes = pathlib.Path(location).read_bytes() - image = Image(image_bytes=image_bytes) - return image - - @property - def _image_bytes(self) -> bytes: - return self._loaded_bytes - - @_image_bytes.setter - def _image_bytes(self, value: bytes): - self._loaded_bytes = value - - @property - def _pil_image(self) -> "PIL_Image.Image": # type: ignore - if self._loaded_image is None: - if not PIL_Image: - raise RuntimeError( - "The PIL module is not available. Please install the Pillow package." - ) - self._loaded_image = PIL_Image.open(io.BytesIO(self._image_bytes)) - return self._loaded_image - - @property - def _size(self): - return self._pil_image.size - - @property - def _mime_type(self) -> str: - """Returns the MIME type of the image.""" - if PIL_Image: - return PIL_Image.MIME.get(self._pil_image.format, "image/jpeg") - # Fall back to jpeg - return "image/jpeg" - - def show(self): - """Shows the image. - - This method only works when in a notebook environment. - """ - if PIL_Image and IPython_display: - IPython_display.display(self._pil_image) - - def save(self, location: str): - """Saves image to a file. - - Args: - location: Local path where to save the image. - """ - pathlib.Path(location).write_bytes(self._image_bytes) - - def _as_base64_string(self) -> str: - """Encodes image using the base64 encoding. - - Returns: - Base64 encoding of the image as a string. - """ - # ! b64encode returns `bytes` object, not `str`. - # We need to convert `bytes` to `str`, otherwise we get service error: - # "received initial metadata size exceeds limit" - return base64.b64encode(self._image_bytes).decode("ascii") - - def _repr_png_(self): - return self._pil_image._repr_png_() # type:ignore - - check_watermark = check_watermark - - -class CheckWatermarkResult: - def __init__(self, predictions): - self._predictions = predictions - - @property - def decision(self): - return self._predictions[0]["decision"] - - def __str__(self): - return f"CheckWatermarkResult([{{'decision': {self.decision!r}}}])" - - def __bool__(self): - decision = self.decision - if decision == "ACCEPT": - return True - elif decision == "REJECT": - return False - else: - raise ValueError(f"Unrecognized result: {decision}") - class ImageGenerationModel: """Generates images from text prompt. @@ -417,20 +202,16 @@ def _generate_images( parameters["personGeneration"] = person_generation shared_generation_parameters["person_generation"] = person_generation - # This is to get around https://github.com/googleapis/proto-plus-python/issues/488 - pr = protos.PredictRequest.pb() - request = pr( - model=self.model_name, instances=[to_value(instance)], parameters=to_value(parameters) + response = self._client.predict( + model=self.model_name, instances=[instance], parameters=parameters ) - response = self._client.predict(request) - generated_images: List["GeneratedImage"] = [] for idx, prediction in enumerate(response.predictions): generation_parameters = dict(shared_generation_parameters) generation_parameters["index_of_image_in_batch"] = idx encoded_bytes = prediction.get("bytesBase64Encoded") - generated_image = GeneratedImage( + generated_image = image_types.GeneratedImage( image_bytes=base64.b64decode(encoded_bytes) if encoded_bytes else None, generation_parameters=generation_parameters, ) @@ -517,74 +298,3 @@ def __getitem__(self, idx: int) -> "GeneratedImage": """Gets the generated image by index.""" return self.images[idx] - -_EXIF_USER_COMMENT_TAG_IDX = 0x9286 -_IMAGE_GENERATION_PARAMETERS_EXIF_KEY = ( - "google.cloud.vertexai.image_generation.image_generation_parameters" -) - - -class GeneratedImage(Image): - """Generated image.""" - - __module__ = "google.generativeai" - - def __init__( - self, - image_bytes: Optional[bytes], - generation_parameters: Dict[str, Any], - ): - """Creates a `GeneratedImage` object. - - Args: - image_bytes: Image file bytes. Image can be in PNG or JPEG format. - generation_parameters: Image generation parameter values. - """ - super().__init__(image_bytes=image_bytes) - self._generation_parameters = generation_parameters - - @property - def generation_parameters(self): - """Image generation parameters as a dictionary.""" - return self._generation_parameters - - @staticmethod - def load_from_file(location: os.PathLike) -> "GeneratedImage": - """Loads image from file. - - Args: - location: Local path from where to load the image. - - Returns: - Loaded image as a `GeneratedImage` object. - """ - base_image = Image.load_from_file(location=location) - exif = base_image._pil_image.getexif() # pylint: disable=protected-access - exif_comment_dict = json.loads(exif[_EXIF_USER_COMMENT_TAG_IDX]) - generation_parameters = exif_comment_dict[_IMAGE_GENERATION_PARAMETERS_EXIF_KEY] - return GeneratedImage( - image_bytes=base_image._image_bytes, # pylint: disable=protected-access - generation_parameters=generation_parameters, - ) - - def save(self, location: str, include_generation_parameters: bool = True): - """Saves image to a file. - - Args: - location: Local path where to save the image. - include_generation_parameters: Whether to include the image - generation parameters in the image's EXIF metadata. - """ - if include_generation_parameters: - if not self._generation_parameters: - raise ValueError("Image does not have generation parameters.") - if not PIL_Image: - raise ValueError("The PIL module is required for saving generation parameters.") - - exif = self._pil_image.getexif() - exif[_EXIF_USER_COMMENT_TAG_IDX] = json.dumps( - {_IMAGE_GENERATION_PARAMETERS_EXIF_KEY: self._generation_parameters} - ) - self._pil_image.save(location, exif=exif) - else: - super().save(location=location) diff --git a/tests/test_content.py b/tests/test_content.py index 2031e40ae..b4a375f8b 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -22,6 +22,8 @@ from absl.testing import parameterized from google.generativeai import protos from google.generativeai.types import content_types +from google.generativeai.types import image_types +from google.generativeai.types.image_types import _image_types import IPython.display import PIL.Image @@ -90,7 +92,7 @@ class UnitTests(parameterized.TestCase): ["P", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8)).convert("P")], ) def test_numpy_to_blob(self, image): - blob = content_types.image_to_blob(image) + blob = _image_types.image_to_blob(image) self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/webp") self.assertStartsWith(blob.data, b"RIFF \x00\x00\x00WEBPVP8L") @@ -98,9 +100,10 @@ def test_numpy_to_blob(self, image): @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_PNG_PATH)], ["IPython", IPython.display.Image(filename=TEST_PNG_PATH)], + ["image_types.Image", image_types.Image.load_from_file(TEST_PNG_PATH)] ) def test_png_to_blob(self, image): - blob = content_types.image_to_blob(image) + blob = _image_types.image_to_blob(image) self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @@ -108,9 +111,10 @@ def test_png_to_blob(self, image): @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_JPG_PATH)], ["IPython", IPython.display.Image(filename=TEST_JPG_PATH)], + ["image_types.Image", image_types.Image.load_from_file(TEST_JPG_PATH)] ) def test_jpg_to_blob(self, image): - blob = content_types.image_to_blob(image) + blob = _image_types.image_to_blob(image) self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/jpeg") self.assertStartsWith(blob.data, b"\xff\xd8\xff\xe0\x00\x10JFIF") @@ -118,9 +122,10 @@ def test_jpg_to_blob(self, image): @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_GIF_PATH)], ["IPython", IPython.display.Image(filename=TEST_GIF_PATH)], + ["image_types.Image", image_types.Image.load_from_file(TEST_GIF_PATH)] ) def test_gif_to_blob(self, image): - blob = content_types.image_to_blob(image) + blob = _image_types.image_to_blob(image) self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/gif") self.assertStartsWith(blob.data, b"GIF87a") From c0a5c285c07fd5f2c1a48b8c86594eab87e5e41f Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 17 Oct 2024 14:20:03 -0700 Subject: [PATCH 02/12] types + formatting Change-Id: I0cac4ba1de764d3c02c5eab7556d8324aeda1f93 --- .../vision_models/_vision_models.py | 37 +++---------------- tests/test_content.py | 6 +-- 2 files changed, 8 insertions(+), 35 deletions(-) diff --git a/google/generativeai/vision_models/_vision_models.py b/google/generativeai/vision_models/_vision_models.py index 69e44ab5f..f89ab86e6 100644 --- a/google/generativeai/vision_models/_vision_models.py +++ b/google/generativeai/vision_models/_vision_models.py @@ -16,38 +16,12 @@ """Classes for working with vision models.""" import base64 -import collections import dataclasses -import io -import json -import os -import pathlib import typing -from typing import Any, Dict, List, Literal, Optional, Union +from typing import List, Literal, Optional from google.generativeai import client from google.generativeai.types import image_types -from google.generativeai import protos -from google.generativeai.types import content_types - - -# pylint: disable=g-import-not-at-top -if typing.TYPE_CHECKING: - from IPython import display as IPython_display -else: - try: - from IPython import display as IPython_display - except ImportError: - IPython_display = None - -if typing.TYPE_CHECKING: - import PIL.Image as PIL_Image -else: - try: - from PIL import Image as PIL_Image - except ImportError: - PIL_Image = None - AspectRatio = Literal["1:1", "9:16", "16:9", "4:3", "3:4"] ASPECT_RATIOS = AspectRatio.__args__ # type: ignore @@ -206,7 +180,7 @@ def _generate_images( model=self.model_name, instances=[instance], parameters=parameters ) - generated_images: List["GeneratedImage"] = [] + generated_images: List[image_types.GeneratedImage] = [] for idx, prediction in enumerate(response.predictions): generation_parameters = dict(shared_generation_parameters) generation_parameters["index_of_image_in_batch"] = idx @@ -288,13 +262,12 @@ class ImageGenerationResponse: __module__ = "vertexai.preview.vision_models" - images: List["GeneratedImage"] + images: List[image_types.GeneratedImage] - def __iter__(self) -> typing.Iterator["GeneratedImage"]: + def __iter__(self) -> typing.Iterator[image_types.GeneratedImage]: """Iterates through the generated images.""" yield from self.images - def __getitem__(self, idx: int) -> "GeneratedImage": + def __getitem__(self, idx: int) -> image_types.GeneratedImage: """Gets the generated image by index.""" return self.images[idx] - diff --git a/tests/test_content.py b/tests/test_content.py index b4a375f8b..8bec14a9c 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -100,7 +100,7 @@ def test_numpy_to_blob(self, image): @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_PNG_PATH)], ["IPython", IPython.display.Image(filename=TEST_PNG_PATH)], - ["image_types.Image", image_types.Image.load_from_file(TEST_PNG_PATH)] + ["image_types.Image", image_types.Image.load_from_file(TEST_PNG_PATH)], ) def test_png_to_blob(self, image): blob = _image_types.image_to_blob(image) @@ -111,7 +111,7 @@ def test_png_to_blob(self, image): @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_JPG_PATH)], ["IPython", IPython.display.Image(filename=TEST_JPG_PATH)], - ["image_types.Image", image_types.Image.load_from_file(TEST_JPG_PATH)] + ["image_types.Image", image_types.Image.load_from_file(TEST_JPG_PATH)], ) def test_jpg_to_blob(self, image): blob = _image_types.image_to_blob(image) @@ -122,7 +122,7 @@ def test_jpg_to_blob(self, image): @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_GIF_PATH)], ["IPython", IPython.display.Image(filename=TEST_GIF_PATH)], - ["image_types.Image", image_types.Image.load_from_file(TEST_GIF_PATH)] + ["image_types.Image", image_types.Image.load_from_file(TEST_GIF_PATH)], ) def test_gif_to_blob(self, image): blob = _image_types.image_to_blob(image) From 7207eebe5e3ef4b96e293bfbbfdc5834c3e2b7ec Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 24 Oct 2024 11:12:21 -0700 Subject: [PATCH 03/12] add files Change-Id: Ie7f91cef171c1f813b52ff1b2a4daedf7ea19edd --- .../types/image_types/__init__.py | 1 + .../types/image_types/_image_types.py | 332 ++++++++++++++++++ 2 files changed, 333 insertions(+) create mode 100644 google/generativeai/types/image_types/__init__.py create mode 100644 google/generativeai/types/image_types/_image_types.py diff --git a/google/generativeai/types/image_types/__init__.py b/google/generativeai/types/image_types/__init__.py new file mode 100644 index 000000000..6e9d0a3fe --- /dev/null +++ b/google/generativeai/types/image_types/__init__.py @@ -0,0 +1 @@ +from google.generativeai.types.image_types._image_types import * diff --git a/google/generativeai/types/image_types/_image_types.py b/google/generativeai/types/image_types/_image_types.py new file mode 100644 index 000000000..ba76a2632 --- /dev/null +++ b/google/generativeai/types/image_types/_image_types.py @@ -0,0 +1,332 @@ +import base64 +import io +import json +import mimetypes +import os +import pathlib +import typing +from typing import Any, Dict, Optional, Union + +from google.generativeai import protos +from google.generativeai import client + +# pylint: disable=g-import-not-at-top +if typing.TYPE_CHECKING: + import PIL.Image + import PIL.ImageFile + import IPython.display + + IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image) + ImageType = PIL.Image.Image | IPython.display.Image +else: + IMAGE_TYPES = () + try: + import PIL.Image + import PIL.ImageFile + + IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,) + except ImportError: + PIL = None + + try: + import IPython.display + + IMAGE_TYPES = IMAGE_TYPES + (IPython.display.Image,) + except ImportError: + IPython = None + + ImageType = Union["Image", "PIL.Image.Image", "IPython.display.Image"] +# pylint: enable=g-import-not-at-top + +__all__ = ["Image", "GeneratedImage", "check_watermark", "CheckWatermarkResult", "ImageType"] + + +def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob: + # If the image is a local file, return a file-based blob without any modification. + # Otherwise, return a lossless WebP blob (same quality with optimized size). + def file_blob(image: PIL.Image.Image) -> protos.Blob | None: + if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None: + return None + filename = str(image.filename) + if not pathlib.Path(filename).is_file(): + return None + + mime_type = image.get_format_mimetype() + image_bytes = pathlib.Path(filename).read_bytes() + + return protos.Blob(mime_type=mime_type, data=image_bytes) + + def webp_blob(image: PIL.Image.Image) -> protos.Blob: + # Reference: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp + image_io = io.BytesIO() + image.save(image_io, format="webp", lossless=True) + image_io.seek(0) + + mime_type = "image/webp" + image_bytes = image_io.read() + + return protos.Blob(mime_type=mime_type, data=image_bytes) + + return file_blob(image) or webp_blob(image) + + +def image_to_blob(image: ImageType) -> protos.Blob: + if PIL is not None: + if isinstance(image, PIL.Image.Image): + return _pil_to_blob(image) + + if IPython is not None: + if isinstance(image, IPython.display.Image): + name = image.filename + if name is None: + raise ValueError( + "Conversion failed. The `IPython.display.Image` can only be converted if " + "it is constructed from a local file. Please ensure you are using the format: Image(filename='...')." + ) + mime_type, _ = mimetypes.guess_type(name) + if mime_type is None: + mime_type = "image/unknown" + + return protos.Blob(mime_type=mime_type, data=image.data) + + if isinstance(image, Image): + return protos.Blob(mime_type=image._mime_type, data=image._image_bytes) + + raise TypeError( + "Image conversion failed. The input was expected to be of type `Image` " + "(either `PIL.Image.Image` or `IPython.display.Image`).\n" + f"However, received an object of type: {type(image)}.\n" + f"Object Value: {image}" + ) + + +class CheckWatermarkResult: + def __init__(self, predictions): + self._predictions = predictions + + @property + def decision(self): + return self._predictions[0]["decision"] + + def __str__(self): + return f"CheckWatermarkResult([{{'decision': {self.decision!r}}}])" + + def __bool__(self): + decision = self.decision + if decision == "ACCEPT": + return True + elif decision == "REJECT": + return False + else: + raise ValueError("Unrecognized result") + + +def check_watermark( + img: pathlib.Path | ImageType, model_id: str = "models/image-verification-001" +) -> "CheckWatermarkResult": + """Checks if an image has a Google-AI watermark. + + Args: + img: can be a `pathlib.Path` or a `PIL.Image.Image`, `IPythin.display.Image`, or `google.generativeai.Image`. + model_id: Which version of the image-verification model to send the image to. + + Returns: + + """ + if isinstance(img, Image): + pass + elif isinstance(img, pathlib.Path): + img = Image.load_from_file(img) + elif IPython.display is not None and isinstance(img, IPython.display.Image): + img = Image(image_bytes=img.data) + elif PIL.Image is not None and isinstance(img, PIL.Image.Image): + blob = _pil_to_blob(img) + img = Image(image_bytes=blob.data) + elif isinstance(img, protos.Blob): + img = Image(image_bytes=img.data) + else: + raise TypeError( + f"Not implemented: Could not convert a {type(img)} into `Image`\n {img=}" + ) + + prediction_client = client.get_default_prediction_client() + if not model_id.startswith("models/"): + model_id = f"models/{model_id}" + + instance = {"image": {"bytesBase64Encoded": base64.b64encode(img._loaded_bytes).decode()}} + parameters = {"watermarkVerification": True} + + response = prediction_client.predict( + model=model_id, instances=[instance], parameters=parameters + ) + + return CheckWatermarkResult(response.predictions) + + +class Image: + """Image.""" + + __module__ = "vertexai.vision_models" + + _loaded_bytes: Optional[bytes] = None + _loaded_image: Optional["PIL_Image.Image"] = None + + def __init__( + self, + image_bytes: Optional[bytes], + ): + """Creates an `Image` object. + + Args: + image_bytes: Image file bytes. Image can be in PNG or JPEG format. + """ + self._image_bytes = image_bytes + + @staticmethod + def load_from_file(location: os.PathLike) -> "Image": + """Loads image from local file or Google Cloud Storage. + + Args: + location: Local path or Google Cloud Storage uri from where to load + the image. + + Returns: + Loaded image as an `Image` object. + """ + # Load image from local path + image_bytes = pathlib.Path(location).read_bytes() + image = Image(image_bytes=image_bytes) + return image + + @property + def _image_bytes(self) -> bytes: + return self._loaded_bytes + + @_image_bytes.setter + def _image_bytes(self, value: bytes): + self._loaded_bytes = value + + @property + def _pil_image(self) -> "PIL_Image.Image": # type: ignore + if self._loaded_image is None: + if not PIL: + raise RuntimeError( + "The PIL module is not available. Please install the Pillow package." + ) + self._loaded_image = PIL.Image.open(io.BytesIO(self._image_bytes)) + return self._loaded_image + + @property + def _size(self): + return self._pil_image.size + + @property + def _mime_type(self) -> str: + """Returns the MIME type of the image.""" + import PIL + + return PIL.Image.MIME.get(self._pil_image.format, "image/jpeg") + + def show(self): + """Shows the image. + + This method only works when in a notebook environment. + """ + if PIL and IPython: + IPython.display.display(self._pil_image) + + def save(self, location: str): + """Saves image to a file. + + Args: + location: Local path where to save the image. + """ + pathlib.Path(location).write_bytes(self._image_bytes) + + def _as_base64_string(self) -> str: + """Encodes image using the base64 encoding. + + Returns: + Base64 encoding of the image as a string. + """ + # ! b64encode returns `bytes` object, not `str`. + # We need to convert `bytes` to `str`, otherwise we get service error: + # "received initial metadata size exceeds limit" + return base64.b64encode(self._image_bytes).decode("ascii") + + def _repr_png_(self): + return self._pil_image._repr_png_() # type:ignore + + check_watermark = check_watermark + + +_EXIF_USER_COMMENT_TAG_IDX = 0x9286 +_IMAGE_GENERATION_PARAMETERS_EXIF_KEY = ( + "google.cloud.vertexai.image_generation.image_generation_parameters" +) + + +class GeneratedImage(Image): + """Generated image.""" + + __module__ = "google.generativeai" + + def __init__( + self, + image_bytes: Optional[bytes], + generation_parameters: Dict[str, Any], + ): + """Creates a `GeneratedImage` object. + + Args: + image_bytes: Image file bytes. Image can be in PNG or JPEG format. + generation_parameters: Image generation parameter values. + """ + super().__init__(image_bytes=image_bytes) + self._generation_parameters = generation_parameters + + @property + def generation_parameters(self): + """Image generation parameters as a dictionary.""" + return self._generation_parameters + + @staticmethod + def load_from_file(location: os.PathLike) -> "GeneratedImage": + """Loads image from file. + + Args: + location: Local path from where to load the image. + + Returns: + Loaded image as a `GeneratedImage` object. + """ + base_image = Image.load_from_file(location=location) + exif = base_image._pil_image.getexif() # pylint: disable=protected-access + exif_comment_dict = json.loads(exif[_EXIF_USER_COMMENT_TAG_IDX]) + generation_parameters = exif_comment_dict[_IMAGE_GENERATION_PARAMETERS_EXIF_KEY] + return GeneratedImage( + image_bytes=base_image._image_bytes, # pylint: disable=protected-access + generation_parameters=generation_parameters, + ) + + def save(self, location: str, include_generation_parameters: bool = True): + """Saves image to a file. + + Args: + location: Local path where to save the image. + include_generation_parameters: Whether to include the image + generation parameters in the image's EXIF metadata. + """ + if include_generation_parameters: + if not self._generation_parameters: + raise ValueError("Image does not have generation parameters.") + if not PIL: + raise ValueError("The PIL module is required for saving generation parameters.") + + exif = self._pil_image.getexif() + exif[_EXIF_USER_COMMENT_TAG_IDX] = json.dumps( + {_IMAGE_GENERATION_PARAMETERS_EXIF_KEY: self._generation_parameters} + ) + self._pil_image.save(location, exif=exif) + else: + super().save(location=location) From 55abd7c0ce5231940c7a4b72600e5d400d137a50 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 24 Oct 2024 11:14:37 -0700 Subject: [PATCH 04/12] Fix 3.9 Change-Id: If9ff9ebc0b2bf16b91e741d862a9e2808c7a738a --- google/generativeai/types/image_types/_image_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/generativeai/types/image_types/_image_types.py b/google/generativeai/types/image_types/_image_types.py index ba76a2632..914d015eb 100644 --- a/google/generativeai/types/image_types/_image_types.py +++ b/google/generativeai/types/image_types/_image_types.py @@ -122,7 +122,7 @@ def __bool__(self): def check_watermark( - img: pathlib.Path | ImageType, model_id: str = "models/image-verification-001" + img: Union[pathlib.Path, ImageType], model_id: str = "models/image-verification-001" ) -> "CheckWatermarkResult": """Checks if an image has a Google-AI watermark. From 7b4d0de21abca717acd73b3084d3d05bf8a7cc93 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 24 Oct 2024 11:16:26 -0700 Subject: [PATCH 05/12] Fix 3.9 Change-Id: Iee02352ca21fa66da9b097d4dfa9454b67609e79 --- google/generativeai/types/image_types/_image_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/generativeai/types/image_types/_image_types.py b/google/generativeai/types/image_types/_image_types.py index 914d015eb..170821353 100644 --- a/google/generativeai/types/image_types/_image_types.py +++ b/google/generativeai/types/image_types/_image_types.py @@ -44,7 +44,7 @@ def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob: # If the image is a local file, return a file-based blob without any modification. # Otherwise, return a lossless WebP blob (same quality with optimized size). - def file_blob(image: PIL.Image.Image) -> protos.Blob | None: + def file_blob(image: PIL.Image.Image) -> Union[protos.Blob, None]: if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None: return None filename = str(image.filename) From 0f5e83da8091fccd883d6ee53328dcac139da918 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 24 Oct 2024 11:23:53 -0700 Subject: [PATCH 06/12] fix pytype Change-Id: Ic5c250f3f3ded2374abfbdbee6d62ea4cfb0f799 --- google/generativeai/client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/google/generativeai/client.py b/google/generativeai/client.py index 69492c1cd..7e6099bf3 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -3,6 +3,7 @@ import os import contextlib import inspect +import collections import dataclasses import pathlib from typing import Any, cast From 1fa954bb59c7e922ecb985a0a286790cdd3d69f1 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 24 Oct 2024 11:40:21 -0700 Subject: [PATCH 07/12] fix pytype Change-Id: I431c66e45e7582218b5de7a90eeeee01b80df664 --- google/generativeai/types/image_types/_image_types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/google/generativeai/types/image_types/_image_types.py b/google/generativeai/types/image_types/_image_types.py index 170821353..9f1c30257 100644 --- a/google/generativeai/types/image_types/_image_types.py +++ b/google/generativeai/types/image_types/_image_types.py @@ -169,7 +169,7 @@ class Image: __module__ = "vertexai.vision_models" _loaded_bytes: Optional[bytes] = None - _loaded_image: Optional["PIL_Image.Image"] = None + _loaded_image: Optional["PIL.Image.Image"] = None def __init__( self, @@ -207,7 +207,7 @@ def _image_bytes(self, value: bytes): self._loaded_bytes = value @property - def _pil_image(self) -> "PIL_Image.Image": # type: ignore + def _pil_image(self) -> "PIL.Image.Image": # type: ignore if self._loaded_image is None: if not PIL: raise RuntimeError( From 584b2ef0765a3dc6c8631455dfeb16db4d4bacd0 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 24 Oct 2024 12:05:35 -0700 Subject: [PATCH 08/12] typo Change-Id: I1bb15e1363c652f9c0b4a60dad834fce65a4f0a1 --- google/generativeai/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/generativeai/client.py b/google/generativeai/client.py index 7e6099bf3..f0dd2dedd 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -299,7 +299,7 @@ def make_client(self, name): "prediction": PredictionServiceClient, "prediction_async": PredictionServiceAsyncClient, } - cls = local_clients.get("name", None) + cls = local_clients.get(name, None) if cls is None: if name.endswith("_async"): From be006e7e7dc4f4c01c4b0538367fe9fa8ac2c84f Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 29 Oct 2024 11:49:42 -0700 Subject: [PATCH 09/12] reapply commits lost in merge Change-Id: I7bfebdeaa217d93ed5d11aca31cf0b20afd38c02 --- google/generativeai/types/image_types/_image_types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/google/generativeai/types/image_types/_image_types.py b/google/generativeai/types/image_types/_image_types.py index 9f1c30257..d8935a507 100644 --- a/google/generativeai/types/image_types/_image_types.py +++ b/google/generativeai/types/image_types/_image_types.py @@ -118,7 +118,7 @@ def __bool__(self): elif decision == "REJECT": return False else: - raise ValueError("Unrecognized result") + raise ValueError(f"Unrecognized result: {decision}") def check_watermark( @@ -127,7 +127,7 @@ def check_watermark( """Checks if an image has a Google-AI watermark. Args: - img: can be a `pathlib.Path` or a `PIL.Image.Image`, `IPythin.display.Image`, or `google.generativeai.Image`. + img: can be a `pathlib.Path` or a `PIL.Image.Image`, `IPython.display.Image`, or `google.generativeai.Image`. model_id: Which version of the image-verification model to send the image to. Returns: From 650d74e7f13dc8a6262ec11154e9703907bee064 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 29 Oct 2024 15:12:15 -0700 Subject: [PATCH 10/12] Update google/generativeai/client.py --- google/generativeai/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/google/generativeai/client.py b/google/generativeai/client.py index f0dd2dedd..755fb3ff5 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -182,7 +182,6 @@ def to_mapping_value(value) -> struct_pb2.Struct: return struct_pb2.Struct(fields={k: to_value(v) for k, v in value.items()}) -# This is to get around https://github.com/googleapis/proto-plus-python/issues/488 class PredictionServiceClient(glm.PredictionServiceClient): From 17ceab589435210903cdd62c7ad6f1c50fab77c4 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 29 Oct 2024 15:35:06 -0700 Subject: [PATCH 11/12] Remove GCS reference Change-Id: I5c1b8cbccee0e13d8aca70582a76e0c089e040ed --- google/generativeai/types/image_types/_image_types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/google/generativeai/types/image_types/_image_types.py b/google/generativeai/types/image_types/_image_types.py index d8935a507..ddfea057f 100644 --- a/google/generativeai/types/image_types/_image_types.py +++ b/google/generativeai/types/image_types/_image_types.py @@ -184,10 +184,10 @@ def __init__( @staticmethod def load_from_file(location: os.PathLike) -> "Image": - """Loads image from local file or Google Cloud Storage. + """Loads image from local file. Args: - location: Local path or Google Cloud Storage uri from where to load + location: Local path from where to load the image. Returns: From 613319f08368ce7c5dd115b04664f53a323bc957 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 29 Oct 2024 16:06:14 -0700 Subject: [PATCH 12/12] black . Change-Id: I2c24f8798cb8103d35474e7e6d2e4fc3100825aa --- google/generativeai/client.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/google/generativeai/client.py b/google/generativeai/client.py index 755fb3ff5..f53007be3 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -182,8 +182,6 @@ def to_mapping_value(value) -> struct_pb2.Struct: return struct_pb2.Struct(fields={k: to_value(v) for k, v in value.items()}) - - class PredictionServiceClient(glm.PredictionServiceClient): def predict(self, model=None, instances=None, parameters=None): pr = protos.PredictRequest.pb()