From 5b694d1c386d4984c81bc13a4a2b132ec0b66f67 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Thu, 17 Jul 2025 10:35:10 -0700 Subject: [PATCH] feat: Allow toolset to process llm_request before tools returned by it PiperOrigin-RevId: 784234023 --- src/google/adk/agents/llm_agent.py | 18 +- .../adk/flows/llm_flows/base_llm_flow.py | 2 +- src/google/adk/tools/base_toolset.py | 22 +++ .../unittests/agents/test_llm_agent_fields.py | 161 +++++++++++++++++- .../flows/llm_flows/test_base_llm_flow.py | 150 ++++++++++++++++ tests/unittests/tools/test_base_toolset.py | 109 ++++++++++++ 6 files changed, 454 insertions(+), 8 deletions(-) create mode 100644 tests/unittests/flows/llm_flows/test_base_llm_flow.py create mode 100644 tests/unittests/tools/test_base_toolset.py diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index c9441cd04..4c7705b05 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -110,14 +110,18 @@ async def _convert_tool_union_to_tools( - tool_union: ToolUnion, ctx: ReadonlyContext -) -> list[BaseTool]: + tool_union: ToolUnion, ctx: ReadonlyContext, with_toolset: bool = False +) -> list[Union[BaseTool, BaseToolset]]: if isinstance(tool_union, BaseTool): return [tool_union] if isinstance(tool_union, Callable): return [FunctionTool(func=tool_union)] - return await tool_union.get_tools(ctx) + return ( + [tool_union] + await tool_union.get_tools(ctx) + if with_toolset + else await tool_union.get_tools(ctx) + ) class LlmAgent(BaseAgent): @@ -361,15 +365,17 @@ async def canonical_global_instruction( return global_instruction, True async def canonical_tools( - self, ctx: ReadonlyContext = None - ) -> list[BaseTool]: + self, ctx: ReadonlyContext = None, with_toolset: bool = False + ) -> list[Union[BaseTool, BaseToolset]]: """The resolved self.tools field as a list of BaseTool based on the context. This method is only for use by Agent Development Kit. """ resolved_tools = [] for tool_union in self.tools: - resolved_tools.extend(await _convert_tool_union_to_tools(tool_union, ctx)) + resolved_tools.extend( + await _convert_tool_union_to_tools(tool_union, ctx, with_toolset) + ) return resolved_tools @property diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index b87a083ac..51922f6b4 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -342,7 +342,7 @@ async def _preprocess_async( # Run processors for tools. for tool in await agent.canonical_tools( - ReadonlyContext(invocation_context) + ReadonlyContext(invocation_context), with_toolset=True ): tool_context = ToolContext(invocation_context) await tool.process_llm_request( diff --git a/src/google/adk/tools/base_toolset.py b/src/google/adk/tools/base_toolset.py index 6003f560b..706b4c42c 100644 --- a/src/google/adk/tools/base_toolset.py +++ b/src/google/adk/tools/base_toolset.py @@ -20,11 +20,16 @@ from typing import Optional from typing import Protocol from typing import runtime_checkable +from typing import TYPE_CHECKING from typing import Union from ..agents.readonly_context import ReadonlyContext from .base_tool import BaseTool +if TYPE_CHECKING: + from ..models.llm_request import LlmRequest + from .tool_context import ToolContext + @runtime_checkable class ToolPredicate(Protocol): @@ -96,3 +101,20 @@ def _is_tool_selected( return tool.name in self.tool_filter return False + + async def process_llm_request( + self, *, tool_context: ToolContext, llm_request: LlmRequest + ) -> None: + """Processes the outgoing LLM request for this toolset. This method will be + called before each tool processes the llm request. + + Use cases: + - Instead of let each tool process the llm request, we can let the toolset + process the llm request. e.g. ComputerUseToolset can add computer use + tool to the llm request. + + Args: + tool_context: The context of the tool. + llm_request: The outgoing LLM request, mutable this method. + """ + pass diff --git a/tests/unittests/agents/test_llm_agent_fields.py b/tests/unittests/agents/test_llm_agent_fields.py index 9b3a4abca..300f0a26d 100644 --- a/tests/unittests/agents/test_llm_agent_fields.py +++ b/tests/unittests/agents/test_llm_agent_fields.py @@ -21,11 +21,13 @@ from google.adk.agents.callback_context import CallbackContext from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent -from google.adk.agents.loop_agent import LoopAgent from google.adk.agents.readonly_context import ReadonlyContext from google.adk.models.llm_request import LlmRequest from google.adk.models.registry import LLMRegistry from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.base_toolset import BaseToolset +from google.adk.tools.function_tool import FunctionTool from google.genai import types from pydantic import BaseModel import pytest @@ -280,3 +282,160 @@ def test_allow_transfer_by_default(): assert not agent.disallow_transfer_to_parent assert not agent.disallow_transfer_to_peers + + +@pytest.mark.asyncio +async def test_canonical_tools_without_toolset(): + """Test canonical_tools returns only tools when with_toolset=False (default).""" + + def _test_function(): + pass + + class _TestToolset(BaseToolset): + + async def get_tools(self, readonly_context=None): + return [FunctionTool(func=_test_function)] + + async def close(self): + pass + + toolset = _TestToolset() + agent = LlmAgent(name='test_agent', tools=[toolset]) + ctx = await _create_readonly_context(agent) + + tools = await agent.canonical_tools(ctx) + + # Should only return tools, not the toolset itself + assert len(tools) == 1 + assert all(isinstance(tool, BaseTool) for tool in tools) + assert isinstance(tools[0], FunctionTool) + + +@pytest.mark.asyncio +async def test_canonical_tools_with_toolset(): + """Test canonical_tools returns tools and toolsets when with_toolset=True.""" + + def _test_function(): + pass + + class _TestToolset(BaseToolset): + + async def get_tools(self, readonly_context=None): + return [FunctionTool(func=_test_function)] + + async def close(self): + pass + + toolset = _TestToolset() + agent = LlmAgent(name='test_agent', tools=[toolset]) + ctx = await _create_readonly_context(agent) + + tools = await agent.canonical_tools(ctx, with_toolset=True) + + # Should return both the toolset and its tools + assert len(tools) == 2 + assert any(isinstance(item, BaseToolset) for item in tools) + assert any(isinstance(item, BaseTool) for item in tools) + + +@pytest.mark.asyncio +async def test_canonical_tools_mixed_tools_and_toolsets(): + """Test canonical_tools with mixed tools, functions, and toolsets.""" + + def _test_function(): + pass + + class _TestTool(BaseTool): + + def __init__(self): + super().__init__(name='test_tool', description='Test tool') + + async def call(self, **kwargs): + return 'test' + + class _TestToolset(BaseToolset): + + async def get_tools(self, readonly_context=None): + return [_TestTool()] + + async def close(self): + pass + + direct_tool = _TestTool() + toolset = _TestToolset() + + agent = LlmAgent( + name='test_agent', tools=[direct_tool, _test_function, toolset] + ) + ctx = await _create_readonly_context(agent) + + # Test without toolset + tools_only = await agent.canonical_tools(ctx, with_toolset=False) + assert len(tools_only) == 3 # direct_tool + function_tool + toolset's tool + assert all(isinstance(tool, BaseTool) for tool in tools_only) + + # Test with toolset + tools_with_toolset = await agent.canonical_tools(ctx, with_toolset=True) + assert ( + len(tools_with_toolset) == 4 + ) # direct_tool + function_tool + toolset + toolset's tool + assert ( + sum(1 for item in tools_with_toolset if isinstance(item, BaseToolset)) + == 1 + ) + assert ( + sum(1 for item in tools_with_toolset if isinstance(item, BaseTool)) == 3 + ) + + +@pytest.mark.asyncio +async def test_convert_tool_union_to_tools(): + """Test the _convert_tool_union_to_tools function with different tool types.""" + from google.adk.agents.llm_agent import _convert_tool_union_to_tools + + def _test_function(): + pass + + class _TestTool(BaseTool): + + def __init__(self): + super().__init__(name='test_tool', description='Test tool') + + async def call(self, **kwargs): + return 'test' + + class _TestToolset(BaseToolset): + + async def get_tools(self, readonly_context=None): + return [_TestTool()] + + async def close(self): + pass + + agent = LlmAgent(name='test_agent') + ctx = await _create_readonly_context(agent) + + # Test with BaseTool + tool = _TestTool() + result = await _convert_tool_union_to_tools(tool, ctx, with_toolset=False) + assert len(result) == 1 + assert isinstance(result[0], BaseTool) + + # Test with function + result = await _convert_tool_union_to_tools( + _test_function, ctx, with_toolset=False + ) + assert len(result) == 1 + assert isinstance(result[0], FunctionTool) + + # Test with toolset without including toolset + toolset = _TestToolset() + result = await _convert_tool_union_to_tools(toolset, ctx, with_toolset=False) + assert len(result) == 1 + assert isinstance(result[0], BaseTool) + + # Test with toolset including toolset + result = await _convert_tool_union_to_tools(toolset, ctx, with_toolset=True) + assert len(result) == 2 + assert any(isinstance(item, BaseToolset) for item in result) + assert any(isinstance(item, BaseTool) for item in result) diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py new file mode 100644 index 000000000..82333c45a --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -0,0 +1,150 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for BaseLlmFlow toolset integration.""" + +from unittest.mock import AsyncMock + +from google.adk.agents import Agent +from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.tools.base_toolset import BaseToolset +from google.genai import types +import pytest + +from ... import testing_utils + + +class BaseLlmFlowForTesting(BaseLlmFlow): + """Test implementation of BaseLlmFlow for testing purposes.""" + + pass + + +@pytest.mark.asyncio +async def test_preprocess_calls_toolset_process_llm_request(): + """Test that _preprocess_async calls process_llm_request on toolsets.""" + + # Create a mock toolset that tracks if process_llm_request was called + class _MockToolset(BaseToolset): + + def __init__(self): + super().__init__() + self.process_llm_request_called = False + self.process_llm_request = AsyncMock(side_effect=self._track_call) + + async def _track_call(self, **kwargs): + self.process_llm_request_called = True + + async def get_tools(self, readonly_context=None): + return [] + + async def close(self): + pass + + mock_toolset = _MockToolset() + + # Create a mock model that returns a simple response + mock_response = LlmResponse( + content=types.Content( + role='model', parts=[types.Part.from_text(text='Test response')] + ), + partial=False, + ) + + mock_model = testing_utils.MockModel.create(responses=[mock_response]) + + # Create agent with the mock toolset + agent = Agent(name='test_agent', model=mock_model, tools=[mock_toolset]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test message' + ) + + flow = BaseLlmFlowForTesting() + + # Call _preprocess_async + llm_request = LlmRequest() + events = [] + async for event in flow._preprocess_async(invocation_context, llm_request): + events.append(event) + + # Verify that process_llm_request was called on the toolset + assert mock_toolset.process_llm_request_called + + +@pytest.mark.asyncio +async def test_preprocess_handles_mixed_tools_and_toolsets(): + """Test that _preprocess_async properly handles both tools and toolsets.""" + from google.adk.tools.base_tool import BaseTool + from google.adk.tools.function_tool import FunctionTool + + # Create a mock tool + class _MockTool(BaseTool): + + def __init__(self): + super().__init__(name='mock_tool', description='Mock tool') + self.process_llm_request_called = False + self.process_llm_request = AsyncMock(side_effect=self._track_call) + + async def _track_call(self, **kwargs): + self.process_llm_request_called = True + + async def call(self, **kwargs): + return 'mock result' + + # Create a mock toolset + class _MockToolset(BaseToolset): + + def __init__(self): + super().__init__() + self.process_llm_request_called = False + self.process_llm_request = AsyncMock(side_effect=self._track_call) + + async def _track_call(self, **kwargs): + self.process_llm_request_called = True + + async def get_tools(self, readonly_context=None): + return [] + + async def close(self): + pass + + def _test_function(): + """Test function tool.""" + return 'function result' + + mock_tool = _MockTool() + mock_toolset = _MockToolset() + + # Create agent with mixed tools and toolsets + agent = Agent( + name='test_agent', tools=[mock_tool, _test_function, mock_toolset] + ) + + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test message' + ) + + flow = BaseLlmFlowForTesting() + + # Call _preprocess_async + llm_request = LlmRequest() + events = [] + async for event in flow._preprocess_async(invocation_context, llm_request): + events.append(event) + + # Verify that process_llm_request was called on both tools and toolsets + assert mock_tool.process_llm_request_called + assert mock_toolset.process_llm_request_called diff --git a/tests/unittests/tools/test_base_toolset.py b/tests/unittests/tools/test_base_toolset.py new file mode 100644 index 000000000..5414bb3c8 --- /dev/null +++ b/tests/unittests/tools/test_base_toolset.py @@ -0,0 +1,109 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for BaseToolset.""" + +from typing import Optional + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.models.llm_request import LlmRequest +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.base_toolset import BaseToolset +from google.adk.tools.tool_context import ToolContext +import pytest + + +class _TestingToolset(BaseToolset): + """A test implementation of BaseToolset.""" + + async def get_tools( + self, readonly_context: Optional[ReadonlyContext] = None + ) -> list[BaseTool]: + return [] + + async def close(self) -> None: + pass + + +@pytest.mark.asyncio +async def test_process_llm_request_default_implementation(): + """Test that the default process_llm_request implementation does nothing.""" + toolset = _TestingToolset() + + # Create test objects + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + agent = SequentialAgent(name='test_agent') + invocation_context = InvocationContext( + invocation_id='test_id', + agent=agent, + session=session, + session_service=session_service, + ) + tool_context = ToolContext(invocation_context) + llm_request = LlmRequest() + + # The default implementation should not modify the request + original_request = LlmRequest.model_validate(llm_request.model_dump()) + + await toolset.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + # Verify the request was not modified + assert llm_request.model_dump() == original_request.model_dump() + + +@pytest.mark.asyncio +async def test_process_llm_request_can_be_overridden(): + """Test that process_llm_request can be overridden by subclasses.""" + + class _CustomToolset(_TestingToolset): + + async def process_llm_request( + self, *, tool_context: ToolContext, llm_request: LlmRequest + ) -> None: + # Add some custom processing + if not llm_request.contents: + llm_request.contents = [] + llm_request.contents.append('Custom processing applied') + + toolset = _CustomToolset() + + # Create test objects + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + agent = SequentialAgent(name='test_agent') + invocation_context = InvocationContext( + invocation_id='test_id', + agent=agent, + session=session, + session_service=session_service, + ) + tool_context = ToolContext(invocation_context) + llm_request = LlmRequest() + + await toolset.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + # Verify the custom processing was applied + assert llm_request.contents == ['Custom processing applied']