diff --git a/pyproject.toml b/pyproject.toml index 095a38cb0..586a956af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,8 @@ sagemaker = [ ] a2a = [ - "a2a-sdk[sql]>=0.2.11,<1.0.0", + "a2a-sdk>=0.3.0,<0.4.0", + "a2a-sdk[sql]>=0.3.0,<0.4.0", "uvicorn>=0.34.2,<1.0.0", "httpx>=0.28.1,<1.0.0", "fastapi>=0.115.12,<1.0.0", @@ -142,7 +143,7 @@ all = [ "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", # a2a - "a2a-sdk[sql]>=0.2.11,<1.0.0", + "a2a-sdk[sql]>=0.3.0,<0.4.0", "uvicorn>=0.34.2,<1.0.0", "httpx>=0.28.1,<1.0.0", "fastapi>=0.115.12,<1.0.0", @@ -321,4 +322,4 @@ style = [ ["instruction", ""], ["text", ""], ["disabled", "fg:#858585 italic"] -] \ No newline at end of file +] diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index ffcb6a5c9..ae21d4c6d 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -28,7 +28,12 @@ from ..telemetry.tracer import get_tracer from ..tools.executor import run_tools, validate_and_prepare_tools from ..types.content import Message -from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +from ..types.exceptions import ( + ContextWindowOverflowException, + EventLoopException, + MaxTokensReachedException, + ModelThrottledException, +) from ..types.streaming import Metrics, StopReason from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse from .streaming import stream_messages @@ -187,6 +192,22 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> raise e try: + if stop_reason == "max_tokens": + """ + Handle max_tokens limit reached by the model. + + When the model reaches its maximum token limit, this represents a potentially unrecoverable + state where the model's response was truncated. By default, Strands fails hard with an + MaxTokensReachedException to maintain consistency with other failure types. + """ + raise MaxTokensReachedException( + message=( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ), + incomplete_message=message, + ) # Add message in trace and mark the end of the stream messages trace stream_trace.add_message(message) stream_trace.end() @@ -231,7 +252,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> # Don't yield or log the exception - we already did it when we # raised the exception and we don't need that duplication. raise - except ContextWindowOverflowException as e: + except (ContextWindowOverflowException, MaxTokensReachedException) as e: + # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException if cycle_span: tracer.end_span_with_error(cycle_span, str(e), e) raise e diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 0d734b762..975fca3e9 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -414,7 +414,7 @@ async def structured_output( stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".") + raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".') content = messages["content"] output_response: dict[str, Any] | None = None diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 9b36b4244..4ea1453a4 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -631,7 +631,7 @@ async def structured_output( stop_reason, messages, _, _ = event["stop"] if stop_reason != "tool_use": - raise ValueError(f"Model returned stop_reason: {stop_reason} instead of \"tool_use\".") + raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".') content = messages["content"] output_response: dict[str, Any] | None = None diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index d65c64aff..5bf9cbfe9 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -61,7 +61,7 @@ async def execute( task = new_task(context.message) # type: ignore await event_queue.enqueue_event(task) - updater = TaskUpdater(event_queue, task.id, task.contextId) + updater = TaskUpdater(event_queue, task.id, task.context_id) try: await self._execute_streaming(context, updater) diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index b32cb00e6..fec2f0761 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -23,6 +23,7 @@ class FileSessionManager(RepositorySessionManager, SessionRepository): """File-based session manager for local filesystem storage. Creates the following filesystem structure for the session storage: + ```bash // └── session_/ ├── session.json # Session metadata @@ -32,7 +33,7 @@ class FileSessionManager(RepositorySessionManager, SessionRepository): └── messages/ ├── message_.json └── message_.json - + ``` """ def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: Any): diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 8f8423828..0cc0a68c1 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -24,6 +24,7 @@ class S3SessionManager(RepositorySessionManager, SessionRepository): """S3-based session manager for cloud storage. Creates the following filesystem structure for the session storage: + ```bash // └── session_/ ├── session.json # Session metadata @@ -33,7 +34,7 @@ class S3SessionManager(RepositorySessionManager, SessionRepository): └── messages/ ├── message_.json └── message_.json - + ``` """ def __init__( diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 4bd3fd88e..71ea28b9f 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -2,6 +2,8 @@ from typing import Any +from strands.types.content import Message + class EventLoopException(Exception): """Exception raised by the event loop.""" @@ -18,6 +20,25 @@ def __init__(self, original_exception: Exception, request_state: Any = None) -> super().__init__(str(original_exception)) +class MaxTokensReachedException(Exception): + """Exception raised when the model reaches its maximum token generation limit. + + This exception is raised when the model stops generating tokens because it has reached the maximum number of + tokens allowed for output generation. This can occur when the model's max_tokens parameter is set too low for + the complexity of the response, or when the model naturally reaches its configured output limit during generation. + """ + + def __init__(self, message: str, incomplete_message: Message): + """Initialize the exception with an error message and the incomplete message object. + + Args: + message: The error message describing the token limit issue + incomplete_message: The valid Message object with incomplete content due to token limits + """ + self.incomplete_message = incomplete_message + super().__init__(message) + + class ContextWindowOverflowException(Exception): """Exception raised when the context window is exceeded. diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 1ac2f8258..3886df8b9 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -19,7 +19,12 @@ ) from strands.telemetry.metrics import EventLoopMetrics from strands.tools.registry import ToolRegistry -from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException +from strands.types.exceptions import ( + ContextWindowOverflowException, + EventLoopException, + MaxTokensReachedException, + ModelThrottledException, +) from tests.fixtures.mock_hook_provider import MockHookProvider @@ -556,6 +561,51 @@ async def test_event_loop_tracing_with_model_error( mock_tracer.end_span_with_error.assert_called_once_with(model_span, "Input too long", model.stream.side_effect) +@pytest.mark.asyncio +async def test_event_loop_cycle_max_tokens_exception( + agent, + model, + agenerator, + alist, +): + """Test that max_tokens stop reason raises MaxTokensReachedException.""" + + # Note the empty toolUse to handle case raised in https://github.com/strands-agents/sdk-python/issues/495 + model.stream.return_value = agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": {}, + }, + }, + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "max_tokens"}}, + ] + ) + + # Call event_loop_cycle, expecting it to raise MaxTokensReachedException + with pytest.raises(MaxTokensReachedException) as exc_info: + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await alist(stream) + + # Verify the exception message contains the expected content + expected_message = ( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ) + assert str(exc_info.value) == expected_message + + # Verify that the message has not been appended to the messages array + assert len(agent.messages) == 1 + assert exc_info.value.incomplete_message not in agent.messages + + @patch("strands.event_loop.event_loop.get_tracer") @pytest.mark.asyncio async def test_event_loop_tracing_with_tool_execution( diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index a956cb769..77645fc73 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -36,7 +36,7 @@ async def mock_stream(user_input): # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task await executor.execute(mock_request_context, mock_event_queue) @@ -65,7 +65,7 @@ async def mock_stream(user_input): # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task await executor.execute(mock_request_context, mock_event_queue) @@ -95,7 +95,7 @@ async def mock_stream(user_input): # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task await executor.execute(mock_request_context, mock_event_queue) @@ -125,7 +125,7 @@ async def mock_stream(user_input): # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task await executor.execute(mock_request_context, mock_event_queue) @@ -156,7 +156,7 @@ async def mock_stream(user_input): mock_request_context.current_task = None with patch("strands.multiagent.a2a.executor.new_task") as mock_new_task: - mock_new_task.return_value = MagicMock(id="new-task-id", contextId="new-context-id") + mock_new_task.return_value = MagicMock(id="new-task-id", context_id="new-context-id") await executor.execute(mock_request_context, mock_event_queue) @@ -180,7 +180,7 @@ async def test_execute_streaming_mode_handles_agent_exception( # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task with pytest.raises(ServerError): @@ -210,7 +210,7 @@ async def test_handle_agent_result_with_none_result(mock_strands_agent, mock_req # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task # Mock TaskUpdater @@ -235,7 +235,7 @@ async def test_handle_agent_result_with_result_but_no_message( # Mock the task creation mock_task = MagicMock() mock_task.id = "test-task-id" - mock_task.contextId = "test-context-id" + mock_task.context_id = "test-context-id" mock_request_context.current_task = mock_task # Mock TaskUpdater diff --git a/tests/strands/multiagent/a2a/test_server.py b/tests/strands/multiagent/a2a/test_server.py index fc76b5f1d..a3b47581c 100644 --- a/tests/strands/multiagent/a2a/test_server.py +++ b/tests/strands/multiagent/a2a/test_server.py @@ -87,8 +87,8 @@ def test_public_agent_card(mock_strands_agent): assert card.description == "A test agent for unit testing" assert card.url == "http://0.0.0.0:9000/" assert card.version == "0.0.1" - assert card.defaultInputModes == ["text"] - assert card.defaultOutputModes == ["text"] + assert card.default_input_modes == ["text"] + assert card.default_output_modes == ["text"] assert card.skills == [] assert card.capabilities == a2a_agent.capabilities diff --git a/tests_integ/test_max_tokens_reached.py b/tests_integ/test_max_tokens_reached.py new file mode 100644 index 000000000..d9c2817b3 --- /dev/null +++ b/tests_integ/test_max_tokens_reached.py @@ -0,0 +1,20 @@ +import pytest + +from strands import Agent, tool +from strands.models.bedrock import BedrockModel +from strands.types.exceptions import MaxTokensReachedException + + +@tool +def story_tool(story: str) -> str: + return story + + +def test_context_window_overflow(): + model = BedrockModel(max_tokens=100) + agent = Agent(model=model, tools=[story_tool]) + + with pytest.raises(MaxTokensReachedException): + agent("Tell me a story!") + + assert len(agent.messages) == 1