diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index b87a083ac..028c43d7c 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -466,16 +466,18 @@ async def _postprocess_handle_function_calls_async( function_call_event: Event, llm_request: LlmRequest, ) -> AsyncGenerator[Event, None]: - if function_response_event := await functions.handle_function_calls_async( + if function_response_event_generator := functions.handle_function_calls_async( invocation_context, function_call_event, llm_request.tools_dict ): - auth_event = functions.generate_auth_event( - invocation_context, function_response_event - ) - if auth_event: - yield auth_event + async for function_response_event in function_response_event_generator: + auth_event = functions.generate_auth_event( + invocation_context, function_response_event + ) + if auth_event: + yield auth_event + + yield function_response_event - yield function_response_event transfer_to_agent = function_response_event.actions.transfer_to_agent if transfer_to_agent: agent_to_run = self._get_agent_to_run( diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 379e11ef7..6cfd3ca04 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -22,6 +22,7 @@ from typing import Any from typing import AsyncGenerator from typing import cast +from typing import Union from typing import Optional import uuid @@ -128,17 +129,17 @@ async def handle_function_calls_async( function_call_event: Event, tools_dict: dict[str, BaseTool], filters: Optional[set[str]] = None, -) -> Optional[Event]: +) -> AsyncGenerator[Optional[Event], None]: """Calls the functions and returns the function response event.""" from ...agents.llm_agent import LlmAgent agent = invocation_context.agent if not isinstance(agent, LlmAgent): + yield None return function_calls = function_call_event.get_function_calls() - function_response_events: list[Event] = [] for function_call in function_calls: if filters and function_call.id not in filters: continue @@ -217,33 +218,50 @@ async def handle_function_calls_async( if not function_response: continue - # Builds the function response event. - function_response_event = __build_response_event( - tool, function_response, tool_context, invocation_context - ) - trace_tool_call( - tool=tool, - args=function_args, - function_response_event=function_response_event, - ) - function_response_events.append(function_response_event) - - if not function_response_events: - return None - merged_event = merge_parallel_function_response_events( - function_response_events - ) + function_response_events: list[Event] = [] + async for function_response_one in _ensure_async_generator_function_response( + function_response + ): + # Builds the function response event. + function_response_event = __build_response_event( + tool, function_response_one, tool_context, invocation_context + ) + trace_tool_call( + tool=tool, + args=function_args, + function_response_event=function_response_event, + ) + function_response_events.append(function_response_event) - if len(function_response_events) > 1: - # this is needed for debug traces of parallel calls - # individual response with tool.name is traced in __build_response_event - # (we drop tool.name from span name here as this is merged event) - with tracer.start_as_current_span('execute_tool (merged)'): - trace_merged_tool_calls( - response_event_id=merged_event.id, - function_response_event=merged_event, - ) - return merged_event + if not function_response_events: + yield None + return + merged_event = merge_parallel_function_response_events( + function_response_events + ) + if len(function_response_events) > 1: + # this is needed for debug traces of parallel calls + # individual response with tool.name is traced in __build_response_event + # (we drop tool.name from span name here as this is merged event) + with tracer.start_as_current_span('execute_tool (merged)'): + trace_merged_tool_calls( + response_event_id=merged_event.id, + function_response_event=merged_event, + ) + yield merged_event + + +async def _ensure_async_generator_function_response( + function_response: Union[str, dict, AsyncGenerator] +) -> AsyncGenerator[Optional[Union[str, dict]], None]: + if inspect.isasyncgen(function_response): + async for response in function_response: + yield response + elif inspect.isgenerator(function_response): + for response in function_response: + yield response + else: + yield function_response async def handle_function_calls_live( @@ -587,4 +605,4 @@ def find_matching_function_call( for function_call in function_calls: if function_call.id == function_call_id: return event - return None + return None \ No newline at end of file