Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion google/generativeai/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class GenerativeModel:

def __init__(
self,
model_name: str = "gemini-m",
model_name: str = "gemini-pro",
safety_settings: safety_types.SafetySettingOptions | None = None,
generation_config: generation_types.GenerationConfigType | None = None,
tools: content_types.ToolsType = None,
Expand Down
40 changes: 20 additions & 20 deletions tests/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def count_tokens(

def test_hello(self):
# Generate text from text prompt
model = generative_models.GenerativeModel(model_name="gemini-m")
model = generative_models.GenerativeModel(model_name="gemini-pro")

self.responses["generate_content"].append(simple_response("world!"))

Expand All @@ -89,7 +89,7 @@ def test_hello(self):
)
def test_image(self, content):
# Generate text from image
model = generative_models.GenerativeModel("gemini-m")
model = generative_models.GenerativeModel("gemini-pro")

cat = "It's a cat"
self.responses["generate_content"].append(simple_response(cat))
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_image(self, content):
)
def test_generation_config_overwrite(self, config1, config2):
# Generation config
model = generative_models.GenerativeModel("gemini-m", generation_config=config1)
model = generative_models.GenerativeModel("gemini-pro", generation_config=config1)

self.responses["generate_content"] = [
simple_response(" world!"),
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_generation_config_overwrite(self, config1, config2):
)
def test_safety_overwrite(self, safe1, safe2):
# Safety
model = generative_models.GenerativeModel("gemini-m", safety_settings={"danger": "low"})
model = generative_models.GenerativeModel("gemini-pro", safety_settings={"danger": "low"})

self.responses["generate_content"] = [
simple_response(" world!"),
Expand Down Expand Up @@ -200,7 +200,7 @@ def test_stream_basic(self):
chunks = ["first", " second", " third"]
self.responses["stream_generate_content"] = [(simple_response(text) for text in chunks)]

model = generative_models.GenerativeModel("gemini-m")
model = generative_models.GenerativeModel("gemini-pro")
response = model.generate_content("Hello", stream=True)

self.assertEqual(self.observed_requests[0].contents[0].parts[0].text, "Hello")
Expand All @@ -214,7 +214,7 @@ def test_stream_lookahead(self):
chunks = ["first", " second", " third"]
self.responses["stream_generate_content"] = [(simple_response(text) for text in chunks)]

model = generative_models.GenerativeModel("gemini-m")
model = generative_models.GenerativeModel("gemini-pro")
response = model.generate_content("Hello", stream=True)

self.assertEqual(self.observed_requests[0].contents[0].parts[0].text, "Hello")
Expand All @@ -234,7 +234,7 @@ def test_stream_prompt_feedback_blocked(self):
]
self.responses["stream_generate_content"] = [(chunk for chunk in chunks)]

model = generative_models.GenerativeModel("gemini-m")
model = generative_models.GenerativeModel("gemini-pro")
response = model.generate_content("Bad stuff!", stream=True)

self.assertEqual(
Expand Down Expand Up @@ -269,7 +269,7 @@ def test_stream_prompt_feedback_not_blocked(self):
]
self.responses["stream_generate_content"] = [(chunk for chunk in chunks)]

model = generative_models.GenerativeModel("gemini-m")
model = generative_models.GenerativeModel("gemini-pro")
response = model.generate_content("Hello", stream=True)

self.assertEqual(
Expand All @@ -282,7 +282,7 @@ def test_stream_prompt_feedback_not_blocked(self):

def test_chat(self):
# Multi turn chat
model = generative_models.GenerativeModel("gemini-m")
model = generative_models.GenerativeModel("gemini-pro")
chat = model.start_chat()

self.responses["generate_content"] = [
Expand Down Expand Up @@ -331,7 +331,7 @@ def test_chat_streaming_basic(self):
iter([simple_response("x"), simple_response("y"), simple_response("z")]),
]

model = generative_models.GenerativeModel("gemini-mm-m")
model = generative_models.GenerativeModel("gemini-pro-vision")
chat = model.start_chat()

response = chat.send_message("letters?", stream=True)
Expand All @@ -354,7 +354,7 @@ def test_chat_incomplete_streaming_errors(self):
iter([simple_response("x"), simple_response("y"), simple_response("z")]),
]

model = generative_models.GenerativeModel("gemini-mm-m")
model = generative_models.GenerativeModel("gemini-pro-vision")
chat = model.start_chat()
response = chat.send_message("letters?", stream=True)

Expand All @@ -378,7 +378,7 @@ def test_edit_history(self):
simple_response("third"),
]

model = generative_models.GenerativeModel("gemini-mm-m")
model = generative_models.GenerativeModel("gemini-pro-vision")
chat = model.start_chat()

response = chat.send_message("hello")
Expand All @@ -404,7 +404,7 @@ def test_replace_history(self):
simple_response("third"),
]

model = generative_models.GenerativeModel("gemini-mm-m")
model = generative_models.GenerativeModel("gemini-pro-vision")
chat = model.start_chat()
chat.send_message("hello1")
chat.send_message("hello2")
Expand All @@ -426,7 +426,7 @@ def test_copy_history(self):
simple_response("third"),
]

model = generative_models.GenerativeModel("gemini-mm-m")
model = generative_models.GenerativeModel("gemini-pro-vision")
chat1 = model.start_chat()
chat1.send_message("hello1")

Expand Down Expand Up @@ -471,7 +471,7 @@ def no_throw():
no_throw(),
]

model = generative_models.GenerativeModel("gemini-mm-m")
model = generative_models.GenerativeModel("gemini-pro-vision")
chat = model.start_chat()

# Send a message, the response is okay..
Expand Down Expand Up @@ -514,7 +514,7 @@ def test_chat_prompt_blocked(self):
)
]

model = generative_models.GenerativeModel("gemini-mm-m")
model = generative_models.GenerativeModel("gemini-pro-vision")
chat = model.start_chat()

with self.assertRaises(generation_types.BlockedPromptException):
Expand All @@ -532,7 +532,7 @@ def test_chat_candidate_blocked(self):
)
]

model = generative_models.GenerativeModel("gemini-mm-m")
model = generative_models.GenerativeModel("gemini-pro-vision")
chat = model.start_chat()

with self.assertRaises(generation_types.StopCandidateException):
Expand All @@ -554,7 +554,7 @@ def test_chat_streaming_unexpected_stop(self):
)
]

model = generative_models.GenerativeModel("gemini-mm-m")
model = generative_models.GenerativeModel("gemini-pro-vision")
chat = model.start_chat()

response = chat.send_message("hello", stream=True)
Expand All @@ -578,7 +578,7 @@ def test_tools(self):
dict(name="datetime", description="Returns the current UTC date and time.")
]
)
model = generative_models.GenerativeModel("gemini-mm-m", tools=tools)
model = generative_models.GenerativeModel("gemini-pro-vision", tools=tools)

self.responses["generate_content"] = [
simple_response("a"),
Expand Down Expand Up @@ -611,7 +611,7 @@ def test_tools(self):
)
def test_count_tokens_smoke(self, contents):
self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)]
model = generative_models.GenerativeModel("gemini-mm-m")
model = generative_models.GenerativeModel("gemini-pro-vision")
response = model.count_tokens(contents)
self.assertEqual(type(response).to_dict(response), {"total_tokens": 7})

Expand Down
6 changes: 3 additions & 3 deletions tests/test_generative_models_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def count_tokens(

async def test_basic(self):
# Generate text from text prompt
model = generative_models.GenerativeModel(model_name="gemini-m")
model = generative_models.GenerativeModel(model_name="gemini-pro")

self.responses["generate_content"] = [simple_response("world!")]

Expand All @@ -85,7 +85,7 @@ async def test_basic(self):

async def test_streaming(self):
# Generate text from text prompt
model = generative_models.GenerativeModel(model_name="gemini-m")
model = generative_models.GenerativeModel(model_name="gemini-pro")

async def responses():
for c in "world!":
Expand Down Expand Up @@ -113,7 +113,7 @@ async def responses():
)
async def test_count_tokens_smoke(self, contents):
self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)]
model = generative_models.GenerativeModel("gemini-mm-m")
model = generative_models.GenerativeModel("gemini-pro-vision")
response = await model.count_tokens_async(contents)
self.assertEqual(type(response).to_dict(response), {"total_tokens": 7})

Expand Down