Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b3fa336

Browse files
Updated naming regex, made naming process simpler. (google-gemini#191)
* Updated naming regex, made naming process simpler. * Fixed non async test cases for naming * updated regex and made non async test cases pass * Updated regex to match proto * Fixing missing await keywords * Update async tests * Update error messages * Fixed async test case for create_document * Add a valid_name function. * Added missing colon to valid_name function * Removed duplicate import, updated error message * Added bool = Force to delete functions * Optional string for chunk name * Change to optional param for hcunk name * Add option for empty name for chunk and test cases * add _async suffix to test case * test async for empty chunk name * Pass type checking * Fixed test cases * updated naming error msg --------- Co-authored-by: Mark Daoust <[email protected]>
1 parent 80b171a commit b3fa336

File tree

4 files changed

+277
-409
lines changed

4 files changed

+277
-409
lines changed

google/generativeai/retriever.py

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,12 @@
2323

2424
from google.generativeai.client import get_default_retriever_client
2525
from google.generativeai.client import get_default_retriever_async_client
26-
from google.generativeai.types import retriever_types
2726
from google.generativeai.types.model_types import idecode_time
28-
29-
_CORPORA_NAME_REGEX = re.compile(r"^corpora/[a-z0-9-]+")
30-
_REMOVE = string.punctuation
31-
_REMOVE = _REMOVE.replace("-", "") # Don't remove hyphens
32-
_PATTERN = r"[{}]".format(_REMOVE) # Create the pattern
27+
from google.generativeai.types import retriever_types
3328

3429

3530
def create_corpus(
36-
name: Optional[str] = None,
31+
name: str,
3732
display_name: Optional[str] = None,
3833
client: glm.RetrieverServiceClient | None = None,
3934
) -> retriever_types.Corpus:
@@ -58,18 +53,12 @@ def create_corpus(
5853
if client is None:
5954
client = get_default_retriever_client()
6055

61-
if not name and not display_name:
62-
raise ValueError("Either the corpus name or display name must be specified.")
63-
6456
corpus = None
65-
if name:
66-
if re.match(_CORPORA_NAME_REGEX, name):
67-
corpus = glm.Corpus(name=name, display_name=display_name)
68-
elif "corpora/" not in name:
69-
corpus_name = "corpora/" + re.sub(_PATTERN, "", name)
70-
corpus = glm.Corpus(name=corpus_name, display_name=display_name)
71-
else:
72-
raise ValueError("Corpus name must be formatted as corpora/<corpus_name>.")
57+
if retriever_types.valid_name(name):
58+
corpus_name = "corpora/" + name # Construct the name
59+
corpus = glm.Corpus(name=corpus_name, display_name=display_name)
60+
else:
61+
raise ValueError(retriever_types.NAME_ERROR_MSG.format(length=len(name), name=name))
7362

7463
request = glm.CreateCorpusRequest(corpus=corpus)
7564
response = client.create_corpus(request)
@@ -81,26 +70,20 @@ def create_corpus(
8170

8271

8372
async def create_corpus_async(
84-
name: Optional[str] = None,
73+
name: str,
8574
display_name: Optional[str] = None,
8675
client: glm.RetrieverServiceAsyncClient | None = None,
8776
) -> retriever_types.Corpus:
8877
"""This is the async version of `retriever.create_corpus`."""
8978
if client is None:
9079
client = get_default_retriever_async_client()
9180

92-
if not name and not display_name:
93-
raise ValueError("Either the corpus name or display name must be specified.")
94-
9581
corpus = None
96-
if name:
97-
if re.match(_CORPORA_NAME_REGEX, name):
98-
corpus = glm.Corpus(name=name, display_name=display_name)
99-
elif "corpora/" not in name:
100-
corpus_name = "corpora/" + re.sub(_PATTERN, "", name)
101-
corpus = glm.Corpus(name=corpus_name, display_name=display_name)
102-
else:
103-
raise ValueError("Corpus name must be formatted as corpora/<corpus_name>.")
82+
if retriever_types.valid_name(name):
83+
corpus_name = "corpora/" + name # Construct the name
84+
corpus = glm.Corpus(name=corpus_name, display_name=display_name)
85+
else:
86+
raise ValueError(retriever_types.NAME_ERROR_MSG.format(length=len(name), name=name))
10487

10588
request = glm.CreateCorpusRequest(corpus=corpus)
10689
response = await client.create_corpus(request)
@@ -147,7 +130,7 @@ async def get_corpus_async(name: str, client: glm.RetrieverServiceAsyncClient |
147130
return response
148131

149132

150-
def delete_corpus(name: str, force: bool, client: glm.RetrieverServiceClient | None = None): # fmt: skip
133+
def delete_corpus(name: str, force: bool = False, client: glm.RetrieverServiceClient | None = None): # fmt: skip
151134
"""
152135
Delete a `Corpus` from the service.
153136
@@ -162,7 +145,7 @@ def delete_corpus(name: str, force: bool, client: glm.RetrieverServiceClient | N
162145
client.delete_corpus(request)
163146

164147

165-
async def delete_corpus_async(name: str, force: bool, client: glm.RetrieverServiceAsyncClient | None = None): # fmt: skip
148+
async def delete_corpus_async(name: str, force: bool = False, client: glm.RetrieverServiceAsyncClient | None = None): # fmt: skip
166149
"""This is the async version of `retriever.delete_corpus`."""
167150
if client is None:
168151
client = get_default_retriever_async_client()

google/generativeai/types/retriever_types.py

Lines changed: 44 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,16 @@
3232
from google.generativeai.types.model_types import idecode_time
3333
from google.generativeai.utils import flatten_update_paths
3434

35+
_VALID_NAME = r"[a-z0-9]([a-z0-9-]{0,38}[a-z0-9])$"
36+
NAME_ERROR_MSG = """The `name` must consist of alphanumeric characters (or -) and be 40 or fewer characters. The name you entered:
37+
\tlen(name)== {length}
38+
\tname={name}
39+
"""
40+
41+
42+
def valid_name(name):
43+
return re.match(_VALID_NAME, name) and len(name) < 40
3544

36-
_DOCUMENT_NAME_REGEX = re.compile(r"^corpora/[a-z0-9-]+/documents/[a-z0-9-]+$")
37-
_CHUNK_NAME_REGEX = re.compile(r"^corpora/([^/]+?)(/documents/([^/]+?)(/chunks/([^/]+?))?)?$")
38-
_REMOVE = string.punctuation
39-
_REMOVE = _REMOVE.replace("-", "") # Don't remove hyphens
40-
_PATTERN = r"[{}]".format(_REMOVE) # Create the pattern
4145

4246
Operator = glm.Condition.Operator
4347
State = glm.Chunk.State
@@ -180,7 +184,7 @@ class Corpus:
180184

181185
def create_document(
182186
self,
183-
name: Optional[str] = None,
187+
name: str,
184188
display_name: Optional[str] = None,
185189
custom_metadata: Optional[list[CustomMetadata]] = None,
186190
client: glm.RetrieverServiceClient | None = None,
@@ -203,32 +207,22 @@ def create_document(
203207
if client is None:
204208
client = get_default_retriever_client()
205209

206-
if not name and not display_name:
207-
raise ValueError("Either the document name or display name must be specified.")
208-
209210
document = None
210-
if name:
211-
if re.match(_DOCUMENT_NAME_REGEX, name):
212-
document = glm.Document(
213-
name=name, display_name=display_name, custom_metadata=custom_metadata
214-
)
215-
elif f"corpora/{self.name}/documents/" not in name:
216-
document_name = f"{self.name}/documents/" + re.sub(_PATTERN, "", name)
217-
document = glm.Document(
218-
name=document_name, display_name=display_name, custom_metadata=custom_metadata
219-
)
220-
else:
221-
raise ValueError(
222-
f"Document name must be formatted as {self.name}/document/<document_name>."
223-
)
211+
if valid_name(name):
212+
document_name = f"{self.name}/documents/{name}"
213+
document = glm.Document(
214+
name=document_name, display_name=display_name, custom_metadata=custom_metadata
215+
)
216+
else:
217+
raise ValueError(NAME_ERROR_MSG.format(length=len(name), name=name))
224218

225219
request = glm.CreateDocumentRequest(parent=self.name, document=document)
226220
response = client.create_document(request)
227221
return decode_document(response)
228222

229223
async def create_document_async(
230224
self,
231-
name: Optional[str] = None,
225+
name: str,
232226
display_name: Optional[str] = None,
233227
custom_metadata: Optional[list[CustomMetadata]] = None,
234228
client: glm.RetrieverServiceAsyncClient | None = None,
@@ -237,24 +231,14 @@ async def create_document_async(
237231
if client is None:
238232
client = get_default_retriever_async_client()
239233

240-
if not name and not display_name:
241-
raise ValueError("Either the document name or display name must be specified.")
242-
243234
document = None
244-
if name:
245-
if re.match(_DOCUMENT_NAME_REGEX, name):
246-
document = glm.Document(
247-
name=name, display_name=display_name, custom_metadata=custom_metadata
248-
)
249-
elif f"corpora/{self.name}/documents/" not in name:
250-
document_name = f"{self.name}/documents/" + re.sub(_PATTERN, "", name)
251-
document = glm.Document(
252-
name=document_name, display_name=display_name, custom_metadata=custom_metadata
253-
)
254-
else:
255-
raise ValueError(
256-
f"Document name must be formatted as {self.name}/document/<document_name>."
257-
)
235+
if valid_name(name):
236+
document_name = f"{self.name}/documents/{name}"
237+
document = glm.Document(
238+
name=document_name, display_name=display_name, custom_metadata=custom_metadata
239+
)
240+
else:
241+
raise ValueError(NAME_ERROR_MSG.format(length=len(name), name=name))
258242

259243
request = glm.CreateDocumentRequest(parent=self.name, document=document)
260244
response = await client.create_document(request)
@@ -431,7 +415,7 @@ async def query_async(
431415
def delete_document(
432416
self,
433417
name: str,
434-
force: Optional[bool] = None,
418+
force: bool = False,
435419
client: glm.RetrieverServiceClient | None = None,
436420
):
437421
"""
@@ -450,7 +434,7 @@ def delete_document(
450434
async def delete_document_async(
451435
self,
452436
name: str,
453-
force: Optional[bool] = None,
437+
force: bool = False,
454438
client: glm.RetrieverServiceAsyncClient | None = None,
455439
):
456440
"""This is the async version of `Corpus.delete_document`."""
@@ -528,17 +512,17 @@ class Document(abc.ABC):
528512

529513
def create_chunk(
530514
self,
531-
name: Optional[str],
532515
data: str | ChunkData,
516+
name: Optional[str] = None,
533517
custom_metadata: Optional[list[CustomMetadata]] = None,
534518
client: glm.RetrieverServiceClient | None = None,
535519
) -> Chunk:
536520
"""
537521
Create a `Chunk` object which has textual data.
538522
539523
Args:
540-
name: The `Chunk` resource name. The ID (name excluding the "corpora/*/documents/*/chunks/" prefix) can contain up to 40 characters that are lowercase alphanumeric or dashes (-).
541524
data: The content for the `Chunk`, such as the text string.
525+
name: The `Chunk` resource name. The ID (name excluding the "corpora/*/documents/*/chunks/" prefix) can contain up to 40 characters that are lowercase alphanumeric or dashes (-).
542526
custom_metadata: User provided custom metadata stored as key-value pairs.
543527
state: States for the lifecycle of a `Chunk`.
544528
@@ -551,17 +535,13 @@ def create_chunk(
551535
if client is None:
552536
client = get_default_retriever_client()
553537

554-
chunk_name, chunk = "", None
555-
if name:
556-
if re.match(_CHUNK_NAME_REGEX, name):
557-
chunk_name = name
558-
559-
elif "chunks/" not in name:
560-
chunk_name = f"{self.name}/chunks/" + re.sub(_PATTERN, "", name)
561-
else:
562-
raise ValueError(
563-
f"Chunk name must be formatted as {self.name}/chunks/<chunk_name>."
564-
)
538+
chunk_name, chunk = None, None
539+
if name is None:
540+
chunk_name = None
541+
elif valid_name(name):
542+
chunk_name = f"{self.name}/chunks/{name}"
543+
else:
544+
raise ValueError(NAME_ERROR_MSG.format(length=len(name), name=name))
565545

566546
if isinstance(data, str):
567547
chunk = glm.Chunk(
@@ -580,26 +560,22 @@ def create_chunk(
580560

581561
async def create_chunk_async(
582562
self,
583-
name: Optional[str],
584563
data: str | ChunkData,
564+
name: Optional[str] = None,
585565
custom_metadata: Optional[list[CustomMetadata]] = None,
586566
client: glm.RetrieverServiceAsyncClient | None = None,
587567
) -> Chunk:
588568
"""This is the async version of `Document.create_chunk`."""
589569
if client is None:
590570
client = get_default_retriever_async_client()
591571

592-
chunk_name, chunk = "", None
593-
if name:
594-
if re.match(_CHUNK_NAME_REGEX, name):
595-
chunk_name = name
596-
597-
elif "chunks/" not in name:
598-
chunk_name = f"{self.name}/chunks/" + re.sub(_PATTERN, "", name)
599-
else:
600-
raise ValueError(
601-
f"Chunk name must be formatted as {self.name}/chunks/<chunk_name>."
602-
)
572+
chunk_name, chunk = None, None
573+
if name is None:
574+
chunk_name = None
575+
elif valid_name(name):
576+
chunk_name = f"{self.name}/chunks/{name}"
577+
else:
578+
raise ValueError(NAME_ERROR_MSG.format(length=len(name), name=name))
603579

604580
if isinstance(data, str):
605581
chunk = glm.Chunk(

0 commit comments

Comments
 (0)