diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index 60bcbba95..1ada73036 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -30,6 +30,7 @@ def _make_message(content: discuss_types.MessageOptions) -> glm.Message: + """Creates a `glm.Message` object from the provided content.""" if isinstance(content, glm.Message): return content if isinstance(content, str): @@ -39,6 +40,20 @@ def _make_message(content: discuss_types.MessageOptions) -> glm.Message: def _make_messages(messages: discuss_types.MessagesOptions) -> List[glm.Message]: + """ + Creates a list of `glm.Message` objects from the provided messages. + + This function takes a variety of message content inputs, such as strings, dictionaries, + or `glm.Message` objects, and creates a list of `glm.Message` objects. It ensures that + the authors of the messages alternate appropriately. If authors are not provided, + default authors are assigned based on their position in the list. + + Args: + messages: The messages to convert. + + Returns: + A list of `glm.Message` objects with alternating authors. + """ if isinstance(messages, (str, dict, glm.Message)): messages = [_make_message(messages)] else: @@ -71,6 +86,7 @@ def _make_messages(messages: discuss_types.MessagesOptions) -> List[glm.Message] def _make_example(item: discuss_types.ExampleOptions) -> glm.Example: + """Creates a `glm.Example` object from the provided item.""" if isinstance(item, glm.Example): return item @@ -91,6 +107,21 @@ def _make_example(item: discuss_types.ExampleOptions) -> glm.Example: def _make_examples_from_flat( examples: List[discuss_types.MessageOptions], ) -> List[glm.Example]: + """ + Creates a list of `glm.Example` objects from a list of message options. + + This function takes a list of `discuss_types.MessageOptions` and pairs them into + `glm.Example` objects. The input examples must be in pairs to create valid examples. + + Args: + examples: The list of `discuss_types.MessageOptions`. + + Returns: + A list of `glm.Example objects` created by pairing up the provided messages. + + Raises: + ValueError: If the provided list of examples is not of even length. + """ if len(examples) % 2 != 0: raise ValueError( textwrap.dedent( @@ -116,6 +147,19 @@ def _make_examples_from_flat( def _make_examples(examples: discuss_types.ExamplesOptions) -> List[glm.Example]: + """ + Creates a list of `glm.Example` objects from the provided examples. + + This function takes various types of example content inputs and creates a list + of `glm.Example` objects. It handles the conversion of different input types and ensures + the appropriate structure for creating valid examples. + + Args: + examples: The examples to convert. + + Returns: + A list of `glm.Example` objects created from the provided examples. + """ if isinstance(examples, glm.Example): return [examples] @@ -155,6 +199,23 @@ def _make_message_prompt_dict( examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, ) -> glm.MessagePrompt: + """ + Creates a `glm.MessagePrompt` object from the provided prompt components. + + This function constructs a `glm.MessagePrompt` object using the provided `context`, `examples`, + or `messages`. It ensures the proper structure and handling of the input components. + + Either pass a `prompt` or it's component `context`, `examples`, `messages`. + + Args: + prompt: The complete prompt components. + context: The context for the prompt. + examples: The examples for the prompt. + messages: The messages for the prompt. + + Returns: + A `glm.MessagePrompt` object created from the provided prompt components. + """ if prompt is None: prompt = dict( context=context, @@ -201,6 +262,7 @@ def _make_message_prompt( examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, ) -> glm.MessagePrompt: + """Creates a `glm.MessagePrompt` object from the provided prompt components.""" prompt = _make_message_prompt_dict( prompt=prompt, context=context, examples=examples, messages=messages ) @@ -219,6 +281,7 @@ def _make_generate_message_request( top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, ) -> glm.GenerateMessageRequest: + """Creates a `glm.GenerateMessageRequest` object for generating messages.""" model = model_types.make_model_name(model) prompt = _make_message_prompt( @@ -236,6 +299,8 @@ def _make_generate_message_request( def set_doc(doc): + """A decorator to set the docstring of a function.""" + def inner(f): f.__doc__ = doc return f diff --git a/google/generativeai/models.py b/google/generativeai/models.py index 498e5566b..b725cc034 100644 --- a/google/generativeai/models.py +++ b/google/generativeai/models.py @@ -15,7 +15,7 @@ from __future__ import annotations import re -from typing import Optional, List +from typing import Optional, List, Iterator import google.ai.generativelanguage as glm from google.generativeai.client import get_default_model_client @@ -35,6 +35,22 @@ def get_model(name: str, *, client=None) -> model_types.Model: class ModelsIterable(model_types.ModelsIterable): + """ + An iterable class to traverse through a list of models. + + This class allows you to iterate over a list of models, fetching them in pages + if necessary based on the provided `page_size` and `page_token`. + + Args: + page_size: The number of `models` to fetch per page. + page_token: Token representing the current page. Pass `None` for the first page. + models: List of models to iterate through. + client: An optional client for the model service. + + Returns: + A `ModelsIterable` iterable object that allows iterating through the models. + """ + def __init__( self, *, @@ -48,13 +64,19 @@ def __init__( self._models = models self._client = client - def __iter__(self): + def __iter__(self) -> Iterator[model_types.Model]: + """ + Returns an iterator over the models. + """ while self: page = self._models yield from page self = self._next_page() - def _next_page(self): + def _next_page(self) -> ModelsIterable | None: + """ + Fetches the next page of models based on the page token. + """ if not self._page_token: return None return _list_models( @@ -62,7 +84,24 @@ def _next_page(self): ) -def _list_models(page_size, page_token, client): +def _list_models( + page_size: int, page_token: str | None, client: glm.ModelServiceClient +) -> ModelsIterable: + """ + Fetches a page of models using the provided client and pagination tokens. + + This function queries the `client` to retrieve a page of models based on the given + `page_size` and `page_token`. It then processes the response and returns an iterable + object to traverse through the models. + + Args: + page_size: How many `types.Models` to fetch per page (api call). + page_token: Token representing the current page. + client: The client to communicate with the model service. + + Returns: + An iterable `ModelsIterable` object containing the fetched models and pagination info. + """ result = client.list_models(page_size=page_size, page_token=page_token) result = result._response result = type(result).to_dict(result) diff --git a/google/generativeai/text.py b/google/generativeai/text.py index 051bfa2f9..d9362ce7f 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -29,6 +29,18 @@ def _make_text_prompt(prompt: str | dict[str, str]) -> glm.TextPrompt: + """ + Creates a `glm.TextPrompt` object based on the provided prompt input. + + Args: + prompt: The prompt input, either a string or a dictionary. + + Returns: + glm.TextPrompt: A TextPrompt object containing the prompt text. + + Raises: + TypeError: If the provided prompt is neither a string nor a dictionary. + """ if isinstance(prompt, str): return glm.TextPrompt(text=prompt) elif isinstance(prompt, dict): @@ -49,6 +61,28 @@ def _make_generate_text_request( safety_settings: safety_types.SafetySettingOptions | None = None, stop_sequences: str | Iterable[str] | None = None, ) -> glm.GenerateTextRequest: + """ + Creates a `glm.GenerateTextRequest` object based on the provided parameters. + + This function generates a `glm.GenerateTextRequest` object with the specified + parameters. It prepares the input parameters and creates a request that can be + used for generating text using the chosen model. + + Args: + model: The model to use for text generation. + prompt: The prompt for text generation. Defaults to None. + temperature: The temperature for randomness in generation. Defaults to None. + candidate_count: The number of candidates to consider. Defaults to None. + max_output_tokens: The maximum number of output tokens. Defaults to None. + top_p: The nucleus sampling probability threshold. Defaults to None. + top_k: The top-k sampling parameter. Defaults to None. + safety_settings: Safety settings for generated text. Defaults to None. + stop_sequences: Stop sequences to halt text generation. Can be a string + or iterable of strings. Defaults to None. + + Returns: + `glm.GenerateTextRequest`: A `GenerateTextRequest` object configured with the specified parameters. + """ model = model_types.make_model_name(model) prompt = _make_text_prompt(prompt=prompt) safety_settings = safety_types.normalize_safety_settings(safety_settings) @@ -155,6 +189,17 @@ def __init__(self, **kwargs): def _generate_response( request: glm.GenerateTextRequest, client: glm.TextServiceClient = None ) -> Completion: + """ + Generates a response using the provided `glm.GenerateTextRequest` and client. + + Args: + request: The text generation request. + client: The client to use for text generation. Defaults to None, in which + case the default text client is used. + + Returns: + `Completion`: A `Completion` object with the generated text and response information. + """ if client is None: client = get_default_text_client() diff --git a/tests/test_models.py b/tests/test_models.py index bfa829d3b..12e5366c4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -17,6 +17,7 @@ from absl.testing import absltest import google.ai.generativelanguage as glm + from google.ai.generativelanguage_v1beta2.types import model from google.generativeai import models