Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Realtime guardrail support #1082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/agents/realtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from typing_extensions import NotRequired, TypeAlias, TypedDict

from ..guardrail import OutputGuardrail
from ..model_settings import ToolChoice
from ..tool import Tool

Expand Down Expand Up @@ -83,10 +84,26 @@ class RealtimeSessionModelSettings(TypedDict):
tools: NotRequired[list[Tool]]


class RealtimeGuardrailsSettings(TypedDict):
"""Settings for output guardrails in realtime sessions."""

debounce_text_length: NotRequired[int]
"""
The minimum number of characters to accumulate before running guardrails on transcript
deltas. Defaults to 100. Guardrails run every time the accumulated text reaches
1x, 2x, 3x, etc. times this threshold.
"""


class RealtimeRunConfig(TypedDict):
model_settings: NotRequired[RealtimeSessionModelSettings]

output_guardrails: NotRequired[list[OutputGuardrail[Any]]]
"""List of output guardrails to run on the agent's responses."""

guardrails_settings: NotRequired[RealtimeGuardrailsSettings]
"""Settings for guardrail execution."""

# TODO (rm) Add tracing support
# tracing: NotRequired[RealtimeTracingConfig | None]
# TODO (rm) Add guardrail support
# TODO (rm) Add history audio storage config
17 changes: 16 additions & 1 deletion src/agents/realtime/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing_extensions import TypeAlias

from ..guardrail import OutputGuardrailResult
from ..run_context import RunContextWrapper
from ..tool import Tool
from .agent import RealtimeAgent
Expand Down Expand Up @@ -181,7 +182,20 @@ class RealtimeHistoryAdded:
type: Literal["history_added"] = "history_added"


# TODO (rm) Add guardrails
@dataclass
class RealtimeGuardrailTripped:
"""A guardrail has been tripped and the agent has been interrupted."""

guardrail_results: list[OutputGuardrailResult]
"""The results from all triggered guardrails."""

message: str
"""The message that was being generated when the guardrail was triggered."""

info: RealtimeEventInfo
"""Common info for all events, such as the context."""

type: Literal["guardrail_tripped"] = "guardrail_tripped"

RealtimeSessionEvent: TypeAlias = Union[
RealtimeAgentStartEvent,
Expand All @@ -196,5 +210,6 @@ class RealtimeHistoryAdded:
RealtimeError,
RealtimeHistoryUpdated,
RealtimeHistoryAdded,
RealtimeGuardrailTripped,
]
"""An event emitted by the realtime session."""
1 change: 1 addition & 0 deletions src/agents/realtime/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ async def run(
agent=self._starting_agent,
context=context,
model_config=model_config,
run_config=self._config,
)

return session
Expand Down
90 changes: 83 additions & 7 deletions src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@

import asyncio
from collections.abc import AsyncIterator
from typing import Any
from typing import Any, cast

from typing_extensions import assert_never

from ..agent import Agent
from ..handoffs import Handoff
from ..run_context import RunContextWrapper, TContext
from ..tool import FunctionTool
from ..tool_context import ToolContext
from .agent import RealtimeAgent
from .config import RealtimeUserInput
from .config import RealtimeRunConfig, RealtimeUserInput
from .events import (
RealtimeAgentEndEvent,
RealtimeAgentStartEvent,
Expand All @@ -20,6 +21,7 @@
RealtimeAudioInterrupted,
RealtimeError,
RealtimeEventInfo,
RealtimeGuardrailTripped,
RealtimeHistoryAdded,
RealtimeHistoryUpdated,
RealtimeRawModelEvent,
Expand Down Expand Up @@ -62,26 +64,35 @@ def __init__(
agent: RealtimeAgent,
context: TContext | None,
model_config: RealtimeModelConfig | None = None,
run_config: RealtimeRunConfig | None = None,
) -> None:
"""Initialize the session.

Args:
model: The model to use.
agent: The current agent.
context_wrapper: The context wrapper.
event_info: Event info object.
history: The conversation history.
context: The context object.
model_config: Model configuration.
run_config: Runtime configuration including guardrails.
"""
self._model = model
self._current_agent = agent
self._context_wrapper = RunContextWrapper(context)
self._event_info = RealtimeEventInfo(context=self._context_wrapper)
self._history: list[RealtimeItem] = []
self._model_config = model_config or {}
self._run_config = run_config or {}
self._event_queue: asyncio.Queue[RealtimeSessionEvent] = asyncio.Queue()
self._closed = False

# Guardrails state tracking
self._interrupted_by_guardrail = False
self._item_transcripts: dict[str, str] = {} # item_id -> accumulated transcript
self._item_guardrail_run_counts: dict[str, int] = {} # item_id -> run count
self._debounce_text_length = self._run_config.get("guardrails_settings", {}).get(
"debounce_text_length", 100
)

async def __aenter__(self) -> RealtimeSession:
"""Start the session by connecting to the model. After this, you will be able to stream
events from the model and send messages and audio to the model.
Expand Down Expand Up @@ -159,8 +170,22 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
RealtimeHistoryUpdated(info=self._event_info, history=self._history)
)
elif event.type == "transcript_delta":
# TODO (rm) Add guardrails
pass
# Accumulate transcript text for guardrail debouncing per item_id
item_id = event.item_id
if item_id not in self._item_transcripts:
self._item_transcripts[item_id] = ""
self._item_guardrail_run_counts[item_id] = 0

self._item_transcripts[item_id] += event.delta

# Check if we should run guardrails based on debounce threshold
current_length = len(self._item_transcripts[item_id])
threshold = self._debounce_text_length
next_run_threshold = (self._item_guardrail_run_counts[item_id] + 1) * threshold

if current_length >= next_run_threshold:
self._item_guardrail_run_counts[item_id] += 1
await self._run_output_guardrails(self._item_transcripts[item_id])
elif event.type == "item_updated":
is_new = not any(item.item_id == event.item.item_id for item in self._history)
self._history = self._get_new_history(self._history, event.item)
Expand Down Expand Up @@ -189,6 +214,11 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
)
)
elif event.type == "turn_ended":
# Clear guardrail state for next turn
self._item_transcripts.clear()
self._item_guardrail_run_counts.clear()
self._interrupted_by_guardrail = False

await self._put_event(
RealtimeAgentEndEvent(
agent=self._current_agent,
Expand Down Expand Up @@ -290,3 +320,49 @@ def _get_new_history(

# Otherwise, add it to the end
return old_history + [event]

async def _run_output_guardrails(self, text: str) -> bool:
"""Run output guardrails on the given text. Returns True if any guardrail was triggered."""
output_guardrails = self._run_config.get("output_guardrails", [])
if not output_guardrails or self._interrupted_by_guardrail:
return False

triggered_results = []

for guardrail in output_guardrails:
try:
result = await guardrail.run(
# TODO (rm) Remove this cast, it's wrong
self._context_wrapper,
cast(Agent[Any], self._current_agent),
text,
)
if result.output.tripwire_triggered:
triggered_results.append(result)
except Exception:
# Continue with other guardrails if one fails
continue

if triggered_results:
# Mark as interrupted to prevent multiple interrupts
self._interrupted_by_guardrail = True

# Emit guardrail tripped event
await self._put_event(
RealtimeGuardrailTripped(
guardrail_results=triggered_results,
message=text,
info=self._event_info,
)
)

# Interrupt the model
await self._model.interrupt()

# Send guardrail triggered message
guardrail_names = [result.guardrail.get_name() for result in triggered_results]
await self._model.send_message(f"guardrail triggered: {', '.join(guardrail_names)}")

return True

return False
Loading