Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged

Fixup #128

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion google/generativeai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

genai.configure(api_key=os.environ['API_KEY'])

model = genai.Model(name='gemini-pro')
model = genai.GenerativeModel(name='gemini-pro')
response = model.generate_content('Please summarise this document: ...')

print(response.text)
Expand Down
8 changes: 6 additions & 2 deletions google/generativeai/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,14 +274,18 @@ async def generate_content_async(
def count_tokens(
self, contents: content_types.ContentsType
) -> glm.CountTokensResponse:
if self._client is None:
self._client = client.get_default_generative_client()
contents = content_types.to_contents(contents)
return self._client.count_tokens(model=self.model_name, contents=contents)
return self._client.count_tokens(glm.CountTokensRequest(model=self.model_name, contents=contents))

async def count_tokens_async(
self, contents: content_types.ContentsType
) -> glm.CountTokensResponse:
if self._async_client is None:
self._async_client = client.get_default_generative_async_client()
contents = content_types.to_contents(contents)
return await self._client.count_tokens(model=self.model_name, contents=contents)
return await self._async_client.count_tokens(glm.CountTokensRequest(model=self.model_name, contents=contents))
# fmt: on

def start_chat(
Expand Down
23 changes: 23 additions & 0 deletions tests/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def stream_generate_content(
response = self.responses["stream_generate_content"].pop(0)
return response

@add_client_method
def count_tokens(
request: glm.CountTokensRequest,
) -> Iterable[glm.GenerateContentResponse]:
self.observed_requests.append(request)
response = self.responses["count_tokens"].pop(0)
return response

def test_hello(self):
# Generate text from text prompt
model = generative_models.GenerativeModel(model_name="gemini-m")
Expand Down Expand Up @@ -564,6 +572,21 @@ def test_chat_streaming_unexpected_stop(self):
chat.rewind()
self.assertLen(chat.history, 0)

@parameterized.named_parameters(
["basic", "Hello"],
["list", ["Hello"]],
[
"list2",
[{"text": "Hello"}, {"inline_data": {"data": b"PNG!", "mime_type": "image/png"}}],
],
["contents", [{"role": "user", "parts": ["hello"]}]],
)
def test_count_tokens_smoke(self, contents):
self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)]
model = generative_models.GenerativeModel("gemini-mm-m")
response = model.count_tokens(contents)
self.assertEqual(type(response).to_dict(response), {"total_tokens": 7})

@parameterized.named_parameters(
[
"GenerateContentResponse",
Expand Down
23 changes: 23 additions & 0 deletions tests/test_generative_models_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ async def stream_generate_content(
response = self.responses["stream_generate_content"].pop(0)
return response

@add_client_method
async def count_tokens(
request: glm.CountTokensRequest,
) -> Iterable[glm.GenerateContentResponse]:
self.observed_requests.append(request)
response = self.responses["count_tokens"].pop(0)
return response

async def test_basic(self):
# Generate text from text prompt
model = generative_models.GenerativeModel(model_name="gemini-m")
Expand Down Expand Up @@ -98,6 +106,21 @@ async def responses():

self.assertEqual(response.text, "world!")

@parameterized.named_parameters(
["basic", "Hello"],
["list", ["Hello"]],
[
"list2",
[{"text": "Hello"}, {"inline_data": {"data": b"PNG!", "mime_type": "image/png"}}],
],
["contents", [{"role": "user", "parts": ["hello"]}]],
)
async def test_count_tokens_smoke(self, contents):
self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)]
model = generative_models.GenerativeModel("gemini-mm-m")
response = await model.count_tokens_async(contents)
self.assertEqual(type(response).to_dict(response), {"total_tokens": 7})


if __name__ == "__main__":
absltest.main()