Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions lib/sycamore/sycamore/llms/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from enum import Enum
from typing import Any, Optional, Union
import os
import io

from sycamore.llms.llms import LLM
from sycamore.llms.prompts.prompts import RenderedPrompt
Expand All @@ -22,10 +23,10 @@ class GeminiModels(Enum):
"""Represents available Gemini models. More info: https://googleapis.github.io/python-genai/"""

# Note that the models available on a given Gemini account may vary.
GEMINI_2_FLASH = GeminiModel(name="gemini-2.0-flash-exp", is_chat=True)
GEMINI_2_FLASH = GeminiModel(name="gemini-2.0-flash", is_chat=True)
GEMINI_2_FLASH_LITE = GeminiModel(name="gemini-2.0-flash-lite-preview-02-05", is_chat=True)
GEMINI_2_FLASH_THINKING = GeminiModel(name="gemini-2.0-flash-thinking-exp", is_chat=True)
GEMINI_2_PRO = GeminiModel(name="gemini-2.0-pro-exp", is_chat=True)
GEMINI_2_FLASH_THINKING = GeminiModel(name="gemini-2.0-flash-thinking-exp-01-21", is_chat=True)
GEMINI_2_PRO = GeminiModel(name="gemini-2.0-pro-exp-02-05", is_chat=True)

@classmethod
def from_name(cls, name: str):
Expand Down Expand Up @@ -86,7 +87,7 @@ def get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict]
if prompt.response_format:
config["response_mime_type"] = "application/json"
config["response_schema"] = prompt.response_format
content_list = []
content_list: list[types.Content] = []
for message in prompt.messages:
if message.role == "system":
config["system_message"] = message.content
Expand All @@ -95,13 +96,15 @@ def get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict]
content = types.Content(parts=[types.Part.from_text(text=message.content)], role=role)
if message.images:
for image in message.images:
image_bytes = image.convert("RGB").tobytes()
content.parts.append(types.Part.from_bytes(image_bytes, media_type="image/png"))
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
content.parts.append(types.Part.from_bytes(data=image_bytes, mime_type="image/png"))
content_list.append(content)
kwargs["config"] = None
if config:
kwargs["config"] = types.GenerateContentConfig(**config)
kwargs["content"] = content
kwargs["content"] = content_list
return kwargs

def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict:
Expand Down
Loading