Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a4499d4

Browse files
authored
Realtime guardrail support (openai#1082)
Guadrails with debouncing support. Output only. --- [//]: # (BEGIN SAPLING FOOTER) * openai#1084 * __->__ openai#1082
1 parent eafd8df commit a4499d4

File tree

5 files changed

+312
-9
lines changed

5 files changed

+312
-9
lines changed

src/agents/realtime/config.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from typing_extensions import NotRequired, TypeAlias, TypedDict
1010

11+
from ..guardrail import OutputGuardrail
1112
from ..model_settings import ToolChoice
1213
from ..tool import Tool
1314

@@ -83,10 +84,26 @@ class RealtimeSessionModelSettings(TypedDict):
8384
tools: NotRequired[list[Tool]]
8485

8586

87+
class RealtimeGuardrailsSettings(TypedDict):
88+
"""Settings for output guardrails in realtime sessions."""
89+
90+
debounce_text_length: NotRequired[int]
91+
"""
92+
The minimum number of characters to accumulate before running guardrails on transcript
93+
deltas. Defaults to 100. Guardrails run every time the accumulated text reaches
94+
1x, 2x, 3x, etc. times this threshold.
95+
"""
96+
97+
8698
class RealtimeRunConfig(TypedDict):
8799
model_settings: NotRequired[RealtimeSessionModelSettings]
88100

101+
output_guardrails: NotRequired[list[OutputGuardrail[Any]]]
102+
"""List of output guardrails to run on the agent's responses."""
103+
104+
guardrails_settings: NotRequired[RealtimeGuardrailsSettings]
105+
"""Settings for guardrail execution."""
106+
89107
# TODO (rm) Add tracing support
90108
# tracing: NotRequired[RealtimeTracingConfig | None]
91-
# TODO (rm) Add guardrail support
92109
# TODO (rm) Add history audio storage config

src/agents/realtime/events.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from typing_extensions import TypeAlias
77

8+
from ..guardrail import OutputGuardrailResult
89
from ..run_context import RunContextWrapper
910
from ..tool import Tool
1011
from .agent import RealtimeAgent
@@ -181,7 +182,20 @@ class RealtimeHistoryAdded:
181182
type: Literal["history_added"] = "history_added"
182183

183184

184-
# TODO (rm) Add guardrails
185+
@dataclass
186+
class RealtimeGuardrailTripped:
187+
"""A guardrail has been tripped and the agent has been interrupted."""
188+
189+
guardrail_results: list[OutputGuardrailResult]
190+
"""The results from all triggered guardrails."""
191+
192+
message: str
193+
"""The message that was being generated when the guardrail was triggered."""
194+
195+
info: RealtimeEventInfo
196+
"""Common info for all events, such as the context."""
197+
198+
type: Literal["guardrail_tripped"] = "guardrail_tripped"
185199

186200
RealtimeSessionEvent: TypeAlias = Union[
187201
RealtimeAgentStartEvent,
@@ -196,5 +210,6 @@ class RealtimeHistoryAdded:
196210
RealtimeError,
197211
RealtimeHistoryUpdated,
198212
RealtimeHistoryAdded,
213+
RealtimeGuardrailTripped,
199214
]
200215
"""An event emitted by the realtime session."""

src/agents/realtime/runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ async def run(
8282
agent=self._starting_agent,
8383
context=context,
8484
model_config=model_config,
85+
run_config=self._config,
8586
)
8687

8788
return session

src/agents/realtime/session.py

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22

33
import asyncio
44
from collections.abc import AsyncIterator
5-
from typing import Any
5+
from typing import Any, cast
66

77
from typing_extensions import assert_never
88

9+
from ..agent import Agent
910
from ..handoffs import Handoff
1011
from ..run_context import RunContextWrapper, TContext
1112
from ..tool import FunctionTool
1213
from ..tool_context import ToolContext
1314
from .agent import RealtimeAgent
14-
from .config import RealtimeUserInput
15+
from .config import RealtimeRunConfig, RealtimeUserInput
1516
from .events import (
1617
RealtimeAgentEndEvent,
1718
RealtimeAgentStartEvent,
@@ -20,6 +21,7 @@
2021
RealtimeAudioInterrupted,
2122
RealtimeError,
2223
RealtimeEventInfo,
24+
RealtimeGuardrailTripped,
2325
RealtimeHistoryAdded,
2426
RealtimeHistoryUpdated,
2527
RealtimeRawModelEvent,
@@ -62,26 +64,35 @@ def __init__(
6264
agent: RealtimeAgent,
6365
context: TContext | None,
6466
model_config: RealtimeModelConfig | None = None,
67+
run_config: RealtimeRunConfig | None = None,
6568
) -> None:
6669
"""Initialize the session.
6770
6871
Args:
6972
model: The model to use.
7073
agent: The current agent.
71-
context_wrapper: The context wrapper.
72-
event_info: Event info object.
73-
history: The conversation history.
74+
context: The context object.
7475
model_config: Model configuration.
76+
run_config: Runtime configuration including guardrails.
7577
"""
7678
self._model = model
7779
self._current_agent = agent
7880
self._context_wrapper = RunContextWrapper(context)
7981
self._event_info = RealtimeEventInfo(context=self._context_wrapper)
8082
self._history: list[RealtimeItem] = []
8183
self._model_config = model_config or {}
84+
self._run_config = run_config or {}
8285
self._event_queue: asyncio.Queue[RealtimeSessionEvent] = asyncio.Queue()
8386
self._closed = False
8487

88+
# Guardrails state tracking
89+
self._interrupted_by_guardrail = False
90+
self._item_transcripts: dict[str, str] = {} # item_id -> accumulated transcript
91+
self._item_guardrail_run_counts: dict[str, int] = {} # item_id -> run count
92+
self._debounce_text_length = self._run_config.get("guardrails_settings", {}).get(
93+
"debounce_text_length", 100
94+
)
95+
8596
async def __aenter__(self) -> RealtimeSession:
8697
"""Start the session by connecting to the model. After this, you will be able to stream
8798
events from the model and send messages and audio to the model.
@@ -159,8 +170,22 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
159170
RealtimeHistoryUpdated(info=self._event_info, history=self._history)
160171
)
161172
elif event.type == "transcript_delta":
162-
# TODO (rm) Add guardrails
163-
pass
173+
# Accumulate transcript text for guardrail debouncing per item_id
174+
item_id = event.item_id
175+
if item_id not in self._item_transcripts:
176+
self._item_transcripts[item_id] = ""
177+
self._item_guardrail_run_counts[item_id] = 0
178+
179+
self._item_transcripts[item_id] += event.delta
180+
181+
# Check if we should run guardrails based on debounce threshold
182+
current_length = len(self._item_transcripts[item_id])
183+
threshold = self._debounce_text_length
184+
next_run_threshold = (self._item_guardrail_run_counts[item_id] + 1) * threshold
185+
186+
if current_length >= next_run_threshold:
187+
self._item_guardrail_run_counts[item_id] += 1
188+
await self._run_output_guardrails(self._item_transcripts[item_id])
164189
elif event.type == "item_updated":
165190
is_new = not any(item.item_id == event.item.item_id for item in self._history)
166191
self._history = self._get_new_history(self._history, event.item)
@@ -189,6 +214,11 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
189214
)
190215
)
191216
elif event.type == "turn_ended":
217+
# Clear guardrail state for next turn
218+
self._item_transcripts.clear()
219+
self._item_guardrail_run_counts.clear()
220+
self._interrupted_by_guardrail = False
221+
192222
await self._put_event(
193223
RealtimeAgentEndEvent(
194224
agent=self._current_agent,
@@ -290,3 +320,49 @@ def _get_new_history(
290320

291321
# Otherwise, add it to the end
292322
return old_history + [event]
323+
324+
async def _run_output_guardrails(self, text: str) -> bool:
325+
"""Run output guardrails on the given text. Returns True if any guardrail was triggered."""
326+
output_guardrails = self._run_config.get("output_guardrails", [])
327+
if not output_guardrails or self._interrupted_by_guardrail:
328+
return False
329+
330+
triggered_results = []
331+
332+
for guardrail in output_guardrails:
333+
try:
334+
result = await guardrail.run(
335+
# TODO (rm) Remove this cast, it's wrong
336+
self._context_wrapper,
337+
cast(Agent[Any], self._current_agent),
338+
text,
339+
)
340+
if result.output.tripwire_triggered:
341+
triggered_results.append(result)
342+
except Exception:
343+
# Continue with other guardrails if one fails
344+
continue
345+
346+
if triggered_results:
347+
# Mark as interrupted to prevent multiple interrupts
348+
self._interrupted_by_guardrail = True
349+
350+
# Emit guardrail tripped event
351+
await self._put_event(
352+
RealtimeGuardrailTripped(
353+
guardrail_results=triggered_results,
354+
message=text,
355+
info=self._event_info,
356+
)
357+
)
358+
359+
# Interrupt the model
360+
await self._model.interrupt()
361+
362+
# Send guardrail triggered message
363+
guardrail_names = [result.guardrail.get_name() for result in triggered_results]
364+
await self._model.send_message(f"guardrail triggered: {', '.join(guardrail_names)}")
365+
366+
return True
367+
368+
return False

0 commit comments

Comments
 (0)