Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 51d806d

Browse files
Fix bugs, improve code clarity, and enhance overall reliability across several files. (google-gemini#339)
* Fix and improve * Fix `_make_grounding_passages` , `_make_generate_answer_request` * fix get_default_permission_client and get_default_permission_async_client * Add how to test all in CONTRIBUTING.md * fix back support for `tunedModels/` in `get_model` function * Add pytest to CONTRIBUTING.md * Break down test_generate_text for better debugging. * Add pip install nose2 to CONTRIBUTING.md * Format Change-Id: I4e222f3e01cb8d350ae293b35a88fd5f718fe3dc * fix sloppy types in tests Change-Id: I3ad717ca26e5d170e4bbef23076e528badaaaacb * Update CONTRIBUTING.md * Update CONTRIBUTING.md --------- Co-authored-by: Mark Daoust <[email protected]>
1 parent 1b1d883 commit 51d806d

File tree

7 files changed

+97
-43
lines changed

7 files changed

+97
-43
lines changed

CONTRIBUTING.md

+27-3
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,41 @@ This "editable" mode lets you edit the source without needing to reinstall the p
6262

6363
### Testing
6464

65-
Use the builtin unittest package:
65+
To ensure the integrity of the codebase, we have a suite of tests located in the `generative-ai-python/tests` directory.
6666

67+
You can run all these tests using Python's built-in `unittest` module or the `pytest` library.
68+
69+
For `unittest`, open a terminal and navigate to the root directory of the project. Then, execute the following command:
70+
71+
```
72+
python -m unittest discover -s tests
73+
74+
# or more simply
75+
python -m unittest
6776
```
68-
python -m unittest
77+
78+
Alternatively, if you prefer using `pytest`, you can install it using pip:
79+
6980
```
81+
pip install pytest
82+
```
83+
84+
Then, run the tests with the following command:
85+
86+
```
87+
pytest tests
88+
89+
# or more simply
90+
pytest
91+
```
92+
7093

7194
Or to debug, use:
7295

7396
```commandline
97+
pip install nose2
98+
7499
nose2 --debugger
75-
```
76100
77101
### Type checking
78102

google/generativeai/answer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP
9494

9595
if not isinstance(source, Iterable):
9696
raise TypeError(
97-
f"`source` must be a valid `GroundingPassagesOptions` type object got a: `{type(source)}`."
97+
f"The 'source' argument must be an instance of 'GroundingPassagesOptions', but got a '{type(source).__name__}' object instead."
9898
)
9999

100100
passages = []
@@ -182,7 +182,7 @@ def _make_generate_answer_request(
182182
temperature: float | None = None,
183183
) -> glm.GenerateAnswerRequest:
184184
"""
185-
Calls the API to generate a grounded answer from the model.
185+
constructs a glm.GenerateAnswerRequest object by organizing the input parameters for the API call to generate a grounded answer from the model.
186186
187187
Args:
188188
model: Name of the model used to generate the grounded response.
@@ -217,7 +217,7 @@ def _make_generate_answer_request(
217217
elif semantic_retriever is not None:
218218
semantic_retriever = _make_semantic_retriever_config(semantic_retriever, contents[-1])
219219
else:
220-
TypeError(
220+
raise TypeError(
221221
f"The source must be either an `inline_passages` xor `semantic_retriever_config`, but both are `None`"
222222
)
223223

google/generativeai/client.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -328,9 +328,9 @@ def get_default_retriever_async_client() -> glm.RetrieverAsyncClient:
328328
return _client_manager.get_default_client("retriever_async")
329329

330330

331-
def get_dafault_permission_client() -> glm.PermissionServiceClient:
331+
def get_default_permission_client() -> glm.PermissionServiceClient:
332332
return _client_manager.get_default_client("permission")
333333

334334

335-
def get_dafault_permission_async_client() -> glm.PermissionServiceAsyncClient:
335+
def get_default_permission_async_client() -> glm.PermissionServiceAsyncClient:
336336
return _client_manager.get_default_client("permission_async")

google/generativeai/generative_models.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def start_chat(
387387
>>> response = chat.send_message("Hello?")
388388
389389
Arguments:
390-
history: An iterable of `glm.Content` objects, or equvalents to initialize the session.
390+
history: An iterable of `glm.Content` objects, or equivalents to initialize the session.
391391
"""
392392
if self._generation_config.get("candidate_count", 1) > 1:
393393
raise ValueError("Can't chat with `candidate_count > 1`")
@@ -401,11 +401,13 @@ def start_chat(
401401
class ChatSession:
402402
"""Contains an ongoing conversation with the model.
403403
404-
>>> model = genai.GenerativeModel(model="gemini-pro")
404+
>>> model = genai.GenerativeModel('models/gemini-pro')
405405
>>> chat = model.start_chat()
406406
>>> response = chat.send_message("Hello")
407407
>>> print(response.text)
408-
>>> response = chat.send_message(...)
408+
>>> response = chat.send_message("Hello again")
409+
>>> print(response.text)
410+
>>> response = chat.send_message(...
409411
410412
This `ChatSession` object collects the messages sent and received, in its
411413
`ChatSession.history` attribute.
@@ -444,7 +446,7 @@ def send_message(
444446
445447
Appends the request and response to the conversation history.
446448
447-
>>> model = genai.GenerativeModel(model="gemini-pro")
449+
>>> model = genai.GenerativeModel('models/gemini-pro')
448450
>>> chat = model.start_chat()
449451
>>> response = chat.send_message("Hello")
450452
>>> print(response.text)

google/generativeai/models.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -33,29 +33,31 @@ def get_model(
3333
client=None,
3434
request_options: dict[str, Any] | None = None,
3535
) -> model_types.Model | model_types.TunedModel:
36-
"""Given a model name, fetch the `types.Model` or `types.TunedModel` object.
36+
"""Given a model name, fetch the `types.Model`
3737
3838
```
3939
import pprint
40-
model = genai.get_tuned_model(model_name):
40+
model = genai.get_model('models/gemini-pro')
4141
pprint.pprint(model)
4242
```
4343
4444
Args:
45-
name: The name of the model to fetch.
45+
name: The name of the model to fetch. Should start with `models/`
4646
client: The client to use.
4747
request_options: Options for the request.
4848
4949
Returns:
50-
A `types.Model` or `types.TunedModel` object.
50+
A `types.Model`
5151
"""
5252
name = model_types.make_model_name(name)
5353
if name.startswith("models/"):
5454
return get_base_model(name, client=client, request_options=request_options)
5555
elif name.startswith("tunedModels/"):
5656
return get_tuned_model(name, client=client, request_options=request_options)
5757
else:
58-
raise ValueError("Model names must start with `models/` or `tunedModels/`")
58+
raise ValueError(
59+
f"Model names must start with `models/` or `tunedModels/`. Received: {name}"
60+
)
5961

6062

6163
def get_base_model(
@@ -68,12 +70,12 @@ def get_base_model(
6870
6971
```
7072
import pprint
71-
model = genai.get_model('models/chat-bison-001'):
73+
model = genai.get_base_model('models/chat-bison-001')
7274
pprint.pprint(model)
7375
```
7476
7577
Args:
76-
name: The name of the model to fetch.
78+
name: The name of the model to fetch. Should start with `models/`
7779
client: The client to use.
7880
request_options: Options for the request.
7981
@@ -88,7 +90,7 @@ def get_base_model(
8890

8991
name = model_types.make_model_name(name)
9092
if not name.startswith("models/"):
91-
raise ValueError(f"Base model names must start with `models/`, got: {name}")
93+
raise ValueError(f"Base model names must start with `models/`, received: {name}")
9294

9395
result = client.get_model(name=name, **request_options)
9496
result = type(result).to_dict(result)
@@ -105,12 +107,12 @@ def get_tuned_model(
105107
106108
```
107109
import pprint
108-
model = genai.get_tuned_model('tunedModels/my-model-1234'):
110+
model = genai.get_tuned_model('tunedModels/gemini-1.0-pro-001')
109111
pprint.pprint(model)
110112
```
111113
112114
Args:
113-
name: The name of the model to fetch.
115+
name: The name of the model to fetch. Should start with `tunedModels/`
114116
client: The client to use.
115117
request_options: Options for the request.
116118
@@ -126,7 +128,7 @@ def get_tuned_model(
126128
name = model_types.make_model_name(name)
127129

128130
if not name.startswith("tunedModels/"):
129-
raise ValueError("Tuned model names must start with `tunedModels/`")
131+
raise ValueError("Tuned model names must start with `tunedModels/` received: {name}")
130132

131133
result = client.get_tuned_model(name=name, **request_options)
132134

google/generativeai/types/permission_types.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222

2323
from google.protobuf import field_mask_pb2
2424

25-
from google.generativeai.client import get_dafault_permission_client
26-
from google.generativeai.client import get_dafault_permission_async_client
25+
from google.generativeai.client import get_default_permission_client
26+
from google.generativeai.client import get_default_permission_async_client
2727
from google.generativeai.utils import flatten_update_paths
2828
from google.generativeai import string_utils
2929

@@ -107,7 +107,7 @@ def delete(
107107
Delete permission (self).
108108
"""
109109
if client is None:
110-
client = get_dafault_permission_client()
110+
client = get_default_permission_client()
111111
delete_request = glm.DeletePermissionRequest(name=self.name)
112112
client.delete_permission(request=delete_request)
113113

@@ -119,7 +119,7 @@ async def delete_async(
119119
This is the async version of `Permission.delete`.
120120
"""
121121
if client is None:
122-
client = get_dafault_permission_async_client()
122+
client = get_default_permission_async_client()
123123
delete_request = glm.DeletePermissionRequest(name=self.name)
124124
await client.delete_permission(request=delete_request)
125125

@@ -146,7 +146,7 @@ def update(
146146
`Permission` object with specified updates.
147147
"""
148148
if client is None:
149-
client = get_dafault_permission_client()
149+
client = get_default_permission_client()
150150

151151
updates = flatten_update_paths(updates)
152152
for update_path in updates:
@@ -176,7 +176,7 @@ async def update_async(
176176
This is the async version of `Permission.update`.
177177
"""
178178
if client is None:
179-
client = get_dafault_permission_async_client()
179+
client = get_default_permission_async_client()
180180

181181
updates = flatten_update_paths(updates)
182182
for update_path in updates:
@@ -224,7 +224,7 @@ def get(
224224
Requested permission as an instance of `Permission`.
225225
"""
226226
if client is None:
227-
client = get_dafault_permission_client()
227+
client = get_default_permission_client()
228228
get_perm_request = glm.GetPermissionRequest(name=name)
229229
get_perm_response = client.get_permission(request=get_perm_request)
230230
get_perm_response = type(get_perm_response).to_dict(get_perm_response)
@@ -240,7 +240,7 @@ async def get_async(
240240
This is the async version of `Permission.get`.
241241
"""
242242
if client is None:
243-
client = get_dafault_permission_async_client()
243+
client = get_default_permission_async_client()
244244
get_perm_request = glm.GetPermissionRequest(name=name)
245245
get_perm_response = await client.get_permission(request=get_perm_request)
246246
get_perm_response = type(get_perm_response).to_dict(get_perm_response)
@@ -313,7 +313,7 @@ def create(
313313
ValueError: When email_address is not specified and grantee_type is not set to EVERYONE.
314314
"""
315315
if client is None:
316-
client = get_dafault_permission_client()
316+
client = get_default_permission_client()
317317

318318
request = self._make_create_permission_request(
319319
role=role, grantee_type=grantee_type, email_address=email_address
@@ -333,7 +333,7 @@ async def create_async(
333333
This is the async version of `PermissionAdapter.create_permission`.
334334
"""
335335
if client is None:
336-
client = get_dafault_permission_async_client()
336+
client = get_default_permission_async_client()
337337

338338
request = self._make_create_permission_request(
339339
role=role, grantee_type=grantee_type, email_address=email_address
@@ -358,7 +358,7 @@ def list(
358358
Paginated list of `Permission` objects.
359359
"""
360360
if client is None:
361-
client = get_dafault_permission_client()
361+
client = get_default_permission_client()
362362

363363
request = glm.ListPermissionsRequest(
364364
parent=self.parent, page_size=page_size # pytype: disable=attribute-error
@@ -376,7 +376,7 @@ async def list_async(
376376
This is the async version of `PermissionAdapter.list_permissions`.
377377
"""
378378
if client is None:
379-
client = get_dafault_permission_async_client()
379+
client = get_default_permission_async_client()
380380

381381
request = glm.ListPermissionsRequest(
382382
parent=self.parent, page_size=page_size # pytype: disable=attribute-error
@@ -400,7 +400,7 @@ def transfer_ownership(
400400
if self.parent.startswith("corpora"):
401401
raise NotImplementedError("Can'/t transfer_ownership for a Corpus")
402402
if client is None:
403-
client = get_dafault_permission_client()
403+
client = get_default_permission_client()
404404
transfer_request = glm.TransferOwnershipRequest(
405405
name=self.parent, email_address=email_address # pytype: disable=attribute-error
406406
)
@@ -415,7 +415,7 @@ async def transfer_ownership_async(
415415
if self.parent.startswith("corpora"):
416416
raise NotImplementedError("Can'/t transfer_ownership for a Corpus")
417417
if client is None:
418-
client = get_dafault_permission_async_client()
418+
client = get_default_permission_async_client()
419419
transfer_request = glm.TransferOwnershipRequest(
420420
name=self.parent, email_address=email_address # pytype: disable=attribute-error
421421
)

tests/notebook/text_model_test.py

+32-6
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,47 @@ def _generate_text(
6868

6969

7070
class TextModelTestCase(absltest.TestCase):
71-
def test_generate_text(self):
71+
def test_generate_text_without_args(self):
7272
model = TestModel()
7373

7474
result = model.call_model("prompt goes in")
7575
self.assertEqual(result.text_results[0], "prompt goes in_1")
76-
self.assertIsNone(result.text_results[1])
77-
self.assertIsNone(result.text_results[2])
78-
self.assertIsNone(result.text_results[3])
7976

77+
def test_generate_text_without_args_none_results(self):
78+
model = TestModel()
79+
80+
result = model.call_model("prompt goes in")
81+
self.assertEqual(result.text_results[1], "None")
82+
self.assertEqual(result.text_results[2], "None")
83+
self.assertEqual(result.text_results[3], "None")
84+
85+
def test_generate_text_with_args_first_result(self):
86+
model = TestModel()
8087
args = model_lib.ModelArguments(model="model_name", temperature=0.42, candidate_count=5)
88+
8189
result = model.call_model("prompt goes in", args)
8290
self.assertEqual(result.text_results[0], "prompt goes in_1")
91+
92+
def test_generate_text_with_args_model_name(self):
93+
model = TestModel()
94+
args = model_lib.ModelArguments(model="model_name", temperature=0.42, candidate_count=5)
95+
96+
result = model.call_model("prompt goes in", args)
8397
self.assertEqual(result.text_results[1], "model_name")
84-
self.assertEqual(result.text_results[2], 0.42)
85-
self.assertEqual(result.text_results[3], 5)
98+
99+
def test_generate_text_with_args_temperature(self):
100+
model = TestModel()
101+
args = model_lib.ModelArguments(model="model_name", temperature=0.42, candidate_count=5)
102+
result = model.call_model("prompt goes in", args)
103+
104+
self.assertEqual(result.text_results[2], str(0.42))
105+
106+
def test_generate_text_with_args_candidate_count(self):
107+
model = TestModel()
108+
args = model_lib.ModelArguments(model="model_name", temperature=0.42, candidate_count=5)
109+
110+
result = model.call_model("prompt goes in", args)
111+
self.assertEqual(result.text_results[3], str(5))
86112

87113
def test_retry(self):
88114
model = TestModel()

0 commit comments

Comments
 (0)