Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a4ac7a5

Browse files
Merge branch 'main' into caching
2 parents f13228d + f987fde commit a4ac7a5

39 files changed

+2225
-564
lines changed

CONTRIBUTING.md

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,41 @@ This "editable" mode lets you edit the source without needing to reinstall the p
6262

6363
### Testing
6464

65-
Use the builtin unittest package:
65+
To ensure the integrity of the codebase, we have a suite of tests located in the `generative-ai-python/tests` directory.
6666

67+
You can run all these tests using Python's built-in `unittest` module or the `pytest` library.
68+
69+
For `unittest`, open a terminal and navigate to the root directory of the project. Then, execute the following command:
70+
71+
```
72+
python -m unittest discover -s tests
73+
74+
# or more simply
75+
python -m unittest
6776
```
68-
python -m unittest
77+
78+
Alternatively, if you prefer using `pytest`, you can install it using pip:
79+
6980
```
81+
pip install pytest
82+
```
83+
84+
Then, run the tests with the following command:
85+
86+
```
87+
pytest tests
88+
89+
# or more simply
90+
pytest
91+
```
92+
7093

7194
Or to debug, use:
7295

7396
```commandline
97+
pip install nose2
98+
7499
nose2 --debugger
75-
```
76100
77101
### Type checking
78102

google/generativeai/answer.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,10 @@
2626
get_default_generative_client,
2727
get_default_generative_async_client,
2828
)
29-
from google.generativeai import string_utils
3029
from google.generativeai.types import model_types
31-
from google.generativeai import models
30+
from google.generativeai.types import helper_types
3231
from google.generativeai.types import safety_types
3332
from google.generativeai.types import content_types
34-
from google.generativeai.types import answer_types
3533
from google.generativeai.types import retriever_types
3634
from google.generativeai.types.retriever_types import MetadataFilter
3735

@@ -94,7 +92,7 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP
9492

9593
if not isinstance(source, Iterable):
9694
raise TypeError(
97-
f"`source` must be a valid `GroundingPassagesOptions` type object got a: `{type(source)}`."
95+
f"The 'source' argument must be an instance of 'GroundingPassagesOptions', but got a '{type(source).__name__}' object instead."
9896
)
9997

10098
passages = []
@@ -182,7 +180,7 @@ def _make_generate_answer_request(
182180
temperature: float | None = None,
183181
) -> glm.GenerateAnswerRequest:
184182
"""
185-
Calls the API to generate a grounded answer from the model.
183+
constructs a glm.GenerateAnswerRequest object by organizing the input parameters for the API call to generate a grounded answer from the model.
186184
187185
Args:
188186
model: Name of the model used to generate the grounded response.
@@ -206,9 +204,7 @@ def _make_generate_answer_request(
206204
contents = content_types.to_contents(contents)
207205

208206
if safety_settings:
209-
safety_settings = safety_types.normalize_safety_settings(
210-
safety_settings, harm_category_set="new"
211-
)
207+
safety_settings = safety_types.normalize_safety_settings(safety_settings)
212208

213209
if inline_passages is not None and semantic_retriever is not None:
214210
raise ValueError(
@@ -219,7 +215,7 @@ def _make_generate_answer_request(
219215
elif semantic_retriever is not None:
220216
semantic_retriever = _make_semantic_retriever_config(semantic_retriever, contents[-1])
221217
else:
222-
TypeError(
218+
raise TypeError(
223219
f"The source must be either an `inline_passages` xor `semantic_retriever_config`, but both are `None`"
224220
)
225221

@@ -247,7 +243,7 @@ def generate_answer(
247243
safety_settings: safety_types.SafetySettingOptions | None = None,
248244
temperature: float | None = None,
249245
client: glm.GenerativeServiceClient | None = None,
250-
request_options: dict[str, Any] | None = None,
246+
request_options: helper_types.RequestOptionsType | None = None,
251247
):
252248
"""
253249
Calls the GenerateAnswer API and returns a `types.Answer` containing the response.
@@ -320,7 +316,7 @@ async def generate_answer_async(
320316
safety_settings: safety_types.SafetySettingOptions | None = None,
321317
temperature: float | None = None,
322318
client: glm.GenerativeServiceClient | None = None,
323-
request_options: dict[str, Any] | None = None,
319+
request_options: helper_types.RequestOptionsType | None = None,
324320
):
325321
"""
326322
Calls the API and returns a `types.Answer` containing the answer.

google/generativeai/client.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

33
import os
4+
import contextlib
45
import dataclasses
56
import pathlib
6-
import re
77
import types
88
from typing import Any, cast
99
from collections.abc import Sequence
@@ -12,6 +12,8 @@
1212
import google.ai.generativelanguage as glm
1313

1414
from google.auth import credentials as ga_credentials
15+
from google.auth import exceptions as ga_exceptions
16+
from google import auth
1517
from google.api_core import client_options as client_options_lib
1618
from google.api_core import gapic_v1
1719
from google.api_core import operations_v1
@@ -30,6 +32,18 @@
3032
GENAI_API_DISCOVERY_URL = "https://generativelanguage.googleapis.com/$discovery/rest"
3133

3234

35+
@contextlib.contextmanager
36+
def patch_colab_gce_credentials():
37+
get_gce = auth._default._get_gce_credentials
38+
if "COLAB_RELEASE_TAG" in os.environ:
39+
auth._default._get_gce_credentials = lambda *args, **kwargs: (None, None)
40+
41+
try:
42+
yield
43+
finally:
44+
auth._default._get_gce_credentials = get_gce
45+
46+
3347
class FileServiceClient(glm.FileServiceClient):
3448
def __init__(self, *args, **kwargs):
3549
self._discovery_api = None
@@ -59,6 +73,7 @@ def create_file(
5973
mime_type: str | None = None,
6074
name: str | None = None,
6175
display_name: str | None = None,
76+
resumable: bool = True,
6277
) -> glm.File:
6378
if self._discovery_api is None:
6479
self._setup_discovery_api()
@@ -69,19 +84,13 @@ def create_file(
6984
if display_name is not None:
7085
file["displayName"] = display_name
7186

72-
media = googleapiclient.http.MediaFileUpload(filename=path, mimetype=mime_type)
87+
media = googleapiclient.http.MediaFileUpload(
88+
filename=path, mimetype=mime_type, resumable=resumable
89+
)
7390
request = self._discovery_api.media().upload(body={"file": file}, media_body=media)
7491
result = request.execute()
7592

76-
allowed_keys = set(glm.File.__annotations__)
77-
78-
return glm.File(
79-
{
80-
re.sub("[A-Z]", lambda ch: f"_{ch.group(0).lower()}", key): value
81-
for key, value in result["file"].items()
82-
if key in allowed_keys
83-
}
84-
)
93+
return self.get_file({"name": result["file"]["name"]})
8594

8695

8796
class FileServiceAsyncClient(glm.FileServiceAsyncClient):
@@ -188,7 +197,17 @@ def make_client(self, name):
188197
if not self.client_config:
189198
configure()
190199

191-
client = cls(**self.client_config)
200+
try:
201+
with patch_colab_gce_credentials():
202+
client = cls(**self.client_config)
203+
except ga_exceptions.DefaultCredentialsError as e:
204+
e.args = (
205+
"\n No API_KEY or ADC found. Please either:\n"
206+
" - Set the `GOOGLE_API_KEY` environment variable.\n"
207+
" - Manually pass the key with `genai.configure(api_key=my_api_key)`.\n"
208+
" - Or set up Application Default Credentials, see https://ai.google.dev/gemini-api/docs/oauth for more information.",
209+
)
210+
raise e
192211

193212
if not self.default_metadata:
194213
return client
@@ -337,9 +356,9 @@ def get_default_retriever_async_client() -> glm.RetrieverAsyncClient:
337356
return _client_manager.get_default_client("retriever_async")
338357

339358

340-
def get_dafault_permission_client() -> glm.PermissionServiceClient:
359+
def get_default_permission_client() -> glm.PermissionServiceClient:
341360
return _client_manager.get_default_client("permission")
342361

343362

344-
def get_dafault_permission_async_client() -> glm.PermissionServiceAsyncClient:
363+
def get_default_permission_async_client() -> glm.PermissionServiceAsyncClient:
345364
return _client_manager.get_default_client("permission_async")

google/generativeai/discuss.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@
2626
from google.generativeai.client import get_default_discuss_async_client
2727
from google.generativeai import string_utils
2828
from google.generativeai.types import discuss_types
29+
from google.generativeai.types import helper_types
2930
from google.generativeai.types import model_types
30-
from google.generativeai.types import safety_types
31+
from google.generativeai.types import palm_safety_types
3132

3233

3334
def _make_message(content: discuss_types.MessageOptions) -> glm.Message:
@@ -316,7 +317,7 @@ def chat(
316317
top_k: float | None = None,
317318
prompt: discuss_types.MessagePromptOptions | None = None,
318319
client: glm.DiscussServiceClient | None = None,
319-
request_options: dict[str, Any] | None = None,
320+
request_options: helper_types.RequestOptionsType | None = None,
320321
) -> discuss_types.ChatResponse:
321322
"""Calls the API and returns a `types.ChatResponse` containing the response.
322323
@@ -416,7 +417,7 @@ async def chat_async(
416417
top_k: float | None = None,
417418
prompt: discuss_types.MessagePromptOptions | None = None,
418419
client: glm.DiscussServiceAsyncClient | None = None,
419-
request_options: dict[str, Any] | None = None,
420+
request_options: helper_types.RequestOptionsType | None = None,
420421
) -> discuss_types.ChatResponse:
421422
request = _make_generate_message_request(
422423
model=model,
@@ -469,7 +470,7 @@ def last(self, message: discuss_types.MessageOptions):
469470
def reply(
470471
self,
471472
message: discuss_types.MessageOptions,
472-
request_options: dict[str, Any] | None = None,
473+
request_options: helper_types.RequestOptionsType | None = None,
473474
) -> discuss_types.ChatResponse:
474475
if isinstance(self._client, glm.DiscussServiceAsyncClient):
475476
raise TypeError(f"reply can't be called on an async client, use reply_async instead.")
@@ -521,7 +522,7 @@ def _build_chat_response(
521522
response = type(response).to_dict(response)
522523
response.pop("messages")
523524

524-
response["filters"] = safety_types.convert_filters_to_enums(response["filters"])
525+
response["filters"] = palm_safety_types.convert_filters_to_enums(response["filters"])
525526

526527
if response["candidates"]:
527528
last = response["candidates"][0]
@@ -537,7 +538,7 @@ def _build_chat_response(
537538
def _generate_response(
538539
request: glm.GenerateMessageRequest,
539540
client: glm.DiscussServiceClient | None = None,
540-
request_options: dict[str, Any] | None = None,
541+
request_options: helper_types.RequestOptionsType | None = None,
541542
) -> ChatResponse:
542543
if request_options is None:
543544
request_options = {}
@@ -553,7 +554,7 @@ def _generate_response(
553554
async def _generate_response_async(
554555
request: glm.GenerateMessageRequest,
555556
client: glm.DiscussServiceAsyncClient | None = None,
556-
request_options: dict[str, Any] | None = None,
557+
request_options: helper_types.RequestOptionsType | None = None,
557558
) -> ChatResponse:
558559
if request_options is None:
559560
request_options = {}
@@ -574,7 +575,7 @@ def count_message_tokens(
574575
messages: discuss_types.MessagesOptions | None = None,
575576
model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL,
576577
client: glm.DiscussServiceAsyncClient | None = None,
577-
request_options: dict[str, Any] | None = None,
578+
request_options: helper_types.RequestOptionsType | None = None,
578579
) -> discuss_types.TokenCount:
579580
model = model_types.make_model_name(model)
580581
prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages)

google/generativeai/embedding.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
import dataclasses
18-
from collections.abc import Iterable, Sequence, Mapping
1917
import itertools
2018
from typing import Any, Iterable, overload, TypeVar, Union, Mapping
2119

@@ -24,7 +22,7 @@
2422
from google.generativeai.client import get_default_generative_client
2523
from google.generativeai.client import get_default_generative_async_client
2624

27-
from google.generativeai import string_utils
25+
from google.generativeai.types import helper_types
2826
from google.generativeai.types import text_types
2927
from google.generativeai.types import model_types
3028
from google.generativeai.types import content_types
@@ -104,7 +102,7 @@ def embed_content(
104102
title: str | None = None,
105103
output_dimensionality: int | None = None,
106104
client: glm.GenerativeServiceClient | None = None,
107-
request_options: dict[str, Any] | None = None,
105+
request_options: helper_types.RequestOptionsType | None = None,
108106
) -> text_types.EmbeddingDict: ...
109107

110108

@@ -116,7 +114,7 @@ def embed_content(
116114
title: str | None = None,
117115
output_dimensionality: int | None = None,
118116
client: glm.GenerativeServiceClient | None = None,
119-
request_options: dict[str, Any] | None = None,
117+
request_options: helper_types.RequestOptionsType | None = None,
120118
) -> text_types.BatchEmbeddingDict: ...
121119

122120

@@ -127,7 +125,7 @@ def embed_content(
127125
title: str | None = None,
128126
output_dimensionality: int | None = None,
129127
client: glm.GenerativeServiceClient = None,
130-
request_options: dict[str, Any] | None = None,
128+
request_options: helper_types.RequestOptionsType | None = None,
131129
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
132130
"""Calls the API to create embeddings for content passed in.
133131
@@ -224,7 +222,7 @@ async def embed_content_async(
224222
title: str | None = None,
225223
output_dimensionality: int | None = None,
226224
client: glm.GenerativeServiceAsyncClient | None = None,
227-
request_options: dict[str, Any] | None = None,
225+
request_options: helper_types.RequestOptionsType | None = None,
228226
) -> text_types.EmbeddingDict: ...
229227

230228

@@ -236,7 +234,7 @@ async def embed_content_async(
236234
title: str | None = None,
237235
output_dimensionality: int | None = None,
238236
client: glm.GenerativeServiceAsyncClient | None = None,
239-
request_options: dict[str, Any] | None = None,
237+
request_options: helper_types.RequestOptionsType | None = None,
240238
) -> text_types.BatchEmbeddingDict: ...
241239

242240

@@ -247,7 +245,7 @@ async def embed_content_async(
247245
title: str | None = None,
248246
output_dimensionality: int | None = None,
249247
client: glm.GenerativeServiceAsyncClient = None,
250-
request_options: dict[str, Any] | None = None,
248+
request_options: helper_types.RequestOptionsType | None = None,
251249
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
252250
"""The async version of `genai.embed_content`."""
253251
model = model_types.make_model_name(model)

google/generativeai/files.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,24 @@ def upload_file(
3535
mime_type: str | None = None,
3636
name: str | None = None,
3737
display_name: str | None = None,
38+
resumable: bool = True,
3839
) -> file_types.File:
40+
"""Uploads a file using a supported file service.
41+
42+
Args:
43+
path: The path to the file to be uploaded.
44+
mime_type: The MIME type of the file. If not provided, it will be
45+
inferred from the file extension.
46+
name: The name of the file in the destination (e.g., 'files/sample-image').
47+
If not provided, a system generated ID will be created.
48+
display_name: Optional display name of the file.
49+
resumable: Whether to use the resumable upload protocol. By default, this is enabled.
50+
See details at
51+
https://googleapis.github.io/google-api-python-client/docs/epy/googleapiclient.http.MediaFileUpload-class.html#resumable
52+
53+
Returns:
54+
file_types.File: The response of the uploaded file.
55+
"""
3956
client = get_default_file_client()
4057

4158
path = pathlib.Path(os.fspath(path))
@@ -50,7 +67,7 @@ def upload_file(
5067
display_name = path.name
5168

5269
response = client.create_file(
53-
path=path, mime_type=mime_type, name=name, display_name=display_name
70+
path=path, mime_type=mime_type, name=name, display_name=display_name, resumable=resumable
5471
)
5572
return file_types.File(response)
5673

0 commit comments

Comments
 (0)