Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1c8bdc7
Updated naming regex, made naming process simpler.
shilpakancharla Feb 1, 2024
e62c90d
Merge branch 'main' into retriever_naming
shilpakancharla Feb 1, 2024
7e5caeb
Fixed non async test cases for naming
shilpakancharla Feb 5, 2024
0a51355
Merge with main
shilpakancharla Feb 5, 2024
a2346d4
updated regex and made non async test cases pass
shilpakancharla Feb 5, 2024
25101ee
Updated regex to match proto
shilpakancharla Feb 5, 2024
ca853a9
Fixing missing await keywords
shilpakancharla Feb 5, 2024
2da63ee
Update async tests
shilpakancharla Feb 5, 2024
8d7cdf6
Update error messages
shilpakancharla Feb 5, 2024
c8a7530
Fixed async test case for create_document
shilpakancharla Feb 5, 2024
1c590c8
Add a valid_name function.
MarkDaoust Feb 5, 2024
66d6aa9
Added missing colon to valid_name function
shilpakancharla Feb 5, 2024
da1eafa
Removed duplicate import, updated error message
shilpakancharla Feb 5, 2024
49062c4
Added bool = Force to delete functions
shilpakancharla Feb 5, 2024
df4b0b6
Optional string for chunk name
shilpakancharla Feb 5, 2024
62f684c
Change to optional param for hcunk name
shilpakancharla Feb 5, 2024
9787e79
Add option for empty name for chunk and test cases
shilpakancharla Feb 5, 2024
9a33c6d
add _async suffix to test case
shilpakancharla Feb 5, 2024
bd93a7e
test async for empty chunk name
shilpakancharla Feb 5, 2024
ed95b1c
Pass type checking
shilpakancharla Feb 5, 2024
c3c2b10
Fixed test cases
shilpakancharla Feb 5, 2024
8d33a22
updated naming error msg
shilpakancharla Feb 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 15 additions & 32 deletions google/generativeai/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,12 @@

from google.generativeai.client import get_default_retriever_client
from google.generativeai.client import get_default_retriever_async_client
from google.generativeai.types import retriever_types
from google.generativeai.types.model_types import idecode_time

_CORPORA_NAME_REGEX = re.compile(r"^corpora/[a-z0-9-]+")
_REMOVE = string.punctuation
_REMOVE = _REMOVE.replace("-", "") # Don't remove hyphens
_PATTERN = r"[{}]".format(_REMOVE) # Create the pattern
from google.generativeai.types import retriever_types


def create_corpus(
name: Optional[str] = None,
name: str,
display_name: Optional[str] = None,
client: glm.RetrieverServiceClient | None = None,
) -> retriever_types.Corpus:
Expand All @@ -58,18 +53,12 @@ def create_corpus(
if client is None:
client = get_default_retriever_client()

if not name and not display_name:
raise ValueError("Either the corpus name or display name must be specified.")

corpus = None
if name:
if re.match(_CORPORA_NAME_REGEX, name):
corpus = glm.Corpus(name=name, display_name=display_name)
elif "corpora/" not in name:
corpus_name = "corpora/" + re.sub(_PATTERN, "", name)
corpus = glm.Corpus(name=corpus_name, display_name=display_name)
else:
raise ValueError("Corpus name must be formatted as corpora/<corpus_name>.")
if retriever_types.valid_name(name):
corpus_name = "corpora/" + name # Construct the name
corpus = glm.Corpus(name=corpus_name, display_name=display_name)
else:
raise ValueError(retriever_types.NAME_ERROR_MSG.format(length=len(name), name=name))

request = glm.CreateCorpusRequest(corpus=corpus)
response = client.create_corpus(request)
Expand All @@ -81,26 +70,20 @@ def create_corpus(


async def create_corpus_async(
name: Optional[str] = None,
name: str,
display_name: Optional[str] = None,
client: glm.RetrieverServiceAsyncClient | None = None,
) -> retriever_types.Corpus:
"""This is the async version of `retriever.create_corpus`."""
if client is None:
client = get_default_retriever_async_client()

if not name and not display_name:
raise ValueError("Either the corpus name or display name must be specified.")

corpus = None
if name:
if re.match(_CORPORA_NAME_REGEX, name):
corpus = glm.Corpus(name=name, display_name=display_name)
elif "corpora/" not in name:
corpus_name = "corpora/" + re.sub(_PATTERN, "", name)
corpus = glm.Corpus(name=corpus_name, display_name=display_name)
else:
raise ValueError("Corpus name must be formatted as corpora/<corpus_name>.")
if retriever_types.valid_name(name):
corpus_name = "corpora/" + name # Construct the name
corpus = glm.Corpus(name=corpus_name, display_name=display_name)
else:
raise ValueError(retriever_types.NAME_ERROR_MSG.format(length=len(name), name=name))

request = glm.CreateCorpusRequest(corpus=corpus)
response = await client.create_corpus(request)
Expand Down Expand Up @@ -147,7 +130,7 @@ async def get_corpus_async(name: str, client: glm.RetrieverServiceAsyncClient |
return response


def delete_corpus(name: str, force: bool, client: glm.RetrieverServiceClient | None = None): # fmt: skip
def delete_corpus(name: str, force: bool = False, client: glm.RetrieverServiceClient | None = None): # fmt: skip
"""
Delete a `Corpus` from the service.

Expand All @@ -162,7 +145,7 @@ def delete_corpus(name: str, force: bool, client: glm.RetrieverServiceClient | N
client.delete_corpus(request)


async def delete_corpus_async(name: str, force: bool, client: glm.RetrieverServiceAsyncClient | None = None): # fmt: skip
async def delete_corpus_async(name: str, force: bool = False, client: glm.RetrieverServiceAsyncClient | None = None): # fmt: skip
"""This is the async version of `retriever.delete_corpus`."""
if client is None:
client = get_default_retriever_async_client()
Expand Down
112 changes: 44 additions & 68 deletions google/generativeai/types/retriever_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,16 @@
from google.generativeai.types.model_types import idecode_time
from google.generativeai.utils import flatten_update_paths

_VALID_NAME = r"[a-z0-9]([a-z0-9-]{0,38}[a-z0-9])$"
NAME_ERROR_MSG = """The `name` must consist of alphanumeric characters (or -) and be 40 or fewer characters. The name you entered:
\tlen(name)== {length}
\tname={name}
"""


def valid_name(name):
return re.match(_VALID_NAME, name) and len(name) < 40

_DOCUMENT_NAME_REGEX = re.compile(r"^corpora/[a-z0-9-]+/documents/[a-z0-9-]+$")
_CHUNK_NAME_REGEX = re.compile(r"^corpora/([^/]+?)(/documents/([^/]+?)(/chunks/([^/]+?))?)?$")
_REMOVE = string.punctuation
_REMOVE = _REMOVE.replace("-", "") # Don't remove hyphens
_PATTERN = r"[{}]".format(_REMOVE) # Create the pattern

Operator = glm.Condition.Operator
State = glm.Chunk.State
Expand Down Expand Up @@ -180,7 +184,7 @@ class Corpus:

def create_document(
self,
name: Optional[str] = None,
name: str,
display_name: Optional[str] = None,
custom_metadata: Optional[list[CustomMetadata]] = None,
client: glm.RetrieverServiceClient | None = None,
Expand All @@ -203,32 +207,22 @@ def create_document(
if client is None:
client = get_default_retriever_client()

if not name and not display_name:
raise ValueError("Either the document name or display name must be specified.")

document = None
if name:
if re.match(_DOCUMENT_NAME_REGEX, name):
document = glm.Document(
name=name, display_name=display_name, custom_metadata=custom_metadata
)
elif f"corpora/{self.name}/documents/" not in name:
document_name = f"{self.name}/documents/" + re.sub(_PATTERN, "", name)
document = glm.Document(
name=document_name, display_name=display_name, custom_metadata=custom_metadata
)
else:
raise ValueError(
f"Document name must be formatted as {self.name}/document/<document_name>."
)
if valid_name(name):
document_name = f"{self.name}/documents/{name}"
document = glm.Document(
name=document_name, display_name=display_name, custom_metadata=custom_metadata
)
else:
raise ValueError(NAME_ERROR_MSG.format(length=len(name), name=name))

request = glm.CreateDocumentRequest(parent=self.name, document=document)
response = client.create_document(request)
return decode_document(response)

async def create_document_async(
self,
name: Optional[str] = None,
name: str,
display_name: Optional[str] = None,
custom_metadata: Optional[list[CustomMetadata]] = None,
client: glm.RetrieverServiceAsyncClient | None = None,
Expand All @@ -237,24 +231,14 @@ async def create_document_async(
if client is None:
client = get_default_retriever_async_client()

if not name and not display_name:
raise ValueError("Either the document name or display name must be specified.")

document = None
if name:
if re.match(_DOCUMENT_NAME_REGEX, name):
document = glm.Document(
name=name, display_name=display_name, custom_metadata=custom_metadata
)
elif f"corpora/{self.name}/documents/" not in name:
document_name = f"{self.name}/documents/" + re.sub(_PATTERN, "", name)
document = glm.Document(
name=document_name, display_name=display_name, custom_metadata=custom_metadata
)
else:
raise ValueError(
f"Document name must be formatted as {self.name}/document/<document_name>."
)
if valid_name(name):
document_name = f"{self.name}/documents/{name}"
document = glm.Document(
name=document_name, display_name=display_name, custom_metadata=custom_metadata
)
else:
raise ValueError(NAME_ERROR_MSG.format(length=len(name), name=name))

request = glm.CreateDocumentRequest(parent=self.name, document=document)
response = await client.create_document(request)
Expand Down Expand Up @@ -431,7 +415,7 @@ async def query_async(
def delete_document(
self,
name: str,
force: Optional[bool] = None,
force: bool = False,
client: glm.RetrieverServiceClient | None = None,
):
"""
Expand All @@ -450,7 +434,7 @@ def delete_document(
async def delete_document_async(
self,
name: str,
force: Optional[bool] = None,
force: bool = False,
client: glm.RetrieverServiceAsyncClient | None = None,
):
"""This is the async version of `Corpus.delete_document`."""
Expand Down Expand Up @@ -528,17 +512,17 @@ class Document(abc.ABC):

def create_chunk(
self,
name: Optional[str],
data: str | ChunkData,
name: Optional[str] = None,
custom_metadata: Optional[list[CustomMetadata]] = None,
client: glm.RetrieverServiceClient | None = None,
) -> Chunk:
"""
Create a `Chunk` object which has textual data.

Args:
name: The `Chunk` resource name. The ID (name excluding the "corpora/*/documents/*/chunks/" prefix) can contain up to 40 characters that are lowercase alphanumeric or dashes (-).
data: The content for the `Chunk`, such as the text string.
name: The `Chunk` resource name. The ID (name excluding the "corpora/*/documents/*/chunks/" prefix) can contain up to 40 characters that are lowercase alphanumeric or dashes (-).
custom_metadata: User provided custom metadata stored as key-value pairs.
state: States for the lifecycle of a `Chunk`.

Expand All @@ -551,17 +535,13 @@ def create_chunk(
if client is None:
client = get_default_retriever_client()

chunk_name, chunk = "", None
if name:
if re.match(_CHUNK_NAME_REGEX, name):
chunk_name = name

elif "chunks/" not in name:
chunk_name = f"{self.name}/chunks/" + re.sub(_PATTERN, "", name)
else:
raise ValueError(
f"Chunk name must be formatted as {self.name}/chunks/<chunk_name>."
)
chunk_name, chunk = None, None
if name is None:
chunk_name = None
elif valid_name(name):
chunk_name = f"{self.name}/chunks/{name}"
else:
raise ValueError(NAME_ERROR_MSG.format(length=len(name), name=name))

if isinstance(data, str):
chunk = glm.Chunk(
Expand All @@ -580,26 +560,22 @@ def create_chunk(

async def create_chunk_async(
self,
name: Optional[str],
data: str | ChunkData,
name: Optional[str] = None,
custom_metadata: Optional[list[CustomMetadata]] = None,
client: glm.RetrieverServiceAsyncClient | None = None,
) -> Chunk:
"""This is the async version of `Document.create_chunk`."""
if client is None:
client = get_default_retriever_async_client()

chunk_name, chunk = "", None
if name:
if re.match(_CHUNK_NAME_REGEX, name):
chunk_name = name

elif "chunks/" not in name:
chunk_name = f"{self.name}/chunks/" + re.sub(_PATTERN, "", name)
else:
raise ValueError(
f"Chunk name must be formatted as {self.name}/chunks/<chunk_name>."
)
chunk_name, chunk = None, None
if name is None:
chunk_name = None
elif valid_name(name):
chunk_name = f"{self.name}/chunks/{name}"
else:
raise ValueError(NAME_ERROR_MSG.format(length=len(name), name=name))

if isinstance(data, str):
chunk = glm.Chunk(
Expand Down
Loading