|
21 | 21 | TEST_IMAGE_DATA = TEST_IMAGE_PATH.read_bytes()
|
22 | 22 |
|
23 | 23 |
|
| 24 | +def simple_part(text: str) -> glm.Content: |
| 25 | + return glm.Content({"parts": [{"text": text}]}) |
| 26 | + |
| 27 | + |
| 28 | +def iter_part(texts: Iterable[str]) -> glm.Content: |
| 29 | + return glm.Content({"parts": [{"text": t} for t in texts]}) |
| 30 | + |
| 31 | + |
24 | 32 | def simple_response(text: str) -> glm.GenerateContentResponse:
|
25 |
| - return glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": text}]}}]}) |
| 33 | + return glm.GenerateContentResponse({"candidates": [{"content": simple_part(text)}]}) |
26 | 34 |
|
27 | 35 |
|
28 | 36 | class CUJTests(parameterized.TestCase):
|
@@ -605,6 +613,24 @@ def test_tools(self):
|
605 | 613 | self.assertLen(obr.tools, 1)
|
606 | 614 | self.assertEqual(type(obr.tools[0]).to_dict(obr.tools[0]), tools)
|
607 | 615 |
|
| 616 | + @parameterized.named_parameters( |
| 617 | + ["bare_str", "talk like a pirate", simple_part("talk like a pirate")], |
| 618 | + [ |
| 619 | + "part_dict", |
| 620 | + {"parts": [{"text": "talk like a pirate"}]}, |
| 621 | + simple_part("talk like a pirate"), |
| 622 | + ], |
| 623 | + ["part_list", ["talk like:", "a pirate"], iter_part(["talk like:", "a pirate"])], |
| 624 | + ) |
| 625 | + def test_system_instruction(self, instruction, expected_instr): |
| 626 | + self.responses["generate_content"] = [simple_response("echo echo")] |
| 627 | + model = generative_models.GenerativeModel("gemini-pro", system_instruction=instruction) |
| 628 | + |
| 629 | + _ = model.generate_content("test") |
| 630 | + |
| 631 | + [req] = self.observed_requests |
| 632 | + self.assertEqual(req.system_instruction, expected_instr) |
| 633 | + |
608 | 634 | @parameterized.named_parameters(
|
609 | 635 | ["basic", "Hello"],
|
610 | 636 | ["list", ["Hello"]],
|
|
0 commit comments