-
Notifications
You must be signed in to change notification settings - Fork 467
Semantic retriever #168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Semantic retriever #168
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
2824827
Semantic retriever
shilpakancharla 6d81564
Adding parameter to re.sub for create_chunk function
shilpakancharla 1434b7e
attempting to fix pytype error with ChunkData and CustomMetadata
shilpakancharla 6dbb461
Adding async to semantic retriever functions
shilpakancharla 8c92643
Update _flatten to flatten
shilpakancharla 9fb4758
Update google/generativeai/types/retriever_types.py
shilpakancharla ef9d47c
Update google/generativeai/client.py
shilpakancharla 1bd5fc9
Update google/generativeai/models.py
shilpakancharla 0a88492
Resolving Github precheck failures
shilpakancharla 59a4cad
Changed .data to .string_value
shilpakancharla 304a610
Update _flatten_update_paths to flatten_update_paths
shilpakancharla 3b1d55c
Updating async test cases for retriever
shilpakancharla 19a8b30
Added in client methods in async retriever test
shilpakancharla c674136
Added all async test cases
shilpakancharla 0684564
Fixed all test cases locally
shilpakancharla 3d09db2
Updated async retriever tests
shilpakancharla 29b057a
Fixing names in async test cases
shilpakancharla 2f0abdc
Fixed async method for QueryCorpus
shilpakancharla c5645b6
Added await statements
shilpakancharla 489d00e
Reformatted file
shilpakancharla df7e0a5
Added async methods to test cases
shilpakancharla ff3046d
Updated regex statements and removed redundancy from elif statements
shilpakancharla 9035f66
Updates to create_chunk
shilpakancharla afc11ac
Async code test update, dataclass updates
shilpakancharla ff76bc9
Skipping format check to resolve errors
shilpakancharla af4625a
Modified gitignore
shilpakancharla d1bddf0
Update to delete_document
shilpakancharla 79f8f33
Formatting check
shilpakancharla File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,4 @@ | |
*.egg-info | ||
.DS_Store | ||
__pycache__ | ||
*.iml | ||
*.iml |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from __future__ import annotations | ||
|
||
import re | ||
import string | ||
import dataclasses | ||
from typing import Optional | ||
|
||
import google.ai.generativelanguage as glm | ||
|
||
from google.generativeai.client import get_default_retriever_client | ||
from google.generativeai.client import get_default_retriever_async_client | ||
from google.generativeai import string_utils | ||
from google.generativeai.types import retriever_types | ||
from google.generativeai.types import model_types | ||
from google.generativeai import models | ||
from google.generativeai.types import safety_types | ||
from google.generativeai.types.model_types import idecode_time | ||
|
||
_CORPORA_NAME_REGEX = re.compile(r"^corpora/[a-z0-9-]+") | ||
_REMOVE = string.punctuation | ||
_REMOVE = _REMOVE.replace("-", "") # Don't remove hyphens | ||
_PATTERN = r"[{}]".format(_REMOVE) # Create the pattern | ||
|
||
|
||
@string_utils.prettyprint | ||
@dataclasses.dataclass(init=False) | ||
class Corpus(retriever_types.Corpus): | ||
shilpakancharla marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__(self, **kwargs): | ||
for key, value in kwargs.items(): | ||
setattr(self, key, value) | ||
|
||
self.result = None | ||
if self.name: | ||
self.result = self.name | ||
|
||
|
||
def create_corpus( | ||
name: Optional[str] = None, | ||
display_name: Optional[str] = None, | ||
client: glm.RetrieverServiceClient | None = None, | ||
) -> Corpus: | ||
""" | ||
Create a Corpus object. Users can specify either a name or display_name. | ||
|
||
Args: | ||
name: The corpus resource name (ID). The name must be alphanumeric and fewer | ||
than 40 characters. | ||
display_name: The human readable display name. The display name must be fewer | ||
than 128 characters. All characters, including alphanumeric, spaces, and | ||
dashes are supported. | ||
|
||
Return: | ||
Corpus object with specified name or display name. | ||
|
||
Raises: | ||
ValueError: When the name is not specified or formatted incorrectly. | ||
""" | ||
if client is None: | ||
client = get_default_retriever_client() | ||
|
||
if not name and not display_name: | ||
raise ValueError("Either the corpus name or display name must be specified.") | ||
|
||
corpus = None | ||
if name: | ||
if re.match(_CORPORA_NAME_REGEX, name): | ||
corpus = glm.Corpus(name=name, display_name=display_name) | ||
elif "corpora/" not in name: | ||
corpus_name = "corpora/" + re.sub(_PATTERN, "", name) | ||
corpus = glm.Corpus(name=corpus_name, display_name=display_name) | ||
else: | ||
raise ValueError("Corpus name must be formatted as corpora/<corpus_name>.") | ||
|
||
request = glm.CreateCorpusRequest(corpus=corpus) | ||
response = client.create_corpus(request) | ||
response = type(response).to_dict(response) | ||
idecode_time(response, "create_time") | ||
idecode_time(response, "update_time") | ||
response = Corpus(**response) | ||
return response | ||
|
||
|
||
async def create_corpus_async( | ||
name: Optional[str] = None, | ||
display_name: Optional[str] = None, | ||
client: glm.RetrieverServiceAsyncClient | None = None, | ||
) -> Corpus: | ||
"""This is the async version of `create_corpus`.""" | ||
shilpakancharla marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if client is None: | ||
client = get_default_retriever_async_client() | ||
|
||
if not name and not display_name: | ||
raise ValueError("Either the corpus name or display name must be specified.") | ||
|
||
corpus = None | ||
if name: | ||
if re.match(_CORPORA_NAME_REGEX, name): | ||
corpus = glm.Corpus(name=name, display_name=display_name) | ||
elif "corpora/" not in name: | ||
corpus_name = "corpora/" + re.sub(_PATTERN, "", name) | ||
corpus = glm.Corpus(name=corpus_name, display_name=display_name) | ||
else: | ||
raise ValueError("Corpus name must be formatted as corpora/<corpus_name>.") | ||
|
||
request = glm.CreateCorpusRequest(corpus=corpus) | ||
response = await client.create_corpus(request) | ||
response = type(response).to_dict(response) | ||
idecode_time(response, "create_time") | ||
idecode_time(response, "update_time") | ||
response = Corpus(**response) | ||
return response | ||
|
||
|
||
def get_corpus(name: str, client: glm.RetrieverServiceClient | None = None) -> Corpus: # fmt: skip | ||
""" | ||
Get information about a specific `Corpus`. | ||
|
||
Args: | ||
name: The `Corpus` name. | ||
|
||
Return: | ||
`Corpus` of interest. | ||
""" | ||
if client is None: | ||
client = get_default_retriever_client() | ||
|
||
request = glm.GetCorpusRequest(name=name) | ||
response = client.get_corpus(request) | ||
response = type(response).to_dict(response) | ||
idecode_time(response, "create_time") | ||
idecode_time(response, "update_time") | ||
response = Corpus(**response) | ||
return response | ||
|
||
|
||
async def get_corpus_async(name: str, client: glm.RetrieverServiceAsyncClient | None = None) -> Corpus: # fmt: skip | ||
"""This is the async version of `get_corpus`.""" | ||
if client is None: | ||
client = get_default_retriever_async_client() | ||
|
||
request = glm.GetCorpusRequest(name=name) | ||
response = await client.get_corpus(request) | ||
response = type(response).to_dict(response) | ||
idecode_time(response, "create_time") | ||
idecode_time(response, "update_time") | ||
response = Corpus(**response) | ||
return response | ||
|
||
|
||
def delete_corpus(name: str, force: bool, client: glm.RetrieverServiceClient | None = None): # fmt: skip | ||
""" | ||
Delete a `Corpus`. | ||
|
||
Args: | ||
name: The `Corpus` name. | ||
force: If set to true, any `Document`s and objects related to this `Corpus` will also be deleted. | ||
""" | ||
if client is None: | ||
client = get_default_retriever_client() | ||
|
||
request = glm.DeleteCorpusRequest(name=name, force=force) | ||
client.delete_corpus(request) | ||
|
||
|
||
async def delete_corpus_async(name: str, force: bool, client: glm.RetrieverServiceAsyncClient | None = None): # fmt: skip | ||
"""This is the async version of `delete_corpus`.""" | ||
if client is None: | ||
client = get_default_retriever_async_client() | ||
|
||
request = glm.DeleteCorpusRequest(name=name, force=force) | ||
await client.delete_corpus(request) | ||
|
||
|
||
def list_corpora( | ||
shilpakancharla marked this conversation as resolved.
Show resolved
Hide resolved
|
||
*, | ||
page_size: Optional[int] = None, | ||
page_token: Optional[str] = None, | ||
client: glm.RetrieverServiceClient | None = None, | ||
) -> list[Corpus]: | ||
""" | ||
List `Corpus`. | ||
|
||
Args: | ||
page_size: Maximum number of `Corpora` to request. | ||
page_token: A page token, received from a previous ListCorpora call. | ||
|
||
Return: | ||
Paginated list of `Corpora`. | ||
""" | ||
if client is None: | ||
client = get_default_retriever_client() | ||
|
||
request = glm.ListCorporaRequest(page_size=page_size, page_token=page_token) | ||
response = client.list_corpora(request) | ||
shilpakancharla marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return response | ||
|
||
|
||
async def list_corpora_async( | ||
*, | ||
page_size: Optional[int] = None, | ||
page_token: Optional[str] = None, | ||
client: glm.RetrieverServiceClient | None = None, | ||
) -> list[Corpus]: | ||
"""This is the async version of `list_corpora`.""" | ||
if client is None: | ||
client = get_default_retriever_async_client() | ||
|
||
request = glm.ListCorporaRequest(page_size=page_size, page_token=page_token) | ||
response = await client.list_corpora(request) | ||
return response |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from __future__ import annotations | ||
|
||
import abc | ||
import dataclasses | ||
from typing import Any, Dict, List, TypedDict | ||
|
||
|
||
class EmbeddingDict(TypedDict): | ||
shilpakancharla marked this conversation as resolved.
Show resolved
Hide resolved
|
||
embedding: list[float] | ||
|
||
|
||
class BatchEmbeddingDict(TypedDict): | ||
embedding: list[list[float]] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.