From 2b2ad4d64eb742adf6850cc21334d0d2250aed54 Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Mon, 22 Jan 2024 10:43:02 +0800 Subject: [PATCH] Rename model defaults to use gemini-pro Updated some test references too. --- google/generativeai/generative_models.py | 2 +- tests/test_generative_models.py | 40 ++++++++++++------------ tests/test_generative_models_async.py | 6 ++-- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index b1421fa1c..f78b203cd 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -156,7 +156,7 @@ class GenerativeModel: def __init__( self, - model_name: str = "gemini-m", + model_name: str = "gemini-pro", safety_settings: safety_types.SafetySettingOptions | None = None, generation_config: generation_types.GenerationConfigType | None = None, tools: content_types.ToolsType = None, diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 9ed5df12d..d7af691f1 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -66,7 +66,7 @@ def count_tokens( def test_hello(self): # Generate text from text prompt - model = generative_models.GenerativeModel(model_name="gemini-m") + model = generative_models.GenerativeModel(model_name="gemini-pro") self.responses["generate_content"].append(simple_response("world!")) @@ -89,7 +89,7 @@ def test_hello(self): ) def test_image(self, content): # Generate text from image - model = generative_models.GenerativeModel("gemini-m") + model = generative_models.GenerativeModel("gemini-pro") cat = "It's a cat" self.responses["generate_content"].append(simple_response(cat)) @@ -123,7 +123,7 @@ def test_image(self, content): ) def test_generation_config_overwrite(self, config1, config2): # Generation config - model = generative_models.GenerativeModel("gemini-m", generation_config=config1) + model = generative_models.GenerativeModel("gemini-pro", generation_config=config1) self.responses["generate_content"] = [ simple_response(" world!"), @@ -168,7 +168,7 @@ def test_generation_config_overwrite(self, config1, config2): ) def test_safety_overwrite(self, safe1, safe2): # Safety - model = generative_models.GenerativeModel("gemini-m", safety_settings={"danger": "low"}) + model = generative_models.GenerativeModel("gemini-pro", safety_settings={"danger": "low"}) self.responses["generate_content"] = [ simple_response(" world!"), @@ -200,7 +200,7 @@ def test_stream_basic(self): chunks = ["first", " second", " third"] self.responses["stream_generate_content"] = [(simple_response(text) for text in chunks)] - model = generative_models.GenerativeModel("gemini-m") + model = generative_models.GenerativeModel("gemini-pro") response = model.generate_content("Hello", stream=True) self.assertEqual(self.observed_requests[0].contents[0].parts[0].text, "Hello") @@ -214,7 +214,7 @@ def test_stream_lookahead(self): chunks = ["first", " second", " third"] self.responses["stream_generate_content"] = [(simple_response(text) for text in chunks)] - model = generative_models.GenerativeModel("gemini-m") + model = generative_models.GenerativeModel("gemini-pro") response = model.generate_content("Hello", stream=True) self.assertEqual(self.observed_requests[0].contents[0].parts[0].text, "Hello") @@ -234,7 +234,7 @@ def test_stream_prompt_feedback_blocked(self): ] self.responses["stream_generate_content"] = [(chunk for chunk in chunks)] - model = generative_models.GenerativeModel("gemini-m") + model = generative_models.GenerativeModel("gemini-pro") response = model.generate_content("Bad stuff!", stream=True) self.assertEqual( @@ -269,7 +269,7 @@ def test_stream_prompt_feedback_not_blocked(self): ] self.responses["stream_generate_content"] = [(chunk for chunk in chunks)] - model = generative_models.GenerativeModel("gemini-m") + model = generative_models.GenerativeModel("gemini-pro") response = model.generate_content("Hello", stream=True) self.assertEqual( @@ -282,7 +282,7 @@ def test_stream_prompt_feedback_not_blocked(self): def test_chat(self): # Multi turn chat - model = generative_models.GenerativeModel("gemini-m") + model = generative_models.GenerativeModel("gemini-pro") chat = model.start_chat() self.responses["generate_content"] = [ @@ -331,7 +331,7 @@ def test_chat_streaming_basic(self): iter([simple_response("x"), simple_response("y"), simple_response("z")]), ] - model = generative_models.GenerativeModel("gemini-mm-m") + model = generative_models.GenerativeModel("gemini-pro-vision") chat = model.start_chat() response = chat.send_message("letters?", stream=True) @@ -354,7 +354,7 @@ def test_chat_incomplete_streaming_errors(self): iter([simple_response("x"), simple_response("y"), simple_response("z")]), ] - model = generative_models.GenerativeModel("gemini-mm-m") + model = generative_models.GenerativeModel("gemini-pro-vision") chat = model.start_chat() response = chat.send_message("letters?", stream=True) @@ -378,7 +378,7 @@ def test_edit_history(self): simple_response("third"), ] - model = generative_models.GenerativeModel("gemini-mm-m") + model = generative_models.GenerativeModel("gemini-pro-vision") chat = model.start_chat() response = chat.send_message("hello") @@ -404,7 +404,7 @@ def test_replace_history(self): simple_response("third"), ] - model = generative_models.GenerativeModel("gemini-mm-m") + model = generative_models.GenerativeModel("gemini-pro-vision") chat = model.start_chat() chat.send_message("hello1") chat.send_message("hello2") @@ -426,7 +426,7 @@ def test_copy_history(self): simple_response("third"), ] - model = generative_models.GenerativeModel("gemini-mm-m") + model = generative_models.GenerativeModel("gemini-pro-vision") chat1 = model.start_chat() chat1.send_message("hello1") @@ -471,7 +471,7 @@ def no_throw(): no_throw(), ] - model = generative_models.GenerativeModel("gemini-mm-m") + model = generative_models.GenerativeModel("gemini-pro-vision") chat = model.start_chat() # Send a message, the response is okay.. @@ -514,7 +514,7 @@ def test_chat_prompt_blocked(self): ) ] - model = generative_models.GenerativeModel("gemini-mm-m") + model = generative_models.GenerativeModel("gemini-pro-vision") chat = model.start_chat() with self.assertRaises(generation_types.BlockedPromptException): @@ -532,7 +532,7 @@ def test_chat_candidate_blocked(self): ) ] - model = generative_models.GenerativeModel("gemini-mm-m") + model = generative_models.GenerativeModel("gemini-pro-vision") chat = model.start_chat() with self.assertRaises(generation_types.StopCandidateException): @@ -554,7 +554,7 @@ def test_chat_streaming_unexpected_stop(self): ) ] - model = generative_models.GenerativeModel("gemini-mm-m") + model = generative_models.GenerativeModel("gemini-pro-vision") chat = model.start_chat() response = chat.send_message("hello", stream=True) @@ -578,7 +578,7 @@ def test_tools(self): dict(name="datetime", description="Returns the current UTC date and time.") ] ) - model = generative_models.GenerativeModel("gemini-mm-m", tools=tools) + model = generative_models.GenerativeModel("gemini-pro-vision", tools=tools) self.responses["generate_content"] = [ simple_response("a"), @@ -611,7 +611,7 @@ def test_tools(self): ) def test_count_tokens_smoke(self, contents): self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)] - model = generative_models.GenerativeModel("gemini-mm-m") + model = generative_models.GenerativeModel("gemini-pro-vision") response = model.count_tokens(contents) self.assertEqual(type(response).to_dict(response), {"total_tokens": 7}) diff --git a/tests/test_generative_models_async.py b/tests/test_generative_models_async.py index d2d62ebb0..e7a1405a3 100644 --- a/tests/test_generative_models_async.py +++ b/tests/test_generative_models_async.py @@ -72,7 +72,7 @@ async def count_tokens( async def test_basic(self): # Generate text from text prompt - model = generative_models.GenerativeModel(model_name="gemini-m") + model = generative_models.GenerativeModel(model_name="gemini-pro") self.responses["generate_content"] = [simple_response("world!")] @@ -85,7 +85,7 @@ async def test_basic(self): async def test_streaming(self): # Generate text from text prompt - model = generative_models.GenerativeModel(model_name="gemini-m") + model = generative_models.GenerativeModel(model_name="gemini-pro") async def responses(): for c in "world!": @@ -113,7 +113,7 @@ async def responses(): ) async def test_count_tokens_smoke(self, contents): self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)] - model = generative_models.GenerativeModel("gemini-mm-m") + model = generative_models.GenerativeModel("gemini-pro-vision") response = await model.count_tokens_async(contents) self.assertEqual(type(response).to_dict(response), {"total_tokens": 7})