diff --git a/bigframes/bigquery/_operations/ai.py b/bigframes/bigquery/_operations/ai.py index d7ea29322d..d82023e4b5 100644 --- a/bigframes/bigquery/_operations/ai.py +++ b/bigframes/bigquery/_operations/ai.py @@ -47,30 +47,13 @@ def generate_bool( ... "col_1": ["apple", "bear", "pear"], ... "col_2": ["fruit", "animal", "animal"] ... }) - >>> bbq.ai_generate_bool((df["col_1"], " is a ", df["col_2"])) + >>> bbq.ai.generate_bool((df["col_1"], " is a ", df["col_2"])) 0 {'result': True, 'full_response': '{"candidate... 1 {'result': True, 'full_response': '{"candidate... 2 {'result': False, 'full_response': '{"candidat... dtype: struct[pyarrow] - >>> bbq.ai_generate_bool((df["col_1"], " is a ", df["col_2"])).struct.field("result") - 0 True - 1 True - 2 False - Name: result, dtype: boolean - - >>> model_params = { - ... "generation_config": { - ... "thinking_config": { - ... "thinking_budget": 0 - ... } - ... } - ... } - >>> bbq.ai_generate_bool( - ... (df["col_1"], " is a ", df["col_2"]), - ... endpoint="gemini-2.5-pro", - ... model_params=model_params, - ... ).struct.field("result") + >>> bbq.ai.generate_bool((df["col_1"], " is a ", df["col_2"])).struct.field("result") 0 True 1 True 2 False diff --git a/tests/system/large/bigquery/__init__.py b/tests/system/large/bigquery/__init__.py deleted file mode 100644 index 0a2669d7a2..0000000000 --- a/tests/system/large/bigquery/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/system/large/bigquery/test_ai.py b/tests/system/large/bigquery/test_ai.py deleted file mode 100644 index be0216a526..0000000000 --- a/tests/system/large/bigquery/test_ai.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pandas as pd -import pandas.testing - -import bigframes.bigquery as bbq - - -def test_ai_generate_bool_multi_model(session): - df = session.from_glob_path( - "gs://bigframes-dev-testing/a_multimodel/images/*", name="image" - ) - - result = bbq.ai.generate_bool((df["image"], " contains an animal")).struct.field( - "result" - ) - - pandas.testing.assert_series_equal( - result.to_pandas(), - pd.Series([True, True, False, False, False], name="result"), - check_dtype=False, - check_index=False, - ) diff --git a/tests/system/small/bigquery/test_ai.py b/tests/system/small/bigquery/test_ai.py index 01050ade04..443d4c54a3 100644 --- a/tests/system/small/bigquery/test_ai.py +++ b/tests/system/small/bigquery/test_ai.py @@ -15,9 +15,10 @@ import sys import pandas as pd -import pandas.testing +import pyarrow as pa import pytest +from bigframes import series import bigframes.bigquery as bbq import bigframes.pandas as bpd @@ -27,15 +28,17 @@ def test_ai_generate_bool(session): s2 = bpd.Series(["fruit", "tree"], session=session) prompt = (s1, " is a ", s2) - result = bbq.ai.generate_bool(prompt, endpoint="gemini-2.5-flash").struct.field( - "result" - ) + result = bbq.ai.generate_bool(prompt, endpoint="gemini-2.5-flash") - pandas.testing.assert_series_equal( - result.to_pandas(), - pd.Series([True, False], name="result"), - check_dtype=False, - check_index=False, + assert _contains_no_nulls(result) + assert result.dtype == pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.bool_()), + pa.field("full_response", pa.string()), + pa.field("status", pa.string()), + ) + ) ) @@ -52,11 +55,38 @@ def test_ai_generate_bool_with_model_params(session): result = bbq.ai.generate_bool( prompt, endpoint="gemini-2.5-flash", model_params=model_params - ).struct.field("result") + ) + + assert _contains_no_nulls(result) + assert result.dtype == pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.bool_()), + pa.field("full_response", pa.string()), + pa.field("status", pa.string()), + ) + ) + ) + - pandas.testing.assert_series_equal( - result.to_pandas(), - pd.Series([True, False], name="result"), - check_dtype=False, - check_index=False, +def test_ai_generate_bool_multi_model(session): + df = session.from_glob_path( + "gs://bigframes-dev-testing/a_multimodel/images/*", name="image" ) + + result = bbq.ai.generate_bool((df["image"], " contains an animal")) + + assert _contains_no_nulls(result) + assert result.dtype == pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.bool_()), + pa.field("full_response", pa.string()), + pa.field("status", pa.string()), + ) + ) + ) + + +def _contains_no_nulls(s: series.Series) -> bool: + return len(s) == s.count()