Thanks to visit codestin.com
Credit goes to github.com

Skip to content

feat: Allow toolset to process llm_request before tools returned by it #2013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/google/adk/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,18 @@


async def _convert_tool_union_to_tools(
tool_union: ToolUnion, ctx: ReadonlyContext
) -> list[BaseTool]:
tool_union: ToolUnion, ctx: ReadonlyContext, with_toolset: bool = False
) -> list[Union[BaseTool, BaseToolset]]:
if isinstance(tool_union, BaseTool):
return [tool_union]
if isinstance(tool_union, Callable):
return [FunctionTool(func=tool_union)]

return await tool_union.get_tools(ctx)
return (
[tool_union] + await tool_union.get_tools(ctx)
if with_toolset
else await tool_union.get_tools(ctx)
)


class LlmAgent(BaseAgent):
Expand Down Expand Up @@ -361,15 +365,17 @@ async def canonical_global_instruction(
return global_instruction, True

async def canonical_tools(
self, ctx: ReadonlyContext = None
) -> list[BaseTool]:
self, ctx: ReadonlyContext = None, with_toolset: bool = False
) -> list[Union[BaseTool, BaseToolset]]:
"""The resolved self.tools field as a list of BaseTool based on the context.

This method is only for use by Agent Development Kit.
"""
resolved_tools = []
for tool_union in self.tools:
resolved_tools.extend(await _convert_tool_union_to_tools(tool_union, ctx))
resolved_tools.extend(
await _convert_tool_union_to_tools(tool_union, ctx, with_toolset)
)
return resolved_tools

@property
Expand Down
2 changes: 1 addition & 1 deletion src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ async def _preprocess_async(

# Run processors for tools.
for tool in await agent.canonical_tools(
ReadonlyContext(invocation_context)
ReadonlyContext(invocation_context), with_toolset=True
):
tool_context = ToolContext(invocation_context)
await tool.process_llm_request(
Expand Down
22 changes: 22 additions & 0 deletions src/google/adk/tools/base_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,16 @@
from typing import Optional
from typing import Protocol
from typing import runtime_checkable
from typing import TYPE_CHECKING
from typing import Union

from ..agents.readonly_context import ReadonlyContext
from .base_tool import BaseTool

if TYPE_CHECKING:
from ..models.llm_request import LlmRequest
from .tool_context import ToolContext


@runtime_checkable
class ToolPredicate(Protocol):
Expand Down Expand Up @@ -96,3 +101,20 @@ def _is_tool_selected(
return tool.name in self.tool_filter

return False

async def process_llm_request(
self, *, tool_context: ToolContext, llm_request: LlmRequest
) -> None:
"""Processes the outgoing LLM request for this toolset. This method will be
called before each tool processes the llm request.

Use cases:
- Instead of let each tool process the llm request, we can let the toolset
process the llm request. e.g. ComputerUseToolset can add computer use
tool to the llm request.

Args:
tool_context: The context of the tool.
llm_request: The outgoing LLM request, mutable this method.
"""
pass
161 changes: 160 additions & 1 deletion tests/unittests/agents/test_llm_agent_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.loop_agent import LoopAgent
from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.models.llm_request import LlmRequest
from google.adk.models.registry import LLMRegistry
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.base_toolset import BaseToolset
from google.adk.tools.function_tool import FunctionTool
from google.genai import types
from pydantic import BaseModel
import pytest
Expand Down Expand Up @@ -280,3 +282,160 @@ def test_allow_transfer_by_default():

assert not agent.disallow_transfer_to_parent
assert not agent.disallow_transfer_to_peers


@pytest.mark.asyncio
async def test_canonical_tools_without_toolset():
"""Test canonical_tools returns only tools when with_toolset=False (default)."""

def _test_function():
pass

class _TestToolset(BaseToolset):

async def get_tools(self, readonly_context=None):
return [FunctionTool(func=_test_function)]

async def close(self):
pass

toolset = _TestToolset()
agent = LlmAgent(name='test_agent', tools=[toolset])
ctx = await _create_readonly_context(agent)

tools = await agent.canonical_tools(ctx)

# Should only return tools, not the toolset itself
assert len(tools) == 1
assert all(isinstance(tool, BaseTool) for tool in tools)
assert isinstance(tools[0], FunctionTool)


@pytest.mark.asyncio
async def test_canonical_tools_with_toolset():
"""Test canonical_tools returns tools and toolsets when with_toolset=True."""

def _test_function():
pass

class _TestToolset(BaseToolset):

async def get_tools(self, readonly_context=None):
return [FunctionTool(func=_test_function)]

async def close(self):
pass

toolset = _TestToolset()
agent = LlmAgent(name='test_agent', tools=[toolset])
ctx = await _create_readonly_context(agent)

tools = await agent.canonical_tools(ctx, with_toolset=True)

# Should return both the toolset and its tools
assert len(tools) == 2
assert any(isinstance(item, BaseToolset) for item in tools)
assert any(isinstance(item, BaseTool) for item in tools)


@pytest.mark.asyncio
async def test_canonical_tools_mixed_tools_and_toolsets():
"""Test canonical_tools with mixed tools, functions, and toolsets."""

def _test_function():
pass

class _TestTool(BaseTool):

def __init__(self):
super().__init__(name='test_tool', description='Test tool')

async def call(self, **kwargs):
return 'test'

class _TestToolset(BaseToolset):

async def get_tools(self, readonly_context=None):
return [_TestTool()]

async def close(self):
pass

direct_tool = _TestTool()
toolset = _TestToolset()

agent = LlmAgent(
name='test_agent', tools=[direct_tool, _test_function, toolset]
)
ctx = await _create_readonly_context(agent)

# Test without toolset
tools_only = await agent.canonical_tools(ctx, with_toolset=False)
assert len(tools_only) == 3 # direct_tool + function_tool + toolset's tool
assert all(isinstance(tool, BaseTool) for tool in tools_only)

# Test with toolset
tools_with_toolset = await agent.canonical_tools(ctx, with_toolset=True)
assert (
len(tools_with_toolset) == 4
) # direct_tool + function_tool + toolset + toolset's tool
assert (
sum(1 for item in tools_with_toolset if isinstance(item, BaseToolset))
== 1
)
assert (
sum(1 for item in tools_with_toolset if isinstance(item, BaseTool)) == 3
)


@pytest.mark.asyncio
async def test_convert_tool_union_to_tools():
"""Test the _convert_tool_union_to_tools function with different tool types."""
from google.adk.agents.llm_agent import _convert_tool_union_to_tools

def _test_function():
pass

class _TestTool(BaseTool):

def __init__(self):
super().__init__(name='test_tool', description='Test tool')

async def call(self, **kwargs):
return 'test'

class _TestToolset(BaseToolset):

async def get_tools(self, readonly_context=None):
return [_TestTool()]

async def close(self):
pass

agent = LlmAgent(name='test_agent')
ctx = await _create_readonly_context(agent)

# Test with BaseTool
tool = _TestTool()
result = await _convert_tool_union_to_tools(tool, ctx, with_toolset=False)
assert len(result) == 1
assert isinstance(result[0], BaseTool)

# Test with function
result = await _convert_tool_union_to_tools(
_test_function, ctx, with_toolset=False
)
assert len(result) == 1
assert isinstance(result[0], FunctionTool)

# Test with toolset without including toolset
toolset = _TestToolset()
result = await _convert_tool_union_to_tools(toolset, ctx, with_toolset=False)
assert len(result) == 1
assert isinstance(result[0], BaseTool)

# Test with toolset including toolset
result = await _convert_tool_union_to_tools(toolset, ctx, with_toolset=True)
assert len(result) == 2
assert any(isinstance(item, BaseToolset) for item in result)
assert any(isinstance(item, BaseTool) for item in result)
Loading