From 57a1575474a34b8f337b93ae458cbfaec71e2c40 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Fri, 11 Jul 2025 22:05:19 -0400 Subject: [PATCH] Realtime guardrail support --- src/agents/realtime/config.py | 19 +++- src/agents/realtime/events.py | 17 ++- src/agents/realtime/runner.py | 1 + src/agents/realtime/session.py | 90 +++++++++++++-- tests/realtime/test_session.py | 194 +++++++++++++++++++++++++++++++++ 5 files changed, 312 insertions(+), 9 deletions(-) diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index 459a03a9a..ad98c9c7a 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -8,6 +8,7 @@ from typing_extensions import NotRequired, TypeAlias, TypedDict +from ..guardrail import OutputGuardrail from ..model_settings import ToolChoice from ..tool import Tool @@ -83,10 +84,26 @@ class RealtimeSessionModelSettings(TypedDict): tools: NotRequired[list[Tool]] +class RealtimeGuardrailsSettings(TypedDict): + """Settings for output guardrails in realtime sessions.""" + + debounce_text_length: NotRequired[int] + """ + The minimum number of characters to accumulate before running guardrails on transcript + deltas. Defaults to 100. Guardrails run every time the accumulated text reaches + 1x, 2x, 3x, etc. times this threshold. + """ + + class RealtimeRunConfig(TypedDict): model_settings: NotRequired[RealtimeSessionModelSettings] + output_guardrails: NotRequired[list[OutputGuardrail[Any]]] + """List of output guardrails to run on the agent's responses.""" + + guardrails_settings: NotRequired[RealtimeGuardrailsSettings] + """Settings for guardrail execution.""" + # TODO (rm) Add tracing support # tracing: NotRequired[RealtimeTracingConfig | None] - # TODO (rm) Add guardrail support # TODO (rm) Add history audio storage config diff --git a/src/agents/realtime/events.py b/src/agents/realtime/events.py index 44a588f03..e1b3cfea3 100644 --- a/src/agents/realtime/events.py +++ b/src/agents/realtime/events.py @@ -5,6 +5,7 @@ from typing_extensions import TypeAlias +from ..guardrail import OutputGuardrailResult from ..run_context import RunContextWrapper from ..tool import Tool from .agent import RealtimeAgent @@ -181,7 +182,20 @@ class RealtimeHistoryAdded: type: Literal["history_added"] = "history_added" -# TODO (rm) Add guardrails +@dataclass +class RealtimeGuardrailTripped: + """A guardrail has been tripped and the agent has been interrupted.""" + + guardrail_results: list[OutputGuardrailResult] + """The results from all triggered guardrails.""" + + message: str + """The message that was being generated when the guardrail was triggered.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["guardrail_tripped"] = "guardrail_tripped" RealtimeSessionEvent: TypeAlias = Union[ RealtimeAgentStartEvent, @@ -196,5 +210,6 @@ class RealtimeHistoryAdded: RealtimeError, RealtimeHistoryUpdated, RealtimeHistoryAdded, + RealtimeGuardrailTripped, ] """An event emitted by the realtime session.""" diff --git a/src/agents/realtime/runner.py b/src/agents/realtime/runner.py index 4470ab220..369267797 100644 --- a/src/agents/realtime/runner.py +++ b/src/agents/realtime/runner.py @@ -82,6 +82,7 @@ async def run( agent=self._starting_agent, context=context, model_config=model_config, + run_config=self._config, ) return session diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index df80c063f..0ca3bf7af 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -2,16 +2,17 @@ import asyncio from collections.abc import AsyncIterator -from typing import Any +from typing import Any, cast from typing_extensions import assert_never +from ..agent import Agent from ..handoffs import Handoff from ..run_context import RunContextWrapper, TContext from ..tool import FunctionTool from ..tool_context import ToolContext from .agent import RealtimeAgent -from .config import RealtimeUserInput +from .config import RealtimeRunConfig, RealtimeUserInput from .events import ( RealtimeAgentEndEvent, RealtimeAgentStartEvent, @@ -20,6 +21,7 @@ RealtimeAudioInterrupted, RealtimeError, RealtimeEventInfo, + RealtimeGuardrailTripped, RealtimeHistoryAdded, RealtimeHistoryUpdated, RealtimeRawModelEvent, @@ -62,16 +64,16 @@ def __init__( agent: RealtimeAgent, context: TContext | None, model_config: RealtimeModelConfig | None = None, + run_config: RealtimeRunConfig | None = None, ) -> None: """Initialize the session. Args: model: The model to use. agent: The current agent. - context_wrapper: The context wrapper. - event_info: Event info object. - history: The conversation history. + context: The context object. model_config: Model configuration. + run_config: Runtime configuration including guardrails. """ self._model = model self._current_agent = agent @@ -79,9 +81,18 @@ def __init__( self._event_info = RealtimeEventInfo(context=self._context_wrapper) self._history: list[RealtimeItem] = [] self._model_config = model_config or {} + self._run_config = run_config or {} self._event_queue: asyncio.Queue[RealtimeSessionEvent] = asyncio.Queue() self._closed = False + # Guardrails state tracking + self._interrupted_by_guardrail = False + self._item_transcripts: dict[str, str] = {} # item_id -> accumulated transcript + self._item_guardrail_run_counts: dict[str, int] = {} # item_id -> run count + self._debounce_text_length = self._run_config.get("guardrails_settings", {}).get( + "debounce_text_length", 100 + ) + async def __aenter__(self) -> RealtimeSession: """Start the session by connecting to the model. After this, you will be able to stream events from the model and send messages and audio to the model. @@ -159,8 +170,22 @@ async def on_event(self, event: RealtimeModelEvent) -> None: RealtimeHistoryUpdated(info=self._event_info, history=self._history) ) elif event.type == "transcript_delta": - # TODO (rm) Add guardrails - pass + # Accumulate transcript text for guardrail debouncing per item_id + item_id = event.item_id + if item_id not in self._item_transcripts: + self._item_transcripts[item_id] = "" + self._item_guardrail_run_counts[item_id] = 0 + + self._item_transcripts[item_id] += event.delta + + # Check if we should run guardrails based on debounce threshold + current_length = len(self._item_transcripts[item_id]) + threshold = self._debounce_text_length + next_run_threshold = (self._item_guardrail_run_counts[item_id] + 1) * threshold + + if current_length >= next_run_threshold: + self._item_guardrail_run_counts[item_id] += 1 + await self._run_output_guardrails(self._item_transcripts[item_id]) elif event.type == "item_updated": is_new = not any(item.item_id == event.item.item_id for item in self._history) self._history = self._get_new_history(self._history, event.item) @@ -189,6 +214,11 @@ async def on_event(self, event: RealtimeModelEvent) -> None: ) ) elif event.type == "turn_ended": + # Clear guardrail state for next turn + self._item_transcripts.clear() + self._item_guardrail_run_counts.clear() + self._interrupted_by_guardrail = False + await self._put_event( RealtimeAgentEndEvent( agent=self._current_agent, @@ -290,3 +320,49 @@ def _get_new_history( # Otherwise, add it to the end return old_history + [event] + + async def _run_output_guardrails(self, text: str) -> bool: + """Run output guardrails on the given text. Returns True if any guardrail was triggered.""" + output_guardrails = self._run_config.get("output_guardrails", []) + if not output_guardrails or self._interrupted_by_guardrail: + return False + + triggered_results = [] + + for guardrail in output_guardrails: + try: + result = await guardrail.run( + # TODO (rm) Remove this cast, it's wrong + self._context_wrapper, + cast(Agent[Any], self._current_agent), + text, + ) + if result.output.tripwire_triggered: + triggered_results.append(result) + except Exception: + # Continue with other guardrails if one fails + continue + + if triggered_results: + # Mark as interrupted to prevent multiple interrupts + self._interrupted_by_guardrail = True + + # Emit guardrail tripped event + await self._put_event( + RealtimeGuardrailTripped( + guardrail_results=triggered_results, + message=text, + info=self._event_info, + ) + ) + + # Interrupt the model + await self._model.interrupt() + + # Send guardrail triggered message + guardrail_names = [result.guardrail.get_name() for result in triggered_results] + await self._model.send_message(f"guardrail triggered: {', '.join(guardrail_names)}") + + return True + + return False diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index f0abd4502..6003f9443 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -3,8 +3,10 @@ import pytest +from agents.guardrail import GuardrailFunctionOutput, OutputGuardrail from agents.handoffs import Handoff from agents.realtime.agent import RealtimeAgent +from agents.realtime.config import RealtimeRunConfig from agents.realtime.events import ( RealtimeAgentEndEvent, RealtimeAgentStartEvent, @@ -12,6 +14,7 @@ RealtimeAudioEnd, RealtimeAudioInterrupted, RealtimeError, + RealtimeGuardrailTripped, RealtimeHistoryAdded, RealtimeHistoryUpdated, RealtimeRawModelEvent, @@ -963,3 +966,194 @@ async def test_mixed_tool_types_filtering(self, mock_model, mock_agent): # Verify result sent_call, sent_output, _ = mock_model.sent_tool_outputs[0] assert sent_output == "result2" + + +class TestGuardrailFunctionality: + """Test suite for output guardrail functionality in RealtimeSession""" + + @pytest.fixture + def triggered_guardrail(self): + """Creates a guardrail that always triggers""" + def guardrail_func(context, agent, output): + return GuardrailFunctionOutput( + output_info={"reason": "test trigger"}, + tripwire_triggered=True + ) + return OutputGuardrail(guardrail_function=guardrail_func, name="triggered_guardrail") + + @pytest.fixture + def safe_guardrail(self): + """Creates a guardrail that never triggers""" + def guardrail_func(context, agent, output): + return GuardrailFunctionOutput( + output_info={"reason": "safe content"}, + tripwire_triggered=False + ) + return OutputGuardrail(guardrail_function=guardrail_func, name="safe_guardrail") + + @pytest.mark.asyncio + async def test_transcript_delta_triggers_guardrail_at_threshold( + self, mock_model, mock_agent, triggered_guardrail + ): + """Test that guardrails run when transcript delta reaches debounce threshold""" + run_config: RealtimeRunConfig = { + "output_guardrails": [triggered_guardrail], + "guardrails_settings": {"debounce_text_length": 10} + } + + session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) + + # Send transcript delta that exceeds threshold (10 chars) + transcript_event = RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="this is more than ten characters", response_id="resp_1" + ) + + await session.on_event(transcript_event) + + # Should have triggered guardrail and interrupted + assert session._interrupted_by_guardrail is True + assert mock_model.interrupts_called == 1 + assert len(mock_model.sent_messages) == 1 + assert "triggered_guardrail" in mock_model.sent_messages[0] + + # Should have emitted guardrail_tripped event + events = [] + while not session._event_queue.empty(): + events.append(await session._event_queue.get()) + + guardrail_events = [e for e in events if isinstance(e, RealtimeGuardrailTripped)] + assert len(guardrail_events) == 1 + assert guardrail_events[0].message == "this is more than ten characters" + + @pytest.mark.asyncio + async def test_transcript_delta_multiple_thresholds_same_item( + self, mock_model, mock_agent, triggered_guardrail + ): + """Test guardrails run at 1x, 2x, 3x thresholds for same item_id""" + run_config: RealtimeRunConfig = { + "output_guardrails": [triggered_guardrail], + "guardrails_settings": {"debounce_text_length": 5} + } + + session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) + + # First delta - reaches 1x threshold (5 chars) + await session.on_event(RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="12345", response_id="resp_1" + )) + + # Second delta - reaches 2x threshold (10 chars total) + await session.on_event(RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="67890", response_id="resp_1" + )) + + # Should only trigger once due to interrupted_by_guardrail flag + assert mock_model.interrupts_called == 1 + assert len(mock_model.sent_messages) == 1 + + @pytest.mark.asyncio + async def test_transcript_delta_different_items_tracked_separately( + self, mock_model, mock_agent, safe_guardrail + ): + """Test that different item_ids are tracked separately for debouncing""" + run_config: RealtimeRunConfig = { + "output_guardrails": [safe_guardrail], + "guardrails_settings": {"debounce_text_length": 10} + } + + session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) + + # Add text to item_1 (8 chars - below threshold) + await session.on_event(RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="12345678", response_id="resp_1" + )) + + # Add text to item_2 (8 chars - below threshold) + await session.on_event(RealtimeModelTranscriptDeltaEvent( + item_id="item_2", delta="abcdefgh", response_id="resp_2" + )) + + # Neither should trigger guardrails yet + assert mock_model.interrupts_called == 0 + + # Add more text to item_1 (total 12 chars - above threshold) + await session.on_event(RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="90ab", response_id="resp_1" + )) + + # item_1 should have triggered guardrail run (but not interrupted since safe) + assert session._item_guardrail_run_counts["item_1"] == 1 + assert ( + "item_2" not in session._item_guardrail_run_counts + or session._item_guardrail_run_counts["item_2"] == 0 + ) + + @pytest.mark.asyncio + async def test_turn_ended_clears_guardrail_state( + self, mock_model, mock_agent, triggered_guardrail + ): + """Test that turn_ended event clears guardrail state for next turn""" + run_config: RealtimeRunConfig = { + "output_guardrails": [triggered_guardrail], + "guardrails_settings": {"debounce_text_length": 5} + } + + session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) + + # Trigger guardrail + await session.on_event(RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="trigger", response_id="resp_1" + )) + + assert session._interrupted_by_guardrail is True + assert len(session._item_transcripts) == 1 + + # End turn + await session.on_event(RealtimeModelTurnEndedEvent()) + + # State should be cleared + assert session._interrupted_by_guardrail is False + assert len(session._item_transcripts) == 0 + assert len(session._item_guardrail_run_counts) == 0 + + @pytest.mark.asyncio + async def test_multiple_guardrails_all_triggered( + self, mock_model, mock_agent + ): + """Test that all triggered guardrails are included in the event""" + def create_triggered_guardrail(name): + def guardrail_func(context, agent, output): + return GuardrailFunctionOutput( + output_info={"name": name}, + tripwire_triggered=True + ) + return OutputGuardrail(guardrail_function=guardrail_func, name=name) + + guardrail1 = create_triggered_guardrail("guardrail_1") + guardrail2 = create_triggered_guardrail("guardrail_2") + + run_config: RealtimeRunConfig = { + "output_guardrails": [guardrail1, guardrail2], + "guardrails_settings": {"debounce_text_length": 5} + } + + session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) + + await session.on_event(RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="trigger", response_id="resp_1" + )) + + # Should have interrupted and sent message with both guardrail names + assert mock_model.interrupts_called == 1 + assert len(mock_model.sent_messages) == 1 + message = mock_model.sent_messages[0] + assert "guardrail_1" in message and "guardrail_2" in message + + # Should have emitted event with both guardrail results + events = [] + while not session._event_queue.empty(): + events.append(await session._event_queue.get()) + + guardrail_events = [e for e in events if isinstance(e, RealtimeGuardrailTripped)] + assert len(guardrail_events) == 1 + assert len(guardrail_events[0].guardrail_results) == 2