diff --git a/src/agents/realtime/runner.py b/src/agents/realtime/runner.py index a7047a6f5..e51a094d8 100644 --- a/src/agents/realtime/runner.py +++ b/src/agents/realtime/runner.py @@ -2,13 +2,10 @@ from __future__ import annotations -import asyncio - -from ..run_context import RunContextWrapper, TContext +from ..run_context import TContext from .agent import RealtimeAgent from .config import ( RealtimeRunConfig, - RealtimeSessionModelSettings, ) from .model import ( RealtimeModel, @@ -67,16 +64,6 @@ async def run( print(event) ``` """ - model_settings = await self._get_model_settings( - agent=self._starting_agent, - disable_tracing=self._config.get("tracing_disabled", False) if self._config else False, - initial_settings=model_config.get("initial_model_settings") if model_config else None, - overrides=self._config.get("model_settings") if self._config else None, - ) - - model_config = model_config.copy() if model_config else {} - model_config["initial_model_settings"] = model_settings - # Create and return the connection session = RealtimeSession( model=self._model, @@ -87,32 +74,3 @@ async def run( ) return session - - async def _get_model_settings( - self, - agent: RealtimeAgent, - disable_tracing: bool, - context: TContext | None = None, - initial_settings: RealtimeSessionModelSettings | None = None, - overrides: RealtimeSessionModelSettings | None = None, - ) -> RealtimeSessionModelSettings: - context_wrapper = RunContextWrapper(context) - model_settings = initial_settings.copy() if initial_settings else {} - - instructions, tools = await asyncio.gather( - agent.get_system_prompt(context_wrapper), - agent.get_all_tools(context_wrapper), - ) - - if instructions is not None: - model_settings["instructions"] = instructions - if tools is not None: - model_settings["tools"] = tools - - if overrides: - model_settings.update(overrides) - - if disable_tracing: - model_settings["tracing"] = None - - return model_settings diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 6df35b438..83e56d5fc 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -114,8 +114,13 @@ async def __aenter__(self) -> RealtimeSession: # Add ourselves as a listener self._model.add_listener(self) + model_config = self._model_config.copy() + model_config["initial_model_settings"] = await self._get_updated_model_settings_from_agent( + self._current_agent + ) + # Connect to the model - await self._model.connect(self._model_config) + await self._model.connect(model_config) # Emit initial history update await self._put_event( @@ -319,7 +324,9 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None: self._current_agent = result # Get updated model settings from new agent - updated_settings = await self._get__updated_model_settings(self._current_agent) + updated_settings = await self._get_updated_model_settings_from_agent( + self._current_agent + ) # Send handoff event await self._put_event( @@ -495,19 +502,28 @@ async def _cleanup(self) -> None: # Mark as closed self._closed = True - async def _get__updated_model_settings( - self, new_agent: RealtimeAgent + async def _get_updated_model_settings_from_agent( + self, + agent: RealtimeAgent, ) -> RealtimeSessionModelSettings: updated_settings: RealtimeSessionModelSettings = {} instructions, tools, handoffs = await asyncio.gather( - new_agent.get_system_prompt(self._context_wrapper), - new_agent.get_all_tools(self._context_wrapper), - self._get_handoffs(new_agent, self._context_wrapper), + agent.get_system_prompt(self._context_wrapper), + agent.get_all_tools(self._context_wrapper), + self._get_handoffs(agent, self._context_wrapper), ) updated_settings["instructions"] = instructions or "" updated_settings["tools"] = tools or [] updated_settings["handoffs"] = handoffs or [] + # Override with initial settings + initial_settings = self._model_config.get("initial_model_settings", {}) + updated_settings.update(initial_settings) + + disable_tracing = self._run_config.get("tracing_disabled", False) + if disable_tracing: + updated_settings["tracing"] = None + return updated_settings @classmethod diff --git a/tests/realtime/test_runner.py b/tests/realtime/test_runner.py index aabdff140..1e6eccbae 100644 --- a/tests/realtime/test_runner.py +++ b/tests/realtime/test_runner.py @@ -1,18 +1,21 @@ from unittest.mock import AsyncMock, Mock, patch import pytest -from inline_snapshot import snapshot from agents.realtime.agent import RealtimeAgent from agents.realtime.config import RealtimeRunConfig, RealtimeSessionModelSettings from agents.realtime.model import RealtimeModel, RealtimeModelConfig from agents.realtime.runner import RealtimeRunner from agents.realtime.session import RealtimeSession +from agents.tool import function_tool class MockRealtimeModel(RealtimeModel): + def __init__(self): + self.connect_args = None + async def connect(self, options=None): - pass + self.connect_args = options def add_listener(self, listener): pass @@ -53,7 +56,9 @@ def mock_model(): @pytest.mark.asyncio -async def test_run_creates_session_with_no_settings(mock_agent, mock_model): +async def test_run_creates_session_with_no_settings( + mock_agent: Mock, mock_model: MockRealtimeModel +): """Test that run() creates a session correctly if no settings are provided""" runner = RealtimeRunner(mock_agent, model=mock_model) @@ -71,22 +76,17 @@ async def test_run_creates_session_with_no_settings(mock_agent, mock_model): assert call_args[1]["agent"] == mock_agent assert call_args[1]["context"] is None - # Verify model_config contains expected settings from agent + # With no settings provided, model_config should be None model_config = call_args[1]["model_config"] - assert model_config == snapshot( - { - "initial_model_settings": { - "instructions": "Test instructions", - "tools": [{"type": "function", "name": "test_tool"}], - } - } - ) + assert model_config is None assert session == mock_session @pytest.mark.asyncio -async def test_run_creates_session_with_settings_only_in_init(mock_agent, mock_model): +async def test_run_creates_session_with_settings_only_in_init( + mock_agent: Mock, mock_model: MockRealtimeModel +): """Test that it creates a session with the right settings if they are provided only in init""" config = RealtimeRunConfig( model_settings=RealtimeSessionModelSettings(model_name="gpt-4o-realtime", voice="nova") @@ -99,28 +99,19 @@ async def test_run_creates_session_with_settings_only_in_init(mock_agent, mock_m _ = await runner.run() - # Verify session was created with config overrides + # Verify session was created - runner no longer processes settings call_args = mock_session_class.call_args model_config = call_args[1]["model_config"] - # Should have agent settings plus config overrides - assert model_config == snapshot( - { - "initial_model_settings": { - "instructions": "Test instructions", - "tools": [{"type": "function", "name": "test_tool"}], - "model_name": "gpt-4o-realtime", - "voice": "nova", - } - } - ) + # Runner should pass None for model_config when none provided to run() + assert model_config is None @pytest.mark.asyncio async def test_run_creates_session_with_settings_in_both_init_and_run_overrides( - mock_agent, mock_model + mock_agent: Mock, mock_model: MockRealtimeModel ): - """Test settings in both init and run() - init should override run()""" + """Test settings provided in run() parameter are passed through""" init_config = RealtimeRunConfig( model_settings=RealtimeSessionModelSettings(model_name="gpt-4o-realtime", voice="nova") ) @@ -138,26 +129,18 @@ async def test_run_creates_session_with_settings_in_both_init_and_run_overrides( _ = await runner.run(model_config=run_model_config) - # Verify run() settings override init settings + # Verify run() model_config is passed through as-is call_args = mock_session_class.call_args model_config = call_args[1]["model_config"] - # Should have agent settings, then init config, then run config overrides - assert model_config == snapshot( - { - "initial_model_settings": { - "voice": "nova", - "input_audio_format": "pcm16", - "instructions": "Test instructions", - "tools": [{"type": "function", "name": "test_tool"}], - "model_name": "gpt-4o-realtime", - } - } - ) + # Runner should pass the model_config from run() parameter directly + assert model_config == run_model_config @pytest.mark.asyncio -async def test_run_creates_session_with_settings_only_in_run(mock_agent, mock_model): +async def test_run_creates_session_with_settings_only_in_run( + mock_agent: Mock, mock_model: MockRealtimeModel +): """Test settings provided only in run()""" runner = RealtimeRunner(mock_agent, model=mock_model) @@ -173,26 +156,16 @@ async def test_run_creates_session_with_settings_only_in_run(mock_agent, mock_mo _ = await runner.run(model_config=run_model_config) - # Verify run() settings are applied + # Verify run() model_config is passed through as-is call_args = mock_session_class.call_args model_config = call_args[1]["model_config"] - # Should have agent settings plus run() settings - assert model_config == snapshot( - { - "initial_model_settings": { - "model_name": "gpt-4o-realtime-preview", - "voice": "shimmer", - "modalities": ["text", "audio"], - "instructions": "Test instructions", - "tools": [{"type": "function", "name": "test_tool"}], - } - } - ) + # Runner should pass the model_config from run() parameter directly + assert model_config == run_model_config @pytest.mark.asyncio -async def test_run_with_context_parameter(mock_agent, mock_model): +async def test_run_with_context_parameter(mock_agent: Mock, mock_model: MockRealtimeModel): """Test that context parameter is passed through to session""" runner = RealtimeRunner(mock_agent, model=mock_model) test_context = {"user_id": "test123"} @@ -208,17 +181,69 @@ async def test_run_with_context_parameter(mock_agent, mock_model): @pytest.mark.asyncio -async def test_get_model_settings_with_none_values(mock_model): - """Test _get_model_settings handles None values from agent properly""" +async def test_run_with_none_values_from_agent_does_not_crash(mock_model: MockRealtimeModel): + """Test that runner handles agents with None values without crashing""" agent = Mock(spec=RealtimeAgent) agent.get_system_prompt = AsyncMock(return_value=None) agent.get_all_tools = AsyncMock(return_value=None) runner = RealtimeRunner(agent, model=mock_model) - with patch("agents.realtime.runner.RealtimeSession"): - await runner.run() + with patch("agents.realtime.runner.RealtimeSession") as mock_session_class: + mock_session = Mock(spec=RealtimeSession) + mock_session_class.return_value = mock_session + + session = await runner.run() + + # Should not crash and return session + assert session == mock_session + # Runner no longer calls agent methods directly - session does that + agent.get_system_prompt.assert_not_called() + agent.get_all_tools.assert_not_called() + + +@pytest.mark.asyncio +async def test_tool_and_handoffs_are_correct(mock_model: MockRealtimeModel): + @function_tool + def tool_one(): + return "result_one" + + agent_1 = RealtimeAgent( + name="one", + instructions="instr_one", + ) + agent_2 = RealtimeAgent( + name="two", + instructions="instr_two", + tools=[tool_one], + handoffs=[agent_1], + ) + + session = RealtimeSession( + model=mock_model, + agent=agent_2, + context=None, + model_config=None, + run_config=None, + ) + + async with session: + pass - # Should not crash and agent methods should be called - agent.get_system_prompt.assert_called_once() - agent.get_all_tools.assert_called_once() + # Assert that the model.connect() was called with the correct settings + connect_args = mock_model.connect_args + assert connect_args is not None + assert isinstance(connect_args, dict) + initial_model_settings = connect_args["initial_model_settings"] + assert initial_model_settings is not None + assert isinstance(initial_model_settings, dict) + assert initial_model_settings["instructions"] == "instr_two" + assert len(initial_model_settings["tools"]) == 1 + tool = initial_model_settings["tools"][0] + assert tool.name == "tool_one" + + handoffs = initial_model_settings["handoffs"] + assert len(handoffs) == 1 + handoff = handoffs[0] + assert handoff.tool_name == "transfer_to_one" + assert handoff.agent_name == "one" diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 7c1eb53ff..3bb9b5931 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -29,7 +29,7 @@ RealtimeItem, UserMessageItem, ) -from agents.realtime.model import RealtimeModel +from agents.realtime.model import RealtimeModel, RealtimeModelConfig from agents.realtime.model_events import ( RealtimeModelAudioDoneEvent, RealtimeModelAudioEvent, @@ -1206,3 +1206,117 @@ def guardrail_func(context, agent, output): guardrail_events = [e for e in events if isinstance(e, RealtimeGuardrailTripped)] assert len(guardrail_events) == 1 assert len(guardrail_events[0].guardrail_results) == 2 + + +class TestModelSettingsIntegration: + """Test suite for model settings integration in RealtimeSession.""" + + @pytest.mark.asyncio + async def test_session_gets_model_settings_from_agent_during_connection(self): + """Test that session properly gets model settings from agent during __aenter__.""" + # Create mock model that records the config passed to connect() + mock_model = Mock(spec=RealtimeModel) + mock_model.connect = AsyncMock() + mock_model.add_listener = Mock() + + # Create agent with specific settings + agent = Mock(spec=RealtimeAgent) + agent.get_system_prompt = AsyncMock(return_value="Test agent instructions") + agent.get_all_tools = AsyncMock(return_value=[{"type": "function", "name": "test_tool"}]) + agent.handoffs = [] + + session = RealtimeSession(mock_model, agent, None) + + # Connect the session + await session.__aenter__() + + # Verify model.connect was called with settings from agent + mock_model.connect.assert_called_once() + connect_config = mock_model.connect.call_args[0][0] + + initial_settings = connect_config["initial_model_settings"] + assert initial_settings["instructions"] == "Test agent instructions" + assert initial_settings["tools"] == [{"type": "function", "name": "test_tool"}] + assert initial_settings["handoffs"] == [] + + await session.__aexit__(None, None, None) + + @pytest.mark.asyncio + async def test_model_config_overrides_agent_settings(self): + """Test that initial_model_settings from model_config override agent settings.""" + mock_model = Mock(spec=RealtimeModel) + mock_model.connect = AsyncMock() + mock_model.add_listener = Mock() + + agent = Mock(spec=RealtimeAgent) + agent.get_system_prompt = AsyncMock(return_value="Agent instructions") + agent.get_all_tools = AsyncMock(return_value=[{"type": "function", "name": "agent_tool"}]) + agent.handoffs = [] + + # Provide model config with overrides + model_config: RealtimeModelConfig = { + "initial_model_settings": { + "instructions": "Override instructions", + "voice": "nova", + "model_name": "gpt-4o-realtime", + } + } + + session = RealtimeSession(mock_model, agent, None, model_config=model_config) + + await session.__aenter__() + + # Verify overrides were applied + connect_config = mock_model.connect.call_args[0][0] + initial_settings = connect_config["initial_model_settings"] + + # Should have override values + assert initial_settings["instructions"] == "Override instructions" + assert initial_settings["voice"] == "nova" + assert initial_settings["model_name"] == "gpt-4o-realtime" + # Should still have agent tools since not overridden + assert initial_settings["tools"] == [{"type": "function", "name": "agent_tool"}] + + await session.__aexit__(None, None, None) + + @pytest.mark.asyncio + async def test_handoffs_are_included_in_model_settings(self): + """Test that handoffs from agent are properly processed into model settings.""" + mock_model = Mock(spec=RealtimeModel) + mock_model.connect = AsyncMock() + mock_model.add_listener = Mock() + + # Create agent with handoffs + agent = Mock(spec=RealtimeAgent) + agent.get_system_prompt = AsyncMock(return_value="Agent with handoffs") + agent.get_all_tools = AsyncMock(return_value=[]) + + # Create a mock handoff + handoff_agent = Mock(spec=RealtimeAgent) + handoff_agent.name = "handoff_target" + + mock_handoff = Mock(spec=Handoff) + mock_handoff.tool_name = "transfer_to_specialist" + mock_handoff.is_enabled = True + + agent.handoffs = [handoff_agent] # Agent handoff + + # Mock the _get_handoffs method since it's complex + with pytest.MonkeyPatch().context() as m: + + async def mock_get_handoffs(cls, agent, context_wrapper): + return [mock_handoff] + + m.setattr("agents.realtime.session.RealtimeSession._get_handoffs", mock_get_handoffs) + + session = RealtimeSession(mock_model, agent, None) + + await session.__aenter__() + + # Verify handoffs were included + connect_config = mock_model.connect.call_args[0][0] + initial_settings = connect_config["initial_model_settings"] + + assert initial_settings["handoffs"] == [mock_handoff] + + await session.__aexit__(None, None, None) diff --git a/tests/realtime/test_tracing.py b/tests/realtime/test_tracing.py index 85da63897..6ef1ae9a7 100644 --- a/tests/realtime/test_tracing.py +++ b/tests/realtime/test_tracing.py @@ -1,8 +1,11 @@ -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest +from agents.realtime.agent import RealtimeAgent +from agents.realtime.model import RealtimeModel from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel +from agents.realtime.session import RealtimeSession class TestRealtimeTracingIntegration: @@ -219,36 +222,24 @@ async def async_websocket(*args, **kwargs): @pytest.mark.asyncio async def test_tracing_disabled_prevents_tracing(self, mock_websocket): """Test that tracing_disabled=True prevents tracing configuration.""" - from agents.realtime.agent import RealtimeAgent - from agents.realtime.runner import RealtimeRunner - # Create a test agent and runner with tracing disabled + # Create a test agent and mock model agent = RealtimeAgent(name="test_agent", instructions="test") + agent.handoffs = [] - runner = RealtimeRunner(starting_agent=agent, config={"tracing_disabled": True}) + mock_model = Mock(spec=RealtimeModel) - # Test the _get_model_settings method directly since that's where the logic is - model_settings = await runner._get_model_settings( + # Create session with tracing disabled + session = RealtimeSession( + model=mock_model, agent=agent, - disable_tracing=True, # This should come from config["tracing_disabled"] - initial_settings=None, - overrides=None, + context=None, + model_config=None, + run_config={"tracing_disabled": True}, ) + # Test the _get_updated_model_settings_from_agent method directly + model_settings = await session._get_updated_model_settings_from_agent(agent) + # When tracing is disabled, model settings should have tracing=None assert model_settings["tracing"] is None - - # Also test that the runner passes disable_tracing=True correctly - with patch.object(runner, "_get_model_settings") as mock_get_settings: - mock_get_settings.return_value = {"tracing": None} - - with patch("agents.realtime.session.RealtimeSession") as mock_session_class: - mock_session = AsyncMock() - mock_session_class.return_value = mock_session - - await runner.run() - - # Verify that _get_model_settings was called with disable_tracing=True - mock_get_settings.assert_called_once_with( - agent=agent, disable_tracing=True, initial_settings=None, overrides=None - ) diff --git a/tests/test_session_exceptions.py b/tests/test_session_exceptions.py index a454cca92..da9390236 100644 --- a/tests/test_session_exceptions.py +++ b/tests/test_session_exceptions.py @@ -90,6 +90,8 @@ def fake_agent(): """Create a fake agent for testing.""" agent = Mock() agent.get_all_tools = AsyncMock(return_value=[]) + agent.get_system_prompt = AsyncMock(return_value="test instructions") + agent.handoffs = [] return agent