diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 166b986ce3..887f3bfead 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.98.0" + ".": "1.99.0" } diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e24845fd3..cd328a9014 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,24 @@ # Changelog +## [1.99.0](https://github.com/googleapis/python-aiplatform/compare/v1.98.0...v1.99.0) (2025-06-24) + + +### Features + +* [vertexai] Added concise option name to `OpenModel.list_deploy_options()` ([9a0eec6](https://github.com/googleapis/python-aiplatform/commit/9a0eec64544630ca3681acf1ba6bb79280ee8c11)) +* Add resource usage assessment for batch prediction. ([f63e436](https://github.com/googleapis/python-aiplatform/commit/f63e436ed0cb30316dc8569a6e1adfe5cb2cced8)) +* Add support for ADK memory service to AdkApp template ([733fddd](https://github.com/googleapis/python-aiplatform/commit/733fdddb7dd61ae9b9e5180e832776b0c4e2682b)) +* GenAI SDK client - Add automatic candidate naming and creation timestamp to evaluation dataset metadata ([e8897e7](https://github.com/googleapis/python-aiplatform/commit/e8897e7bee243fe9ac9996d451ad313e9ff6484c)) +* GenAI SDK client - Add support for OpenAI data format for evals ([f8f66f1](https://github.com/googleapis/python-aiplatform/commit/f8f66f1420271c9278236c6c1bf0b652126c4959)) +* GenAI SDK client - Adding client-based SDKs for Agent Engine ([7b51d9e](https://github.com/googleapis/python-aiplatform/commit/7b51d9e0a06662def2275c84cda6d08fed796740)) + + +### Documentation + +* Add deprecation notice to readme for Generative AI submodules: vertexai.generative_models, vertexai.language_models, vertexai.vision_models, vertexai.tuning, vertexai.caching ([beae2e3](https://github.com/googleapis/python-aiplatform/commit/beae2e3c39904c83a63f77be5bb7dc9958abd19d)) +* Add deprecation notice to readme for Generative AI submodules: vertexai.generative_models, vertexai.language_models, vertexai.vision_models, vertexai.tuning, vertexai.caching ([cdee7c2](https://github.com/googleapis/python-aiplatform/commit/cdee7c2406ad50dd2d3af0075f1c84d38b684b33)) +* Add deprecation notice to readme for Generative AI submodules: vertexai.generative_models, vertexai.language_models, vertexai.vision_models, vertexai.tuning, vertexai.caching ([9b0beae](https://github.com/googleapis/python-aiplatform/commit/9b0beae22be2f7618618ef52971f0b3603ae3885)) + ## [1.98.0](https://github.com/googleapis/python-aiplatform/compare/v1.97.0...v1.98.0) (2025-06-19) diff --git a/README.rst b/README.rst index 9ef50aabb1..626221cfe8 100644 --- a/README.rst +++ b/README.rst @@ -1,6 +1,14 @@ Vertex AI SDK for Python ================================================= +.. note:: + + The following Generative AI modules in the Vertex AI SDK are deprecated as of June 24, 2025 and will be removed on June 24, 2026: + `vertexai.generative_models`, `vertexai.language_models`, `vertexai.vision_models`, `vertexai.tuning`, `vertexai.caching`. Please use the + [Google Gen AI SDK](https://pypi.org/project/google-genai/) to access these features. See + [the migration guide](https://cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk) for details. + You can continue using all other Vertex AI SDK modules, as they are the recommended way to use the API. + |GA| |pypi| |versions| |unit-tests| |system-tests| |sample-tests| `Vertex AI`_: Google Vertex AI is an integrated suite of machine learning tools and services for building and using ML models with AutoML or custom code. It offers both novices and experts the best workbench for the entire machine learning development lifecycle. diff --git a/google/cloud/aiplatform/gapic_version.py b/google/cloud/aiplatform/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform/gapic_version.py +++ b/google/cloud/aiplatform/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/preview/datasets.py b/google/cloud/aiplatform/preview/datasets.py index 8340533950..b185cbd3c2 100644 --- a/google/cloud/aiplatform/preview/datasets.py +++ b/google/cloud/aiplatform/preview/datasets.py @@ -602,6 +602,21 @@ class TuningValidationAssessmentResult: errors: List[str] +@dataclasses.dataclass(frozen=True) +class BatchPredictionResourceUsageAssessmentResult: + """The result of a batch prediction resource usage assessment. + + Attributes: + token_count (int): + Number of tokens in the dataset. + audio_token_count (int): + Number of audio tokens in the dataset. + """ + + token_count: int + audio_token_count: int + + class MultimodalDataset(base.VertexAiResourceNounWithFutureManager): """A class representing a unified multimodal dataset.""" @@ -1499,6 +1514,48 @@ def assess_tuning_validity( errors=assessment_result.tuning_validation_assessment_result.errors ) + def assess_batch_prediction_resources( + self, + *, + model_name: str, + template_config: Optional[GeminiTemplateConfig] = None, + assess_request_timeout: Optional[float] = None, + ) -> BatchPredictionResourceUsageAssessmentResult: + """Assess the batch prediction resources required for a given model. + + Args: + model_name (str): + Required. The name of the model to assess the batch prediction resources + for. + template_config (GeminiTemplateConfig): + Optional. The template config used to assemble the dataset + before assessing the batch prediction resources. If not provided, the + template config attached to the dataset will be used. Required + if no template config is attached to the dataset. + assess_request_timeout (float): + Optional. The timeout for the assess batch prediction resources request. + Returns: + A dict containing the batch prediction resource usage assessment result. The + dict contains the following keys: + - token_count: The number of tokens in the dataset. + - audio_token_count: The number of audio tokens in the dataset. + + """ + request = self._build_assess_data_request(template_config) + request.batch_prediction_resource_usage_assessment_config = gca_dataset_service.AssessDataRequest.BatchPredictionResourceUsageAssessmentConfig( + model_name=model_name + ) + + assessment_result = ( + self.api_client.assess_data(request=request, timeout=assess_request_timeout) + .result(timeout=None) + .batch_prediction_resource_usage_assessment_result + ) + return BatchPredictionResourceUsageAssessmentResult( + token_count=assessment_result.token_count, + audio_token_count=assessment_result.audio_token_count, + ) + def _build_assess_data_request( self, template_config: Optional[GeminiTemplateConfig] = None, diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/version.py b/google/cloud/aiplatform/version.py index 4ae710b3b4..f9391e1333 100644 --- a/google/cloud/aiplatform/version.py +++ b/google/cloud/aiplatform/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.98.0" +__version__ = "1.99.0" diff --git a/google/cloud/aiplatform_v1/gapic_version.py b/google/cloud/aiplatform_v1/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform_v1/gapic_version.py +++ b/google/cloud/aiplatform_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform_v1beta1/gapic_version.py b/google/cloud/aiplatform_v1beta1/gapic_version.py index 1a1b7dfaef..fc8db199f8 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.98.0" # {x-release-please-version} +__version__ = "1.99.0" # {x-release-please-version} diff --git a/pypi/_vertex_ai_placeholder/version.py b/pypi/_vertex_ai_placeholder/version.py index ebd0774020..f5da39f18d 100644 --- a/pypi/_vertex_ai_placeholder/version.py +++ b/pypi/_vertex_ai_placeholder/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.98.0" +__version__ = "1.99.0" diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json index 524f1eb18e..ac2091d3c9 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.98.0" + "version": "1.99.0" }, "snippets": [ { diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json index a8587832b3..3373fc56dc 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.98.0" + "version": "1.99.0" }, "snippets": [ { diff --git a/tests/unit/aiplatform/test_multimodal_datasets.py b/tests/unit/aiplatform/test_multimodal_datasets.py index c116f9ce9d..77a224e9b9 100644 --- a/tests/unit/aiplatform/test_multimodal_datasets.py +++ b/tests/unit/aiplatform/test_multimodal_datasets.py @@ -282,6 +282,21 @@ def assess_data_tuning_validation_mock(): yield assess_data_mock +@pytest.fixture +def assess_data_batch_prediction_resources_mock(): + with mock.patch.object( + dataset_service.DatasetServiceClient, "assess_data" + ) as assess_data_mock: + assess_data_lro_mock = mock.Mock(operation.Operation) + assess_data_lro_mock.result.return_value = gca_dataset_service.AssessDataResponse( + batch_prediction_resource_usage_assessment_result=gca_dataset_service.AssessDataResponse.BatchPredictionResourceUsageAssessmentResult( + token_count=100, audio_token_count=200 + ) + ) + assess_data_mock.return_value = assess_data_lro_mock + yield assess_data_mock + + @pytest.fixture def assemble_data_mock(): with mock.patch.object( @@ -746,6 +761,55 @@ def test_assess_tuning_validity(self, assess_data_tuning_validation_mock): ) assert result == ummd.TuningValidationAssessmentResult(errors=["error message"]) + @pytest.mark.usefixtures("get_dataset_mock") + def test_assess_batch_prediction_resources( + self, assess_data_batch_prediction_resources_mock + ): + aiplatform.init(project=_TEST_PROJECT) + dataset = ummd.MultimodalDataset(dataset_name=_TEST_NAME) + template_config = ummd.GeminiTemplateConfig( + field_mapping={"question": "questionColumn"}, + ) + result = dataset.assess_batch_prediction_resources( + model_name="gemini-1.5-flash-exp", + template_config=template_config, + ) + assess_data_batch_prediction_resources_mock.assert_called_once_with( + request=gca_dataset_service.AssessDataRequest( + name=_TEST_NAME, + batch_prediction_resource_usage_assessment_config=gca_dataset_service.AssessDataRequest.BatchPredictionResourceUsageAssessmentConfig( + model_name="gemini-1.5-flash-exp" + ), + gemini_request_read_config=gca_dataset_service.GeminiRequestReadConfig( + template_config=template_config._raw_gemini_template_config + ), + ), + timeout=None, + ) + assert result == ummd.BatchPredictionResourceUsageAssessmentResult( + token_count=100, audio_token_count=200 + ) + + @pytest.mark.usefixtures("get_dataset_request_column_name_mock") + def test_assess_batch_prediction_resources_request_column_name( + self, assess_data_batch_prediction_resources_mock + ): + aiplatform.init(project=_TEST_PROJECT) + dataset = ummd.MultimodalDataset(dataset_name=_TEST_NAME) + dataset.assess_batch_prediction_resources(model_name="gemini-1.5-flash-exp") + assess_data_batch_prediction_resources_mock.assert_called_once_with( + request=gca_dataset_service.AssessDataRequest( + name=_TEST_NAME, + batch_prediction_resource_usage_assessment_config=gca_dataset_service.AssessDataRequest.BatchPredictionResourceUsageAssessmentConfig( + model_name="gemini-1.5-flash-exp" + ), + gemini_request_read_config=gca_dataset_service.GeminiRequestReadConfig( + assembled_request_column_name="requests" + ), + ), + timeout=None, + ) + @pytest.mark.usefixtures("get_dataset_request_column_name_mock") def test_assess_tuning_validity_request_column_name( self, assess_data_tuning_validation_mock diff --git a/tests/unit/vertex_langchain/test_agent_engines.py b/tests/unit/vertex_langchain/test_agent_engines.py index 92ffb0e27b..ec2e554aec 100644 --- a/tests/unit/vertex_langchain/test_agent_engines.py +++ b/tests/unit/vertex_langchain/test_agent_engines.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from absl.testing import parameterized import cloudpickle import difflib import importlib @@ -1147,7 +1146,6 @@ def test_create_agent_engine_with_env_vars_list( retry=_TEST_RETRY, ) - # pytest does not allow absl.testing.parameterized.named_parameters. @pytest.mark.parametrize( "test_case_name, test_engine_instance, expected_framework", [ @@ -1190,7 +1188,6 @@ def test_get_agent_framework( framework = _agent_engines._get_agent_framework(test_engine_instance) assert framework == expected_framework - # pytest does not allow absl.testing.parameterized.named_parameters. @pytest.mark.parametrize( "test_case_name, test_kwargs, want_request", [ @@ -1601,7 +1598,6 @@ def test_query_agent_engine( test_agent_engine.query(query=_TEST_QUERY_PROMPT) query_mock.assert_called_with(request=_TEST_AGENT_ENGINE_QUERY_REQUEST) - # pytest does not allow absl.testing.parameterized.named_parameters. @pytest.mark.parametrize( "test_case_name, test_class_methods_spec, want_operation_schema_api_modes", [ @@ -1847,7 +1843,6 @@ def test_query_after_create_agent_engine_with_operation_schema( class_method=method_name, ) ) - assert invoked_method.__doc__ == test_doc # pytest does not allow absl.testing.parameterized.named_parameters. @pytest.mark.parametrize( @@ -2062,7 +2057,6 @@ def test_stream_query_after_create_agent_engine_with_operation_schema( class_method=method_name, ) ) - assert invoked_method.__doc__ == test_doc # pytest does not allow absl.testing.parameterized.named_parameters. @pytest.mark.parametrize( @@ -2243,7 +2237,6 @@ async def test_async_stream_query_after_create_agent_engine_with_operation_schem class_method=method_name, ) ) - assert invoked_method.__doc__ == test_doc # pytest does not allow absl.testing.parameterized.named_parameters. @pytest.mark.parametrize( @@ -2870,7 +2863,7 @@ def test_update_class_methods_spec_with_registered_operation_not_found(self): "register the API methods: " "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#custom-methods. " "Error: {Unsupported api mode: `UNKNOWN_API_MODE`, " - "Supported modes are: ``, `async`, `stream` and `async_stream`.}" + "Supported modes are: ``, `async`, `async_stream`, `stream`.}" ), ), ], @@ -2987,161 +2980,169 @@ def assert_called_with_diff(mock_obj, expected_kwargs=None): ) -class TestGenerateSchema(parameterized.TestCase): - @parameterized.named_parameters( - dict( - testcase_name="place_tool_query", - func=place_tool_query, - required=["city", "activity"], - expected_operation={ - "name": "place_tool_query", - "description": ( - "Searches the city for recommendations on the activity." - ), - "parameters": { - "type": "object", - "properties": { - "city": {"type": "string"}, - "activity": {"type": "string", "nullable": True}, - "page_size": {"type": "integer"}, +class TestGenerateSchema: + # pytest does not allow absl.testing.parameterized.named_parameters. + @pytest.mark.parametrize( + "func, required, expected_operation", + [ + ( + # "place_tool_query", + place_tool_query, + ["city", "activity"], + { + "name": "place_tool_query", + "description": ( + "Searches the city for recommendations on the activity." + ), + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "activity": {"type": "string", "nullable": True}, + "page_size": {"type": "integer"}, + }, + "required": ["city", "activity"], }, - "required": ["city", "activity"], }, - }, - ), - dict( - testcase_name="place_photo_query", - func=place_photo_query, - required=["photo_reference"], - expected_operation={ - "name": "place_photo_query", - "description": "Returns the photo for a given reference.", - "parameters": { - "properties": { - "photo_reference": {"type": "string"}, - "maxwidth": {"type": "integer"}, - "maxheight": {"type": "integer", "nullable": True}, + ), + ( + # "place_photo_query", + place_photo_query, + ["photo_reference"], + { + "name": "place_photo_query", + "description": "Returns the photo for a given reference.", + "parameters": { + "type": "object", + "properties": { + "photo_reference": {"type": "string"}, + "maxwidth": {"type": "integer"}, + "maxheight": {"type": "integer", "nullable": True}, + }, + "required": ["photo_reference"], }, - "required": ["photo_reference"], - "type": "object", }, - }, - ), + ), + ], ) def test_generate_schemas(self, func, required, expected_operation): result = _utils.generate_schema(func, required=required) - self.assertDictEqual(result, expected_operation) + assert result == expected_operation -class TestToProto(parameterized.TestCase): - @parameterized.named_parameters( - dict( - testcase_name="empty_dict", - obj={}, - expected_proto=struct_pb2.Struct(fields={}), - ), - dict( - testcase_name="nonempty_dict", - obj={"snake_case": 1, "camelCase": 2}, - expected_proto=struct_pb2.Struct( - fields={ - "snake_case": struct_pb2.Value(number_value=1), - "camelCase": struct_pb2.Value(number_value=2), - }, +class TestToProto: + # pytest does not allow absl.testing.parameterized.named_parameters. + @pytest.mark.parametrize( + "obj, expected_proto", + [ + ( + # "empty_dict", + {}, + struct_pb2.Struct(fields={}), ), - ), - dict( - testcase_name="empty_proto_message", - obj=struct_pb2.Struct(fields={}), - expected_proto=struct_pb2.Struct(fields={}), - ), - dict( - testcase_name="nonempty_proto_message", - obj=struct_pb2.Struct( - fields={ - "snake_case": struct_pb2.Value(number_value=1), - "camelCase": struct_pb2.Value(number_value=2), - }, + ( + # "nonempty_dict", + {"snake_case": 1, "camelCase": 2}, + struct_pb2.Struct( + fields={ + "snake_case": struct_pb2.Value(number_value=1), + "camelCase": struct_pb2.Value(number_value=2), + }, + ), ), - expected_proto=struct_pb2.Struct( - fields={ - "snake_case": struct_pb2.Value(number_value=1), - "camelCase": struct_pb2.Value(number_value=2), - }, + ( + # "empty_proto_message", + struct_pb2.Struct(fields={}), + struct_pb2.Struct(fields={}), ), - ), + ( + # "nonempty_proto_message", + struct_pb2.Struct( + fields={ + "snake_case": struct_pb2.Value(number_value=1), + "camelCase": struct_pb2.Value(number_value=2), + }, + ), + struct_pb2.Struct( + fields={ + "snake_case": struct_pb2.Value(number_value=1), + "camelCase": struct_pb2.Value(number_value=2), + }, + ), + ), + ], ) def test_to_proto(self, obj, expected_proto): result = _utils.to_proto(obj) - self.assertDictEqual(_utils.to_dict(result), _utils.to_dict(expected_proto)) - # converting a new object to proto should not modify earlier objects. - new_result = _utils.to_proto({}) - self.assertDictEqual(_utils.to_dict(result), _utils.to_dict(expected_proto)) - self.assertEmpty(new_result) + assert _utils.to_dict(result) == _utils.to_dict(expected_proto) -class ToParsedJsonTest(parameterized.TestCase): - @parameterized.named_parameters( - dict( - testcase_name="valid_json", - obj=httpbody_pb2.HttpBody( - content_type="application/json", data=b'{"a": 1, "b": "hello"}' +class ToParsedJsonTest: + # pytest does not allow absl.testing.parameterized.named_parameters. + @pytest.mark.parametrize( + "obj, expected", + [ + ( + # "valid_json", + httpbody_pb2.HttpBody( + content_type="application/json", data=b'{"a": 1, "b": "hello"}' + ), + [{"a": 1, "b": "hello"}], ), - expected=[{"a": 1, "b": "hello"}], - ), - dict( - testcase_name="invalid_json", - obj=httpbody_pb2.HttpBody( - content_type="application/json", data=b'{"a": 1, "b": "hello"' + ( + # "invalid_json", + httpbody_pb2.HttpBody( + content_type="application/json", data=b'{"a": 1, "b": "hello"' + ), + ['{"a": 1, "b": "hello"'], # returns the unparsed string ), - expected=['{"a": 1, "b": "hello"'], # returns the unparsed string - ), - dict( - testcase_name="missing_content_type", - obj=httpbody_pb2.HttpBody(data=b'{"a": 1}'), - expected=[httpbody_pb2.HttpBody(data=b'{"a": 1}')], - ), - dict( - testcase_name="missing_data", - obj=httpbody_pb2.HttpBody(content_type="application/json"), - expected=[None], - ), - dict( - testcase_name="wrong_content_type", - obj=httpbody_pb2.HttpBody(content_type="text/plain", data=b"hello"), - expected=[httpbody_pb2.HttpBody(content_type="text/plain", data=b"hello")], - ), - dict( - testcase_name="empty_data", - obj=httpbody_pb2.HttpBody(content_type="application/json", data=b""), - expected=[None], - ), - dict( - testcase_name="unicode_data", - obj=httpbody_pb2.HttpBody( - content_type="application/json", data='{"a": "你好"}'.encode("utf-8") + ( + # "missing_content_type", + httpbody_pb2.HttpBody(data=b'{"a": 1}'), + [httpbody_pb2.HttpBody(data=b'{"a": 1}')], ), - expected=[{"a": "你好"}], - ), - dict( - testcase_name="nested_json", - obj=httpbody_pb2.HttpBody( - content_type="application/json", data=b'{"a": {"b": 1}}' + ( + # "missing_data", + httpbody_pb2.HttpBody(content_type="application/json"), + [None], ), - expected=[{"a": {"b": 1}}], - ), - dict( - testcase_name="multiline_json", - obj=httpbody_pb2.HttpBody( - content_type="application/json", - data=b'{"a": {"b": 1}}\n{"a": {"b": 2}}', + ( + # "wrong_content_type", + httpbody_pb2.HttpBody(content_type="text/plain", data=b"hello"), + [httpbody_pb2.HttpBody(content_type="text/plain", data=b"hello")], ), - expected=[{"a": {"b": 1}}, {"a": {"b": 2}}], - ), + ( + # "empty_data", + httpbody_pb2.HttpBody(content_type="application/json", data=b""), + [None], + ), + ( + # "unicode_data", + httpbody_pb2.HttpBody( + content_type="application/json", data='{"a": "你好"}'.encode("utf-8") + ), + [{"a": "你好"}], + ), + ( + # "nested_json", + httpbody_pb2.HttpBody( + content_type="application/json", data=b'{"a": {"b": 1}}' + ), + [{"a": {"b": 1}}], + ), + ( + # "multiline_json", + httpbody_pb2.HttpBody( + content_type="application/json", + data=b'{"a": {"b": 1}}\n{"a": {"b": 2}}', + ), + [{"a": {"b": 1}}, {"a": {"b": 2}}], + ), + ], ) def test_to_parsed_json(self, obj, expected): for got, want in zip(_utils.yield_parsed_json(obj), expected): - self.assertEqual(got, want) + assert got == want class TestRequirements: diff --git a/tests/unit/vertex_langchain/test_reasoning_engines.py b/tests/unit/vertex_langchain/test_reasoning_engines.py index 15257abc26..9f7f547ead 100644 --- a/tests/unit/vertex_langchain/test_reasoning_engines.py +++ b/tests/unit/vertex_langchain/test_reasoning_engines.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from absl.testing import parameterized import cloudpickle import dataclasses import datetime @@ -1373,9 +1372,6 @@ def test_query_after_create_reasoning_engine_with_operation_schema( ) ) - invoked_method = getattr(test_reasoning_engine, method_name) - assert invoked_method.__doc__ == test_doc - # pytest does not allow absl.testing.parameterized.named_parameters. @pytest.mark.parametrize( "test_case_name, test_engine, test_class_methods, test_class_methods_spec", @@ -1599,7 +1595,6 @@ def test_stream_query_after_create_reasoning_engine_with_operation_schema( class_method=method_name, ) ) - assert invoked_method.__doc__ == test_doc # pytest does not allow absl.testing.parameterized.named_parameters. @pytest.mark.parametrize( @@ -2162,97 +2157,94 @@ def assert_called_with_diff(mock_obj, expected_kwargs=None): ) -class TestGenerateSchema(parameterized.TestCase): - @parameterized.named_parameters( - dict( - testcase_name="place_tool_query", - func=place_tool_query, - required=["city", "activity"], - expected_operation={ - "name": "place_tool_query", - "description": ( - "Searches the city for recommendations on the activity." - ), - "parameters": { - "type": "object", - "properties": { - "city": {"type": "string"}, - "activity": {"type": "string", "nullable": True}, - "page_size": {"type": "integer"}, +class TestGenerateSchema: + # pytest does not allow absl.testing.parameterized.named_parameters. + @pytest.mark.parametrize( + "func, required, expected_operation", + [ + ( + place_tool_query, + ["city", "activity"], + { + "name": "place_tool_query", + "description": ( + "Searches the city for recommendations on the activity." + ), + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "activity": {"type": "string", "nullable": True}, + "page_size": {"type": "integer"}, + }, + "required": ["city", "activity"], }, - "required": ["city", "activity"], }, - }, - ), - dict( - testcase_name="place_photo_query", - func=place_photo_query, - required=["photo_reference"], - expected_operation={ - "name": "place_photo_query", - "description": "Returns the photo for a given reference.", - "parameters": { - "properties": { - "photo_reference": {"type": "string"}, - "maxwidth": {"type": "integer"}, - "maxheight": {"type": "integer", "nullable": True}, + ), + ( + place_photo_query, + ["photo_reference"], + { + "name": "place_photo_query", + "description": "Returns the photo for a given reference.", + "parameters": { + "properties": { + "photo_reference": {"type": "string"}, + "maxwidth": {"type": "integer"}, + "maxheight": {"type": "integer", "nullable": True}, + }, + "required": ["photo_reference"], + "type": "object", }, - "required": ["photo_reference"], - "type": "object", }, - }, - ), + ), + ], ) def test_generate_schemas(self, func, required, expected_operation): result = _utils.generate_schema(func, required=required) - self.assertDictEqual(result, expected_operation) + assert result == expected_operation -class TestToProto(parameterized.TestCase): - @parameterized.named_parameters( - dict( - testcase_name="empty_dict", - obj={}, - expected_proto=struct_pb2.Struct(fields={}), - ), - dict( - testcase_name="nonempty_dict", - obj={"a": 1, "b": 2}, - expected_proto=struct_pb2.Struct( - fields={ - "a": struct_pb2.Value(number_value=1), - "b": struct_pb2.Value(number_value=2), - }, +class TestToProto: + @pytest.mark.parametrize( + "obj, expected_proto", + [ + ( + {}, + struct_pb2.Struct(fields={}), ), - ), - dict( - testcase_name="empty_proto_message", - obj=struct_pb2.Struct(fields={}), - expected_proto=struct_pb2.Struct(fields={}), - ), - dict( - testcase_name="nonempty_proto_message", - obj=struct_pb2.Struct( - fields={ - "a": struct_pb2.Value(number_value=1), - "b": struct_pb2.Value(number_value=2), - }, + ( + {"a": 1, "b": 2}, + struct_pb2.Struct( + fields={ + "a": struct_pb2.Value(number_value=1), + "b": struct_pb2.Value(number_value=2), + }, + ), ), - expected_proto=struct_pb2.Struct( - fields={ - "a": struct_pb2.Value(number_value=1), - "b": struct_pb2.Value(number_value=2), - }, + ( + struct_pb2.Struct(fields={}), + struct_pb2.Struct(fields={}), ), - ), + ( + struct_pb2.Struct( + fields={ + "a": struct_pb2.Value(number_value=1), + "b": struct_pb2.Value(number_value=2), + }, + ), + struct_pb2.Struct( + fields={ + "a": struct_pb2.Value(number_value=1), + "b": struct_pb2.Value(number_value=2), + }, + ), + ), + ], ) def test_to_proto(self, obj, expected_proto): result = _utils.to_proto(obj) - self.assertDictEqual(_utils.to_dict(result), _utils.to_dict(expected_proto)) - # converting a new object to proto should not modify earlier objects. - new_result = _utils.to_proto({}) - self.assertDictEqual(_utils.to_dict(result), _utils.to_dict(expected_proto)) - self.assertEmpty(new_result) + assert _utils.to_dict(result) == _utils.to_dict(expected_proto) # class TestDataclassToDict(parameterized.TestCase): # @parameterized.named_parameters( @@ -2283,82 +2275,86 @@ def test_to_proto(self, obj, expected_proto): # result = _utils.dataclass_to_dict(obj) # self.assertEqual(result, expected_dict) - @parameterized.named_parameters( - dict( - testcase_name="non_dataclass_input", - obj="not a dataclass", - expected_exception=TypeError, - ), - dict( - testcase_name="non_serializable_field", - obj=NonSerializableClass(name="test", date=datetime.datetime.now()), - expected_exception=TypeError, - ), + @pytest.mark.parametrize( + "obj, expected_exception", + [ + ( + "not a dataclass", + TypeError, + ), + ( + NonSerializableClass(name="test", date=datetime.datetime.now()), + TypeError, + ), + ], ) def test_dataclass_to_dict_failure(self, obj, expected_exception): - with self.assertRaises(expected_exception): + with pytest.raises(expected_exception): _utils.dataclass_to_dict(obj) -class ToParsedJsonTest(parameterized.TestCase): - @parameterized.named_parameters( - dict( - testcase_name="valid_json", - obj=httpbody_pb2.HttpBody( - content_type="application/json", data=b'{"a": 1, "b": "hello"}' +class ToParsedJsonTest: + @pytest.mark.parametrize( + "obj, expected", + [ + ( + # "valid_json", + httpbody_pb2.HttpBody( + content_type="application/json", data=b'{"a": 1, "b": "hello"}' + ), + [{"a": 1, "b": "hello"}], ), - expected=[{"a": 1, "b": "hello"}], - ), - dict( - testcase_name="invalid_json", - obj=httpbody_pb2.HttpBody( - content_type="application/json", data=b'{"a": 1, "b": "hello"' + ( + # "invalid_json", + httpbody_pb2.HttpBody( + content_type="application/json", data=b'{"a": 1, "b": "hello"' + ), + ['{"a": 1, "b": "hello"'], # returns the unparsed string ), - expected=['{"a": 1, "b": "hello"'], # returns the unparsed string - ), - dict( - testcase_name="missing_content_type", - obj=httpbody_pb2.HttpBody(data=b'{"a": 1}'), - expected=[httpbody_pb2.HttpBody(data=b'{"a": 1}')], - ), - dict( - testcase_name="missing_data", - obj=httpbody_pb2.HttpBody(content_type="application/json"), - expected=[None], - ), - dict( - testcase_name="wrong_content_type", - obj=httpbody_pb2.HttpBody(content_type="text/plain", data=b"hello"), - expected=[httpbody_pb2.HttpBody(content_type="text/plain", data=b"hello")], - ), - dict( - testcase_name="empty_data", - obj=httpbody_pb2.HttpBody(content_type="application/json", data=b""), - expected=[None], - ), - dict( - testcase_name="unicode_data", - obj=httpbody_pb2.HttpBody( - content_type="application/json", data='{"a": "你好"}'.encode("utf-8") + ( + # "missing_content_type", + httpbody_pb2.HttpBody(data=b'{"a": 1}'), + [httpbody_pb2.HttpBody(data=b'{"a": 1}')], ), - expected=[{"a": "你好"}], - ), - dict( - testcase_name="nested_json", - obj=httpbody_pb2.HttpBody( - content_type="application/json", data=b'{"a": {"b": 1}}' + ( + # "missing_data", + httpbody_pb2.HttpBody(content_type="application/json"), + [None], ), - expected=[{"a": {"b": 1}}], - ), - dict( - testcase_name="multiline_json", - obj=httpbody_pb2.HttpBody( - content_type="application/json", - data=b'{"a": {"b": 1}}\n{"a": {"b": 2}}', + ( + # "wrong_content_type", + httpbody_pb2.HttpBody(content_type="text/plain", data=b"hello"), + [httpbody_pb2.HttpBody(content_type="text/plain", data=b"hello")], ), - expected=[{"a": {"b": 1}}, {"a": {"b": 2}}], - ), + ( + # "empty_data", + httpbody_pb2.HttpBody(content_type="application/json", data=b""), + [None], + ), + ( + # "unicode_data", + httpbody_pb2.HttpBody( + content_type="application/json", data='{"a": "你好"}'.encode("utf-8") + ), + [{"a": "你好"}], + ), + ( + # "nested_json", + httpbody_pb2.HttpBody( + content_type="application/json", data=b'{"a": {"b": 1}}' + ), + [{"a": {"b": 1}}], + ), + ( + # "multiline_json", + httpbody_pb2.HttpBody( + content_type="application/json", + data=b'{"a": {"b": 1}}\n{"a": {"b": 2}}', + ), + [{"a": {"b": 1}}, {"a": {"b": 2}}], + ), + ], ) def test_to_parsed_json(self, obj, expected): for got, want in zip(_utils.yield_parsed_json(obj), expected): - self.assertEqual(got, want) + assert got == want diff --git a/tests/unit/vertexai/genai/replays/conftest.py b/tests/unit/vertexai/genai/replays/conftest.py new file mode 100644 index 0000000000..e41405e3cb --- /dev/null +++ b/tests/unit/vertexai/genai/replays/conftest.py @@ -0,0 +1,132 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +"""Conftest for Vertex SDK GenAI tests.""" + +import os +from unittest import mock + +from vertexai._genai import ( + client as vertexai_genai_client_module, +) +from google.genai import _replay_api_client +from google.genai import client as google_genai_client_module +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--mode", + action="store", + default="auto", + help="""Replay mode. + One of: + * auto: Replay if replay files exist, otherwise record. + * record: Always call the API and record. + * replay: Always replay, fail if replay files do not exist. + * api: Always call the API and do not record. + * tap: Always replay, fail if replay files do not exist. Also sets default values for the API key and replay directory. + """, + ) + parser.addoption( + "--replays-directory-prefix", + action="store", + default=None, + help="""Directory to use for replays. + If not set, the default directory will be used. + """, + ) + + +@pytest.fixture +def use_vertex(): + return True + + +# Overridden at the module level for each test file. +@pytest.fixture +def replays_prefix(): + return "test" + + +def _get_replay_id(use_vertex: bool, replays_prefix: str) -> str: + test_name_ending = os.environ.get("PYTEST_CURRENT_TEST").split("::")[-1] + test_name = test_name_ending.split(" ")[0].split("[")[0] + "." + "vertex" + return "/".join([replays_prefix, test_name]) + + +@pytest.fixture +def client(use_vertex, replays_prefix, http_options, request): + + mode = request.config.getoption("--mode") + replays_directory_prefix = request.config.getoption("--replays-directory-prefix") + if mode not in ["auto", "record", "replay", "api", "tap"]: + raise ValueError("Invalid mode: " + mode) + test_function_name = request.function.__name__ + test_filename = os.path.splitext(os.path.basename(request.path))[0] + if test_function_name.startswith(test_filename): + raise ValueError( + f""" + {test_function_name}: + Do not include the test filename in the test function name. + keep the test function name short.""" + ) + + replay_id = _get_replay_id(use_vertex, replays_prefix) + + if mode == "tap": + mode = "replay" + + # Set various environment variables to ensure that the test runs. + os.environ["GOOGLE_API_KEY"] = "dummy-api-key" + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.join( + os.path.dirname(__file__), + "credentials.json", + ) + os.environ["GOOGLE_CLOUD_PROJECT"] = "project-id" + os.environ["GOOGLE_CLOUD_LOCATION"] = "location" + + # Set the replay directory to the root directory of the replays. + # This is needed to ensure that the replay files are found. + replays_root_directory = os.path.abspath( + os.path.join( + os.path.dirname(__file__), + "../../../../../../../../../google/cloud/aiplatform/sdk/genai/replays", + ) + ) + os.environ["GOOGLE_GENAI_REPLAYS_DIRECTORY"] = replays_root_directory + replay_client = _replay_api_client.ReplayApiClient( + mode=mode, + replay_id=replay_id, + vertexai=use_vertex, + http_options=http_options, + ) + + replay_client.replays_directory = ( + f"{replays_directory_prefix}/google/cloud/aiplatform/sdk/replays/" + ) + + with mock.patch.object( + google_genai_client_module.Client, "_get_api_client" + ) as patch_method: + patch_method.return_value = replay_client + google_genai_client = vertexai_genai_client_module.Client() + + # Yield the client so that cleanup can be completed at the end of the test. + yield google_genai_client + + # Save the replay after the test if we're in recording mode. + replay_client.close() diff --git a/tests/unit/vertexai/genai/replays/pytest_helper.py b/tests/unit/vertexai/genai/replays/pytest_helper.py new file mode 100644 index 0000000000..ffe4c0c4c3 --- /dev/null +++ b/tests/unit/vertexai/genai/replays/pytest_helper.py @@ -0,0 +1,49 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from typing import Any, Optional + +from google.genai._api_client import HttpOptions +import pytest + + +is_api_mode = "config.getoption('--mode') == 'api'" + + +# Sets up the test framework. +# file: Always use __file__ +# globals_for_file: Always use globals() +def setup( + *, + file: str, + globals_for_file: Optional[dict[str, Any]] = None, + test_method: Optional[str] = None, + http_options: Optional[HttpOptions] = None, +): + """Generates parameterization for tests""" + replays_directory = ( + file.replace(os.path.dirname(__file__), "tests/vertex_sdk_genai_replays") + .replace(".py", "") + .replace("/test_", "/") + ) + + # Add fixture for requested client option. + return pytest.mark.parametrize( + "use_vertex, replays_prefix, http_options", + [ + (True, replays_directory, http_options), + ], + ) diff --git a/tests/unit/vertexai/genai/replays/test_evaluate_instances.py b/tests/unit/vertexai/genai/replays/test_evaluate_instances.py new file mode 100644 index 0000000000..cb67dd0028 --- /dev/null +++ b/tests/unit/vertexai/genai/replays/test_evaluate_instances.py @@ -0,0 +1,49 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +import os + +from tests.unit.vertexai.genai.replays import pytest_helper +from vertexai._genai import types +import pytest + + +IS_KOKORO = os.getenv("KOKORO_BUILD_NUMBER") is not None + + +@pytest.mark.skipif(IS_KOKORO, reason="This test is only run in google3 env.") +class TestEvaluateInstances: + """Tests for evaluate instances.""" + + def test_bleu_metric(self, client): + test_bleu_input = types.BleuInput( + instances=[ + types.BleuInstance( + reference="The quick brown fox jumps over the lazy dog.", + prediction="A fast brown fox leaps over a lazy dog.", + ) + ], + metric_spec=types.BleuSpec(), + ) + response = client.evals._evaluate_instances(bleu_input=test_bleu_input) + assert len(response.bleu_results.bleu_metric_values) == 1 + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="evals.evaluate", +) diff --git a/tests/unit/vertexai/genai/run_replay_tests.sh b/tests/unit/vertexai/genai/run_replay_tests.sh new file mode 100755 index 0000000000..5ee3cef025 --- /dev/null +++ b/tests/unit/vertexai/genai/run_replay_tests.sh @@ -0,0 +1,149 @@ +#!/bin/bash + +# This script runs replay tests for the Vertex SDK GenAI client +# It is intended to be used from the google3 directory of a CitC client. +# You can provide a specific test file to run, or it will run all the replay tests +# in third_party/py/google/cloud/aiplatform/tests/unit/vertexai/genai/replays/ +# +# Example: +# ./third_party/py/google/cloud/aiplatform/tests/unit/vertexai/genai/run_replay_tests.sh test_evals.py + +# It also supports a --mode flag, which can be one of: +# * record: Call the API and record the result in a replay file. +# * replay: Use the recorded replay file to simulate the API call, or record if the replay file does not exist. +# * api: Call the API and do not record. + +# Get the current working directory +START_DIR=$(pwd) + +# Check if the current directory ends with /google3 +# Otherwise the copybara command will fail +if [[ "$START_DIR" != */google3 ]]; then + echo "Error: This script must be run from your client's '/google3' directory." + echo "Your current directory is: $START_DIR" + exit 1 +fi + +# Check required env vars have been set +REQUIRED_ENV_VARS=( + "GOOGLE_CLOUD_PROJECT" + "GOOGLE_CLOUD_LOCATION" + "GOOGLE_GENAI_REPLAYS_DIRECTORY" +) + +for var_name in "${REQUIRED_ENV_VARS[@]}"; do + var_value="${!var_name}" + if [ -z "$var_value" ]; then + echo "Error: Environment variable $var_name is not set." + echo "Please set it before running this script." + exit 1 + fi +done + +# Generate a unique temporary directory in /tmp/ +TEMP_DIR=$(mktemp -d -t XXXXXX) + +# Check if the temporary directory was created successfully +if [ -z "$TEMP_DIR" ]; then + echo "Error: Could not create a temporary directory." + exit 1 +fi + +echo "Created temporary directory: $TEMP_DIR" + +# Run copybara and copy Vertex SDK to the temporary directory +# The --folder-dir argument is set to the newly created temporary directory. +echo "Running copybara..." +COPYBARA_EXEC="/google/bin/releases/copybara/public/copybara/copybara" +"$COPYBARA_EXEC" third_party/py/google/cloud/aiplatform/copy.bara.sky folder_to_folder .. --folder-dir="$TEMP_DIR" --ignore-noop + +# Check copybara's exit status +if [ $? -ne 0 ]; then + echo "Error: copybara command failed. Exiting." + # Clean up the temporary directory on failure + rm -rf "$TEMP_DIR" + exit 1 +fi + +echo "Copybara completed successfully." + +# Change into the temporary directory with copybara output +echo "Changing into temp directory: $TEMP_DIR" +cd "$TEMP_DIR" + +if [ $? -ne 0 ]; then + echo "Error: Could not change into directory $TEMP_DIR. Exiting." + exit 1 +fi + +PARSED_ARGS=$(getopt -o "" -l "mode:" -- "$@") + +if [ $? -ne 0 ]; then + echo "Error: Failed to parse command line arguments." >&2 + exit 1 +fi + +# Get the test file path and --mode flag value if provided +eval set -- "$PARSED_ARGS" + +TEST_FILE_ARG="" # Stores the provided test path, if any +MODE_VALUE="" # Stores the value of the --mode flag (e.g., 'replay') + +while true; do + case "$1" in + --mode) + MODE_VALUE="$2" + shift 2 + ;; + --) + shift + break + ;; + *) + echo "Internal error: Unrecognized arg option: '$1'" >&2 + exit 1 + ;; + esac +done + +# We expect at most one positional argument (the test file path). +if [ -n "$1" ]; then + TEST_FILE_ARG="$1" + if [ "$#" -gt 1 ]; then + echo "Warning: Ignoring extra positional arguments after '$TEST_FILE_ARG'. Only one test file/path can be specified." >&2 + fi +fi + +# Construct the full --mode argument string to be passed to pytest. +MODE_ARG="" +if [ -n "$MODE_VALUE" ]; then + MODE_ARG="--mode $MODE_VALUE" +fi + + +# Set pytest path for which tests to run +DEFAULT_TEST_PATH="tests/unit/vertexai/genai/replays/" + +if [ -n "$TEST_FILE_ARG" ]; then + PYTEST_PATH="${DEFAULT_TEST_PATH}${TEST_FILE_ARG}" + echo "Provided test file path: '$TEST_FILE_ARG'. Running pytest on: ${PYTEST_PATH}" +else + PYTEST_PATH="$DEFAULT_TEST_PATH" + echo "No test file arg provided. Running pytest on default path: ${PYTEST_PATH}" +fi + +# Run tests +# -s is equivalent to --capture=no, it ensures pytest doesn't capture the output from stdout and stderr +# so it can be logged when this script is run +pytest -v -s "$PYTEST_PATH" ${MODE_ARG} --replays-directory-prefix="$START_DIR" + +PYTEST_EXIT_CODE=$? + +echo "Cleaning up temporary directory: $TEMP_DIR" +# Go back to the original directory before removing the temporary directory +cd - > /dev/null +rm -rf "$TEMP_DIR" + +echo "Pytest tests completed with exit code: $PYTEST_EXIT_CODE." + +exit $PYTEST_EXIT_CODE \ No newline at end of file diff --git a/tests/unit/vertexai/genai/test_agent_engines.py b/tests/unit/vertexai/genai/test_agent_engines.py new file mode 100644 index 0000000000..11a383a72e --- /dev/null +++ b/tests/unit/vertexai/genai/test_agent_engines.py @@ -0,0 +1,1712 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import asyncio +import importlib +import json +import os +import pytest +import sys +import tempfile +from typing import Any, AsyncIterable, Dict, Iterable, List +from unittest import mock +from urllib.parse import urlencode + +from google import auth +from google.auth import credentials as auth_credentials +from google.cloud import aiplatform +import vertexai +from google.cloud.aiplatform import initializer +from vertexai._genai import agent_engines +from vertexai._genai import types as _genai_types +from vertexai.agent_engines import _agent_engines +from vertexai.agent_engines import _utils +from google.genai import client as genai_client +from google.genai import types as genai_types + + +_TEST_AGENT_FRAMEWORK = "test-agent-framework" + + +class CapitalizeEngine: + """A sample Agent Engine.""" + + def set_up(self): + pass + + def query(self, unused_arbitrary_string_name: str) -> str: + """Runs the engine.""" + return unused_arbitrary_string_name.upper() + + def clone(self): + return self + + +class AsyncQueryEngine: + """A sample Agent Engine that implements `async_query`.""" + + def set_up(self): + pass + + async def async_query(self, unused_arbitrary_string_name: str): + """Runs the query asynchronously.""" + return unused_arbitrary_string_name.upper() + + def clone(self): + return self + + +class AsyncStreamQueryEngine: + """A sample Agent Engine that implements `async_stream_query`.""" + + def set_up(self): + pass + + async def async_stream_query( + self, unused_arbitrary_string_name: str + ) -> AsyncIterable[Any]: + """Runs the async stream engine.""" + for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE: + yield chunk + + def clone(self): + return self + + +class StreamQueryEngine: + """A sample Agent Engine that implements `stream_query`.""" + + def set_up(self): + pass + + def stream_query(self, unused_arbitrary_string_name: str) -> Iterable[Any]: + """Runs the stream engine.""" + for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE: + yield chunk + + def clone(self): + return self + + +class OperationRegistrableEngine: + """Add a test class that implements OperationRegistrable.""" + + agent_framework = _TEST_AGENT_FRAMEWORK + + def query(self, unused_arbitrary_string_name: str) -> str: + """Runs the engine.""" + return unused_arbitrary_string_name.upper() + + async def async_query(self, unused_arbitrary_string_name: str) -> str: + """Runs the query asynchronously.""" + return unused_arbitrary_string_name.upper() + + # Add a custom method to test the custom method registration. + def custom_method(self, x: str) -> str: + return x.upper() + + # Add a custom async method to test the custom async method registration. + async def custom_async_method(self, x: str): + return x.upper() + + def stream_query(self, unused_arbitrary_string_name: str) -> Iterable[Any]: + """Runs the stream engine.""" + for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE: + yield chunk + + async def async_stream_query( + self, unused_arbitrary_string_name: str + ) -> AsyncIterable[Any]: + """Runs the async stream engine.""" + for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE: + yield chunk + + # Add a custom method to test the custom stream method registration. + def custom_stream_query(self, unused_arbitrary_string_name: str) -> Iterable[Any]: + """Runs the stream engine.""" + for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE: + yield chunk + + # Add a custom method to test the custom stream method registration. + def custom_stream_method(self, unused_arbitrary_string_name: str) -> Iterable[Any]: + for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE: + yield chunk + + async def custom_async_stream_method( + self, unused_arbitrary_string_name: str + ) -> AsyncIterable[Any]: + for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE: + yield chunk + + def clone(self): + return self + + def register_operations(self) -> Dict[str, List[str]]: + return { + _TEST_STANDARD_API_MODE: [ + _TEST_DEFAULT_METHOD_NAME, + _TEST_CUSTOM_METHOD_NAME, + ], + _TEST_ASYNC_API_MODE: [ + _TEST_DEFAULT_ASYNC_METHOD_NAME, + _TEST_CUSTOM_ASYNC_METHOD_NAME, + ], + _TEST_STREAM_API_MODE: [ + _TEST_DEFAULT_STREAM_METHOD_NAME, + _TEST_CUSTOM_STREAM_METHOD_NAME, + ], + _TEST_ASYNC_STREAM_API_MODE: [ + _TEST_DEFAULT_ASYNC_STREAM_METHOD_NAME, + _TEST_CUSTOM_ASYNC_STREAM_METHOD_NAME, + ], + } + + +class SameRegisteredOperationsEngine: + """Add a test class that is different from `OperationRegistrableEngine` but has the same registered operations.""" + + def query(self, unused_arbitrary_string_name: str) -> str: + """Runs the engine.""" + return unused_arbitrary_string_name.upper() + + async def async_query(self, unused_arbitrary_string_name: str) -> str: + """Runs the query asynchronously.""" + return unused_arbitrary_string_name.upper() + + # Add a custom method to test the custom method registration + def custom_method(self, x: str) -> str: + return x.upper() + + # Add a custom method that is not registered. + def custom_method_2(self, x: str) -> str: + return x.upper() + + # Add a custom async method to test the custom async method registration. + async def custom_async_method(self, x: str): + return x.upper() + + def stream_query(self, unused_arbitrary_string_name: str) -> Iterable[Any]: + """Runs the stream engine.""" + for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE: + yield chunk + + async def async_stream_query( + self, unused_arbitrary_string_name: str + ) -> AsyncIterable[Any]: + """Runs the async stream engine.""" + for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE: + yield chunk + + # Add a custom method to test the custom stream method registration. + def custom_stream_method(self, unused_arbitrary_string_name: str) -> Iterable[Any]: + for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE: + yield chunk + + async def custom_async_stream_method( + self, unused_arbitrary_string_name: str + ) -> AsyncIterable[Any]: + for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE: + yield chunk + + def clone(self): + return self + + def register_operations(self) -> Dict[str, List[str]]: + return { + _TEST_STANDARD_API_MODE: [ + _TEST_DEFAULT_METHOD_NAME, + _TEST_CUSTOM_METHOD_NAME, + ], + _TEST_ASYNC_API_MODE: [ + _TEST_DEFAULT_ASYNC_METHOD_NAME, + _TEST_CUSTOM_ASYNC_METHOD_NAME, + ], + _TEST_STREAM_API_MODE: [ + _TEST_DEFAULT_STREAM_METHOD_NAME, + _TEST_CUSTOM_STREAM_METHOD_NAME, + ], + _TEST_ASYNC_STREAM_API_MODE: [ + _TEST_DEFAULT_ASYNC_STREAM_METHOD_NAME, + _TEST_CUSTOM_ASYNC_STREAM_METHOD_NAME, + ], + } + + +class OperationNotRegisteredEngine: + """Add a test class that has a method that is not registered.""" + + def query(self, unused_arbitrary_string_name: str) -> str: + """Runs the engine.""" + return unused_arbitrary_string_name.upper() + + def custom_method(self, x: str) -> str: + return x.upper() + + def clone(self): + return self + + def register_operations(self) -> Dict[str, List[str]]: + # `query` method is not exported in registered operations. + return { + _TEST_STANDARD_API_MODE: [ + _TEST_CUSTOM_METHOD_NAME, + ] + } + + +class RegisteredOperationNotExistEngine: + """Add a test class that has a method that is registered but does not exist.""" + + def query(self, unused_arbitrary_string_name: str) -> str: + """Runs the engine.""" + return unused_arbitrary_string_name.upper() + + def custom_method(self, x: str) -> str: + return x.upper() + + def clone(self): + return self + + def register_operations(self) -> Dict[str, List[str]]: + # Registered method `missing_method` is not a method of the AgentEngine. + return { + _TEST_STANDARD_API_MODE: [ + _TEST_DEFAULT_METHOD_NAME, + _TEST_CUSTOM_METHOD_NAME, + "missing_method", + ] + } + + +class MethodToBeUnregisteredEngine: + """An Agent Engine that has a method to be unregistered.""" + + def method_to_be_unregistered(self, unused_arbitrary_string_name: str) -> str: + """Method to be unregistered.""" + return unused_arbitrary_string_name.upper() + + def register_operations(self) -> Dict[str, List[str]]: + # Registered method `missing_method` is not a method of the AgentEngine. + return {_TEST_STANDARD_API_MODE: [_TEST_METHOD_TO_BE_UNREGISTERED_NAME]} + + +_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) +_TEST_STAGING_BUCKET = "gs://test-bucket" +_TEST_LOCATION = "us-central1" +_TEST_PROJECT = "test-project" +_TEST_RESOURCE_ID = "1028944691210842416" +_TEST_OPERATION_ID = "4589432830794137600" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_AGENT_ENGINE_RESOURCE_NAME = ( + f"{_TEST_PARENT}/reasoningEngines/{_TEST_RESOURCE_ID}" +) +_TEST_AGENT_ENGINE_OPERATION_NAME = f"{_TEST_PARENT}/operations/{_TEST_OPERATION_ID}" +_TEST_AGENT_ENGINE_DISPLAY_NAME = "Agent Engine Display Name" +_TEST_AGENT_ENGINE_DESCRIPTION = "Agent Engine Description" +_TEST_AGENT_ENGINE_LIST_FILTER = f'display_name="{_TEST_AGENT_ENGINE_DISPLAY_NAME}"' +_TEST_GCS_DIR_NAME = _agent_engines._DEFAULT_GCS_DIR_NAME +_TEST_BLOB_FILENAME = _agent_engines._BLOB_FILENAME +_TEST_REQUIREMENTS_FILE = _agent_engines._REQUIREMENTS_FILE +_TEST_EXTRA_PACKAGES_FILE = _agent_engines._EXTRA_PACKAGES_FILE +_TEST_STANDARD_API_MODE = _agent_engines._STANDARD_API_MODE +_TEST_ASYNC_API_MODE = _agent_engines._ASYNC_API_MODE +_TEST_STREAM_API_MODE = _agent_engines._STREAM_API_MODE +_TEST_ASYNC_STREAM_API_MODE = _agent_engines._ASYNC_STREAM_API_MODE +_TEST_DEFAULT_METHOD_NAME = _agent_engines._DEFAULT_METHOD_NAME +_TEST_DEFAULT_ASYNC_METHOD_NAME = _agent_engines._DEFAULT_ASYNC_METHOD_NAME +_TEST_DEFAULT_STREAM_METHOD_NAME = _agent_engines._DEFAULT_STREAM_METHOD_NAME +_TEST_DEFAULT_ASYNC_STREAM_METHOD_NAME = ( + _agent_engines._DEFAULT_ASYNC_STREAM_METHOD_NAME +) +_TEST_CAPITALIZE_ENGINE_METHOD_DOCSTRING = "Runs the engine." +_TEST_STREAM_METHOD_DOCSTRING = "Runs the stream engine." +_TEST_ASYNC_STREAM_METHOD_DOCSTRING = "Runs the async stream engine." +_TEST_MODE_KEY_IN_SCHEMA = _agent_engines._MODE_KEY_IN_SCHEMA +_TEST_METHOD_NAME_KEY_IN_SCHEMA = _agent_engines._METHOD_NAME_KEY_IN_SCHEMA +_TEST_CUSTOM_METHOD_NAME = "custom_method" +_TEST_CUSTOM_ASYNC_METHOD_NAME = "custom_async_method" +_TEST_CUSTOM_STREAM_METHOD_NAME = "custom_stream_method" +_TEST_CUSTOM_ASYNC_STREAM_METHOD_NAME = "custom_async_stream_method" +_TEST_CUSTOM_METHOD_DEFAULT_DOCSTRING = """ + Runs the Agent Engine to serve the user request. + + This will be based on the `.custom_method(...)` of the python object that + was passed in when creating the Agent Engine. The method will invoke the + `query` API client of the python object. + + Args: + **kwargs: + Optional. The arguments of the `.custom_method(...)` method. + + Returns: + dict[str, Any]: The response from serving the user request. +""" +_TEST_CUSTOM_ASYNC_METHOD_DEFAULT_DOCSTRING = """ + Runs the Agent Engine to serve the user request. + + This will be based on the `.custom_async_method(...)` of the python object that + was passed in when creating the Agent Engine. The method will invoke the + `async_query` API client of the python object. + + Args: + **kwargs: + Optional. The arguments of the `.custom_async_method(...)` method. + + Returns: + Coroutine[Any]: The response from serving the user request. +""" +_TEST_CUSTOM_STREAM_METHOD_DEFAULT_DOCSTRING = """ + Runs the Agent Engine to serve the user request. + + This will be based on the `.custom_stream_method(...)` of the python object that + was passed in when creating the Agent Engine. The method will invoke the + `stream_query` API client of the python object. + + Args: + **kwargs: + Optional. The arguments of the `.custom_stream_method(...)` method. + + Returns: + Iterable[Any]: The response from serving the user request. +""" +_TEST_CUSTOM_ASYNC_STREAM_METHOD_DEFAULT_DOCSTRING = """ + Runs the Agent Engine to serve the user request. + + This will be based on the `.custom_async_stream_method(...)` of the python object that + was passed in when creating the Agent Engine. The method will invoke the + `async_stream_query` API client of the python object. + + Args: + **kwargs: + Optional. The arguments of the `.custom_async_stream_method(...)` method. + + Returns: + AsyncIterable[Any]: The response from serving the user request. +""" +_TEST_METHOD_TO_BE_UNREGISTERED_NAME = "method_to_be_unregistered" +_TEST_QUERY_PROMPT = "Find the first fibonacci number greater than 999" +_TEST_AGENT_ENGINE_ENV_KEY = "GOOGLE_CLOUD_AGENT_ENGINE_ENV" +_TEST_AGENT_ENGINE_ENV_VALUE = "test_env_value" +_TEST_AGENT_ENGINE_GCS_URI = "{}/{}/{}".format( + _TEST_STAGING_BUCKET, + _TEST_GCS_DIR_NAME, + _TEST_BLOB_FILENAME, +) +_TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI = "{}/{}/{}".format( + _TEST_STAGING_BUCKET, + _TEST_GCS_DIR_NAME, + _TEST_EXTRA_PACKAGES_FILE, +) +_TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI = "{}/{}/{}".format( + _TEST_STAGING_BUCKET, + _TEST_GCS_DIR_NAME, + _TEST_REQUIREMENTS_FILE, +) +_TEST_AGENT_ENGINE_REQUIREMENTS = [ + "google-cloud-aiplatform==1.29.0", + "langchain", +] +_TEST_AGENT_ENGINE_INVALID_EXTRA_PACKAGES = [ + "lib", + "main.py", +] +_TEST_AGENT_ENGINE_QUERY_SCHEMA = _utils.generate_schema( + CapitalizeEngine().query, + schema_name=_TEST_DEFAULT_METHOD_NAME, +) +_TEST_AGENT_ENGINE_QUERY_SCHEMA[_TEST_MODE_KEY_IN_SCHEMA] = _TEST_STANDARD_API_MODE +_TEST_PYTHON_VERSION = f"{sys.version_info.major}.{sys.version_info.minor}" +_TEST_AGENT_ENGINE_FRAMEWORK = _agent_engines._DEFAULT_AGENT_FRAMEWORK +_TEST_AGENT_ENGINE_CLASS_METHOD_1 = { + "description": "Runs the engine.", + "name": "query", + "parameters": { + "type": "object", + "properties": { + "unused_arbitrary_string_name": {"type": "string"}, + }, + "required": ["unused_arbitrary_string_name"], + }, + "api_mode": "", +} +_TEST_AGENT_ENGINE_CLASS_METHOD_ASYNC_QUERY = { + "description": "Runs the engine.", + "name": "async_query", + "parameters": { + "type": "object", + "properties": { + "unused_arbitrary_string_name": {"type": "string"}, + }, + "required": ["unused_arbitrary_string_name"], + }, + "api_mode": "async", +} +_TEST_AGENT_ENGINE_CLASS_METHOD_STREAM_QUERY = { + "description": "Runs the engine.", + "name": "stream_query", + "parameters": { + "type": "object", + "properties": { + "unused_arbitrary_string_name": {"type": "string"}, + }, + "required": ["unused_arbitrary_string_name"], + }, + "api_mode": "stream", +} +_TEST_AGENT_ENGINE_CLASS_METHOD_ASYNC_STREAM_QUERY = { + "description": "Runs the engine.", + "name": "async_stream_query", + "parameters": { + "type": "object", + "properties": { + "unused_arbitrary_string_name": {"type": "string"}, + }, + "required": ["unused_arbitrary_string_name"], + }, + "api_mode": "async_stream", +} +_TEST_AGENT_ENGINE_ENV_VARS_INPUT = { + "TEST_ENV_VAR": "TEST_ENV_VAR_VALUE", + "TEST_ENV_VAR_2": "TEST_ENV_VAR_VALUE_2", + "TEST_SECRET_ENV_VAR": { + "secret": "TEST_SECRET_NAME_1", + "version": "TEST_SECRET_VERSION_1", + }, +} +_TEST_AGENT_ENGINE_SPEC = _genai_types.ReasoningEngineSpecDict( + agent_framework=_TEST_AGENT_ENGINE_FRAMEWORK, + class_methods=[_TEST_AGENT_ENGINE_CLASS_METHOD_1], + deployment_spec={ + "env": [ + {"name": "TEST_ENV_VAR", "value": "TEST_ENV_VAR_VALUE"}, + {"name": "TEST_ENV_VAR_2", "value": "TEST_ENV_VAR_VALUE_2"}, + ], + "secret_env": [ + { + "name": "TEST_SECRET_ENV_VAR", + "secret_ref": { + "secret": "TEST_SECRET_NAME_1", + "version": "TEST_SECRET_VERSION_1", + }, + }, + ], + }, + package_spec=_genai_types.ReasoningEngineSpecPackageSpecDict( + python_version=f"{sys.version_info.major}.{sys.version_info.minor}", + pickle_object_gcs_uri=_TEST_AGENT_ENGINE_GCS_URI, + dependency_files_gcs_uri=_TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI, + requirements_gcs_uri=_TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI, + ), +) +_TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE = [{"output": "hello"}, {"output": "world"}] +_TEST_AGENT_ENGINE_OPERATION_SCHEMAS = [] +_TEST_AGENT_ENGINE_EXTRA_PACKAGE = "fake.py" +_TEST_AGENT_ENGINE_ASYNC_METHOD_SCHEMA = _utils.generate_schema( + AsyncQueryEngine().async_query, + schema_name=_TEST_DEFAULT_ASYNC_METHOD_NAME, +) +_TEST_AGENT_ENGINE_ASYNC_METHOD_SCHEMA[_TEST_MODE_KEY_IN_SCHEMA] = _TEST_ASYNC_API_MODE +_TEST_AGENT_ENGINE_CUSTOM_METHOD_SCHEMA = _utils.generate_schema( + OperationRegistrableEngine().custom_method, + schema_name=_TEST_CUSTOM_METHOD_NAME, +) +_TEST_AGENT_ENGINE_CUSTOM_METHOD_SCHEMA[ + _TEST_MODE_KEY_IN_SCHEMA +] = _TEST_STANDARD_API_MODE +_TEST_AGENT_ENGINE_ASYNC_CUSTOM_METHOD_SCHEMA = _utils.generate_schema( + OperationRegistrableEngine().custom_async_method, + schema_name=_TEST_CUSTOM_ASYNC_METHOD_NAME, +) +_TEST_AGENT_ENGINE_ASYNC_CUSTOM_METHOD_SCHEMA[ + _TEST_MODE_KEY_IN_SCHEMA +] = _TEST_ASYNC_API_MODE +_TEST_AGENT_ENGINE_STREAM_QUERY_SCHEMA = _utils.generate_schema( + StreamQueryEngine().stream_query, + schema_name=_TEST_DEFAULT_STREAM_METHOD_NAME, +) +_TEST_AGENT_ENGINE_STREAM_QUERY_SCHEMA[_TEST_MODE_KEY_IN_SCHEMA] = _TEST_STREAM_API_MODE +_TEST_AGENT_ENGINE_CUSTOM_STREAM_QUERY_SCHEMA = _utils.generate_schema( + OperationRegistrableEngine().custom_stream_method, + schema_name=_TEST_CUSTOM_STREAM_METHOD_NAME, +) +_TEST_AGENT_ENGINE_CUSTOM_STREAM_QUERY_SCHEMA[ + _TEST_MODE_KEY_IN_SCHEMA +] = _TEST_STREAM_API_MODE +_TEST_AGENT_ENGINE_ASYNC_STREAM_QUERY_SCHEMA = _utils.generate_schema( + AsyncStreamQueryEngine().async_stream_query, + schema_name=_TEST_DEFAULT_ASYNC_STREAM_METHOD_NAME, +) +_TEST_AGENT_ENGINE_ASYNC_STREAM_QUERY_SCHEMA[ + _TEST_MODE_KEY_IN_SCHEMA +] = _TEST_ASYNC_STREAM_API_MODE +_TEST_AGENT_ENGINE_CUSTOM_ASYNC_STREAM_QUERY_SCHEMA = _utils.generate_schema( + OperationRegistrableEngine().custom_async_stream_method, + schema_name=_TEST_CUSTOM_ASYNC_STREAM_METHOD_NAME, +) +_TEST_AGENT_ENGINE_CUSTOM_ASYNC_STREAM_QUERY_SCHEMA[ + _TEST_MODE_KEY_IN_SCHEMA +] = _TEST_ASYNC_STREAM_API_MODE +_TEST_OPERATION_REGISTRABLE_SCHEMAS = [ + _TEST_AGENT_ENGINE_QUERY_SCHEMA, + _TEST_AGENT_ENGINE_CUSTOM_METHOD_SCHEMA, + _TEST_AGENT_ENGINE_ASYNC_METHOD_SCHEMA, + _TEST_AGENT_ENGINE_ASYNC_CUSTOM_METHOD_SCHEMA, + _TEST_AGENT_ENGINE_STREAM_QUERY_SCHEMA, + _TEST_AGENT_ENGINE_CUSTOM_STREAM_QUERY_SCHEMA, + _TEST_AGENT_ENGINE_ASYNC_STREAM_QUERY_SCHEMA, + _TEST_AGENT_ENGINE_CUSTOM_ASYNC_STREAM_QUERY_SCHEMA, +] +_TEST_OPERATION_NOT_REGISTERED_SCHEMAS = [ + _TEST_AGENT_ENGINE_CUSTOM_METHOD_SCHEMA, +] +_TEST_REGISTERED_OPERATION_NOT_EXIST_SCHEMAS = [ + _TEST_AGENT_ENGINE_QUERY_SCHEMA, + _TEST_AGENT_ENGINE_CUSTOM_METHOD_SCHEMA, +] +_TEST_NO_OPERATION_REGISTRABLE_SCHEMAS = [ + _TEST_AGENT_ENGINE_QUERY_SCHEMA, +] +_TEST_METHOD_TO_BE_UNREGISTERED_SCHEMA = _utils.generate_schema( + MethodToBeUnregisteredEngine().method_to_be_unregistered, + schema_name=_TEST_METHOD_TO_BE_UNREGISTERED_NAME, +) +_TEST_METHOD_TO_BE_UNREGISTERED_SCHEMA[ + _TEST_MODE_KEY_IN_SCHEMA +] = _TEST_STANDARD_API_MODE +_TEST_ASYNC_QUERY_SCHEMAS = [_TEST_AGENT_ENGINE_ASYNC_METHOD_SCHEMA] +_TEST_STREAM_QUERY_SCHEMAS = [ + _TEST_AGENT_ENGINE_STREAM_QUERY_SCHEMA, +] +_TEST_ASYNC_STREAM_QUERY_SCHEMAS = [ + _TEST_AGENT_ENGINE_ASYNC_STREAM_QUERY_SCHEMA, +] +_TEST_PACKAGE_DISTRIBUTIONS = { + "requests": ["requests"], + "cloudpickle": ["cloudpickle"], + "pydantic": ["pydantic"], +} +_TEST_OPERATION_NAME = "test_operation_name" + + +def _create_empty_fake_package(package_name: str) -> str: + """Creates a temporary directory structure representing an empty fake Python package. + + Args: + package_name (str): The name of the fake package. + + Returns: + str: The path to the top-level directory of the fake package. + """ + temp_dir = tempfile.mkdtemp() + package_dir = os.path.join(temp_dir, package_name) + os.makedirs(package_dir) + + # Create an empty __init__.py file to mark it as a package + init_path = os.path.join(package_dir, "__init__.py") + open(init_path, "w").close() + + return temp_dir + + +_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH = _create_empty_fake_package( + _TEST_AGENT_ENGINE_EXTRA_PACKAGE +) + + +@pytest.fixture(scope="module") +def google_auth_mock(): + with mock.patch.object(auth, "default") as google_auth_mock: + google_auth_mock.return_value = ( + auth_credentials.AnonymousCredentials(), + _TEST_PROJECT, + ) + yield google_auth_mock + + +@pytest.fixture(scope="module") +def importlib_metadata_version_mock(): + with mock.patch.object( + importlib.metadata, "version" + ) as importlib_metadata_version_mock: + + def get_version(pkg): + versions = { + "requests": "2.0.0", + "cloudpickle": "3.0.0", + "pydantic": "1.11.1", + } + return versions.get(pkg, "unknown") + + importlib_metadata_version_mock.side_effect = get_version + yield importlib_metadata_version_mock + + +class InvalidCapitalizeEngineWithoutQuerySelf: + """A sample Agent Engine with an invalid query method.""" + + def set_up(self): + pass + + def query() -> str: + """Runs the engine.""" + return "RESPONSE" + + +class InvalidCapitalizeEngineWithoutAsyncQuerySelf: + """A sample Agent Engine with an invalid async_query method.""" + + def set_up(self): + pass + + async def async_query() -> str: + """Runs the engine.""" + return "RESPONSE" + + +class InvalidCapitalizeEngineWithoutStreamQuerySelf: + """A sample Agent Engine with an invalid query_stream_query method.""" + + def set_up(self): + pass + + def stream_query() -> str: + """Runs the engine.""" + return "RESPONSE" + + +class InvalidCapitalizeEngineWithoutAsyncStreamQuerySelf: + """A sample Agent Engine with an invalid async_stream_query method.""" + + def set_up(self): + pass + + async def async_stream_query() -> str: + """Runs the engine.""" + return "RESPONSE" + + +class InvalidCapitalizeEngineWithoutRegisterOperationsSelf: + """A sample Agent Engine with an invalid register_operations method.""" + + def set_up(self): + pass + + def register_operations() -> str: + """Runs the engine.""" + return "RESPONSE" + + +class InvalidCapitalizeEngineWithoutQueryMethod: + """A sample Agent Engine without a query method.""" + + def set_up(self): + pass + + def invoke(self) -> str: + """Runs the engine.""" + return "RESPONSE" + + +class InvalidCapitalizeEngineWithNoncallableQueryStreamQuery: + """A sample Agent Engine with a noncallable query attribute.""" + + def __init__(self): + self.query = "RESPONSE" + + def set_up(self): + pass + + +@pytest.mark.usefixtures("google_auth_mock") +class TestAgentEngineHelpers: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + importlib.reload(vertexai) + importlib.reload(os) + os.environ[_TEST_AGENT_ENGINE_ENV_KEY] = _TEST_AGENT_ENGINE_ENV_VALUE + self.client = vertexai.Client( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + self.test_agent = CapitalizeEngine() + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @mock.patch.object(_agent_engines, "_prepare") + def test_create_agent_engine_config_lightweight(self, mock_prepare): + config = self.client.agent_engines._create_config( + mode="create", + staging_bucket=_TEST_STAGING_BUCKET, + display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME, + description=_TEST_AGENT_ENGINE_DESCRIPTION, + ) + assert config == { + "display_name": _TEST_AGENT_ENGINE_DISPLAY_NAME, + "description": _TEST_AGENT_ENGINE_DESCRIPTION, + } + + @mock.patch.object(_agent_engines, "_prepare") + def test_create_agent_engine_config_full(self, mock_prepare): + config = self.client.agent_engines._create_config( + mode="create", + agent_engine=self.test_agent, + staging_bucket=_TEST_STAGING_BUCKET, + requirements=_TEST_AGENT_ENGINE_REQUIREMENTS, + display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME, + description=_TEST_AGENT_ENGINE_DESCRIPTION, + gcs_dir_name=_TEST_GCS_DIR_NAME, + extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH], + env_vars=_TEST_AGENT_ENGINE_ENV_VARS_INPUT, + ) + assert config["display_name"] == _TEST_AGENT_ENGINE_DISPLAY_NAME + assert config["description"] == _TEST_AGENT_ENGINE_DESCRIPTION + assert config["spec"]["agent_framework"] == "custom" + assert config["spec"]["package_spec"] == { + "python_version": _TEST_PYTHON_VERSION, + "pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI, + "dependency_files_gcs_uri": _TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI, + "requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI, + } + assert config["spec"]["deployment_spec"] == { + "env": [ + {"name": "TEST_ENV_VAR", "value": "TEST_ENV_VAR_VALUE"}, + {"name": "TEST_ENV_VAR_2", "value": "TEST_ENV_VAR_VALUE_2"}, + ], + "secret_env": [ + { + "name": "TEST_SECRET_ENV_VAR", + "secret_ref": { + "secret": "TEST_SECRET_NAME_1", + "version": "TEST_SECRET_VERSION_1", + }, + }, + ], + } + assert config["spec"]["class_methods"] == [_TEST_AGENT_ENGINE_CLASS_METHOD_1] + + @mock.patch.object(_agent_engines, "_prepare") + def test_update_agent_engine_config_full(self, mock_prepare): + config = self.client.agent_engines._create_config( + mode="update", + agent_engine=self.test_agent, + staging_bucket=_TEST_STAGING_BUCKET, + requirements=_TEST_AGENT_ENGINE_REQUIREMENTS, + display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME, + description=_TEST_AGENT_ENGINE_DESCRIPTION, + gcs_dir_name=_TEST_GCS_DIR_NAME, + extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH], + env_vars=_TEST_AGENT_ENGINE_ENV_VARS_INPUT, + ) + assert config["display_name"] == _TEST_AGENT_ENGINE_DISPLAY_NAME + assert config["description"] == _TEST_AGENT_ENGINE_DESCRIPTION + assert config["spec"]["agent_framework"] == "custom" + assert config["spec"]["package_spec"] == { + "python_version": _TEST_PYTHON_VERSION, + "pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI, + "dependency_files_gcs_uri": _TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI, + "requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI, + } + assert config["spec"]["deployment_spec"] == { + "env": [ + {"name": "TEST_ENV_VAR", "value": "TEST_ENV_VAR_VALUE"}, + {"name": "TEST_ENV_VAR_2", "value": "TEST_ENV_VAR_VALUE_2"}, + ], + "secret_env": [ + { + "name": "TEST_SECRET_ENV_VAR", + "secret_ref": { + "secret": "TEST_SECRET_NAME_1", + "version": "TEST_SECRET_VERSION_1", + }, + }, + ], + } + assert config["spec"]["class_methods"] == [_TEST_AGENT_ENGINE_CLASS_METHOD_1] + assert config["update_mask"] == ",".join( + [ + "display_name", + "description", + "spec.package_spec.pickle_object_gcs_uri", + "spec.package_spec.dependency_files_gcs_uri", + "spec.package_spec.requirements_gcs_uri", + "spec.deployment_spec.env", + "spec.deployment_spec.secret_env", + "spec.class_methods", + "spec.agent_framework", + ] + ) + + def test_get_operation(self): + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse( + body=json.dumps( + { + "name": _TEST_AGENT_ENGINE_OPERATION_NAME, + "done": True, + "response": _TEST_AGENT_ENGINE_SPEC, + } + ), + ) + operation = self.client.agent_engines._get_operation( + operation_name=_TEST_AGENT_ENGINE_OPERATION_NAME, + ) + request_mock.assert_called_with( + "get", + _TEST_AGENT_ENGINE_OPERATION_NAME, + {"_url": {"operationName": _TEST_AGENT_ENGINE_OPERATION_NAME}}, + None, + ) + assert isinstance(operation, _genai_types.AgentEngineOperation) + assert operation.done + assert isinstance(operation.response, _genai_types.ReasoningEngine) + + def test_await_operation(self): + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse( + body=json.dumps( + { + "name": _TEST_AGENT_ENGINE_OPERATION_NAME, + "done": True, + "response": _TEST_AGENT_ENGINE_SPEC, + } + ), + ) + agent_engine = self.client.agent_engines._await_operation( + operation_name=_TEST_AGENT_ENGINE_OPERATION_NAME, + ) + request_mock.assert_called_with( + "get", + _TEST_AGENT_ENGINE_OPERATION_NAME, + {"_url": {"operationName": _TEST_AGENT_ENGINE_OPERATION_NAME}}, + None, + ) + assert isinstance(agent_engine, _genai_types.AgentEngine) + + def test_register_api_methods(self): + agent = self.client.agent_engines._register_api_methods( + agent=_genai_types.AgentEngine( + api_client=self.client.agent_engines._api_client, + api_resource=_genai_types.ReasoningEngine( + spec=_genai_types.ReasoningEngineSpec( + class_methods=[ + _TEST_AGENT_ENGINE_CLASS_METHOD_1, + ] + ), + ), + ) + ) + assert agent.query.__doc__ == _TEST_AGENT_ENGINE_CLASS_METHOD_1.get( + "description" + ) + + +@pytest.mark.usefixtures("google_auth_mock") +class TestAgentEngine: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + importlib.reload(vertexai) + importlib.reload(os) + os.environ[_TEST_AGENT_ENGINE_ENV_KEY] = _TEST_AGENT_ENGINE_ENV_VALUE + self.client = vertexai.Client( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + self.test_agent = CapitalizeEngine() + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_get_agent_engine(self): + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + self.client.agent_engines.get(name=_TEST_AGENT_ENGINE_RESOURCE_NAME) + request_mock.assert_called_with( + "get", + _TEST_AGENT_ENGINE_RESOURCE_NAME, + {"_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}}, + None, + ) + + def test_list_agent_engine(self): + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + expected_query_params = {"filter": _TEST_AGENT_ENGINE_LIST_FILTER} + list(self.client.agent_engines.list(config=expected_query_params)) + request_mock.assert_called_with( + "get", + f"reasoningEngines?{urlencode(expected_query_params)}", + {"_query": expected_query_params}, + None, + ) + + @mock.patch.object(_agent_engines, "_prepare") + @mock.patch.object(agent_engines.AgentEngines, "_await_operation") + def test_create_agent_engine(self, mock_await_operation, mock_prepare): + mock_await_operation.return_value = _genai_types.AgentEngine() + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + self.client.agent_engines.create( + agent_engine=self.test_agent, + config=_genai_types.AgentEngineConfig( + display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME, + description=_TEST_AGENT_ENGINE_DESCRIPTION, + requirements=_TEST_AGENT_ENGINE_REQUIREMENTS, + extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH], + staging_bucket=_TEST_STAGING_BUCKET, + gcs_dir_name=_TEST_GCS_DIR_NAME, + env_vars=_TEST_AGENT_ENGINE_ENV_VARS_INPUT, + ), + ) + request_mock.assert_called_with( + "post", + "reasoningEngines", + { + "displayName": _TEST_AGENT_ENGINE_DISPLAY_NAME, + "description": _TEST_AGENT_ENGINE_DESCRIPTION, + "spec": { + "agentFramework": _TEST_AGENT_ENGINE_FRAMEWORK, + "classMethods": mock.ANY, # dict ordering was too flakey + "deploymentSpec": _TEST_AGENT_ENGINE_SPEC.get( + "deployment_spec" + ), + "packageSpec": _TEST_AGENT_ENGINE_SPEC.get("package_spec"), + }, + }, + None, + ) + + @mock.patch.object(agent_engines.AgentEngines, "_create_config") + @mock.patch.object(agent_engines.AgentEngines, "_await_operation") + def test_create_agent_engine_lightweight( + self, + mock_await_operation, + mock_create_config, + ): + mock_create_config.return_value = _genai_types.CreateAgentEngineConfig( + display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME, + description=_TEST_AGENT_ENGINE_DESCRIPTION, + ) + mock_await_operation.return_value = _genai_types.AgentEngine() + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + self.client.agent_engines.create( + config=_genai_types.AgentEngineConfig( + display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME, + description=_TEST_AGENT_ENGINE_DESCRIPTION, + ) + ) + request_mock.assert_called_with( + "post", + "reasoningEngines", + { + "displayName": _TEST_AGENT_ENGINE_DISPLAY_NAME, + "description": _TEST_AGENT_ENGINE_DESCRIPTION, + }, + None, + ) + + @mock.patch.object(agent_engines.AgentEngines, "_create_config") + @mock.patch.object(agent_engines.AgentEngines, "_await_operation") + def test_create_agent_engine_with_env_vars_dict( + self, + mock_await_operation, + mock_create_config, + ): + mock_create_config.return_value = { + "display_name": _TEST_AGENT_ENGINE_DISPLAY_NAME, + "description": _TEST_AGENT_ENGINE_DESCRIPTION, + "spec": { + "package_spec": { + "python_version": _TEST_PYTHON_VERSION, + "pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI, + "requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI, + }, + "class_methods": [_TEST_AGENT_ENGINE_CLASS_METHOD_1], + "agent_framework": _TEST_AGENT_ENGINE_FRAMEWORK, + }, + # "update_mask": "display_name,spec.package_spec.pickle_object_gcs_uri,spec.package_spec.requirements_gcs_uri", + } + mock_await_operation.return_value = _genai_types.AgentEngineOperation() + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + self.client.agent_engines.create( + agent_engine=self.test_agent, + config=_genai_types.AgentEngineConfig( + display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME, + requirements=_TEST_AGENT_ENGINE_REQUIREMENTS, + extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH], + env_vars=_TEST_AGENT_ENGINE_ENV_VARS_INPUT, + staging_bucket=_TEST_STAGING_BUCKET, + return_agent=False, + ), + ) + mock_create_config.assert_called_with( + mode="create", + agent_engine=self.test_agent, + staging_bucket=_TEST_STAGING_BUCKET, + requirements=_TEST_AGENT_ENGINE_REQUIREMENTS, + display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME, + description=None, + gcs_dir_name=None, + extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH], + env_vars=_TEST_AGENT_ENGINE_ENV_VARS_INPUT, + ) + request_mock.assert_called_with( + "post", + "reasoningEngines", + { + "displayName": _TEST_AGENT_ENGINE_DISPLAY_NAME, + "description": _TEST_AGENT_ENGINE_DESCRIPTION, + "spec": { + "agentFramework": _TEST_AGENT_ENGINE_FRAMEWORK, + "classMethods": [_TEST_AGENT_ENGINE_CLASS_METHOD_1], + "packageSpec": { + "pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI, + "python_version": _TEST_PYTHON_VERSION, + "requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI, + }, + }, + }, + None, + ) + + @mock.patch.object(agent_engines.AgentEngines, "_create") + @mock.patch.object(agent_engines.AgentEngines, "_create_config") + def test_create_agent_engine_operation( + self, + mock_create_config, + mock_create, + ): + mock_create.return_value = _genai_types.AgentEngineOperation( + name=_TEST_OPERATION_NAME, + ) + operation = self.client.agent_engines.create( + config=_genai_types.AgentEngineConfig(return_agent=False) + ) + assert operation.name == _TEST_OPERATION_NAME + + @mock.patch.object(_agent_engines, "_prepare") + @mock.patch.object(agent_engines.AgentEngines, "_await_operation") + def test_update_agent_engine_requirements(self, mock_await_operation, mock_prepare): + mock_await_operation.return_value = _genai_types.AgentEngine() + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + self.client.agent_engines.update( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + agent_engine=self.test_agent, + config=_genai_types.AgentEngineConfig( + staging_bucket=_TEST_STAGING_BUCKET, + requirements=_TEST_AGENT_ENGINE_REQUIREMENTS, + ), + ) + update_mask = ",".join( + [ + "spec.package_spec.pickle_object_gcs_uri", + "spec.package_spec.requirements_gcs_uri", + "spec.class_methods", + "spec.agent_framework", + ] + ) + query_params = {"updateMask": update_mask} + request_mock.assert_called_with( + "patch", + f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}?{urlencode(query_params)}", + { + "_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}, + "spec": { + "agentFramework": _TEST_AGENT_ENGINE_FRAMEWORK, + "classMethods": mock.ANY, + "packageSpec": { + "python_version": _TEST_PYTHON_VERSION, + "pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI, + "requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI, + }, + }, + "_query": {"updateMask": update_mask}, + }, + None, + ) + + @mock.patch.object(_agent_engines, "_prepare") + @mock.patch.object(agent_engines.AgentEngines, "_await_operation") + def test_update_agent_engine_extra_packages( + self, mock_await_operation, mock_prepare + ): + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + self.client.agent_engines.update( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + agent_engine=self.test_agent, + config=_genai_types.AgentEngineConfig( + staging_bucket=_TEST_STAGING_BUCKET, + requirements=_TEST_AGENT_ENGINE_REQUIREMENTS, + extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH], + ), + ) + update_mask = ",".join( + [ + "spec.package_spec.pickle_object_gcs_uri", + "spec.package_spec.dependency_files_gcs_uri", + "spec.package_spec.requirements_gcs_uri", + "spec.class_methods", + "spec.agent_framework", + ] + ) + query_params = {"updateMask": update_mask} + request_mock.assert_called_with( + "patch", + f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}?{urlencode(query_params)}", + { + "_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}, + "spec": { + "agentFramework": _TEST_AGENT_ENGINE_FRAMEWORK, + "classMethods": mock.ANY, + "packageSpec": { + "python_version": _TEST_PYTHON_VERSION, + "pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI, + "dependency_files_gcs_uri": _TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI, + "requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI, + }, + }, + "_query": {"updateMask": update_mask}, + }, + None, + ) + + @mock.patch.object(_agent_engines, "_prepare") + @mock.patch.object(agent_engines.AgentEngines, "_await_operation") + def test_update_agent_engine_env_vars(self, mock_await_operation, mock_prepare): + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + self.client.agent_engines.update( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + agent_engine=self.test_agent, + config=_genai_types.AgentEngineConfig( + staging_bucket=_TEST_STAGING_BUCKET, + requirements=_TEST_AGENT_ENGINE_REQUIREMENTS, + env_vars=_TEST_AGENT_ENGINE_ENV_VARS_INPUT, + ), + ) + update_mask = ",".join( + [ + "spec.package_spec.pickle_object_gcs_uri", + "spec.package_spec.requirements_gcs_uri", + "spec.deployment_spec.env", + "spec.deployment_spec.secret_env", + "spec.class_methods", + "spec.agent_framework", + ] + ) + query_params = {"updateMask": update_mask} + request_mock.assert_called_with( + "patch", + f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}?{urlencode(query_params)}", + { + "_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}, + "spec": { + "agentFramework": _TEST_AGENT_ENGINE_FRAMEWORK, + "classMethods": mock.ANY, + "packageSpec": { + "python_version": _TEST_PYTHON_VERSION, + "pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI, + "requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI, + }, + "deploymentSpec": _TEST_AGENT_ENGINE_SPEC.get( + "deployment_spec" + ), + }, + "_query": {"updateMask": update_mask}, + }, + None, + ) + + @mock.patch.object(agent_engines.AgentEngines, "_await_operation") + def test_update_agent_engine_display_name(self, mock_await_operation): + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + self.client.agent_engines.update( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + config=_genai_types.AgentEngineConfig( + display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME, + ), + ) + request_mock.assert_called_with( + "patch", + f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}?updateMask=display_name", + { + "_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}, + "displayName": _TEST_AGENT_ENGINE_DISPLAY_NAME, + "_query": {"updateMask": "display_name"}, + }, + None, + ) + + @mock.patch.object(agent_engines.AgentEngines, "_await_operation") + def test_update_agent_engine_description(self, mock_await_operation): + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + self.client.agent_engines.update( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + config=_genai_types.AgentEngineConfig( + description=_TEST_AGENT_ENGINE_DESCRIPTION, + ), + ) + request_mock.assert_called_with( + "patch", + f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}?updateMask=description", + { + "_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}, + "description": _TEST_AGENT_ENGINE_DESCRIPTION, + "_query": {"updateMask": "description"}, + }, + None, + ) + + def test_delete_agent_engine(self): + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + self.client.agent_engines.delete(name=_TEST_AGENT_ENGINE_RESOURCE_NAME) + request_mock.assert_called_with( + "delete", + _TEST_AGENT_ENGINE_RESOURCE_NAME, + {"_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}}, + None, + ) + + def test_delete_agent_engine_force(self): + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + self.client.agent_engines.delete( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + force=True, + ) + request_mock.assert_called_with( + "delete", + _TEST_AGENT_ENGINE_RESOURCE_NAME, + {"_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}, "force": True}, + None, + ) + + def test_query_agent_engine(self): + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + agent = self.client.agent_engines._register_api_methods( + agent=_genai_types.AgentEngine( + api_client=self.client.agent_engines, + api_resource=_genai_types.ReasoningEngine( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + spec=_genai_types.ReasoningEngineSpec( + class_methods=[ + _TEST_AGENT_ENGINE_CLASS_METHOD_1, + ] + ), + ), + ) + ) + agent.query(query=_TEST_QUERY_PROMPT) + request_mock.assert_called_with( + "post", + f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}:query", + { + "_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}, + "classMethod": "query", + "input": {"query": _TEST_QUERY_PROMPT}, + }, + None, + ) + + def test_query_agent_engine_async(self): + agent = self.client.agent_engines._register_api_methods( + agent=_genai_types.AgentEngine( + api_async_client=agent_engines.AsyncAgentEngines( + api_client_=self.client.agent_engines._api_client + ), + api_resource=_genai_types.ReasoningEngine( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + spec=_genai_types.ReasoningEngineSpec( + class_methods=[ + _TEST_AGENT_ENGINE_CLASS_METHOD_ASYNC_QUERY, + ] + ), + ), + ) + ) + with mock.patch.object( + self.client.agent_engines._api_client, "async_request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + asyncio.run(agent.async_query(query=_TEST_QUERY_PROMPT)) + request_mock.assert_called_with( + "post", + f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}:query", + { + "_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}, + "classMethod": "async_query", + "input": {"query": _TEST_QUERY_PROMPT}, + }, + None, + ) + + def test_query_agent_engine_stream(self): + with mock.patch.object( + self.client.agent_engines._api_client, "request_streamed" + ) as request_mock: + agent = self.client.agent_engines._register_api_methods( + agent=_genai_types.AgentEngine( + api_client=self.client.agent_engines, + api_resource=_genai_types.ReasoningEngine( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + spec=_genai_types.ReasoningEngineSpec( + class_methods=[ + _TEST_AGENT_ENGINE_CLASS_METHOD_STREAM_QUERY, + ] + ), + ), + ) + ) + list(agent.stream_query(query=_TEST_QUERY_PROMPT)) + request_mock.assert_called_with( + "post", + f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}:streamQuery?alt=sse", + { + "_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}, + "classMethod": "stream_query", + "input": {"query": _TEST_QUERY_PROMPT}, + }, + None, + ) + + def test_query_agent_engine_async_stream(self): + with mock.patch.object( + self.client.agent_engines._api_client, "request_streamed" + ) as request_mock: + agent = self.client.agent_engines._register_api_methods( + agent=_genai_types.AgentEngine( + api_client=self.client.agent_engines, + api_resource=_genai_types.ReasoningEngine( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + spec=_genai_types.ReasoningEngineSpec( + class_methods=[ + _TEST_AGENT_ENGINE_CLASS_METHOD_ASYNC_STREAM_QUERY, + ] + ), + ), + ) + ) + + async def consume(): + async for response in agent.async_stream_query( + query=_TEST_QUERY_PROMPT + ): + print(response) + + asyncio.run(consume()) + request_mock.assert_called_with( + "post", + f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}:streamQuery?alt=sse", + { + "_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}, + "classMethod": "async_stream_query", + "input": {"query": _TEST_QUERY_PROMPT}, + }, + None, + ) + + @pytest.mark.parametrize( + "test_case_name, test_class_methods_spec, want_operation_schema_api_modes", + [ + ( + "Default (Not Operation Registrable) Engine", + _TEST_NO_OPERATION_REGISTRABLE_SCHEMAS, + [ + ( + _utils.generate_schema( + CapitalizeEngine().query, + schema_name=_TEST_DEFAULT_METHOD_NAME, + ), + _TEST_STANDARD_API_MODE, + ) + ], + ), + ( + "Operation Registrable Engine", + _TEST_OPERATION_REGISTRABLE_SCHEMAS, + [ + ( + _utils.generate_schema( + OperationRegistrableEngine().query, + schema_name=_TEST_DEFAULT_METHOD_NAME, + ), + _TEST_STANDARD_API_MODE, + ), + ( + _utils.generate_schema( + OperationRegistrableEngine().custom_method, + schema_name=_TEST_CUSTOM_METHOD_NAME, + ), + _TEST_STANDARD_API_MODE, + ), + ( + _utils.generate_schema( + OperationRegistrableEngine().async_query, + schema_name=_TEST_DEFAULT_ASYNC_METHOD_NAME, + ), + _TEST_ASYNC_API_MODE, + ), + ( + _utils.generate_schema( + OperationRegistrableEngine().custom_async_method, + schema_name=_TEST_CUSTOM_ASYNC_METHOD_NAME, + ), + _TEST_ASYNC_API_MODE, + ), + ( + _utils.generate_schema( + OperationRegistrableEngine().stream_query, + schema_name=_TEST_DEFAULT_STREAM_METHOD_NAME, + ), + _TEST_STREAM_API_MODE, + ), + ( + _utils.generate_schema( + OperationRegistrableEngine().custom_stream_method, + schema_name=_TEST_CUSTOM_STREAM_METHOD_NAME, + ), + _TEST_STREAM_API_MODE, + ), + ( + _utils.generate_schema( + OperationRegistrableEngine().async_stream_query, + schema_name=_TEST_DEFAULT_ASYNC_STREAM_METHOD_NAME, + ), + _TEST_ASYNC_STREAM_API_MODE, + ), + ( + _utils.generate_schema( + OperationRegistrableEngine().custom_async_stream_method, + schema_name=_TEST_CUSTOM_ASYNC_STREAM_METHOD_NAME, + ), + _TEST_ASYNC_STREAM_API_MODE, + ), + ], + ), + ( + "Operation Not Registered Engine", + _TEST_OPERATION_NOT_REGISTERED_SCHEMAS, + [ + ( + _utils.generate_schema( + OperationNotRegisteredEngine().custom_method, + schema_name=_TEST_CUSTOM_METHOD_NAME, + ), + _TEST_STANDARD_API_MODE, + ), + ], + ), + ], + ) + @mock.patch.object(genai_client.Client, "_get_api_client") + @mock.patch.object(agent_engines.AgentEngines, "_get") + def test_operation_schemas( + self, + mock_get, + mock_get_api_client, + test_case_name, + test_class_methods_spec, + want_operation_schema_api_modes, + ): + test_agent_engine = _genai_types.AgentEngine( + api_resource=_genai_types.ReasoningEngine( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + spec=_genai_types.ReasoningEngineSpec( + class_methods=test_class_methods_spec, + ), + ), + ) + want_operation_schemas = [] + for want_operation_schema, api_mode in want_operation_schema_api_modes: + want_operation_schema[_TEST_MODE_KEY_IN_SCHEMA] = api_mode + want_operation_schemas.append(want_operation_schema) + assert test_agent_engine.operation_schemas() == want_operation_schemas + + +@pytest.mark.usefixtures("google_auth_mock") +class TestAgentEngineErrors: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + importlib.reload(vertexai) + self.client = vertexai.Client( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + self.test_agent = CapitalizeEngine() + + @pytest.mark.parametrize( + "test_case_name, test_operation_schemas, want_log_output", + [ + ( + "No API mode in operation schema", + [ + { + _TEST_METHOD_NAME_KEY_IN_SCHEMA: _TEST_DEFAULT_METHOD_NAME, + }, + ], + ( + "Failed to register API methods. Please follow the guide to " + "register the API methods: " + "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#custom-methods. " + "Error: {Operation schema {'name': 'query'} does not " + "contain an `api_mode` field.}" + ), + ), + ( + "No method name in operation schema", + [ + { + _TEST_MODE_KEY_IN_SCHEMA: _TEST_STANDARD_API_MODE, + }, + ], + ( + "Failed to register API methods. Please follow the guide to " + "register the API methods: " + "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#custom-methods. " + "Error: {Operation schema {'api_mode': ''} does not " + "contain a `name` field.}" + ), + ), + ( + "Unknown API mode in operation schema", + [ + { + _TEST_MODE_KEY_IN_SCHEMA: "UNKNOWN_API_MODE", + _TEST_METHOD_NAME_KEY_IN_SCHEMA: _TEST_DEFAULT_METHOD_NAME, + }, + ], + ( + "Failed to register API methods. Please follow the guide to " + "register the API methods: " + "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#custom-methods. " + "Error: {Unsupported api mode: `UNKNOWN_API_MODE`, " + "Supported modes are: ``, `async`, `async_stream`, `stream`.}" + ), + ), + ], + ) + @pytest.mark.usefixtures("caplog") + @mock.patch.object(_genai_types.AgentEngine, "operation_schemas") + @mock.patch.object(agent_engines.AgentEngines, "_get") + def test_invalid_operation_schema( + self, + mock_get, + mock_operation_schemas, + test_case_name, + test_operation_schemas, + want_log_output, + caplog, + ): + mock_get.return_value = _genai_types.AgentEngine() # just to avoid an API call + mock_operation_schemas.return_value = test_operation_schemas + self.client.agent_engines.get(name=_TEST_AGENT_ENGINE_RESOURCE_NAME) + assert want_log_output in caplog.text + + +@pytest.mark.usefixtures("google_auth_mock") +class TestAsyncAgentEngine: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + importlib.reload(vertexai) + importlib.reload(os) + os.environ[_TEST_AGENT_ENGINE_ENV_KEY] = _TEST_AGENT_ENGINE_ENV_VALUE + self.client = vertexai.Client( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + self.test_agent = CapitalizeEngine() + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_delete_agent_engine(self): + with mock.patch.object( + self.client.agent_engines._api_client, "async_request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + asyncio.run( + self.client.aio.agent_engines.delete( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME + ) + ) + request_mock.assert_called_with( + "delete", + _TEST_AGENT_ENGINE_RESOURCE_NAME, + {"_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}}, + None, + ) + + def test_delete_agent_engine_force(self): + with mock.patch.object( + self.client.agent_engines._api_client, "async_request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="") + asyncio.run( + self.client.aio.agent_engines.delete( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + force=True, + ) + ) + request_mock.assert_called_with( + "delete", + _TEST_AGENT_ENGINE_RESOURCE_NAME, + {"_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}, "force": True}, + None, + ) diff --git a/tests/unit/vertexai/genai/test_evals.py b/tests/unit/vertexai/genai/test_evals.py index 0bd23ec7e8..b9d3a89b10 100644 --- a/tests/unit/vertexai/genai/test_evals.py +++ b/tests/unit/vertexai/genai/test_evals.py @@ -154,6 +154,49 @@ def test_inference_with_string_model_success( } ), ) + assert inference_result.candidate_name == "gemini-pro" + assert inference_result.gcs_source is None + + @mock.patch.object(_evals_utils, "EvalDatasetLoader") + def test_inference_with_callable_model_sets_candidate_name( + self, mock_eval_dataset_loader + ): + mock_df = pd.DataFrame({"prompt": ["test prompt"]}) + mock_eval_dataset_loader.return_value.load.return_value = mock_df.to_dict( + orient="records" + ) + + def my_model_fn(contents): + return "callable response" + + inference_result = self.client.evals.run_inference( + model=my_model_fn, + src=mock_df, + ) + assert inference_result.candidate_name == "my_model_fn" + assert inference_result.gcs_source is None + + @mock.patch.object(_evals_utils, "EvalDatasetLoader") + def test_inference_with_lambda_model_candidate_name_is_none( + self, mock_eval_dataset_loader + ): + mock_df = pd.DataFrame({"prompt": ["test prompt"]}) + mock_eval_dataset_loader.return_value.load.return_value = mock_df.to_dict( + orient="records" + ) + + inference_result = self.client.evals.run_inference( + model=lambda x: "lambda response", # pylint: disable=unnecessary-lambda + src=mock_df, + ) + # Lambdas may or may not have a __name__ depending on Python version/env + # but it's typically '' if it exists. + # The code under test uses getattr(model, "__name__", None) + assert ( + inference_result.candidate_name == "" + or inference_result.candidate_name is None + ) + assert inference_result.gcs_source is None @mock.patch.object(_evals_utils, "EvalDatasetLoader") def test_inference_with_callable_model_success(self, mock_eval_dataset_loader): @@ -179,6 +222,8 @@ def mock_model_fn(contents): } ), ) + assert inference_result.candidate_name == "mock_model_fn" + assert inference_result.gcs_source is None @mock.patch.object(_evals_common, "Models") @mock.patch.object(_evals_utils, "EvalDatasetLoader") @@ -224,6 +269,8 @@ def test_inference_with_prompt_template( } ), ) + assert inference_result.candidate_name == "gemini-pro" + assert inference_result.gcs_source is None @mock.patch.object(_evals_common, "Models") @mock.patch.object(_evals_utils, "EvalDatasetLoader") @@ -273,6 +320,10 @@ def test_inference_with_gcs_destination( pd.testing.assert_frame_equal( inference_result.eval_dataset_df, expected_df_to_save ) + assert inference_result.candidate_name == "gemini-pro" + assert inference_result.gcs_source == vertexai_genai_types.GcsSource( + uris=[gcs_dest_path] + ) @mock.patch.object(_evals_common, "Models") @mock.patch.object(_evals_utils, "EvalDatasetLoader") @@ -322,6 +373,8 @@ def test_inference_with_local_destination( } ) pd.testing.assert_frame_equal(inference_result.eval_dataset_df, expected_df) + assert inference_result.candidate_name == "gemini-pro" + assert inference_result.gcs_source is None @mock.patch.object(_evals_common, "Models") @mock.patch.object(_evals_utils, "EvalDatasetLoader") @@ -405,6 +458,8 @@ def test_inference_from_request_column_save_locally( expected_records, key=lambda x: x["request"] ) os.remove(local_dest_path) + assert inference_result.candidate_name == "gemini-pro" + assert inference_result.gcs_source is None @mock.patch.object(_evals_common, "Models") def test_inference_from_local_jsonl_file(self, mock_models): @@ -478,6 +533,8 @@ def test_inference_from_local_jsonl_file(self, mock_models): any_order=True, ) os.remove(local_src_path) + assert inference_result.candidate_name == "gemini-pro" + assert inference_result.gcs_source is None @mock.patch.object(_evals_common, "Models") def test_inference_from_local_csv_file(self, mock_models): @@ -548,6 +605,8 @@ def test_inference_from_local_csv_file(self, mock_models): any_order=True, ) os.remove(local_src_path) + assert inference_result.candidate_name == "gemini-pro" + assert inference_result.gcs_source is None @mock.patch.object(_evals_common, "Models") @mock.patch.object(_evals_utils, "EvalDatasetLoader") @@ -719,6 +778,8 @@ def mock_generate_content_logic(*args, **kwargs): expected_df.sort_values(by="id").reset_index(drop=True), check_dtype=False, ) + assert inference_result.candidate_name == "gemini-pro" + assert inference_result.gcs_source is None @mock.patch.object(_evals_common, "Models") @mock.patch.object(_evals_utils, "EvalDatasetLoader") @@ -794,6 +855,8 @@ def test_inference_with_multimodal_content( } ), ) + assert inference_result.candidate_name == "gemini-pro" + assert inference_result.gcs_source is None class TestMetricPromptBuilder: @@ -1471,6 +1534,146 @@ def test_convert_with_additional_columns(self): assert eval_case.custom_column == "custom_value" +class TestOpenAIDataConverter: + """Unit tests for the _OpenAIDataConverter class.""" + + def setup_method(self): + self.converter = _evals_data_converters._OpenAIDataConverter() + + def test_convert_simple_prompt_response(self): + raw_data = [ + { + "request": {"messages": [{"role": "user", "content": "Hello"}]}, + "response": { + "choices": [{"message": {"role": "assistant", "content": "Hi"}}] + }, + } + ] + result_dataset = self.converter.convert(raw_data) + assert isinstance(result_dataset, vertexai_genai_types.EvaluationDataset) + assert len(result_dataset.eval_cases) == 1 + eval_case = result_dataset.eval_cases[0] + + assert eval_case.prompt == genai_types.Content( + parts=[genai_types.Part(text="Hello")], role="user" + ) + assert len(eval_case.responses) == 1 + assert eval_case.responses[0].response == genai_types.Content( + parts=[genai_types.Part(text="Hi")] + ) + assert eval_case.reference is None + assert eval_case.system_instruction is None + assert eval_case.conversation_history == [] + + def test_convert_with_system_instruction(self): + raw_data = [ + { + "request": { + "messages": [ + {"role": "system", "content": "Be helpful."}, + {"role": "user", "content": "Hello"}, + ] + }, + "response": { + "choices": [{"message": {"role": "assistant", "content": "Hi"}}] + }, + } + ] + result_dataset = self.converter.convert(raw_data) + eval_case = result_dataset.eval_cases[0] + assert eval_case.system_instruction == genai_types.Content( + parts=[genai_types.Part(text="Be helpful.")] + ) + assert eval_case.prompt == genai_types.Content( + parts=[genai_types.Part(text="Hello")], role="user" + ) + + def test_convert_with_conversation_history_and_reference(self): + raw_data = [ + { + "request": { + "messages": [ + {"role": "user", "content": "Initial user"}, + {"role": "assistant", "content": "Initial model (ref)"}, + ] + }, + "response": { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Actual response", + } + } + ] + }, + } + ] + result_dataset = self.converter.convert(raw_data) + eval_case = result_dataset.eval_cases[0] + + assert eval_case.prompt == genai_types.Content( + parts=[genai_types.Part(text="Initial user")], role="user" + ) + assert eval_case.reference.response == genai_types.Content( + parts=[genai_types.Part(text="Initial model (ref)")], role="assistant" + ) + assert len(eval_case.conversation_history) == 0 # History before prompt and ref + assert eval_case.responses[0].response == genai_types.Content( + parts=[genai_types.Part(text="Actual response")] + ) + + def test_convert_with_conversation_history_no_reference(self): + raw_data = [ + { + "request": { + "messages": [ + {"role": "user", "content": "Old user msg"}, + {"role": "assistant", "content": "Old model msg"}, + {"role": "user", "content": "Current prompt"}, + ] + }, + "response": { + "choices": [ + {"message": {"role": "assistant", "content": "A response"}} + ] + }, + } + ] + result_dataset = self.converter.convert(raw_data) + eval_case = result_dataset.eval_cases[0] + + assert eval_case.prompt == genai_types.Content( + parts=[genai_types.Part(text="Current prompt")], role="user" + ) + assert eval_case.reference is None + assert len(eval_case.conversation_history) == 2 + assert eval_case.conversation_history[0].content.parts[0].text == "Old user msg" + assert ( + eval_case.conversation_history[1].content.parts[0].text == "Old model msg" + ) + + def test_convert_empty_choices_uses_placeholder(self): + raw_data = [ + { + "request": {"messages": [{"role": "user", "content": "Hello"}]}, + "response": {"choices": []}, + } + ] + result_dataset = self.converter.convert(raw_data) + eval_case = result_dataset.eval_cases[0] + assert len(eval_case.responses) == 1 + assert ( + eval_case.responses[0].response.parts[0].text + == _evals_data_converters._PLACEHOLDER_RESPONSE_TEXT + ) + + def test_convert_skips_missing_request_or_response(self): + raw_data = [{"response": {"choices": []}}, {"request": {"messages": []}}] + result_dataset = self.converter.convert(raw_data) + assert len(result_dataset.eval_cases) == 0 + + class TestMetric: """Unit tests for the Metric class.""" @@ -1649,6 +1852,113 @@ def test_merge_two_flatten_datasets(self): 1 ].response == genai_types.Content(parts=[genai_types.Part(text="Response 2b")]) + def test_merge_flatten_and_openai_datasets(self): + raw_dataset_flatten = [ + { + "prompt": "Prompt 1", + "response": "Response 1 Flatten", + "reference": "Ref 1", + }, + ] + raw_dataset_openai = [ + { + "request": { + "messages": [ + {"role": "user", "content": "Prompt 1"}, + {"role": "assistant", "content": "Ref 1"}, + ] + }, + "response": { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Response 1 OpenAI", + } + } + ] + }, + } + ] + schemas = [ + _evals_data_converters.EvalDatasetSchema.FLATTEN, + _evals_data_converters.EvalDatasetSchema.OPENAI, + ] + + merged_dataset = ( + _evals_data_converters.merge_response_datasets_into_canonical_format( + [raw_dataset_flatten, raw_dataset_openai], schemas=schemas + ) + ) + assert len(merged_dataset.eval_cases) == 1 + case0 = merged_dataset.eval_cases[0] + assert case0.prompt == genai_types.Content( + parts=[genai_types.Part(text="Prompt 1")] + ) + assert case0.reference.response == genai_types.Content( + parts=[genai_types.Part(text="Ref 1")] + ) + assert len(case0.responses) == 2 + assert case0.responses[0].response == genai_types.Content( + parts=[genai_types.Part(text="Response 1 Flatten")] + ) + assert case0.responses[1].response == genai_types.Content( + parts=[genai_types.Part(text="Response 1 OpenAI")] + ) + + def test_merge_two_openai_datasets(self): + raw_dataset_openai_1 = [ + { + "request": { + "messages": [ + {"role": "developer", "content": "Sys1"}, + {"role": "user", "content": "P1"}, + ] + }, + "response": { + "choices": [{"message": {"role": "assistant", "content": "R1a"}}] + }, + } + ] + raw_dataset_openai_2 = [ + { + "request": { + "messages": [ + {"role": "system", "content": "Sys1"}, + {"role": "user", "content": "P1"}, + ] + }, + "response": { + "choices": [{"message": {"role": "assistant", "content": "R1b"}}] + }, + } + ] + schemas = [ + _evals_data_converters.EvalDatasetSchema.OPENAI, + _evals_data_converters.EvalDatasetSchema.OPENAI, + ] + + merged_dataset = ( + _evals_data_converters.merge_response_datasets_into_canonical_format( + [raw_dataset_openai_1, raw_dataset_openai_2], schemas=schemas + ) + ) + assert len(merged_dataset.eval_cases) == 1 + case0 = merged_dataset.eval_cases[0] + assert case0.prompt == genai_types.Content( + parts=[genai_types.Part(text="P1")], role="user" + ) + assert case0.system_instruction == genai_types.Content( + parts=[genai_types.Part(text="Sys1")] + ) + assert len(case0.responses) == 2 + assert case0.responses[0].response == genai_types.Content( + parts=[genai_types.Part(text="R1a")] + ) + assert case0.responses[1].response == genai_types.Content( + parts=[genai_types.Part(text="R1b")] + ) + def test_merge_flatten_and_gemini_datasets(self): raw_dataset_1 = [ {"prompt": "Prompt 1", "response": "Response 1a"}, @@ -2373,6 +2683,61 @@ def test_merge_with_missing_response(self): ) +@pytest.mark.usefixtures("google_auth_mock") +class TestAutoDetectDatasetSchema: + def test_auto_detect_gemini_schema(self): + raw_data = [ + { + "request": { + "contents": [{"role": "user", "parts": [{"text": "Hello"}]}] + }, + "response": { + "candidates": [ + {"content": {"role": "model", "parts": [{"text": "Hi"}]}} + ] + }, + } + ] + assert ( + _evals_data_converters.auto_detect_dataset_schema(raw_data) + == _evals_data_converters.EvalDatasetSchema.GEMINI + ) + + def test_auto_detect_flatten_schema(self): + raw_data = [{"prompt": "Hello", "response": "Hi"}] + assert ( + _evals_data_converters.auto_detect_dataset_schema(raw_data) + == _evals_data_converters.EvalDatasetSchema.FLATTEN + ) + + def test_auto_detect_openai_schema(self): + raw_data = [ + { + "request": {"messages": [{"role": "user", "content": "Hello"}]}, + "response": { + "choices": [{"message": {"role": "assistant", "content": "Hi"}}] + }, + } + ] + assert ( + _evals_data_converters.auto_detect_dataset_schema(raw_data) + == _evals_data_converters.EvalDatasetSchema.OPENAI + ) + + def test_auto_detect_unknown_schema(self): + raw_data = [{"foo": "bar"}] + assert ( + _evals_data_converters.auto_detect_dataset_schema(raw_data) + == _evals_data_converters.EvalDatasetSchema.UNKNOWN + ) + + def test_auto_detect_empty_dataset(self): + assert ( + _evals_data_converters.auto_detect_dataset_schema([]) + == _evals_data_converters.EvalDatasetSchema.UNKNOWN + ) + + @pytest.fixture def mock_api_client_fixture(): mock_client = mock.Mock(spec=client.Client) @@ -2560,6 +2925,98 @@ def test_execute_evaluation_llm_metric( call_args = mock_eval_dependencies["mock_evaluate_instances"].call_args assert "pointwise_metric_input" in call_args[1]["metric_config"] + @mock.patch.object(_evals_data_converters, "get_dataset_converter") + def test_execute_evaluation_with_openai_schema( + self, + mock_get_converter, + mock_api_client_fixture, + mock_eval_dependencies, + ): + mock_openai_raw_data = [ + { + "request": {"messages": [{"role": "user", "content": "OpenAI Prompt"}]}, + "response": { + "choices": [ + { + "message": { + "index": 0, + "message": { + "role": "assistant", + "content": "OpenAI Response", + "refusal": None, + "annotations": [], + }, + "logprobs": None, + "finish_reason": "stop", + } + } + ] + }, + } + ] + converted_eval_case = vertexai_genai_types.EvalCase( + prompt=genai_types.Content( + parts=[genai_types.Part(text="OpenAI Prompt")], role="user" + ), + responses=[ + vertexai_genai_types.ResponseCandidate( + response=genai_types.Content( + parts=[genai_types.Part(text="Candidate Response")] + ) + ) + ], + ) + mock_converted_dataset = vertexai_genai_types.EvaluationDataset( + eval_cases=[converted_eval_case] + ) + + mock_converter_instance = mock.Mock( + spec=_evals_data_converters._OpenAIDataConverter + ) + mock_converter_instance.convert.return_value = mock_converted_dataset + mock_get_converter.return_value = mock_converter_instance + + input_dataset_for_loader = vertexai_genai_types.EvaluationDataset( + eval_dataset_df=pd.DataFrame(mock_openai_raw_data) + ) + llm_metric = vertexai_genai_types.LLMMetric( + name="test_metric", prompt_template="Evaluate: {response}" + ) + + with mock.patch.object(_evals_utils, "EvalDatasetLoader") as mock_loader_class: + mock_loader_instance = mock_loader_class.return_value + mock_loader_instance.load.return_value = mock_openai_raw_data + + with mock.patch.object( + _evals_metric_handlers.LLMMetricHandler, "process" + ) as mock_llm_process: + mock_llm_process.return_value = ( + vertexai_genai_types.EvalCaseMetricResult( + metric_name="test_metric", score=0.75 + ) + ) + + result = _evals_common._execute_evaluation( + api_client=mock_api_client_fixture, + dataset=input_dataset_for_loader, + metrics=[llm_metric], + dataset_schema="OPENAI", + ) + + mock_loader_instance.load.assert_called_once_with( + input_dataset_for_loader.eval_dataset_df + ) + mock_get_converter.assert_called_with( + _evals_data_converters.EvalDatasetSchema.OPENAI + ) + mock_converter_instance.convert.assert_called_once_with(mock_openai_raw_data) + + assert isinstance(result, vertexai_genai_types.EvaluationResult) + assert len(result.summary_metrics) == 1 + summary_metric = result.summary_metrics[0] + assert summary_metric.metric_name == "test_metric" + assert summary_metric.mean_score == 0.75 + def test_execute_evaluation_custom_metric( self, mock_api_client_fixture, mock_eval_dependencies ): @@ -2901,3 +3358,76 @@ def test_execute_evaluation_multiple_datasets( assert summary_metric.mean_score == 1.0 assert mock_eval_dependencies["mock_evaluate_instances"].call_count == 2 + + def test_execute_evaluation_deduplicates_candidate_names( + self, mock_api_client_fixture, mock_eval_dependencies + ): + """Tests that duplicate candidate names are indexed.""" + dataset1 = vertexai_genai_types.EvaluationDataset( + eval_dataset_df=pd.DataFrame( + [{"prompt": "p1", "response": "r1", "reference": "ref1"}] + ), + candidate_name="gemini-pro", + ) + dataset2 = vertexai_genai_types.EvaluationDataset( + eval_dataset_df=pd.DataFrame( + [{"prompt": "p1", "response": "r2", "reference": "ref1"}] + ), + candidate_name="gemini-flash", + ) + dataset3 = vertexai_genai_types.EvaluationDataset( + eval_dataset_df=pd.DataFrame( + [{"prompt": "p1", "response": "r3", "reference": "ref1"}] + ), + candidate_name="gemini-pro", + ) + + mock_eval_dependencies[ + "mock_evaluate_instances" + ].return_value = vertexai_genai_types.EvaluateInstancesResponse( + exact_match_results=vertexai_genai_types.ExactMatchResults( + exact_match_metric_values=[ + vertexai_genai_types.ExactMatchMetricValue(score=1.0) + ] + ) + ) + + result = _evals_common._execute_evaluation( + api_client=mock_api_client_fixture, + dataset=[dataset1, dataset2, dataset3], + metrics=[vertexai_genai_types.Metric(name="exact_match")], + ) + + assert result.metadata.candidate_names == [ + "gemini-pro #1", + "gemini-flash", + "gemini-pro #2", + ] + + @mock.patch("vertexai._genai._evals_common.datetime") + def test_execute_evaluation_adds_creation_timestamp( + self, mock_datetime, mock_api_client_fixture, mock_eval_dependencies + ): + """Tests that creation_timestamp is added to the result metadata.""" + import datetime + + mock_now = datetime.datetime( + 2025, 6, 18, 12, 0, 0, tzinfo=datetime.timezone.utc + ) + mock_datetime.datetime.now.return_value = mock_now + + dataset = vertexai_genai_types.EvaluationDataset( + eval_dataset_df=pd.DataFrame( + [{"prompt": "p", "response": "r", "reference": "r"}] + ) + ) + metric = vertexai_genai_types.Metric(name="exact_match") + + result = _evals_common._execute_evaluation( + api_client=mock_api_client_fixture, + dataset=dataset, + metrics=[metric], + ) + + assert result.metadata is not None + assert result.metadata.creation_timestamp == mock_now diff --git a/tests/unit/vertexai/model_garden/test_model_garden.py b/tests/unit/vertexai/model_garden/test_model_garden.py index c16ba04f28..d2b334db16 100644 --- a/tests/unit/vertexai/model_garden/test_model_garden.py +++ b/tests/unit/vertexai/model_garden/test_model_garden.py @@ -181,6 +181,7 @@ def get_publisher_model_mock(): multi_deploy_vertex=types.PublisherModel.CallToAction.DeployVertex( multi_deploy_vertex=[ types.PublisherModel.CallToAction.Deploy( + deploy_task_name="vLLM 32K context", container_spec=types.ModelContainerSpec( image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20241202_0916_RC00", command=["python", "main.py"], @@ -198,6 +199,7 @@ def get_publisher_model_mock(): ), ), types.PublisherModel.CallToAction.Deploy( + deploy_task_name="vLLM 128K context", container_spec=types.ModelContainerSpec( image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/text-generation-inference-cu121.2-1.py310:latest", command=["python", "main.py"], @@ -1032,17 +1034,17 @@ def test_list_deploy_options_concise(self, get_publisher_model_mock): result = model.list_deploy_options(concise=True) expected_result = textwrap.dedent( """\ - [Option 1] - serving_container_image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20241202_0916_RC00", - machine_type="g2-standard-16", - accelerator_type="NVIDIA_L4", - accelerator_count=1, + [Option 1: vLLM 32K context] + serving_container_image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20241202_0916_RC00", + machine_type="g2-standard-16", + accelerator_type="NVIDIA_L4", + accelerator_count=1, - [Option 2] - serving_container_image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/text-generation-inference-cu121.2-1.py310:latest", - machine_type="g2-standard-32", - accelerator_type="NVIDIA_L4", - accelerator_count=4,""" + [Option 2: vLLM 128K context] + serving_container_image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/text-generation-inference-cu121.2-1.py310:latest", + machine_type="g2-standard-32", + accelerator_type="NVIDIA_L4", + accelerator_count=4,""" ) assert result == expected_result get_publisher_model_mock.assert_called_with( @@ -1058,16 +1060,16 @@ def test_list_deploy_options_concise(self, get_publisher_model_mock): expected_hf_result = textwrap.dedent( """\ [Option 1] - serving_container_image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20241202_0916_RC00", - machine_type="g2-standard-16", - accelerator_type="NVIDIA_L4", - accelerator_count=1, + serving_container_image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20241202_0916_RC00", + machine_type="g2-standard-16", + accelerator_type="NVIDIA_L4", + accelerator_count=1, [Option 2] - serving_container_image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/text-generation-inference-cu121.2-1.py310:latest", - machine_type="g2-standard-32", - accelerator_type="NVIDIA_L4", - accelerator_count=4,""" + serving_container_image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/text-generation-inference-cu121.2-1.py310:latest", + machine_type="g2-standard-32", + accelerator_type="NVIDIA_L4", + accelerator_count=4,""" ) assert hf_result == expected_hf_result get_publisher_model_mock.assert_called_with( diff --git a/vertexai/_genai/_agent_engines_utils.py b/vertexai/_genai/_agent_engines_utils.py new file mode 100644 index 0000000000..a6b8657311 --- /dev/null +++ b/vertexai/_genai/_agent_engines_utils.py @@ -0,0 +1,156 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Utility functions for agent engines.""" + +from typing import Any, Callable, Coroutine, Iterator, AsyncIterator +from . import types + + +def _wrap_query_operation(*, method_name: str) -> Callable[..., Any]: + """Wraps an Agent Engine method, creating a callable for `query` API. + + This function creates a callable object that executes the specified + Agent Engine method using the `query` API. It handles the creation of + the API request and the processing of the API response. + + Args: + method_name: The name of the Agent Engine method to call. + doc: Documentation string for the method. + + Returns: + A callable object that executes the method on the Agent Engine via + the `query` API. + """ + + def _method(self: types.AgentEngine, **kwargs): + if not self.api_client: + raise ValueError("api_client is not initialized.") + if not self.api_resource: + raise ValueError("api_resource is not initialized.") + response = self.api_client._query( + name=self.api_resource.name, + config={ + "class_method": method_name, + "input": kwargs, + "include_all_fields": True, + }, + ) + return response.output + + return _method + + +def _wrap_async_query_operation(*, method_name: str) -> Callable[..., Coroutine]: + """Wraps an Agent Engine method, creating an async callable for `query` API. + + This function creates a callable object that executes the specified + Agent Engine method asynchronously using the `query` API. It handles the + creation of the API request and the processing of the API response. + + Args: + method_name: The name of the Agent Engine method to call. + doc: Documentation string for the method. + + Returns: + A callable object that executes the method on the Agent Engine via + the `query` API. + """ + + async def _method(self: types.AgentEngine, **kwargs): + if not self.api_async_client: + raise ValueError("api_async_client is not initialized.") + if not self.api_resource: + raise ValueError("api_resource is not initialized.") + response = await self.api_async_client._query( + name=self.api_resource.name, + config={ + "class_method": method_name, + "input": kwargs, + "include_all_fields": True, + }, + ) + return response.output + + return _method + + +def _wrap_stream_query_operation(*, method_name: str) -> Callable[..., Iterator[Any]]: + """Wraps an Agent Engine method, creating a callable for `stream_query` API. + + This function creates a callable object that executes the specified + Agent Engine method using the `stream_query` API. It handles the + creation of the API request and the processing of the API response. + + Args: + method_name: The name of the Agent Engine method to call. + doc: Documentation string for the method. + + Returns: + A callable object that executes the method on the Agent Engine via + the `stream_query` API. + """ + + def _method(self: types.AgentEngine, **kwargs) -> Iterator[Any]: + if not self.api_client: + raise ValueError("api_client is not initialized.") + if not self.api_resource: + raise ValueError("api_resource is not initialized.") + for response in self.api_client._stream_query( + name=self.api_resource.name, + config={ + "class_method": method_name, + "input": kwargs, + "include_all_fields": True, + }, + ): + yield response + + return _method + + +def _wrap_async_stream_query_operation( + *, method_name: str +) -> Callable[..., AsyncIterator]: + """Wraps an Agent Engine method, creating an async callable for `stream_query` API. + + This function creates a callable object that executes the specified + Agent Engine method using the `stream_query` API. It handles the + creation of the API request and the processing of the API response. + + Args: + method_name: The name of the Agent Engine method to call. + doc: Documentation string for the method. + + Returns: + A callable object that executes the method on the Agent Engine via + the `stream_query` API. + """ + + async def _method(self: types.AgentEngine, **kwargs): + if not self.api_client: + raise ValueError("api_client is not initialized.") + if not self.api_resource: + raise ValueError("api_resource is not initialized.") + for response in self.api_client._stream_query( + name=self.api_resource.name, + config={ + "class_method": method_name, + "input": kwargs, + "include_all_fields": True, + }, + ): + yield response + + return _method diff --git a/vertexai/_genai/_evals_common.py b/vertexai/_genai/_evals_common.py index c460c64950..6256dbab15 100644 --- a/vertexai/_genai/_evals_common.py +++ b/vertexai/_genai/_evals_common.py @@ -13,7 +13,9 @@ # limitations under the License. # """Common utilities for evals.""" +import collections import concurrent.futures +import datetime import json import logging import os @@ -475,6 +477,17 @@ def _execute_inference( end_time = time.time() logger.info("Inference completed in %.2f seconds.", end_time - start_time) + candidate_name = None + if isinstance(model, str): + candidate_name = model + elif callable(model): + candidate_name = getattr(model, "__name__", None) + + evaluation_dataset = types.EvaluationDataset( + eval_dataset_df=results_df, + candidate_name=candidate_name, + ) + if dest: file_name = "inference_results.jsonl" full_dest_path = dest @@ -500,13 +513,14 @@ def _execute_inference( file_type="jsonl", ) logger.info("Results saved to GCS: %s", full_dest_path) + evaluation_dataset.gcs_source = types.GcsSource(uris=[full_dest_path]) else: results_df.to_json(full_dest_path, orient="records", lines=True) logger.info("Results saved locally to: %s", full_dest_path) except Exception as e: # pylint: disable=broad-exception-caught logger.error("Failed to save results to %s. Error: %s", full_dest_path, e) - return types.EvaluationDataset(eval_dataset_df=results_df) + return evaluation_dataset def _get_dataset_source( @@ -534,7 +548,7 @@ def _get_dataset_source( def _resolve_dataset_inputs( dataset: list[types.EvaluationDataset], - dataset_schema: Optional[Literal["gemini", "flatten"]], + dataset_schema: Optional[Literal["GEMINI", "FLATTEN", "OPENAI"]], loader: _evals_utils.EvalDatasetLoader, ) -> tuple[types.EvaluationDataset, int]: """Loads and processes single or multiple datasets for evaluation. @@ -657,7 +671,7 @@ def _execute_evaluation( api_client: Any, dataset: Union[types.EvaluationDataset, list[types.EvaluationDataset]], metrics: list[types.Metric], - dataset_schema: Optional[Literal["gemini", "flatten"]] = None, + dataset_schema: Optional[Literal["GEMINI", "FLATTEN", "OPENAI"]] = None, dest: Optional[str] = None, ) -> types.EvaluationResult: """Evaluates a dataset using the provided metrics. @@ -690,6 +704,19 @@ def _execute_evaluation( f"Unsupported dataset type: {type(dataset)}. Must be an" " EvaluationDataset or a list of EvaluationDataset." ) + original_candidate_names = [ + ds.candidate_name or f"candidate_{i+1}" for i, ds in enumerate(dataset_list) + ] + name_counts = collections.Counter(original_candidate_names) + deduped_candidate_names = [] + current_name_counts = collections.defaultdict(int) + + for name in original_candidate_names: + if name_counts[name] > 1: + current_name_counts[name] += 1 + deduped_candidate_names.append(f"{name} #{current_name_counts[name]}") + else: + deduped_candidate_names.append(name) loader = _evals_utils.EvalDatasetLoader(api_client=api_client) processed_eval_dataset, num_response_candidates = _resolve_dataset_inputs( @@ -714,6 +741,17 @@ def _execute_evaluation( logger.info("Evaluation took: %f seconds", t2 - t1) evaluation_result.evaluation_dataset = dataset_list + + if not evaluation_result.metadata: + evaluation_result.metadata = types.EvaluationRunMetadata() + + evaluation_result.metadata.creation_timestamp = datetime.datetime.now( + datetime.timezone.utc + ) + + if deduped_candidate_names: + evaluation_result.metadata.candidate_names = deduped_candidate_names + logger.info("Evaluation run completed.") if dest: diff --git a/vertexai/_genai/_evals_data_converters.py b/vertexai/_genai/_evals_data_converters.py index cb41f3c461..cb814e032f 100644 --- a/vertexai/_genai/_evals_data_converters.py +++ b/vertexai/_genai/_evals_data_converters.py @@ -27,23 +27,13 @@ logger = logging.getLogger("vertexai_genai._evals_data_converters") -_PLACEHOLDER_RESPONSE_TEXT = "Error: Missing response for this candidate" - - -def _create_placeholder_response_candidate( - text: str = _PLACEHOLDER_RESPONSE_TEXT, -) -> types.ResponseCandidate: - """Creates a ResponseCandidate with placeholder text.""" - return types.ResponseCandidate( - response=genai_types.Content(parts=[genai_types.Part(text=text)]) - ) - class EvalDatasetSchema(_common.CaseInSensitiveEnum): """Represents the schema of an evaluation dataset.""" GEMINI = "gemini" FLATTEN = "flatten" + OPENAI = "openai" UNKNOWN = "unknown" @@ -56,6 +46,18 @@ def convert(self, raw_data: Any) -> types.EvaluationDataset: raise NotImplementedError() +_PLACEHOLDER_RESPONSE_TEXT = "Error: Missing response for this candidate" + + +def _create_placeholder_response_candidate( + text: str = _PLACEHOLDER_RESPONSE_TEXT, +) -> types.ResponseCandidate: + """Creates a ResponseCandidate with placeholder text.""" + return types.ResponseCandidate( + response=genai_types.Content(parts=[genai_types.Part(text=text)]) + ) + + class _GeminiEvalDataConverter(_EvalDataConverter): """Converter for dataset in the Gemini format.""" @@ -96,15 +98,17 @@ def _parse_request( ) if conversation_history: last_message = conversation_history.pop() - if last_message.content and last_message.content.role == "user": + last_message_role = ( + last_message.content.role if last_message.content else "user" + ) + if last_message_role in ["user", None]: prompt = last_message.content - elif last_message.content and last_message.content.role == "model": - # If the last message is from the model, then it's the reference. + elif last_message_role == "model": reference = types.ResponseCandidate(response=last_message.content) - if conversation_history: # Ensure there's a previous message + if conversation_history: second_to_last_message = conversation_history.pop() prompt = second_to_last_message.content - else: # If only one model message, prompt is invalid. + else: prompt = genai_types.Content() return prompt, system_instruction, conversation_history, reference @@ -152,13 +156,11 @@ def convert(self, raw_data: list[dict[str, Any]]) -> types.EvaluationDataset: ) ) ) - else: # Handle cases where there are no candidates. + else: responses.append(_create_placeholder_response_candidate()) except Exception: - # Fallback for dicts that don't match the schema, treat as empty. responses.append(_create_placeholder_response_candidate()) else: - # For any other type, treat as an empty/invalid response. responses.append(_create_placeholder_response_candidate()) eval_case = types.EvalCase( @@ -291,66 +293,171 @@ def convert(self, raw_data: list[dict[str, Any]]) -> types.EvaluationDataset: return types.EvaluationDataset(eval_cases=eval_cases) +class _OpenAIDataConverter(_EvalDataConverter): + """Converter for dataset in OpenAI's Chat Completion format.""" + + def _parse_messages( + self, messages: list[dict[str, Any]] + ) -> tuple[ + Optional[genai_types.Content], + list[types.Message], + Optional[genai_types.Content], + Optional[types.ResponseCandidate], + ]: + """Parses a list of messages into instruction, history, prompt, and reference.""" + system_instruction = None + prompt = None + reference = None + conversation_history = [] + + if messages and messages[0].get("role") in ["system", "developer"]: + system_instruction = genai_types.Content( + parts=[genai_types.Part(text=messages[0].get("content"))] + ) + messages = messages[1:] + + for turn_id, msg in enumerate(messages): + role = msg.get("role", "user") + content = msg.get("content", "") + conversation_history.append( + types.Message( + turn_id=str(turn_id), + content=genai_types.Content( + parts=[genai_types.Part(text=content)], role=role + ), + author=role, + ) + ) + + if conversation_history: + last_message = conversation_history.pop() + if last_message.content and last_message.content.role == "user": + prompt = last_message.content + elif last_message.content and last_message.content.role == "assistant": + reference = types.ResponseCandidate(response=last_message.content) + if conversation_history: + second_to_last_message = conversation_history.pop() + prompt = second_to_last_message.content + + return system_instruction, conversation_history, prompt, reference + + @override + def convert(self, raw_data: list[dict[str, Any]]) -> types.EvaluationDataset: + """Converts a list of OpenAI ChatCompletion data into an EvaluationDataset.""" + eval_cases = [] + for i, item in enumerate(raw_data): + eval_case_id = f"openai_eval_case_{i}" + + if "request" not in item or "response" not in item: + logger.warning( + f"Skipping case {i} due to missing 'request' or 'response' key." + ) + continue + + request_data = item.get("request", {}) + response_data = item.get("response", {}) + + messages = request_data.get("messages", []) + choices = response_data.get("choices", []) + + ( + system_instruction, + conversation_history, + prompt, + reference, + ) = self._parse_messages(messages) + + if prompt is None and reference is None: + logger.warning( + "Could not determine a user prompt or reference for case %s." + " Skipping.", + i, + ) + continue + + responses = [] + if ( + choices + and isinstance(choices, list) + and isinstance(choices[0], dict) + and choices[0].get("message") + ): + response_content = choices[0]["message"].get("content", "") + responses.append( + types.ResponseCandidate( + response=genai_types.Content( + parts=[genai_types.Part(text=response_content)] + ) + ) + ) + else: + responses.append(_create_placeholder_response_candidate()) + + other_fields = { + k: v for k, v in item.items() if k not in ["request", "response"] + } + + eval_case = types.EvalCase( + eval_case_id=eval_case_id, + prompt=prompt, + responses=responses, + reference=reference, + system_instruction=system_instruction, + conversation_history=conversation_history, + **other_fields, + ) + eval_cases.append(eval_case) + + return types.EvaluationDataset(eval_cases=eval_cases) + + def auto_detect_dataset_schema( raw_dataset: list[dict[str, Any]], ) -> EvalDatasetSchema: """Detects the schema of a raw dataset.""" if not raw_dataset: - logger.debug("Empty dataset, returning UNKNOWN schema.") return EvalDatasetSchema.UNKNOWN first_item = raw_dataset[0] - if not isinstance(first_item, dict): - logger.warning( - "First item in dataset is not a dictionary. Cannot determine schema." - ) - return EvalDatasetSchema.UNKNOWN - keys = set(first_item.keys()) - request_field = first_item.get("request") - if isinstance(request_field, dict) and isinstance( - request_field.get("contents"), list - ): - try: - _GeminiEvalDataConverter().convert([first_item]) - logger.debug( - "Detected GEMINI schema based on 'request.contents' presence and" - " successful conversion." - ) - return EvalDatasetSchema.GEMINI - except (ValueError, KeyError, AttributeError, TypeError) as e: - logger.debug( - "First item looked like Gemini schema (due to 'request.contents') but" - " conversion failed (error: %s). Will try other schemas.", - e, - ) + if "request" in keys and "response" in keys: + request_content = first_item.get("request", {}) + if isinstance(request_content, dict) and "contents" in request_content: + contents_list = request_content.get("contents") + if ( + contents_list + and isinstance(contents_list, list) + and isinstance(contents_list[0], dict) + ): + if "parts" in contents_list[0]: + return EvalDatasetSchema.GEMINI + + if "request" in keys and "response" in keys: + request_content = first_item.get("request", {}) + if isinstance(request_content, dict) and "messages" in request_content: + messages_list = request_content.get("messages") + if ( + messages_list + and isinstance(messages_list, list) + and isinstance(messages_list[0], dict) + ): + if "role" in messages_list[0] and "content" in messages_list[0]: + return EvalDatasetSchema.OPENAI - # Check for flatten schema if Gemini check failed or wasn't applicable if {"prompt", "response"}.issubset(keys) or { "response", "reference", }.issubset(keys): - try: - _FlattenEvalDataConverter().convert([first_item]) - logger.debug( - "Detected FLATTEN schema based on key presence and successful" - " conversion." - ) - return EvalDatasetSchema.FLATTEN - except (ValueError, KeyError, AttributeError, TypeError) as e: - logger.debug( - "Flatten schema key check passed, but conversion failed (error: %s).", - e, - ) - - logger.debug("Could not confidently determine schema. Returning UNKNOWN.") - return EvalDatasetSchema.UNKNOWN + return EvalDatasetSchema.FLATTEN + else: + return EvalDatasetSchema.UNKNOWN -_SCHEMA_TO_CONVERTER = { +_CONVERTER_REGISTRY = { EvalDatasetSchema.GEMINI: _GeminiEvalDataConverter, EvalDatasetSchema.FLATTEN: _FlattenEvalDataConverter, + EvalDatasetSchema.OPENAI: _OpenAIDataConverter, } @@ -358,8 +465,8 @@ def get_dataset_converter( dataset_schema: EvalDatasetSchema, ) -> _EvalDataConverter: """Returns the appropriate dataset converter for the given schema.""" - if dataset_schema in _SCHEMA_TO_CONVERTER: - return _SCHEMA_TO_CONVERTER[dataset_schema]() + if dataset_schema in _CONVERTER_REGISTRY: + return _CONVERTER_REGISTRY[dataset_schema]() else: raise ValueError(f"Unsupported dataset schema: {dataset_schema}") @@ -398,10 +505,15 @@ def _validate_case_consistency( base_prompt_text_preview = _get_first_part_text(base_case.prompt)[:50] current_prompt_text_preview = _get_first_part_text(current_case.prompt)[:50] logger.warning( - f"Prompt mismatch for case index {case_idx} between base dataset (0)" - f" and dataset {dataset_idx}. Using prompt from base. Base prompt" - f" preview: '{base_prompt_text_preview}...', Dataset" - f" {dataset_idx} prompt preview: '{current_prompt_text_preview}...'" + "Prompt mismatch for case index %d between base dataset (0)" + " and dataset %d. Using prompt from base. Base prompt" + " preview: '%s...', Dataset" + " %d prompt preview: '%s...'", + case_idx, + dataset_idx, + base_prompt_text_preview, + dataset_idx, + current_prompt_text_preview, ) base_ref_text = _get_text_from_reference(base_case.reference) @@ -409,16 +521,22 @@ def _validate_case_consistency( if bool(base_case.reference) != bool(current_case.reference): logger.warning( - f"Reference presence mismatch for case index {case_idx} between base" - f" dataset (0) and dataset {dataset_idx}. Using reference (or lack" - " thereof) from base." + "Reference presence mismatch for case index %d between base" + " dataset (0) and dataset %d. Using reference (or lack" + " thereof) from base.", + case_idx, + dataset_idx, ) elif base_ref_text != current_ref_text: logger.warning( - f"Reference text mismatch for case index {case_idx} between base" - f" dataset (0) and dataset {dataset_idx}. Using reference from base." - f" Base ref: '{str(base_ref_text)[:50]}...', Current ref:" - f" '{str(current_ref_text)[:50]}...'" + "Reference text mismatch for case index %d between base" + " dataset (0) and dataset %d. Using reference from base. " + " Base ref: '%s...', Current ref:" + " '%s...'", + case_idx, + dataset_idx, + str(base_ref_text)[:50], + str(current_ref_text)[:50], ) diff --git a/vertexai/_genai/agent_engines.py b/vertexai/_genai/agent_engines.py new file mode 100644 index 0000000000..4f86678ded --- /dev/null +++ b/vertexai/_genai/agent_engines.py @@ -0,0 +1,1661 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import json +import logging +import time +from typing import Any, Iterator, Optional, Sequence, Union +from urllib.parse import urlencode + +from google.genai import _api_module +from google.genai import _common +from google.genai import types as genai_types +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv +from google.genai.pagers import Pager + +from . import _agent_engines_utils +from . import types + + +logger = logging.getLogger("vertexai_genai.agentengines") + + +def _ReasoningEngineSpec_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["agent_framework"]) is not None: + setv( + to_object, + ["agentFramework"], + getv(from_object, ["agent_framework"]), + ) + + if getv(from_object, ["class_methods"]) is not None: + setv(to_object, ["classMethods"], getv(from_object, ["class_methods"])) + + if getv(from_object, ["deployment_spec"]) is not None: + setv( + to_object, + ["deploymentSpec"], + getv(from_object, ["deployment_spec"]), + ) + + if getv(from_object, ["package_spec"]) is not None: + setv(to_object, ["packageSpec"], getv(from_object, ["package_spec"])) + + return to_object + + +def _CreateAgentEngineConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["display_name"]) is not None: + setv(parent_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["description"]) is not None: + setv(parent_object, ["description"], getv(from_object, ["description"])) + + if getv(from_object, ["spec"]) is not None: + setv( + parent_object, + ["spec"], + _ReasoningEngineSpec_to_vertex(getv(from_object, ["spec"]), to_object), + ) + + return to_object + + +def _CreateAgentEngineRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _CreateAgentEngineConfig_to_vertex( + getv(from_object, ["config"]), to_object + ), + ) + + return to_object + + +def _DeleteAgentEngineRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["force"]) is not None: + setv(to_object, ["force"], getv(from_object, ["force"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _GetAgentEngineRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _ListAgentEngineConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["page_size"]) is not None: + setv( + parent_object, + ["_query", "pageSize"], + getv(from_object, ["page_size"]), + ) + + if getv(from_object, ["page_token"]) is not None: + setv( + parent_object, + ["_query", "pageToken"], + getv(from_object, ["page_token"]), + ) + + if getv(from_object, ["filter"]) is not None: + setv(parent_object, ["_query", "filter"], getv(from_object, ["filter"])) + + return to_object + + +def _ListAgentEngineRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _ListAgentEngineConfig_to_vertex(getv(from_object, ["config"]), to_object), + ) + + return to_object + + +def _GetAgentEngineOperationParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["operation_name"]) is not None: + setv( + to_object, + ["_url", "operationName"], + getv(from_object, ["operation_name"]), + ) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _QueryAgentEngineConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["class_method"]) is not None: + setv(parent_object, ["classMethod"], getv(from_object, ["class_method"])) + + if getv(from_object, ["input"]) is not None: + setv(parent_object, ["input"], getv(from_object, ["input"])) + + if getv(from_object, ["include_all_fields"]) is not None: + setv( + to_object, + ["includeAllFields"], + getv(from_object, ["include_all_fields"]), + ) + + return to_object + + +def _QueryAgentEngineRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _QueryAgentEngineConfig_to_vertex(getv(from_object, ["config"]), to_object), + ) + + return to_object + + +def _UpdateAgentEngineConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["display_name"]) is not None: + setv(parent_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["description"]) is not None: + setv(parent_object, ["description"], getv(from_object, ["description"])) + + if getv(from_object, ["spec"]) is not None: + setv( + parent_object, + ["spec"], + _ReasoningEngineSpec_to_vertex(getv(from_object, ["spec"]), to_object), + ) + + if getv(from_object, ["update_mask"]) is not None: + setv( + parent_object, + ["_query", "updateMask"], + getv(from_object, ["update_mask"]), + ) + + return to_object + + +def _UpdateAgentEngineRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _UpdateAgentEngineConfig_to_vertex( + getv(from_object, ["config"]), to_object + ), + ) + + return to_object + + +def _ReasoningEngine_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["createTime"]) is not None: + setv(to_object, ["create_time"], getv(from_object, ["createTime"])) + + if getv(from_object, ["description"]) is not None: + setv(to_object, ["description"], getv(from_object, ["description"])) + + if getv(from_object, ["displayName"]) is not None: + setv(to_object, ["display_name"], getv(from_object, ["displayName"])) + + if getv(from_object, ["etag"]) is not None: + setv(to_object, ["etag"], getv(from_object, ["etag"])) + + if getv(from_object, ["name"]) is not None: + setv(to_object, ["name"], getv(from_object, ["name"])) + + if getv(from_object, ["spec"]) is not None: + setv(to_object, ["spec"], getv(from_object, ["spec"])) + + if getv(from_object, ["updateTime"]) is not None: + setv(to_object, ["update_time"], getv(from_object, ["updateTime"])) + + return to_object + + +def _AgentEngineOperation_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["name"], getv(from_object, ["name"])) + + if getv(from_object, ["metadata"]) is not None: + setv(to_object, ["metadata"], getv(from_object, ["metadata"])) + + if getv(from_object, ["done"]) is not None: + setv(to_object, ["done"], getv(from_object, ["done"])) + + if getv(from_object, ["error"]) is not None: + setv(to_object, ["error"], getv(from_object, ["error"])) + + if getv(from_object, ["response"]) is not None: + setv( + to_object, + ["response"], + _ReasoningEngine_from_vertex(getv(from_object, ["response"]), to_object), + ) + + return to_object + + +def _DeleteAgentEngineOperation_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["name"], getv(from_object, ["name"])) + + if getv(from_object, ["metadata"]) is not None: + setv(to_object, ["metadata"], getv(from_object, ["metadata"])) + + if getv(from_object, ["done"]) is not None: + setv(to_object, ["done"], getv(from_object, ["done"])) + + if getv(from_object, ["error"]) is not None: + setv(to_object, ["error"], getv(from_object, ["error"])) + + return to_object + + +def _ListReasoningEnginesResponse_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["nextPageToken"]) is not None: + setv(to_object, ["next_page_token"], getv(from_object, ["nextPageToken"])) + + if getv(from_object, ["reasoningEngines"]) is not None: + setv( + to_object, + ["reasoning_engines"], + [ + _ReasoningEngine_from_vertex(item, to_object) + for item in getv(from_object, ["reasoningEngines"]) + ], + ) + + return to_object + + +def _QueryReasoningEngineResponse_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["output"]) is not None: + setv(to_object, ["output"], getv(from_object, ["output"])) + + return to_object + + +class AgentEngines(_api_module.BaseModule): + def _create( + self, *, config: Optional[types.CreateAgentEngineConfigOrDict] = None + ) -> types.AgentEngineOperation: + """Creates a new Agent Engine.""" + + parameter_model = types._CreateAgentEngineRequestParameters( + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _CreateAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "reasoningEngines".format_map(request_url_dict) + else: + path = "reasoningEngines" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = "" if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _AgentEngineOperation_from_vertex(response_dict) + + return_value = types.AgentEngineOperation._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + + return return_value + + def delete( + self, + *, + name: str, + force: Optional[bool] = None, + config: Optional[types.DeleteAgentEngineConfigOrDict] = None, + ) -> types.DeleteAgentEngineOperation: + """Delete an Agent Engine resource. + + Args: + name (str): Required. The name of the Agent Engine to be deleted. + Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}` + or `reasoningEngines/{resource_id}`. + force (bool): Optional. If set to True, child resources will also be + deleted. Otherwise, the request will fail with FAILED_PRECONDITION + error when the Agent Engine has undeleted child resources. + Defaults to False. + config (DeleteAgentEngineConfig): Optional. Additional + configurations for deleting the Agent Engine. + """ + + parameter_model = types._DeleteAgentEngineRequestParameters( + name=name, + force=force, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _DeleteAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("delete", path, request_dict, http_options) + + response_dict = "" if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _DeleteAgentEngineOperation_from_vertex(response_dict) + + return_value = types.DeleteAgentEngineOperation._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + + return return_value + + def _get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineConfigOrDict] = None, + ) -> types.ReasoningEngine: + """Get an Agent Engine instance.""" + + parameter_model = types._GetAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _GetAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = "" if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _ReasoningEngine_from_vertex(response_dict) + + return_value = types.ReasoningEngine._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + + return return_value + + def _list( + self, *, config: Optional[types.ListAgentEngineConfigOrDict] = None + ) -> types.ListReasoningEnginesResponse: + """Lists Agent Engines.""" + + parameter_model = types._ListAgentEngineRequestParameters( + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _ListAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "reasoningEngines".format_map(request_url_dict) + else: + path = "reasoningEngines" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = "" if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _ListReasoningEnginesResponse_from_vertex(response_dict) + + return_value = types.ListReasoningEnginesResponse._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + + return return_value + + def _get_operation( + self, + *, + operation_name: str, + config: Optional[types.GetAgentEngineOperationConfigOrDict] = None, + ) -> types.AgentEngineOperation: + parameter_model = types._GetAgentEngineOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _GetAgentEngineOperationParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = "" if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _AgentEngineOperation_from_vertex(response_dict) + + return_value = types.AgentEngineOperation._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + + return return_value + + def _query( + self, + *, + name: str, + config: Optional[types.QueryAgentEngineConfigOrDict] = None, + ) -> types.QueryReasoningEngineResponse: + """Query an Agent Engine.""" + + parameter_model = types._QueryAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _QueryAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:query".format_map(request_url_dict) + else: + path = "{name}:query" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = "" if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _QueryReasoningEngineResponse_from_vertex(response_dict) + + return_value = types.QueryReasoningEngineResponse._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + + return return_value + + def _update( + self, + *, + name: str, + config: Optional[types.UpdateAgentEngineConfigOrDict] = None, + ) -> types.AgentEngineOperation: + """Updates an Agent Engine.""" + + parameter_model = types._UpdateAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _UpdateAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("patch", path, request_dict, http_options) + + response_dict = "" if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _AgentEngineOperation_from_vertex(response_dict) + + return_value = types.AgentEngineOperation._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + + return return_value + + def _list_pager( + self, *, config: Optional[types.ListAgentEngineConfigOrDict] = None + ) -> Pager[types.ReasoningEngine]: + return Pager( + "reasoning_engines", + self._list, + self._list(config=config), + config, + ) + + def get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineConfigOrDict] = None, + ) -> types.AgentEngine: + """Gets an agent engine. + + Args: + name (str): Required. A fully-qualified resource name or ID such as + "projects/123/locations/us-central1/reasoningEngines/456" or a + shortened name such as "reasoningEngines/456". + """ + agent = types.AgentEngine( + api_client=self, + api_async_client=AsyncAgentEngines(api_client_=self._api_client), + api_resource=self._get(name=name, config=config), + ) + self._register_api_methods(agent=agent) + return agent + + def create( + self, + *, + agent_engine: Any = None, + config: types.AgentEngineConfigOrDict, + ) -> Union[types.AgentEngine, types.AgentEngineOperation]: + """Creates an agent engine. + + The Agent Engine will be an instance of the `agent_engine` that + was passed in, running remotely on Vertex AI. + + Sample ``src_dir`` contents (e.g. ``./user_src_dir``): + + .. code-block:: python + + user_src_dir/ + |-- main.py + |-- requirements.txt + |-- user_code/ + | |-- utils.py + | |-- ... + |-- ... + + To build an Agent Engine with the above files, run: + + .. code-block:: python + + client = vertexai.Client( + project="your-project", + location="us-central1", + ) + remote_agent = client.agent_engines.create( + agent_engine=local_agent, + config=dict( + requirements=[ + # I.e. the PyPI dependencies listed in requirements.txt + "google-cloud-aiplatform[agent_engines,adk]", + ... + ], + extra_packages=[ + "./user_src_dir/main.py", # a single file + "./user_src_dir/user_code", # a directory + ... + ], + ), + ) + + Args: + agent_engine (Any): Optional. The Agent Engine to be created. If not + specified, this will correspond to a lightweight instance that + cannot be queried (but can be updated to future instances that can + be queried). + config (AgentEngineConfig): Required. The configurations to use for + creating the Agent Engine. + + Returns: + Union[types.AgentEngine, types.AgentEngineOperation]: + It returns the Agent Engine if `config.return_agent` is True, + otherwise it returns the operation for creating the Agent + Engine. + + Raises: + ValueError: If the `project` was not set using `client.Client`. + ValueError: If the `location` was not set using `client.Client`. + ValueError: If `config.staging_bucket` was not set when + `agent_engine` + is specified. + ValueError: If `config.staging_bucket` does not start with "gs://". + ValueError: If `config.extra_packages` is specified but + `agent_engine` + is None. + ValueError: If `config.requirements` is specified but `agent_engine` + is + None. + ValueError: If `config.env_vars` has a dictionary entry that does + not + correspond to an environment variable value or a SecretRef. + TypeError: If `config.env_vars` is not a dictionary. + FileNotFoundError: If `config.extra_packages` includes a file or + directory that does not exist. + IOError: If ``config.requirements` is a string that corresponds to a + nonexistent file. + """ + if isinstance(config, dict): + config = types.AgentEngineConfig.model_validate(config) + elif not isinstance(config, types.AgentEngineConfig): + raise TypeError( + "config must be a dict or AgentEngineConfig, but got" + f" {type(config)}." + ) + api_config = self._create_config( + mode="create", + agent_engine=agent_engine, + staging_bucket=config.staging_bucket, + requirements=config.requirements, + display_name=config.display_name, + description=config.description, + gcs_dir_name=config.gcs_dir_name, + extra_packages=config.extra_packages, + env_vars=config.env_vars, + ) + operation = self._create(config=api_config) + if config.return_agent: + return self._await_operation(operation_name=operation.name) + return operation + + def _create_config( + self, + *, + mode: str, + agent_engine: Any = None, + staging_bucket: Optional[str] = None, + requirements: Optional[Union[str, Sequence[str]]] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + gcs_dir_name: Optional[str] = None, + extra_packages: Optional[Sequence[str]] = None, + env_vars: Optional[dict[str, Union[str, Any]]] = None, + ): + import sys + from vertexai.agent_engines import _agent_engines + from vertexai.agent_engines import _utils + + config = {} + update_masks = [] + if mode not in ["create", "update"]: + raise ValueError(f"Unsupported mode: {mode}") + if agent_engine is None: + if requirements is not None: + raise ValueError("requirements must be None if agent_engine is None.") + if extra_packages is not None: + raise ValueError("extra_packages must be None if agent_engine is None.") + if display_name is not None: + update_masks.append("display_name") + config["display_name"] = display_name + if description is not None: + update_masks.append("description") + config["description"] = description + if agent_engine is not None: + sys_version = f"{sys.version_info.major}.{sys.version_info.minor}" + gcs_dir_name = gcs_dir_name or _agent_engines._DEFAULT_GCS_DIR_NAME + agent_engine = _agent_engines._validate_agent_engine_or_raise(agent_engine) + _agent_engines._validate_staging_bucket_or_raise(staging_bucket) + requirements = _agent_engines._validate_requirements_or_raise( + agent_engine=agent_engine, + requirements=requirements, + ) + extra_packages = _agent_engines._validate_extra_packages_or_raise( + extra_packages + ) + # Prepares the Agent Engine for creation/update in Vertex AI. This + # involves packaging and uploading the artifacts for agent_engine, + # requirements and extra_packages to `staging_bucket/gcs_dir_name`. + _agent_engines._prepare( + agent_engine=agent_engine, + requirements=requirements, + project=self._api_client.project, + location=self._api_client.location, + staging_bucket=staging_bucket, + gcs_dir_name=gcs_dir_name, + extra_packages=extra_packages, + ) + # Update the package spec. + update_masks.append("spec.package_spec.pickle_object_gcs_uri") + package_spec = { + "python_version": sys_version, + "pickle_object_gcs_uri": "{}/{}/{}".format( + staging_bucket, + gcs_dir_name, + _agent_engines._BLOB_FILENAME, + ), + } + if extra_packages: + update_masks.append("spec.package_spec.dependency_files_gcs_uri") + package_spec["dependency_files_gcs_uri"] = "{}/{}/{}".format( + staging_bucket, + gcs_dir_name, + _agent_engines._EXTRA_PACKAGES_FILE, + ) + if requirements: + update_masks.append("spec.package_spec.requirements_gcs_uri") + package_spec["requirements_gcs_uri"] = "{}/{}/{}".format( + staging_bucket, + gcs_dir_name, + _agent_engines._REQUIREMENTS_FILE, + ) + agent_engine_spec = {"package_spec": package_spec} + if env_vars is not None: + ( + deployment_spec, + deployment_update_masks, + ) = self._generate_deployment_spec_or_raise(env_vars=env_vars) + update_masks.extend(deployment_update_masks) + agent_engine_spec["deployment_spec"] = deployment_spec + class_methods = _agent_engines._generate_class_methods_spec_or_raise( + agent_engine=agent_engine, + operations=_agent_engines._get_registered_operations(agent_engine), + ) + agent_engine_spec["class_methods"] = [ + _utils.to_dict(class_method) for class_method in class_methods + ] + update_masks.append("spec.class_methods") + agent_engine_spec["agent_framework"] = _agent_engines._get_agent_framework( + agent_engine + ) + update_masks.append("spec.agent_framework") + config["spec"] = agent_engine_spec + if update_masks and mode == "update": + config["update_mask"] = ",".join(update_masks) + return config + + def _generate_deployment_spec_or_raise( + self, + *, + env_vars: Optional[dict[str, Union[str, Any]]] = None, + ): + deployment_spec = {} + update_masks = [] + if env_vars: + deployment_spec["env"] = [] + deployment_spec["secret_env"] = [] + if isinstance(env_vars, dict): + self._update_deployment_spec_with_env_vars_dict_or_raise( + deployment_spec=deployment_spec, + env_vars=env_vars, + ) + else: + raise TypeError(f"env_vars must be a dict, but got {type(env_vars)}.") + if deployment_spec.get("env"): + update_masks.append("spec.deployment_spec.env") + if deployment_spec.get("secret_env"): + update_masks.append("spec.deployment_spec.secret_env") + return deployment_spec, update_masks + + def _update_deployment_spec_with_env_vars_dict_or_raise( + self, + *, + deployment_spec: dict[str, Any], + env_vars: dict[str, Any], + ) -> None: + for key, value in env_vars.items(): + if isinstance(value, dict): + if "secret_env" not in deployment_spec: + deployment_spec["secret_env"] = [] + deployment_spec["secret_env"].append({"name": key, "secret_ref": value}) + elif isinstance(value, str): + if "env" not in deployment_spec: + deployment_spec["env"] = [] + deployment_spec["env"].append({"name": key, "value": value}) + else: + raise TypeError( + f"Unknown value type in env_vars for {key}. " + f"Must be a str or SecretRef: {value}" + ) + + def _await_operation( + self, + *, + operation_name: str, + poll_interval_seconds: int = 10, + ) -> types.AgentEngine: + """Waits for the operation for creating an agent engine to complete. + + Args: + operation_name (str): Required. The name of the operation for + creating the Agent Engine. + poll_interval_seconds (int): The number of seconds to wait between + each poll. + + Returns: + AgentEngine: The Agent Engine that was created. + """ + operation = self._get_operation(operation_name=operation_name) + while not operation.done: + time.sleep(poll_interval_seconds) + operation = self._get_operation(operation_name=operation.name) + + agent = types.AgentEngine( + api_client=self, + api_async_client=AsyncAgentEngines(api_client_=self._api_client), + api_resource=operation.response, + ) + return self._register_api_methods(agent=agent) + + def _register_api_methods(self, *, agent: types.AgentEngine) -> types.AgentEngine: + """Registers the API methods for the agent engine.""" + from vertexai.agent_engines import _agent_engines + + try: + _agent_engines._register_api_methods_or_raise( + agent, + wrap_operation_fn={ + "": _agent_engines_utils._wrap_query_operation, + "async": _agent_engines_utils._wrap_async_query_operation, + "stream": _agent_engines_utils._wrap_stream_query_operation, + "async_stream": _agent_engines_utils._wrap_async_stream_query_operation, + }, + ) + except Exception as e: + logger.warning( + _agent_engines._FAILED_TO_REGISTER_API_METHODS_WARNING_TEMPLATE, + e, + ) + return agent + + def list( + self, *, config: Optional[types.ListAgentEngineConfigOrDict] = None + ) -> Iterator[types.AgentEngine]: + """List all instances of Agent Engine matching the filter. + + Example Usage: + + .. code-block:: python + import vertexai + + client = vertexai.Client(project="my_project", + location="us-central1") + for agent in client.agent_engines.list( + config={"filter": "'display_name="My Custom Agent"'}, + ): + print(agent.api_resource.name) + + Args: + config (ListAgentEngineConfig): Optional. The config (e.g. filter) + for the agents to be listed. + + Returns: + Iterable[AgentEngine]: An iterable of Agent Engines matching the + filter. + """ + + for reasoning_engine in self._list_pager(config=config): + yield types.AgentEngine( + api_client=self, + api_async_client=AsyncAgentEngines(api_client_=self._api_client), + api_resource=reasoning_engine, + ) + + def update( + self, + *, + name: str, + agent_engine: Any = None, + config: types.AgentEngineConfigOrDict, + ) -> types.AgentEngine: + """Updates an existing Agent Engine. + + This method updates the configuration of an existing Agent Engine + running + remotely, which is identified by its name. + + Args: + name (str): Required. A fully-qualified resource name or ID such as + "projects/123/locations/us-central1/reasoningEngines/456" or a + shortened name such as "reasoningEngines/456". + agent_engine (Any): Optional. The instance to be used as the updated + Agent Engine. If it is not specified, the existing instance will + be used. + config (AgentEngineConfig): Optional. The configurations to use for + updating the Agent Engine. + + Returns: + Union[types.AgentEngine, types.UpdateAgentEngineOperation]: + It returns the updated Agent Engine if `config.return_agent` is + True, otherwise it returns the operation for updating the Agent + Engine. + + Raises: + ValueError: If the `project` was not set using `client.Client`. + ValueError: If the `location` was not set using `client.Client`. + ValueError: If `config.staging_bucket` was not set when `agent_engine` + is specified. + ValueError: If `config.staging_bucket` does not start with "gs://". + ValueError: If `config.extra_packages` is specified but `agent_engine` + is None. + ValueError: If `config.requirements` is specified but `agent_engine` + is + None. + ValueError: If `config.env_vars` has a dictionary entry that does not + correspond to an environment variable value or a SecretRef. + TypeError: If `config.env_vars` is not a dictionary. + FileNotFoundError: If `config.extra_packages` includes a file or + directory that does not exist. + IOError: If `config.requirements` is a string that corresponds to a + nonexistent file. + """ + if isinstance(config, dict): + config = types.AgentEngineConfig.model_validate(config) + elif not isinstance(config, types.AgentEngineConfig): + raise TypeError( + "config must be a dict or AgentEngineConfig, but got" + f" {type(config)}." + ) + api_config = self._create_config( + mode="update", + agent_engine=agent_engine, + staging_bucket=config.staging_bucket, + requirements=config.requirements, + display_name=config.display_name, + description=config.description, + gcs_dir_name=config.gcs_dir_name, + extra_packages=config.extra_packages, + env_vars=config.env_vars, + ) + operation = self._update(name=name, config=api_config) + if config.return_agent: + return self._await_operation(operation_name=operation.name) + return operation + + def _stream_query( + self, + *, + name: str, + config: Optional[types.QueryAgentEngineConfigOrDict] = None, + ) -> Iterator[Any]: + """Streams the response of the agent engine.""" + parameter_model = types._QueryAgentEngineRequestParameters( + name=name, + config=config, + ) + request_dict = _QueryAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:streamQuery?alt=sse".format_map(request_url_dict) + else: + path = "{name}:streamQuery?alt=sse" + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + for response in self._api_client.request_streamed( + "post", path, request_dict, http_options + ): + yield response + + +class AsyncAgentEngines(_api_module.BaseModule): + async def _create( + self, *, config: Optional[types.CreateAgentEngineConfigOrDict] = None + ) -> types.AgentEngineOperation: + """Creates a new Agent Engine.""" + + parameter_model = types._CreateAgentEngineRequestParameters( + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _CreateAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "reasoningEngines".format_map(request_url_dict) + else: + path = "reasoningEngines" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = "" if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _AgentEngineOperation_from_vertex(response_dict) + + return_value = types.AgentEngineOperation._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + + return return_value + + async def delete( + self, + *, + name: str, + force: Optional[bool] = None, + config: Optional[types.DeleteAgentEngineConfigOrDict] = None, + ) -> types.DeleteAgentEngineOperation: + """Delete an Agent Engine resource. + + Args: + name (str): Required. The name of the Agent Engine to be deleted. + Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}` + or `reasoningEngines/{resource_id}`. + force (bool): Optional. If set to True, child resources will also be + deleted. Otherwise, the request will fail with FAILED_PRECONDITION + error when the Agent Engine has undeleted child resources. + Defaults to False. + config (DeleteAgentEngineConfig): Optional. Additional + configurations for deleting the Agent Engine. + """ + + parameter_model = types._DeleteAgentEngineRequestParameters( + name=name, + force=force, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _DeleteAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "delete", path, request_dict, http_options + ) + + response_dict = "" if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _DeleteAgentEngineOperation_from_vertex(response_dict) + + return_value = types.DeleteAgentEngineOperation._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + + return return_value + + async def _get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineConfigOrDict] = None, + ) -> types.ReasoningEngine: + """Get an Agent Engine instance.""" + + parameter_model = types._GetAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _GetAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = "" if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _ReasoningEngine_from_vertex(response_dict) + + return_value = types.ReasoningEngine._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + + return return_value + + async def _list( + self, *, config: Optional[types.ListAgentEngineConfigOrDict] = None + ) -> types.ListReasoningEnginesResponse: + """Lists Agent Engines.""" + + parameter_model = types._ListAgentEngineRequestParameters( + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _ListAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "reasoningEngines".format_map(request_url_dict) + else: + path = "reasoningEngines" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = "" if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _ListReasoningEnginesResponse_from_vertex(response_dict) + + return_value = types.ListReasoningEnginesResponse._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + + return return_value + + async def _get_operation( + self, + *, + operation_name: str, + config: Optional[types.GetAgentEngineOperationConfigOrDict] = None, + ) -> types.AgentEngineOperation: + parameter_model = types._GetAgentEngineOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _GetAgentEngineOperationParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = "" if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _AgentEngineOperation_from_vertex(response_dict) + + return_value = types.AgentEngineOperation._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + + return return_value + + async def _query( + self, + *, + name: str, + config: Optional[types.QueryAgentEngineConfigOrDict] = None, + ) -> types.QueryReasoningEngineResponse: + """Query an Agent Engine.""" + + parameter_model = types._QueryAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _QueryAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:query".format_map(request_url_dict) + else: + path = "{name}:query" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = "" if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _QueryReasoningEngineResponse_from_vertex(response_dict) + + return_value = types.QueryReasoningEngineResponse._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + + return return_value + + async def _update( + self, + *, + name: str, + config: Optional[types.UpdateAgentEngineConfigOrDict] = None, + ) -> types.AgentEngineOperation: + """Updates an Agent Engine.""" + + parameter_model = types._UpdateAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _UpdateAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "patch", path, request_dict, http_options + ) + + response_dict = "" if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _AgentEngineOperation_from_vertex(response_dict) + + return_value = types.AgentEngineOperation._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + + return return_value diff --git a/vertexai/_genai/client.py b/vertexai/_genai/client.py index 3d06a34064..ee89116000 100644 --- a/vertexai/_genai/client.py +++ b/vertexai/_genai/client.py @@ -18,7 +18,7 @@ import google.auth from google.genai import _common -from google.genai import client +from google.genai import client as genai_client from google.genai import types @@ -26,9 +26,10 @@ class AsyncClient: """Async Client for the GenAI SDK.""" - def __init__(self, api_client: client.Client): + def __init__(self, api_client: genai_client.Client): self._api_client = api_client self._evals = None + self._agent_engines = None @property @_common.experimental_warning( @@ -51,6 +52,28 @@ def evals(self): # TODO(b/424176979): add async prompt optimizer here. + @property + @_common.experimental_warning( + "The Vertex SDK GenAI agent engines module is experimental, " + "and may change in future versions." + ) + def agent_engines(self): + if self._agent_engines is None: + try: + # We need to lazy load the agent_engines module to handle the + # possibility of ImportError when dependencies are not installed. + self._agent_engines = importlib.import_module( + ".agent_engines", + __package__, + ) + except ImportError as e: + raise ImportError( + "The 'agent_engines' module requires 'additional packages'. " + "Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._agent_engines.AsyncAgentEngines(self._api_client) + class Client: """Client for the GenAI SDK. @@ -64,7 +87,7 @@ def __init__( credentials: Optional[google.auth.credentials.Credentials] = None, project: Optional[str] = None, location: Optional[str] = None, - debug_config: Optional[client.DebugConfig] = None, + debug_config: Optional[genai_client.DebugConfig] = None, http_options: Optional[Union[types.HttpOptions, types.HttpOptionsDict]] = None, ): """Initializes the client. @@ -89,11 +112,11 @@ def __init__( for the client. """ - self._debug_config = debug_config or client.DebugConfig() + self._debug_config = debug_config or genai_client.DebugConfig() if isinstance(http_options, dict): http_options = types.HttpOptions(**http_options) - self._api_client = client.Client._get_api_client( + self._api_client = genai_client.Client._get_api_client( vertexai=True, credentials=credentials, project=project, @@ -104,6 +127,7 @@ def __init__( self._aio = AsyncClient(self._api_client) self._evals = None self._prompt_optimizer = None + self._agent_engines = None @property @_common.experimental_warning( @@ -131,3 +155,52 @@ def evals(self): ) def aio(self): return self._aio + + # This is only used for replay tests + @staticmethod + def _get_api_client( + api_key: Optional[str] = None, + credentials: Optional[google.auth.credentials.Credentials] = None, + project: Optional[str] = None, + location: Optional[str] = None, + debug_config: Optional[genai_client.DebugConfig] = None, + http_options: Optional[genai_client.HttpOptions] = None, + ) -> Optional[genai_client.BaseApiClient]: + if debug_config and debug_config.client_mode in [ + "record", + "replay", + "auto", + ]: + return genai_client.ReplayApiClient( + mode=debug_config.client_mode, # type: ignore[arg-type] + replay_id=debug_config.replay_id, # type: ignore[arg-type] + replays_directory=debug_config.replays_directory, + vertexai=True, # type: ignore[arg-type] + api_key=api_key, + credentials=credentials, + project=project, + location=location, + http_options=http_options, + ) + + @property + @_common.experimental_warning( + "The Vertex SDK GenAI agent engines module is experimental, " + "and may change in future versions." + ) + def agent_engines(self): + if self._agent_engines is None: + try: + # We need to lazy load the agent_engines module to handle the + # possibility of ImportError when dependencies are not installed. + self._agent_engines = importlib.import_module( + ".agent_engines", + __package__, + ) + except ImportError as e: + raise ImportError( + "The 'agent_engines' module requires 'additional packages'. " + "Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._agent_engines.AgentEngines(self._api_client) diff --git a/vertexai/_genai/evals.py b/vertexai/_genai/evals.py index 506ea8200b..f504002ff8 100644 --- a/vertexai/_genai/evals.py +++ b/vertexai/_genai/evals.py @@ -15,6 +15,7 @@ # Code generated by the Google Gen AI SDK generator DO NOT EDIT. +import json import logging from typing import Any, Callable, Optional, Union from urllib.parse import urlencode @@ -22,7 +23,6 @@ from google.genai import _api_module from google.genai import _common from google.genai import types as genai_types -from google.genai._api_client import BaseApiClient from google.genai._common import get_value_by_path as getv from google.genai._common import set_value_by_path as setv import pandas as pd @@ -35,7 +35,6 @@ def _BleuInstance_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -50,7 +49,6 @@ def _BleuInstance_to_vertex( def _BleuSpec_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -66,7 +64,6 @@ def _BleuSpec_to_vertex( def _BleuInput_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -76,7 +73,7 @@ def _BleuInput_to_vertex( to_object, ["instances"], [ - _BleuInstance_to_vertex(api_client, item, to_object) + _BleuInstance_to_vertex(item, to_object) for item in getv(from_object, ["instances"]) ], ) @@ -85,16 +82,13 @@ def _BleuInput_to_vertex( setv( to_object, ["metricSpec"], - _BleuSpec_to_vertex( - api_client, getv(from_object, ["metric_spec"]), to_object - ), + _BleuSpec_to_vertex(getv(from_object, ["metric_spec"]), to_object), ) return to_object def _ExactMatchInstance_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -109,7 +103,6 @@ def _ExactMatchInstance_to_vertex( def _ExactMatchSpec_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -119,7 +112,6 @@ def _ExactMatchSpec_to_vertex( def _ExactMatchInput_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -129,7 +121,7 @@ def _ExactMatchInput_to_vertex( to_object, ["instances"], [ - _ExactMatchInstance_to_vertex(api_client, item, to_object) + _ExactMatchInstance_to_vertex(item, to_object) for item in getv(from_object, ["instances"]) ], ) @@ -138,16 +130,13 @@ def _ExactMatchInput_to_vertex( setv( to_object, ["metricSpec"], - _ExactMatchSpec_to_vertex( - api_client, getv(from_object, ["metric_spec"]), to_object - ), + _ExactMatchSpec_to_vertex(getv(from_object, ["metric_spec"]), to_object), ) return to_object def _RougeInstance_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -162,7 +151,6 @@ def _RougeInstance_to_vertex( def _RougeSpec_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -184,7 +172,6 @@ def _RougeSpec_to_vertex( def _RougeInput_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -194,7 +181,7 @@ def _RougeInput_to_vertex( to_object, ["instances"], [ - _RougeInstance_to_vertex(api_client, item, to_object) + _RougeInstance_to_vertex(item, to_object) for item in getv(from_object, ["instances"]) ], ) @@ -203,16 +190,13 @@ def _RougeInput_to_vertex( setv( to_object, ["metricSpec"], - _RougeSpec_to_vertex( - api_client, getv(from_object, ["metric_spec"]), to_object - ), + _RougeSpec_to_vertex(getv(from_object, ["metric_spec"]), to_object), ) return to_object def _PointwiseMetricInstance_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -224,7 +208,6 @@ def _PointwiseMetricInstance_to_vertex( def _PointwiseMetricSpec_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -254,7 +237,6 @@ def _PointwiseMetricSpec_to_vertex( def _PointwiseMetricInput_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -264,7 +246,7 @@ def _PointwiseMetricInput_to_vertex( to_object, ["instance"], _PointwiseMetricInstance_to_vertex( - api_client, getv(from_object, ["instance"]), to_object + getv(from_object, ["instance"]), to_object ), ) @@ -273,7 +255,7 @@ def _PointwiseMetricInput_to_vertex( to_object, ["metricSpec"], _PointwiseMetricSpec_to_vertex( - api_client, getv(from_object, ["metric_spec"]), to_object + getv(from_object, ["metric_spec"]), to_object ), ) @@ -281,7 +263,6 @@ def _PointwiseMetricInput_to_vertex( def _PairwiseMetricInstance_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -294,7 +275,6 @@ def _PairwiseMetricInstance_to_vertex( def _PairwiseMetricSpec_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -338,7 +318,6 @@ def _PairwiseMetricSpec_to_vertex( def _PairwiseMetricInput_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -348,7 +327,7 @@ def _PairwiseMetricInput_to_vertex( to_object, ["instance"], _PairwiseMetricInstance_to_vertex( - api_client, getv(from_object, ["instance"]), to_object + getv(from_object, ["instance"]), to_object ), ) @@ -357,7 +336,7 @@ def _PairwiseMetricInput_to_vertex( to_object, ["metricSpec"], _PairwiseMetricSpec_to_vertex( - api_client, getv(from_object, ["metric_spec"]), to_object + getv(from_object, ["metric_spec"]), to_object ), ) @@ -365,7 +344,6 @@ def _PairwiseMetricInput_to_vertex( def _ToolCallValidInstance_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -380,7 +358,6 @@ def _ToolCallValidInstance_to_vertex( def _ToolCallValidSpec_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -390,7 +367,6 @@ def _ToolCallValidSpec_to_vertex( def _ToolCallValidInput_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -400,7 +376,7 @@ def _ToolCallValidInput_to_vertex( to_object, ["instances"], [ - _ToolCallValidInstance_to_vertex(api_client, item, to_object) + _ToolCallValidInstance_to_vertex(item, to_object) for item in getv(from_object, ["instances"]) ], ) @@ -409,16 +385,13 @@ def _ToolCallValidInput_to_vertex( setv( to_object, ["metricSpec"], - _ToolCallValidSpec_to_vertex( - api_client, getv(from_object, ["metric_spec"]), to_object - ), + _ToolCallValidSpec_to_vertex(getv(from_object, ["metric_spec"]), to_object), ) return to_object def _ToolNameMatchInstance_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -433,7 +406,6 @@ def _ToolNameMatchInstance_to_vertex( def _ToolNameMatchSpec_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -443,7 +415,6 @@ def _ToolNameMatchSpec_to_vertex( def _ToolNameMatchInput_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -453,7 +424,7 @@ def _ToolNameMatchInput_to_vertex( to_object, ["instances"], [ - _ToolNameMatchInstance_to_vertex(api_client, item, to_object) + _ToolNameMatchInstance_to_vertex(item, to_object) for item in getv(from_object, ["instances"]) ], ) @@ -462,16 +433,13 @@ def _ToolNameMatchInput_to_vertex( setv( to_object, ["metricSpec"], - _ToolNameMatchSpec_to_vertex( - api_client, getv(from_object, ["metric_spec"]), to_object - ), + _ToolNameMatchSpec_to_vertex(getv(from_object, ["metric_spec"]), to_object), ) return to_object def _ToolParameterKeyMatchInstance_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -486,7 +454,6 @@ def _ToolParameterKeyMatchInstance_to_vertex( def _ToolParameterKeyMatchSpec_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -496,7 +463,6 @@ def _ToolParameterKeyMatchSpec_to_vertex( def _ToolParameterKeyMatchInput_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -506,7 +472,7 @@ def _ToolParameterKeyMatchInput_to_vertex( to_object, ["instances"], [ - _ToolParameterKeyMatchInstance_to_vertex(api_client, item, to_object) + _ToolParameterKeyMatchInstance_to_vertex(item, to_object) for item in getv(from_object, ["instances"]) ], ) @@ -516,7 +482,7 @@ def _ToolParameterKeyMatchInput_to_vertex( to_object, ["metricSpec"], _ToolParameterKeyMatchSpec_to_vertex( - api_client, getv(from_object, ["metric_spec"]), to_object + getv(from_object, ["metric_spec"]), to_object ), ) @@ -524,7 +490,6 @@ def _ToolParameterKeyMatchInput_to_vertex( def _ToolParameterKVMatchInstance_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -539,7 +504,6 @@ def _ToolParameterKVMatchInstance_to_vertex( def _ToolParameterKVMatchSpec_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -555,7 +519,6 @@ def _ToolParameterKVMatchSpec_to_vertex( def _ToolParameterKVMatchInput_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -565,7 +528,7 @@ def _ToolParameterKVMatchInput_to_vertex( to_object, ["instances"], [ - _ToolParameterKVMatchInstance_to_vertex(api_client, item, to_object) + _ToolParameterKVMatchInstance_to_vertex(item, to_object) for item in getv(from_object, ["instances"]) ], ) @@ -575,7 +538,7 @@ def _ToolParameterKVMatchInput_to_vertex( to_object, ["metricSpec"], _ToolParameterKVMatchSpec_to_vertex( - api_client, getv(from_object, ["metric_spec"]), to_object + getv(from_object, ["metric_spec"]), to_object ), ) @@ -583,7 +546,6 @@ def _ToolParameterKVMatchInput_to_vertex( def _AutoraterConfig_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -605,7 +567,6 @@ def _AutoraterConfig_to_vertex( def _EvaluateInstancesRequestParameters_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -614,9 +575,7 @@ def _EvaluateInstancesRequestParameters_to_vertex( setv( to_object, ["bleuInput"], - _BleuInput_to_vertex( - api_client, getv(from_object, ["bleu_input"]), to_object - ), + _BleuInput_to_vertex(getv(from_object, ["bleu_input"]), to_object), ) if getv(from_object, ["exact_match_input"]) is not None: @@ -624,7 +583,7 @@ def _EvaluateInstancesRequestParameters_to_vertex( to_object, ["exactMatchInput"], _ExactMatchInput_to_vertex( - api_client, getv(from_object, ["exact_match_input"]), to_object + getv(from_object, ["exact_match_input"]), to_object ), ) @@ -632,9 +591,7 @@ def _EvaluateInstancesRequestParameters_to_vertex( setv( to_object, ["rougeInput"], - _RougeInput_to_vertex( - api_client, getv(from_object, ["rouge_input"]), to_object - ), + _RougeInput_to_vertex(getv(from_object, ["rouge_input"]), to_object), ) if getv(from_object, ["pointwise_metric_input"]) is not None: @@ -642,9 +599,7 @@ def _EvaluateInstancesRequestParameters_to_vertex( to_object, ["pointwiseMetricInput"], _PointwiseMetricInput_to_vertex( - api_client, - getv(from_object, ["pointwise_metric_input"]), - to_object, + getv(from_object, ["pointwise_metric_input"]), to_object ), ) @@ -653,9 +608,7 @@ def _EvaluateInstancesRequestParameters_to_vertex( to_object, ["pairwiseMetricInput"], _PairwiseMetricInput_to_vertex( - api_client, - getv(from_object, ["pairwise_metric_input"]), - to_object, + getv(from_object, ["pairwise_metric_input"]), to_object ), ) @@ -664,9 +617,7 @@ def _EvaluateInstancesRequestParameters_to_vertex( to_object, ["toolCallValidInput"], _ToolCallValidInput_to_vertex( - api_client, - getv(from_object, ["tool_call_valid_input"]), - to_object, + getv(from_object, ["tool_call_valid_input"]), to_object ), ) @@ -675,9 +626,7 @@ def _EvaluateInstancesRequestParameters_to_vertex( to_object, ["toolNameMatchInput"], _ToolNameMatchInput_to_vertex( - api_client, - getv(from_object, ["tool_name_match_input"]), - to_object, + getv(from_object, ["tool_name_match_input"]), to_object ), ) @@ -686,9 +635,7 @@ def _EvaluateInstancesRequestParameters_to_vertex( to_object, ["toolParameterKeyMatchInput"], _ToolParameterKeyMatchInput_to_vertex( - api_client, - getv(from_object, ["tool_parameter_key_match_input"]), - to_object, + getv(from_object, ["tool_parameter_key_match_input"]), to_object ), ) @@ -697,9 +644,7 @@ def _EvaluateInstancesRequestParameters_to_vertex( to_object, ["toolParameterKvMatchInput"], _ToolParameterKVMatchInput_to_vertex( - api_client, - getv(from_object, ["tool_parameter_kv_match_input"]), - to_object, + getv(from_object, ["tool_parameter_kv_match_input"]), to_object ), ) @@ -708,7 +653,7 @@ def _EvaluateInstancesRequestParameters_to_vertex( to_object, ["autoraterConfig"], _AutoraterConfig_to_vertex( - api_client, getv(from_object, ["autorater_config"]), to_object + getv(from_object, ["autorater_config"]), to_object ), ) @@ -719,7 +664,6 @@ def _EvaluateInstancesRequestParameters_to_vertex( def _EvaluationDataset_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -743,7 +687,6 @@ def _EvaluationDataset_to_vertex( def _Metric_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -792,7 +735,6 @@ def _Metric_to_vertex( def _OutputConfig_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -808,7 +750,6 @@ def _OutputConfig_to_vertex( def _EvaluateDatasetRequestParameters_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -817,9 +758,7 @@ def _EvaluateDatasetRequestParameters_to_vertex( setv( to_object, ["dataset"], - _EvaluationDataset_to_vertex( - api_client, getv(from_object, ["dataset"]), to_object - ), + _EvaluationDataset_to_vertex(getv(from_object, ["dataset"]), to_object), ) if getv(from_object, ["metrics"]) is not None: @@ -827,7 +766,7 @@ def _EvaluateDatasetRequestParameters_to_vertex( to_object, ["metrics"], [ - _Metric_to_vertex(api_client, item, to_object) + _Metric_to_vertex(item, to_object) for item in getv(from_object, ["metrics"]) ], ) @@ -836,9 +775,7 @@ def _EvaluateDatasetRequestParameters_to_vertex( setv( to_object, ["outputConfig"], - _OutputConfig_to_vertex( - api_client, getv(from_object, ["output_config"]), to_object - ), + _OutputConfig_to_vertex(getv(from_object, ["output_config"]), to_object), ) if getv(from_object, ["autorater_config"]) is not None: @@ -846,7 +783,7 @@ def _EvaluateDatasetRequestParameters_to_vertex( to_object, ["autoraterConfig"], _AutoraterConfig_to_vertex( - api_client, getv(from_object, ["autorater_config"]), to_object + getv(from_object, ["autorater_config"]), to_object ), ) @@ -857,7 +794,6 @@ def _EvaluateDatasetRequestParameters_to_vertex( def _EvaluateInstancesResponse_from_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -984,7 +920,6 @@ def _EvaluateInstancesResponse_from_vertex( def _EvaluationDataset_from_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -1008,7 +943,6 @@ def _EvaluationDataset_from_vertex( def _EvaluateDatasetOperation_from_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -1029,9 +963,7 @@ def _EvaluateDatasetOperation_from_vertex( setv( to_object, ["response"], - _EvaluationDataset_from_vertex( - api_client, getv(from_object, ["response"]), to_object - ), + _EvaluationDataset_from_vertex(getv(from_object, ["response"]), to_object), ) return to_object @@ -1078,7 +1010,7 @@ def _evaluate_instances( raise ValueError("This method is only supported in the Vertex AI client.") else: request_dict = _EvaluateInstancesRequestParameters_to_vertex( - self._api_client, parameter_model + parameter_model ) request_url_dict = request_dict.get("_url") if request_url_dict: @@ -1102,19 +1034,18 @@ def _evaluate_instances( request_dict = _common.convert_to_dict(request_dict) request_dict = _common.encode_unserializable_types(request_dict) - response_dict = self._api_client.request( - "post", path, request_dict, http_options - ) + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = "" if not response.body else json.loads(response.body) if self._api_client.vertexai: - response_dict = _EvaluateInstancesResponse_from_vertex( - self._api_client, response_dict - ) + response_dict = _EvaluateInstancesResponse_from_vertex(response_dict) return_value = types.EvaluateInstancesResponse._from_response( response=response_dict, kwargs=parameter_model.model_dump() ) self._api_client._verify_response(return_value) + return return_value def batch_eval( @@ -1140,9 +1071,7 @@ def batch_eval( if not self._api_client.vertexai: raise ValueError("This method is only supported in the Vertex AI client.") else: - request_dict = _EvaluateDatasetRequestParameters_to_vertex( - self._api_client, parameter_model - ) + request_dict = _EvaluateDatasetRequestParameters_to_vertex(parameter_model) request_url_dict = request_dict.get("_url") if request_url_dict: path = ":evaluateDataset".format_map(request_url_dict) @@ -1165,19 +1094,18 @@ def batch_eval( request_dict = _common.convert_to_dict(request_dict) request_dict = _common.encode_unserializable_types(request_dict) - response_dict = self._api_client.request( - "post", path, request_dict, http_options - ) + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = "" if not response.body else json.loads(response.body) if self._api_client.vertexai: - response_dict = _EvaluateDatasetOperation_from_vertex( - self._api_client, response_dict - ) + response_dict = _EvaluateDatasetOperation_from_vertex(response_dict) return_value = types.EvaluateDatasetOperation._from_response( response=response_dict, kwargs=parameter_model.model_dump() ) self._api_client._verify_response(return_value) + return return_value def run(self) -> types.EvaluateInstancesResponse: @@ -1299,7 +1227,7 @@ async def _evaluate_instances( raise ValueError("This method is only supported in the Vertex AI client.") else: request_dict = _EvaluateInstancesRequestParameters_to_vertex( - self._api_client, parameter_model + parameter_model ) request_url_dict = request_dict.get("_url") if request_url_dict: @@ -1323,19 +1251,20 @@ async def _evaluate_instances( request_dict = _common.convert_to_dict(request_dict) request_dict = _common.encode_unserializable_types(request_dict) - response_dict = await self._api_client.async_request( + response = await self._api_client.async_request( "post", path, request_dict, http_options ) + response_dict = "" if not response.body else json.loads(response.body) + if self._api_client.vertexai: - response_dict = _EvaluateInstancesResponse_from_vertex( - self._api_client, response_dict - ) + response_dict = _EvaluateInstancesResponse_from_vertex(response_dict) return_value = types.EvaluateInstancesResponse._from_response( response=response_dict, kwargs=parameter_model.model_dump() ) self._api_client._verify_response(return_value) + return return_value async def batch_eval( @@ -1361,9 +1290,7 @@ async def batch_eval( if not self._api_client.vertexai: raise ValueError("This method is only supported in the Vertex AI client.") else: - request_dict = _EvaluateDatasetRequestParameters_to_vertex( - self._api_client, parameter_model - ) + request_dict = _EvaluateDatasetRequestParameters_to_vertex(parameter_model) request_url_dict = request_dict.get("_url") if request_url_dict: path = ":evaluateDataset".format_map(request_url_dict) @@ -1386,17 +1313,18 @@ async def batch_eval( request_dict = _common.convert_to_dict(request_dict) request_dict = _common.encode_unserializable_types(request_dict) - response_dict = await self._api_client.async_request( + response = await self._api_client.async_request( "post", path, request_dict, http_options ) + response_dict = "" if not response.body else json.loads(response.body) + if self._api_client.vertexai: - response_dict = _EvaluateDatasetOperation_from_vertex( - self._api_client, response_dict - ) + response_dict = _EvaluateDatasetOperation_from_vertex(response_dict) return_value = types.EvaluateDatasetOperation._from_response( response=response_dict, kwargs=parameter_model.model_dump() ) self._api_client._verify_response(return_value) + return return_value diff --git a/vertexai/_genai/prompt_optimizer.py b/vertexai/_genai/prompt_optimizer.py index 6065260d9d..97bc701650 100644 --- a/vertexai/_genai/prompt_optimizer.py +++ b/vertexai/_genai/prompt_optimizer.py @@ -16,6 +16,7 @@ # Code generated by the Google Gen AI SDK generator DO NOT EDIT. import datetime +import json import logging from typing import Any, Optional, Union from urllib.parse import urlencode @@ -24,7 +25,6 @@ from google.genai import _api_module from google.genai import _common from google.genai import types as genai_types -from google.genai._api_client import BaseApiClient from google.genai._common import get_value_by_path as getv from google.genai._common import set_value_by_path as setv @@ -35,7 +35,6 @@ def _OptimizeRequestParameters_to_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -47,7 +46,6 @@ def _OptimizeRequestParameters_to_vertex( def _OptimizeResponse_from_vertex( - api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -72,9 +70,7 @@ def optimize_dummy( if not self._api_client.vertexai: raise ValueError("This method is only supported in the Vertex AI client.") else: - request_dict = _OptimizeRequestParameters_to_vertex( - self._api_client, parameter_model - ) + request_dict = _OptimizeRequestParameters_to_vertex(parameter_model) request_url_dict = request_dict.get("_url") if request_url_dict: path = ":optimize".format_map(request_url_dict) @@ -97,19 +93,18 @@ def optimize_dummy( request_dict = _common.convert_to_dict(request_dict) request_dict = _common.encode_unserializable_types(request_dict) - response_dict = self._api_client.request( - "post", path, request_dict, http_options - ) + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = "" if not response.body else json.loads(response.body) if self._api_client.vertexai: - response_dict = _OptimizeResponse_from_vertex( - self._api_client, response_dict - ) + response_dict = _OptimizeResponse_from_vertex(response_dict) return_value = types.OptimizeResponse._from_response( response=response_dict, kwargs=parameter_model.model_dump() ) self._api_client._verify_response(return_value) + return return_value """Prompt Optimizer PO-Data.""" @@ -218,9 +213,7 @@ async def optimize_dummy( if not self._api_client.vertexai: raise ValueError("This method is only supported in the Vertex AI client.") else: - request_dict = _OptimizeRequestParameters_to_vertex( - self._api_client, parameter_model - ) + request_dict = _OptimizeRequestParameters_to_vertex(parameter_model) request_url_dict = request_dict.get("_url") if request_url_dict: path = ":optimize".format_map(request_url_dict) @@ -243,17 +236,18 @@ async def optimize_dummy( request_dict = _common.convert_to_dict(request_dict) request_dict = _common.encode_unserializable_types(request_dict) - response_dict = await self._api_client.async_request( + response = await self._api_client.async_request( "post", path, request_dict, http_options ) + response_dict = "" if not response.body else json.loads(response.body) + if self._api_client.vertexai: - response_dict = _OptimizeResponse_from_vertex( - self._api_client, response_dict - ) + response_dict = _OptimizeResponse_from_vertex(response_dict) return_value = types.OptimizeResponse._from_response( response=response_dict, kwargs=parameter_model.model_dump() ) self._api_client._verify_response(return_value) + return return_value diff --git a/vertexai/_genai/types.py b/vertexai/_genai/types.py index 3d224c982a..feeb929deb 100644 --- a/vertexai/_genai/types.py +++ b/vertexai/_genai/types.py @@ -25,6 +25,7 @@ Any, Callable, ClassVar, + Dict, List, Literal, Optional, @@ -804,6 +805,63 @@ class AutoraterConfigDict(TypedDict, total=False): AutoraterConfigOrDict = Union[AutoraterConfig, AutoraterConfigDict] +class HttpRetryOptions(_common.BaseModel): + """HTTP retry options to be used in each of the requests.""" + + attempts: Optional[int] = Field( + default=None, + description="""Maximum number of attempts, including the original request. + If 0 or 1, it means no retries.""", + ) + initial_delay: Optional[float] = Field( + default=None, + description="""Initial delay before the first retry, in fractions of a second.""", + ) + max_delay: Optional[float] = Field( + default=None, + description="""Maximum delay between retries, in fractions of a second.""", + ) + exp_base: Optional[float] = Field( + default=None, + description="""Multiplier by which the delay increases after each attempt.""", + ) + jitter: Optional[float] = Field( + default=None, description="""Randomness factor for the delay.""" + ) + http_status_codes: Optional[list[int]] = Field( + default=None, + description="""List of HTTP status codes that should trigger a retry. + If not specified, a default set of retryable codes may be used.""", + ) + + +class HttpRetryOptionsDict(TypedDict, total=False): + """HTTP retry options to be used in each of the requests.""" + + attempts: Optional[int] + """Maximum number of attempts, including the original request. + If 0 or 1, it means no retries.""" + + initial_delay: Optional[float] + """Initial delay before the first retry, in fractions of a second.""" + + max_delay: Optional[float] + """Maximum delay between retries, in fractions of a second.""" + + exp_base: Optional[float] + """Multiplier by which the delay increases after each attempt.""" + + jitter: Optional[float] + """Randomness factor for the delay.""" + + http_status_codes: Optional[list[int]] + """List of HTTP status codes that should trigger a retry. + If not specified, a default set of retryable codes may be used.""" + + +HttpRetryOptionsOrDict = Union[HttpRetryOptions, HttpRetryOptionsDict] + + class HttpOptions(_common.BaseModel): """HTTP options to be used in each of the requests.""" @@ -827,6 +885,13 @@ class HttpOptions(_common.BaseModel): async_client_args: Optional[dict[str, Any]] = Field( default=None, description="""Args passed to the async HTTP client.""" ) + extra_body: Optional[dict[str, Any]] = Field( + default=None, + description="""Extra parameters to add to the request body.""", + ) + retry_options: Optional[HttpRetryOptions] = Field( + default=None, description="""HTTP retry options for the request.""" + ) class HttpOptionsDict(TypedDict, total=False): @@ -850,6 +915,12 @@ class HttpOptionsDict(TypedDict, total=False): async_client_args: Optional[dict[str, Any]] """Args passed to the async HTTP client.""" + extra_body: Optional[dict[str, Any]] + """Extra parameters to add to the request body.""" + + retry_options: Optional[HttpRetryOptionsDict] + """HTTP retry options for the request.""" + HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict] @@ -2064,6 +2135,10 @@ class EvaluationDataset(_common.BaseModel): default=None, description="""The evaluation dataset in the form of a Pandas DataFrame.""", ) + candidate_name: Optional[str] = Field( + default=None, + description="""The name of the candidate model or agent for this evaluation dataset.""", + ) gcs_source: Optional[GcsSource] = Field( default=None, description="""The GCS source for the evaluation dataset.""", @@ -2089,6 +2164,9 @@ class EvaluationDatasetDict(TypedDict, total=False): eval_dataset_df: Optional[pd.DataFrame] """The evaluation dataset in the form of a Pandas DataFrame.""" + candidate_name: Optional[str] + """The name of the candidate model or agent for this evaluation dataset.""" + gcs_source: Optional[GcsSourceDict] """The GCS source for the evaluation dataset.""" @@ -2449,250 +2527,1032 @@ class OptimizeResponseDict(TypedDict, total=False): OptimizeResponseOrDict = Union[OptimizeResponse, OptimizeResponseDict] -class PromptOptimizerVAPOConfig(_common.BaseModel): - """VAPO Prompt Optimizer Config.""" +class EnvVar(_common.BaseModel): + """Represents an environment variable present in a Container or Python Module.""" - config_path: Optional[str] = Field( - default=None, description="""The gcs path to the config file.""" + name: Optional[str] = Field( + default=None, + description="""Required. Name of the environment variable. Must be a valid C identifier.""", + ) + value: Optional[str] = Field( + default=None, + description="""Required. Variables that reference a $(VAR_NAME) are expanded using the previous defined environment variables in the container and any service environment variables. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not.""", ) - wait_for_completion: Optional[bool] = Field(default=None, description="""""") -class PromptOptimizerVAPOConfigDict(TypedDict, total=False): - """VAPO Prompt Optimizer Config.""" +class EnvVarDict(TypedDict, total=False): + """Represents an environment variable present in a Container or Python Module.""" - config_path: Optional[str] - """The gcs path to the config file.""" + name: Optional[str] + """Required. Name of the environment variable. Must be a valid C identifier.""" - wait_for_completion: Optional[bool] - """""" + value: Optional[str] + """Required. Variables that reference a $(VAR_NAME) are expanded using the previous defined environment variables in the container and any service environment variables. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not.""" -PromptOptimizerVAPOConfigOrDict = Union[ - PromptOptimizerVAPOConfig, PromptOptimizerVAPOConfigDict -] +EnvVarOrDict = Union[EnvVar, EnvVarDict] -class PromptTemplate(_common.BaseModel): - """A prompt template for creating prompts with variables.""" +class SecretRef(_common.BaseModel): + """Reference to a secret stored in the Cloud Secret Manager that will provide the value for this environment variable.""" - text: Optional[str] = Field( - default=None, description="""The prompt template text.""" + secret: Optional[str] = Field( + default=None, + description="""Required. The name of the secret in Cloud Secret Manager. Format: {secret_name}.""", + ) + version: Optional[str] = Field( + default=None, + description="""The Cloud Secret Manager secret version. Can be 'latest' for the latest version, an integer for a specific version, or a version alias.""", ) - _VARIABLE_NAME_REGEX: ClassVar[str] = r"\{([_a-zA-Z][_a-zA-Z0-9]*)\}" - - @field_validator("text") - @classmethod - def text_must_not_be_empty(cls, value: str) -> str: - if not value.strip(): - raise ValueError( - "Prompt template text cannot be empty or consist only of" " whitespace." - ) - return value - @computed_field - @property - def variables(self) -> set[str]: - return set(re.findall(self._VARIABLE_NAME_REGEX, self.text)) - def _split_template_by_variables(self) -> list[Tuple[str, str]]: - parts = [] - last_end = 0 - for match in re.finditer(self._VARIABLE_NAME_REGEX, self.text): - start, end = match.span() - var_name = match.group(1) - if start > last_end: - parts.append(("text", self.text[last_end:start])) - parts.append(("var", var_name)) - last_end = end - if last_end < len(self.text): - parts.append(("text", self.text[last_end:])) - return parts +class SecretRefDict(TypedDict, total=False): + """Reference to a secret stored in the Cloud Secret Manager that will provide the value for this environment variable.""" - def _merge_adjacent_text_parts( - self, parts: list[genai_types.Part] - ) -> list[genai_types.Part]: - if not parts: - return [] + secret: Optional[str] + """Required. The name of the secret in Cloud Secret Manager. Format: {secret_name}.""" - merged = [] - current_text_buffer = [] + version: Optional[str] + """The Cloud Secret Manager secret version. Can be 'latest' for the latest version, an integer for a specific version, or a version alias.""" - for part in parts: - is_purely_text = part.text is not None and all( - getattr(part, field) is None - for field in part.model_fields - if field != "text" - ) - if is_purely_text: - current_text_buffer.append(part.text) - else: - if current_text_buffer: - merged.append(genai_types.Part(text="".join(current_text_buffer))) - current_text_buffer = [] - merged.append(part) +SecretRefOrDict = Union[SecretRef, SecretRefDict] - if current_text_buffer: - merged.append(genai_types.Part(text="".join(current_text_buffer))) - return merged +class SecretEnvVar(_common.BaseModel): + """Represents an environment variable where the value is a secret in Cloud Secret Manager.""" - def _is_multimodal_json_string( - self, - value: Any, - ) -> bool: - """Checks if the input value is a multimodal JSON string.""" - if not isinstance(value, str): - return False - try: - data = json.loads(value) - # Check for the specific structure: {"contents": [{"parts": [...]}]} - # or {"parts": [...]} if assemble returns a single Content JSON - if isinstance(data, dict): - if "contents" in data and isinstance(data["contents"], list): - if not data["contents"]: - return False - first_content = data["contents"][0] - if isinstance(first_content, dict) and "parts" in first_content: - try: - genai_types.Content.model_validate(first_content) - return True - except ValueError: - return False - # Adding a check if 'data' itself is a Content-like object with parts - elif "parts" in data and isinstance(data["parts"], list): - try: - genai_types.Content.model_validate(data) - return True - except ValueError: - return False - return False - except json.JSONDecodeError: - return False + name: Optional[str] = Field( + default=None, + description="""Required. Name of the secret environment variable.""", + ) + secret_ref: Optional[SecretRef] = Field( + default=None, + description="""Required. Reference to a secret stored in the Cloud Secret Manager that will provide the value for this environment variable.""", + ) - def _parse_multimodal_json_string_into_parts( - self, - value: str, - ) -> list[genai_types.Part]: - """Parses a multimodal JSON string and returns its list of Parts.""" - try: - content = genai_types.Content.model_validate_json(value) - return content.parts - except Exception: - return [genai_types.Part(text=value)] - def assemble(self, **kwargs: Any) -> str: - """Assembles the prompt template with the given keyword arguments. +class SecretEnvVarDict(TypedDict, total=False): + """Represents an environment variable where the value is a secret in Cloud Secret Manager.""" - Supports both text and multimodal content. The `assemble` method - substitutes variables from the prompt template text with provided - values. + name: Optional[str] + """Required. Name of the secret environment variable.""" - Key Behaviors of `assemble()`: - 1. Variable Substitution: Replaces all defined variables with their - corresponding keyword argument values. Raises ValueError if a - template - variable is missing a value or if an extraneous kwarg is provided. - 2. Multimodal Handling: - - Detects if any variable's value is a JSON string representing - multimodal content (specifically, `{"contents": [{"parts": [...]}]}` - or `{"role": "user", "parts": [...]}`). - - If multimodal content is detected for a variable, its `Part` - objects - are extracted and inserted into the assembled sequence. - - Text segments from the template and simple text variable values - become `Part(text=...)`. - 3. Output Format: - - If ALL substituted variables were simple text AND the assembled - result (after merging adjacent text parts) consists of a single, - purely textual `Part`, `assemble()` returns a raw Python string. - - Otherwise (if any variable was multimodal, or if the assembly - results in multiple parts or non-textual parts), `assemble()` - returns - a JSON string representing a single `google.genai.types.Content` - object with `role="user"` and the assembled parts. - 4. Text Part Merging: Consecutively assembled text parts are - automatically merged into a single text `Part` to create a more - concise list of parts. + secret_ref: Optional[SecretRefDict] + """Required. Reference to a secret stored in the Cloud Secret Manager that will provide the value for this environment variable.""" - This dual output format (raw string or JSON string of `Content`) allows - the downstream inference functions to seamlessly handle both simple text - prompts and more complex multimodal prompts generated from the same - templating mechanism. - """ - current_variables = self.variables - for var_name_in_kwarg in kwargs: - if var_name_in_kwarg not in current_variables: - raise ValueError( - f"Invalid variable name '{var_name_in_kwarg}' provided to" - " assemble. Valid variables in template are:" - f" {current_variables}" - ) - # Check if all template variables are provided in kwargs - for tpl_var in current_variables: - if tpl_var not in kwargs: - raise ValueError(f"Missing value for template variable '{tpl_var}'.") - template_segments = self._split_template_by_variables() +SecretEnvVarOrDict = Union[SecretEnvVar, SecretEnvVarDict] - raw_assembled_parts: list[genai_types.Part] = [] - contains_multimodal_variable_type = False - for segment_type, segment_value in template_segments: - if segment_type == "text": - if segment_value: - raw_assembled_parts.append(genai_types.Part(text=segment_value)) - elif segment_type == "var": - var_value = kwargs.get(segment_value) +class ReasoningEngineSpecDeploymentSpec(_common.BaseModel): + """The specification of a Reasoning Engine deployment.""" - str_var_value = str(var_value) + env: Optional[list[EnvVar]] = Field( + default=None, + description="""Optional. Environment variables to be set with the Reasoning Engine deployment. The environment variables can be updated through the UpdateReasoningEngine API.""", + ) + secret_env: Optional[list[SecretEnvVar]] = Field( + default=None, + description="""Optional. Environment variables where the value is a secret in Cloud Secret Manager. To use this feature, add 'Secret Manager Secret Accessor' role (roles/secretmanager.secretAccessor) to AI Platform Reasoning Engine Service Agent.""", + ) - if self._is_multimodal_json_string(str_var_value): - multimodal_parts = self._parse_multimodal_json_string_into_parts( - str_var_value - ) - if multimodal_parts: - contains_multimodal_variable_type = True - raw_assembled_parts.extend(multimodal_parts) - else: - raw_assembled_parts.append(genai_types.Part(text=str_var_value)) - else: - raw_assembled_parts.append(genai_types.Part(text=str_var_value)) - final_assembled_parts = self._merge_adjacent_text_parts(raw_assembled_parts) +class ReasoningEngineSpecDeploymentSpecDict(TypedDict, total=False): + """The specification of a Reasoning Engine deployment.""" - # Condition for returning raw text string: - # 1. No multimodal variable was *originally* a multimodal JSON string. - # 2. After merging, there's exactly one part. - # 3. That single part is purely textual. - if ( - not contains_multimodal_variable_type - and len(final_assembled_parts) == 1 - and final_assembled_parts[0].text is not None - and all( - getattr(final_assembled_parts[0], field) is None - for field in final_assembled_parts[0].model_fields - if field not in ["text", "role"] - ) - ): - return final_assembled_parts[0].text + env: Optional[list[EnvVarDict]] + """Optional. Environment variables to be set with the Reasoning Engine deployment. The environment variables can be updated through the UpdateReasoningEngine API.""" - # Otherwise, construct a Content object (as JSON string). - final_content_obj = genai_types.Content(parts=final_assembled_parts) - return final_content_obj.model_dump_json(exclude_none=True) + secret_env: Optional[list[SecretEnvVarDict]] + """Optional. Environment variables where the value is a secret in Cloud Secret Manager. To use this feature, add 'Secret Manager Secret Accessor' role (roles/secretmanager.secretAccessor) to AI Platform Reasoning Engine Service Agent.""" - def __str__(self) -> str: - return self.text - def __repr__(self) -> str: - return f"PromptTemplate(text='{self.text}')" +ReasoningEngineSpecDeploymentSpecOrDict = Union[ + ReasoningEngineSpecDeploymentSpec, ReasoningEngineSpecDeploymentSpecDict +] -class MetricPromptBuilder(PromptTemplate): - """Builder class for structured LLM-based metric prompt template.""" +class ReasoningEngineSpecPackageSpec(_common.BaseModel): + """User provided package spec like pickled object and package requirements.""" - criteria: Optional[dict[str, str]] = Field( + dependency_files_gcs_uri: Optional[str] = Field( + default=None, + description="""Optional. The Cloud Storage URI of the dependency files in tar.gz format.""", + ) + pickle_object_gcs_uri: Optional[str] = Field( + default=None, + description="""Optional. The Cloud Storage URI of the pickled python object.""", + ) + python_version: Optional[str] = Field( + default=None, + description="""Optional. The Python version. Currently support 3.8, 3.9, 3.10, 3.11. If not specified, default value is 3.10.""", + ) + requirements_gcs_uri: Optional[str] = Field( + default=None, + description="""Optional. The Cloud Storage URI of the `requirements.txt` file""", + ) + + +class ReasoningEngineSpecPackageSpecDict(TypedDict, total=False): + """User provided package spec like pickled object and package requirements.""" + + dependency_files_gcs_uri: Optional[str] + """Optional. The Cloud Storage URI of the dependency files in tar.gz format.""" + + pickle_object_gcs_uri: Optional[str] + """Optional. The Cloud Storage URI of the pickled python object.""" + + python_version: Optional[str] + """Optional. The Python version. Currently support 3.8, 3.9, 3.10, 3.11. If not specified, default value is 3.10.""" + + requirements_gcs_uri: Optional[str] + """Optional. The Cloud Storage URI of the `requirements.txt` file""" + + +ReasoningEngineSpecPackageSpecOrDict = Union[ + ReasoningEngineSpecPackageSpec, ReasoningEngineSpecPackageSpecDict +] + + +class ReasoningEngineSpec(_common.BaseModel): + """The specification of a Reasoning Engine.""" + + agent_framework: Optional[str] = Field( + default=None, + description="""Optional. The OSS agent framework used to develop the agent. Currently supported values: "google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom".""", + ) + class_methods: Optional[list[dict[str, Any]]] = Field( + default=None, + description="""Optional. Declarations for object class methods in OpenAPI specification format.""", + ) + deployment_spec: Optional[ReasoningEngineSpecDeploymentSpec] = Field( + default=None, + description="""Optional. The specification of a Reasoning Engine deployment.""", + ) + package_spec: Optional[ReasoningEngineSpecPackageSpec] = Field( + default=None, + description="""Optional. User provided package spec of the ReasoningEngine. Ignored when users directly specify a deployment image through `deployment_spec.first_party_image_override`, but keeping the field_behavior to avoid introducing breaking changes.""", + ) + + +class ReasoningEngineSpecDict(TypedDict, total=False): + """The specification of a Reasoning Engine.""" + + agent_framework: Optional[str] + """Optional. The OSS agent framework used to develop the agent. Currently supported values: "google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom".""" + + class_methods: Optional[list[dict[str, Any]]] + """Optional. Declarations for object class methods in OpenAPI specification format.""" + + deployment_spec: Optional[ReasoningEngineSpecDeploymentSpecDict] + """Optional. The specification of a Reasoning Engine deployment.""" + + package_spec: Optional[ReasoningEngineSpecPackageSpecDict] + """Optional. User provided package spec of the ReasoningEngine. Ignored when users directly specify a deployment image through `deployment_spec.first_party_image_override`, but keeping the field_behavior to avoid introducing breaking changes.""" + + +ReasoningEngineSpecOrDict = Union[ReasoningEngineSpec, ReasoningEngineSpecDict] + + +class CreateAgentEngineConfig(_common.BaseModel): + """Config for create agent engine.""" + + http_options: Optional[HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + display_name: Optional[str] = Field( + default=None, + description="""The user-defined name of the Agent Engine. + + The display name can be up to 128 characters long and can comprise any + UTF-8 characters. + """, + ) + description: Optional[str] = Field( + default=None, description="""The description of the Agent Engine.""" + ) + spec: Optional[ReasoningEngineSpec] = Field( + default=None, + description="""Optional. Configurations of the ReasoningEngine.""", + ) + + +class CreateAgentEngineConfigDict(TypedDict, total=False): + """Config for create agent engine.""" + + http_options: Optional[HttpOptionsDict] + """Used to override HTTP request options.""" + + display_name: Optional[str] + """The user-defined name of the Agent Engine. + + The display name can be up to 128 characters long and can comprise any + UTF-8 characters. + """ + + description: Optional[str] + """The description of the Agent Engine.""" + + spec: Optional[ReasoningEngineSpecDict] + """Optional. Configurations of the ReasoningEngine.""" + + +CreateAgentEngineConfigOrDict = Union[ + CreateAgentEngineConfig, CreateAgentEngineConfigDict +] + + +class _CreateAgentEngineRequestParameters(_common.BaseModel): + """Parameters for creating agent engines.""" + + config: Optional[CreateAgentEngineConfig] = Field(default=None, description="""""") + + +class _CreateAgentEngineRequestParametersDict(TypedDict, total=False): + """Parameters for creating agent engines.""" + + config: Optional[CreateAgentEngineConfigDict] + """""" + + +_CreateAgentEngineRequestParametersOrDict = Union[ + _CreateAgentEngineRequestParameters, _CreateAgentEngineRequestParametersDict +] + + +class ReasoningEngine(_common.BaseModel): + """An agent engine.""" + + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Timestamp when this ReasoningEngine was created.""", + ) + description: Optional[str] = Field( + default=None, + description="""Optional. The description of the ReasoningEngine.""", + ) + display_name: Optional[str] = Field( + default=None, + description="""Required. The display name of the ReasoningEngine.""", + ) + etag: Optional[str] = Field( + default=None, + description="""Optional. Used to perform consistent read-modify-write updates. If not set, a blind "overwrite" update happens.""", + ) + name: Optional[str] = Field( + default=None, + description="""Identifier. The resource name of the ReasoningEngine. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}`""", + ) + spec: Optional[ReasoningEngineSpec] = Field( + default=None, + description="""Optional. Configurations of the ReasoningEngine""", + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Timestamp when this ReasoningEngine was most recently updated.""", + ) + + +class ReasoningEngineDict(TypedDict, total=False): + """An agent engine.""" + + create_time: Optional[datetime.datetime] + """Output only. Timestamp when this ReasoningEngine was created.""" + + description: Optional[str] + """Optional. The description of the ReasoningEngine.""" + + display_name: Optional[str] + """Required. The display name of the ReasoningEngine.""" + + etag: Optional[str] + """Optional. Used to perform consistent read-modify-write updates. If not set, a blind "overwrite" update happens.""" + + name: Optional[str] + """Identifier. The resource name of the ReasoningEngine. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}`""" + + spec: Optional[ReasoningEngineSpecDict] + """Optional. Configurations of the ReasoningEngine""" + + update_time: Optional[datetime.datetime] + """Output only. Timestamp when this ReasoningEngine was most recently updated.""" + + +ReasoningEngineOrDict = Union[ReasoningEngine, ReasoningEngineDict] + + +class AgentEngineOperation(_common.BaseModel): + """Operation that has an agent engine as a response.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + response: Optional[ReasoningEngine] = Field( + default=None, description="""The created Agent Engine.""" + ) + + +class AgentEngineOperationDict(TypedDict, total=False): + """Operation that has an agent engine as a response.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + response: Optional[ReasoningEngineDict] + """The created Agent Engine.""" + + +AgentEngineOperationOrDict = Union[AgentEngineOperation, AgentEngineOperationDict] + + +class DeleteAgentEngineConfig(_common.BaseModel): + """Config for deleting agent engine.""" + + http_options: Optional[HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class DeleteAgentEngineConfigDict(TypedDict, total=False): + """Config for deleting agent engine.""" + + http_options: Optional[HttpOptionsDict] + """Used to override HTTP request options.""" + + +DeleteAgentEngineConfigOrDict = Union[ + DeleteAgentEngineConfig, DeleteAgentEngineConfigDict +] + + +class _DeleteAgentEngineRequestParameters(_common.BaseModel): + """Parameters for deleting agent engines.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine.""" + ) + force: Optional[bool] = Field( + default=False, + description="""If set to true, any child resources will also be deleted.""", + ) + config: Optional[DeleteAgentEngineConfig] = Field(default=None, description="""""") + + +class _DeleteAgentEngineRequestParametersDict(TypedDict, total=False): + """Parameters for deleting agent engines.""" + + name: Optional[str] + """Name of the agent engine.""" + + force: Optional[bool] + """If set to true, any child resources will also be deleted.""" + + config: Optional[DeleteAgentEngineConfigDict] + """""" + + +_DeleteAgentEngineRequestParametersOrDict = Union[ + _DeleteAgentEngineRequestParameters, _DeleteAgentEngineRequestParametersDict +] + + +class DeleteAgentEngineOperation(_common.BaseModel): + """Operation for deleting agent engines.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + + +class DeleteAgentEngineOperationDict(TypedDict, total=False): + """Operation for deleting agent engines.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + +DeleteAgentEngineOperationOrDict = Union[ + DeleteAgentEngineOperation, DeleteAgentEngineOperationDict +] + + +class GetAgentEngineConfig(_common.BaseModel): + """Config for create agent engine.""" + + http_options: Optional[HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetAgentEngineConfigDict(TypedDict, total=False): + """Config for create agent engine.""" + + http_options: Optional[HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetAgentEngineConfigOrDict = Union[GetAgentEngineConfig, GetAgentEngineConfigDict] + + +class _GetAgentEngineRequestParameters(_common.BaseModel): + """Parameters for getting agent engines.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine.""" + ) + config: Optional[GetAgentEngineConfig] = Field(default=None, description="""""") + + +class _GetAgentEngineRequestParametersDict(TypedDict, total=False): + """Parameters for getting agent engines.""" + + name: Optional[str] + """Name of the agent engine.""" + + config: Optional[GetAgentEngineConfigDict] + """""" + + +_GetAgentEngineRequestParametersOrDict = Union[ + _GetAgentEngineRequestParameters, _GetAgentEngineRequestParametersDict +] + + +class ListAgentEngineConfig(_common.BaseModel): + """Config for listing agent engines.""" + + http_options: Optional[HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field( + default=None, + description="""An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""", + ) + + +class ListAgentEngineConfigDict(TypedDict, total=False): + """Config for listing agent engines.""" + + http_options: Optional[HttpOptionsDict] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""" + + +ListAgentEngineConfigOrDict = Union[ListAgentEngineConfig, ListAgentEngineConfigDict] + + +class _ListAgentEngineRequestParameters(_common.BaseModel): + """Parameters for listing agent engines.""" + + config: Optional[ListAgentEngineConfig] = Field(default=None, description="""""") + + +class _ListAgentEngineRequestParametersDict(TypedDict, total=False): + """Parameters for listing agent engines.""" + + config: Optional[ListAgentEngineConfigDict] + """""" + + +_ListAgentEngineRequestParametersOrDict = Union[ + _ListAgentEngineRequestParameters, _ListAgentEngineRequestParametersDict +] + + +class ListReasoningEnginesResponse(_common.BaseModel): + """Response for listing agent engines.""" + + next_page_token: Optional[str] = Field(default=None, description="""""") + reasoning_engines: Optional[list[ReasoningEngine]] = Field( + default=None, + description="""List of agent engines. + """, + ) + + +class ListReasoningEnginesResponseDict(TypedDict, total=False): + """Response for listing agent engines.""" + + next_page_token: Optional[str] + """""" + + reasoning_engines: Optional[list[ReasoningEngineDict]] + """List of agent engines. + """ + + +ListReasoningEnginesResponseOrDict = Union[ + ListReasoningEnginesResponse, ListReasoningEnginesResponseDict +] + + +class GetAgentEngineOperationConfig(_common.BaseModel): + + http_options: Optional[HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetAgentEngineOperationConfigDict(TypedDict, total=False): + + http_options: Optional[HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetAgentEngineOperationConfigOrDict = Union[ + GetAgentEngineOperationConfig, GetAgentEngineOperationConfigDict +] + + +class _GetAgentEngineOperationParameters(_common.BaseModel): + """Parameters for the GET method.""" + + operation_name: Optional[str] = Field( + default=None, + description="""The server-assigned name for the operation.""", + ) + config: Optional[GetAgentEngineOperationConfig] = Field( + default=None, + description="""Used to override the default configuration.""", + ) + + +class _GetAgentEngineOperationParametersDict(TypedDict, total=False): + """Parameters for the GET method.""" + + operation_name: Optional[str] + """The server-assigned name for the operation.""" + + config: Optional[GetAgentEngineOperationConfigDict] + """Used to override the default configuration.""" + + +_GetAgentEngineOperationParametersOrDict = Union[ + _GetAgentEngineOperationParameters, _GetAgentEngineOperationParametersDict +] + + +class QueryAgentEngineConfig(_common.BaseModel): + """Config for querying agent engines.""" + + http_options: Optional[HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + class_method: Optional[str] = Field( + default=None, description="""The class method to call.""" + ) + input: Optional[dict[str, Any]] = Field( + default=None, description="""The input to the class method.""" + ) + include_all_fields: Optional[bool] = Field(default=False, description="""""") + + +class QueryAgentEngineConfigDict(TypedDict, total=False): + """Config for querying agent engines.""" + + http_options: Optional[HttpOptionsDict] + """Used to override HTTP request options.""" + + class_method: Optional[str] + """The class method to call.""" + + input: Optional[dict[str, Any]] + """The input to the class method.""" + + include_all_fields: Optional[bool] + """""" + + +QueryAgentEngineConfigOrDict = Union[QueryAgentEngineConfig, QueryAgentEngineConfigDict] + + +class _QueryAgentEngineRequestParameters(_common.BaseModel): + """Parameters for querying agent engines.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine.""" + ) + config: Optional[QueryAgentEngineConfig] = Field(default=None, description="""""") + + +class _QueryAgentEngineRequestParametersDict(TypedDict, total=False): + """Parameters for querying agent engines.""" + + name: Optional[str] + """Name of the agent engine.""" + + config: Optional[QueryAgentEngineConfigDict] + """""" + + +_QueryAgentEngineRequestParametersOrDict = Union[ + _QueryAgentEngineRequestParameters, _QueryAgentEngineRequestParametersDict +] + + +class QueryReasoningEngineResponse(_common.BaseModel): + """The response for querying an agent engine.""" + + output: Optional[Any] = Field( + default=None, + description="""Response provided by users in JSON object format.""", + ) + + +class QueryReasoningEngineResponseDict(TypedDict, total=False): + """The response for querying an agent engine.""" + + output: Optional[Any] + """Response provided by users in JSON object format.""" + + +QueryReasoningEngineResponseOrDict = Union[ + QueryReasoningEngineResponse, QueryReasoningEngineResponseDict +] + + +class UpdateAgentEngineConfig(_common.BaseModel): + """Config for updating agent engine.""" + + http_options: Optional[HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + display_name: Optional[str] = Field( + default=None, + description="""The user-defined name of the Agent Engine. + + The display name can be up to 128 characters long and can comprise any + UTF-8 characters. + """, + ) + description: Optional[str] = Field( + default=None, description="""The description of the Agent Engine.""" + ) + spec: Optional[ReasoningEngineSpec] = Field( + default=None, + description="""Optional. Configurations of the ReasoningEngine.""", + ) + update_mask: Optional[str] = Field( + default=None, + description="""The update mask to apply. For the `FieldMask` definition, see + https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask.""", + ) + + +class UpdateAgentEngineConfigDict(TypedDict, total=False): + """Config for updating agent engine.""" + + http_options: Optional[HttpOptionsDict] + """Used to override HTTP request options.""" + + display_name: Optional[str] + """The user-defined name of the Agent Engine. + + The display name can be up to 128 characters long and can comprise any + UTF-8 characters. + """ + + description: Optional[str] + """The description of the Agent Engine.""" + + spec: Optional[ReasoningEngineSpecDict] + """Optional. Configurations of the ReasoningEngine.""" + + update_mask: Optional[str] + """The update mask to apply. For the `FieldMask` definition, see + https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask.""" + + +UpdateAgentEngineConfigOrDict = Union[ + UpdateAgentEngineConfig, UpdateAgentEngineConfigDict +] + + +class _UpdateAgentEngineRequestParameters(_common.BaseModel): + """Parameters for updating agent engines.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine.""" + ) + config: Optional[UpdateAgentEngineConfig] = Field(default=None, description="""""") + + +class _UpdateAgentEngineRequestParametersDict(TypedDict, total=False): + """Parameters for updating agent engines.""" + + name: Optional[str] + """Name of the agent engine.""" + + config: Optional[UpdateAgentEngineConfigDict] + """""" + + +_UpdateAgentEngineRequestParametersOrDict = Union[ + _UpdateAgentEngineRequestParameters, _UpdateAgentEngineRequestParametersDict +] + + +class PromptOptimizerVAPOConfig(_common.BaseModel): + """VAPO Prompt Optimizer Config.""" + + config_path: Optional[str] = Field( + default=None, description="""The gcs path to the config file.""" + ) + wait_for_completion: Optional[bool] = Field(default=None, description="""""") + + +class PromptOptimizerVAPOConfigDict(TypedDict, total=False): + """VAPO Prompt Optimizer Config.""" + + config_path: Optional[str] + """The gcs path to the config file.""" + + wait_for_completion: Optional[bool] + """""" + + +PromptOptimizerVAPOConfigOrDict = Union[ + PromptOptimizerVAPOConfig, PromptOptimizerVAPOConfigDict +] + + +class PromptTemplate(_common.BaseModel): + """A prompt template for creating prompts with variables.""" + + text: Optional[str] = Field( + default=None, description="""The prompt template text.""" + ) + _VARIABLE_NAME_REGEX: ClassVar[str] = r"\{([_a-zA-Z][_a-zA-Z0-9]*)\}" + + @field_validator("text") + @classmethod + def text_must_not_be_empty(cls, value: str) -> str: + if not value.strip(): + raise ValueError( + "Prompt template text cannot be empty or consist only of" " whitespace." + ) + return value + + @computed_field + @property + def variables(self) -> set[str]: + return set(re.findall(self._VARIABLE_NAME_REGEX, self.text)) + + def _split_template_by_variables(self) -> list[Tuple[str, str]]: + parts = [] + last_end = 0 + for match in re.finditer(self._VARIABLE_NAME_REGEX, self.text): + start, end = match.span() + var_name = match.group(1) + if start > last_end: + parts.append(("text", self.text[last_end:start])) + parts.append(("var", var_name)) + last_end = end + if last_end < len(self.text): + parts.append(("text", self.text[last_end:])) + return parts + + def _merge_adjacent_text_parts( + self, parts: list[genai_types.Part] + ) -> list[genai_types.Part]: + if not parts: + return [] + + merged = [] + current_text_buffer = [] + + for part in parts: + is_purely_text = part.text is not None and all( + getattr(part, field) is None + for field in part.model_fields + if field != "text" + ) + + if is_purely_text: + current_text_buffer.append(part.text) + else: + if current_text_buffer: + merged.append(genai_types.Part(text="".join(current_text_buffer))) + current_text_buffer = [] + merged.append(part) + + if current_text_buffer: + merged.append(genai_types.Part(text="".join(current_text_buffer))) + + return merged + + def _is_multimodal_json_string( + self, + value: Any, + ) -> bool: + """Checks if the input value is a multimodal JSON string.""" + if not isinstance(value, str): + return False + try: + data = json.loads(value) + # Check for the specific structure: {"contents": [{"parts": [...]}]} + # or {"parts": [...]} if assemble returns a single Content JSON + if isinstance(data, dict): + if "contents" in data and isinstance(data["contents"], list): + if not data["contents"]: + return False + first_content = data["contents"][0] + if isinstance(first_content, dict) and "parts" in first_content: + try: + genai_types.Content.model_validate(first_content) + return True + except ValueError: + return False + # Adding a check if 'data' itself is a Content-like object with parts + elif "parts" in data and isinstance(data["parts"], list): + try: + genai_types.Content.model_validate(data) + return True + except ValueError: + return False + return False + except json.JSONDecodeError: + return False + + def _parse_multimodal_json_string_into_parts( + self, + value: str, + ) -> list[genai_types.Part]: + """Parses a multimodal JSON string and returns its list of Parts.""" + try: + content = genai_types.Content.model_validate_json(value) + return content.parts + except Exception: + return [genai_types.Part(text=value)] + + def assemble(self, **kwargs: Any) -> str: + """Assembles the prompt template with the given keyword arguments. + + Supports both text and multimodal content. The `assemble` method + substitutes variables from the prompt template text with provided + values. + + Key Behaviors of `assemble()`: + 1. Variable Substitution: Replaces all defined variables with their + corresponding keyword argument values. Raises ValueError if a + template + variable is missing a value or if an extraneous kwarg is provided. + 2. Multimodal Handling: + - Detects if any variable's value is a JSON string representing + multimodal content (specifically, `{"contents": [{"parts": [...]}]}` + or `{"role": "user", "parts": [...]}`). + - If multimodal content is detected for a variable, its `Part` + objects + are extracted and inserted into the assembled sequence. + - Text segments from the template and simple text variable values + become `Part(text=...)`. + 3. Output Format: + - If ALL substituted variables were simple text AND the assembled + result (after merging adjacent text parts) consists of a single, + purely textual `Part`, `assemble()` returns a raw Python string. + - Otherwise (if any variable was multimodal, or if the assembly + results in multiple parts or non-textual parts), `assemble()` + returns + a JSON string representing a single `google.genai.types.Content` + object with `role="user"` and the assembled parts. + 4. Text Part Merging: Consecutively assembled text parts are + automatically merged into a single text `Part` to create a more + concise list of parts. + + This dual output format (raw string or JSON string of `Content`) allows + the downstream inference functions to seamlessly handle both simple text + prompts and more complex multimodal prompts generated from the same + templating mechanism. + """ + current_variables = self.variables + for var_name_in_kwarg in kwargs: + if var_name_in_kwarg not in current_variables: + raise ValueError( + f"Invalid variable name '{var_name_in_kwarg}' provided to" + " assemble. Valid variables in template are:" + f" {current_variables}" + ) + # Check if all template variables are provided in kwargs + for tpl_var in current_variables: + if tpl_var not in kwargs: + raise ValueError(f"Missing value for template variable '{tpl_var}'.") + + template_segments = self._split_template_by_variables() + + raw_assembled_parts: list[genai_types.Part] = [] + contains_multimodal_variable_type = False + + for segment_type, segment_value in template_segments: + if segment_type == "text": + if segment_value: + raw_assembled_parts.append(genai_types.Part(text=segment_value)) + elif segment_type == "var": + var_value = kwargs.get(segment_value) + + str_var_value = str(var_value) + + if self._is_multimodal_json_string(str_var_value): + multimodal_parts = self._parse_multimodal_json_string_into_parts( + str_var_value + ) + if multimodal_parts: + contains_multimodal_variable_type = True + raw_assembled_parts.extend(multimodal_parts) + else: + raw_assembled_parts.append(genai_types.Part(text=str_var_value)) + else: + raw_assembled_parts.append(genai_types.Part(text=str_var_value)) + + final_assembled_parts = self._merge_adjacent_text_parts(raw_assembled_parts) + + # Condition for returning raw text string: + # 1. No multimodal variable was *originally* a multimodal JSON string. + # 2. After merging, there's exactly one part. + # 3. That single part is purely textual. + if ( + not contains_multimodal_variable_type + and len(final_assembled_parts) == 1 + and final_assembled_parts[0].text is not None + and all( + getattr(final_assembled_parts[0], field) is None + for field in final_assembled_parts[0].model_fields + if field not in ["text", "role"] + ) + ): + return final_assembled_parts[0].text + + # Otherwise, construct a Content object (as JSON string). + final_content_obj = genai_types.Content(parts=final_assembled_parts) + return final_content_obj.model_dump_json(exclude_none=True) + + def __str__(self) -> str: + return self.text + + def __repr__(self) -> str: + return f"PromptTemplate(text='{self.text}')" + + +class MetricPromptBuilder(PromptTemplate): + """Builder class for structured LLM-based metric prompt template.""" + + criteria: Optional[dict[str, str]] = Field( None, description="""A dictionary of criteria used to evaluate the model responses. The keys are criterion names, and the values are the corresponding @@ -3183,7 +4043,7 @@ class EvaluateMethodConfig(_common.BaseModel): http_options: Optional[HttpOptions] = Field( default=None, description="""Used to override HTTP request options.""" ) - dataset_schema: Optional[Literal["gemini", "flatten"]] = Field( + dataset_schema: Optional[Literal["GEMINI", "FLATTEN", "OPENAI"]] = Field( default=None, description="""The schema to use for the dataset. If not specified, the dataset schema will be inferred from the first @@ -3201,7 +4061,7 @@ class EvaluateMethodConfigDict(TypedDict, total=False): http_options: Optional[HttpOptionsDict] """Used to override HTTP request options.""" - dataset_schema: Optional[Literal["gemini", "flatten"]] + dataset_schema: Optional[Literal["GEMINI", "FLATTEN", "OPENAI"]] """The schema to use for the dataset. If not specified, the dataset schema will be inferred from the first example in the dataset.""" @@ -3211,3 +4071,175 @@ class EvaluateMethodConfigDict(TypedDict, total=False): EvaluateMethodConfigOrDict = Union[EvaluateMethodConfig, EvaluateMethodConfigDict] + + +class AgentEngine(_common.BaseModel): + """An agent engine instance.""" + + api_client: Optional[Any] = Field( + default=None, description="""The underlying API client.""" + ) + api_async_client: Optional[Any] = Field( + default=None, + description="""The underlying API client for asynchronous operations.""", + ) + api_resource: Optional[ReasoningEngine] = Field( + default=None, + description="""The underlying API resource (i.e. ReasoningEngine).""", + ) + + # Allows dynamic binding of methods based on the registered operations. + model_config = ConfigDict(extra="allow") + + def __repr__(self) -> str: + return f"AgentEngine(api_resource.name='{self.api_resource.name}')" + + def operation_schemas(self) -> list[Dict[str, Any]]: + """Returns the schemas of all registered operations for the agent.""" + if not isinstance(self.api_resource, ReasoningEngine): + raise ValueError("api_resource is not initialized.") + if not self.api_resource.spec: + raise ValueError("api_resource.spec is not initialized.") + return self.api_resource.spec.class_methods + + def delete( + self, + force: bool = False, + config: Optional[DeleteAgentEngineConfigOrDict] = None, + ): + """Deletes the agent engine. + + Args: + force (bool): Optional. If set to True, child resources will also be + deleted. Otherwise, the request will fail with FAILED_PRECONDITION + error when the Agent Engine has undeleted child resources. Defaults + to False. + config (DeleteAgentEngineConfig): Optional. Additional configurations + for deleting the Agent Engine. + """ + if not isinstance(self.api_resource, ReasoningEngine): + raise ValueError("api_resource is not initialized.") + self.api_client.delete(name=self.api_resource.name, force=force, config=config) + + +class AgentEngineDict(TypedDict, total=False): + """An agent engine instance.""" + + api_client: Optional[Any] + """The underlying API client.""" + + api_async_client: Optional[Any] + """The underlying API client for asynchronous operations.""" + + api_resource: Optional[ReasoningEngineDict] + """The underlying API resource (i.e. ReasoningEngine).""" + + +AgentEngineOrDict = Union[AgentEngine, AgentEngineDict] + + +class AgentEngineConfig(_common.BaseModel): + """Config for agent engine methods.""" + + http_options: Optional[HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + staging_bucket: Optional[str] = Field( + default=None, + description="""The GCS bucket to use for staging the artifacts needed. + + It must be a valid GCS bucket name, e.g. "gs://bucket-name". It is + required if `agent_engine` is specified.""", + ) + requirements: Optional[Any] = Field( + default=None, + description="""The set of PyPI dependencies needed. + + It can either be the path to a single file (requirements.txt), or an + ordered list of strings corresponding to each line of the requirements + file.""", + ) + display_name: Optional[str] = Field( + default=None, + description="""The user-defined name of the Agent Engine. + + The name can be up to 128 characters long and can comprise any UTF-8 + character.""", + ) + description: Optional[str] = Field( + default=None, description="""The description of the Agent Engine.""" + ) + gcs_dir_name: Optional[str] = Field( + default=None, + description="""The GCS bucket directory under `staging_bucket` to use for staging + the artifacts needed.""", + ) + extra_packages: Optional[list[str]] = Field( + default=None, + description="""The set of extra user-provided packages (if any).""", + ) + env_vars: Optional[Any] = Field( + default=None, + description="""The environment variables to be set when running the Agent Engine. + + If it is a dictionary, the keys are the environment variable names, and + the values are the corresponding values.""", + ) + return_agent: Optional[bool] = Field( + default=True, + description="""If True, the agent will be returned. + + Otherwise, the operation for creating or updating the agent will be + returned.""", + ) + + +class AgentEngineConfigDict(TypedDict, total=False): + """Config for agent engine methods.""" + + http_options: Optional[HttpOptionsDict] + """Used to override HTTP request options.""" + + staging_bucket: Optional[str] + """The GCS bucket to use for staging the artifacts needed. + + It must be a valid GCS bucket name, e.g. "gs://bucket-name". It is + required if `agent_engine` is specified.""" + + requirements: Optional[Any] + """The set of PyPI dependencies needed. + + It can either be the path to a single file (requirements.txt), or an + ordered list of strings corresponding to each line of the requirements + file.""" + + display_name: Optional[str] + """The user-defined name of the Agent Engine. + + The name can be up to 128 characters long and can comprise any UTF-8 + character.""" + + description: Optional[str] + """The description of the Agent Engine.""" + + gcs_dir_name: Optional[str] + """The GCS bucket directory under `staging_bucket` to use for staging + the artifacts needed.""" + + extra_packages: Optional[list[str]] + """The set of extra user-provided packages (if any).""" + + env_vars: Optional[Any] + """The environment variables to be set when running the Agent Engine. + + If it is a dictionary, the keys are the environment variable names, and + the values are the corresponding values.""" + + return_agent: Optional[bool] + """If True, the agent will be returned. + + Otherwise, the operation for creating or updating the agent will be + returned.""" + + +AgentEngineConfigOrDict = Union[AgentEngineConfig, AgentEngineConfigDict] diff --git a/vertexai/_model_garden/_model_garden_models.py b/vertexai/_model_garden/_model_garden_models.py index 846a46f92c..699104f9c5 100644 --- a/vertexai/_model_garden/_model_garden_models.py +++ b/vertexai/_model_garden/_model_garden_models.py @@ -24,6 +24,7 @@ from google.cloud.aiplatform import initializer as aiplatform_initializer from google.cloud.aiplatform import models as aiplatform_models from google.cloud.aiplatform import _publisher_models +from vertexai._utils import warning_logs _SUPPORTED_PUBLISHERS = ["google"] @@ -274,7 +275,7 @@ def from_pretrained(cls: Type[T], model_name: str) -> T: ValueError: If model_name is unknown. ValueError: If model does not support this class. """ - + warning_logs.show_deprecation_warning() credential_exception_str = ( "\nUnable to authenticate your request." "\nDepending on your runtime environment, you can complete authentication by:" diff --git a/vertexai/_utils/warning_logs.py b/vertexai/_utils/warning_logs.py new file mode 100644 index 0000000000..cb3912731b --- /dev/null +++ b/vertexai/_utils/warning_logs.py @@ -0,0 +1,31 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import warnings + + +GENAI_DEPRECATION_WARNING_MESSAGE = ( + "This feature is deprecated as of June 24, 2025 and will be removed on June" + " 24, 2026. For details, see" + " https://cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk." +) + + +def show_deprecation_warning() -> None: + warnings.warn( + message=GENAI_DEPRECATION_WARNING_MESSAGE, + category=UserWarning, + stacklevel=2, + ) diff --git a/vertexai/agent_engines/_agent_engines.py b/vertexai/agent_engines/_agent_engines.py index cd7f0e3592..f76dcb43e3 100644 --- a/vertexai/agent_engines/_agent_engines.py +++ b/vertexai/agent_engines/_agent_engines.py @@ -71,15 +71,12 @@ _DEFAULT_ASYNC_STREAM_METHOD_RETURN_TYPE = "AsyncIterable[Any]" _DEFAULT_METHOD_DOCSTRING_TEMPLATE = """ Runs the Agent Engine to serve the user request. - This will be based on the `.{method_name}(...)` of the python object that was passed in when creating the Agent Engine. The method will invoke the `{default_method_name}` API client of the python object. - Args: **kwargs: Optional. The arguments of the `.{method_name}(...)` method. - Returns: {return_type}: The response from serving the user request. """ @@ -91,6 +88,18 @@ ) _AGENT_FRAMEWORK_ATTR = "agent_framework" _DEFAULT_AGENT_FRAMEWORK = "custom" +_DEFAULT_METHOD_NAME_MAP = { + _STANDARD_API_MODE: _DEFAULT_METHOD_NAME, + _ASYNC_API_MODE: _DEFAULT_ASYNC_METHOD_NAME, + _STREAM_API_MODE: _DEFAULT_STREAM_METHOD_NAME, + _ASYNC_STREAM_API_MODE: _DEFAULT_ASYNC_STREAM_METHOD_NAME, +} +_DEFAULT_METHOD_RETURN_TYPE_MAP = { + _STANDARD_API_MODE: _DEFAULT_METHOD_RETURN_TYPE, + _ASYNC_API_MODE: _DEFAULT_ASYNC_METHOD_RETURN_TYPE, + _STREAM_API_MODE: _DEFAULT_STREAM_METHOD_RETURN_TYPE, + _ASYNC_STREAM_API_MODE: _DEFAULT_ASYNC_STREAM_METHOD_RETURN_TYPE, +} try: @@ -1179,7 +1188,7 @@ def _generate_update_request_or_raise( ) -def _wrap_query_operation(method_name: str, doc: str) -> Callable[..., _utils.JsonDict]: +def _wrap_query_operation(method_name: str) -> Callable[..., _utils.JsonDict]: """Wraps an Agent Engine method, creating a callable for `query` API. This function creates a callable object that executes the specified @@ -1206,13 +1215,10 @@ def _method(self, **kwargs) -> _utils.JsonDict: output = _utils.to_dict(response) return output.get("output", output) - _method.__name__ = method_name - _method.__doc__ = doc - return _method -def _wrap_async_query_operation(method_name: str, doc: str) -> Callable[..., Coroutine]: +def _wrap_async_query_operation(method_name: str) -> Callable[..., Coroutine]: """Wraps an Agent Engine method, creating an async callable for `query` API. This function creates a callable object that executes the specified @@ -1239,15 +1245,10 @@ async def _method(self, **kwargs) -> _utils.JsonDict: output = _utils.to_dict(response) return output.get("output", output) - _method.__name__ = method_name - _method.__doc__ = doc - return _method -def _wrap_stream_query_operation( - *, method_name: str, doc: str -) -> Callable[..., Iterable[Any]]: +def _wrap_stream_query_operation(*, method_name: str) -> Callable[..., Iterable[Any]]: """Wraps an Agent Engine method, creating a callable for `stream_query` API. This function creates a callable object that executes the specified @@ -1276,14 +1277,11 @@ def _method(self, **kwargs) -> Iterable[Any]: if parsed_json is not None: yield parsed_json - _method.__name__ = method_name - _method.__doc__ = doc - return _method def _wrap_async_stream_query_operation( - *, method_name: str, doc: str + *, method_name: str ) -> Callable[..., AsyncIterable[Any]]: """Wraps an Agent Engine method, creating an async callable for `stream_query` API. @@ -1313,9 +1311,6 @@ async def _method(self, **kwargs) -> AsyncIterable[Any]: if parsed_json is not None: yield parsed_json - _method.__name__ = method_name - _method.__doc__ = doc - return _method @@ -1340,7 +1335,12 @@ def _unregister_api_methods( delattr(obj, method_name) -def _register_api_methods_or_raise(obj: "AgentEngine"): +def _register_api_methods_or_raise( + obj: "AgentEngine", + wrap_operation_fn: Optional[ + dict[str, Callable[[str, str], Callable[..., Any]]] + ] = None, +): """Registers Agent Engine API methods based on operation schemas. This function iterates through operation schemas provided by the @@ -1351,6 +1351,8 @@ def _register_api_methods_or_raise(obj: "AgentEngine"): Args: obj: The AgentEngine object to augment with API methods. + wrap_operation_fn: A dictionary of API modes and method wrapping + functions. Raises: ValueError: If the API mode is not supported or if the operation schema @@ -1369,68 +1371,43 @@ def _register_api_methods_or_raise(obj: "AgentEngine"): f" contain a `{_METHOD_NAME_KEY_IN_SCHEMA}` field." ) method_name = operation_schema.get(_METHOD_NAME_KEY_IN_SCHEMA) - method_description = operation_schema.get("description") - - if api_mode == _STANDARD_API_MODE: - method_description = ( - method_description - or _DEFAULT_METHOD_DOCSTRING_TEMPLATE.format( - method_name=method_name, - default_method_name=_DEFAULT_METHOD_NAME, - return_type=_DEFAULT_METHOD_RETURN_TYPE, - ) - ) - method = _wrap_query_operation( - method_name=method_name, - doc=method_description, - ) - elif api_mode == _ASYNC_API_MODE: - method_description = ( - method_description - or _DEFAULT_METHOD_DOCSTRING_TEMPLATE.format( - method_name=method_name, - default_method_name=_DEFAULT_ASYNC_METHOD_NAME, - return_type=_DEFAULT_ASYNC_METHOD_RETURN_TYPE, - ) - ) - method = _wrap_async_query_operation( + method_description = operation_schema.get( + "description", + _DEFAULT_METHOD_DOCSTRING_TEMPLATE.format( method_name=method_name, - doc=method_description, - ) - elif api_mode == _STREAM_API_MODE: - method_description = ( - method_description - or _DEFAULT_METHOD_DOCSTRING_TEMPLATE.format( - method_name=method_name, - default_method_name=_DEFAULT_STREAM_METHOD_NAME, - return_type=_DEFAULT_STREAM_METHOD_RETURN_TYPE, - ) - ) - method = _wrap_stream_query_operation( - method_name=method_name, - doc=method_description, - ) - elif api_mode == _ASYNC_STREAM_API_MODE: - method_description = ( - method_description - or _DEFAULT_METHOD_DOCSTRING_TEMPLATE.format( - method_name=method_name, - default_method_name=_DEFAULT_ASYNC_STREAM_METHOD_NAME, - return_type=_DEFAULT_ASYNC_STREAM_METHOD_RETURN_TYPE, - ) - ) - method = _wrap_async_stream_query_operation( - method_name=method_name, - doc=method_description, - ) + default_method_name=_DEFAULT_METHOD_NAME_MAP.get( + api_mode, _DEFAULT_METHOD_NAME + ), + return_type=_DEFAULT_METHOD_RETURN_TYPE_MAP.get( + api_mode, + _DEFAULT_METHOD_RETURN_TYPE, + ), + ), + ) + _wrap_operation_map = { + _STANDARD_API_MODE: _wrap_query_operation, + _ASYNC_API_MODE: _wrap_async_query_operation, + _STREAM_API_MODE: _wrap_stream_query_operation, + _ASYNC_STREAM_API_MODE: _wrap_async_stream_query_operation, + } + if isinstance(wrap_operation_fn, dict) and api_mode in wrap_operation_fn: + # Override the default function with user-specified function if it exists. + _wrap_operation = wrap_operation_fn[api_mode] + elif api_mode in _wrap_operation_map: + _wrap_operation = _wrap_operation_map[api_mode] else: + supported_api_modes = ", ".join( + f"`{mode}`" for mode in sorted(_wrap_operation_map.keys()) + ) raise ValueError( f"Unsupported api mode: `{api_mode}`," - f" Supported modes are: `{_STANDARD_API_MODE}`, `{_ASYNC_API_MODE}`," - f" `{_STREAM_API_MODE}` and `{_ASYNC_STREAM_API_MODE}`." + f" Supported modes are: {supported_api_modes}." ) - # Binds the method to the object. + # Bind the method to the object. + method = _wrap_operation(method_name=method_name) + method.__name__ = method_name + method.__doc__ = method_description setattr(obj, method_name, types.MethodType(method, obj)) diff --git a/vertexai/caching/_caching.py b/vertexai/caching/_caching.py index fc19c9d308..aa3ecda64a 100644 --- a/vertexai/caching/_caching.py +++ b/vertexai/caching/_caching.py @@ -49,6 +49,7 @@ ContentsType, ) from google.protobuf import field_mask_pb2 +from vertexai._utils import warning_logs def _prepare_create_request( @@ -157,6 +158,7 @@ def __init__(self, cached_content_name: str): ID. Example: "projects/.../locations/../cachedContents/456" or "456". """ + warning_logs.show_deprecation_warning() super().__init__(resource_name=cached_content_name) self._gca_resource = self._get_gca_resource(cached_content_name) diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index de35ffc164..e54d3c9e6a 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -62,6 +62,7 @@ from google.protobuf import json_format from google.protobuf import field_mask_pb2 import warnings +from vertexai._utils import warning_logs if TYPE_CHECKING: from vertexai.caching import CachedContent @@ -429,6 +430,7 @@ def __init__( Content of each part will become a separate paragraph. labels: labels that will be passed to billing for cost tracking. """ + warning_logs.show_deprecation_warning() project = aiplatform_initializer.global_config.project location = aiplatform_initializer.global_config.location model_name = _reconcile_model_name(model_name, project, location) diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 51b2594d81..47625d6ad2 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -42,6 +42,7 @@ from vertexai.language_models import ( _evaluatable_language_models, ) +from vertexai._utils import warning_logs try: import pandas @@ -3275,6 +3276,7 @@ def __init__( message_history: Optional[List[ChatMessage]] = None, stop_sequences: Optional[List[str]] = None, ): + warning_logs.show_deprecation_warning() super().__init__( model=model, context=context, @@ -3305,6 +3307,7 @@ def __init__( message_history: Optional[List[ChatMessage]] = None, stop_sequences: Optional[List[str]] = None, ): + warning_logs.show_deprecation_warning() super().__init__( model=model, context=context, diff --git a/vertexai/model_garden/_model_garden.py b/vertexai/model_garden/_model_garden.py index d0a8b7e8a7..cfafac0aad 100644 --- a/vertexai/model_garden/_model_garden.py +++ b/vertexai/model_garden/_model_garden.py @@ -661,7 +661,7 @@ def list_deploy_options( Args: concise: If true, returns a human-readable string with container and - machine specs. + machine specs. Returns: A list of deploy options or a concise formatted string. @@ -694,8 +694,10 @@ def _extract_config(option): if option.dedicated_resources else None ) + option_name = getattr(option, "deploy_task_name", None) return { + "option_name": option_name, "serving_container_image_uri": container, "machine_type": getattr(machine, "machine_type", None), "accelerator_type": getattr( @@ -706,11 +708,15 @@ def _extract_config(option): concise_deploy_options = [_extract_config(opt) for opt in deploy_options] return "\n\n".join( - f"[Option {i + 1}]\n" + ( + f"[Option {i + 1}: {config['option_name']}]\n" + if config.get("option_name") + else f"[Option {i + 1}]\n" + ) + "\n".join( - f' {k}="{v}",' if k != "accelerator_count" else f" {k}={v}," + f' {k}="{v}",' if k != "accelerator_count" else f" {k}={v}," for k, v in config.items() - if v is not None + if v is not None and k != "option_name" ) for i, config in enumerate(concise_deploy_options) ) diff --git a/vertexai/preview/reasoning_engines/templates/adk.py b/vertexai/preview/reasoning_engines/templates/adk.py index 06a7ff01ef..d4ba24361d 100644 --- a/vertexai/preview/reasoning_engines/templates/adk.py +++ b/vertexai/preview/reasoning_engines/templates/adk.py @@ -58,6 +58,13 @@ except (ImportError, AttributeError): BaseArtifactService = Any + try: + from google.adk.memory import BaseMemoryService + + BaseMemoryService = BaseMemoryService + except (ImportError, AttributeError): + BaseMemoryService = Any + try: from opentelemetry.sdk import trace @@ -281,6 +288,7 @@ def __init__( enable_tracing: bool = False, session_service_builder: Optional[Callable[..., "BaseSessionService"]] = None, artifact_service_builder: Optional[Callable[..., "BaseArtifactService"]] = None, + memory_service_builder: Optional[Callable[..., "BaseMemoryService"]] = None, env_vars: Optional[Dict[str, str]] = None, ): """An ADK Application.""" @@ -301,6 +309,7 @@ def __init__( "enable_tracing": enable_tracing, "session_service_builder": session_service_builder, "artifact_service_builder": artifact_service_builder, + "memory_service_builder": memory_service_builder, "app_name": _DEFAULT_APP_NAME, "env_vars": env_vars or {}, } @@ -410,6 +419,7 @@ def clone(self): enable_tracing=self._tmpl_attrs.get("enable_tracing"), session_service_builder=self._tmpl_attrs.get("session_service_builder"), artifact_service_builder=self._tmpl_attrs.get("artifact_service_builder"), + memory_service_builder=self._tmpl_attrs.get("memory_service_builder"), env_vars=self._tmpl_attrs.get("env_vars"), ) @@ -421,6 +431,7 @@ def set_up(self): from google.adk.artifacts.in_memory_artifact_service import ( InMemoryArtifactService, ) + from google.adk.memory.in_memory_memory_service import InMemoryMemoryService os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1" project = self._tmpl_attrs.get("project") @@ -460,18 +471,27 @@ def set_up(self): else: self._tmpl_attrs["session_service"] = InMemorySessionService() + memory_service_builder = self._tmpl_attrs.get("memory_service_builder") + if memory_service_builder: + self._tmpl_attrs["memory_service"] = memory_service_builder() + else: + self._tmpl_attrs["memory_service"] = InMemoryMemoryService() + self._tmpl_attrs["runner"] = Runner( agent=self._tmpl_attrs.get("agent"), session_service=self._tmpl_attrs.get("session_service"), artifact_service=self._tmpl_attrs.get("artifact_service"), + memory_service=self._tmpl_attrs.get("memory_service"), app_name=self._tmpl_attrs.get("app_name"), ) self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService() self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService() + self._tmpl_attrs["in_memory_memory_service"] = InMemoryMemoryService() self._tmpl_attrs["in_memory_runner"] = Runner( agent=self._tmpl_attrs.get("agent"), session_service=self._tmpl_attrs.get("in_memory_session_service"), artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"), + memory_service=self._tmpl_attrs.get("in_memory_memory_service"), app_name=self._tmpl_attrs.get("app_name"), ) diff --git a/vertexai/tuning/_tuning.py b/vertexai/tuning/_tuning.py index f080608dd9..97d35eb1ce 100644 --- a/vertexai/tuning/_tuning.py +++ b/vertexai/tuning/_tuning.py @@ -38,7 +38,7 @@ from google.cloud.aiplatform_v1beta1 import types as gca_types from google.rpc import status_pb2 # type: ignore - +from vertexai._utils import warning_logs _LOGGER = aiplatform_base.Logger(__name__) @@ -71,6 +71,7 @@ class TuningJob(aiplatform_base._VertexAiResourceNounPlus): api_client: gen_ai_tuning_service_v1beta1.client.GenAiTuningServiceClient def __init__(self, tuning_job_name: str): + warning_logs.show_deprecation_warning() super().__init__(resource_name=tuning_job_name) self._gca_resource: gca_tuning_job_types.TuningJob = self._get_gca_resource( resource_name=tuning_job_name diff --git a/vertexai/vision_models/_vision_models.py b/vertexai/vision_models/_vision_models.py index c252fc3e6f..69901d0198 100644 --- a/vertexai/vision_models/_vision_models.py +++ b/vertexai/vision_models/_vision_models.py @@ -29,6 +29,7 @@ from google.cloud.aiplatform import initializer as aiplatform_initializer from vertexai._model_garden import _model_garden_models +from vertexai._utils import warning_logs # pylint: disable=g-import-not-at-top try: @@ -149,6 +150,7 @@ def __init__( image_bytes: Image file bytes. Image can be in PNG or JPEG format. gcs_uri: Image URI in Google Cloud Storage. """ + warning_logs.show_deprecation_warning() if bool(image_bytes) == bool(gcs_uri): raise ValueError("Either image_bytes or gcs_uri must be provided.") @@ -487,6 +489,7 @@ def __init__( MP4, MPEG, MPG, WEBM, and WMV formats. gcs_uri: Image URI in Google Cloud Storage. """ + warning_logs.show_deprecation_warning() if bool(video_bytes) == bool(gcs_uri): raise ValueError("Either video_bytes or gcs_uri must be provided.") @@ -594,6 +597,7 @@ def __init__( end_offset_sec: End time offset (in seconds) to generate embeddings for. interval_sec: Interval to divide video for generated embeddings. """ + warning_logs.show_deprecation_warning() self.start_offset_sec = start_offset_sec self.end_offset_sec = end_offset_sec self.interval_sec = interval_sec @@ -618,6 +622,7 @@ def __init__( end_offset_sec: End time offset (in seconds) of generated embeddings. embedding: Generated embedding for interval. """ + warning_logs.show_deprecation_warning() self.start_offset_sec = start_offset_sec self.end_offset_sec = end_offset_sec self.embedding = embedding @@ -1422,6 +1427,7 @@ def __init__( generation_parameters: Image generation parameter values. gcs_uri: Image file Google Cloud Storage uri. """ + warning_logs.show_deprecation_warning() super().__init__(image_bytes=image_bytes, gcs_uri=gcs_uri) self._generation_parameters = generation_parameters