diff --git a/lib/sycamore/sycamore/llms/gemini.py b/lib/sycamore/sycamore/llms/gemini.py index b4dc14607..e289fb22e 100644 --- a/lib/sycamore/sycamore/llms/gemini.py +++ b/lib/sycamore/sycamore/llms/gemini.py @@ -3,6 +3,7 @@ from enum import Enum from typing import Any, Optional, Union import os +import io from sycamore.llms.llms import LLM from sycamore.llms.prompts.prompts import RenderedPrompt @@ -22,10 +23,10 @@ class GeminiModels(Enum): """Represents available Gemini models. More info: https://googleapis.github.io/python-genai/""" # Note that the models available on a given Gemini account may vary. - GEMINI_2_FLASH = GeminiModel(name="gemini-2.0-flash-exp", is_chat=True) + GEMINI_2_FLASH = GeminiModel(name="gemini-2.0-flash", is_chat=True) GEMINI_2_FLASH_LITE = GeminiModel(name="gemini-2.0-flash-lite-preview-02-05", is_chat=True) - GEMINI_2_FLASH_THINKING = GeminiModel(name="gemini-2.0-flash-thinking-exp", is_chat=True) - GEMINI_2_PRO = GeminiModel(name="gemini-2.0-pro-exp", is_chat=True) + GEMINI_2_FLASH_THINKING = GeminiModel(name="gemini-2.0-flash-thinking-exp-01-21", is_chat=True) + GEMINI_2_PRO = GeminiModel(name="gemini-2.0-pro-exp-02-05", is_chat=True) @classmethod def from_name(cls, name: str): @@ -86,7 +87,7 @@ def get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict] if prompt.response_format: config["response_mime_type"] = "application/json" config["response_schema"] = prompt.response_format - content_list = [] + content_list: list[types.Content] = [] for message in prompt.messages: if message.role == "system": config["system_message"] = message.content @@ -95,13 +96,15 @@ def get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict] content = types.Content(parts=[types.Part.from_text(text=message.content)], role=role) if message.images: for image in message.images: - image_bytes = image.convert("RGB").tobytes() - content.parts.append(types.Part.from_bytes(image_bytes, media_type="image/png")) + buffered = io.BytesIO() + image.save(buffered, format="PNG") + image_bytes = buffered.getvalue() + content.parts.append(types.Part.from_bytes(data=image_bytes, mime_type="image/png")) content_list.append(content) kwargs["config"] = None if config: kwargs["config"] = types.GenerateContentConfig(**config) - kwargs["content"] = content + kwargs["content"] = content_list return kwargs def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict: