Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions docs/build_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@
"search_hints", True, "Include metadata search hints in the generated files"
)

_SITE_PATH = flags.DEFINE_string(
"site_path", "/api/python", "Path prefix in the _toc.yaml"
)
_SITE_PATH = flags.DEFINE_string("site_path", "/api/python", "Path prefix in the _toc.yaml")

_CODE_URL_PREFIX = flags.DEFINE_string(
"code_url_prefix",
Expand All @@ -139,9 +137,7 @@ def drop_staticmethods(self, parent, children):
def __call__(self, path, parent, children):
if any("generativelanguage" in part for part in path) or "generativeai" in path:
children = self.filter_base_dirs(path, parent, children)
children = public_api.explicit_package_contents_filter(
path, parent, children
)
children = public_api.explicit_package_contents_filter(path, parent, children)

if any("generativelanguage" in part for part in path):
if "ServiceClient" in path[-1] or "ServiceAsyncClient" in path[-1]:
Expand All @@ -159,9 +155,7 @@ def make_default_filters(self):
public_api.add_proto_fields,
public_api.filter_builtin_modules,
public_api.filter_private_symbols,
MyFilter(
self._base_dir
), # Replaces: public_api.FilterBaseDirs(self._base_dir),
MyFilter(self._base_dir), # Replaces: public_api.FilterBaseDirs(self._base_dir),
public_api.FilterPrivateMap(self._private_map),
public_api.filter_doc_controls_skip,
public_api.ignore_typing,
Expand Down Expand Up @@ -229,9 +223,7 @@ def gen_api_docs():
new_content = re.sub(r".*?`oneof`_ ``_.*?\n", "", new_content, re.MULTILINE)
new_content = re.sub(r"\.\. code-block:: python.*?\n", "", new_content)

new_content = re.sub(
r"generativelanguage_\w+.types", "generativelanguage", new_content
)
new_content = re.sub(r"generativelanguage_\w+.types", "generativelanguage", new_content)

if new_content != old_content:
fpath.write_text(new_content)
Expand Down
12 changes: 3 additions & 9 deletions google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def configure(

if had_api_key_value:
if api_key is not None:
raise ValueError(
"You can't set both `api_key` and `client_options['api_key']`."
)
raise ValueError("You can't set both `api_key` and `client_options['api_key']`.")
else:
if api_key is None:
# If no key is provided explicitly, attempt to load one from the
Expand All @@ -107,9 +105,7 @@ def configure(
}

new_default_client_config = {
key: value
for key, value in new_default_client_config.items()
if value is not None
key: value for key, value in new_default_client_config.items() if value is not None
}

default_client_config = new_default_client_config
Expand Down Expand Up @@ -147,9 +143,7 @@ def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient:
# Attempt to configure using defaults.
if not default_client_config:
configure()
default_discuss_async_client = glm.DiscussServiceAsyncClient(
**default_client_config
)
default_discuss_async_client = glm.DiscussServiceAsyncClient(**default_client_config)

return default_discuss_async_client

Expand Down
32 changes: 12 additions & 20 deletions google/generativeai/discuss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def _make_message(content: discuss_types.MessageOptions) -> glm.Message:
return glm.Message(content)


def _make_messages(messages: discuss_types.MessagesOptions) -> List[glm.Message]:
def _make_messages(
messages: discuss_types.MessagesOptions,
) -> List[glm.Message]:
"""
Creates a list of `glm.Message` objects from the provided messages.

Expand Down Expand Up @@ -146,7 +148,9 @@ def _make_examples_from_flat(
return result


def _make_examples(examples: discuss_types.ExamplesOptions) -> List[glm.Example]:
def _make_examples(
examples: discuss_types.ExamplesOptions,
) -> List[glm.Example]:
"""
Creates a list of `glm.Example` objects from the provided examples.

Expand Down Expand Up @@ -223,9 +227,7 @@ def _make_message_prompt_dict(
messages=messages,
)
else:
flat_prompt = (
(context is not None) or (examples is not None) or (messages is not None)
)
flat_prompt = (context is not None) or (examples is not None) or (messages is not None)
if flat_prompt:
raise ValueError(
"You can't set `prompt`, and its fields `(context, examples, messages)`"
Expand Down Expand Up @@ -446,9 +448,7 @@ async def chat_async(
@set_doc(discuss_types.ChatResponse.__doc__)
@dataclasses.dataclass(**DATACLASS_KWARGS, init=False)
class ChatResponse(discuss_types.ChatResponse):
_client: glm.DiscussServiceClient | None = dataclasses.field(
default=lambda: None, repr=False
)
_client: glm.DiscussServiceClient | None = dataclasses.field(default=lambda: None, repr=False)

def __init__(self, **kwargs):
for key, value in kwargs.items():
Expand All @@ -469,13 +469,9 @@ def last(self, message: discuss_types.MessageOptions):
self.messages[-1] = message

@set_doc(discuss_types.ChatResponse.reply.__doc__)
def reply(
self, message: discuss_types.MessageOptions
) -> discuss_types.ChatResponse:
def reply(self, message: discuss_types.MessageOptions) -> discuss_types.ChatResponse:
if isinstance(self._client, glm.DiscussServiceAsyncClient):
raise TypeError(
f"reply can't be called on an async client, use reply_async instead."
)
raise TypeError(f"reply can't be called on an async client, use reply_async instead.")
if self.last is None:
raise ValueError(
"The last response from the model did not return any candidates.\n"
Expand Down Expand Up @@ -532,9 +528,7 @@ def _build_chat_response(
request.setdefault("temperature", None)
request.setdefault("candidate_count", None)

return ChatResponse(
_client=client, **response, **request
) # pytype: disable=missing-parameter
return ChatResponse(_client=client, **response, **request) # pytype: disable=missing-parameter


def _generate_response(
Expand Down Expand Up @@ -571,9 +565,7 @@ def count_message_tokens(
client: glm.DiscussServiceAsyncClient | None = None,
):
model = model_types.make_model_name(model)
prompt = _make_message_prompt(
prompt, context=context, examples=examples, messages=messages
)
prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages)

if client is None:
client = get_default_discuss_client()
Expand Down
32 changes: 14 additions & 18 deletions google/generativeai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def get_model(
raise ValueError("Model names must start with `models/` or `tunedModels/`")


def get_base_model(
name: model_types.BaseModelNameOptions, *, client=None
) -> model_types.Model:
def get_base_model(name: model_types.BaseModelNameOptions, *, client=None) -> model_types.Model:
"""Get the `types.Model` for the given base model name.

```
Expand Down Expand Up @@ -133,9 +131,7 @@ def _list_tuned_models_next_page(page_size, page_token, client):
)
result = result._response
result = type(result).to_dict(result)
result["models"] = [
model_types.decode_tuned_model(mod) for mod in result.pop("tuned_models")
]
result["models"] = [model_types.decode_tuned_model(mod) for mod in result.pop("tuned_models")]
result["page_size"] = page_size
result["page_token"] = result.pop("next_page_token")
result["client"] = client
Expand All @@ -154,21 +150,19 @@ def _list_models_iter_pages(
page_token = None
while True:
if select == "base":
result = _list_base_models_next_page(
page_size, page_token=page_token, client=client
)
result = _list_base_models_next_page(page_size, page_token=page_token, client=client)
elif select == "tuned":
result = _list_tuned_models_next_page(
page_size, page_token=page_token, client=client
)
result = _list_tuned_models_next_page(page_size, page_token=page_token, client=client)
yield from result["models"]
page_token = result["page_token"]
if page_token == "":
break


def list_models(
*, page_size: int | None = None, client: glm.ModelServiceClient | None = None
*,
page_size: int | None = None,
client: glm.ModelServiceClient | None = None,
) -> model_types.ModelsIterable:
"""Lists available models.

Expand All @@ -190,7 +184,9 @@ def list_models(


def list_tuned_models(
*, page_size: int | None = None, client: glm.ModelServiceClient | None = None
*,
page_size: int | None = None,
client: glm.ModelServiceClient | None = None,
) -> model_types.TunedModelsIterable:
"""Lists available models.

Expand Down Expand Up @@ -294,7 +290,9 @@ def create_tuned_model(
training_data = model_types.encode_tuning_data(training_data)

hyperparameters = glm.Hyperparameters(
epoch_count=epoch_count, batch_size=batch_size, learning_rate=learning_rate
epoch_count=epoch_count,
batch_size=batch_size,
learning_rate=learning_rate,
)
tuning_task = glm.TuningTask(
training_data=training_data,
Expand All @@ -310,9 +308,7 @@ def create_tuned_model(
top_k=top_k,
tuning_task=tuning_task,
)
operation = client.create_tuned_model(
dict(tuned_model_id=id, tuned_model=tuned_model)
)
operation = client.create_tuned_model(dict(tuned_model_id=id, tuned_model=tuned_model))

return operations.CreateTunedModelOperation.from_core_operation(operation)

Expand Down
8 changes: 2 additions & 6 deletions google/generativeai/notebook/argument_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ class ArgumentParserTest(absltest.TestCase):
def test_help(self):
"""Verify that help messages raise ParserNormalExit."""
parser = parser_lib.ArgumentParser()
with self.assertRaisesRegex(
parser_lib.ParserNormalExit, "show this help message and exit"
):
with self.assertRaisesRegex(parser_lib.ParserNormalExit, "show this help message and exit"):
parser.parse_args(["-h"])

def test_parse_arg_errors(self):
Expand All @@ -42,9 +40,7 @@ def new_parser() -> argparse.ArgumentParser:
with self.assertRaisesRegex(parser_lib.ParserError, "invalid int value"):
new_parser().parse_args(["--value", "forty-two"])

with self.assertRaisesRegex(
parser_lib.ParserError, "the following arguments are required"
):
with self.assertRaisesRegex(parser_lib.ParserError, "the following arguments are required"):
new_parser().parse_args([])

with self.assertRaisesRegex(parser_lib.ParserError, "expected one argument"):
Expand Down
57 changes: 16 additions & 41 deletions google/generativeai/notebook/cmd_line_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def _resolve_compare_fn_var(
"""Resolves a value passed into --compare_fn."""
fn = py_utils.get_py_var(name)
if not isinstance(fn, Callable):
raise ValueError(
'Variable "{}" does not contain a Callable object'.format(name)
)
raise ValueError('Variable "{}" does not contain a Callable object'.format(name))

return name, fn

Expand All @@ -80,19 +78,11 @@ def _resolve_ground_truth_var(name: str) -> Sequence[str]:

# "str" and "bytes" are also Sequences but we want an actual Sequence of
# strings, like a list.
if (
not isinstance(value, Sequence)
or isinstance(value, str)
or isinstance(value, bytes)
):
raise ValueError(
'Variable "{}" does not contain a Sequence of strings'.format(name)
)
if not isinstance(value, Sequence) or isinstance(value, str) or isinstance(value, bytes):
raise ValueError('Variable "{}" does not contain a Sequence of strings'.format(name))
for x in value:
if not isinstance(x, str):
raise ValueError(
'Variable "{}" does not contain a Sequence of strings'.format(name)
)
raise ValueError('Variable "{}" does not contain a Sequence of strings'.format(name))
return value


Expand Down Expand Up @@ -128,9 +118,7 @@ def _add_model_flags(

def _check_is_greater_than_or_equal_to_zero(x: float) -> float:
if x < 0:
raise ValueError(
"Value should be greater than or equal to zero, got {}".format(x)
)
raise ValueError("Value should be greater than or equal to zero, got {}".format(x))
return x

flag_def.SingleValueFlagDef(
Expand All @@ -154,8 +142,7 @@ def _check_is_greater_than_or_equal_to_zero(x: float) -> float:
short_name="m",
default_value=None,
help_msg=(
"The name of the model to use. If not provided, a default model will"
" be used."
"The name of the model to use. If not provided, a default model will" " be used."
),
).add_argument_to_parser(parser)

Expand Down Expand Up @@ -315,9 +302,7 @@ def _compile_save_name_fn(var_name: str) -> str:
return var_name

save_name_help = "The name of a Python variable to save the compiled function to."
parser.add_argument(
"compile_save_name", help=save_name_help, type=_compile_save_name_fn
)
parser.add_argument("compile_save_name", help=save_name_help, type=_compile_save_name_fn)
_add_model_flags(parser)


Expand Down Expand Up @@ -346,22 +331,16 @@ def _resolve_llm_function_fn(
if not isinstance(fn, llm_function.LLMFunction):
raise argparse.ArgumentError(
None,
'{} is not a function created with the "compile" command'.format(
var_name
),
'{} is not a function created with the "compile" command'.format(var_name),
)
return var_name, fn

name_help = (
"The name of a Python variable containing a function previously created"
' with the "compile" command.'
)
parser.add_argument(
"lhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn
)
parser.add_argument(
"rhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn
)
parser.add_argument("lhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn)
parser.add_argument("rhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn)

_add_input_flags(parser, placeholders)
_add_output_flags(parser)
Expand Down Expand Up @@ -409,9 +388,7 @@ def _create_parser(
subparsers.add_parser(parsed_args_lib.CommandName.RUN_CMD.value),
placeholders,
)
_create_compile_parser(
subparsers.add_parser(parsed_args_lib.CommandName.COMPILE_CMD.value)
)
_create_compile_parser(subparsers.add_parser(parsed_args_lib.CommandName.COMPILE_CMD.value))
_create_compare_parser(
subparsers.add_parser(parsed_args_lib.CommandName.COMPARE_CMD.value),
placeholders,
Expand Down Expand Up @@ -471,9 +448,7 @@ def _split_post_processing_tokens(
if start_idx is None:
start_idx = token_num
if token == CmdLineParser.PIPE_OP:
split_tokens.append(
tokens[start_idx:token_num] if start_idx is not None else []
)
split_tokens.append(tokens[start_idx:token_num] if start_idx is not None else [])
start_idx = None

# Add the remaining tokens after the last PIPE_OP.
Expand Down Expand Up @@ -518,7 +493,9 @@ def _get_model_args(
candidate_count = parsed_results.pop("candidate_count", None)

model_args = model_lib.ModelArguments(
model=model, temperature=temperature, candidate_count=candidate_count
model=model,
temperature=temperature,
candidate_count=candidate_count,
)
return parsed_results, model_args

Expand Down Expand Up @@ -556,9 +533,7 @@ def parse_line(
_, rhs_fn = parsed_args.rhs_name_and_fn
parsed_args = self._get_parsed_args_from_cmd_line_tokens(
tokens=tokens,
placeholders=frozenset(lhs_fn.get_placeholders()).union(
rhs_fn.get_placeholders()
),
placeholders=frozenset(lhs_fn.get_placeholders()).union(rhs_fn.get_placeholders()),
)

_validate_parsed_args(parsed_args)
Expand Down
Loading