diff --git a/.gitignore b/.gitignore index f6e6bdba8..9d726916a 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,4 @@ *.egg-info .DS_Store __pycache__ -*.iml +*.iml \ No newline at end of file diff --git a/google/generativeai/client.py b/google/generativeai/client.py index 2a8b15a20..7d655a5b3 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -168,7 +168,6 @@ def get_default_operations_client(self) -> operations_v1.OperationsClient: model_client = self.get_default_client("Model") client = model_client._transport.operations_client self.clients["operations"] = client - return client @@ -244,3 +243,11 @@ def get_default_operations_client() -> operations_v1.OperationsClient: def get_default_model_client() -> glm.ModelServiceAsyncClient: return _client_manager.get_default_client("model") + + +def get_default_retriever_client() -> glm.RetrieverClient: + return _client_manager.get_default_client("retriever") + + +def get_default_retriever_async_client() -> glm.RetrieverAsyncClient: + return _client_manager.get_default_client("retriever_async") diff --git a/google/generativeai/models.py b/google/generativeai/models.py index 46a412f8e..b975df236 100644 --- a/google/generativeai/models.py +++ b/google/generativeai/models.py @@ -24,6 +24,7 @@ from google.api_core import operation from google.api_core import protobuf_helpers from google.protobuf import field_mask_pb2 +from google.generativeai.utils import flatten_update_paths def get_model( @@ -351,7 +352,7 @@ def update_tuned_model( ) tuned_model = client.get_tuned_model(name=name) - updates = _flatten_update_paths(updates) + updates = flatten_update_paths(updates) field_mask = field_mask_pb2.FieldMask() for path in updates.keys(): field_mask.paths.append(path) @@ -379,18 +380,6 @@ def update_tuned_model( return model_types.decode_tuned_model(result) -def _flatten_update_paths(updates): - new_updates = {} - for key, value in updates.items(): - if isinstance(value, dict): - for sub_key, sub_value in _flatten_update_paths(value).items(): - new_updates[f"{key}.{sub_key}"] = sub_value - else: - new_updates[key] = value - - return new_updates - - def _apply_update(thing, path, value): parts = path.split(".") for part in parts[:-1]: diff --git a/google/generativeai/retriever.py b/google/generativeai/retriever.py new file mode 100644 index 000000000..ece43e207 --- /dev/null +++ b/google/generativeai/retriever.py @@ -0,0 +1,224 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import re +import string +import dataclasses +from typing import Optional + +import google.ai.generativelanguage as glm + +from google.generativeai.client import get_default_retriever_client +from google.generativeai.client import get_default_retriever_async_client +from google.generativeai import string_utils +from google.generativeai.types import retriever_types +from google.generativeai.types import model_types +from google.generativeai import models +from google.generativeai.types import safety_types +from google.generativeai.types.model_types import idecode_time + +_CORPORA_NAME_REGEX = re.compile(r"^corpora/[a-z0-9-]+") +_REMOVE = string.punctuation +_REMOVE = _REMOVE.replace("-", "") # Don't remove hyphens +_PATTERN = r"[{}]".format(_REMOVE) # Create the pattern + + +@string_utils.prettyprint +@dataclasses.dataclass(init=False) +class Corpus(retriever_types.Corpus): + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + self.result = None + if self.name: + self.result = self.name + + +def create_corpus( + name: Optional[str] = None, + display_name: Optional[str] = None, + client: glm.RetrieverServiceClient | None = None, +) -> Corpus: + """ + Create a Corpus object. Users can specify either a name or display_name. + + Args: + name: The corpus resource name (ID). The name must be alphanumeric and fewer + than 40 characters. + display_name: The human readable display name. The display name must be fewer + than 128 characters. All characters, including alphanumeric, spaces, and + dashes are supported. + + Return: + Corpus object with specified name or display name. + + Raises: + ValueError: When the name is not specified or formatted incorrectly. + """ + if client is None: + client = get_default_retriever_client() + + if not name and not display_name: + raise ValueError("Either the corpus name or display name must be specified.") + + corpus = None + if name: + if re.match(_CORPORA_NAME_REGEX, name): + corpus = glm.Corpus(name=name, display_name=display_name) + elif "corpora/" not in name: + corpus_name = "corpora/" + re.sub(_PATTERN, "", name) + corpus = glm.Corpus(name=corpus_name, display_name=display_name) + else: + raise ValueError("Corpus name must be formatted as corpora/.") + + request = glm.CreateCorpusRequest(corpus=corpus) + response = client.create_corpus(request) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + response = Corpus(**response) + return response + + +async def create_corpus_async( + name: Optional[str] = None, + display_name: Optional[str] = None, + client: glm.RetrieverServiceAsyncClient | None = None, +) -> Corpus: + """This is the async version of `create_corpus`.""" + if client is None: + client = get_default_retriever_async_client() + + if not name and not display_name: + raise ValueError("Either the corpus name or display name must be specified.") + + corpus = None + if name: + if re.match(_CORPORA_NAME_REGEX, name): + corpus = glm.Corpus(name=name, display_name=display_name) + elif "corpora/" not in name: + corpus_name = "corpora/" + re.sub(_PATTERN, "", name) + corpus = glm.Corpus(name=corpus_name, display_name=display_name) + else: + raise ValueError("Corpus name must be formatted as corpora/.") + + request = glm.CreateCorpusRequest(corpus=corpus) + response = await client.create_corpus(request) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + response = Corpus(**response) + return response + + +def get_corpus(name: str, client: glm.RetrieverServiceClient | None = None) -> Corpus: # fmt: skip + """ + Get information about a specific `Corpus`. + + Args: + name: The `Corpus` name. + + Return: + `Corpus` of interest. + """ + if client is None: + client = get_default_retriever_client() + + request = glm.GetCorpusRequest(name=name) + response = client.get_corpus(request) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + response = Corpus(**response) + return response + + +async def get_corpus_async(name: str, client: glm.RetrieverServiceAsyncClient | None = None) -> Corpus: # fmt: skip + """This is the async version of `get_corpus`.""" + if client is None: + client = get_default_retriever_async_client() + + request = glm.GetCorpusRequest(name=name) + response = await client.get_corpus(request) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + response = Corpus(**response) + return response + + +def delete_corpus(name: str, force: bool, client: glm.RetrieverServiceClient | None = None): # fmt: skip + """ + Delete a `Corpus`. + + Args: + name: The `Corpus` name. + force: If set to true, any `Document`s and objects related to this `Corpus` will also be deleted. + """ + if client is None: + client = get_default_retriever_client() + + request = glm.DeleteCorpusRequest(name=name, force=force) + client.delete_corpus(request) + + +async def delete_corpus_async(name: str, force: bool, client: glm.RetrieverServiceAsyncClient | None = None): # fmt: skip + """This is the async version of `delete_corpus`.""" + if client is None: + client = get_default_retriever_async_client() + + request = glm.DeleteCorpusRequest(name=name, force=force) + await client.delete_corpus(request) + + +def list_corpora( + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + client: glm.RetrieverServiceClient | None = None, +) -> list[Corpus]: + """ + List `Corpus`. + + Args: + page_size: Maximum number of `Corpora` to request. + page_token: A page token, received from a previous ListCorpora call. + + Return: + Paginated list of `Corpora`. + """ + if client is None: + client = get_default_retriever_client() + + request = glm.ListCorporaRequest(page_size=page_size, page_token=page_token) + response = client.list_corpora(request) + return response + + +async def list_corpora_async( + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + client: glm.RetrieverServiceClient | None = None, +) -> list[Corpus]: + """This is the async version of `list_corpora`.""" + if client is None: + client = get_default_retriever_async_client() + + request = glm.ListCorporaRequest(page_size=page_size, page_token=page_token) + response = await client.list_corpora(request) + return response diff --git a/google/generativeai/types/embedding_types.py b/google/generativeai/types/embedding_types.py new file mode 100644 index 000000000..86e996674 --- /dev/null +++ b/google/generativeai/types/embedding_types.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import abc +import dataclasses +from typing import Any, Dict, List, TypedDict + + +class EmbeddingDict(TypedDict): + embedding: list[float] + + +class BatchEmbeddingDict(TypedDict): + embedding: list[list[float]] diff --git a/google/generativeai/types/retriever_types.py b/google/generativeai/types/retriever_types.py new file mode 100644 index 000000000..19c8e8134 --- /dev/null +++ b/google/generativeai/types/retriever_types.py @@ -0,0 +1,1279 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import re +import string +import abc +import dataclasses +from typing import Any, Optional, Union, Iterable, Mapping + +import google.ai.generativelanguage as glm + +from google.protobuf import field_mask_pb2 +from google.generativeai.client import get_default_retriever_client +from google.generativeai.client import get_default_retriever_async_client +from google.generativeai import string_utils +from google.generativeai.types import safety_types +from google.generativeai.types import citation_types +from google.generativeai.types.model_types import idecode_time +from google.generativeai.utils import flatten_update_paths + + +_DOCUMENT_NAME_REGEX = re.compile(r"^corpora/[a-z0-9-]+/documents/[a-z0-9-]+$") +_CHUNK_NAME_REGEX = re.compile(r"^corpora/([^/]+?)(/documents/([^/]+?)(/chunks/([^/]+?))?)?$") +_REMOVE = string.punctuation +_REMOVE = _REMOVE.replace("-", "") # Don't remove hyphens +_PATTERN = r"[{}]".format(_REMOVE) # Create the pattern + +Operator = glm.Condition.Operator +State = glm.Chunk.State + +OperatorOptions = Union[str, int, Operator] +StateOptions = Union[str, int, State] + +CreateChunkOptions = Union[Mapping[str, str], tuple[str, str]] + +BatchCreateChunkOptions = Union[ + glm.BatchCreateChunksRequest, + list[glm.CreateChunkRequest], + Iterable[str], + Iterable[CreateChunkOptions], +] + +UpdateChunkOptions = Union[glm.UpdateChunkRequest, Mapping[str, Any], tuple[str, Any]] + +BatchUpdateChunksOptions = Union[glm.BatchUpdateChunksRequest, Iterable[UpdateChunkOptions]] + +BatchDeleteChunkOptions = Union[list[glm.DeleteChunkRequest], Iterable[str]] + +_OPERATOR: dict[OperatorOptions, Operator] = { + Operator.OPERATOR_UNSPECIFIED: Operator.OPERATOR_UNSPECIFIED, + 0: Operator.OPERATOR_UNSPECIFIED, + "operator_unspecified": Operator.OPERATOR_UNSPECIFIED, + "unspecified": Operator.OPERATOR_UNSPECIFIED, + Operator.LESS: Operator.LESS, + 1: Operator.LESS, + "operator_less": Operator.LESS, + "less": Operator.LESS, + "<": Operator.LESS, + Operator.LESS_EQUAL: Operator.LESS_EQUAL, + 2: Operator.LESS_EQUAL, + "operator_less_equal": Operator.LESS_EQUAL, + "less_equal": Operator.LESS_EQUAL, + "<=": Operator.LESS_EQUAL, + Operator.EQUAL: Operator.EQUAL, + 3: Operator.EQUAL, + "operator_equal": Operator.EQUAL, + "equal": Operator.EQUAL, + "==": Operator.EQUAL, + Operator.GREATER_EQUAL: Operator.GREATER_EQUAL, + 4: Operator.GREATER_EQUAL, + "operator_greater_equal": Operator.GREATER_EQUAL, + "greater_equal": Operator.GREATER_EQUAL, + Operator.NOT_EQUAL: Operator.NOT_EQUAL, + 5: Operator.NOT_EQUAL, + "operator_not_equal": Operator.NOT_EQUAL, + "not_equal": Operator.NOT_EQUAL, + "!=": Operator.NOT_EQUAL, + Operator.INCLUDES: Operator.INCLUDES, + 6: Operator.INCLUDES, + "operator_includes": Operator.INCLUDES, + "includes": Operator.INCLUDES, + Operator.EXCLUDES: Operator.EXCLUDES, + 6: Operator.EXCLUDES, + "operator_excludes": Operator.EXCLUDES, + "excludes": Operator.EXCLUDES, + "not in": Operator.EXCLUDES, +} + +_STATE: dict[StateOptions, State] = { + State.STATE_UNSPECIFIED: State.STATE_UNSPECIFIED, + "0": State.STATE_UNSPECIFIED, + "state_unspecifed": State.STATE_UNSPECIFIED, + "unspecified": State.STATE_UNSPECIFIED, + State.STATE_PENDING_PROCESSING: State.STATE_PENDING_PROCESSING, + "1": State.STATE_PENDING_PROCESSING, + "pending_processing": State.STATE_PENDING_PROCESSING, + "pending": State.STATE_PENDING_PROCESSING, + State.STATE_ACTIVE: State.STATE_ACTIVE, + "2": State.STATE_ACTIVE, + "state_active": State.STATE_ACTIVE, + "active": State.STATE_ACTIVE, + State.STATE_FAILED: State.STATE_FAILED, + "10": State.STATE_FAILED, # TODO: This is specified as 10 in the proto, should it be 3 or 10? + "state_failed": State.STATE_FAILED, + "failed": State.STATE_FAILED, +} + + +def to_operator(x: OperatorOptions) -> Operator: + if isinstance(x, str): + x = x.lower() + return _OPERATOR[x] + + +def to_state(x: StateOptions) -> State: + if isinstance(x, str): + x = x.lower() + return _STATE[x] + + +@string_utils.prettyprint +@dataclasses.dataclass +class MetadataFilters: + key: str + conditions: Condition + + +@string_utils.prettyprint +@dataclasses.dataclass +class Condition: + value: str | float + + +@string_utils.prettyprint +@dataclasses.dataclass +class CustomMetadata: + key: str + string_value: str + string_list_value: list[str] + numeric_value: float + + +@string_utils.prettyprint +@dataclasses.dataclass +class ChunkData: + string_value: str + + +@string_utils.prettyprint +@dataclasses.dataclass() +class Corpus: + """ + A `Corpus` is a collection of `Documents`. + """ + + name: str + display_name: str + + def create_document( + self, + name: Optional[str] = None, + display_name: Optional[str] = None, + custom_metadata: Optional[list[CustomMetadata]] = None, + client: glm.RetrieverServiceClient | None = None, + ) -> Document: + """ + Request to create a `Document`. + + Args: + name: The `Document` resource name. The ID (name excluding the "corpora/*/documents/" prefix) can contain up to 40 characters + that are lowercase alphanumeric or dashes (-). The ID cannot start or end with a dash. + display_name: The human-readable display name for the `Document`. + custom_metadata: User provided custom metadata stored as key-value pairs used for querying. + + Return: + Document object with specified name or display name. + + Raises: + ValueError: When the name is not specified or formatted incorrectly. + """ + if client is None: + client = get_default_retriever_client() + + if not name and not display_name: + raise ValueError("Either the document name or display name must be specified.") + + document = None + if name: + if re.match(_DOCUMENT_NAME_REGEX, name): + document = glm.Document( + name=name, display_name=display_name, custom_metadata=custom_metadata + ) + elif f"corpora/{self.name}/documents/" not in name: + document_name = f"{self.name}/documents/" + re.sub(_PATTERN, "", name) + document = glm.Document( + name=document_name, display_name=display_name, custom_metadata=custom_metadata + ) + else: + raise ValueError( + f"Document name must be formatted as {self.name}/document/." + ) + + request = glm.CreateDocumentRequest(parent=self.name, document=document) + response = client.create_document(request) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + response = Document(**response) + return response + + async def create_document_async( + self, + name: Optional[str] = None, + display_name: Optional[str] = None, + custom_metadata: Optional[list[CustomMetadata]] = None, + client: glm.RetrieverServiceAsyncClient | None = None, + ) -> Document: + """This is the async version of `Corpus.create_document`.""" + if client is None: + client = get_default_retriever_async_client() + + if not name and not display_name: + raise ValueError("Either the document name or display name must be specified.") + + document = None + if name: + if re.match(_DOCUMENT_NAME_REGEX, name): + document = glm.Document( + name=name, display_name=display_name, custom_metadata=custom_metadata + ) + elif f"corpora/{self.name}/documents/" not in name: + document_name = f"{self.name}/documents/" + re.sub(_PATTERN, "", name) + document = glm.Document( + name=document_name, display_name=display_name, custom_metadata=custom_metadata + ) + else: + raise ValueError( + f"Document name must be formatted as {self.name}/document/." + ) + + request = glm.CreateDocumentRequest(parent=self.name, document=document) + response = await client.create_document(request) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + response = Document(**response) + return response + + def get_document( + self, + name: str, + client: glm.RetrieverServiceClient | None = None, + ) -> Document: + """ + Get information about a specific `Document`. + + Args: + name: The `Document` name. + + Return: + `Document` of interest. + """ + if client is None: + client = get_default_retriever_client() + + request = glm.GetDocumentRequest(name=name) + response = client.get_document(request) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + response = Document(**response) + return response + + async def get_document_async( + self, + name: str, + client: glm.RetrieverServiceAsyncClient | None = None, + ) -> Document: + """This is the async version of `Corpus.get_document`.""" + if client is None: + client = get_default_retriever_async_client() + + request = glm.GetDocumentRequest(name=name) + response = await client.get_document(request) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + response = Document(**response) + return response + + def _apply_update(self, path, value): + parts = path.split(".") + for part in parts[:-1]: + self = getattr(self, part) + setattr(self, parts[-1], value) + + def update( + self, + updates: dict[str, Any], + client: glm.RetrieverServiceClient | None = None, + ): + """ + Update a list of fields for a specified `Corpus`. + + Args: + updates: List of fields to update in a `Corpus`. + + Return: + Updated version of the `Corpus` object. + """ + if client is None: + client = get_default_retriever_client() + + updates = flatten_update_paths(updates) + field_mask = field_mask_pb2.FieldMask() + + for path in updates.keys(): + field_mask.paths.append(path) + for path, value in updates.items(): + self._apply_update(path, value) + + request = glm.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) + response = client.update_corpus(request) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + return self + + async def update_async( + self, + updates: dict[str, Any], + client: glm.RetrieverServiceAsyncClient | None = None, + ): + """This is the async version of `Corpus.update`.""" + if client is None: + client = get_default_retriever_async_client() + + updates = flatten_update_paths(updates) + field_mask = field_mask_pb2.FieldMask() + + for path in updates.keys(): + field_mask.paths.append(path) + for path, value in updates.items(): + self._apply_update(path, value) + + request = glm.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) + response = await client.update_corpus(request) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + return self + + def query( + self, + query: str, + metadata_filters: Optional[list[str]] = None, + results_count: Optional[int] = None, + client: glm.RetrieverServiceClient | None = None, + ): + """ + Query a corpus for information. + + Args: + query: Query string to perform semantic search. + metadata_filters: Filter for `Chunk` metadata. + results_count: The maximum number of `Chunk`s to return; must be less than 100. + + Returns: + List of relevant chunks. + """ + if client is None: + client = get_default_retriever_client() + + if results_count: + if results_count > 100: + raise ValueError("Number of results returned must be between 1 and 100.") + + request = glm.QueryCorpusRequest( + name=self.name, + query=query, + metadata_filters=metadata_filters, + results_count=results_count, + ) + response = client.query_corpus(request) + response = type(response).to_dict(response) + + return response + + async def query_async( + self, + query: str, + metadata_filters: Optional[list[str]] = None, + results_count: Optional[int] = None, + client: glm.RetrieverServiceAsyncClient | None = None, + ): + """This is the async version of `Corpus.query`.""" + if client is None: + client = get_default_retriever_async_client() + + if results_count: + if results_count > 100: + raise ValueError("Number of results returned must be between 1 and 100.") + + request = glm.QueryCorpusRequest( + name=self.name, + query=query, + metadata_filters=metadata_filters, + results_count=results_count, + ) + response = await client.query_corpus(request) + response = type(response).to_dict(response) + + return response + + def delete_document( + self, + name: str, + force: Optional[bool] = None, + client: glm.RetrieverServiceClient | None = None, + ): + """ + Delete a document in the corpus. + + Args: + name: The `Document` name. + force: If set to true, any `Chunk`s and objects related to this `Document` will also be deleted. + """ + if client is None: + client = get_default_retriever_client() + + if force: + request = glm.DeleteDocumentRequest(name=name, force=force) + else: + request = glm.DeleteDocumentRequest(name=name) + + client.delete_document(request) + + async def delete_document_async( + self, + name: str, + force: Optional[bool] = None, + client: glm.RetrieverServiceAsyncClient | None = None, + ): + """This is the async version of `Corpus.delete_document`.""" + if client is None: + client = get_default_retriever_async_client() + + if force: + request = glm.DeleteDocumentRequest(name=name, force=force) + else: + request = glm.DeleteDocumentRequest(name=name) + + await client.delete_document(request) + + def list_documents( + self, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + client: glm.RetrieverServiceClient | None = None, + ) -> list[Document]: + """ + List documents in corpus. + + Args: + name: The name of the `Corpus` containing `Document`s. + page_size: The maximum number of `Document`s to return (per page). The service may return fewer `Document`s. + page_token: A page token, received from a previous `ListDocuments` call. + + Return: + Paginated list of `Document`s. + """ + if client is None: + client = get_default_retriever_client() + + request = glm.ListDocumentsRequest( + parent=self.name, page_size=page_size, page_token=page_token + ) + response = client.list_documents(request) + return response + + async def list_documents_async( + self, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + client: glm.RetrieverServiceAsyncClient | None = None, + ) -> list[Document]: + """This is the async version of `Corpus.list_documents`.""" + if client is None: + client = get_default_retriever_async_client() + + request = glm.ListDocumentsRequest( + parent=self.name, page_size=page_size, page_token=page_token + ) + response = await client.list_documents(request) + return response + + def to_dict(self) -> dict[str, Any]: + result = {"name": self.name, "display_name": self.display_name} + return result + + +@string_utils.prettyprint +@dataclasses.dataclass() +class Document(abc.ABC): + """ + A `Document` is a collection of `Chunk`s. + """ + + name: str + display_name: str + custom_metadata: list[CustomMetadata] + + def create_chunk( + self, + name: Optional[str], + data: str | ChunkData, + custom_metadata: Optional[list[CustomMetadata]] = None, + client: glm.RetrieverServiceClient | None = None, + ) -> Chunk: + """ + Create a `Chunk` object which has textual data. + + Args: + name: The `Chunk` resource name. The ID (name excluding the "corpora/*/documents/*/chunks/" prefix) can contain up to 40 characters that are lowercase alphanumeric or dashes (-). + data: The content for the `Chunk`, such as the text string. + custom_metadata: User provided custom metadata stored as key-value pairs. + state: States for the lifecycle of a `Chunk`. + + Return: + `Chunk` object with specified data. + + Raises: + ValueError when chunk name not specified correctly. + """ + if client is None: + client = get_default_retriever_client() + + chunk_name, chunk = "", None + if name: + if re.match(_CHUNK_NAME_REGEX, name): + chunk_name = name + + elif "chunks/" not in name: + chunk_name = f"{self.name}/chunks/" + re.sub(_PATTERN, "", name) + else: + raise ValueError( + f"Chunk name must be formatted as {self.name}/chunks/." + ) + + if isinstance(data, str): + chunk = glm.Chunk( + name=chunk_name, data={"string_value": data}, custom_metadata=custom_metadata + ) + else: + chunk = glm.Chunk( + name=chunk_name, + data={"string_value": data}, + custom_metadata=custom_metadata, + ) + + request = glm.CreateChunkRequest(parent=self.name, chunk=chunk) + response = client.create_chunk(request) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + response = Chunk(**response) + return response + + async def create_chunk_async( + self, + name: Optional[str], + data: str | ChunkData, + custom_metadata: Optional[list[CustomMetadata]] = None, + client: glm.RetrieverServiceAsyncClient | None = None, + ) -> Chunk: + """This is the async version of `Document.create_chunk`.""" + if client is None: + client = get_default_retriever_async_client() + + chunk_name, chunk = "", None + if name: + if re.match(_CHUNK_NAME_REGEX, name): + chunk_name = name + + elif "chunks/" not in name: + chunk_name = f"{self.name}/chunks/" + re.sub(_PATTERN, "", name) + else: + raise ValueError( + f"Chunk name must be formatted as {self.name}/chunks/." + ) + + if isinstance(data, str): + chunk = glm.Chunk( + name=chunk_name, data={"string_value": data}, custom_metadata=custom_metadata + ) + else: + chunk = glm.Chunk( + name=chunk_name, + data={"string_value": data}, + custom_metadata=custom_metadata, + ) + + request = glm.CreateChunkRequest(parent=self.name, chunk=chunk) + response = await client.create_chunk(request) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + response = Chunk(**response) + return response + + def batch_create_chunks( + self, + chunks: BatchCreateChunkOptions, + client: glm.RetrieverServiceClient | None = None, + ): + """ + Create chunks within the given document. + + Args: + chunks: `Chunks` to create. + + Return: + Information about the created chunks. + """ + if client is None: + client = get_default_retriever_client() + + if isinstance(chunks, glm.BatchCreateChunksRequest): + response = client.batch_create_chunks(chunks) + response = type(response).to_dict(response) + return response + + _requests = [] + name, data, custom_metadata = None, None, None + if isinstance(chunks, Iterable): + for chunk in chunks: + if isinstance(chunk, glm.CreateChunkRequest): + _requests.append(chunk) + elif isinstance(chunk, str): + c = glm.CreateChunkRequest( + parent=self.name, chunk=glm.Chunk(data={"string_value": chunk}) + ) + _requests.append(c) + elif isinstance(chunk, Mapping): + for key, value in chunk.items(): + if re.match(_CHUNK_NAME_REGEX, value): + name = value + elif isinstance(value, str): + data = chunk[key] + elif isinstance(value, Iterable): + custom_metadata = value + c = glm.CreateChunkRequest( # Create a glm.CreateChunkRequest + parent=self.name, + chunk=glm.Chunk( + name=name, + data={"string_value": data}, + custom_metadata=custom_metadata, + ), + ) + _requests.append(c) + elif isinstance(chunk, tuple): + for item in chunk: + if re.match(_CHUNK_NAME_REGEX, item): + name = item + elif isinstance(item, str): + data = item + elif isinstance(item, Iterable): + custom_metadata = item + c = glm.CreateChunkRequest( # Create a glm.CreateChunkRequest + parent=self.name, + chunk=glm.Chunk( + name=name, + data={"string_value": data}, + custom_metadata=custom_metadata, + ), + ) + + else: + raise TypeError( + "Batched chunk requests must be in the format of a dictionary or tuple," + "with the name as the key and the data as the value." + ) + + request = glm.BatchCreateChunksRequest(parent=self.name, requests=_requests) + response = client.batch_create_chunks(request) + response = type(response).to_dict(response) + return response + + async def batch_create_chunks_async( + self, + chunks: BatchCreateChunkOptions, + client: glm.RetrieverServiceAsyncClient | None = None, + ): + """This is the async version of `Document.batch_create_chunk`.""" + if client is None: + client = get_default_retriever_async_client() + + if isinstance(chunks, glm.BatchCreateChunksRequest): + response = await client.batch_create_chunks(chunks) + response = type(response).to_dict(response) + return response + + _requests = [] + name, data, custom_metadata = None, None, None + if isinstance(chunks, Iterable): + for chunk in chunks: + if isinstance(chunk, glm.CreateChunkRequest): + _requests.append(chunk) + elif isinstance(chunk, str): + c = glm.CreateChunkRequest( + parent=self.name, chunk=glm.Chunk(data={"string_value": chunk}) + ) + _requests.append(c) + elif isinstance(chunk, Mapping): + for key, value in chunk.items(): + if re.match(_CHUNK_NAME_REGEX, value): + name = value + elif isinstance(value, str): + data = chunk[key] + elif isinstance(value, Iterable): + custom_metadata = value + c = glm.CreateChunkRequest( # Create a glm.CreateChunkRequest + parent=self.name, + chunk=glm.Chunk( + name=name, + data={"string_value": data}, + custom_metadata=custom_metadata, + ), + ) + _requests.append(c) + elif isinstance(chunk, tuple): + for item in chunk: + if re.match(_CHUNK_NAME_REGEX, item): + name = item + elif isinstance(item, str): + data = item + elif isinstance(item, Iterable): + custom_metadata = item + c = glm.CreateChunkRequest( # Create a glm.CreateChunkRequest + parent=self.name, + chunk=glm.Chunk( + name=name, + data={"string_value": data}, + custom_metadata=custom_metadata, + ), + ) + + else: + raise TypeError( + "Batched chunk requests must be in the format of a dictionary or tuple," + "with the name as the key and the data as the value." + ) + + request = glm.BatchCreateChunksRequest(parent=self.name, requests=_requests) + response = await client.batch_create_chunks(request) + response = type(response).to_dict(response) + return response + + def get_chunk( + self, + name: str, + client: glm.RetrieverServiceClient | None = None, + ): + """ + Get information about a specific chunk. + + Args: + name: Name of `Chunk`. + + Returns: + `Chunk` that was requested. + """ + if client is None: + client = get_default_retriever_client() + + request = glm.GetChunkRequest(name=name) + response = client.get_chunk(request) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + response = Chunk(**response) + return response + + async def get_chunk_async( + self, + name: str, + client: glm.RetrieverServiceAsyncClient | None = None, + ): + """This is the async version of `Document.get_chunk`.""" + if client is None: + client = get_default_retriever_async_client() + + request = glm.GetChunkRequest(name=name) + response = await client.get_chunk(request) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + response = Chunk(**response) + return response + + def list_chunks( + self, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + client: glm.RetrieverServiceClient | None = None, + ): + """ + List chunks of a document. + + Args: + page_size: Maximum number of `Chunk`s to request. + page_token: A page token, received from a previous ListChunks call. + + Return: + List of chunks in the document. + """ + if client is None: + client = get_default_retriever_client() + + request = glm.ListChunksRequest( + parent=self.name, page_size=page_size, page_token=page_token + ) + response = client.list_chunks(request) + return response + + async def list_chunks_async( + self, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + client: glm.RetrieverServiceClient | None = None, + ): + """This is the async version of `Document.list_chunks`.""" + if client is None: + client = get_default_retriever_async_client() + + request = glm.ListChunksRequest( + parent=self.name, page_size=page_size, page_token=page_token + ) + response = await client.list_chunks(request) + return response + + def query( + self, + query: str, + metadata_filters: Optional[list[str]] = None, + results_count: Optional[int] = None, + client: glm.RetrieverServiceClient | None = None, + ): + """ + Query a `Document` in the `Corpus` for information. + + Args: + query: Query string to perform semantic search. + metadata_filters: Filter for `Chunk` metadata. + results_count: The maximum number of `Chunk`s to return. + + Returns: + List of relevant chunks. + """ + if client is None: + client = get_default_retriever_client() + + if results_count: + if results_count < 0 or results_count >= 100: + raise ValueError("Number of results returned must be between 1 and 100.") + + request = glm.QueryDocumentRequest( + name=self.name, + query=query, + metadata_filters=metadata_filters, + results_count=results_count, + ) + response = client.query_document(request) + response = type(response).to_dict(response) + + return response + + async def query_async( + self, + query: str, + metadata_filters: Optional[list[str]] = None, + results_count: Optional[int] = None, + client: glm.RetrieverServiceAsyncClient | None = None, + ): + """This is the async version of `Document.query`.""" + if client is None: + client = get_default_retriever_async_client() + + if results_count: + if results_count < 0 or results_count >= 100: + raise ValueError("Number of results returned must be between 1 and 100.") + + request = glm.QueryDocumentRequest( + name=self.name, + query=query, + metadata_filters=metadata_filters, + results_count=results_count, + ) + response = await client.query_document(request) + response = type(response).to_dict(response) + + return response + + def _apply_update(self, path, value): + parts = path.split(".") + for part in parts[:-1]: + self = getattr(self, part) + setattr(self, parts[-1], value) + + def update( + self, + updates: dict[str, Any], + client: glm.RetrieverServiceClient | None = None, + ): + """ + Update a list of fields for a specified document. + + Args: + updates: The list of fields to update. + + Return: + `Chunk` object with specified updates. + """ + if client is None: + client = get_default_retriever_client() + + updates = flatten_update_paths(updates) + field_mask = field_mask_pb2.FieldMask() + for path in updates.keys(): + field_mask.paths.append(path) + for path, value in updates.items(): + self._apply_update(path, value) + + request = glm.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) + response = client.update_document(request) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + return self + + async def update_async( + self, + updates: dict[str, Any], + client: glm.RetrieverServiceAsyncClient | None = None, + ): + """This is the async version of `Document.update`.""" + if client is None: + client = get_default_retriever_async_client() + + updates = flatten_update_paths(updates) + field_mask = field_mask_pb2.FieldMask() + for path in updates.keys(): + field_mask.paths.append(path) + for path, value in updates.items(): + self._apply_update(path, value) + + request = glm.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) + response = await client.update_document(request) + response = type(response).to_dict(response) + idecode_time(response, "create_time") + idecode_time(response, "update_time") + return self + + def batch_update_chunks( + self, + chunks: BatchUpdateChunksOptions, + client: glm.RetrieverServiceClient | None = None, + ): + """ + Update multiple chunks within the same document. + + Args: + chunks: Data structure specifying which `Chunk`s to update and what the required updats are. + + Return: + Updated `Chunk`s. + """ + if client is None: + client = get_default_retriever_client() + + # TODO (@snkancharla): Add idecode_time here in each conditional loop? + if isinstance(chunks, glm.BatchUpdateChunksRequest): + response = client.batch_update_chunks(chunks) + response = type(response).to_dict(response) + return response + + _requests = [] + if isinstance(chunks, Mapping): + # Key is name of chunk, value is a dictionary of updates + for key, value in chunks.items(): + c = self.get_chunk(name=key) + updates = flatten_update_paths(value) + field_mask = field_mask_pb2.FieldMask() + for path in updates.keys(): + field_mask.paths.append(path) + for path, value in updates.items(): + c._apply_update(path, value) + _requests.append(glm.UpdateChunkRequest(chunk=c.to_dict(), update_mask=field_mask)) + request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + response = client.batch_update_chunks(request) + response = type(response).to_dict(response) + return response + if isinstance(chunks, Iterable) and not isinstance(chunks, Mapping): + for chunk in chunks: + if isinstance(chunk, glm.UpdateChunkRequest): + _requests.append(chunk) + elif isinstance(chunk, tuple): + # First element is name of chunk, second element contains updates + c = self.get_chunk(name=chunk[0]) + updates = flatten_update_paths(chunk[1]) + field_mask = field_mask_pb2.FieldMask() + for path in updates.keys(): + field_mask.paths.append(path) + for path, value in updates.items(): + c._apply_update(path, value) + _requests.append({"chunk": c.to_dict(), "update_mask": field_mask}) + else: + raise TypeError( + "The `chunks` parameter must be a list of glm.UpdateChunkRequests," + "dictionaries, or tuples of dictionaries." + ) + request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + response = client.batch_update_chunks(request) + response = type(response).to_dict(response) + return response + + async def batch_update_chunks_async( + self, + chunks: BatchUpdateChunksOptions, + client: glm.RetrieverServiceAsyncClient | None = None, + ): + """This is the async version of `Document.batch_update_chunks`.""" + if client is None: + client = get_default_retriever_async_client() + + # TODO (@snkancharla): Add idecode_time here in each conditional loop? + if isinstance(chunks, glm.BatchUpdateChunksRequest): + response = await client.batch_update_chunks(chunks) + response = type(response).to_dict(response) + return response + + _requests = [] + if isinstance(chunks, Mapping): + # Key is name of chunk, value is a dictionary of updates + for key, value in chunks.items(): + c = self.get_chunk(name=key) + updates = flatten_update_paths(value) + field_mask = field_mask_pb2.FieldMask() + for path in updates.keys(): + field_mask.paths.append(path) + for path, value in updates.items(): + c._apply_update(path, value) + _requests.append(glm.UpdateChunkRequest(chunk=c.to_dict(), update_mask=field_mask)) + request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + response = await client.batch_update_chunks(request) + response = type(response).to_dict(response) + return response + if isinstance(chunks, Iterable) and not isinstance(chunks, Mapping): + for chunk in chunks: + if isinstance(chunk, glm.UpdateChunkRequest): + _requests.append(chunk) + elif isinstance(chunk, tuple): + # First element is name of chunk, second element contains updates + c = self.get_chunk(name=chunk[0]) + updates = flatten_update_paths(chunk[1]) + field_mask = field_mask_pb2.FieldMask() + for path in updates.keys(): + field_mask.paths.append(path) + for path, value in updates.items(): + c._apply_update(path, value) + _requests.append({"chunk": c.to_dict(), "update_mask": field_mask}) + else: + raise TypeError( + "The `chunks` parameter must be a list of glm.UpdateChunkRequests," + "dictionaries, or tuples of dictionaries." + ) + request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + response = await client.batch_update_chunks(request) + response = type(response).to_dict(response) + return response + + def delete_chunk( + self, name: str, client: glm.RetrieverServiceClient | None = None, # fmt: skip + ): + """ + Delete a `Chunk`. + + Args: + name: The `Chunk` name. + """ + if client is None: + client = get_default_retriever_client() + + request = glm.DeleteChunkRequest(name=name) + client.delete_chunk(request) + + async def delete_chunk_async( + self, name: str, client: glm.RetrieverServiceAsyncClient | None = None, # fmt: skip + ): + """This is the async version of `Document.delete_chunk`.""" + if client is None: + client = get_default_retriever_async_client() + + request = glm.DeleteChunkRequest(name=name) + await client.delete_chunk(request) + + def batch_delete_chunks( + self, + chunks: BatchDeleteChunkOptions, + client: glm.RetrieverServiceClient | None = None, + ): + """ + Delete multiple `Chunk`s from a document. + + Args: + chunks: Names of `Chunks` to delete. + """ + if client is None: + client = get_default_retriever_client() + + if all(isinstance(x, glm.DeleteChunkRequest) for x in chunks): + request = glm.BatchDeleteChunksRequest(parent=self.name, requests=chunks) + client.batch_delete_chunks(request) + elif isinstance(chunks, Iterable): + _request_list = [] + for chunk_name in chunks: + _request_list.append(glm.DeleteChunkRequest(name=chunk_name)) + request = glm.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) + client.batch_delete_chunks(request) + else: + raise ValueError( + "To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple `glm.DeleteChunkRequest`s." + ) + + async def batch_delete_chunks_async( + self, + chunks: BatchDeleteChunkOptions, + client: glm.RetrieverServiceAsyncClient | None = None, + ): + """This is the async version of `Document.batch_delete_chunks`.""" + if client is None: + client = get_default_retriever_async_client() + + if all(isinstance(x, glm.DeleteChunkRequest) for x in chunks): + request = glm.BatchDeleteChunksRequest(parent=self.name, requests=chunks) + await client.batch_delete_chunks(request) + elif isinstance(chunks, Iterable): + _request_list = [] + for chunk_name in chunks: + _request_list.append(glm.DeleteChunkRequest(name=chunk_name)) + request = glm.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) + await client.batch_delete_chunks(request) + else: + raise ValueError( + "To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple `glm.DeleteChunkRequest`s." + ) + + def to_dict(self) -> dict[str, Any]: + result = { + "name": self.name, + "display_name": self.display_name, + "custom_metadata": self.custom_metadata, + } + return result + + +@string_utils.prettyprint +@dataclasses.dataclass(init=False) +class Chunk(abc.ABC): + """ + A `Chunk` is part of the `Document`, or the actual text. + """ + + name: str + data: ChunkData + custom_metadata: list[CustomMetadata] | None + state: State + + def __init__( + self, + name: str, + data: ChunkData | str, + custom_metadata: list[CustomMetadata] | None, + state: State, + ): + self.name = name + if isinstance(data, str): + self.data = ChunkData(string_value=data) + elif isinstance(data, dict): + self.data = ChunkData(string_value=data["string_value"]) + if custom_metadata is None: + self.custom_metadata = [] + else: + self.custom_metadata = [CustomMetadata(*cm) for cm in custom_metadata] + self.state = state + + def _apply_update(self, path, value): + parts = path.split(".") + for part in parts[:-1]: + self = getattr(self, part) + setattr(self, parts[-1], value) + + def update( + self, + updates: dict[str, Any], + client: glm.RetrieverServiceClient | None = None, + ): + """ + Update a list of fields for a specified `Chunk`. + + Args: + updates: List of fields to update for a `Chunk`. + + Return: + Updated `Chunk` object. + """ + if client is None: + client = get_default_retriever_client() + + updates = flatten_update_paths(updates) + field_mask = field_mask_pb2.FieldMask() + for path in updates.keys(): + field_mask.paths.append(path) + for path, value in updates.items(): + self._apply_update(path, value) + request = glm.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) + response = client.update_chunk(request) + response = type(response).to_dict(response) + + idecode_time(response, "create_time") + idecode_time(response, "update_time") + + return self + + async def update_async( + self, + updates: dict[str, Any], + client: glm.RetrieverServiceAsyncClient | None = None, + ): + """This is the async version of `Chunk.update`.""" + if client is None: + client = get_default_retriever_async_client() + + updates = flatten_update_paths(updates) + field_mask = field_mask_pb2.FieldMask() + for path in updates.keys(): + field_mask.paths.append(path) + for path, value in updates.items(): + self._apply_update(path, value) + request = glm.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) + response = await client.update_chunk(request) + response = type(response).to_dict(response) + + idecode_time(response, "create_time") + idecode_time(response, "update_time") + + return self + + def to_dict(self) -> dict[str, Any]: + result = { + "name": self.name, + "data": dataclasses.asdict(self.data), + "custom_metadata": [dataclasses.asdict(cm) for cm in self.custom_metadata], + "state": self.state, + } + return result diff --git a/google/generativeai/utils.py b/google/generativeai/utils.py new file mode 100644 index 000000000..6dc2b6a20 --- /dev/null +++ b/google/generativeai/utils.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + + +def flatten_update_paths(updates): + new_updates = {} + for key, value in updates.items(): + if isinstance(value, dict): + for sub_key, sub_value in flatten_update_paths(value).items(): + new_updates[f"{key}.{sub_key}"] = sub_value + else: + new_updates[key] = value + + return new_updates diff --git a/tests/test_retriever.py b/tests/test_retriever.py new file mode 100644 index 000000000..b1ad1bcfc --- /dev/null +++ b/tests/test_retriever.py @@ -0,0 +1,706 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import unittest.mock as mock + +import google.ai.generativelanguage as glm + +from google.generativeai import retriever +from google.generativeai import client +from google.generativeai.types import retriever_types as retriever_service +from absl.testing import absltest +from absl.testing import parameterized + + +class UnitTests(parameterized.TestCase): + def setUp(self): + self.client = unittest.mock.MagicMock() + + client._client_manager.clients["retriever"] = self.client + + self.observed_requests = [] + + self.responses = {} + + def add_client_method(f): + name = f.__name__ + setattr(self.client, name, f) + return f + + @add_client_method + def create_corpus( + request: glm.CreateCorpusRequest, + ) -> glm.Corpus: + self.observed_requests.append(request) + return glm.Corpus(name="corpora/demo_corpus", display_name="demo_corpus") + + @add_client_method + def get_corpus( + request: glm.GetCorpusRequest, + ) -> glm.Corpus: + self.observed_requests.append(request) + return glm.Corpus(name="corpora/demo_corpus", display_name="demo_corpus") + + @add_client_method + def update_corpus( + request: glm.UpdateCorpusRequest, + ) -> glm.Corpus: + self.observed_requests.append(request) + return glm.Corpus(name="corpora/demo_corpus", display_name="demo_corpus_1") + + @add_client_method + def list_corpora( + request: glm.ListCorporaRequest, + ) -> glm.ListCorporaResponse: + self.observed_requests.append(request) + return [ + glm.Corpus(name="corpora/demo_corpus_1", display_name="demo_corpus_1"), + glm.Corpus(name="corpora/demo_corpus_2", display_name="demo_corpus_2"), + ] + + @add_client_method + def query_corpus( + request: glm.QueryCorpusRequest, + ) -> glm.QueryCorpusResponse: + self.observed_requests.append(request) + return glm.QueryCorpusResponse( + relevant_chunks=[ + glm.RelevantChunk( + chunk_relevance_score=0.08, + chunk=glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data={"string_value": "This is a demo chunk."}, + ), + ) + ] + ) + + @add_client_method + def delete_corpus(request: glm.DeleteCorpusRequest) -> None: + self.observed_requests.append(request) + + @add_client_method + def create_document( + request: glm.CreateDocumentRequest, + ) -> retriever_service.Document: + self.observed_requests.append(request) + return glm.Document( + name="corpora/demo_corpus/documents/demo_doc", display_name="demo_doc" + ) + + @add_client_method + def get_document( + request: glm.GetDocumentRequest, + ) -> retriever_service.Document: + self.observed_requests.append(request) + return glm.Document( + name="corpora/demo_corpus/documents/demo_doc", display_name="demo_doc" + ) + + @add_client_method + def update_document( + request: glm.UpdateDocumentRequest, + ) -> glm.Document: + self.observed_requests.append(request) + return glm.Document( + name="corpora/demo_corpus/documents/demo_doc", display_name="demo_doc_1" + ) + + @add_client_method + def list_documents( + request: glm.ListDocumentsRequest, + ) -> glm.ListDocumentsResponse: + self.observed_requests.append(request) + return [ + glm.Document( + name="corpora/demo_corpus/documents/demo_doc_1", display_name="demo_doc_1" + ), + glm.Document( + name="corpora/demo_corpus/documents/demo_doc_2", display_name="demo_doc_2" + ), + ] + + @add_client_method + def delete_document( + request: glm.DeleteDocumentRequest, + ) -> None: + self.observed_requests.append(request) + + @add_client_method + def query_document( + request: glm.QueryDocumentRequest, + ) -> glm.QueryDocumentResponse: + self.observed_requests.append(request) + return glm.QueryCorpusResponse( + relevant_chunks=[ + glm.RelevantChunk( + chunk_relevance_score=0.08, + chunk=glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data={"string_value": "This is a demo chunk."}, + ), + ) + ] + ) + + @add_client_method + def create_chunk( + request: glm.CreateChunkRequest, + ) -> retriever_service.Chunk: + self.observed_requests.append(request) + return glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data={"string_value": "This is a demo chunk."}, + ) + + @add_client_method + def batch_create_chunks( + request: glm.BatchCreateChunksRequest, + ) -> glm.BatchCreateChunksResponse: + self.observed_requests.append(request) + return glm.BatchCreateChunksResponse( + chunks=[ + glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/dc", + data={"string_value": "This is a demo chunk."}, + ), + glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/dc1", + data={"string_value": "This is another demo chunk."}, + ), + ] + ) + + @add_client_method + def get_chunk( + request: glm.GetChunkRequest, + ) -> retriever_service.Chunk: + self.observed_requests.append(request) + return glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data={"string_value": "This is a demo chunk."}, + ) + + @add_client_method + def list_chunks( + request: glm.ListChunksRequest, + ) -> glm.ListChunksResponse: + self.observed_requests.append(request) + return [ + glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data={"string_value": "This is a demo chunk."}, + ), + glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk_1", + data={"string_value": "This is another demo chunk."}, + ), + ] + + @add_client_method + def update_chunk(request: glm.UpdateChunkRequest) -> glm.Chunk: + self.observed_requests.append(request) + return glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data={"string_value": "This is an updated demo chunk."}, + ) + + @add_client_method + def batch_update_chunks( + request: glm.BatchUpdateChunksRequest, + ) -> glm.BatchUpdateChunksResponse: + self.observed_requests.append(request) + return glm.BatchUpdateChunksResponse( + chunks=[ + glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data={"string_value": "This is an updated chunk."}, + ), + glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk_1", + data={"string_value": "This is another updated chunk."}, + ), + ] + ) + + @add_client_method + def delete_chunk( + request: glm.DeleteChunkRequest, + ) -> None: + self.observed_requests.append(request) + + @add_client_method + def batch_delete_chunks( + request: glm.BatchDeleteChunksRequest, + ) -> None: + self.observed_requests.append(request) + + def test_create_corpus(self, display_name="demo_corpus"): + x = retriever.create_corpus(display_name=display_name) + self.assertIsInstance(x, retriever_service.Corpus) + self.assertEqual("demo_corpus", x.display_name) + self.assertEqual("corpora/demo_corpus", x.name) + + @parameterized.named_parameters( + [ + dict(testcase_name="match_corpora_regex", name="corpora/demo_corpus"), + dict(testcase_name="no_corpora", name="demo_corpus"), + dict(testcase_name="with_punctuation", name="corpora/demo_corpus*(*)"), + dict(testcase_name="dash_at_start", name="-demo_corpus"), + ] + ) + def test_create_corpus_names(self, name): + x = retriever.create_corpus(name=name) + self.assertEqual("demo_corpus", x.display_name) + self.assertEqual("corpora/demo_corpus", x.name) + + def test_get_corpus(self, display_name="demo_corpus"): + x = retriever.create_corpus(display_name=display_name) + c = retriever.get_corpus(name=x.name) + self.assertEqual("demo_corpus", c.display_name) + + def test_update_corpus(self): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + update_request = demo_corpus.update(updates={"display_name": "demo_corpus_1"}) + self.assertIsInstance(self.observed_requests[-1], glm.UpdateCorpusRequest) + self.assertEqual("demo_corpus_1", demo_corpus.display_name) + + def test_list_corpora(self): + x = retriever.list_corpora(page_size=1) + self.assertIsInstance(x, list) + self.assertEqual(len(x), 2) + + def test_query_corpus(self): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + demo_document = demo_corpus.create_document(display_name="demo_doc") + demo_chunk = demo_document.create_chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + q = demo_corpus.query(query="What kind of chunk is this?") + self.assertIsInstance(q, dict) + self.assertEqual( + q, + { + "relevant_chunks": [ + { + "chunk_relevance_score": 0.08, + "chunk": { + "name": "corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + "data": {"string_value": "This is a demo chunk."}, + "custom_metadata": [], + "state": 0, + }, + } + ] + }, + ) + + def test_delete_corpus(self): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + demo_document = demo_corpus.create_document(display_name="demo_doc") + delete_request = retriever.delete_corpus(name="corpora/demo_corpus", force=True) + self.assertIsInstance(self.observed_requests[-1], glm.DeleteCorpusRequest) + + def test_create_document(self, display_name="demo_doc"): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + x = demo_corpus.create_document(display_name=display_name) + self.assertIsInstance(x, retriever_service.Document) + self.assertEqual("demo_doc", x.display_name) + + @parameterized.named_parameters( + [ + dict( + testcase_name="match_document_regex", name="corpora/demo_corpus/documents/demo_doc" + ), + dict(testcase_name="no_document", name="corpora/demo_corpus/demo_document"), + dict( + testcase_name="with_punctuation", name="corpora/demo_corpus*(*)/documents/demo_doc" + ), + dict(testcase_name="dash_at_start", name="-demo_doc"), + ] + ) + def test_create_document_name(self, name): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + x = demo_corpus.create_document(name=name) + self.assertEqual("corpora/demo_corpus/documents/demo_doc", x.name) + self.assertEqual("demo_doc", x.display_name) + + def test_get_document(self, display_name="demo_doc"): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + x = demo_corpus.create_document(display_name=display_name) + d = demo_corpus.get_document(name=x.name) + self.assertEqual("demo_doc", d.display_name) + + def test_update_document(self): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + demo_document = demo_corpus.create_document(display_name="demo_doc") + update_request = demo_document.update(updates={"display_name": "demo_doc_1"}) + self.assertEqual("demo_doc_1", demo_document.display_name) + + def test_delete_document(self): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + demo_document = demo_corpus.create_document(display_name="demo_doc") + demo_doc2 = demo_corpus.create_document(display_name="demo_doc_2") + delete_request = demo_corpus.delete_document(name="corpora/demo_corpus/documents/demo_doc") + self.assertIsInstance(self.observed_requests[-1], glm.DeleteDocumentRequest) + + def test_list_documents(self): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + demo_document = demo_corpus.create_document(display_name="demo_doc") + demo_doc2 = demo_corpus.create_document(display_name="demo_doc_2") + self.assertLen(demo_corpus.list_documents(), 2) + + def test_query_document(self): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + demo_document = demo_corpus.create_document(display_name="demo_doc") + demo_chunk = demo_document.create_chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + q = demo_document.query(query="What kind of chunk is this?") + self.assertIsInstance(q, dict) + self.assertEqual( + q, + { + "relevant_chunks": [ + { + "chunk_relevance_score": 0.08, + "chunk": { + "name": "corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + "data": {"string_value": "This is a demo chunk."}, + "custom_metadata": [], + "state": 0, + }, + } + ] + }, + ) + + def test_create_chunk(self): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + demo_document = demo_corpus.create_document(display_name="demo_doc") + x = demo_document.create_chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + self.assertIsInstance(x, retriever_service.Chunk) + self.assertEqual("corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", x.name) + self.assertEqual(retriever_service.ChunkData("This is a demo chunk."), x.data) + + @parameterized.named_parameters( + [ + dict( + testcase_name="match_chunk_regex", + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + ), + dict(testcase_name="no_chunk", name="corpora/demo_corpus/demo_document/demo_chunk"), + dict( + testcase_name="with_punctuation", + name="corpora/demo_corpus*(*)/documents/demo_doc/chunks*****/demo_chunk", + ), + dict(testcase_name="dash_at_start", name="-demo_chunk"), + dict(testcase_name="empty_value", name=""), + ] + ) + def test_create_chunk_name(self, name): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + demo_document = demo_corpus.create_document(display_name="demo_doc") + x = demo_document.create_chunk( + name=name, + data="This is a demo chunk.", + ) + self.assertEqual("corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", x.name) + + @parameterized.named_parameters( + [ + dict( + testcase_name="dictionaries", + chunks=[ + { + "name": "corpora/demo_corpus/documents/demo_doc/chunks/dc", + "data": "This is a demo chunk.", + }, + { + "name": "corpora/demo_corpus/documents/demo_doc/chunks/dc1", + "data": "This is another demo chunk.", + }, + ], + ), + dict( + testcase_name="tuples", + chunks=[ + ( + "corpora/demo_corpus/documents/demo_doc/chunks/dc", + "This is a demo chunk.", + ), + ( + "corpora/demo_corpus/documents/demo_doc/chunks/dc1", + "This is another demo chunk.", + ), + ], + ), + ] + ) + def test_batch_create_chunks(self, chunks): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + demo_document = demo_corpus.create_document(display_name="demo_doc") + creation_req = demo_document.batch_create_chunks(chunks=chunks) + self.assertIsInstance(self.observed_requests[-1], glm.BatchCreateChunksRequest) + self.assertEqual("This is a demo chunk.", creation_req["chunks"][0]["data"]["string_value"]) + self.assertEqual( + "This is another demo chunk.", creation_req["chunks"][1]["data"]["string_value"] + ) + + def test_get_chunk(self): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + demo_document = demo_corpus.create_document(display_name="demo_doc") + x = demo_document.create_chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + ch = demo_document.get_chunk(name=x.name) + self.assertEqual(retriever_service.ChunkData("This is a demo chunk."), ch.data) + + def test_list_chunks(self): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + demo_document = demo_corpus.create_document(display_name="demo_doc") + x = demo_document.create_chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + y = demo_document.create_chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk_1", + data="This is another demo chunk.", + ) + list_req = demo_document.list_chunks() + self.assertIsInstance(self.observed_requests[-1], glm.ListChunksRequest) + self.assertLen(list_req, 2) + + def test_update_chunk(self): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + demo_document = demo_corpus.create_document(display_name="demo_doc") + x = demo_document.create_chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + update_request = x.update( + updates={"data": {"string_value": "This is an updated demo chunk."}} + ) + self.assertEqual( + retriever_service.ChunkData("This is an updated demo chunk."), + update_request.data, + ) + + @parameterized.named_parameters( + [ + dict( + testcase_name="dictionary_of_updates", + updates={ + "corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk": { + "data": {"string_value": "This is an updated chunk."} + }, + "corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk_1": { + "data": {"string_value": "This is another updated chunk."} + }, + }, + ), + dict( + testcase_name="list_of_tuples", + updates=[ + ( + "corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + {"data": {"string_value": "This is an updated chunk."}}, + ), + ( + "corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk_1", + {"data": {"string_value": "This is another updated chunk."}}, + ), + ], + ), + ], + ) + def test_batch_update_chunks_data_structures(self, updates): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + demo_document = demo_corpus.create_document(display_name="demo_doc") + x = demo_document.create_chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + y = demo_document.create_chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk_1", + data="This is another demo chunk.", + ) + update_request = demo_document.batch_update_chunks(chunks=updates) + self.assertIsInstance(self.observed_requests[-1], glm.BatchUpdateChunksRequest) + self.assertEqual( + "This is an updated chunk.", update_request["chunks"][0]["data"]["string_value"] + ) + self.assertEqual( + "This is another updated chunk.", update_request["chunks"][1]["data"]["string_value"] + ) + + def test_delete_chunk(self): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + demo_document = demo_corpus.create_document(display_name="demo_doc") + x = demo_document.create_chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + delete_request = demo_document.delete_chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk" + ) + self.assertIsInstance(self.observed_requests[-1], glm.DeleteChunkRequest) + + def test_batch_delete_chunks(self): + demo_corpus = retriever.create_corpus(display_name="demo_corpus") + demo_document = demo_corpus.create_document(display_name="demo_doc") + x = demo_document.create_chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + y = demo_document.create_chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is another demo chunk.", + ) + delete_request = demo_document.batch_delete_chunks(chunks=[x.name, y.name]) + self.assertIsInstance(self.observed_requests[-1], glm.BatchDeleteChunksRequest) + + @parameterized.named_parameters( + [ + "create_corpus", + retriever.create_corpus, + retriever.create_corpus_async, + ], + [ + "get_corpus", + retriever.get_corpus, + retriever.get_corpus_async, + ], + [ + "delete_corpus", + retriever.delete_corpus, + retriever.delete_corpus_async, + ], + [ + "list_corpora", + retriever.list_corpora, + retriever.list_corpora_async, + ], + [ + "Corpus.create_document", + retriever_service.Corpus.create_document, + retriever_service.Corpus.create_document_async, + ], + [ + "Corpus.get_document", + retriever_service.Corpus.get_document, + retriever_service.Corpus.get_document_async, + ], + [ + "Corpus.update", + retriever_service.Corpus.update, + retriever_service.Corpus.update_async, + ], + [ + "Corpus.query", + retriever_service.Corpus.query, + retriever_service.Corpus.query_async, + ], + [ + "Corpus.list_documents", + retriever_service.Corpus.list_documents, + retriever_service.Corpus.list_documents_async, + ], + [ + "Corpus.delete_document", + retriever_service.Corpus.delete_document, + retriever_service.Corpus.delete_document_async, + ], + [ + "Document.create_chunk", + retriever_service.Document.create_chunk, + retriever_service.Document.create_chunk_async, + ], + [ + "Document.get_chunk", + retriever_service.Document.get_chunk, + retriever_service.Document.get_chunk_async, + ], + [ + "Document.batch_create_chunks", + retriever_service.Document.batch_create_chunks, + retriever_service.Document.batch_create_chunks_async, + ], + [ + "Document.list_chunks", + retriever_service.Document.list_chunks, + retriever_service.Document.list_chunks_async, + ], + [ + "Document.query", + retriever_service.Document.query, + retriever_service.Document.query_async, + ], + [ + "Document.update", + retriever_service.Document.update, + retriever_service.Document.update_async, + ], + [ + "Document.batch_update_chunks", + retriever_service.Document.batch_update_chunks, + retriever_service.Document.batch_update_chunks_async, + ], + [ + "Document.delete_chunk", + retriever_service.Document.delete_chunk, + retriever_service.Document.delete_chunk_async, + ], + [ + "Document.batch_delete_chunks", + retriever_service.Document.batch_delete_chunks, + retriever_service.Document.batch_delete_chunks_async, + ], + [ + "Chunk.update", + retriever_service.Chunk.update, + retriever_service.Chunk.update_async, + ], + ) + def test_async_code_match(self, obj, aobj): + import inspect + import re + + source = inspect.getsource(obj) + asource = inspect.getsource(aobj) + source = re.sub('""".*"""', "", source, flags=re.DOTALL) + asource = re.sub('""".*"""', "", asource, flags=re.DOTALL) + asource = ( + asource.replace("anext", "next") + .replace("aiter", "iter") + .replace("_async", "") + .replace("async ", "") + .replace("await ", "") + .replace("Async", "") + .replace("ASYNC_", "") + ) + + asource = re.sub(" *?# type: ignore", "", asource) + self.assertEqual(source, asource) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/test_retriever_async.py b/tests/test_retriever_async.py new file mode 100644 index 000000000..b9e1acea3 --- /dev/null +++ b/tests/test_retriever_async.py @@ -0,0 +1,581 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import copy +import math +import unittest +import unittest.mock as mock + +import google.ai.generativelanguage as glm + +from google.generativeai import retriever +from google.generativeai import client as client_lib +from google.generativeai.types import retriever_types as retriever_service +from absl.testing import absltest +from absl.testing import parameterized + + +class AsyncTests(parameterized.TestCase, unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.client = unittest.mock.AsyncMock() + + client_lib._client_manager.clients["retriever_async"] = self.client + + def add_client_method(f): + name = f.__name__ + setattr(self.client, name, f) + return f + + self.observed_requests = [] + self.responses = collections.defaultdict(list) + + @add_client_method + async def create_corpus( + request: glm.CreateCorpusRequest, + ) -> glm.Corpus: + self.observed_requests.append(request) + return glm.Corpus(name="corpora/demo_corpus", display_name="demo_corpus") + + @add_client_method + async def get_corpus( + request: glm.GetCorpusRequest, + ) -> glm.Corpus: + self.observed_requests.append(request) + return glm.Corpus(name="corpora/demo_corpus", display_name="demo_corpus") + + @add_client_method + async def update_corpus(request: glm.UpdateCorpusRequest) -> glm.Corpus: + self.observed_requests.append(request) + return glm.Corpus(name="corpora/demo_corpus", display_name="demo_corpus_1") + + @add_client_method + async def list_corpora(request: glm.ListCorporaRequest) -> glm.ListCorporaResponse: + self.observed_requests.append(request) + return [ + glm.Corpus(name="corpora/demo_corpus_1", display_name="demo_corpus_1"), + glm.Corpus(name="corpora/demo_corpus_2", display_name="demo_corpus_2"), + ] + + @add_client_method + async def query_corpus( + request: glm.QueryCorpusRequest, + ) -> glm.QueryCorpusResponse: + self.observed_requests.append(request) + return glm.QueryCorpusResponse( + relevant_chunks=[ + glm.RelevantChunk( + chunk_relevance_score=0.08, + chunk=glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data={"string_value": "This is a demo chunk."}, + ), + ) + ] + ) + + @add_client_method + async def delete_corpus(request: glm.DeleteCorpusRequest) -> None: + self.observed_requests.append(request) + + @add_client_method + async def create_document( + request: glm.CreateDocumentRequest, + ) -> retriever_service.Document: + self.observed_requests.append(request) + return glm.Document( + name="corpora/demo_corpus/documents/demo_doc", display_name="demo_doc" + ) + + @add_client_method + async def get_document( + request: glm.GetDocumentRequest, + ) -> retriever_service.Document: + self.observed_requests.append(request) + return glm.Document( + name="corpora/demo_corpus/documents/demo_doc", display_name="demo_doc" + ) + + @add_client_method + async def update_document( + request: glm.UpdateDocumentRequest, + ) -> glm.Document: + self.observed_requests.append(request) + return glm.Document( + name="corpora/demo_corpus/documents/demo_doc", display_name="demo_doc_1" + ) + + @add_client_method + async def list_documents( + request: glm.ListDocumentsRequest, + ) -> glm.ListDocumentsResponse: + self.observed_requests.append(request) + return [ + glm.Document( + name="corpora/demo_corpus/documents/demo_doc_1", display_name="demo_doc_1" + ), + glm.Document( + name="corpora/demo_corpus/documents/demo_doc_2", display_name="demo_doc_2" + ), + ] + + @add_client_method + async def delete_document( + request: glm.DeleteDocumentRequest, + ) -> None: + self.observed_requests.append(request) + + @add_client_method + async def query_document( + request: glm.QueryDocumentRequest, + ) -> glm.QueryDocumentResponse: + self.observed_requests.append(request) + return glm.QueryCorpusResponse( + relevant_chunks=[ + glm.RelevantChunk( + chunk_relevance_score=0.08, + chunk=glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data={"string_value": "This is a demo chunk."}, + ), + ) + ] + ) + + @add_client_method + async def create_chunk( + request: glm.CreateChunkRequest, + ) -> retriever_service.Chunk: + self.observed_requests.append(request) + return glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data={"string_value": "This is a demo chunk."}, + ) + + @add_client_method + async def batch_create_chunks( + request: glm.BatchCreateChunksRequest, + ) -> glm.BatchCreateChunksResponse: + self.observed_requests.append(request) + return glm.BatchCreateChunksResponse( + chunks=[ + glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/dc", + data={"string_value": "This is a demo chunk."}, + ), + glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/dc1", + data={"string_value": "This is another demo chunk."}, + ), + ] + ) + + @add_client_method + async def get_chunk( + request: glm.GetChunkRequest, + ) -> retriever_service.Chunk: + self.observed_requests.append(request) + return glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data={"string_value": "This is a demo chunk."}, + ) + + @add_client_method + async def list_chunks( + request: glm.ListChunksRequest, + ) -> glm.ListChunksResponse: + self.observed_requests.append(request) + return [ + glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data={"string_value": "This is a demo chunk."}, + ), + glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk_1", + data={"string_value": "This is another demo chunk."}, + ), + ] + + @add_client_method + async def update_chunk(request: glm.UpdateChunkRequest) -> glm.Chunk: + self.observed_requests.append(request) + return glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data={"string_value": "This is an updated demo chunk."}, + ) + + @add_client_method + async def batch_update_chunks( + request: glm.BatchUpdateChunksRequest, + ) -> glm.BatchUpdateChunksResponse: + self.observed_requests.append(request) + return glm.BatchUpdateChunksResponse( + chunks=[ + glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data={"string_value": "This is an updated chunk."}, + ), + glm.Chunk( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk_1", + data={"string_value": "This is another updated chunk."}, + ), + ] + ) + + @add_client_method + async def delete_chunk( + request: glm.DeleteChunkRequest, + ) -> None: + self.observed_requests.append(request) + + @add_client_method + async def batch_delete_chunks( + request: glm.BatchDeleteChunksRequest, + ) -> None: + self.observed_requests.append(request) + + async def test_create_corpus(self, display_name="demo_corpus"): + x = await retriever.create_corpus_async(display_name=display_name) + self.assertIsInstance(x, retriever_service.Corpus) + self.assertEqual("demo_corpus", x.display_name) + self.assertEqual("corpora/demo_corpus", x.name) + + @parameterized.named_parameters( + [ + dict(testcase_name="match_corpora_regex", name="corpora/demo_corpus"), + dict(testcase_name="no_corpora", name="demo_corpus"), + dict(testcase_name="with_punctuation", name="corpora/demo_corpus*(*)"), + dict(testcase_name="dash_at_start", name="-demo_corpus"), + ] + ) + async def test_create_corpus_names(self, name): + x = await retriever.create_corpus_async(name=name) + self.assertEqual("demo_corpus", x.display_name) + self.assertEqual("corpora/demo_corpus", x.name) + + async def test_get_corpus(self, display_name="demo_corpus"): + x = await retriever.create_corpus_async(display_name=display_name) + c = await retriever.get_corpus_async(name=x.name) + self.assertEqual("demo_corpus", c.display_name) + + async def test_update_corpus(self): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + update_request = await demo_corpus.update_async(updates={"display_name": "demo_corpus_1"}) + self.assertEqual("demo_corpus_1", demo_corpus.display_name) + + async def test_list_corpora(self): + x = await retriever.list_corpora_async(page_size=1) + self.assertIsInstance(x, list) + self.assertEqual(len(x), 2) + + async def test_query_corpus(self): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + demo_document = await demo_corpus.create_document_async(display_name="demo_doc") + demo_chunk = await demo_document.create_chunk_async( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + q = await demo_corpus.query_async(query="What kind of chunk is this?") + self.assertIsInstance(q, dict) + self.assertEqual( + q, + { + "relevant_chunks": [ + { + "chunk_relevance_score": 0.08, + "chunk": { + "name": "corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + "data": {"string_value": "This is a demo chunk."}, + "custom_metadata": [], + "state": 0, + }, + } + ] + }, + ) + + async def test_delete_corpus(self): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + demo_document = await demo_corpus.create_document_async(display_name="demo_doc") + delete_request = await retriever.delete_corpus_async(name="corpora/demo_corpus", force=True) + self.assertIsInstance(self.observed_requests[-1], glm.DeleteCorpusRequest) + + async def test_create_document(self, display_name="demo_doc"): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + x = await demo_corpus.create_document_async(display_name=display_name) + self.assertIsInstance(x, retriever_service.Document) + self.assertEqual("demo_doc", x.display_name) + + @parameterized.named_parameters( + [ + dict( + testcase_name="match_document_regex", name="corpora/demo_corpus/documents/demo_doc" + ), + dict(testcase_name="no_document", name="corpora/demo_corpus/demo_document"), + dict( + testcase_name="with_punctuation", name="corpora/demo_corpus*(*)/documents/demo_doc" + ), + dict(testcase_name="dash_at_start", name="-demo_doc"), + ] + ) + async def test_create_document_name(self, name): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + x = await demo_corpus.create_document_async(name=name) + self.assertEqual("corpora/demo_corpus/documents/demo_doc", x.name) + self.assertEqual("demo_doc", x.display_name) + + async def test_get_document(self, display_name="demo_doc"): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + x = await demo_corpus.create_document_async(display_name=display_name) + d = await demo_corpus.get_document_async(name=x.name) + self.assertEqual("demo_doc", d.display_name) + + async def test_update_document(self): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + demo_document = await demo_corpus.create_document_async(display_name="demo_doc") + update_request = await demo_document.update_async(updates={"display_name": "demo_doc_1"}) + self.assertEqual("demo_doc_1", demo_document.display_name) + + async def test_delete_document(self): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + demo_document = await demo_corpus.create_document_async(display_name="demo_doc") + demo_doc2 = await demo_corpus.create_document_async(display_name="demo_doc_2") + delete_request = await demo_corpus.delete_document_async( + name="corpora/demo_corpus/documents/demo_doc" + ) + self.assertIsInstance(self.observed_requests[-1], glm.DeleteDocumentRequest) + + async def test_list_documents(self): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + demo_document = await demo_corpus.create_document_async(display_name="demo_doc") + demo_doc2 = await demo_corpus.create_document_async(display_name="demo_doc_2") + self.assertLen(demo_corpus.list_documents(), 2) + + async def test_query_document(self): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + demo_document = await demo_corpus.create_document_async(display_name="demo_doc") + demo_chunk = await demo_document.create_chunk_async( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + q = await demo_document.query_async(query="What kind of chunk is this?") + self.assertIsInstance(q, dict) + self.assertEqual( + q, + { + "relevant_chunks": [ + { + "chunk_relevance_score": 0.08, + "chunk": { + "name": "corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + "data": {"string_value": "This is a demo chunk."}, + "custom_metadata": [], + "state": 0, + }, + } + ] + }, + ) + + async def test_create_chunk(self): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + demo_document = await demo_corpus.create_document_async(display_name="demo_doc") + x = await demo_document.create_chunk_async( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + self.assertIsInstance(x, retriever_service.Chunk) + self.assertEqual("corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", x.name) + self.assertEqual(retriever_service.ChunkData("This is a demo chunk."), x.data) + + @parameterized.named_parameters( + [ + dict( + testcase_name="match_chunk_regex", + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + ), + dict(testcase_name="no_chunk", name="corpora/demo_corpus/demo_document/demo_chunk"), + dict( + testcase_name="with_punctuation", + name="corpora/demo_corpus*(*)/documents/demo_doc/chunks*****/demo_chunk", + ), + dict(testcase_name="dash_at_start", name="-demo_chunk"), + ] + ) + async def test_create_chunk_name(self, name): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + demo_document = await demo_corpus.create_document_async(display_name="demo_doc") + x = await demo_document.create_chunk_async( + name=name, + data="This is a demo chunk.", + ) + self.assertEqual("corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", x.name) + + @parameterized.named_parameters( + [ + dict( + testcase_name="dictionaries", + chunks=[ + { + "name": "corpora/demo_corpus/documents/demo_doc/chunks/dc", + "data": "This is a demo chunk.", + }, + { + "name": "corpora/demo_corpus/documents/demo_doc/chunks/dc1", + "data": "This is another demo chunk.", + }, + ], + ), + dict( + testcase_name="tuples", + chunks=[ + ( + "corpora/demo_corpus/documents/demo_doc/chunks/dc", + "This is a demo chunk.", + ), + ( + "corpora/demo_corpus/documents/demo_doc/chunks/dc1", + "This is another demo chunk.", + ), + ], + ), + ] + ) + async def test_batch_create_chunks(self, chunks): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + demo_document = await demo_corpus.create_document_async(display_name="demo_doc") + creation_req = await demo_document.batch_create_chunks_async(chunks=chunks) + self.assertIsInstance(self.observed_requests[-1], glm.BatchCreateChunksRequest) + self.assertEqual("This is a demo chunk.", creation_req["chunks"][0]["data"]["string_value"]) + self.assertEqual( + "This is another demo chunk.", creation_req["chunks"][1]["data"]["string_value"] + ) + + async def test_get_chunk(self): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + demo_document = await demo_corpus.create_document_async(display_name="demo_doc") + x = await demo_document.create_chunk_async( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + ch = await demo_document.get_chunk_async(name=x.name) + self.assertEqual(retriever_service.ChunkData("This is a demo chunk."), ch.data) + + async def test_list_chunks(self): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + demo_document = await demo_corpus.create_document_async(display_name="demo_doc") + x = await demo_document.create_chunk_async( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + y = await demo_document.create_chunk_async( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk_1", + data="This is another demo chunk.", + ) + list_req = await demo_document.list_chunks_async() + self.assertIsInstance(self.observed_requests[-1], glm.ListChunksRequest) + self.assertLen(list_req, 2) + + async def test_update_chunk(self): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + demo_document = await demo_corpus.create_document_async(display_name="demo_doc") + x = await demo_document.create_chunk_async( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + update_request = await x.update_async( + updates={"data": {"string_value": "This is an updated demo chunk."}} + ) + self.assertEqual( + retriever_service.ChunkData("This is an updated demo chunk."), + update_request.data, + ) + + @parameterized.named_parameters( + [ + dict( + testcase_name="dictionary_of_updates", + updates={ + "corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk": { + "data": {"string_value": "This is an updated chunk."} + }, + "corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk_1": { + "data": {"string_value": "This is another updated chunk."} + }, + }, + ), + dict( + testcase_name="list_of_tuples", + updates=[ + ( + "corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + {"data": {"string_value": "This is an updated chunk."}}, + ), + ( + "corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk_1", + {"data": {"string_value": "This is another updated chunk."}}, + ), + ], + ), + ], + ) + async def test_batch_update_chunks_data_structures(self, updates): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + demo_document = await demo_corpus.create_document_async(display_name="demo_doc") + x = await demo_document.create_chunk_async( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + y = await demo_document.create_chunk_async( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk_1", + data="This is another demo chunk.", + ) + update_request = await demo_document.batch_update_chunks_async(chunks=updates) + self.assertIsInstance(self.observed_requests[-1], glm.BatchUpdateChunksRequest) + self.assertEqual( + "This is an updated chunk.", update_request["chunks"][0]["data"]["string_value"] + ) + self.assertEqual( + "This is another updated chunk.", update_request["chunks"][1]["data"]["string_value"] + ) + + async def test_delete_chunk(self): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + demo_document = await demo_corpus.create_document_async(display_name="demo_doc") + x = await demo_document.create_chunk_async( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + delete_request = await demo_document.delete_chunk_async( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk" + ) + self.assertIsInstance(self.observed_requests[-1], glm.DeleteChunkRequest) + + async def test_batch_delete_chunks(self): + demo_corpus = await retriever.create_corpus_async(display_name="demo_corpus") + demo_document = await demo_corpus.create_document_async(display_name="demo_doc") + x = await demo_document.create_chunk_async( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is a demo chunk.", + ) + y = await demo_document.create_chunk_async( + name="corpora/demo_corpus/documents/demo_doc/chunks/demo_chunk", + data="This is another demo chunk.", + ) + delete_request = await demo_document.batch_delete_chunks_async(chunks=[x.name, y.name]) + self.assertIsInstance(self.observed_requests[-1], glm.BatchDeleteChunksRequest) + + +if __name__ == "__main__": + absltest.main()