Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2824827
Semantic retriever
shilpakancharla Oct 30, 2023
6d81564
Adding parameter to re.sub for create_chunk function
shilpakancharla Jan 12, 2024
1434b7e
attempting to fix pytype error with ChunkData and CustomMetadata
shilpakancharla Jan 12, 2024
6dbb461
Adding async to semantic retriever functions
shilpakancharla Jan 17, 2024
8c92643
Update _flatten to flatten
shilpakancharla Jan 17, 2024
9fb4758
Update google/generativeai/types/retriever_types.py
shilpakancharla Jan 17, 2024
ef9d47c
Update google/generativeai/client.py
shilpakancharla Jan 17, 2024
1bd5fc9
Update google/generativeai/models.py
shilpakancharla Jan 17, 2024
0a88492
Resolving Github precheck failures
shilpakancharla Jan 19, 2024
59a4cad
Changed .data to .string_value
shilpakancharla Jan 19, 2024
304a610
Update _flatten_update_paths to flatten_update_paths
shilpakancharla Jan 19, 2024
3b1d55c
Updating async test cases for retriever
shilpakancharla Jan 23, 2024
19a8b30
Added in client methods in async retriever test
shilpakancharla Jan 23, 2024
c674136
Added all async test cases
shilpakancharla Jan 24, 2024
0684564
Fixed all test cases locally
shilpakancharla Jan 24, 2024
3d09db2
Updated async retriever tests
shilpakancharla Jan 24, 2024
29b057a
Fixing names in async test cases
shilpakancharla Jan 24, 2024
2f0abdc
Fixed async method for QueryCorpus
shilpakancharla Jan 24, 2024
c5645b6
Added await statements
shilpakancharla Jan 24, 2024
489d00e
Reformatted file
shilpakancharla Jan 24, 2024
df7e0a5
Added async methods to test cases
shilpakancharla Jan 24, 2024
ff3046d
Updated regex statements and removed redundancy from elif statements
shilpakancharla Jan 24, 2024
9035f66
Updates to create_chunk
shilpakancharla Jan 25, 2024
afc11ac
Async code test update, dataclass updates
shilpakancharla Jan 29, 2024
ff76bc9
Skipping format check to resolve errors
shilpakancharla Jan 30, 2024
af4625a
Modified gitignore
shilpakancharla Jan 30, 2024
d1bddf0
Update to delete_document
shilpakancharla Jan 30, 2024
79f8f33
Formatting check
shilpakancharla Jan 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
*.egg-info
.DS_Store
__pycache__
*.iml
*.iml
9 changes: 8 additions & 1 deletion google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def get_default_operations_client(self) -> operations_v1.OperationsClient:
model_client = self.get_default_client("Model")
client = model_client._transport.operations_client
self.clients["operations"] = client

return client


Expand Down Expand Up @@ -244,3 +243,11 @@ def get_default_operations_client() -> operations_v1.OperationsClient:

def get_default_model_client() -> glm.ModelServiceAsyncClient:
return _client_manager.get_default_client("model")


def get_default_retriever_client() -> glm.RetrieverClient:
return _client_manager.get_default_client("retriever")


def get_default_retriever_async_client() -> glm.RetrieverAsyncClient:
return _client_manager.get_default_client("retriever_async")
15 changes: 2 additions & 13 deletions google/generativeai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.api_core import operation
from google.api_core import protobuf_helpers
from google.protobuf import field_mask_pb2
from google.generativeai.utils import flatten_update_paths


def get_model(
Expand Down Expand Up @@ -351,7 +352,7 @@ def update_tuned_model(
)
tuned_model = client.get_tuned_model(name=name)

updates = _flatten_update_paths(updates)
updates = flatten_update_paths(updates)
field_mask = field_mask_pb2.FieldMask()
for path in updates.keys():
field_mask.paths.append(path)
Expand Down Expand Up @@ -379,18 +380,6 @@ def update_tuned_model(
return model_types.decode_tuned_model(result)


def _flatten_update_paths(updates):
new_updates = {}
for key, value in updates.items():
if isinstance(value, dict):
for sub_key, sub_value in _flatten_update_paths(value).items():
new_updates[f"{key}.{sub_key}"] = sub_value
else:
new_updates[key] = value

return new_updates


def _apply_update(thing, path, value):
parts = path.split(".")
for part in parts[:-1]:
Expand Down
224 changes: 224 additions & 0 deletions google/generativeai/retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import re
import string
import dataclasses
from typing import Optional

import google.ai.generativelanguage as glm

from google.generativeai.client import get_default_retriever_client
from google.generativeai.client import get_default_retriever_async_client
from google.generativeai import string_utils
from google.generativeai.types import retriever_types
from google.generativeai.types import model_types
from google.generativeai import models
from google.generativeai.types import safety_types
from google.generativeai.types.model_types import idecode_time

_CORPORA_NAME_REGEX = re.compile(r"^corpora/[a-z0-9-]+")
_REMOVE = string.punctuation
_REMOVE = _REMOVE.replace("-", "") # Don't remove hyphens
_PATTERN = r"[{}]".format(_REMOVE) # Create the pattern


@string_utils.prettyprint
@dataclasses.dataclass(init=False)
class Corpus(retriever_types.Corpus):
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)

self.result = None
if self.name:
self.result = self.name


def create_corpus(
name: Optional[str] = None,
display_name: Optional[str] = None,
client: glm.RetrieverServiceClient | None = None,
) -> Corpus:
"""
Create a Corpus object. Users can specify either a name or display_name.

Args:
name: The corpus resource name (ID). The name must be alphanumeric and fewer
than 40 characters.
display_name: The human readable display name. The display name must be fewer
than 128 characters. All characters, including alphanumeric, spaces, and
dashes are supported.

Return:
Corpus object with specified name or display name.

Raises:
ValueError: When the name is not specified or formatted incorrectly.
"""
if client is None:
client = get_default_retriever_client()

if not name and not display_name:
raise ValueError("Either the corpus name or display name must be specified.")

corpus = None
if name:
if re.match(_CORPORA_NAME_REGEX, name):
corpus = glm.Corpus(name=name, display_name=display_name)
elif "corpora/" not in name:
corpus_name = "corpora/" + re.sub(_PATTERN, "", name)
corpus = glm.Corpus(name=corpus_name, display_name=display_name)
else:
raise ValueError("Corpus name must be formatted as corpora/<corpus_name>.")

request = glm.CreateCorpusRequest(corpus=corpus)
response = client.create_corpus(request)
response = type(response).to_dict(response)
idecode_time(response, "create_time")
idecode_time(response, "update_time")
response = Corpus(**response)
return response


async def create_corpus_async(
name: Optional[str] = None,
display_name: Optional[str] = None,
client: glm.RetrieverServiceAsyncClient | None = None,
) -> Corpus:
"""This is the async version of `create_corpus`."""
if client is None:
client = get_default_retriever_async_client()

if not name and not display_name:
raise ValueError("Either the corpus name or display name must be specified.")

corpus = None
if name:
if re.match(_CORPORA_NAME_REGEX, name):
corpus = glm.Corpus(name=name, display_name=display_name)
elif "corpora/" not in name:
corpus_name = "corpora/" + re.sub(_PATTERN, "", name)
corpus = glm.Corpus(name=corpus_name, display_name=display_name)
else:
raise ValueError("Corpus name must be formatted as corpora/<corpus_name>.")

request = glm.CreateCorpusRequest(corpus=corpus)
response = await client.create_corpus(request)
response = type(response).to_dict(response)
idecode_time(response, "create_time")
idecode_time(response, "update_time")
response = Corpus(**response)
return response


def get_corpus(name: str, client: glm.RetrieverServiceClient | None = None) -> Corpus: # fmt: skip
"""
Get information about a specific `Corpus`.

Args:
name: The `Corpus` name.

Return:
`Corpus` of interest.
"""
if client is None:
client = get_default_retriever_client()

request = glm.GetCorpusRequest(name=name)
response = client.get_corpus(request)
response = type(response).to_dict(response)
idecode_time(response, "create_time")
idecode_time(response, "update_time")
response = Corpus(**response)
return response


async def get_corpus_async(name: str, client: glm.RetrieverServiceAsyncClient | None = None) -> Corpus: # fmt: skip
"""This is the async version of `get_corpus`."""
if client is None:
client = get_default_retriever_async_client()

request = glm.GetCorpusRequest(name=name)
response = await client.get_corpus(request)
response = type(response).to_dict(response)
idecode_time(response, "create_time")
idecode_time(response, "update_time")
response = Corpus(**response)
return response


def delete_corpus(name: str, force: bool, client: glm.RetrieverServiceClient | None = None): # fmt: skip
"""
Delete a `Corpus`.

Args:
name: The `Corpus` name.
force: If set to true, any `Document`s and objects related to this `Corpus` will also be deleted.
"""
if client is None:
client = get_default_retriever_client()

request = glm.DeleteCorpusRequest(name=name, force=force)
client.delete_corpus(request)


async def delete_corpus_async(name: str, force: bool, client: glm.RetrieverServiceAsyncClient | None = None): # fmt: skip
"""This is the async version of `delete_corpus`."""
if client is None:
client = get_default_retriever_async_client()

request = glm.DeleteCorpusRequest(name=name, force=force)
await client.delete_corpus(request)


def list_corpora(
*,
page_size: Optional[int] = None,
page_token: Optional[str] = None,
client: glm.RetrieverServiceClient | None = None,
) -> list[Corpus]:
"""
List `Corpus`.

Args:
page_size: Maximum number of `Corpora` to request.
page_token: A page token, received from a previous ListCorpora call.

Return:
Paginated list of `Corpora`.
"""
if client is None:
client = get_default_retriever_client()

request = glm.ListCorporaRequest(page_size=page_size, page_token=page_token)
response = client.list_corpora(request)
return response


async def list_corpora_async(
*,
page_size: Optional[int] = None,
page_token: Optional[str] = None,
client: glm.RetrieverServiceClient | None = None,
) -> list[Corpus]:
"""This is the async version of `list_corpora`."""
if client is None:
client = get_default_retriever_async_client()

request = glm.ListCorporaRequest(page_size=page_size, page_token=page_token)
response = await client.list_corpora(request)
return response
27 changes: 27 additions & 0 deletions google/generativeai/types/embedding_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import abc
import dataclasses
from typing import Any, Dict, List, TypedDict


class EmbeddingDict(TypedDict):
embedding: list[float]


class BatchEmbeddingDict(TypedDict):
embedding: list[list[float]]
Loading