From 4c5ef1e235fb59966f486b9a57819afec2e3ce71 Mon Sep 17 00:00:00 2001 From: Sean Zhou Date: Thu, 10 Jul 2025 21:22:48 -0700 Subject: [PATCH] feat: Add configure service account auth for google_api_tool_set --- .../tools/google_api_tool/google_api_tool.py | 17 +- .../google_api_tool/google_api_toolset.py | 16 +- .../google_api_tool/google_api_toolsets.py | 61 ++- .../google_api_tool/test_google_api_tool.py | 145 ++++++ .../test_google_api_toolset.py | 427 ++++++++++++++++++ 5 files changed, 638 insertions(+), 28 deletions(-) create mode 100644 tests/unittests/tools/google_api_tool/test_google_api_tool.py create mode 100644 tests/unittests/tools/google_api_tool/test_google_api_toolset.py diff --git a/src/google/adk/tools/google_api_tool/google_api_tool.py b/src/google/adk/tools/google_api_tool/google_api_tool.py index 4fc254b25..5b2d51a23 100644 --- a/src/google/adk/tools/google_api_tool/google_api_tool.py +++ b/src/google/adk/tools/google_api_tool/google_api_tool.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import Any from typing import Dict from typing import Optional @@ -23,7 +25,9 @@ from ...auth import AuthCredential from ...auth import AuthCredentialTypes from ...auth import OAuth2Auth +from ...auth.auth_credential import ServiceAccount from ..openapi_tool import RestApiTool +from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential from ..tool_context import ToolContext @@ -34,6 +38,7 @@ def __init__( rest_api_tool: RestApiTool, client_id: Optional[str] = None, client_secret: Optional[str] = None, + service_account: Optional[ServiceAccount] = None, ): super().__init__( name=rest_api_tool.name, @@ -41,7 +46,10 @@ def __init__( is_long_running=rest_api_tool.is_long_running, ) self._rest_api_tool = rest_api_tool - self.configure_auth(client_id, client_secret) + if service_account is not None: + self.configure_sa_auth(service_account) + else: + self.configure_auth(client_id, client_secret) @override def _get_declaration(self) -> FunctionDeclaration: @@ -63,3 +71,10 @@ def configure_auth(self, client_id: str, client_secret: str): client_secret=client_secret, ), ) + + def configure_sa_auth(self, service_account: ServiceAccount): + auth_scheme, auth_credential = service_account_scheme_credential( + service_account + ) + self._rest_api_tool.auth_scheme = auth_scheme + self._rest_api_tool.auth_credential = auth_credential diff --git a/src/google/adk/tools/google_api_tool/google_api_toolset.py b/src/google/adk/tools/google_api_tool/google_api_toolset.py index 2cb00fa6d..47b3838e1 100644 --- a/src/google/adk/tools/google_api_tool/google_api_toolset.py +++ b/src/google/adk/tools/google_api_tool/google_api_toolset.py @@ -14,18 +14,15 @@ from __future__ import annotations -import inspect -import os -from typing import Any from typing import List from typing import Optional -from typing import Type from typing import Union from typing_extensions import override from ...agents.readonly_context import ReadonlyContext from ...auth import OpenIdConnectWithConfig +from ...auth.auth_credential import ServiceAccount from ...tools.base_toolset import BaseToolset from ...tools.base_toolset import ToolPredicate from ..openapi_tool import OpenAPIToolset @@ -48,11 +45,13 @@ def __init__( client_id: Optional[str] = None, client_secret: Optional[str] = None, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + service_account: Optional[ServiceAccount] = None, ): self.api_name = api_name self.api_version = api_version self._client_id = client_id self._client_secret = client_secret + self._service_account = service_account self._openapi_toolset = self._load_toolset_with_oidc_auth() self.tool_filter = tool_filter @@ -61,10 +60,10 @@ async def get_tools( self, readonly_context: Optional[ReadonlyContext] = None ) -> List[GoogleApiTool]: """Get all tools in the toolset.""" - tools = [] - return [ - GoogleApiTool(tool, self._client_id, self._client_secret) + GoogleApiTool( + tool, self._client_id, self._client_secret, self._service_account + ) for tool in await self._openapi_toolset.get_tools(readonly_context) if self._is_tool_selected(tool, readonly_context) ] @@ -106,6 +105,9 @@ def configure_auth(self, client_id: str, client_secret: str): self._client_id = client_id self._client_secret = client_secret + def configure_sa_auth(self, service_account: ServiceAccount): + self._service_account = service_account + @override async def close(self): if self._openapi_toolset: diff --git a/src/google/adk/tools/google_api_tool/google_api_toolsets.py b/src/google/adk/tools/google_api_tool/google_api_toolsets.py index 22ecb39e6..6e50440d4 100644 --- a/src/google/adk/tools/google_api_tool/google_api_toolsets.py +++ b/src/google/adk/tools/google_api_tool/google_api_toolsets.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import logging from typing import List from typing import Optional from typing import Union +from ...auth.auth_credential import ServiceAccount from ..base_toolset import ToolPredicate from .google_api_toolset import GoogleApiToolset @@ -29,11 +31,14 @@ class BigQueryToolset(GoogleApiToolset): def __init__( self, - client_id: str = None, - client_secret: str = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + service_account: Optional[ServiceAccount] = None, ): - super().__init__("bigquery", "v2", client_id, client_secret, tool_filter) + super().__init__( + "bigquery", "v2", client_id, client_secret, tool_filter, service_account + ) class CalendarToolset(GoogleApiToolset): @@ -41,11 +46,14 @@ class CalendarToolset(GoogleApiToolset): def __init__( self, - client_id: str = None, - client_secret: str = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + service_account: Optional[ServiceAccount] = None, ): - super().__init__("calendar", "v3", client_id, client_secret, tool_filter) + super().__init__( + "calendar", "v3", client_id, client_secret, tool_filter, service_account + ) class GmailToolset(GoogleApiToolset): @@ -53,11 +61,14 @@ class GmailToolset(GoogleApiToolset): def __init__( self, - client_id: str = None, - client_secret: str = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + service_account: Optional[ServiceAccount] = None, ): - super().__init__("gmail", "v1", client_id, client_secret, tool_filter) + super().__init__( + "gmail", "v1", client_id, client_secret, tool_filter, service_account + ) class YoutubeToolset(GoogleApiToolset): @@ -65,11 +76,14 @@ class YoutubeToolset(GoogleApiToolset): def __init__( self, - client_id: str = None, - client_secret: str = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + service_account: Optional[ServiceAccount] = None, ): - super().__init__("youtube", "v3", client_id, client_secret, tool_filter) + super().__init__( + "youtube", "v3", client_id, client_secret, tool_filter, service_account + ) class SlidesToolset(GoogleApiToolset): @@ -77,11 +91,14 @@ class SlidesToolset(GoogleApiToolset): def __init__( self, - client_id: str = None, - client_secret: str = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + service_account: Optional[ServiceAccount] = None, ): - super().__init__("slides", "v1", client_id, client_secret, tool_filter) + super().__init__( + "slides", "v1", client_id, client_secret, tool_filter, service_account + ) class SheetsToolset(GoogleApiToolset): @@ -89,9 +106,10 @@ class SheetsToolset(GoogleApiToolset): def __init__( self, - client_id: str = None, - client_secret: str = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + service_account: Optional[ServiceAccount] = None, ): super().__init__("sheets", "v4", client_id, client_secret, tool_filter) @@ -101,8 +119,11 @@ class DocsToolset(GoogleApiToolset): def __init__( self, - client_id: str = None, - client_secret: str = None, + client_id: Optional[str] = None, + client_secret: Optional[str] = None, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + service_account: Optional[ServiceAccount] = None, ): - super().__init__("docs", "v1", client_id, client_secret, tool_filter) + super().__init__( + "docs", "v1", client_id, client_secret, tool_filter, service_account + ) diff --git a/tests/unittests/tools/google_api_tool/test_google_api_tool.py b/tests/unittests/tools/google_api_tool/test_google_api_tool.py new file mode 100644 index 000000000..0d9c1f9ef --- /dev/null +++ b/tests/unittests/tools/google_api_tool/test_google_api_tool.py @@ -0,0 +1,145 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import ServiceAccount +from google.adk.auth.auth_credential import ServiceAccountCredential +from google.adk.tools.google_api_tool.google_api_tool import GoogleApiTool +from google.adk.tools.openapi_tool import RestApiTool +from google.adk.tools.tool_context import ToolContext +from google.genai.types import FunctionDeclaration +import pytest + + +@pytest.fixture +def mock_rest_api_tool(): + """Fixture for a mock RestApiTool.""" + mock_tool = mock.MagicMock(spec=RestApiTool) + mock_tool.name = "test_tool" + mock_tool.description = "Test Tool Description" + mock_tool.is_long_running = False + mock_tool._get_declaration.return_value = FunctionDeclaration( + name="test_function", description="Test function description" + ) + mock_tool.run_async.return_value = {"result": "success"} + return mock_tool + + +@pytest.fixture +def mock_tool_context(): + """Fixture for a mock ToolContext.""" + return mock.MagicMock(spec=ToolContext) + + +class TestGoogleApiTool: + """Test suite for the GoogleApiTool class.""" + + def test_init(self, mock_rest_api_tool): + """Test GoogleApiTool initialization.""" + tool = GoogleApiTool(mock_rest_api_tool) + + assert tool.name == "test_tool" + assert tool.description == "Test Tool Description" + assert tool.is_long_running is False + assert tool._rest_api_tool == mock_rest_api_tool + + def test_get_declaration(self, mock_rest_api_tool): + """Test _get_declaration method.""" + tool = GoogleApiTool(mock_rest_api_tool) + + declaration = tool._get_declaration() + + assert isinstance(declaration, FunctionDeclaration) + assert declaration.name == "test_function" + assert declaration.description == "Test function description" + mock_rest_api_tool._get_declaration.assert_called_once() + + @pytest.mark.asyncio + async def test_run_async(self, mock_rest_api_tool, mock_tool_context): + """Test run_async method.""" + tool = GoogleApiTool(mock_rest_api_tool) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=mock_tool_context) + + assert result == {"result": "success"} + mock_rest_api_tool.run_async.assert_called_once_with( + args=args, tool_context=mock_tool_context + ) + + def test_configure_auth(self, mock_rest_api_tool): + """Test configure_auth method.""" + tool = GoogleApiTool(mock_rest_api_tool) + client_id = "test_client_id" + client_secret = "test_client_secret" + + tool.configure_auth(client_id=client_id, client_secret=client_secret) + + # Check that auth_credential was set correctly on the rest_api_tool + assert mock_rest_api_tool.auth_credential is not None + assert ( + mock_rest_api_tool.auth_credential.auth_type + == AuthCredentialTypes.OPEN_ID_CONNECT + ) + assert mock_rest_api_tool.auth_credential.oauth2.client_id == client_id + assert ( + mock_rest_api_tool.auth_credential.oauth2.client_secret == client_secret + ) + + @mock.patch( + "google.adk.tools.google_api_tool.google_api_tool.service_account_scheme_credential" + ) + def test_configure_sa_auth( + self, mock_service_account_scheme_credential, mock_rest_api_tool + ): + """Test configure_sa_auth method.""" + # Setup mock return values + mock_auth_scheme = mock.MagicMock() + mock_auth_credential = mock.MagicMock() + mock_service_account_scheme_credential.return_value = ( + mock_auth_scheme, + mock_auth_credential, + ) + + service_account = ServiceAccount( + service_account_credential=ServiceAccountCredential( + type="service_account", + project_id="project_id", + private_key_id="private_key_id", + private_key="private_key", + client_email="client_email", + client_id="client_id", + auth_uri="auth_uri", + token_uri="token_uri", + auth_provider_x509_cert_url="auth_provider_x509_cert_url", + client_x509_cert_url="client_x509_cert_url", + universe_domain="universe_domain", + ), + scopes=["scope1", "scope2"], + ) + + # Create tool and call method + tool = GoogleApiTool(mock_rest_api_tool) + tool.configure_sa_auth(service_account=service_account) + + # Verify service_account_scheme_credential was called correctly + mock_service_account_scheme_credential.assert_called_once_with( + service_account + ) + + # Verify auth_scheme and auth_credential were set correctly on the rest_api_tool + assert mock_rest_api_tool.auth_scheme == mock_auth_scheme + assert mock_rest_api_tool.auth_credential == mock_auth_credential diff --git a/tests/unittests/tools/google_api_tool/test_google_api_toolset.py b/tests/unittests/tools/google_api_tool/test_google_api_toolset.py new file mode 100644 index 000000000..4f5ca1f22 --- /dev/null +++ b/tests/unittests/tools/google_api_tool/test_google_api_toolset.py @@ -0,0 +1,427 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional +from unittest import mock + +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.auth import OpenIdConnectWithConfig +from google.adk.auth.auth_credential import ServiceAccount +from google.adk.auth.auth_credential import ServiceAccountCredential +from google.adk.tools import BaseTool +from google.adk.tools.base_toolset import ToolPredicate +from google.adk.tools.google_api_tool.google_api_tool import GoogleApiTool +from google.adk.tools.google_api_tool.google_api_toolset import GoogleApiToolset +from google.adk.tools.google_api_tool.googleapi_to_openapi_converter import GoogleApiToOpenApiConverter +from google.adk.tools.openapi_tool import OpenAPIToolset +from google.adk.tools.openapi_tool import RestApiTool +import pytest + +TEST_API_NAME = "calendar" +TEST_API_VERSION = "v3" +DEFAULT_SCOPE = "https://www.googleapis.com/auth/calendar" + + +@pytest.fixture +def mock_rest_api_tool(): + """Fixture for a mock RestApiTool.""" + mock_tool = mock.MagicMock(spec=RestApiTool) + mock_tool.name = "test_tool" + mock_tool.description = "Test Tool Description" + return mock_tool + + +@pytest.fixture +def mock_google_api_tool_instance( + mock_rest_api_tool, +): # Renamed from mock_google_api_tool + """Fixture for a mock GoogleApiTool instance.""" + mock_tool = mock.MagicMock(spec=GoogleApiTool) + mock_tool.name = "test_tool" + mock_tool.description = "Test Tool Description" + mock_tool.rest_api_tool = mock_rest_api_tool + return mock_tool + + +@pytest.fixture +def mock_rest_api_tools(): + """Fixture for a list of mock RestApiTools.""" + tools = [] + for i in range(3): + mock_tool = mock.MagicMock( + spec=RestApiTool, description=f"Test Tool Description {i}" + ) + mock_tool.name = f"test_tool_{i}" + tools.append(mock_tool) + return tools + + +@pytest.fixture +def mock_openapi_toolset_instance(): # Renamed from mock_openapi_toolset + """Fixture for a mock OpenAPIToolset instance.""" + mock_toolset = mock.MagicMock(spec=OpenAPIToolset) + # Mock async methods if they are called + mock_toolset.get_tools = mock.AsyncMock(return_value=[]) + mock_toolset.close = mock.AsyncMock() + return mock_toolset + + +@pytest.fixture +def mock_converter_instance(): # Renamed from mock_converter + """Fixture for a mock GoogleApiToOpenApiConverter instance.""" + mock_conv = mock.MagicMock(spec=GoogleApiToOpenApiConverter) + mock_conv.convert.return_value = { + "components": { + "securitySchemes": { + "oauth2": { + "flows": { + "authorizationCode": { + "scopes": { + DEFAULT_SCOPE: "Full access to Google Calendar" + } + } + } + } + } + } + } + return mock_conv + + +@pytest.fixture +def mock_readonly_context(): + """Fixture for a mock ReadonlyContext.""" + return mock.MagicMock(spec=ReadonlyContext) + + +class TestGoogleApiToolset: + """Test suite for the GoogleApiToolset class.""" + + @mock.patch( + "google.adk.tools.google_api_tool.google_api_toolset.OpenAPIToolset" + ) + @mock.patch( + "google.adk.tools.google_api_tool.google_api_toolset.GoogleApiToOpenApiConverter" + ) + def test_init( + self, + mock_converter_class, + mock_openapi_toolset_class, + mock_converter_instance, + mock_openapi_toolset_instance, + ): + """Test GoogleApiToolset initialization.""" + mock_converter_class.return_value = mock_converter_instance + mock_openapi_toolset_class.return_value = mock_openapi_toolset_instance + + client_id = "test_client_id" + client_secret = "test_client_secret" + + tool_set = GoogleApiToolset( + api_name=TEST_API_NAME, + api_version=TEST_API_VERSION, + client_id=client_id, + client_secret=client_secret, + ) + + assert tool_set.api_name == TEST_API_NAME + assert tool_set.api_version == TEST_API_VERSION + assert tool_set._client_id == client_id + assert tool_set._client_secret == client_secret + assert tool_set._service_account is None + assert tool_set.tool_filter is None + assert tool_set._openapi_toolset == mock_openapi_toolset_instance + + mock_converter_class.assert_called_once_with( + TEST_API_NAME, TEST_API_VERSION + ) + mock_converter_instance.convert.assert_called_once() + spec_dict = mock_converter_instance.convert.return_value + + mock_openapi_toolset_class.assert_called_once() + _, kwargs = mock_openapi_toolset_class.call_args + assert kwargs["spec_dict"] == spec_dict + assert kwargs["spec_str_type"] == "yaml" + assert isinstance(kwargs["auth_scheme"], OpenIdConnectWithConfig) + assert kwargs["auth_scheme"].scopes == [DEFAULT_SCOPE] + + @mock.patch( + "google.adk.tools.google_api_tool.google_api_toolset.GoogleApiTool" + ) + @mock.patch( + "google.adk.tools.google_api_tool.google_api_toolset.OpenAPIToolset" + ) + @mock.patch( + "google.adk.tools.google_api_tool.google_api_toolset.GoogleApiToOpenApiConverter" + ) + async def test_get_tools( + self, + mock_converter_class, + mock_openapi_toolset_class, + mock_google_api_tool_class, + mock_converter_instance, + mock_openapi_toolset_instance, + mock_rest_api_tools, + mock_readonly_context, + ): + """Test get_tools method.""" + mock_converter_class.return_value = mock_converter_instance + mock_openapi_toolset_class.return_value = mock_openapi_toolset_instance + mock_openapi_toolset_instance.get_tools = mock.AsyncMock( + return_value=mock_rest_api_tools + ) + + # Setup mock GoogleApiTool instances to be returned by the constructor + mock_google_api_tool_instances = [ + mock.MagicMock(spec=GoogleApiTool, name=f"google_tool_{i}") + for i in range(len(mock_rest_api_tools)) + ] + mock_google_api_tool_class.side_effect = mock_google_api_tool_instances + + client_id = "cid" + client_secret = "csecret" + sa_mock = mock.MagicMock(spec=ServiceAccount) + + tool_set = GoogleApiToolset( + api_name=TEST_API_NAME, + api_version=TEST_API_VERSION, + client_id=client_id, + client_secret=client_secret, + service_account=sa_mock, + ) + + tools = await tool_set.get_tools(mock_readonly_context) + + assert len(tools) == len(mock_rest_api_tools) + mock_openapi_toolset_instance.get_tools.assert_called_once_with( + mock_readonly_context + ) + + for i, rest_tool in enumerate(mock_rest_api_tools): + mock_google_api_tool_class.assert_any_call( + rest_tool, client_id, client_secret, sa_mock + ) + assert tools[i] is mock_google_api_tool_instances[i] + + @mock.patch( + "google.adk.tools.google_api_tool.google_api_toolset.OpenAPIToolset" + ) + @mock.patch( + "google.adk.tools.google_api_tool.google_api_toolset.GoogleApiToOpenApiConverter" + ) + async def test_get_tools_with_filter_list( + self, + mock_converter_class, + mock_openapi_toolset_class, + mock_openapi_toolset_instance, + mock_rest_api_tools, # Has test_tool_0, test_tool_1, test_tool_2 + mock_readonly_context, + mock_converter_instance, + ): + """Test get_tools method with a list filter.""" + mock_converter_class.return_value = mock_converter_instance + mock_openapi_toolset_class.return_value = mock_openapi_toolset_instance + mock_openapi_toolset_instance.get_tools = mock.AsyncMock( + return_value=mock_rest_api_tools + ) + + tool_filter = ["test_tool_0", "test_tool_2"] + tool_set = GoogleApiToolset( + api_name=TEST_API_NAME, + api_version=TEST_API_VERSION, + tool_filter=tool_filter, + ) + + tools = await tool_set.get_tools(mock_readonly_context) + + assert len(tools) == 2 + assert tools[0].name == "test_tool_0" + assert tools[1].name == "test_tool_2" + + @mock.patch( + "google.adk.tools.google_api_tool.google_api_toolset.OpenAPIToolset" + ) + @mock.patch( + "google.adk.tools.google_api_tool.google_api_toolset.GoogleApiToOpenApiConverter" + ) + async def test_get_tools_with_filter_predicate( + self, + mock_converter_class, + mock_openapi_toolset_class, + mock_converter_instance, + mock_openapi_toolset_instance, + mock_rest_api_tools, # Has test_tool_0, test_tool_1, test_tool_2 + mock_readonly_context, + ): + """Test get_tools method with a predicate filter.""" + mock_converter_class.return_value = mock_converter_instance + mock_openapi_toolset_class.return_value = mock_openapi_toolset_instance + mock_openapi_toolset_instance.get_tools = mock.AsyncMock( + return_value=mock_rest_api_tools + ) + + class MyPredicate(ToolPredicate): + + def __call__( + self, + tool: BaseTool, + readonly_context: Optional[ReadonlyContext] = None, + ) -> bool: + return tool.name == "test_tool_1" + + tool_set = GoogleApiToolset( + api_name=TEST_API_NAME, + api_version=TEST_API_VERSION, + tool_filter=MyPredicate(), + ) + + tools = await tool_set.get_tools(mock_readonly_context) + + assert len(tools) == 1 + assert tools[0].name == "test_tool_1" + + @mock.patch( + "google.adk.tools.google_api_tool.google_api_toolset.OpenAPIToolset" + ) + @mock.patch( + "google.adk.tools.google_api_tool.google_api_toolset.GoogleApiToOpenApiConverter" + ) + def test_configure_auth( + self, + mock_converter_class, + mock_openapi_toolset_class, + mock_converter_instance, + mock_openapi_toolset_instance, + ): + """Test configure_auth method.""" + mock_converter_class.return_value = mock_converter_instance + mock_openapi_toolset_class.return_value = mock_openapi_toolset_instance + + tool_set = GoogleApiToolset( + api_name=TEST_API_NAME, api_version=TEST_API_VERSION + ) + client_id = "test_client_id" + client_secret = "test_client_secret" + + tool_set.configure_auth(client_id, client_secret) + + assert tool_set._client_id == client_id + assert tool_set._client_secret == client_secret + + # To verify its effect, we would ideally call get_tools and check + # how GoogleApiTool is instantiated. This is covered in test_get_tools. + + @mock.patch( + "google.adk.tools.google_api_tool.google_api_toolset.OpenAPIToolset" + ) + @mock.patch( + "google.adk.tools.google_api_tool.google_api_toolset.GoogleApiToOpenApiConverter" + ) + def test_configure_sa_auth( + self, + mock_converter_class, + mock_openapi_toolset_class, + mock_converter_instance, + mock_openapi_toolset_instance, + ): + """Test configure_sa_auth method.""" + mock_converter_class.return_value = mock_converter_instance + mock_openapi_toolset_class.return_value = mock_openapi_toolset_instance + + tool_set = GoogleApiToolset( + api_name=TEST_API_NAME, api_version=TEST_API_VERSION + ) + service_account = ServiceAccount( + service_account_credential=ServiceAccountCredential( + type="service_account", + project_id="project_id", + private_key_id="private_key_id", + private_key=( + "-----BEGIN PRIVATE KEY-----\nprivate_key\n-----END PRIVATE" + " KEY-----\n" + ), + client_email="client_email", + client_id="client_id", + auth_uri="auth_uri", + token_uri="token_uri", + auth_provider_x509_cert_url="auth_provider_x509_cert_url", + client_x509_cert_url="client_x509_cert_url", + universe_domain="universe_domain", + ), + scopes=["scope1", "scope2"], + ) + + tool_set.configure_sa_auth(service_account) + assert tool_set._service_account == service_account + # Effect verification is covered in test_get_tools. + + @mock.patch( + "google.adk.tools.google_api_tool.google_api_toolset.OpenAPIToolset" + ) + @mock.patch( + "google.adk.tools.google_api_tool.google_api_toolset.GoogleApiToOpenApiConverter" + ) + async def test_close( + self, + mock_converter_class, + mock_openapi_toolset_class, + mock_converter_instance, + mock_openapi_toolset_instance, + ): + """Test close method.""" + mock_converter_class.return_value = mock_converter_instance + mock_openapi_toolset_class.return_value = mock_openapi_toolset_instance + + tool_set = GoogleApiToolset( + api_name=TEST_API_NAME, api_version=TEST_API_VERSION + ) + await tool_set.close() + + mock_openapi_toolset_instance.close.assert_called_once() + + @mock.patch( + "google.adk.tools.google_api_tool.google_api_toolset.OpenAPIToolset" + ) + @mock.patch( + "google.adk.tools.google_api_tool.google_api_toolset.GoogleApiToOpenApiConverter" + ) + def test_set_tool_filter( + self, + mock_converter_class, + mock_openapi_toolset_class, + mock_converter_instance, + mock_openapi_toolset_instance, + ): + """Test set_tool_filter method.""" + mock_converter_class.return_value = mock_converter_instance + mock_openapi_toolset_class.return_value = mock_openapi_toolset_instance + + tool_set = GoogleApiToolset( + api_name=TEST_API_NAME, api_version=TEST_API_VERSION + ) + + assert tool_set.tool_filter is None + + new_filter_list = ["tool1", "tool2"] + tool_set.set_tool_filter(new_filter_list) + assert tool_set.tool_filter == new_filter_list + + def new_filter_predicate( + tool_name: str, + tool: RestApiTool, + readonly_context: Optional[ReadonlyContext] = None, + ) -> bool: + return True + + tool_set.set_tool_filter(new_filter_predicate) + assert tool_set.tool_filter == new_filter_predicate