From 111fc9ee66c5f6f4af3ba6faaedba4789da3d03a Mon Sep 17 00:00:00 2001 From: Jonny Kalambay Date: Tue, 22 Apr 2025 19:26:47 -0700 Subject: [PATCH 1/8] Adding extra_headers parameters to ModelSettings (#550) --- src/agents/extensions/models/litellm_model.py | 2 +- src/agents/model_settings.py | 6 +- src/agents/models/openai_chatcompletions.py | 2 +- src/agents/models/openai_responses.py | 2 +- tests/test_extra_headers.py | 92 +++++++++++++++++++ 5 files changed, 100 insertions(+), 4 deletions(-) create mode 100644 tests/test_extra_headers.py diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index e939ee8d..f5e7752f 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -286,7 +286,7 @@ async def _fetch_response( stream=stream, stream_options=stream_options, reasoning_effort=reasoning_effort, - extra_headers=HEADERS, + extra_headers={**HEADERS, **(model_settings.extra_headers or {})}, api_key=self.api_key, base_url=self.base_url, **extra_kwargs, diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index ed9a0131..fee92b4e 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, fields, replace from typing import Literal -from openai._types import Body, Query +from openai._types import Body, Headers, Query from openai.types.shared import Reasoning @@ -67,6 +67,10 @@ class ModelSettings: """Additional body fields to provide with the request. Defaults to None if not provided.""" + extra_headers: Headers | None = None + """Additional headers to provide with the request. + Defaults to None if not provided.""" + def resolve(self, override: ModelSettings | None) -> ModelSettings: """Produce a new ModelSettings by overlaying any non-None values from the override on top of this instance.""" diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 9fd10269..15bf19cb 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -255,7 +255,7 @@ async def _fetch_response( stream_options=self._non_null_or_not_given(stream_options), store=self._non_null_or_not_given(store), reasoning_effort=self._non_null_or_not_given(reasoning_effort), - extra_headers=HEADERS, + extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) }, extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, metadata=self._non_null_or_not_given(model_settings.metadata), diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index b751663d..c1ff85b9 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -253,7 +253,7 @@ async def _fetch_response( tool_choice=tool_choice, parallel_tool_calls=parallel_tool_calls, stream=stream, - extra_headers=_HEADERS, + extra_headers={**_HEADERS, **(model_settings.extra_headers or {})}, extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, text=response_format, diff --git a/tests/test_extra_headers.py b/tests/test_extra_headers.py new file mode 100644 index 00000000..f29c2540 --- /dev/null +++ b/tests/test_extra_headers.py @@ -0,0 +1,92 @@ +import pytest +from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage + +from agents import ModelSettings, ModelTracing, OpenAIChatCompletionsModel, OpenAIResponsesModel + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_extra_headers_passed_to_openai_responses_model(): + """ + Ensure extra_headers in ModelSettings is passed to the OpenAIResponsesModel client. + """ + called_kwargs = {} + + class DummyResponses: + async def create(self, **kwargs): + nonlocal called_kwargs + called_kwargs = kwargs + class DummyResponse: + id = "dummy" + output = [] + usage = type( + "Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + )() + return DummyResponse() + + class DummyClient: + def __init__(self): + self.responses = DummyResponses() + + + + model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + extra_headers = {"X-Test-Header": "test-value"} + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(extra_headers=extra_headers), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + assert "extra_headers" in called_kwargs + assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value" + + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_extra_headers_passed_to_openai_client(): + """ + Ensure extra_headers in ModelSettings is passed to the OpenAI client. + """ + called_kwargs = {} + + class DummyCompletions: + async def create(self, **kwargs): + nonlocal called_kwargs + called_kwargs = kwargs + msg = ChatCompletionMessage(role="assistant", content="Hello") + choice = Choice(index=0, finish_reason="stop", message=msg) + return ChatCompletion( + id="resp-id", + created=0, + model="fake", + object="chat.completion", + choices=[choice], + usage=None, + ) + + class DummyClient: + def __init__(self): + self.chat = type("_Chat", (), {"completions": DummyCompletions()})() + self.base_url = "https://api.openai.com" + + model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + extra_headers = {"X-Test-Header": "test-value"} + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(extra_headers=extra_headers), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + assert "extra_headers" in called_kwargs + assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value" From 178020ea33980e5873a82dc715e79f0c6a285623 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Wed, 23 Apr 2025 11:29:12 +0900 Subject: [PATCH 2/8] Examples: Fix financial_research_agent instructions (#573) --- examples/financial_research_agent/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/financial_research_agent/main.py b/examples/financial_research_agent/main.py index 3fa8a7e0..b5b6cfdf 100644 --- a/examples/financial_research_agent/main.py +++ b/examples/financial_research_agent/main.py @@ -4,7 +4,7 @@ # Entrypoint for the financial bot example. -# Run this as `python -m examples.financial_bot.main` and enter a +# Run this as `python -m examples.financial_research_agent.main` and enter a # financial research query, for example: # "Write up an analysis of Apple Inc.'s most recent quarter." async def main() -> None: From a113fea0eef82bb37a0a803eaae42c4761d0ebdf Mon Sep 17 00:00:00 2001 From: Andrew Han Date: Wed, 23 Apr 2025 16:51:10 -0700 Subject: [PATCH 3/8] Allow cancel out of the streaming result (#579) Fix for #574 @rm-openai I'm not sure how to add a test within the repo but I have pasted a test script below that seems to work ```python import asyncio from openai.types.responses import ResponseTextDeltaEvent from agents import Agent, Runner async def main(): agent = Agent( name="Joker", instructions="You are a helpful assistant.", ) result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") num_visible_event = 0 async for event in result.stream_events(): if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): print(event.data.delta, end="", flush=True) num_visible_event += 1 print(num_visible_event) if num_visible_event == 3: result.cancel() if __name__ == "__main__": asyncio.run(main()) ```` --- src/agents/result.py | 24 +++++++++++++++++++++--- tests/test_cancel_streaming.py | 22 ++++++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) create mode 100644 tests/test_cancel_streaming.py diff --git a/src/agents/result.py b/src/agents/result.py index 0d8372c8..1f1c7832 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -75,7 +75,9 @@ def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) - def to_input_list(self) -> list[TResponseInputItem]: """Creates a new input list, merging the original input with all the new items generated.""" - original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input) + original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list( + self.input + ) new_items = [item.to_input_item() for item in self.new_items] return original_items + new_items @@ -152,6 +154,18 @@ def last_agent(self) -> Agent[Any]: """ return self.current_agent + def cancel(self) -> None: + """Cancels the streaming run, stopping all background tasks and marking the run as + complete.""" + self._cleanup_tasks() # Cancel all running tasks + self.is_complete = True # Mark the run as complete to stop event streaming + + # Optionally, clear the event queue to prevent processing stale events + while not self._event_queue.empty(): + self._event_queue.get_nowait() + while not self._input_guardrail_queue.empty(): + self._input_guardrail_queue.get_nowait() + async def stream_events(self) -> AsyncIterator[StreamEvent]: """Stream deltas for new items as they are generated. We're using the types from the OpenAI Responses API, so these are semantic events: each event has a `type` field that @@ -192,13 +206,17 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: def _check_errors(self): if self.current_turn > self.max_turns: - self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded") + self._stored_exception = MaxTurnsExceeded( + f"Max turns ({self.max_turns}) exceeded" + ) # Fetch all the completed guardrail results from the queue and raise if needed while not self._input_guardrail_queue.empty(): guardrail_result = self._input_guardrail_queue.get_nowait() if guardrail_result.output.tripwire_triggered: - self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result) + self._stored_exception = InputGuardrailTripwireTriggered( + guardrail_result + ) # Check the tasks for any exceptions if self._run_impl_task and self._run_impl_task.done(): diff --git a/tests/test_cancel_streaming.py b/tests/test_cancel_streaming.py new file mode 100644 index 00000000..6d1807d7 --- /dev/null +++ b/tests/test_cancel_streaming.py @@ -0,0 +1,22 @@ +import pytest + +from agents import Agent, Runner + +from .fake_model import FakeModel + + +@pytest.mark.asyncio +async def test_joker_streamed_jokes_with_cancel(): + model = FakeModel() + agent = Agent(name="Joker", model=model) + + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + num_events = 0 + stop_after = 1 # There are two that the model gives back. + + async for _event in result.stream_events(): + num_events += 1 + if num_events == 1: + result.cancel() + + assert num_events == 1, f"Expected {stop_after} visible events, but got {num_events}" From 3755ea86589b8e929c5b2bdd51df9f62c1cad8bf Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Wed, 23 Apr 2025 20:39:07 -0400 Subject: [PATCH 4/8] Create to_json_dict for ModelSettings (#582) Now that `ModelSettings` has `Reasoning`, a non-primitive object, `dataclasses.as_dict()` wont work. It will raise an error when you try to serialize (e.g. for tracing). This ensures the object is actually serializable. --- pyproject.toml | 2 +- src/agents/extensions/models/litellm_model.py | 5 +- src/agents/model_settings.py | 17 +++++- src/agents/models/openai_chatcompletions.py | 7 +-- tests/model_settings/test_serialization.py | 59 +++++++++++++++++++ tests/voice/conftest.py | 1 - uv.lock | 8 +-- 7 files changed, 84 insertions(+), 15 deletions(-) create mode 100644 tests/model_settings/test_serialization.py diff --git a/pyproject.toml b/pyproject.toml index eeeb6d3d..12ffff1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires-python = ">=3.9" license = "MIT" authors = [{ name = "OpenAI", email = "support@openai.com" }] dependencies = [ - "openai>=1.66.5", + "openai>=1.76.0", "pydantic>=2.10, <3", "griffe>=1.5.6, <2", "typing-extensions>=4.12.2, <5", diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index f5e7752f..dc672acd 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -1,6 +1,5 @@ from __future__ import annotations -import dataclasses import json import time from collections.abc import AsyncIterator @@ -75,7 +74,7 @@ async def get_response( ) -> ModelResponse: with generation_span( model=str(self.model), - model_config=dataclasses.asdict(model_settings) + model_config=model_settings.to_json_dict() | {"base_url": str(self.base_url or ""), "model_impl": "litellm"}, disabled=tracing.is_disabled(), ) as span_generation: @@ -147,7 +146,7 @@ async def stream_response( ) -> AsyncIterator[TResponseStreamEvent]: with generation_span( model=str(self.model), - model_config=dataclasses.asdict(model_settings) + model_config=model_settings.to_json_dict() | {"base_url": str(self.base_url or ""), "model_impl": "litellm"}, disabled=tracing.is_disabled(), ) as span_generation: diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index fee92b4e..7b016c98 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -1,10 +1,12 @@ from __future__ import annotations +import dataclasses from dataclasses import dataclass, fields, replace -from typing import Literal +from typing import Any, Literal from openai._types import Body, Headers, Query from openai.types.shared import Reasoning +from pydantic import BaseModel @dataclass @@ -83,3 +85,16 @@ def resolve(self, override: ModelSettings | None) -> ModelSettings: if getattr(override, field.name) is not None } return replace(self, **changes) + + def to_json_dict(self) -> dict[str, Any]: + dataclass_dict = dataclasses.asdict(self) + + json_dict: dict[str, Any] = {} + + for field_name, value in dataclass_dict.items(): + if isinstance(value, BaseModel): + json_dict[field_name] = value.model_dump(mode="json") + else: + json_dict[field_name] = value + + return json_dict diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 15bf19cb..89619f83 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -1,6 +1,5 @@ from __future__ import annotations -import dataclasses import json import time from collections.abc import AsyncIterator @@ -56,8 +55,7 @@ async def get_response( ) -> ModelResponse: with generation_span( model=str(self.model), - model_config=dataclasses.asdict(model_settings) - | {"base_url": str(self._client.base_url)}, + model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, disabled=tracing.is_disabled(), ) as span_generation: response = await self._fetch_response( @@ -121,8 +119,7 @@ async def stream_response( """ with generation_span( model=str(self.model), - model_config=dataclasses.asdict(model_settings) - | {"base_url": str(self._client.base_url)}, + model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, disabled=tracing.is_disabled(), ) as span_generation: response, stream = await self._fetch_response( diff --git a/tests/model_settings/test_serialization.py b/tests/model_settings/test_serialization.py new file mode 100644 index 00000000..d76a58d1 --- /dev/null +++ b/tests/model_settings/test_serialization.py @@ -0,0 +1,59 @@ +import json +from dataclasses import fields + +from openai.types.shared import Reasoning + +from agents.model_settings import ModelSettings + + +def verify_serialization(model_settings: ModelSettings) -> None: + """Verify that ModelSettings can be serialized to a JSON string.""" + json_dict = model_settings.to_json_dict() + json_string = json.dumps(json_dict) + assert json_string is not None + + +def test_basic_serialization() -> None: + """Tests whether ModelSettings can be serialized to a JSON string.""" + + # First, lets create a ModelSettings instance + model_settings = ModelSettings( + temperature=0.5, + top_p=0.9, + max_tokens=100, + ) + + # Now, lets serialize the ModelSettings instance to a JSON string + verify_serialization(model_settings) + + +def test_all_fields_serialization() -> None: + """Tests whether ModelSettings can be serialized to a JSON string.""" + + # First, lets create a ModelSettings instance + model_settings = ModelSettings( + temperature=0.5, + top_p=0.9, + frequency_penalty=0.0, + presence_penalty=0.0, + tool_choice="auto", + parallel_tool_calls=True, + truncation="auto", + max_tokens=100, + reasoning=Reasoning(), + metadata={"foo": "bar"}, + store=False, + include_usage=False, + extra_query={"foo": "bar"}, + extra_body={"foo": "bar"}, + extra_headers={"foo": "bar"}, + ) + + # Verify that every single field is set to a non-None value + for field in fields(model_settings): + assert getattr(model_settings, field.name) is not None, ( + f"You must set the {field.name} field" + ) + + # Now, lets serialize the ModelSettings instance to a JSON string + verify_serialization(model_settings) diff --git a/tests/voice/conftest.py b/tests/voice/conftest.py index 6ed7422c..79d85d8b 100644 --- a/tests/voice/conftest.py +++ b/tests/voice/conftest.py @@ -9,4 +9,3 @@ def pytest_ignore_collect(collection_path, config): if str(collection_path).startswith(this_dir): return True - diff --git a/uv.lock b/uv.lock index 3a737cf3..4c6c370a 100644 --- a/uv.lock +++ b/uv.lock @@ -1463,7 +1463,7 @@ wheels = [ [[package]] name = "openai" -version = "1.74.0" +version = "1.76.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1475,9 +1475,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/75/86/c605a6e84da0248f2cebfcd864b5a6076ecf78849245af5e11d2a5ec7977/openai-1.74.0.tar.gz", hash = "sha256:592c25b8747a7cad33a841958f5eb859a785caea9ee22b9e4f4a2ec062236526", size = 427571 } +sdist = { url = "https://files.pythonhosted.org/packages/84/51/817969ec969b73d8ddad085670ecd8a45ef1af1811d8c3b8a177ca4d1309/openai-1.76.0.tar.gz", hash = "sha256:fd2bfaf4608f48102d6b74f9e11c5ecaa058b60dad9c36e409c12477dfd91fb2", size = 434660 } wheels = [ - { url = "https://files.pythonhosted.org/packages/a9/91/8c150f16a96367e14bd7d20e86e0bbbec3080e3eb593e63f21a7f013f8e4/openai-1.74.0-py3-none-any.whl", hash = "sha256:aff3e0f9fb209836382ec112778667027f4fd6ae38bdb2334bc9e173598b092a", size = 644790 }, + { url = "https://files.pythonhosted.org/packages/59/aa/84e02ab500ca871eb8f62784426963a1c7c17a72fea3c7f268af4bbaafa5/openai-1.76.0-py3-none-any.whl", hash = "sha256:a712b50e78cf78e6d7b2a8f69c4978243517c2c36999756673e07a14ce37dc0a", size = 661201 }, ] [[package]] @@ -1538,7 +1538,7 @@ requires-dist = [ { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.65.0,<2" }, { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.6.0,<2" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, - { name = "openai", specifier = ">=1.66.5" }, + { name = "openai", specifier = ">=1.76.0" }, { name = "pydantic", specifier = ">=2.10,<3" }, { name = "requests", specifier = ">=2.0,<3" }, { name = "types-requests", specifier = ">=2.0,<3" }, From af80e3a97123a5a0ad0fba695fbc257163c23224 Mon Sep 17 00:00:00 2001 From: Nathan Brake <33383515+njbrake@users.noreply.github.com> Date: Thu, 24 Apr 2025 12:12:46 -0400 Subject: [PATCH 5/8] Prevent MCP ClientSession hang (#580) Per https://modelcontextprotocol.io/specification/draft/basic/lifecycle#timeouts "Implementations SHOULD establish timeouts for all sent requests, to prevent hung connections and resource exhaustion. When the request has not received a success or error response within the timeout period, the sender SHOULD issue a cancellation notification for that request and stop waiting for a response. SDKs and other middleware SHOULD allow these timeouts to be configured on a per-request basis." I picked 5 seconds since that's the default for SSE --- src/agents/mcp/server.py | 26 ++++++++++++++++++++++---- tests/mcp/test_server_errors.py | 2 +- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 9a137bbd..9916c92b 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -3,6 +3,7 @@ import abc import asyncio from contextlib import AbstractAsyncContextManager, AsyncExitStack +from datetime import timedelta from pathlib import Path from typing import Any, Literal @@ -54,7 +55,7 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C class _MCPServerWithClientSession(MCPServer, abc.ABC): """Base class for MCP servers that use a `ClientSession` to communicate with the server.""" - def __init__(self, cache_tools_list: bool): + def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None): """ Args: cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be @@ -63,12 +64,16 @@ def __init__(self, cache_tools_list: bool): by calling `invalidate_tools_cache()`. You should set this to `True` if you know the server will not change its tools list, because it can drastically improve latency (by avoiding a round-trip to the server every time). + + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. """ self.session: ClientSession | None = None self.exit_stack: AsyncExitStack = AsyncExitStack() self._cleanup_lock: asyncio.Lock = asyncio.Lock() self.cache_tools_list = cache_tools_list + self.client_session_timeout_seconds = client_session_timeout_seconds + # The cache is always dirty at startup, so that we fetch tools at least once self._cache_dirty = True self._tools_list: list[MCPTool] | None = None @@ -101,7 +106,15 @@ async def connect(self): try: transport = await self.exit_stack.enter_async_context(self.create_streams()) read, write = transport - session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + session = await self.exit_stack.enter_async_context( + ClientSession( + read, + write, + timedelta(seconds=self.client_session_timeout_seconds) + if self.client_session_timeout_seconds + else None, + ) + ) await session.initialize() self.session = session except Exception as e: @@ -183,6 +196,7 @@ def __init__( params: MCPServerStdioParams, cache_tools_list: bool = False, name: str | None = None, + client_session_timeout_seconds: float | None = 5, ): """Create a new MCP server based on the stdio transport. @@ -199,8 +213,9 @@ def __init__( improve latency (by avoiding a round-trip to the server every time). name: A readable name for the server. If not provided, we'll create one from the command. + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. """ - super().__init__(cache_tools_list) + super().__init__(cache_tools_list, client_session_timeout_seconds) self.params = StdioServerParameters( command=params["command"], @@ -257,6 +272,7 @@ def __init__( params: MCPServerSseParams, cache_tools_list: bool = False, name: str | None = None, + client_session_timeout_seconds: float | None = 5, ): """Create a new MCP server based on the HTTP with SSE transport. @@ -274,8 +290,10 @@ def __init__( name: A readable name for the server. If not provided, we'll create one from the URL. + + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. """ - super().__init__(cache_tools_list) + super().__init__(cache_tools_list, client_session_timeout_seconds) self.params = params self._name = name or f"sse: {self.params['url']}" diff --git a/tests/mcp/test_server_errors.py b/tests/mcp/test_server_errors.py index bdca7ce6..fbd8db17 100644 --- a/tests/mcp/test_server_errors.py +++ b/tests/mcp/test_server_errors.py @@ -6,7 +6,7 @@ class CrashingClientSessionServer(_MCPServerWithClientSession): def __init__(self): - super().__init__(cache_tools_list=False) + super().__init__(cache_tools_list=False, client_session_timeout_seconds=5) self.cleanup_called = False def create_streams(self): From e11b822d5f075fc32683c8df71ac9388a7df79e5 Mon Sep 17 00:00:00 2001 From: Daniele Morotti <58258368+DanieleMorotti@users.noreply.github.com> Date: Thu, 24 Apr 2025 18:53:39 +0200 Subject: [PATCH 6/8] Fix stream error using LiteLLM (#589) In response to issue #587 , I implemented a solution to first check if `refusal` and `usage` attributes exist in the `delta` object. I added a unit test similar to `test_openai_chatcompletions_stream.py`. Let me know if I should change something. --------- Co-authored-by: Rohan Mehta --- src/agents/models/chatcmpl_stream_handler.py | 6 +- .../test_litellm_chatcompletions_stream.py | 286 ++++++++++++++++++ 2 files changed, 290 insertions(+), 2 deletions(-) create mode 100644 tests/models/test_litellm_chatcompletions_stream.py diff --git a/src/agents/models/chatcmpl_stream_handler.py b/src/agents/models/chatcmpl_stream_handler.py index 32f04acb..c71adeb5 100644 --- a/src/agents/models/chatcmpl_stream_handler.py +++ b/src/agents/models/chatcmpl_stream_handler.py @@ -56,7 +56,8 @@ async def handle_stream( type="response.created", ) - usage = chunk.usage + # This is always set by the OpenAI API, but not by others e.g. LiteLLM + usage = chunk.usage if hasattr(chunk, "usage") else None if not chunk.choices or not chunk.choices[0].delta: continue @@ -112,7 +113,8 @@ async def handle_stream( state.text_content_index_and_output[1].text += delta.content # Handle refusals (model declines to answer) - if delta.refusal: + # This is always set by the OpenAI API, but not by others e.g. LiteLLM + if hasattr(delta, "refusal") and delta.refusal: if not state.refusal_content_index_and_output: # Initialize a content tracker for streaming refusal text state.refusal_content_index_and_output = ( diff --git a/tests/models/test_litellm_chatcompletions_stream.py b/tests/models/test_litellm_chatcompletions_stream.py new file mode 100644 index 00000000..80bd8ea2 --- /dev/null +++ b/tests/models/test_litellm_chatcompletions_stream.py @@ -0,0 +1,286 @@ +from collections.abc import AsyncIterator + +import pytest +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) +from openai.types.completion_usage import CompletionUsage +from openai.types.responses import ( + Response, + ResponseFunctionToolCall, + ResponseOutputMessage, + ResponseOutputRefusal, + ResponseOutputText, +) + +from agents.extensions.models.litellm_model import LitellmModel +from agents.extensions.models.litellm_provider import LitellmProvider +from agents.model_settings import ModelSettings +from agents.models.interface import ModelTracing + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_yields_events_for_text_content(monkeypatch) -> None: + """ + Validate that `stream_response` emits the correct sequence of events when + streaming a simple assistant message consisting of plain text content. + We simulate two chunks of text returned from the chat completion stream. + """ + # Create two chunks that will be emitted by the fake stream. + chunk1 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(content="He"))], + ) + # Mark last chunk with usage so stream_response knows this is final. + chunk2 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(content="llo"))], + usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12), + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + for c in (chunk1, chunk2): + yield c + + # Patch _fetch_response to inject our fake stream + async def patched_fetch_response(self, *args, **kwargs): + # `_fetch_response` is expected to return a Response skeleton and the async stream + resp = Response( + id="resp-id", + created_at=0, + model="fake-model", + object="response", + output=[], + tool_choice="none", + tools=[], + parallel_tool_calls=False, + ) + return resp, fake_stream() + + monkeypatch.setattr(LitellmModel, "_fetch_response", patched_fetch_response) + model = LitellmProvider().get_model("gpt-4") + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ): + output_events.append(event) + # We expect a response.created, then a response.output_item.added, content part added, + # two content delta events (for "He" and "llo"), a content part done, the assistant message + # output_item.done, and finally response.completed. + # There should be 8 events in total. + assert len(output_events) == 8 + # First event indicates creation. + assert output_events[0].type == "response.created" + # The output item added and content part added events should mark the assistant message. + assert output_events[1].type == "response.output_item.added" + assert output_events[2].type == "response.content_part.added" + # Two text delta events. + assert output_events[3].type == "response.output_text.delta" + assert output_events[3].delta == "He" + assert output_events[4].type == "response.output_text.delta" + assert output_events[4].delta == "llo" + # After streaming, the content part and item should be marked done. + assert output_events[5].type == "response.content_part.done" + assert output_events[6].type == "response.output_item.done" + # Last event indicates completion of the stream. + assert output_events[7].type == "response.completed" + # The completed response should have one output message with full text. + completed_resp = output_events[7].response + assert isinstance(completed_resp.output[0], ResponseOutputMessage) + assert isinstance(completed_resp.output[0].content[0], ResponseOutputText) + assert completed_resp.output[0].content[0].text == "Hello" + + assert completed_resp.usage, "usage should not be None" + assert completed_resp.usage.input_tokens == 7 + assert completed_resp.usage.output_tokens == 5 + assert completed_resp.usage.total_tokens == 12 + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_yields_events_for_refusal_content(monkeypatch) -> None: + """ + Validate that when the model streams a refusal string instead of normal content, + `stream_response` emits the appropriate sequence of events including + `response.refusal.delta` events for each chunk of the refusal message and + constructs a completed assistant message with a `ResponseOutputRefusal` part. + """ + # Simulate refusal text coming in two pieces, like content but using the `refusal` + # field on the delta rather than `content`. + chunk1 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(refusal="No"))], + ) + chunk2 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(refusal="Thanks"))], + usage=CompletionUsage(completion_tokens=2, prompt_tokens=2, total_tokens=4), + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + for c in (chunk1, chunk2): + yield c + + async def patched_fetch_response(self, *args, **kwargs): + resp = Response( + id="resp-id", + created_at=0, + model="fake-model", + object="response", + output=[], + tool_choice="none", + tools=[], + parallel_tool_calls=False, + ) + return resp, fake_stream() + + monkeypatch.setattr(LitellmModel, "_fetch_response", patched_fetch_response) + model = LitellmProvider().get_model("gpt-4") + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ): + output_events.append(event) + # Expect sequence similar to text: created, output_item.added, content part added, + # two refusal delta events, content part done, output_item.done, completed. + assert len(output_events) == 8 + assert output_events[0].type == "response.created" + assert output_events[1].type == "response.output_item.added" + assert output_events[2].type == "response.content_part.added" + assert output_events[3].type == "response.refusal.delta" + assert output_events[3].delta == "No" + assert output_events[4].type == "response.refusal.delta" + assert output_events[4].delta == "Thanks" + assert output_events[5].type == "response.content_part.done" + assert output_events[6].type == "response.output_item.done" + assert output_events[7].type == "response.completed" + completed_resp = output_events[7].response + assert isinstance(completed_resp.output[0], ResponseOutputMessage) + refusal_part = completed_resp.output[0].content[0] + assert isinstance(refusal_part, ResponseOutputRefusal) + assert refusal_part.refusal == "NoThanks" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_yields_events_for_tool_call(monkeypatch) -> None: + """ + Validate that `stream_response` emits the correct sequence of events when + the model is streaming a function/tool call instead of plain text. + The function call will be split across two chunks. + """ + # Simulate a single tool call whose ID stays constant and function name/args built over chunks. + tool_call_delta1 = ChoiceDeltaToolCall( + index=0, + id="tool-id", + function=ChoiceDeltaToolCallFunction(name="my_", arguments="arg1"), + type="function", + ) + tool_call_delta2 = ChoiceDeltaToolCall( + index=0, + id="tool-id", + function=ChoiceDeltaToolCallFunction(name="func", arguments="arg2"), + type="function", + ) + chunk1 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta1]))], + ) + chunk2 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta2]))], + usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2), + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + for c in (chunk1, chunk2): + yield c + + async def patched_fetch_response(self, *args, **kwargs): + resp = Response( + id="resp-id", + created_at=0, + model="fake-model", + object="response", + output=[], + tool_choice="none", + tools=[], + parallel_tool_calls=False, + ) + return resp, fake_stream() + + monkeypatch.setattr(LitellmModel, "_fetch_response", patched_fetch_response) + model = LitellmProvider().get_model("gpt-4") + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ): + output_events.append(event) + # Sequence should be: response.created, then after loop we expect function call-related events: + # one response.output_item.added for function call, a response.function_call_arguments.delta, + # a response.output_item.done, and finally response.completed. + assert output_events[0].type == "response.created" + # The next three events are about the tool call. + assert output_events[1].type == "response.output_item.added" + # The added item should be a ResponseFunctionToolCall. + added_fn = output_events[1].item + assert isinstance(added_fn, ResponseFunctionToolCall) + assert added_fn.name == "my_func" # Name should be concatenation of both chunks. + assert added_fn.arguments == "arg1arg2" + assert output_events[2].type == "response.function_call_arguments.delta" + assert output_events[2].delta == "arg1arg2" + assert output_events[3].type == "response.output_item.done" + assert output_events[4].type == "response.completed" + assert output_events[2].delta == "arg1arg2" + assert output_events[3].type == "response.output_item.done" + assert output_events[4].type == "response.completed" + assert added_fn.name == "my_func" # Name should be concatenation of both chunks. + assert added_fn.arguments == "arg1arg2" + assert output_events[2].type == "response.function_call_arguments.delta" + assert output_events[2].delta == "arg1arg2" + assert output_events[3].type == "response.output_item.done" + assert output_events[4].type == "response.completed" From 45eb41f1e668a45c6b53b64e06fa7db9eab4db46 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Thu, 24 Apr 2025 14:45:03 -0400 Subject: [PATCH 7/8] More tests for cancelling streamed run (#590) --- tests/fake_model.py | 5 +- tests/test_cancel_streaming.py | 98 +++++++++++++++++++++++++++++++++- 2 files changed, 100 insertions(+), 3 deletions(-) diff --git a/tests/fake_model.py b/tests/fake_model.py index c6b3ba92..da3019a0 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -127,7 +127,10 @@ async def stream_response( ) -def get_response_obj(output: list[TResponseOutputItem], response_id: str | None = None) -> Response: +def get_response_obj( + output: list[TResponseOutputItem], + response_id: str | None = None, +) -> Response: return Response( id=response_id or "123", created_at=123, diff --git a/tests/test_cancel_streaming.py b/tests/test_cancel_streaming.py index 6d1807d7..3417a3c5 100644 --- a/tests/test_cancel_streaming.py +++ b/tests/test_cancel_streaming.py @@ -1,12 +1,15 @@ +import json + import pytest from agents import Agent, Runner from .fake_model import FakeModel +from .test_responses import get_function_tool, get_function_tool_call, get_text_message @pytest.mark.asyncio -async def test_joker_streamed_jokes_with_cancel(): +async def test_simple_streaming_with_cancel(): model = FakeModel() agent = Agent(name="Joker", model=model) @@ -16,7 +19,98 @@ async def test_joker_streamed_jokes_with_cancel(): async for _event in result.stream_events(): num_events += 1 - if num_events == 1: + if num_events == stop_after: result.cancel() assert num_events == 1, f"Expected {stop_after} visible events, but got {num_events}" + + +@pytest.mark.asyncio +async def test_multiple_events_streaming_with_cancel(): + model = FakeModel() + agent = Agent( + name="Joker", + model=model, + tools=[get_function_tool("foo", "tool_result")], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [ + get_text_message("a_message"), + get_function_tool_call("foo", json.dumps({"a": "b"})), + ], + # Second turn: text message + [get_text_message("done")], + ] + ) + + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + num_events = 0 + stop_after = 2 + + async for _ in result.stream_events(): + num_events += 1 + if num_events == stop_after: + result.cancel() + + assert num_events == stop_after, f"Expected {stop_after} visible events, but got {num_events}" + + +@pytest.mark.asyncio +async def test_cancel_prevents_further_events(): + model = FakeModel() + agent = Agent(name="Joker", model=model) + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + events = [] + async for event in result.stream_events(): + events.append(event) + result.cancel() + break # Cancel after first event + # Try to get more events after cancel + more_events = [e async for e in result.stream_events()] + assert len(events) == 1 + assert more_events == [], "No events should be yielded after cancel()" + + +@pytest.mark.asyncio +async def test_cancel_is_idempotent(): + model = FakeModel() + agent = Agent(name="Joker", model=model) + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + events = [] + async for event in result.stream_events(): + events.append(event) + result.cancel() + result.cancel() # Call cancel again + break + # Should not raise or misbehave + assert len(events) == 1 + + +@pytest.mark.asyncio +async def test_cancel_before_streaming(): + model = FakeModel() + agent = Agent(name="Joker", model=model) + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + result.cancel() # Cancel before streaming + events = [e async for e in result.stream_events()] + assert events == [], "No events should be yielded if cancel() is called before streaming." + + +@pytest.mark.asyncio +async def test_cancel_cleans_up_resources(): + model = FakeModel() + agent = Agent(name="Joker", model=model) + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + # Start streaming, then cancel + async for _ in result.stream_events(): + result.cancel() + break + # After cancel, queues should be empty and is_complete True + assert result.is_complete, "Result should be marked complete after cancel." + assert result._event_queue.empty(), "Event queue should be empty after cancel." + assert result._input_guardrail_queue.empty(), ( + "Input guardrail queue should be empty after cancel." + ) From 3bbc7c48cb9ee80ed4b3dfbbd55efddf7f77d6a3 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Thu, 24 Apr 2025 14:58:38 -0400 Subject: [PATCH 8/8] v0.0.13 (#593) --- pyproject.toml | 2 +- uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 12ffff1f..c1ae467a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openai-agents" -version = "0.0.12" +version = "0.0.13" description = "OpenAI Agents SDK" readme = "README.md" requires-python = ">=3.9" diff --git a/uv.lock b/uv.lock index 4c6c370a..c6824a08 100644 --- a/uv.lock +++ b/uv.lock @@ -1482,7 +1482,7 @@ wheels = [ [[package]] name = "openai-agents" -version = "0.0.12" +version = "0.0.13" source = { editable = "." } dependencies = [ { name = "griffe" },