diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index 459a03a9a..ece72e755 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -8,6 +8,7 @@ from typing_extensions import NotRequired, TypeAlias, TypedDict +from ..guardrail import OutputGuardrail from ..model_settings import ToolChoice from ..tool import Tool @@ -82,11 +83,43 @@ class RealtimeSessionModelSettings(TypedDict): tool_choice: NotRequired[ToolChoice] tools: NotRequired[list[Tool]] + tracing: NotRequired[RealtimeModelTracingConfig | None] + + +class RealtimeGuardrailsSettings(TypedDict): + """Settings for output guardrails in realtime sessions.""" + + debounce_text_length: NotRequired[int] + """ + The minimum number of characters to accumulate before running guardrails on transcript + deltas. Defaults to 100. Guardrails run every time the accumulated text reaches + 1x, 2x, 3x, etc. times this threshold. + """ + + +class RealtimeModelTracingConfig(TypedDict): + """Configuration for tracing in realtime model sessions.""" + + workflow_name: NotRequired[str] + """The workflow name to use for tracing.""" + + group_id: NotRequired[str] + """A group identifier to use for tracing, to link multiple traces together.""" + + metadata: NotRequired[dict[str, Any]] + """Additional metadata to include with the trace.""" + class RealtimeRunConfig(TypedDict): model_settings: NotRequired[RealtimeSessionModelSettings] - # TODO (rm) Add tracing support - # tracing: NotRequired[RealtimeTracingConfig | None] - # TODO (rm) Add guardrail support + output_guardrails: NotRequired[list[OutputGuardrail[Any]]] + """List of output guardrails to run on the agent's responses.""" + + guardrails_settings: NotRequired[RealtimeGuardrailsSettings] + """Settings for guardrail execution.""" + + tracing_disabled: NotRequired[bool] + """Whether tracing is disabled for this run.""" + # TODO (rm) Add history audio storage config diff --git a/src/agents/realtime/events.py b/src/agents/realtime/events.py index 44a588f03..e1b3cfea3 100644 --- a/src/agents/realtime/events.py +++ b/src/agents/realtime/events.py @@ -5,6 +5,7 @@ from typing_extensions import TypeAlias +from ..guardrail import OutputGuardrailResult from ..run_context import RunContextWrapper from ..tool import Tool from .agent import RealtimeAgent @@ -181,7 +182,20 @@ class RealtimeHistoryAdded: type: Literal["history_added"] = "history_added" -# TODO (rm) Add guardrails +@dataclass +class RealtimeGuardrailTripped: + """A guardrail has been tripped and the agent has been interrupted.""" + + guardrail_results: list[OutputGuardrailResult] + """The results from all triggered guardrails.""" + + message: str + """The message that was being generated when the guardrail was triggered.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["guardrail_tripped"] = "guardrail_tripped" RealtimeSessionEvent: TypeAlias = Union[ RealtimeAgentStartEvent, @@ -196,5 +210,6 @@ class RealtimeHistoryAdded: RealtimeError, RealtimeHistoryUpdated, RealtimeHistoryAdded, + RealtimeGuardrailTripped, ] """An event emitted by the realtime session.""" diff --git a/src/agents/realtime/model.py b/src/agents/realtime/model.py index 2b41960e7..abb3a1eac 100644 --- a/src/agents/realtime/model.py +++ b/src/agents/realtime/model.py @@ -38,6 +38,7 @@ class RealtimeModelConfig(TypedDict): """ initial_model_settings: NotRequired[RealtimeSessionModelSettings] + """The initial model settings to use when connecting.""" class RealtimeModel(abc.ABC): diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index de8f57ac7..797753242 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -6,7 +6,7 @@ import json import os from datetime import datetime -from typing import Any, Callable +from typing import Any, Callable, Literal import websockets from openai.types.beta.realtime.conversation_item import ConversationItem @@ -23,6 +23,7 @@ from ..logger import logger from .config import ( RealtimeClientMessage, + RealtimeModelTracingConfig, RealtimeSessionModelSettings, RealtimeUserInput, ) @@ -73,6 +74,7 @@ def __init__(self) -> None: self._audio_length_ms: float = 0.0 self._ongoing_response: bool = False self._current_audio_content_index: int | None = None + self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None async def connect(self, options: RealtimeModelConfig) -> None: """Establish a connection to the model and keep it alive.""" @@ -84,6 +86,11 @@ async def connect(self, options: RealtimeModelConfig) -> None: self.model = model_settings.get("model_name", self.model) api_key = await get_api_key(options.get("api_key")) + if "tracing" in model_settings: + self._tracing_config = model_settings["tracing"] + else: + self._tracing_config = "auto" + if not api_key: raise UserError("API key is required but was not provided.") @@ -96,6 +103,15 @@ async def connect(self, options: RealtimeModelConfig) -> None: self._websocket = await websockets.connect(url, additional_headers=headers) self._websocket_task = asyncio.create_task(self._listen_for_messages()) + async def _send_tracing_config( + self, tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None + ) -> None: + """Update tracing configuration via session.update event.""" + if tracing_config is not None: + await self.send_event( + {"type": "session.update", "other_data": {"session": {"tracing": tracing_config}}} + ) + def add_listener(self, listener: RealtimeModelListener) -> None: """Add a listener to the model.""" self._listeners.append(listener) @@ -343,8 +359,7 @@ async def _handle_ws_event(self, event: dict[str, Any]): self._ongoing_response = False await self._emit_event(RealtimeModelTurnEndedEvent()) elif parsed.type == "session.created": - # TODO (rm) tracing stuff here - pass + await self._send_tracing_config(self._tracing_config) elif parsed.type == "error": await self._emit_event(RealtimeModelErrorEvent(error=parsed.error)) elif parsed.type == "conversation.item.deleted": diff --git a/src/agents/realtime/runner.py b/src/agents/realtime/runner.py index 4470ab220..a7047a6f5 100644 --- a/src/agents/realtime/runner.py +++ b/src/agents/realtime/runner.py @@ -69,6 +69,7 @@ async def run( """ model_settings = await self._get_model_settings( agent=self._starting_agent, + disable_tracing=self._config.get("tracing_disabled", False) if self._config else False, initial_settings=model_config.get("initial_model_settings") if model_config else None, overrides=self._config.get("model_settings") if self._config else None, ) @@ -82,6 +83,7 @@ async def run( agent=self._starting_agent, context=context, model_config=model_config, + run_config=self._config, ) return session @@ -89,6 +91,7 @@ async def run( async def _get_model_settings( self, agent: RealtimeAgent, + disable_tracing: bool, context: TContext | None = None, initial_settings: RealtimeSessionModelSettings | None = None, overrides: RealtimeSessionModelSettings | None = None, @@ -109,4 +112,7 @@ async def _get_model_settings( if overrides: model_settings.update(overrides) + if disable_tracing: + model_settings["tracing"] = None + return model_settings diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index df80c063f..0ca3bf7af 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -2,16 +2,17 @@ import asyncio from collections.abc import AsyncIterator -from typing import Any +from typing import Any, cast from typing_extensions import assert_never +from ..agent import Agent from ..handoffs import Handoff from ..run_context import RunContextWrapper, TContext from ..tool import FunctionTool from ..tool_context import ToolContext from .agent import RealtimeAgent -from .config import RealtimeUserInput +from .config import RealtimeRunConfig, RealtimeUserInput from .events import ( RealtimeAgentEndEvent, RealtimeAgentStartEvent, @@ -20,6 +21,7 @@ RealtimeAudioInterrupted, RealtimeError, RealtimeEventInfo, + RealtimeGuardrailTripped, RealtimeHistoryAdded, RealtimeHistoryUpdated, RealtimeRawModelEvent, @@ -62,16 +64,16 @@ def __init__( agent: RealtimeAgent, context: TContext | None, model_config: RealtimeModelConfig | None = None, + run_config: RealtimeRunConfig | None = None, ) -> None: """Initialize the session. Args: model: The model to use. agent: The current agent. - context_wrapper: The context wrapper. - event_info: Event info object. - history: The conversation history. + context: The context object. model_config: Model configuration. + run_config: Runtime configuration including guardrails. """ self._model = model self._current_agent = agent @@ -79,9 +81,18 @@ def __init__( self._event_info = RealtimeEventInfo(context=self._context_wrapper) self._history: list[RealtimeItem] = [] self._model_config = model_config or {} + self._run_config = run_config or {} self._event_queue: asyncio.Queue[RealtimeSessionEvent] = asyncio.Queue() self._closed = False + # Guardrails state tracking + self._interrupted_by_guardrail = False + self._item_transcripts: dict[str, str] = {} # item_id -> accumulated transcript + self._item_guardrail_run_counts: dict[str, int] = {} # item_id -> run count + self._debounce_text_length = self._run_config.get("guardrails_settings", {}).get( + "debounce_text_length", 100 + ) + async def __aenter__(self) -> RealtimeSession: """Start the session by connecting to the model. After this, you will be able to stream events from the model and send messages and audio to the model. @@ -159,8 +170,22 @@ async def on_event(self, event: RealtimeModelEvent) -> None: RealtimeHistoryUpdated(info=self._event_info, history=self._history) ) elif event.type == "transcript_delta": - # TODO (rm) Add guardrails - pass + # Accumulate transcript text for guardrail debouncing per item_id + item_id = event.item_id + if item_id not in self._item_transcripts: + self._item_transcripts[item_id] = "" + self._item_guardrail_run_counts[item_id] = 0 + + self._item_transcripts[item_id] += event.delta + + # Check if we should run guardrails based on debounce threshold + current_length = len(self._item_transcripts[item_id]) + threshold = self._debounce_text_length + next_run_threshold = (self._item_guardrail_run_counts[item_id] + 1) * threshold + + if current_length >= next_run_threshold: + self._item_guardrail_run_counts[item_id] += 1 + await self._run_output_guardrails(self._item_transcripts[item_id]) elif event.type == "item_updated": is_new = not any(item.item_id == event.item.item_id for item in self._history) self._history = self._get_new_history(self._history, event.item) @@ -189,6 +214,11 @@ async def on_event(self, event: RealtimeModelEvent) -> None: ) ) elif event.type == "turn_ended": + # Clear guardrail state for next turn + self._item_transcripts.clear() + self._item_guardrail_run_counts.clear() + self._interrupted_by_guardrail = False + await self._put_event( RealtimeAgentEndEvent( agent=self._current_agent, @@ -290,3 +320,49 @@ def _get_new_history( # Otherwise, add it to the end return old_history + [event] + + async def _run_output_guardrails(self, text: str) -> bool: + """Run output guardrails on the given text. Returns True if any guardrail was triggered.""" + output_guardrails = self._run_config.get("output_guardrails", []) + if not output_guardrails or self._interrupted_by_guardrail: + return False + + triggered_results = [] + + for guardrail in output_guardrails: + try: + result = await guardrail.run( + # TODO (rm) Remove this cast, it's wrong + self._context_wrapper, + cast(Agent[Any], self._current_agent), + text, + ) + if result.output.tripwire_triggered: + triggered_results.append(result) + except Exception: + # Continue with other guardrails if one fails + continue + + if triggered_results: + # Mark as interrupted to prevent multiple interrupts + self._interrupted_by_guardrail = True + + # Emit guardrail tripped event + await self._put_event( + RealtimeGuardrailTripped( + guardrail_results=triggered_results, + message=text, + info=self._event_info, + ) + ) + + # Interrupt the model + await self._model.interrupt() + + # Send guardrail triggered message + guardrail_names = [result.guardrail.get_name() for result in triggered_results] + await self._model.send_message(f"guardrail triggered: {', '.join(guardrail_names)}") + + return True + + return False diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index f0abd4502..6003f9443 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -3,8 +3,10 @@ import pytest +from agents.guardrail import GuardrailFunctionOutput, OutputGuardrail from agents.handoffs import Handoff from agents.realtime.agent import RealtimeAgent +from agents.realtime.config import RealtimeRunConfig from agents.realtime.events import ( RealtimeAgentEndEvent, RealtimeAgentStartEvent, @@ -12,6 +14,7 @@ RealtimeAudioEnd, RealtimeAudioInterrupted, RealtimeError, + RealtimeGuardrailTripped, RealtimeHistoryAdded, RealtimeHistoryUpdated, RealtimeRawModelEvent, @@ -963,3 +966,194 @@ async def test_mixed_tool_types_filtering(self, mock_model, mock_agent): # Verify result sent_call, sent_output, _ = mock_model.sent_tool_outputs[0] assert sent_output == "result2" + + +class TestGuardrailFunctionality: + """Test suite for output guardrail functionality in RealtimeSession""" + + @pytest.fixture + def triggered_guardrail(self): + """Creates a guardrail that always triggers""" + def guardrail_func(context, agent, output): + return GuardrailFunctionOutput( + output_info={"reason": "test trigger"}, + tripwire_triggered=True + ) + return OutputGuardrail(guardrail_function=guardrail_func, name="triggered_guardrail") + + @pytest.fixture + def safe_guardrail(self): + """Creates a guardrail that never triggers""" + def guardrail_func(context, agent, output): + return GuardrailFunctionOutput( + output_info={"reason": "safe content"}, + tripwire_triggered=False + ) + return OutputGuardrail(guardrail_function=guardrail_func, name="safe_guardrail") + + @pytest.mark.asyncio + async def test_transcript_delta_triggers_guardrail_at_threshold( + self, mock_model, mock_agent, triggered_guardrail + ): + """Test that guardrails run when transcript delta reaches debounce threshold""" + run_config: RealtimeRunConfig = { + "output_guardrails": [triggered_guardrail], + "guardrails_settings": {"debounce_text_length": 10} + } + + session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) + + # Send transcript delta that exceeds threshold (10 chars) + transcript_event = RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="this is more than ten characters", response_id="resp_1" + ) + + await session.on_event(transcript_event) + + # Should have triggered guardrail and interrupted + assert session._interrupted_by_guardrail is True + assert mock_model.interrupts_called == 1 + assert len(mock_model.sent_messages) == 1 + assert "triggered_guardrail" in mock_model.sent_messages[0] + + # Should have emitted guardrail_tripped event + events = [] + while not session._event_queue.empty(): + events.append(await session._event_queue.get()) + + guardrail_events = [e for e in events if isinstance(e, RealtimeGuardrailTripped)] + assert len(guardrail_events) == 1 + assert guardrail_events[0].message == "this is more than ten characters" + + @pytest.mark.asyncio + async def test_transcript_delta_multiple_thresholds_same_item( + self, mock_model, mock_agent, triggered_guardrail + ): + """Test guardrails run at 1x, 2x, 3x thresholds for same item_id""" + run_config: RealtimeRunConfig = { + "output_guardrails": [triggered_guardrail], + "guardrails_settings": {"debounce_text_length": 5} + } + + session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) + + # First delta - reaches 1x threshold (5 chars) + await session.on_event(RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="12345", response_id="resp_1" + )) + + # Second delta - reaches 2x threshold (10 chars total) + await session.on_event(RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="67890", response_id="resp_1" + )) + + # Should only trigger once due to interrupted_by_guardrail flag + assert mock_model.interrupts_called == 1 + assert len(mock_model.sent_messages) == 1 + + @pytest.mark.asyncio + async def test_transcript_delta_different_items_tracked_separately( + self, mock_model, mock_agent, safe_guardrail + ): + """Test that different item_ids are tracked separately for debouncing""" + run_config: RealtimeRunConfig = { + "output_guardrails": [safe_guardrail], + "guardrails_settings": {"debounce_text_length": 10} + } + + session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) + + # Add text to item_1 (8 chars - below threshold) + await session.on_event(RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="12345678", response_id="resp_1" + )) + + # Add text to item_2 (8 chars - below threshold) + await session.on_event(RealtimeModelTranscriptDeltaEvent( + item_id="item_2", delta="abcdefgh", response_id="resp_2" + )) + + # Neither should trigger guardrails yet + assert mock_model.interrupts_called == 0 + + # Add more text to item_1 (total 12 chars - above threshold) + await session.on_event(RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="90ab", response_id="resp_1" + )) + + # item_1 should have triggered guardrail run (but not interrupted since safe) + assert session._item_guardrail_run_counts["item_1"] == 1 + assert ( + "item_2" not in session._item_guardrail_run_counts + or session._item_guardrail_run_counts["item_2"] == 0 + ) + + @pytest.mark.asyncio + async def test_turn_ended_clears_guardrail_state( + self, mock_model, mock_agent, triggered_guardrail + ): + """Test that turn_ended event clears guardrail state for next turn""" + run_config: RealtimeRunConfig = { + "output_guardrails": [triggered_guardrail], + "guardrails_settings": {"debounce_text_length": 5} + } + + session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) + + # Trigger guardrail + await session.on_event(RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="trigger", response_id="resp_1" + )) + + assert session._interrupted_by_guardrail is True + assert len(session._item_transcripts) == 1 + + # End turn + await session.on_event(RealtimeModelTurnEndedEvent()) + + # State should be cleared + assert session._interrupted_by_guardrail is False + assert len(session._item_transcripts) == 0 + assert len(session._item_guardrail_run_counts) == 0 + + @pytest.mark.asyncio + async def test_multiple_guardrails_all_triggered( + self, mock_model, mock_agent + ): + """Test that all triggered guardrails are included in the event""" + def create_triggered_guardrail(name): + def guardrail_func(context, agent, output): + return GuardrailFunctionOutput( + output_info={"name": name}, + tripwire_triggered=True + ) + return OutputGuardrail(guardrail_function=guardrail_func, name=name) + + guardrail1 = create_triggered_guardrail("guardrail_1") + guardrail2 = create_triggered_guardrail("guardrail_2") + + run_config: RealtimeRunConfig = { + "output_guardrails": [guardrail1, guardrail2], + "guardrails_settings": {"debounce_text_length": 5} + } + + session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) + + await session.on_event(RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="trigger", response_id="resp_1" + )) + + # Should have interrupted and sent message with both guardrail names + assert mock_model.interrupts_called == 1 + assert len(mock_model.sent_messages) == 1 + message = mock_model.sent_messages[0] + assert "guardrail_1" in message and "guardrail_2" in message + + # Should have emitted event with both guardrail results + events = [] + while not session._event_queue.empty(): + events.append(await session._event_queue.get()) + + guardrail_events = [e for e in events if isinstance(e, RealtimeGuardrailTripped)] + assert len(guardrail_events) == 1 + assert len(guardrail_events[0].guardrail_results) == 2 diff --git a/tests/realtime/test_tracing.py b/tests/realtime/test_tracing.py new file mode 100644 index 000000000..456ae125f --- /dev/null +++ b/tests/realtime/test_tracing.py @@ -0,0 +1,257 @@ +from unittest.mock import AsyncMock, patch + +import pytest + +from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel + + +class TestRealtimeTracingIntegration: + """Test tracing configuration and session.update integration.""" + + @pytest.fixture + def model(self): + """Create a fresh model instance for each test.""" + return OpenAIRealtimeWebSocketModel() + + @pytest.fixture + def mock_websocket(self): + """Create a mock websocket connection.""" + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + mock_ws.close = AsyncMock() + return mock_ws + + @pytest.mark.asyncio + async def test_tracing_config_storage_and_defaults(self, model, mock_websocket): + """Test that tracing config is stored correctly and defaults to 'auto'.""" + # Test with explicit tracing config + config_with_tracing = { + "api_key": "test-key", + "initial_model_settings": { + "tracing": { + "workflow_name": "test_workflow", + "group_id": "group_123", + "metadata": {"version": "1.0"}, + } + }, + } + + async def async_websocket(*args, **kwargs): + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + mock_create_task.return_value = mock_task + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] + + await model.connect(config_with_tracing) + + # Should store the tracing config + assert model._tracing_config == { + "workflow_name": "test_workflow", + "group_id": "group_123", + "metadata": {"version": "1.0"}, + } + + # Test without tracing config - should default to "auto" + model2 = OpenAIRealtimeWebSocketModel() + config_no_tracing = { + "api_key": "test-key", + "initial_model_settings": {}, + } + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] + + await model2.connect(config_no_tracing) # type: ignore[arg-type] + assert model2._tracing_config == "auto" + + @pytest.mark.asyncio + async def test_send_tracing_config_on_session_created(self, model, mock_websocket): + """Test that tracing config is sent when session.created event is received.""" + config = { + "api_key": "test-key", + "initial_model_settings": { + "tracing": {"workflow_name": "test_workflow", "group_id": "group_123"} + }, + } + + async def async_websocket(*args, **kwargs): + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] + + await model.connect(config) + + # Simulate session.created event + session_created_event = { + "type": "session.created", + "event_id": "event_123", + "session": {"id": "session_456"}, + } + + with patch.object(model, "send_event") as mock_send_event: + await model._handle_ws_event(session_created_event) + + # Should send session.update with tracing config + mock_send_event.assert_called_once_with( + { + "type": "session.update", + "other_data": { + "session": { + "tracing": { + "workflow_name": "test_workflow", + "group_id": "group_123", + } + } + }, + } + ) + + @pytest.mark.asyncio + async def test_send_tracing_config_auto_mode(self, model, mock_websocket): + """Test that 'auto' tracing config is sent correctly.""" + config = { + "api_key": "test-key", + "initial_model_settings": {}, # No tracing config - defaults to "auto" + } + + async def async_websocket(*args, **kwargs): + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] + + await model.connect(config) + + session_created_event = { + "type": "session.created", + "event_id": "event_123", + "session": {"id": "session_456"}, + } + + with patch.object(model, "send_event") as mock_send_event: + await model._handle_ws_event(session_created_event) + + # Should send session.update with "auto" + mock_send_event.assert_called_once_with( + {"type": "session.update", "other_data": {"session": {"tracing": "auto"}}} + ) + + @pytest.mark.asyncio + async def test_tracing_config_none_skips_session_update(self, model, mock_websocket): + """Test that None tracing config skips sending session.update.""" + # Manually set tracing config to None (this would happen if explicitly set) + model._tracing_config = None + + session_created_event = { + "type": "session.created", + "event_id": "event_123", + "session": {"id": "session_456"}, + } + + with patch.object(model, "send_event") as mock_send_event: + await model._handle_ws_event(session_created_event) + + # Should not send any session.update + mock_send_event.assert_not_called() + + @pytest.mark.asyncio + async def test_tracing_config_with_metadata_serialization(self, model, mock_websocket): + """Test that complex metadata in tracing config is handled correctly.""" + complex_metadata = { + "user_id": "user_123", + "session_type": "demo", + "features": ["audio", "tools"], + "config": {"timeout": 30, "retries": 3}, + } + + config = { + "api_key": "test-key", + "initial_model_settings": { + "tracing": {"workflow_name": "complex_workflow", "metadata": complex_metadata} + }, + } + + async def async_websocket(*args, **kwargs): + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] + + await model.connect(config) + + session_created_event = { + "type": "session.created", + "event_id": "event_123", + "session": {"id": "session_456"}, + } + + with patch.object(model, "send_event") as mock_send_event: + await model._handle_ws_event(session_created_event) + + # Should send session.update with complete tracing config including metadata + expected_call = { + "type": "session.update", + "other_data": { + "session": { + "tracing": { + "workflow_name": "complex_workflow", + "metadata": complex_metadata, + } + } + }, + } + mock_send_event.assert_called_once_with(expected_call) + + @pytest.mark.asyncio + async def test_tracing_disabled_prevents_tracing(self, mock_websocket): + """Test that tracing_disabled=True prevents tracing configuration.""" + from agents.realtime.agent import RealtimeAgent + from agents.realtime.runner import RealtimeRunner + + # Create a test agent and runner with tracing disabled + agent = RealtimeAgent(name="test_agent", instructions="test") + + runner = RealtimeRunner( + starting_agent=agent, + config={"tracing_disabled": True} + ) + + # Test the _get_model_settings method directly since that's where the logic is + model_settings = await runner._get_model_settings( + agent=agent, + disable_tracing=True, # This should come from config["tracing_disabled"] + initial_settings=None, + overrides=None + ) + + # When tracing is disabled, model settings should have tracing=None + assert model_settings["tracing"] is None + + # Also test that the runner passes disable_tracing=True correctly + with patch.object(runner, '_get_model_settings') as mock_get_settings: + mock_get_settings.return_value = {"tracing": None} + + with patch('agents.realtime.session.RealtimeSession') as mock_session_class: + mock_session = AsyncMock() + mock_session_class.return_value = mock_session + + await runner.run() + + # Verify that _get_model_settings was called with disable_tracing=True + mock_get_settings.assert_called_once_with( + agent=agent, + disable_tracing=True, + initial_settings=None, + overrides=None + )