diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..a75c1414f --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,14 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Python File", + "type": "debugpy", + "request": "launch", + "program": "${file}" + } + ] +} \ No newline at end of file diff --git a/README.md b/README.md index 755c342ae..638f94b21 100644 --- a/README.md +++ b/README.md @@ -12,10 +12,119 @@ The OpenAI Agents SDK is a lightweight yet powerful framework for building multi 1. [**Agents**](https://openai.github.io/openai-agents-python/agents): LLMs configured with instructions, tools, guardrails, and handoffs 2. [**Handoffs**](https://openai.github.io/openai-agents-python/handoffs/): A specialized tool call used by the Agents SDK for transferring control between agents 3. [**Guardrails**](https://openai.github.io/openai-agents-python/guardrails/): Configurable safety checks for input and output validation -4. [**Tracing**](https://openai.github.io/openai-agents-python/tracing/): Built-in tracking of agent runs, allowing you to view, debug and optimize your workflows +4. [**Sessions**](#sessions): Automatic conversation history management across agent runs +5. [**Tracing**](https://openai.github.io/openai-agents-python/tracing/): Built-in tracking of agent runs, allowing you to view, debug and optimize your workflows Explore the [examples](examples) directory to see the SDK in action, and read our [documentation](https://openai.github.io/openai-agents-python/) for more details. +## Sessions + +The Agents SDK provides built-in session memory to automatically maintain conversation history across multiple agent runs, eliminating the need to manually handle `.to_input_list()` between turns. + +### Quick start + +```python +from agents import Agent, Runner, SQLiteSession + +# Create agent +agent = Agent( + name="Assistant", + instructions="Reply very concisely.", +) + +# Create a session instance +session = SQLiteSession("conversation_123") + +# First turn +result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session +) +print(result.final_output) # "San Francisco" + +# Second turn - agent automatically remembers previous context +result = await Runner.run( + agent, + "What state is it in?", + session=session +) +print(result.final_output) # "California" + +# Also works with synchronous runner +result = Runner.run_sync( + agent, + "What's the population?", + session=session +) +print(result.final_output) # "Approximately 39 million" +``` + +### Session options + +- **No memory** (default): No session memory when session parameter is omitted +- **`session: Session = DatabaseSession(...)`**: Use a Session instance to manage conversation history + +```python +from agents import Agent, Runner, SQLiteSession + +# Custom SQLite database file +session = SQLiteSession("user_123", "conversations.db") +agent = Agent(name="Assistant") + +# Different session IDs maintain separate conversation histories +result1 = await Runner.run( + agent, + "Hello", + session=session +) +result2 = await Runner.run( + agent, + "Hello", + session=SQLiteSession("user_456", "conversations.db") +) +``` + +### Custom session implementations + +You can implement your own session memory by creating a class that follows the `Session` protocol: + +```python +from agents.memory import Session +from typing import List + +class MyCustomSession: + """Custom session implementation following the Session protocol.""" + + def __init__(self, session_id: str): + self.session_id = session_id + # Your initialization here + + async def get_items(self, limit: int | None = None) -> List[dict]: + # Retrieve conversation history for the session + pass + + async def add_items(self, items: List[dict]) -> None: + # Store new items for the session + pass + + async def pop_item(self) -> dict | None: + # Remove and return the most recent item from the session + pass + + async def clear_session(self) -> None: + # Clear all items for the session + pass + +# Use your custom session +agent = Agent(name="Assistant") +result = await Runner.run( + agent, + "Hello", + session=MyCustomSession("my_session") +) +``` + ## Get started 1. Set up your Python environment diff --git a/docs/examples.md b/docs/examples.md index 30d602827..ae40fa909 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -40,3 +40,6 @@ Check out a variety of sample implementations of the SDK in the examples section - **[voice](https://github.com/openai/openai-agents-python/tree/main/examples/voice):** See examples of voice agents, using our TTS and STT models. + +- **[realtime](https://github.com/openai/openai-agents-python/tree/main/examples/realtime):** + Examples showing how to build realtime experiences using the SDK. diff --git a/docs/index.md b/docs/index.md index 8aef6574e..935c4be5b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -5,6 +5,7 @@ The [OpenAI Agents SDK](https://github.com/openai/openai-agents-python) enables - **Agents**, which are LLMs equipped with instructions and tools - **Handoffs**, which allow agents to delegate to other agents for specific tasks - **Guardrails**, which enable the inputs to agents to be validated +- **Sessions**, which automatically maintains conversation history across agent runs In combination with Python, these primitives are powerful enough to express complex relationships between tools and agents, and allow you to build real-world applications without a steep learning curve. In addition, the SDK comes with built-in **tracing** that lets you visualize and debug your agentic flows, as well as evaluate them and even fine-tune models for your application. @@ -21,6 +22,7 @@ Here are the main features of the SDK: - Python-first: Use built-in language features to orchestrate and chain agents, rather than needing to learn new abstractions. - Handoffs: A powerful feature to coordinate and delegate between multiple agents. - Guardrails: Run input validations and checks in parallel to your agents, breaking early if the checks fail. +- Sessions: Automatic conversation history management across agent runs, eliminating manual state handling. - Function tools: Turn any Python function into a tool, with automatic schema generation and Pydantic-powered validation. - Tracing: Built-in tracing that lets you visualize, debug and monitor your workflows, as well as use the OpenAI suite of evaluation, fine-tuning and distillation tools. diff --git a/docs/ref/memory.md b/docs/ref/memory.md new file mode 100644 index 000000000..04a2258bf --- /dev/null +++ b/docs/ref/memory.md @@ -0,0 +1,8 @@ +# Memory + +::: agents.memory + + options: + members: + - Session + - SQLiteSession diff --git a/docs/release.md b/docs/release.md index a86103f96..8498c8b17 100644 --- a/docs/release.md +++ b/docs/release.md @@ -19,6 +19,10 @@ We will increment `Z` for non-breaking changes: ## Breaking change changelog +### 0.2.0 + +In this version, a few places that used to take `Agent` as an arg, now take `AgentBase` as an arg instead. For example, the `list_tools()` call in MCP servers. This is a purely typing change, you will still receive `Agent` objects. To update, just fix type errors by replacing `Agent` with `AgentBase`. + ### 0.1.0 In this version, [`MCPServer.list_tools()`][agents.mcp.server.MCPServer] has two new params: `run_context` and `agent`. You'll need to add these params to any classes that subclass `MCPServer`. diff --git a/docs/running_agents.md b/docs/running_agents.md index f631cf46f..6898f5101 100644 --- a/docs/running_agents.md +++ b/docs/running_agents.md @@ -65,7 +65,9 @@ Calling any of the run methods can result in one or more agents running (and hen At the end of the agent run, you can choose what to show to the user. For example, you might show the user every new item generated by the agents, or just the final output. Either way, the user might then ask a followup question, in which case you can call the run method again. -You can use the base [`RunResultBase.to_input_list()`][agents.result.RunResultBase.to_input_list] method to get the inputs for the next turn. +### Manual conversation management + +You can manually manage conversation history using the [`RunResultBase.to_input_list()`][agents.result.RunResultBase.to_input_list] method to get the inputs for the next turn: ```python async def main(): @@ -84,6 +86,39 @@ async def main(): # California ``` +### Automatic conversation management with Sessions + +For a simpler approach, you can use [Sessions](sessions.md) to automatically handle conversation history without manually calling `.to_input_list()`: + +```python +from agents import Agent, Runner, SQLiteSession + +async def main(): + agent = Agent(name="Assistant", instructions="Reply very concisely.") + + # Create session instance + session = SQLiteSession("conversation_123") + + with trace(workflow_name="Conversation", group_id=thread_id): + # First turn + result = await Runner.run(agent, "What city is the Golden Gate Bridge in?", session=session) + print(result.final_output) + # San Francisco + + # Second turn - agent automatically remembers previous context + result = await Runner.run(agent, "What state is it in?", session=session) + print(result.final_output) + # California +``` + +Sessions automatically: + +- Retrieves conversation history before each run +- Stores new messages after each run +- Maintains separate conversations for different session IDs + +See the [Sessions documentation](sessions.md) for more details. + ## Exceptions The SDK raises exceptions in certain cases. The full list is in [`agents.exceptions`][]. As an overview: diff --git a/docs/scripts/translate_docs.py b/docs/scripts/translate_docs.py index 5dada2681..ac40b6fa8 100644 --- a/docs/scripts/translate_docs.py +++ b/docs/scripts/translate_docs.py @@ -266,7 +266,9 @@ def translate_single_source_file(file_path: str) -> None: def main(): parser = argparse.ArgumentParser(description="Translate documentation files") - parser.add_argument("--file", type=str, help="Specific file to translate (relative to docs directory)") + parser.add_argument( + "--file", type=str, help="Specific file to translate (relative to docs directory)" + ) args = parser.parse_args() if args.file: diff --git a/docs/sessions.md b/docs/sessions.md new file mode 100644 index 000000000..956712438 --- /dev/null +++ b/docs/sessions.md @@ -0,0 +1,319 @@ +# Sessions + +The Agents SDK provides built-in session memory to automatically maintain conversation history across multiple agent runs, eliminating the need to manually handle `.to_input_list()` between turns. + +Sessions stores conversation history for a specific session, allowing agents to maintain context without requiring explicit manual memory management. This is particularly useful for building chat applications or multi-turn conversations where you want the agent to remember previous interactions. + +## Quick start + +```python +from agents import Agent, Runner, SQLiteSession + +# Create agent +agent = Agent( + name="Assistant", + instructions="Reply very concisely.", +) + +# Create a session instance with a session ID +session = SQLiteSession("conversation_123") + +# First turn +result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session +) +print(result.final_output) # "San Francisco" + +# Second turn - agent automatically remembers previous context +result = await Runner.run( + agent, + "What state is it in?", + session=session +) +print(result.final_output) # "California" + +# Also works with synchronous runner +result = Runner.run_sync( + agent, + "What's the population?", + session=session +) +print(result.final_output) # "Approximately 39 million" +``` + +## How it works + +When session memory is enabled: + +1. **Before each run**: The runner automatically retrieves the conversation history for the session and prepends it to the input items. +2. **After each run**: All new items generated during the run (user input, assistant responses, tool calls, etc.) are automatically stored in the session. +3. **Context preservation**: Each subsequent run with the same session includes the full conversation history, allowing the agent to maintain context. + +This eliminates the need to manually call `.to_input_list()` and manage conversation state between runs. + +## Memory operations + +### Basic operations + +Sessions supports several operations for managing conversation history: + +```python +from agents import SQLiteSession + +session = SQLiteSession("user_123", "conversations.db") + +# Get all items in a session +items = await session.get_items() + +# Add new items to a session +new_items = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} +] +await session.add_items(new_items) + +# Remove and return the most recent item +last_item = await session.pop_item() +print(last_item) # {"role": "assistant", "content": "Hi there!"} + +# Clear all items from a session +await session.clear_session() +``` + +### Using pop_item for corrections + +The `pop_item` method is particularly useful when you want to undo or modify the last item in a conversation: + +```python +from agents import Agent, Runner, SQLiteSession + +agent = Agent(name="Assistant") +session = SQLiteSession("correction_example") + +# Initial conversation +result = await Runner.run( + agent, + "What's 2 + 2?", + session=session +) +print(f"Agent: {result.final_output}") + +# User wants to correct their question +assistant_item = await session.pop_item() # Remove agent's response +user_item = await session.pop_item() # Remove user's question + +# Ask a corrected question +result = await Runner.run( + agent, + "What's 2 + 3?", + session=session +) +print(f"Agent: {result.final_output}") +``` + +## Memory options + +### No memory (default) + +```python +# Default behavior - no session memory +result = await Runner.run(agent, "Hello") +``` + +### SQLite memory + +```python +from agents import SQLiteSession + +# In-memory database (lost when process ends) +session = SQLiteSession("user_123") + +# Persistent file-based database +session = SQLiteSession("user_123", "conversations.db") + +# Use the session +result = await Runner.run( + agent, + "Hello", + session=session +) +``` + +### Multiple sessions + +```python +from agents import Agent, Runner, SQLiteSession + +agent = Agent(name="Assistant") + +# Different sessions maintain separate conversation histories +session_1 = SQLiteSession("user_123", "conversations.db") +session_2 = SQLiteSession("user_456", "conversations.db") + +result1 = await Runner.run( + agent, + "Hello", + session=session_1 +) +result2 = await Runner.run( + agent, + "Hello", + session=session_2 +) +``` + +## Custom memory implementations + +You can implement your own session memory by creating a class that follows the [`Session`][agents.memory.session.Session] protocol: + +````python +from agents.memory import Session +from typing import List + +class MyCustomSession: + """Custom session implementation following the Session protocol.""" + + def __init__(self, session_id: str): + self.session_id = session_id + # Your initialization here + + async def get_items(self, limit: int | None = None) -> List[dict]: + """Retrieve conversation history for this session.""" + # Your implementation here + pass + + async def add_items(self, items: List[dict]) -> None: + """Store new items for this session.""" + # Your implementation here + pass + + async def pop_item(self) -> dict | None: + """Remove and return the most recent item from this session.""" + # Your implementation here + pass + + async def clear_session(self) -> None: + """Clear all items for this session.""" + # Your implementation here + pass + +# Use your custom session +agent = Agent(name="Assistant") +result = await Runner.run( + agent, + "Hello", + session=MyCustomSession("my_session") +) + +## Session management + +### Session ID naming + +Use meaningful session IDs that help you organize conversations: + +- User-based: `"user_12345"` +- Thread-based: `"thread_abc123"` +- Context-based: `"support_ticket_456"` + +### Memory persistence + +- Use in-memory SQLite (`SQLiteSession("session_id")`) for temporary conversations +- Use file-based SQLite (`SQLiteSession("session_id", "path/to/db.sqlite")`) for persistent conversations +- Consider implementing custom session backends for production systems (Redis, PostgreSQL, etc.) + +### Session management + +```python +# Clear a session when conversation should start fresh +await session.clear_session() + +# Different agents can share the same session +support_agent = Agent(name="Support") +billing_agent = Agent(name="Billing") +session = SQLiteSession("user_123") + +# Both agents will see the same conversation history +result1 = await Runner.run( + support_agent, + "Help me with my account", + session=session +) +result2 = await Runner.run( + billing_agent, + "What are my charges?", + session=session +) +```` + +## Complete example + +Here's a complete example showing session memory in action: + +```python +import asyncio +from agents import Agent, Runner, SQLiteSession + + +async def main(): + # Create an agent + agent = Agent( + name="Assistant", + instructions="Reply very concisely.", + ) + + # Create a session instance that will persist across runs + session = SQLiteSession("conversation_123", "conversation_history.db") + + print("=== Sessions Example ===") + print("The agent will remember previous messages automatically.\n") + + # First turn + print("First turn:") + print("User: What city is the Golden Gate Bridge in?") + result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session + ) + print(f"Assistant: {result.final_output}") + print() + + # Second turn - the agent will remember the previous conversation + print("Second turn:") + print("User: What state is it in?") + result = await Runner.run( + agent, + "What state is it in?", + session=session + ) + print(f"Assistant: {result.final_output}") + print() + + # Third turn - continuing the conversation + print("Third turn:") + print("User: What's the population of that state?") + result = await Runner.run( + agent, + "What's the population of that state?", + session=session + ) + print(f"Assistant: {result.final_output}") + print() + + print("=== Conversation Complete ===") + print("Notice how the agent remembered the context from previous turns!") + print("Sessions automatically handles conversation history.") + + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## API Reference + +For detailed API documentation, see: + +- [`Session`][agents.memory.Session] - Protocol interface +- [`SQLiteSession`][agents.memory.SQLiteSession] - SQLite implementation diff --git a/docs/tools.md b/docs/tools.md index 6dba1a853..17f7da0a1 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -180,7 +180,7 @@ Sometimes, you don't want to use a Python function as a tool. You can directly c - `name` - `description` - `params_json_schema`, which is the JSON schema for the arguments -- `on_invoke_tool`, which is an async function that receives the context and the arguments as a JSON string, and must return the tool output as a string. +- `on_invoke_tool`, which is an async function that receives a [`ToolContext`][agents.tool_context.ToolContext] and the arguments as a JSON string, and must return the tool output as a string. ```python from typing import Any diff --git a/examples/agent_patterns/llm_as_a_judge.py b/examples/agent_patterns/llm_as_a_judge.py index 5a46cc3eb..81dec7501 100644 --- a/examples/agent_patterns/llm_as_a_judge.py +++ b/examples/agent_patterns/llm_as_a_judge.py @@ -32,7 +32,7 @@ class EvaluationFeedback: instructions=( "You evaluate a story outline and decide if it's good enough." "If it's not good enough, you provide feedback on what needs to be improved." - "Never give it a pass on the first try." + "Never give it a pass on the first try. After 5 attempts, you can give it a pass if story outline is good enough - do not go for perfection" ), output_type=EvaluationFeedback, ) diff --git a/examples/basic/hello_world_jupyter.ipynb b/examples/basic/hello_world_jupyter.ipynb index 42ee8e6a2..8dd3bb379 100644 --- a/examples/basic/hello_world_jupyter.ipynb +++ b/examples/basic/hello_world_jupyter.ipynb @@ -30,7 +30,7 @@ "agent = Agent(name=\"Assistant\", instructions=\"You are a helpful assistant\")\n", "\n", "# Intended for Jupyter notebooks where there's an existing event loop\n", - "result = await Runner.run(agent, \"Write a haiku about recursion in programming.\") # type: ignore[top-level-await] # noqa: F704\n", + "result = await Runner.run(agent, \"Write a haiku about recursion in programming.\") # type: ignore[top-level-await] # noqa: F704\n", "print(result.final_output)" ] } diff --git a/examples/basic/session_example.py b/examples/basic/session_example.py new file mode 100644 index 000000000..63d1d1b7c --- /dev/null +++ b/examples/basic/session_example.py @@ -0,0 +1,77 @@ +""" +Example demonstrating session memory functionality. + +This example shows how to use session memory to maintain conversation history +across multiple agent runs without manually handling .to_input_list(). +""" + +import asyncio + +from agents import Agent, Runner, SQLiteSession + + +async def main(): + # Create an agent + agent = Agent( + name="Assistant", + instructions="Reply very concisely.", + ) + + # Create a session instance that will persist across runs + session_id = "conversation_123" + session = SQLiteSession(session_id) + + print("=== Session Example ===") + print("The agent will remember previous messages automatically.\n") + + # First turn + print("First turn:") + print("User: What city is the Golden Gate Bridge in?") + result = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session, + ) + print(f"Assistant: {result.final_output}") + print() + + # Second turn - the agent will remember the previous conversation + print("Second turn:") + print("User: What state is it in?") + result = await Runner.run(agent, "What state is it in?", session=session) + print(f"Assistant: {result.final_output}") + print() + + # Third turn - continuing the conversation + print("Third turn:") + print("User: What's the population of that state?") + result = await Runner.run( + agent, + "What's the population of that state?", + session=session, + ) + print(f"Assistant: {result.final_output}") + print() + + print("=== Conversation Complete ===") + print("Notice how the agent remembered the context from previous turns!") + print("Sessions automatically handles conversation history.") + + # Demonstrate the limit parameter - get only the latest 2 items + print("\n=== Latest Items Demo ===") + latest_items = await session.get_items(limit=2) + print("Latest 2 items:") + for i, msg in enumerate(latest_items, 1): + role = msg.get("role", "unknown") + content = msg.get("content", "") + print(f" {i}. {role}: {content}") + + print(f"\nFetched {len(latest_items)} out of total conversation history.") + + # Get all items to show the difference + all_items = await session.get_items() + print(f"Total items in session: {len(all_items)}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/mcp/prompt_server/main.py b/examples/mcp/prompt_server/main.py index 8f2991fc0..4caa95d88 100644 --- a/examples/mcp/prompt_server/main.py +++ b/examples/mcp/prompt_server/main.py @@ -17,7 +17,7 @@ async def get_instructions_from_prompt(mcp_server: MCPServer, prompt_name: str, try: prompt_result = await mcp_server.get_prompt(prompt_name, kwargs) content = prompt_result.messages[0].content - if hasattr(content, 'text'): + if hasattr(content, "text"): instructions = content.text else: instructions = str(content) diff --git a/examples/realtime/demo.py b/examples/realtime/demo.py new file mode 100644 index 000000000..3db051963 --- /dev/null +++ b/examples/realtime/demo.py @@ -0,0 +1,115 @@ +import asyncio +import os +import sys +from typing import TYPE_CHECKING + +import numpy as np + +from agents.realtime import RealtimeSession + +# Add the current directory to path so we can import ui +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from agents import function_tool +from agents.realtime import RealtimeAgent, RealtimeRunner, RealtimeSessionEvent + +if TYPE_CHECKING: + from .ui import AppUI +else: + # Try both import styles + try: + # Try relative import first (when used as a package) + from .ui import AppUI + except ImportError: + # Fall back to direct import (when run as a script) + from ui import AppUI + + +@function_tool +def get_weather(city: str) -> str: + """Get the weather in a city.""" + return f"The weather in {city} is sunny." + + +agent = RealtimeAgent( + name="Assistant", + instructions="You always greet the user with 'Top of the morning to you'.", + tools=[get_weather], +) + + +def _truncate_str(s: str, max_length: int) -> str: + if len(s) > max_length: + return s[:max_length] + "..." + return s + + +class Example: + def __init__(self) -> None: + self.ui = AppUI() + self.ui.connected = asyncio.Event() + self.ui.last_audio_item_id = None + # Set the audio callback + self.ui.set_audio_callback(self.on_audio_recorded) + + self.session: RealtimeSession | None = None + + async def run(self) -> None: + # Start UI in a separate task instead of waiting for it to complete + ui_task = asyncio.create_task(self.ui.run_async()) + + # Set up session immediately without waiting for UI to finish + runner = RealtimeRunner(agent) + async with await runner.run() as session: + self.session = session + self.ui.set_is_connected(True) + async for event in session: + await self._on_event(event) + print("done") + + # Wait for UI task to complete when session ends + await ui_task + + async def on_audio_recorded(self, audio_bytes: bytes) -> None: + # Send the audio to the session + assert self.session is not None + await self.session.send_audio(audio_bytes) + + async def _on_event(self, event: RealtimeSessionEvent) -> None: + try: + if event.type == "agent_start": + self.ui.add_transcript(f"Agent started: {event.agent.name}") + elif event.type == "agent_end": + self.ui.add_transcript(f"Agent ended: {event.agent.name}") + elif event.type == "handoff": + self.ui.add_transcript( + f"Handoff from {event.from_agent.name} to {event.to_agent.name}" + ) + elif event.type == "tool_start": + self.ui.add_transcript(f"Tool started: {event.tool.name}") + elif event.type == "tool_end": + self.ui.add_transcript(f"Tool ended: {event.tool.name}; output: {event.output}") + elif event.type == "audio_end": + self.ui.add_transcript("Audio ended") + elif event.type == "audio": + np_audio = np.frombuffer(event.audio.data, dtype=np.int16) + self.ui.play_audio(np_audio) + elif event.type == "audio_interrupted": + self.ui.add_transcript("Audio interrupted") + elif event.type == "error": + pass + elif event.type == "history_updated": + pass + elif event.type == "history_added": + pass + elif event.type == "raw_model_event": + self.ui.log_message(f"Raw model event: {_truncate_str(str(event.data), 50)}") + else: + self.ui.log_message(f"Unknown event type: {event.type}") + except Exception as e: + self.ui.log_message(f"Error processing event: {_truncate_str(str(e), 50)}") + + +if __name__ == "__main__": + example = Example() + asyncio.run(example.run()) diff --git a/examples/realtime/ui.py b/examples/realtime/ui.py new file mode 100644 index 000000000..51a1fed41 --- /dev/null +++ b/examples/realtime/ui.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Coroutine +from typing import Any, Callable + +import numpy as np +import numpy.typing as npt +import sounddevice as sd +from textual import events +from textual.app import App, ComposeResult +from textual.containers import Container, Horizontal +from textual.reactive import reactive +from textual.widgets import RichLog, Static +from typing_extensions import override + +CHUNK_LENGTH_S = 0.05 # 50ms +SAMPLE_RATE = 24000 +FORMAT = np.int16 +CHANNELS = 1 + + +class Header(Static): + """A header widget.""" + + @override + def render(self) -> str: + return "Realtime Demo" + + +class AudioStatusIndicator(Static): + """A widget that shows the current audio recording status.""" + + is_recording = reactive(False) + + @override + def render(self) -> str: + status = ( + "🔴 Conversation started." + if self.is_recording + else "⚪ Press SPACE to start the conversation (q to quit)" + ) + return status + + +class AppUI(App[None]): + CSS = """ + Screen { + background: #1a1b26; /* Dark blue-grey background */ + } + + Container { + border: double rgb(91, 164, 91); + } + + #input-container { + height: 5; /* Explicit height for input container */ + margin: 1 1; + padding: 1 2; + } + + #bottom-pane { + width: 100%; + height: 82%; /* Reduced to make room for session display */ + border: round rgb(205, 133, 63); + } + + #status-indicator { + height: 3; + content-align: center middle; + background: #2a2b36; + border: solid rgb(91, 164, 91); + margin: 1 1; + } + + #session-display { + height: 3; + content-align: center middle; + background: #2a2b36; + border: solid rgb(91, 164, 91); + margin: 1 1; + } + + #transcripts { + width: 50%; + height: 100%; + border-right: solid rgb(91, 164, 91); + } + + #transcripts-header { + height: 2; + background: #2a2b36; + content-align: center middle; + border-bottom: solid rgb(91, 164, 91); + } + + #transcripts-content { + height: 100%; + } + + #event-log { + width: 50%; + height: 100%; + } + + #event-log-header { + height: 2; + background: #2a2b36; + content-align: center middle; + border-bottom: solid rgb(91, 164, 91); + } + + #event-log-content { + height: 100%; + } + + Static { + color: white; + } + """ + + should_send_audio: asyncio.Event + connected: asyncio.Event + last_audio_item_id: str | None + audio_callback: Callable[[bytes], Coroutine[Any, Any, None]] | None + + def __init__(self) -> None: + super().__init__() + self.audio_player = sd.OutputStream( + samplerate=SAMPLE_RATE, + channels=CHANNELS, + dtype=FORMAT, + ) + self.should_send_audio = asyncio.Event() + self.connected = asyncio.Event() + self.audio_callback = None + + @override + def compose(self) -> ComposeResult: + """Create child widgets for the app.""" + with Container(): + yield Header(id="session-display") + yield AudioStatusIndicator(id="status-indicator") + with Container(id="bottom-pane"): + with Horizontal(): + with Container(id="transcripts"): + yield Static("Conversation transcript", id="transcripts-header") + yield RichLog( + id="transcripts-content", wrap=True, highlight=True, markup=True + ) + with Container(id="event-log"): + yield Static("Raw event log", id="event-log-header") + yield RichLog( + id="event-log-content", wrap=True, highlight=True, markup=True + ) + + def set_is_connected(self, is_connected: bool) -> None: + self.connected.set() if is_connected else self.connected.clear() + + def set_audio_callback(self, callback: Callable[[bytes], Coroutine[Any, Any, None]]) -> None: + """Set a callback function to be called when audio is recorded.""" + self.audio_callback = callback + + # High-level methods for UI operations + def set_header_text(self, text: str) -> None: + """Update the header text.""" + header = self.query_one("#session-display", Header) + header.update(text) + + def set_recording_status(self, is_recording: bool) -> None: + """Set the recording status indicator.""" + status_indicator = self.query_one(AudioStatusIndicator) + status_indicator.is_recording = is_recording + + def log_message(self, message: str) -> None: + """Add a message to the event log.""" + try: + log_pane = self.query_one("#event-log-content", RichLog) + log_pane.write(message) + except Exception: + # Handle the case where the widget might not be available + pass + + def add_transcript(self, message: str) -> None: + """Add a transcript message to the transcripts panel.""" + try: + transcript_pane = self.query_one("#transcripts-content", RichLog) + transcript_pane.write(message) + except Exception: + # Handle the case where the widget might not be available + pass + + def play_audio(self, audio_data: npt.NDArray[np.int16]) -> None: + """Play audio data through the audio player.""" + try: + self.audio_player.write(audio_data) + except Exception as e: + self.log_message(f"Audio play error: {e}") + + async def on_mount(self) -> None: + """Set up audio player and start the audio capture worker.""" + self.audio_player.start() + self.run_worker(self.capture_audio()) + + async def capture_audio(self) -> None: + """Capture audio from the microphone and send to the session.""" + # Wait for connection to be established + await self.connected.wait() + + # Set up audio input stream + stream = sd.InputStream( + channels=CHANNELS, + samplerate=SAMPLE_RATE, + dtype=FORMAT, + ) + + try: + # Wait for user to press spacebar to start + await self.should_send_audio.wait() + + stream.start() + self.set_recording_status(True) + self.log_message("Recording started - speak to the agent") + + # Buffer size in samples + read_size = int(SAMPLE_RATE * CHUNK_LENGTH_S) + + while True: + # Check if there's enough data to read + if stream.read_available < read_size: + await asyncio.sleep(0.01) # Small sleep to avoid CPU hogging + continue + + # Read audio data + data, _ = stream.read(read_size) + + # Convert numpy array to bytes + audio_bytes = data.tobytes() + + # Call audio callback if set + if self.audio_callback: + await self.audio_callback(audio_bytes) + + # Yield control back to event loop + await asyncio.sleep(0) + + except Exception as e: + self.log_message(f"Audio capture error: {e}") + finally: + if stream.active: + stream.stop() + stream.close() + + async def on_key(self, event: events.Key) -> None: + """Handle key press events.""" + # add the keypress to the log + self.log_message(f"Key pressed: {event.key}") + + if event.key == "q": + self.audio_player.stop() + self.audio_player.close() + self.exit() + return + + if event.key == "space": # Spacebar + if not self.should_send_audio.is_set(): + self.should_send_audio.set() + self.set_recording_status(True) diff --git a/examples/reasoning_content/main.py b/examples/reasoning_content/main.py index 5f67e1779..c23b04254 100644 --- a/examples/reasoning_content/main.py +++ b/examples/reasoning_content/main.py @@ -44,7 +44,7 @@ async def stream_with_reasoning_content(): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, - prompt=None + prompt=None, ): if event.type == "response.reasoning_summary_text.delta": print( @@ -82,7 +82,7 @@ async def get_response_with_reasoning_content(): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, - prompt=None + prompt=None, ) # Extract reasoning content and regular content from the response diff --git a/mkdocs.yml b/mkdocs.yml index b79e6454f..19529bf30 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -57,6 +57,7 @@ plugins: - Documentation: - agents.md - running_agents.md + - sessions.md - results.md - streaming.md - repl.md @@ -82,6 +83,7 @@ plugins: - ref/index.md - ref/agent.md - ref/run.md + - ref/memory.md - ref/repl.md - ref/tool.md - ref/result.md @@ -90,6 +92,7 @@ plugins: - ref/lifecycle.md - ref/items.md - ref/run_context.md + - ref/tool_context.md - ref/usage.md - ref/exceptions.md - ref/guardrail.md @@ -140,6 +143,7 @@ plugins: - ドキュメント: - agents.md - running_agents.md + - sessions.md - results.md - streaming.md - repl.md diff --git a/pyproject.toml b/pyproject.toml index e659348ca..0f9b70852 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,13 @@ [project] name = "openai-agents" -version = "0.1.0" +version = "0.2.0" description = "OpenAI Agents SDK" readme = "README.md" requires-python = ">=3.9" license = "MIT" authors = [{ name = "OpenAI", email = "support@openai.com" }] dependencies = [ - "openai>=1.87.0", + "openai>=1.93.1, <2", "pydantic>=2.10, <3", "griffe>=1.5.6, <2", "typing-extensions>=4.12.2, <5", @@ -37,6 +37,7 @@ Repository = "https://github.com/openai/openai-agents-python" voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"] viz = ["graphviz>=0.17"] litellm = ["litellm>=1.67.4.post1, <2"] +realtime = ["websockets>=15.0, <16"] [dependency-groups] dev = [ diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 1296b72be..7de17efdb 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -5,7 +5,7 @@ from openai import AsyncOpenAI from . import _config -from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult +from .agent import Agent, AgentBase, ToolsToFinalOutputFunction, ToolsToFinalOutputResult from .agent_output import AgentOutputSchema, AgentOutputSchemaBase from .computer import AsyncComputer, Button, Computer, Environment from .exceptions import ( @@ -40,6 +40,7 @@ TResponseInputItem, ) from .lifecycle import AgentHooks, RunHooks +from .memory import Session, SQLiteSession from .model_settings import ModelSettings from .models.interface import Model, ModelProvider, ModelTracing from .models.openai_chatcompletions import OpenAIChatCompletionsModel @@ -160,6 +161,7 @@ def enable_verbose_stdout_logging(): __all__ = [ "Agent", + "AgentBase", "ToolsToFinalOutputFunction", "ToolsToFinalOutputResult", "Runner", @@ -209,6 +211,8 @@ def enable_verbose_stdout_logging(): "ItemHelpers", "RunHooks", "AgentHooks", + "Session", + "SQLiteSession", "RunContextWrapper", "TContext", "RunErrorDetails", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 4ac8b316b..a83af62a1 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -548,7 +548,11 @@ async def run_single_tool( func_tool: FunctionTool, tool_call: ResponseFunctionToolCall ) -> Any: with function_span(func_tool.name) as span_fn: - tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id) + tool_context = ToolContext.from_agent_context( + context_wrapper, + tool_call.call_id, + tool_call=tool_call, + ) if config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments try: diff --git a/src/agents/agent.py b/src/agents/agent.py index 6c87297f1..9c107a81b 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -67,7 +67,63 @@ class MCPConfig(TypedDict): @dataclass -class Agent(Generic[TContext]): +class AgentBase(Generic[TContext]): + """Base class for `Agent` and `RealtimeAgent`.""" + + name: str + """The name of the agent.""" + + handoff_description: str | None = None + """A description of the agent. This is used when the agent is used as a handoff, so that an + LLM knows what it does and when to invoke it. + """ + + tools: list[Tool] = field(default_factory=list) + """A list of tools that the agent can use.""" + + mcp_servers: list[MCPServer] = field(default_factory=list) + """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that + the agent can use. Every time the agent runs, it will include tools from these servers in the + list of available tools. + + NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call + `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no + longer needed. + """ + + mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig()) + """Configuration for MCP servers.""" + + async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]: + """Fetches the available tools from the MCP servers.""" + convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) + return await MCPUtil.get_all_function_tools( + self.mcp_servers, convert_schemas_to_strict, run_context, self + ) + + async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]: + """All agent tools, including MCP tools and function tools.""" + mcp_tools = await self.get_mcp_tools(run_context) + + async def _check_tool_enabled(tool: Tool) -> bool: + if not isinstance(tool, FunctionTool): + return True + + attr = tool.is_enabled + if isinstance(attr, bool): + return attr + res = attr(run_context, self) + if inspect.isawaitable(res): + return bool(await res) + return bool(res) + + results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools)) + enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok] + return [*mcp_tools, *enabled] + + +@dataclass +class Agent(AgentBase, Generic[TContext]): """An agent is an AI model configured with instructions, tools, guardrails, handoffs and more. We strongly recommend passing `instructions`, which is the "system prompt" for the agent. In @@ -76,10 +132,9 @@ class Agent(Generic[TContext]): Agents are generic on the context type. The context is a (mutable) object you create. It is passed to tool functions, handoffs, guardrails, etc. - """ - name: str - """The name of the agent.""" + See `AgentBase` for base parameters that are shared with `RealtimeAgent`s. + """ instructions: ( str @@ -103,11 +158,6 @@ class Agent(Generic[TContext]): usable with OpenAI models, using the Responses API. """ - handoff_description: str | None = None - """A description of the agent. This is used when the agent is used as a handoff, so that an - LLM knows what it does and when to invoke it. - """ - handoffs: list[Agent[Any] | Handoff[TContext]] = field(default_factory=list) """Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs, and the agent can choose to delegate to them if relevant. Allows for separation of concerns and @@ -125,22 +175,6 @@ class Agent(Generic[TContext]): """Configures model-specific tuning parameters (e.g. temperature, top_p). """ - tools: list[Tool] = field(default_factory=list) - """A list of tools that the agent can use.""" - - mcp_servers: list[MCPServer] = field(default_factory=list) - """A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that - the agent can use. Every time the agent runs, it will include tools from these servers in the - list of available tools. - - NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call - `server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no - longer needed. - """ - - mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig()) - """Configuration for MCP servers.""" - input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list) """A list of checks that run in parallel to the agent's execution, before generating a response. Runs only if the agent is the first agent in the chain. @@ -176,7 +210,7 @@ class Agent(Generic[TContext]): The final output will be the output of the first matching tool call. The LLM does not process the result of the tool call. - A function: If you pass a function, it will be called with the run context and the list of - tool results. It must return a `ToolToFinalOutputResult`, which determines whether the tool + tool results. It must return a `ToolsToFinalOutputResult`, which determines whether the tool calls result in a final output. NOTE: This configuration is specific to FunctionTools. Hosted tools, such as file search, @@ -256,9 +290,7 @@ async def get_prompt( """Get the prompt for the agent.""" return await PromptUtil.to_model_input(self.prompt, run_context, self) - async def get_mcp_tools( - self, run_context: RunContextWrapper[TContext] - ) -> list[Tool]: + async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]: """Fetches the available tools from the MCP servers.""" convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) return await MCPUtil.get_all_function_tools( diff --git a/src/agents/function_schema.py b/src/agents/function_schema.py index e1a91e189..ec54d8227 100644 --- a/src/agents/function_schema.py +++ b/src/agents/function_schema.py @@ -9,6 +9,7 @@ from griffe import Docstring, DocstringSectionKind from pydantic import BaseModel, Field, create_model +from pydantic.fields import FieldInfo from .exceptions import UserError from .run_context import RunContextWrapper @@ -319,6 +320,14 @@ def function_schema( ann, Field(..., description=field_description), ) + elif isinstance(default, FieldInfo): + # Parameter with a default value that is a Field(...) + fields[name] = ( + ann, + FieldInfo.merge_field_infos( + default, description=field_description or default.description + ), + ) else: # Parameter with a default value fields[name] = ( diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index a96f0f7d7..f8a272b53 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -241,7 +241,11 @@ async def my_async_guardrail(...): ... def decorator( f: _InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co], ) -> InputGuardrail[TContext_co]: - return InputGuardrail(guardrail_function=f, name=name) + return InputGuardrail( + guardrail_function=f, + # If not set, guardrail name uses the function’s name by default. + name=name if name else f.__name__ + ) if func is not None: # Decorator was used without parentheses diff --git a/src/agents/lifecycle.py b/src/agents/lifecycle.py index 8643248b1..2cce496c8 100644 --- a/src/agents/lifecycle.py +++ b/src/agents/lifecycle.py @@ -1,25 +1,27 @@ from typing import Any, Generic -from .agent import Agent +from typing_extensions import TypeVar + +from .agent import Agent, AgentBase from .run_context import RunContextWrapper, TContext from .tool import Tool +TAgent = TypeVar("TAgent", bound=AgentBase, default=AgentBase) + -class RunHooks(Generic[TContext]): +class RunHooksBase(Generic[TContext, TAgent]): """A class that receives callbacks on various lifecycle events in an agent run. Subclass and override the methods you need. """ - async def on_agent_start( - self, context: RunContextWrapper[TContext], agent: Agent[TContext] - ) -> None: + async def on_agent_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None: """Called before the agent is invoked. Called each time the current agent changes.""" pass async def on_agent_end( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, output: Any, ) -> None: """Called when the agent produces a final output.""" @@ -28,8 +30,8 @@ async def on_agent_end( async def on_handoff( self, context: RunContextWrapper[TContext], - from_agent: Agent[TContext], - to_agent: Agent[TContext], + from_agent: TAgent, + to_agent: TAgent, ) -> None: """Called when a handoff occurs.""" pass @@ -37,7 +39,7 @@ async def on_handoff( async def on_tool_start( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, tool: Tool, ) -> None: """Called before a tool is invoked.""" @@ -46,7 +48,7 @@ async def on_tool_start( async def on_tool_end( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, tool: Tool, result: str, ) -> None: @@ -54,14 +56,14 @@ async def on_tool_end( pass -class AgentHooks(Generic[TContext]): +class AgentHooksBase(Generic[TContext, TAgent]): """A class that receives callbacks on various lifecycle events for a specific agent. You can set this on `agent.hooks` to receive events for that specific agent. Subclass and override the methods you need. """ - async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None: + async def on_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None: """Called before the agent is invoked. Called each time the running agent is changed to this agent.""" pass @@ -69,7 +71,7 @@ async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TCon async def on_end( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, output: Any, ) -> None: """Called when the agent produces a final output.""" @@ -78,8 +80,8 @@ async def on_end( async def on_handoff( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], - source: Agent[TContext], + agent: TAgent, + source: TAgent, ) -> None: """Called when the agent is being handed off to. The `source` is the agent that is handing off to this agent.""" @@ -88,7 +90,7 @@ async def on_handoff( async def on_tool_start( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, tool: Tool, ) -> None: """Called before a tool is invoked.""" @@ -97,9 +99,16 @@ async def on_tool_start( async def on_tool_end( self, context: RunContextWrapper[TContext], - agent: Agent[TContext], + agent: TAgent, tool: Tool, result: str, ) -> None: """Called after a tool is invoked.""" pass + + +RunHooks = RunHooksBase[TContext, Agent] +"""Run hooks when using `Agent`.""" + +AgentHooks = AgentHooksBase[TContext, Agent] +"""Agent hooks for `Agent`s.""" diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 4fd606e34..91a9274fc 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -22,7 +22,7 @@ from .util import ToolFilter, ToolFilterCallable, ToolFilterContext, ToolFilterStatic if TYPE_CHECKING: - from ..agent import Agent + from ..agent import AgentBase class MCPServer(abc.ABC): @@ -53,7 +53,7 @@ async def cleanup(self): async def list_tools( self, run_context: RunContextWrapper[Any] | None = None, - agent: Agent[Any] | None = None, + agent: AgentBase | None = None, ) -> list[MCPTool]: """List the tools available on the server.""" pass @@ -117,7 +117,7 @@ async def _apply_tool_filter( self, tools: list[MCPTool], run_context: RunContextWrapper[Any], - agent: Agent[Any], + agent: AgentBase, ) -> list[MCPTool]: """Apply the tool filter to the list of tools.""" if self.tool_filter is None: @@ -153,7 +153,7 @@ async def _apply_dynamic_tool_filter( self, tools: list[MCPTool], run_context: RunContextWrapper[Any], - agent: Agent[Any], + agent: AgentBase, ) -> list[MCPTool]: """Apply dynamic tool filtering using a callable filter function.""" @@ -244,7 +244,7 @@ async def connect(self): async def list_tools( self, run_context: RunContextWrapper[Any] | None = None, - agent: Agent[Any] | None = None, + agent: AgentBase | None = None, ) -> list[MCPTool]: """List the tools available on the server.""" if not self.session: diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 48da9f841..18cf4440a 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -5,12 +5,11 @@ from typing_extensions import NotRequired, TypedDict -from agents.strict_schema import ensure_strict_json_schema - from .. import _debug from ..exceptions import AgentsException, ModelBehaviorError, UserError from ..logger import logger from ..run_context import RunContextWrapper +from ..strict_schema import ensure_strict_json_schema from ..tool import FunctionTool, Tool from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span from ..util._types import MaybeAwaitable @@ -18,7 +17,7 @@ if TYPE_CHECKING: from mcp.types import Tool as MCPTool - from ..agent import Agent + from ..agent import AgentBase from .server import MCPServer @@ -29,7 +28,7 @@ class ToolFilterContext: run_context: RunContextWrapper[Any] """The current run context.""" - agent: "Agent[Any]" + agent: "AgentBase" """The agent that is requesting the tool list.""" server_name: str @@ -100,7 +99,7 @@ async def get_all_function_tools( servers: list["MCPServer"], convert_schemas_to_strict: bool, run_context: RunContextWrapper[Any], - agent: "Agent[Any]", + agent: "AgentBase", ) -> list[Tool]: """Get all function tools from a list of MCP servers.""" tools = [] @@ -126,7 +125,7 @@ async def get_function_tools( server: "MCPServer", convert_schemas_to_strict: bool, run_context: RunContextWrapper[Any], - agent: "Agent[Any]", + agent: "AgentBase", ) -> list[Tool]: """Get all function tools from a single MCP server.""" diff --git a/src/agents/memory/__init__.py b/src/agents/memory/__init__.py new file mode 100644 index 000000000..059ca57ab --- /dev/null +++ b/src/agents/memory/__init__.py @@ -0,0 +1,3 @@ +from .session import Session, SQLiteSession + +__all__ = ["Session", "SQLiteSession"] diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py new file mode 100644 index 000000000..8db0971eb --- /dev/null +++ b/src/agents/memory/session.py @@ -0,0 +1,369 @@ +from __future__ import annotations + +import asyncio +import json +import sqlite3 +import threading +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from ..items import TResponseInputItem + + +@runtime_checkable +class Session(Protocol): + """Protocol for session implementations. + + Session stores conversation history for a specific session, allowing + agents to maintain context without requiring explicit manual memory management. + """ + + session_id: str + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + """Retrieve the conversation history for this session. + + Args: + limit: Maximum number of items to retrieve. If None, retrieves all items. + When specified, returns the latest N items in chronological order. + + Returns: + List of input items representing the conversation history + """ + ... + + async def add_items(self, items: list[TResponseInputItem]) -> None: + """Add new items to the conversation history. + + Args: + items: List of input items to add to the history + """ + ... + + async def pop_item(self) -> TResponseInputItem | None: + """Remove and return the most recent item from the session. + + Returns: + The most recent item if it exists, None if the session is empty + """ + ... + + async def clear_session(self) -> None: + """Clear all items for this session.""" + ... + + +class SessionABC(ABC): + """Abstract base class for session implementations. + + Session stores conversation history for a specific session, allowing + agents to maintain context without requiring explicit manual memory management. + + This ABC is intended for internal use and as a base class for concrete implementations. + Third-party libraries should implement the Session protocol instead. + """ + + session_id: str + + @abstractmethod + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + """Retrieve the conversation history for this session. + + Args: + limit: Maximum number of items to retrieve. If None, retrieves all items. + When specified, returns the latest N items in chronological order. + + Returns: + List of input items representing the conversation history + """ + ... + + @abstractmethod + async def add_items(self, items: list[TResponseInputItem]) -> None: + """Add new items to the conversation history. + + Args: + items: List of input items to add to the history + """ + ... + + @abstractmethod + async def pop_item(self) -> TResponseInputItem | None: + """Remove and return the most recent item from the session. + + Returns: + The most recent item if it exists, None if the session is empty + """ + ... + + @abstractmethod + async def clear_session(self) -> None: + """Clear all items for this session.""" + ... + + +class SQLiteSession(SessionABC): + """SQLite-based implementation of session storage. + + This implementation stores conversation history in a SQLite database. + By default, uses an in-memory database that is lost when the process ends. + For persistent storage, provide a file path. + """ + + def __init__( + self, + session_id: str, + db_path: str | Path = ":memory:", + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + ): + """Initialize the SQLite session. + + Args: + session_id: Unique identifier for the conversation session + db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database) + sessions_table: Name of the table to store session metadata. Defaults to + 'agent_sessions' + messages_table: Name of the table to store message data. Defaults to 'agent_messages' + """ + self.session_id = session_id + self.db_path = db_path + self.sessions_table = sessions_table + self.messages_table = messages_table + self._local = threading.local() + self._lock = threading.Lock() + + # For in-memory databases, we need a shared connection to avoid thread isolation + # For file databases, we use thread-local connections for better concurrency + self._is_memory_db = str(db_path) == ":memory:" + if self._is_memory_db: + self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False) + self._shared_connection.execute("PRAGMA journal_mode=WAL") + self._init_db_for_connection(self._shared_connection) + else: + # For file databases, initialize the schema once since it persists + init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + init_conn.execute("PRAGMA journal_mode=WAL") + self._init_db_for_connection(init_conn) + init_conn.close() + + def _get_connection(self) -> sqlite3.Connection: + """Get a database connection.""" + if self._is_memory_db: + # Use shared connection for in-memory database to avoid thread isolation + return self._shared_connection + else: + # Use thread-local connections for file databases + if not hasattr(self._local, "connection"): + self._local.connection = sqlite3.connect( + str(self.db_path), + check_same_thread=False, + ) + self._local.connection.execute("PRAGMA journal_mode=WAL") + assert isinstance(self._local.connection, sqlite3.Connection), ( + f"Expected sqlite3.Connection, got {type(self._local.connection)}" + ) + return self._local.connection + + def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: + """Initialize the database schema for a specific connection.""" + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.sessions_table} ( + session_id TEXT PRIMARY KEY, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.messages_table} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + message_data TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id) + ON DELETE CASCADE + ) + """ + ) + + conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.messages_table}_session_id + ON {self.messages_table} (session_id, created_at) + """ + ) + + conn.commit() + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + """Retrieve the conversation history for this session. + + Args: + limit: Maximum number of items to retrieve. If None, retrieves all items. + When specified, returns the latest N items in chronological order. + + Returns: + List of input items representing the conversation history + """ + + def _get_items_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + if limit is None: + # Fetch all items in chronological order + cursor = conn.execute( + f""" + SELECT message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at ASC + """, + (self.session_id,), + ) + else: + # Fetch the latest N items in chronological order + cursor = conn.execute( + f""" + SELECT message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at DESC + LIMIT ? + """, + (self.session_id, limit), + ) + + rows = cursor.fetchall() + + # Reverse to get chronological order when using DESC + if limit is not None: + rows = list(reversed(rows)) + + items = [] + for (message_data,) in rows: + try: + item = json.loads(message_data) + items.append(item) + except json.JSONDecodeError: + # Skip invalid JSON entries + continue + + return items + + return await asyncio.to_thread(_get_items_sync) + + async def add_items(self, items: list[TResponseInputItem]) -> None: + """Add new items to the conversation history. + + Args: + items: List of input items to add to the history + """ + if not items: + return + + def _add_items_sync(): + conn = self._get_connection() + + with self._lock if self._is_memory_db else threading.Lock(): + # Ensure session exists + conn.execute( + f""" + INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) + """, + (self.session_id,), + ) + + # Add items + message_data = [(self.session_id, json.dumps(item)) for item in items] + conn.executemany( + f""" + INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) + """, + message_data, + ) + + # Update session timestamp + conn.execute( + f""" + UPDATE {self.sessions_table} + SET updated_at = CURRENT_TIMESTAMP + WHERE session_id = ? + """, + (self.session_id,), + ) + + conn.commit() + + await asyncio.to_thread(_add_items_sync) + + async def pop_item(self) -> TResponseInputItem | None: + """Remove and return the most recent item from the session. + + Returns: + The most recent item if it exists, None if the session is empty + """ + + def _pop_item_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + # Use DELETE with RETURNING to atomically delete and return the most recent item + cursor = conn.execute( + f""" + DELETE FROM {self.messages_table} + WHERE id = ( + SELECT id FROM {self.messages_table} + WHERE session_id = ? + ORDER BY created_at DESC + LIMIT 1 + ) + RETURNING message_data + """, + (self.session_id,), + ) + + result = cursor.fetchone() + conn.commit() + + if result: + message_data = result[0] + try: + item = json.loads(message_data) + return item + except json.JSONDecodeError: + # Return None for corrupted JSON entries (already deleted) + return None + + return None + + return await asyncio.to_thread(_pop_item_sync) + + async def clear_session(self) -> None: + """Clear all items for this session.""" + + def _clear_session_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + conn.execute( + f"DELETE FROM {self.messages_table} WHERE session_id = ?", + (self.session_id,), + ) + conn.execute( + f"DELETE FROM {self.sessions_table} WHERE session_id = ?", + (self.session_id,), + ) + conn.commit() + + await asyncio.to_thread(_clear_session_sync) + + def close(self) -> None: + """Close the database connection.""" + if self._is_memory_db: + if hasattr(self, "_shared_connection"): + self._shared_connection.close() + else: + if hasattr(self._local, "connection"): + self._local.connection.close() diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index 26af94ba3..edb692960 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -17,9 +17,9 @@ class _OmitTypeAnnotation: @classmethod def __get_pydantic_core_schema__( - cls, - _source_type: Any, - _handler: GetCoreSchemaHandler, + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, ) -> core_schema.CoreSchema: def validate_from_none(value: None) -> _Omit: return _Omit() @@ -39,12 +39,20 @@ def validate_from_none(value: None) -> _Omit: from_none_schema, ] ), - serialization=core_schema.plain_serializer_function_ser_schema( - lambda instance: None - ), + serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: None), ) + + +@dataclass +class MCPToolChoice: + server_label: str + name: str + + Omit = Annotated[_Omit, _OmitTypeAnnotation] Headers: TypeAlias = Mapping[str, Union[str, Omit]] +ToolChoice: TypeAlias = Union[Literal["auto", "required", "none"], str, MCPToolChoice, None] + @dataclass class ModelSettings: @@ -69,7 +77,7 @@ class ModelSettings: presence_penalty: float | None = None """The presence penalty to use when calling the model.""" - tool_choice: Literal["auto", "required", "none"] | str | None = None + tool_choice: ToolChoice | None = None """The tool choice to use when calling the model.""" parallel_tool_calls: bool | None = None diff --git a/src/agents/models/chatcmpl_converter.py b/src/agents/models/chatcmpl_converter.py index 9d0c6cf5e..d3c71c24e 100644 --- a/src/agents/models/chatcmpl_converter.py +++ b/src/agents/models/chatcmpl_converter.py @@ -44,6 +44,7 @@ from ..exceptions import AgentsException, UserError from ..handoffs import Handoff from ..items import TResponseInputItem, TResponseOutputItem +from ..model_settings import MCPToolChoice from ..tool import FunctionTool, Tool from .fake_id import FAKE_RESPONSES_ID @@ -51,10 +52,12 @@ class Converter: @classmethod def convert_tool_choice( - cls, tool_choice: Literal["auto", "required", "none"] | str | None + cls, tool_choice: Literal["auto", "required", "none"] | str | MCPToolChoice | None ) -> ChatCompletionToolChoiceOptionParam | NotGiven: if tool_choice is None: return NOT_GIVEN + elif isinstance(tool_choice, MCPToolChoice): + raise UserError("MCPToolChoice is not supported for Chat Completions models") elif tool_choice == "auto": return "auto" elif tool_choice == "required": diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index a7ce62983..76c67903c 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -25,6 +25,7 @@ from ..handoffs import Handoff from ..items import ItemHelpers, ModelResponse, TResponseInputItem from ..logger import logger +from ..model_settings import MCPToolChoice from ..tool import ( CodeInterpreterTool, ComputerTool, @@ -303,10 +304,16 @@ class ConvertedTools: class Converter: @classmethod def convert_tool_choice( - cls, tool_choice: Literal["auto", "required", "none"] | str | None + cls, tool_choice: Literal["auto", "required", "none"] | str | MCPToolChoice | None ) -> response_create_params.ToolChoice | NotGiven: if tool_choice is None: return NOT_GIVEN + elif isinstance(tool_choice, MCPToolChoice): + return { + "server_label": tool_choice.server_label, + "type": "mcp", + "name": tool_choice.name, + } elif tool_choice == "required": return "required" elif tool_choice == "auto": @@ -334,9 +341,9 @@ def convert_tool_choice( "type": "code_interpreter", } elif tool_choice == "mcp": - return { - "type": "mcp", - } + # Note that this is still here for backwards compatibility, + # but migrating to MCPToolChoice is recommended. + return {"type": "mcp"} # type: ignore [typeddict-item] else: return { "type": "function", diff --git a/src/agents/realtime/README.md b/src/agents/realtime/README.md new file mode 100644 index 000000000..9acc23160 --- /dev/null +++ b/src/agents/realtime/README.md @@ -0,0 +1,3 @@ +# Realtime + +Realtime agents are in beta: expect some breaking changes over the next few weeks as we find issues and fix them. diff --git a/src/agents/realtime/__init__.py b/src/agents/realtime/__init__.py new file mode 100644 index 000000000..0e3e12f75 --- /dev/null +++ b/src/agents/realtime/__init__.py @@ -0,0 +1,174 @@ +from .agent import RealtimeAgent, RealtimeAgentHooks, RealtimeRunHooks +from .config import ( + RealtimeAudioFormat, + RealtimeClientMessage, + RealtimeGuardrailsSettings, + RealtimeInputAudioTranscriptionConfig, + RealtimeModelName, + RealtimeModelTracingConfig, + RealtimeRunConfig, + RealtimeSessionModelSettings, + RealtimeTurnDetectionConfig, + RealtimeUserInput, + RealtimeUserInputMessage, + RealtimeUserInputText, +) +from .events import ( + RealtimeAgentEndEvent, + RealtimeAgentStartEvent, + RealtimeAudio, + RealtimeAudioEnd, + RealtimeAudioInterrupted, + RealtimeError, + RealtimeEventInfo, + RealtimeGuardrailTripped, + RealtimeHandoffEvent, + RealtimeHistoryAdded, + RealtimeHistoryUpdated, + RealtimeRawModelEvent, + RealtimeSessionEvent, + RealtimeToolEnd, + RealtimeToolStart, +) +from .items import ( + AssistantMessageItem, + AssistantText, + InputAudio, + InputText, + RealtimeItem, + RealtimeMessageItem, + RealtimeResponse, + RealtimeToolCallItem, + SystemMessageItem, + UserMessageItem, +) +from .model import ( + RealtimeModel, + RealtimeModelConfig, + RealtimeModelListener, +) +from .model_events import ( + RealtimeConnectionStatus, + RealtimeModelAudioDoneEvent, + RealtimeModelAudioEvent, + RealtimeModelAudioInterruptedEvent, + RealtimeModelConnectionStatusEvent, + RealtimeModelErrorEvent, + RealtimeModelEvent, + RealtimeModelExceptionEvent, + RealtimeModelInputAudioTranscriptionCompletedEvent, + RealtimeModelItemDeletedEvent, + RealtimeModelItemUpdatedEvent, + RealtimeModelOtherEvent, + RealtimeModelToolCallEvent, + RealtimeModelTranscriptDeltaEvent, + RealtimeModelTurnEndedEvent, + RealtimeModelTurnStartedEvent, +) +from .model_inputs import ( + RealtimeModelInputTextContent, + RealtimeModelRawClientMessage, + RealtimeModelSendAudio, + RealtimeModelSendEvent, + RealtimeModelSendInterrupt, + RealtimeModelSendRawMessage, + RealtimeModelSendSessionUpdate, + RealtimeModelSendToolOutput, + RealtimeModelSendUserInput, + RealtimeModelUserInput, + RealtimeModelUserInputMessage, +) +from .openai_realtime import ( + DEFAULT_MODEL_SETTINGS, + OpenAIRealtimeWebSocketModel, + get_api_key, +) +from .runner import RealtimeRunner +from .session import RealtimeSession + +__all__ = [ + # Agent + "RealtimeAgent", + "RealtimeAgentHooks", + "RealtimeRunHooks", + "RealtimeRunner", + # Config + "RealtimeAudioFormat", + "RealtimeClientMessage", + "RealtimeGuardrailsSettings", + "RealtimeInputAudioTranscriptionConfig", + "RealtimeModelName", + "RealtimeModelTracingConfig", + "RealtimeRunConfig", + "RealtimeSessionModelSettings", + "RealtimeTurnDetectionConfig", + "RealtimeUserInput", + "RealtimeUserInputMessage", + "RealtimeUserInputText", + # Events + "RealtimeAgentEndEvent", + "RealtimeAgentStartEvent", + "RealtimeAudio", + "RealtimeAudioEnd", + "RealtimeAudioInterrupted", + "RealtimeError", + "RealtimeEventInfo", + "RealtimeGuardrailTripped", + "RealtimeHandoffEvent", + "RealtimeHistoryAdded", + "RealtimeHistoryUpdated", + "RealtimeRawModelEvent", + "RealtimeSessionEvent", + "RealtimeToolEnd", + "RealtimeToolStart", + # Items + "AssistantMessageItem", + "AssistantText", + "InputAudio", + "InputText", + "RealtimeItem", + "RealtimeMessageItem", + "RealtimeResponse", + "RealtimeToolCallItem", + "SystemMessageItem", + "UserMessageItem", + # Model + "RealtimeModel", + "RealtimeModelConfig", + "RealtimeModelListener", + # Model Events + "RealtimeConnectionStatus", + "RealtimeModelAudioDoneEvent", + "RealtimeModelAudioEvent", + "RealtimeModelAudioInterruptedEvent", + "RealtimeModelConnectionStatusEvent", + "RealtimeModelErrorEvent", + "RealtimeModelEvent", + "RealtimeModelExceptionEvent", + "RealtimeModelInputAudioTranscriptionCompletedEvent", + "RealtimeModelItemDeletedEvent", + "RealtimeModelItemUpdatedEvent", + "RealtimeModelOtherEvent", + "RealtimeModelToolCallEvent", + "RealtimeModelTranscriptDeltaEvent", + "RealtimeModelTurnEndedEvent", + "RealtimeModelTurnStartedEvent", + # Model Inputs + "RealtimeModelInputTextContent", + "RealtimeModelRawClientMessage", + "RealtimeModelSendAudio", + "RealtimeModelSendEvent", + "RealtimeModelSendInterrupt", + "RealtimeModelSendRawMessage", + "RealtimeModelSendSessionUpdate", + "RealtimeModelSendToolOutput", + "RealtimeModelSendUserInput", + "RealtimeModelUserInput", + "RealtimeModelUserInputMessage", + # OpenAI Realtime + "DEFAULT_MODEL_SETTINGS", + "OpenAIRealtimeWebSocketModel", + "get_api_key", + # Session + "RealtimeSession", +] diff --git a/src/agents/realtime/agent.py b/src/agents/realtime/agent.py new file mode 100644 index 000000000..9bbed8cb4 --- /dev/null +++ b/src/agents/realtime/agent.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import dataclasses +import inspect +from collections.abc import Awaitable +from dataclasses import dataclass +from typing import Any, Callable, Generic, cast + +from ..agent import AgentBase +from ..lifecycle import AgentHooksBase, RunHooksBase +from ..logger import logger +from ..run_context import RunContextWrapper, TContext +from ..util._types import MaybeAwaitable + +RealtimeAgentHooks = AgentHooksBase[TContext, "RealtimeAgent[TContext]"] +"""Agent hooks for `RealtimeAgent`s.""" + +RealtimeRunHooks = RunHooksBase[TContext, "RealtimeAgent[TContext]"] +"""Run hooks for `RealtimeAgent`s.""" + + +@dataclass +class RealtimeAgent(AgentBase, Generic[TContext]): + """A specialized agent instance that is meant to be used within a `RealtimeSession` to build + voice agents. Due to the nature of this agent, some configuration options are not supported + that are supported by regular `Agent` instances. For example: + - `model` choice is not supported, as all RealtimeAgents will be handled by the same model + within a `RealtimeSession`. + - `modelSettings` is not supported, as all RealtimeAgents will be handled by the same model + within a `RealtimeSession`. + - `outputType` is not supported, as RealtimeAgents do not support structured outputs. + - `toolUseBehavior` is not supported, as all RealtimeAgents will be handled by the same model + within a `RealtimeSession`. + - `voice` can be configured on an `Agent` level; however, it cannot be changed after the first + agent within a `RealtimeSession` has spoken. + + See `AgentBase` for base parameters that are shared with `Agent`s. + """ + + instructions: ( + str + | Callable[ + [RunContextWrapper[TContext], RealtimeAgent[TContext]], + MaybeAwaitable[str], + ] + | None + ) = None + """The instructions for the agent. Will be used as the "system prompt" when this agent is + invoked. Describes what the agent should do, and how it responds. + + Can either be a string, or a function that dynamically generates instructions for the agent. If + you provide a function, it will be called with the context and the agent instance. It must + return a string. + """ + + hooks: RealtimeAgentHooks | None = None + """A class that receives callbacks on various lifecycle events for this agent. + """ + + def clone(self, **kwargs: Any) -> RealtimeAgent[TContext]: + """Make a copy of the agent, with the given arguments changed. For example, you could do: + ``` + new_agent = agent.clone(instructions="New instructions") + ``` + """ + return dataclasses.replace(self, **kwargs) + + async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None: + """Get the system prompt for the agent.""" + if isinstance(self.instructions, str): + return self.instructions + elif callable(self.instructions): + if inspect.iscoroutinefunction(self.instructions): + return await cast(Awaitable[str], self.instructions(run_context, self)) + else: + return cast(str, self.instructions(run_context, self)) + elif self.instructions is not None: + logger.error(f"Instructions must be a string or a function, got {self.instructions}") + + return None diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py new file mode 100644 index 000000000..7f874cfb0 --- /dev/null +++ b/src/agents/realtime/config.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from typing import ( + Any, + Literal, + Union, +) + +from typing_extensions import NotRequired, TypeAlias, TypedDict + +from ..guardrail import OutputGuardrail +from ..model_settings import ToolChoice +from ..tool import Tool + +RealtimeModelName: TypeAlias = Union[ + Literal[ + "gpt-4o-realtime-preview", + "gpt-4o-mini-realtime-preview", + "gpt-4o-realtime-preview-2025-06-03", + "gpt-4o-realtime-preview-2024-12-17", + "gpt-4o-realtime-preview-2024-10-01", + "gpt-4o-mini-realtime-preview-2024-12-17", + ], + str, +] +"""The name of a realtime model.""" + + +RealtimeAudioFormat: TypeAlias = Union[Literal["pcm16", "g711_ulaw", "g711_alaw"], str] + + +class RealtimeClientMessage(TypedDict): + """A raw message to be sent to the model.""" + + type: str # explicitly required + other_data: NotRequired[dict[str, Any]] + """Merged into the message body.""" + + +class RealtimeInputAudioTranscriptionConfig(TypedDict): + language: NotRequired[str] + model: NotRequired[Literal["gpt-4o-transcribe", "gpt-4o-mini-transcribe", "whisper-1"] | str] + prompt: NotRequired[str] + + +class RealtimeTurnDetectionConfig(TypedDict): + """Turn detection config. Allows extra vendor keys if needed.""" + + type: NotRequired[Literal["semantic_vad", "server_vad"]] + create_response: NotRequired[bool] + eagerness: NotRequired[Literal["auto", "low", "medium", "high"]] + interrupt_response: NotRequired[bool] + prefix_padding_ms: NotRequired[int] + silence_duration_ms: NotRequired[int] + threshold: NotRequired[float] + + +class RealtimeSessionModelSettings(TypedDict): + """Model settings for a realtime model session.""" + + model_name: NotRequired[RealtimeModelName] + + instructions: NotRequired[str] + modalities: NotRequired[list[Literal["text", "audio"]]] + voice: NotRequired[str] + + input_audio_format: NotRequired[RealtimeAudioFormat] + output_audio_format: NotRequired[RealtimeAudioFormat] + input_audio_transcription: NotRequired[RealtimeInputAudioTranscriptionConfig] + turn_detection: NotRequired[RealtimeTurnDetectionConfig] + + tool_choice: NotRequired[ToolChoice] + tools: NotRequired[list[Tool]] + + tracing: NotRequired[RealtimeModelTracingConfig | None] + + +class RealtimeGuardrailsSettings(TypedDict): + """Settings for output guardrails in realtime sessions.""" + + debounce_text_length: NotRequired[int] + """ + The minimum number of characters to accumulate before running guardrails on transcript + deltas. Defaults to 100. Guardrails run every time the accumulated text reaches + 1x, 2x, 3x, etc. times this threshold. + """ + + +class RealtimeModelTracingConfig(TypedDict): + """Configuration for tracing in realtime model sessions.""" + + workflow_name: NotRequired[str] + """The workflow name to use for tracing.""" + + group_id: NotRequired[str] + """A group identifier to use for tracing, to link multiple traces together.""" + + metadata: NotRequired[dict[str, Any]] + """Additional metadata to include with the trace.""" + + +class RealtimeRunConfig(TypedDict): + model_settings: NotRequired[RealtimeSessionModelSettings] + + output_guardrails: NotRequired[list[OutputGuardrail[Any]]] + """List of output guardrails to run on the agent's responses.""" + + guardrails_settings: NotRequired[RealtimeGuardrailsSettings] + """Settings for guardrail execution.""" + + tracing_disabled: NotRequired[bool] + """Whether tracing is disabled for this run.""" + + # TODO (rm) Add history audio storage config + + +class RealtimeUserInputText(TypedDict): + type: Literal["input_text"] + text: str + + +class RealtimeUserInputMessage(TypedDict): + type: Literal["message"] + role: Literal["user"] + content: list[RealtimeUserInputText] + + +RealtimeUserInput: TypeAlias = Union[str, RealtimeUserInputMessage] diff --git a/src/agents/realtime/events.py b/src/agents/realtime/events.py new file mode 100644 index 000000000..24444b66e --- /dev/null +++ b/src/agents/realtime/events.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal, Union + +from typing_extensions import TypeAlias + +from ..guardrail import OutputGuardrailResult +from ..run_context import RunContextWrapper +from ..tool import Tool +from .agent import RealtimeAgent +from .items import RealtimeItem +from .model_events import RealtimeModelAudioEvent, RealtimeModelEvent + + +@dataclass +class RealtimeEventInfo: + context: RunContextWrapper + """The context for the event.""" + + +@dataclass +class RealtimeAgentStartEvent: + """A new agent has started.""" + + agent: RealtimeAgent + """The new agent.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["agent_start"] = "agent_start" + + +@dataclass +class RealtimeAgentEndEvent: + """An agent has ended.""" + + agent: RealtimeAgent + """The agent that ended.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["agent_end"] = "agent_end" + + +@dataclass +class RealtimeHandoffEvent: + """An agent has handed off to another agent.""" + + from_agent: RealtimeAgent + """The agent that handed off.""" + + to_agent: RealtimeAgent + """The agent that was handed off to.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["handoff"] = "handoff" + + +@dataclass +class RealtimeToolStart: + """An agent is starting a tool call.""" + + agent: RealtimeAgent + """The agent that updated.""" + + tool: Tool + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["tool_start"] = "tool_start" + + +@dataclass +class RealtimeToolEnd: + """An agent has ended a tool call.""" + + agent: RealtimeAgent + """The agent that ended the tool call.""" + + tool: Tool + """The tool that was called.""" + + output: Any + """The output of the tool call.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["tool_end"] = "tool_end" + + +@dataclass +class RealtimeRawModelEvent: + """Forwards raw events from the model layer.""" + + data: RealtimeModelEvent + """The raw data from the model layer.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["raw_model_event"] = "raw_model_event" + + +@dataclass +class RealtimeAudioEnd: + """Triggered when the agent stops generating audio.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["audio_end"] = "audio_end" + + +@dataclass +class RealtimeAudio: + """Triggered when the agent generates new audio to be played.""" + + audio: RealtimeModelAudioEvent + """The audio event from the model layer.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["audio"] = "audio" + + +@dataclass +class RealtimeAudioInterrupted: + """Triggered when the agent is interrupted. Can be listened to by the user to stop audio + playback or give visual indicators to the user. + """ + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["audio_interrupted"] = "audio_interrupted" + + +@dataclass +class RealtimeError: + """An error has occurred.""" + + error: Any + """The error that occurred.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["error"] = "error" + + +@dataclass +class RealtimeHistoryUpdated: + """The history has been updated. Contains the full history of the session.""" + + history: list[RealtimeItem] + """The full history of the session.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["history_updated"] = "history_updated" + + +@dataclass +class RealtimeHistoryAdded: + """A new item has been added to the history.""" + + item: RealtimeItem + """The new item that was added to the history.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["history_added"] = "history_added" + + +@dataclass +class RealtimeGuardrailTripped: + """A guardrail has been tripped and the agent has been interrupted.""" + + guardrail_results: list[OutputGuardrailResult] + """The results from all triggered guardrails.""" + + message: str + """The message that was being generated when the guardrail was triggered.""" + + info: RealtimeEventInfo + """Common info for all events, such as the context.""" + + type: Literal["guardrail_tripped"] = "guardrail_tripped" + + +RealtimeSessionEvent: TypeAlias = Union[ + RealtimeAgentStartEvent, + RealtimeAgentEndEvent, + RealtimeHandoffEvent, + RealtimeToolStart, + RealtimeToolEnd, + RealtimeRawModelEvent, + RealtimeAudioEnd, + RealtimeAudio, + RealtimeAudioInterrupted, + RealtimeError, + RealtimeHistoryUpdated, + RealtimeHistoryAdded, + RealtimeGuardrailTripped, +] +"""An event emitted by the realtime session.""" diff --git a/src/agents/realtime/items.py b/src/agents/realtime/items.py new file mode 100644 index 000000000..a835e7a88 --- /dev/null +++ b/src/agents/realtime/items.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from typing import Annotated, Literal, Union + +from pydantic import BaseModel, ConfigDict, Field + + +class InputText(BaseModel): + type: Literal["input_text"] = "input_text" + text: str | None = None + + # Allow extra data + model_config = ConfigDict(extra="allow") + + +class InputAudio(BaseModel): + type: Literal["input_audio"] = "input_audio" + audio: str | None = None + transcript: str | None = None + + # Allow extra data + model_config = ConfigDict(extra="allow") + + +class AssistantText(BaseModel): + type: Literal["text"] = "text" + text: str | None = None + + # Allow extra data + model_config = ConfigDict(extra="allow") + + +class SystemMessageItem(BaseModel): + item_id: str + previous_item_id: str | None = None + type: Literal["message"] = "message" + role: Literal["system"] = "system" + content: list[InputText] + + # Allow extra data + model_config = ConfigDict(extra="allow") + + +class UserMessageItem(BaseModel): + item_id: str + previous_item_id: str | None = None + type: Literal["message"] = "message" + role: Literal["user"] = "user" + content: list[Annotated[InputText | InputAudio, Field(discriminator="type")]] + + # Allow extra data + model_config = ConfigDict(extra="allow") + + +class AssistantMessageItem(BaseModel): + item_id: str + previous_item_id: str | None = None + type: Literal["message"] = "message" + role: Literal["assistant"] = "assistant" + status: Literal["in_progress", "completed", "incomplete"] | None = None + content: list[AssistantText] + + # Allow extra data + model_config = ConfigDict(extra="allow") + + +RealtimeMessageItem = Annotated[ + Union[SystemMessageItem, UserMessageItem, AssistantMessageItem], + Field(discriminator="role"), +] + + +class RealtimeToolCallItem(BaseModel): + item_id: str + previous_item_id: str | None = None + type: Literal["function_call"] = "function_call" + status: Literal["in_progress", "completed"] + arguments: str + name: str + output: str | None = None + + # Allow extra data + model_config = ConfigDict(extra="allow") + + +RealtimeItem = Union[RealtimeMessageItem, RealtimeToolCallItem] + + +class RealtimeResponse(BaseModel): + id: str + output: list[RealtimeMessageItem] diff --git a/src/agents/realtime/model.py b/src/agents/realtime/model.py new file mode 100644 index 000000000..e279ecc95 --- /dev/null +++ b/src/agents/realtime/model.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import abc +from typing import Callable + +from typing_extensions import NotRequired, TypedDict + +from ..util._types import MaybeAwaitable +from .config import ( + RealtimeSessionModelSettings, +) +from .model_events import RealtimeModelEvent +from .model_inputs import RealtimeModelSendEvent + + +class RealtimeModelListener(abc.ABC): + """A listener for realtime transport events.""" + + @abc.abstractmethod + async def on_event(self, event: RealtimeModelEvent) -> None: + """Called when an event is emitted by the realtime transport.""" + pass + + +class RealtimeModelConfig(TypedDict): + """Options for connecting to a realtime model.""" + + api_key: NotRequired[str | Callable[[], MaybeAwaitable[str]]] + """The API key (or function that returns a key) to use when connecting. If unset, the model will + try to use a sane default. For example, the OpenAI Realtime model will try to use the + `OPENAI_API_KEY` environment variable. + """ + + url: NotRequired[str] + """The URL to use when connecting. If unset, the model will use a sane default. For example, + the OpenAI Realtime model will use the default OpenAI WebSocket URL. + """ + + initial_model_settings: NotRequired[RealtimeSessionModelSettings] + """The initial model settings to use when connecting.""" + + +class RealtimeModel(abc.ABC): + """Interface for connecting to a realtime model and sending/receiving events.""" + + @abc.abstractmethod + async def connect(self, options: RealtimeModelConfig) -> None: + """Establish a connection to the model and keep it alive.""" + pass + + @abc.abstractmethod + def add_listener(self, listener: RealtimeModelListener) -> None: + """Add a listener to the model.""" + pass + + @abc.abstractmethod + def remove_listener(self, listener: RealtimeModelListener) -> None: + """Remove a listener from the model.""" + pass + + @abc.abstractmethod + async def send_event(self, event: RealtimeModelSendEvent) -> None: + """Send an event to the model.""" + pass + + @abc.abstractmethod + async def close(self) -> None: + """Close the session.""" + pass diff --git a/src/agents/realtime/model_events.py b/src/agents/realtime/model_events.py new file mode 100644 index 000000000..3a158ef4e --- /dev/null +++ b/src/agents/realtime/model_events.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal, Union + +from typing_extensions import TypeAlias + +from .items import RealtimeItem + +RealtimeConnectionStatus: TypeAlias = Literal["connecting", "connected", "disconnected"] + + +@dataclass +class RealtimeModelErrorEvent: + """Represents a transport‑layer error.""" + + error: Any + + type: Literal["error"] = "error" + + +@dataclass +class RealtimeModelToolCallEvent: + """Model attempted a tool/function call.""" + + name: str + call_id: str + arguments: str + + id: str | None = None + previous_item_id: str | None = None + + type: Literal["function_call"] = "function_call" + + +@dataclass +class RealtimeModelAudioEvent: + """Raw audio bytes emitted by the model.""" + + data: bytes + response_id: str + + type: Literal["audio"] = "audio" + + +@dataclass +class RealtimeModelAudioInterruptedEvent: + """Audio interrupted.""" + + type: Literal["audio_interrupted"] = "audio_interrupted" + + +@dataclass +class RealtimeModelAudioDoneEvent: + """Audio done.""" + + type: Literal["audio_done"] = "audio_done" + + +@dataclass +class RealtimeModelInputAudioTranscriptionCompletedEvent: + """Input audio transcription completed.""" + + item_id: str + transcript: str + + type: Literal["input_audio_transcription_completed"] = "input_audio_transcription_completed" + + +@dataclass +class RealtimeModelTranscriptDeltaEvent: + """Partial transcript update.""" + + item_id: str + delta: str + response_id: str + + type: Literal["transcript_delta"] = "transcript_delta" + + +@dataclass +class RealtimeModelItemUpdatedEvent: + """Item added to the history or updated.""" + + item: RealtimeItem + + type: Literal["item_updated"] = "item_updated" + + +@dataclass +class RealtimeModelItemDeletedEvent: + """Item deleted from the history.""" + + item_id: str + + type: Literal["item_deleted"] = "item_deleted" + + +@dataclass +class RealtimeModelConnectionStatusEvent: + """Connection status changed.""" + + status: RealtimeConnectionStatus + + type: Literal["connection_status"] = "connection_status" + + +@dataclass +class RealtimeModelTurnStartedEvent: + """Triggered when the model starts generating a response for a turn.""" + + type: Literal["turn_started"] = "turn_started" + + +@dataclass +class RealtimeModelTurnEndedEvent: + """Triggered when the model finishes generating a response for a turn.""" + + type: Literal["turn_ended"] = "turn_ended" + + +@dataclass +class RealtimeModelOtherEvent: + """Used as a catchall for vendor-specific events.""" + + data: Any + + type: Literal["other"] = "other" + + +@dataclass +class RealtimeModelExceptionEvent: + """Exception occurred during model operation.""" + + exception: Exception + context: str | None = None + + type: Literal["exception"] = "exception" + + +# TODO (rm) Add usage events + + +RealtimeModelEvent: TypeAlias = Union[ + RealtimeModelErrorEvent, + RealtimeModelToolCallEvent, + RealtimeModelAudioEvent, + RealtimeModelAudioInterruptedEvent, + RealtimeModelAudioDoneEvent, + RealtimeModelInputAudioTranscriptionCompletedEvent, + RealtimeModelTranscriptDeltaEvent, + RealtimeModelItemUpdatedEvent, + RealtimeModelItemDeletedEvent, + RealtimeModelConnectionStatusEvent, + RealtimeModelTurnStartedEvent, + RealtimeModelTurnEndedEvent, + RealtimeModelOtherEvent, + RealtimeModelExceptionEvent, +] diff --git a/src/agents/realtime/model_inputs.py b/src/agents/realtime/model_inputs.py new file mode 100644 index 000000000..df09e6697 --- /dev/null +++ b/src/agents/realtime/model_inputs.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal, Union + +from typing_extensions import NotRequired, TypeAlias, TypedDict + +from .config import RealtimeSessionModelSettings +from .model_events import RealtimeModelToolCallEvent + + +class RealtimeModelRawClientMessage(TypedDict): + """A raw message to be sent to the model.""" + + type: str # explicitly required + other_data: NotRequired[dict[str, Any]] + """Merged into the message body.""" + + +class RealtimeModelInputTextContent(TypedDict): + """A piece of text to be sent to the model.""" + + type: Literal["input_text"] + text: str + + +class RealtimeModelUserInputMessage(TypedDict): + """A message to be sent to the model.""" + + type: Literal["message"] + role: Literal["user"] + content: list[RealtimeModelInputTextContent] + + +RealtimeModelUserInput: TypeAlias = Union[str, RealtimeModelUserInputMessage] +"""A user input to be sent to the model.""" + + +# Model messages + + +@dataclass +class RealtimeModelSendRawMessage: + """Send a raw message to the model.""" + + message: RealtimeModelRawClientMessage + """The message to send.""" + + +@dataclass +class RealtimeModelSendUserInput: + """Send a user input to the model.""" + + user_input: RealtimeModelUserInput + """The user input to send.""" + + +@dataclass +class RealtimeModelSendAudio: + """Send audio to the model.""" + + audio: bytes + commit: bool = False + + +@dataclass +class RealtimeModelSendToolOutput: + """Send tool output to the model.""" + + tool_call: RealtimeModelToolCallEvent + """The tool call to send.""" + + output: str + """The output to send.""" + + start_response: bool + """Whether to start a response.""" + + +@dataclass +class RealtimeModelSendInterrupt: + """Send an interrupt to the model.""" + + +@dataclass +class RealtimeModelSendSessionUpdate: + """Send a session update to the model.""" + + session_settings: RealtimeSessionModelSettings + """The updated session settings to send.""" + + +RealtimeModelSendEvent: TypeAlias = Union[ + RealtimeModelSendRawMessage, + RealtimeModelSendUserInput, + RealtimeModelSendAudio, + RealtimeModelSendToolOutput, + RealtimeModelSendInterrupt, + RealtimeModelSendSessionUpdate, +] diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py new file mode 100644 index 000000000..1c4a4de3c --- /dev/null +++ b/src/agents/realtime/openai_realtime.py @@ -0,0 +1,584 @@ +from __future__ import annotations + +import asyncio +import base64 +import inspect +import json +import os +from datetime import datetime +from typing import Any, Callable, Literal + +import pydantic +import websockets +from openai.types.beta.realtime.conversation_item import ConversationItem +from openai.types.beta.realtime.realtime_server_event import ( + RealtimeServerEvent as OpenAIRealtimeServerEvent, +) +from openai.types.beta.realtime.response_audio_delta_event import ResponseAudioDeltaEvent +from openai.types.beta.realtime.session_update_event import ( + Session as OpenAISessionObject, + SessionTool as OpenAISessionTool, +) +from pydantic import TypeAdapter +from typing_extensions import assert_never +from websockets.asyncio.client import ClientConnection + +from agents.tool import FunctionTool, Tool +from agents.util._types import MaybeAwaitable + +from ..exceptions import UserError +from ..logger import logger +from .config import ( + RealtimeModelTracingConfig, + RealtimeSessionModelSettings, +) +from .items import RealtimeMessageItem, RealtimeToolCallItem +from .model import ( + RealtimeModel, + RealtimeModelConfig, + RealtimeModelListener, +) +from .model_events import ( + RealtimeModelAudioDoneEvent, + RealtimeModelAudioEvent, + RealtimeModelAudioInterruptedEvent, + RealtimeModelErrorEvent, + RealtimeModelEvent, + RealtimeModelExceptionEvent, + RealtimeModelInputAudioTranscriptionCompletedEvent, + RealtimeModelItemDeletedEvent, + RealtimeModelItemUpdatedEvent, + RealtimeModelToolCallEvent, + RealtimeModelTranscriptDeltaEvent, + RealtimeModelTurnEndedEvent, + RealtimeModelTurnStartedEvent, +) +from .model_inputs import ( + RealtimeModelSendAudio, + RealtimeModelSendEvent, + RealtimeModelSendInterrupt, + RealtimeModelSendRawMessage, + RealtimeModelSendSessionUpdate, + RealtimeModelSendToolOutput, + RealtimeModelSendUserInput, +) + +DEFAULT_MODEL_SETTINGS: RealtimeSessionModelSettings = { + "voice": "ash", + "modalities": ["text", "audio"], + "input_audio_format": "pcm16", + "output_audio_format": "pcm16", + "input_audio_transcription": { + "model": "gpt-4o-mini-transcribe", + }, + "turn_detection": {"type": "semantic_vad"}, +} + + +async def get_api_key(key: str | Callable[[], MaybeAwaitable[str]] | None) -> str | None: + if isinstance(key, str): + return key + elif callable(key): + result = key() + if inspect.isawaitable(result): + return await result + return result + + return os.getenv("OPENAI_API_KEY") + + +class OpenAIRealtimeWebSocketModel(RealtimeModel): + """A model that uses OpenAI's WebSocket API.""" + + def __init__(self) -> None: + self.model = "gpt-4o-realtime-preview" # Default model + self._websocket: ClientConnection | None = None + self._websocket_task: asyncio.Task[None] | None = None + self._listeners: list[RealtimeModelListener] = [] + self._current_item_id: str | None = None + self._audio_start_time: datetime | None = None + self._audio_length_ms: float = 0.0 + self._ongoing_response: bool = False + self._current_audio_content_index: int | None = None + self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None + + async def connect(self, options: RealtimeModelConfig) -> None: + """Establish a connection to the model and keep it alive.""" + assert self._websocket is None, "Already connected" + assert self._websocket_task is None, "Already connected" + + model_settings: RealtimeSessionModelSettings = options.get("initial_model_settings", {}) + + self.model = model_settings.get("model_name", self.model) + api_key = await get_api_key(options.get("api_key")) + + if "tracing" in model_settings: + self._tracing_config = model_settings["tracing"] + else: + self._tracing_config = "auto" + + if not api_key: + raise UserError("API key is required but was not provided.") + + url = options.get("url", f"wss://api.openai.com/v1/realtime?model={self.model}") + + headers = { + "Authorization": f"Bearer {api_key}", + "OpenAI-Beta": "realtime=v1", + } + self._websocket = await websockets.connect(url, additional_headers=headers) + self._websocket_task = asyncio.create_task(self._listen_for_messages()) + await self._update_session_config(model_settings) + + async def _send_tracing_config( + self, tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None + ) -> None: + """Update tracing configuration via session.update event.""" + if tracing_config is not None: + await self._send_raw_message( + RealtimeModelSendRawMessage( + message={ + "type": "session.update", + "other_data": {"session": {"tracing": tracing_config}}, + } + ) + ) + + def add_listener(self, listener: RealtimeModelListener) -> None: + """Add a listener to the model.""" + if listener not in self._listeners: + self._listeners.append(listener) + + def remove_listener(self, listener: RealtimeModelListener) -> None: + """Remove a listener from the model.""" + if listener in self._listeners: + self._listeners.remove(listener) + + async def _emit_event(self, event: RealtimeModelEvent) -> None: + """Emit an event to the listeners.""" + for listener in self._listeners: + await listener.on_event(event) + + async def _listen_for_messages(self): + assert self._websocket is not None, "Not connected" + + try: + async for message in self._websocket: + try: + parsed = json.loads(message) + await self._handle_ws_event(parsed) + except json.JSONDecodeError as e: + await self._emit_event( + RealtimeModelExceptionEvent( + exception=e, context="Failed to parse WebSocket message as JSON" + ) + ) + except Exception as e: + await self._emit_event( + RealtimeModelExceptionEvent( + exception=e, context="Error handling WebSocket event" + ) + ) + + except websockets.exceptions.ConnectionClosedOK: + # Normal connection closure - no exception event needed + logger.info("WebSocket connection closed normally") + except websockets.exceptions.ConnectionClosed as e: + await self._emit_event( + RealtimeModelExceptionEvent( + exception=e, context="WebSocket connection closed unexpectedly" + ) + ) + except Exception as e: + await self._emit_event( + RealtimeModelExceptionEvent( + exception=e, context="WebSocket error in message listener" + ) + ) + + async def send_event(self, event: RealtimeModelSendEvent) -> None: + """Send an event to the model.""" + if isinstance(event, RealtimeModelSendRawMessage): + await self._send_raw_message(event) + elif isinstance(event, RealtimeModelSendUserInput): + await self._send_user_input(event) + elif isinstance(event, RealtimeModelSendAudio): + await self._send_audio(event) + elif isinstance(event, RealtimeModelSendToolOutput): + await self._send_tool_output(event) + elif isinstance(event, RealtimeModelSendInterrupt): + await self._send_interrupt(event) + elif isinstance(event, RealtimeModelSendSessionUpdate): + await self._send_session_update(event) + else: + assert_never(event) + raise ValueError(f"Unknown event type: {type(event)}") + + async def _send_raw_message(self, event: RealtimeModelSendRawMessage) -> None: + """Send a raw message to the model.""" + assert self._websocket is not None, "Not connected" + + converted_event = { + "type": event.message["type"], + } + + converted_event.update(event.message.get("other_data", {})) + + await self._websocket.send(json.dumps(converted_event)) + + async def _send_user_input(self, event: RealtimeModelSendUserInput) -> None: + message = ( + event.user_input + if isinstance(event.user_input, dict) + else { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": event.user_input}], + } + ) + other_data = { + "item": message, + } + + await self._send_raw_message( + RealtimeModelSendRawMessage( + message={"type": "conversation.item.create", "other_data": other_data} + ) + ) + await self._send_raw_message( + RealtimeModelSendRawMessage(message={"type": "response.create"}) + ) + + async def _send_audio(self, event: RealtimeModelSendAudio) -> None: + base64_audio = base64.b64encode(event.audio).decode("utf-8") + await self._send_raw_message( + RealtimeModelSendRawMessage( + message={ + "type": "input_audio_buffer.append", + "other_data": { + "audio": base64_audio, + }, + } + ) + ) + if event.commit: + await self._send_raw_message( + RealtimeModelSendRawMessage(message={"type": "input_audio_buffer.commit"}) + ) + + async def _send_tool_output(self, event: RealtimeModelSendToolOutput) -> None: + await self._send_raw_message( + RealtimeModelSendRawMessage( + message={ + "type": "conversation.item.create", + "other_data": { + "item": { + "type": "function_call_output", + "output": event.output, + "call_id": event.tool_call.id, + }, + }, + } + ) + ) + + tool_item = RealtimeToolCallItem( + item_id=event.tool_call.id or "", + previous_item_id=event.tool_call.previous_item_id, + type="function_call", + status="completed", + arguments=event.tool_call.arguments, + name=event.tool_call.name, + output=event.output, + ) + await self._emit_event(RealtimeModelItemUpdatedEvent(item=tool_item)) + + if event.start_response: + await self._send_raw_message( + RealtimeModelSendRawMessage(message={"type": "response.create"}) + ) + + async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None: + if not self._current_item_id or not self._audio_start_time: + return + + await self._cancel_response() + + elapsed_time_ms = (datetime.now() - self._audio_start_time).total_seconds() * 1000 + if elapsed_time_ms > 0 and elapsed_time_ms < self._audio_length_ms: + await self._emit_event(RealtimeModelAudioInterruptedEvent()) + await self._send_raw_message( + RealtimeModelSendRawMessage( + message={ + "type": "conversation.item.truncate", + "other_data": { + "item_id": self._current_item_id, + "content_index": self._current_audio_content_index, + "audio_end_ms": elapsed_time_ms, + }, + } + ) + ) + + self._current_item_id = None + self._audio_start_time = None + self._audio_length_ms = 0.0 + self._current_audio_content_index = None + + async def _send_session_update(self, event: RealtimeModelSendSessionUpdate) -> None: + """Send a session update to the model.""" + await self._update_session_config(event.session_settings) + + async def _handle_audio_delta(self, parsed: ResponseAudioDeltaEvent) -> None: + """Handle audio delta events and update audio tracking state.""" + self._current_audio_content_index = parsed.content_index + self._current_item_id = parsed.item_id + if self._audio_start_time is None: + self._audio_start_time = datetime.now() + self._audio_length_ms = 0.0 + + audio_bytes = base64.b64decode(parsed.delta) + # Calculate audio length in ms using 24KHz pcm16le + self._audio_length_ms += self._calculate_audio_length_ms(audio_bytes) + await self._emit_event( + RealtimeModelAudioEvent(data=audio_bytes, response_id=parsed.response_id) + ) + + def _calculate_audio_length_ms(self, audio_bytes: bytes) -> float: + """Calculate audio length in milliseconds for 24KHz PCM16LE format.""" + return len(audio_bytes) / 24 / 2 + + async def _handle_output_item(self, item: ConversationItem) -> None: + """Handle response output item events (function calls and messages).""" + if item.type == "function_call" and item.status == "completed": + tool_call = RealtimeToolCallItem( + item_id=item.id or "", + previous_item_id=None, + type="function_call", + # We use the same item for tool call and output, so it will be completed by the + # output being added + status="in_progress", + arguments=item.arguments or "", + name=item.name or "", + output=None, + ) + await self._emit_event(RealtimeModelItemUpdatedEvent(item=tool_call)) + await self._emit_event( + RealtimeModelToolCallEvent( + call_id=item.id or "", + name=item.name or "", + arguments=item.arguments or "", + id=item.id or "", + ) + ) + elif item.type == "message": + # Handle message items from output_item events (no previous_item_id) + message_item: RealtimeMessageItem = TypeAdapter(RealtimeMessageItem).validate_python( + { + "item_id": item.id or "", + "type": item.type, + "role": item.role, + "content": item.content, + "status": "in_progress", + } + ) + await self._emit_event(RealtimeModelItemUpdatedEvent(item=message_item)) + + async def _handle_conversation_item( + self, item: ConversationItem, previous_item_id: str | None + ) -> None: + """Handle conversation item creation/retrieval events.""" + message_item = _ConversionHelper.conversation_item_to_realtime_message_item( + item, previous_item_id + ) + await self._emit_event(RealtimeModelItemUpdatedEvent(item=message_item)) + + async def close(self) -> None: + """Close the session.""" + if self._websocket: + await self._websocket.close() + self._websocket = None + if self._websocket_task: + self._websocket_task.cancel() + self._websocket_task = None + + async def _cancel_response(self) -> None: + if self._ongoing_response: + await self._send_raw_message( + RealtimeModelSendRawMessage(message={"type": "response.cancel"}) + ) + self._ongoing_response = False + + async def _handle_ws_event(self, event: dict[str, Any]): + try: + if "previous_item_id" in event and event["previous_item_id"] is None: + event["previous_item_id"] = "" # TODO (rm) remove + parsed: OpenAIRealtimeServerEvent = TypeAdapter( + OpenAIRealtimeServerEvent + ).validate_python(event) + except pydantic.ValidationError as e: + logger.error(f"Failed to validate server event: {event}", exc_info=True) + await self._emit_event( + RealtimeModelErrorEvent( + error=e, + ) + ) + return + except Exception as e: + event_type = event.get("type", "unknown") if isinstance(event, dict) else "unknown" + logger.error(f"Failed to validate server event: {event}", exc_info=True) + await self._emit_event( + RealtimeModelExceptionEvent( + exception=e, + context=f"Failed to validate server event: {event_type}", + ) + ) + return + + if parsed.type == "response.audio.delta": + await self._handle_audio_delta(parsed) + elif parsed.type == "response.audio.done": + await self._emit_event(RealtimeModelAudioDoneEvent()) + elif parsed.type == "input_audio_buffer.speech_started": + await self._send_interrupt(RealtimeModelSendInterrupt()) + elif parsed.type == "response.created": + self._ongoing_response = True + await self._emit_event(RealtimeModelTurnStartedEvent()) + elif parsed.type == "response.done": + self._ongoing_response = False + await self._emit_event(RealtimeModelTurnEndedEvent()) + elif parsed.type == "session.created": + await self._send_tracing_config(self._tracing_config) + elif parsed.type == "error": + await self._emit_event(RealtimeModelErrorEvent(error=parsed.error)) + elif parsed.type == "conversation.item.deleted": + await self._emit_event(RealtimeModelItemDeletedEvent(item_id=parsed.item_id)) + elif ( + parsed.type == "conversation.item.created" + or parsed.type == "conversation.item.retrieved" + ): + previous_item_id = ( + parsed.previous_item_id if parsed.type == "conversation.item.created" else None + ) + if parsed.item.type == "message": + await self._handle_conversation_item(parsed.item, previous_item_id) + elif ( + parsed.type == "conversation.item.input_audio_transcription.completed" + or parsed.type == "conversation.item.truncated" + ): + await self._send_raw_message( + RealtimeModelSendRawMessage( + message={ + "type": "conversation.item.retrieve", + "other_data": { + "item_id": self._current_item_id, + }, + } + ) + ) + if parsed.type == "conversation.item.input_audio_transcription.completed": + await self._emit_event( + RealtimeModelInputAudioTranscriptionCompletedEvent( + item_id=parsed.item_id, transcript=parsed.transcript + ) + ) + elif parsed.type == "response.audio_transcript.delta": + await self._emit_event( + RealtimeModelTranscriptDeltaEvent( + item_id=parsed.item_id, delta=parsed.delta, response_id=parsed.response_id + ) + ) + elif ( + parsed.type == "conversation.item.input_audio_transcription.delta" + or parsed.type == "response.text.delta" + or parsed.type == "response.function_call_arguments.delta" + ): + # No support for partials yet + pass + elif ( + parsed.type == "response.output_item.added" + or parsed.type == "response.output_item.done" + ): + await self._handle_output_item(parsed.item) + + async def _update_session_config(self, model_settings: RealtimeSessionModelSettings) -> None: + session_config = self._get_session_config(model_settings) + await self._send_raw_message( + RealtimeModelSendRawMessage( + message={ + "type": "session.update", + "other_data": { + "session": session_config.model_dump(exclude_unset=True, exclude_none=True) + }, + } + ) + ) + + def _get_session_config( + self, model_settings: RealtimeSessionModelSettings + ) -> OpenAISessionObject: + """Get the session config.""" + return OpenAISessionObject( + instructions=model_settings.get("instructions", None), + model=( + model_settings.get("model_name", self.model) # type: ignore + or DEFAULT_MODEL_SETTINGS.get("model_name") + ), + voice=model_settings.get("voice", DEFAULT_MODEL_SETTINGS.get("voice")), + modalities=model_settings.get("modalities", DEFAULT_MODEL_SETTINGS.get("modalities")), + input_audio_format=model_settings.get( + "input_audio_format", + DEFAULT_MODEL_SETTINGS.get("input_audio_format"), # type: ignore + ), + output_audio_format=model_settings.get( + "output_audio_format", + DEFAULT_MODEL_SETTINGS.get("output_audio_format"), # type: ignore + ), + input_audio_transcription=model_settings.get( + "input_audio_transcription", + DEFAULT_MODEL_SETTINGS.get("input_audio_transcription"), # type: ignore + ), + turn_detection=model_settings.get( + "turn_detection", + DEFAULT_MODEL_SETTINGS.get("turn_detection"), # type: ignore + ), + tool_choice=model_settings.get( + "tool_choice", + DEFAULT_MODEL_SETTINGS.get("tool_choice"), # type: ignore + ), + tools=self._tools_to_session_tools(model_settings.get("tools", [])), + ) + + def _tools_to_session_tools(self, tools: list[Tool]) -> list[OpenAISessionTool]: + converted_tools: list[OpenAISessionTool] = [] + for tool in tools: + if not isinstance(tool, FunctionTool): + raise UserError(f"Tool {tool.name} is unsupported. Must be a function tool.") + converted_tools.append( + OpenAISessionTool( + name=tool.name, + description=tool.description, + parameters=tool.params_json_schema, + type="function", + ) + ) + return converted_tools + + +class _ConversionHelper: + @classmethod + def conversation_item_to_realtime_message_item( + cls, item: ConversationItem, previous_item_id: str | None + ) -> RealtimeMessageItem: + return TypeAdapter(RealtimeMessageItem).validate_python( + { + "item_id": item.id or "", + "previous_item_id": previous_item_id, + "type": item.type, + "role": item.role, + "content": ( + [content.model_dump() for content in item.content] if item.content else [] + ), + "status": "in_progress", + }, + ) diff --git a/src/agents/realtime/runner.py b/src/agents/realtime/runner.py new file mode 100644 index 000000000..a7047a6f5 --- /dev/null +++ b/src/agents/realtime/runner.py @@ -0,0 +1,118 @@ +"""Minimal realtime session implementation for voice agents.""" + +from __future__ import annotations + +import asyncio + +from ..run_context import RunContextWrapper, TContext +from .agent import RealtimeAgent +from .config import ( + RealtimeRunConfig, + RealtimeSessionModelSettings, +) +from .model import ( + RealtimeModel, + RealtimeModelConfig, +) +from .openai_realtime import OpenAIRealtimeWebSocketModel +from .session import RealtimeSession + + +class RealtimeRunner: + """A `RealtimeRunner` is the equivalent of `Runner` for realtime agents. It automatically + handles multiple turns by maintaining a persistent connection with the underlying model + layer. + + The session manages the local history copy, executes tools, runs guardrails and facilitates + handoffs between agents. + + Since this code runs on your server, it uses WebSockets by default. You can optionally create + your own custom model layer by implementing the `RealtimeModel` interface. + """ + + def __init__( + self, + starting_agent: RealtimeAgent, + *, + model: RealtimeModel | None = None, + config: RealtimeRunConfig | None = None, + ) -> None: + """Initialize the realtime runner. + + Args: + starting_agent: The agent to start the session with. + context: The context to use for the session. + model: The model to use. If not provided, will use a default OpenAI realtime model. + config: Override parameters to use for the entire run. + """ + self._starting_agent = starting_agent + self._config = config + self._model = model or OpenAIRealtimeWebSocketModel() + + async def run( + self, *, context: TContext | None = None, model_config: RealtimeModelConfig | None = None + ) -> RealtimeSession: + """Start and returns a realtime session. + + Returns: + RealtimeSession: A session object that allows bidirectional communication with the + realtime model. + + Example: + ```python + runner = RealtimeRunner(agent) + async with await runner.run() as session: + await session.send_message("Hello") + async for event in session: + print(event) + ``` + """ + model_settings = await self._get_model_settings( + agent=self._starting_agent, + disable_tracing=self._config.get("tracing_disabled", False) if self._config else False, + initial_settings=model_config.get("initial_model_settings") if model_config else None, + overrides=self._config.get("model_settings") if self._config else None, + ) + + model_config = model_config.copy() if model_config else {} + model_config["initial_model_settings"] = model_settings + + # Create and return the connection + session = RealtimeSession( + model=self._model, + agent=self._starting_agent, + context=context, + model_config=model_config, + run_config=self._config, + ) + + return session + + async def _get_model_settings( + self, + agent: RealtimeAgent, + disable_tracing: bool, + context: TContext | None = None, + initial_settings: RealtimeSessionModelSettings | None = None, + overrides: RealtimeSessionModelSettings | None = None, + ) -> RealtimeSessionModelSettings: + context_wrapper = RunContextWrapper(context) + model_settings = initial_settings.copy() if initial_settings else {} + + instructions, tools = await asyncio.gather( + agent.get_system_prompt(context_wrapper), + agent.get_all_tools(context_wrapper), + ) + + if instructions is not None: + model_settings["instructions"] = instructions + if tools is not None: + model_settings["tools"] = tools + + if overrides: + model_settings.update(overrides) + + if disable_tracing: + model_settings["tracing"] = None + + return model_settings diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py new file mode 100644 index 000000000..07791c8d8 --- /dev/null +++ b/src/agents/realtime/session.py @@ -0,0 +1,502 @@ +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from typing import Any, cast + +from typing_extensions import assert_never + +from ..agent import Agent +from ..exceptions import ModelBehaviorError, UserError +from ..handoffs import Handoff +from ..run_context import RunContextWrapper, TContext +from ..tool import FunctionTool +from ..tool_context import ToolContext +from .agent import RealtimeAgent +from .config import RealtimeRunConfig, RealtimeSessionModelSettings, RealtimeUserInput +from .events import ( + RealtimeAgentEndEvent, + RealtimeAgentStartEvent, + RealtimeAudio, + RealtimeAudioEnd, + RealtimeAudioInterrupted, + RealtimeError, + RealtimeEventInfo, + RealtimeGuardrailTripped, + RealtimeHandoffEvent, + RealtimeHistoryAdded, + RealtimeHistoryUpdated, + RealtimeRawModelEvent, + RealtimeSessionEvent, + RealtimeToolEnd, + RealtimeToolStart, +) +from .items import InputAudio, InputText, RealtimeItem +from .model import RealtimeModel, RealtimeModelConfig, RealtimeModelListener +from .model_events import ( + RealtimeModelEvent, + RealtimeModelInputAudioTranscriptionCompletedEvent, + RealtimeModelToolCallEvent, +) +from .model_inputs import ( + RealtimeModelSendAudio, + RealtimeModelSendInterrupt, + RealtimeModelSendSessionUpdate, + RealtimeModelSendToolOutput, + RealtimeModelSendUserInput, +) + + +class RealtimeSession(RealtimeModelListener): + """A connection to a realtime model. It streams events from the model to you, and allows you to + send messages and audio to the model. + + Example: + ```python + runner = RealtimeRunner(agent) + async with await runner.run() as session: + # Send messages + await session.send_message("Hello") + await session.send_audio(audio_bytes) + + # Stream events + async for event in session: + if event.type == "audio": + # Handle audio event + pass + ``` + """ + + def __init__( + self, + model: RealtimeModel, + agent: RealtimeAgent, + context: TContext | None, + model_config: RealtimeModelConfig | None = None, + run_config: RealtimeRunConfig | None = None, + ) -> None: + """Initialize the session. + + Args: + model: The model to use. + agent: The current agent. + context: The context object. + model_config: Model configuration. + run_config: Runtime configuration including guardrails. + """ + self._model = model + self._current_agent = agent + self._context_wrapper = RunContextWrapper(context) + self._event_info = RealtimeEventInfo(context=self._context_wrapper) + self._history: list[RealtimeItem] = [] + self._model_config = model_config or {} + self._run_config = run_config or {} + self._event_queue: asyncio.Queue[RealtimeSessionEvent] = asyncio.Queue() + self._closed = False + self._stored_exception: Exception | None = None + + # Guardrails state tracking + self._interrupted_by_guardrail = False + self._item_transcripts: dict[str, str] = {} # item_id -> accumulated transcript + self._item_guardrail_run_counts: dict[str, int] = {} # item_id -> run count + self._debounce_text_length = self._run_config.get("guardrails_settings", {}).get( + "debounce_text_length", 100 + ) + + self._guardrail_tasks: set[asyncio.Task[Any]] = set() + + async def __aenter__(self) -> RealtimeSession: + """Start the session by connecting to the model. After this, you will be able to stream + events from the model and send messages and audio to the model. + """ + # Add ourselves as a listener + self._model.add_listener(self) + + # Connect to the model + await self._model.connect(self._model_config) + + # Emit initial history update + await self._put_event( + RealtimeHistoryUpdated( + history=self._history, + info=self._event_info, + ) + ) + + return self + + async def enter(self) -> RealtimeSession: + """Enter the async context manager. We strongly recommend using the async context manager + pattern instead of this method. If you use this, you need to manually call `close()` when + you are done. + """ + return await self.__aenter__() + + async def __aexit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None: + """End the session.""" + await self.close() + + async def __aiter__(self) -> AsyncIterator[RealtimeSessionEvent]: + """Iterate over events from the session.""" + while not self._closed: + try: + # Check if there's a stored exception to raise + if self._stored_exception is not None: + # Clean up resources before raising + await self._cleanup() + raise self._stored_exception + + event = await self._event_queue.get() + yield event + except asyncio.CancelledError: + break + + async def close(self) -> None: + """Close the session.""" + await self._cleanup() + + async def send_message(self, message: RealtimeUserInput) -> None: + """Send a message to the model.""" + await self._model.send_event(RealtimeModelSendUserInput(user_input=message)) + + async def send_audio(self, audio: bytes, *, commit: bool = False) -> None: + """Send a raw audio chunk to the model.""" + await self._model.send_event(RealtimeModelSendAudio(audio=audio, commit=commit)) + + async def interrupt(self) -> None: + """Interrupt the model.""" + await self._model.send_event(RealtimeModelSendInterrupt()) + + async def on_event(self, event: RealtimeModelEvent) -> None: + await self._put_event(RealtimeRawModelEvent(data=event, info=self._event_info)) + + if event.type == "error": + await self._put_event(RealtimeError(info=self._event_info, error=event.error)) + elif event.type == "function_call": + await self._handle_tool_call(event) + elif event.type == "audio": + await self._put_event(RealtimeAudio(info=self._event_info, audio=event)) + elif event.type == "audio_interrupted": + await self._put_event(RealtimeAudioInterrupted(info=self._event_info)) + elif event.type == "audio_done": + await self._put_event(RealtimeAudioEnd(info=self._event_info)) + elif event.type == "input_audio_transcription_completed": + self._history = RealtimeSession._get_new_history(self._history, event) + await self._put_event( + RealtimeHistoryUpdated(info=self._event_info, history=self._history) + ) + elif event.type == "transcript_delta": + # Accumulate transcript text for guardrail debouncing per item_id + item_id = event.item_id + if item_id not in self._item_transcripts: + self._item_transcripts[item_id] = "" + self._item_guardrail_run_counts[item_id] = 0 + + self._item_transcripts[item_id] += event.delta + + # Check if we should run guardrails based on debounce threshold + current_length = len(self._item_transcripts[item_id]) + threshold = self._debounce_text_length + next_run_threshold = (self._item_guardrail_run_counts[item_id] + 1) * threshold + + if current_length >= next_run_threshold: + self._item_guardrail_run_counts[item_id] += 1 + self._enqueue_guardrail_task(self._item_transcripts[item_id]) + elif event.type == "item_updated": + is_new = not any(item.item_id == event.item.item_id for item in self._history) + self._history = self._get_new_history(self._history, event.item) + if is_new: + new_item = next( + item for item in self._history if item.item_id == event.item.item_id + ) + await self._put_event(RealtimeHistoryAdded(info=self._event_info, item=new_item)) + else: + await self._put_event( + RealtimeHistoryUpdated(info=self._event_info, history=self._history) + ) + elif event.type == "item_deleted": + deleted_id = event.item_id + self._history = [item for item in self._history if item.item_id != deleted_id] + await self._put_event( + RealtimeHistoryUpdated(info=self._event_info, history=self._history) + ) + elif event.type == "connection_status": + pass + elif event.type == "turn_started": + await self._put_event( + RealtimeAgentStartEvent( + agent=self._current_agent, + info=self._event_info, + ) + ) + elif event.type == "turn_ended": + # Clear guardrail state for next turn + self._item_transcripts.clear() + self._item_guardrail_run_counts.clear() + self._interrupted_by_guardrail = False + + await self._put_event( + RealtimeAgentEndEvent( + agent=self._current_agent, + info=self._event_info, + ) + ) + elif event.type == "exception": + # Store the exception to be raised in __aiter__ + self._stored_exception = event.exception + elif event.type == "other": + pass + else: + assert_never(event) + + async def _put_event(self, event: RealtimeSessionEvent) -> None: + """Put an event into the queue.""" + await self._event_queue.put(event) + + async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None: + """Handle a tool call event.""" + all_tools = await self._current_agent.get_all_tools(self._context_wrapper) + function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} + handoff_map = {tool.name: tool for tool in all_tools if isinstance(tool, Handoff)} + + if event.name in function_map: + await self._put_event( + RealtimeToolStart( + info=self._event_info, + tool=function_map[event.name], + agent=self._current_agent, + ) + ) + + func_tool = function_map[event.name] + tool_context = ToolContext( + context=self._context_wrapper.context, + usage=self._context_wrapper.usage, + tool_name=event.name, + tool_call_id=event.call_id, + ) + result = await func_tool.on_invoke_tool(tool_context, event.arguments) + + await self._model.send_event( + RealtimeModelSendToolOutput( + tool_call=event, output=str(result), start_response=True + ) + ) + + await self._put_event( + RealtimeToolEnd( + info=self._event_info, + tool=func_tool, + output=result, + agent=self._current_agent, + ) + ) + elif event.name in handoff_map: + handoff = handoff_map[event.name] + tool_context = ToolContext( + context=self._context_wrapper.context, + usage=self._context_wrapper.usage, + tool_name=event.name, + tool_call_id=event.call_id, + ) + + # Execute the handoff to get the new agent + result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments) + if not isinstance(result, RealtimeAgent): + raise UserError(f"Handoff {handoff.name} returned invalid result: {type(result)}") + + # Store previous agent for event + previous_agent = self._current_agent + + # Update current agent + self._current_agent = result + + # Get updated model settings from new agent + updated_settings = await self._get__updated_model_settings(self._current_agent) + + # Send handoff event + await self._put_event( + RealtimeHandoffEvent( + from_agent=previous_agent, + to_agent=self._current_agent, + info=self._event_info, + ) + ) + + # Send tool output to complete the handoff + await self._model.send_event( + RealtimeModelSendToolOutput( + tool_call=event, + output=f"Handed off to {self._current_agent.name}", + start_response=True, + ) + ) + + # Send session update to model + await self._model.send_event( + RealtimeModelSendSessionUpdate(session_settings=updated_settings) + ) + else: + raise ModelBehaviorError(f"Tool {event.name} not found") + + @classmethod + def _get_new_history( + cls, + old_history: list[RealtimeItem], + event: RealtimeModelInputAudioTranscriptionCompletedEvent | RealtimeItem, + ) -> list[RealtimeItem]: + # Merge transcript into placeholder input_audio message. + if isinstance(event, RealtimeModelInputAudioTranscriptionCompletedEvent): + new_history: list[RealtimeItem] = [] + for item in old_history: + if item.item_id == event.item_id and item.type == "message" and item.role == "user": + content: list[InputText | InputAudio] = [] + for entry in item.content: + if entry.type == "input_audio": + copied_entry = entry.model_copy(update={"transcript": event.transcript}) + content.append(copied_entry) + else: + content.append(entry) # type: ignore + new_history.append( + item.model_copy(update={"content": content, "status": "completed"}) + ) + else: + new_history.append(item) + return new_history + + # Otherwise it's just a new item + # TODO (rm) Add support for audio storage config + + # If the item already exists, update it + existing_index = next( + (i for i, item in enumerate(old_history) if item.item_id == event.item_id), None + ) + if existing_index is not None: + new_history = old_history.copy() + new_history[existing_index] = event + return new_history + # Otherwise, insert it after the previous_item_id if that is set + elif event.previous_item_id: + # Insert the new item after the previous item + previous_index = next( + (i for i, item in enumerate(old_history) if item.item_id == event.previous_item_id), + None, + ) + if previous_index is not None: + new_history = old_history.copy() + new_history.insert(previous_index + 1, event) + return new_history + + # Otherwise, add it to the end + return old_history + [event] + + async def _run_output_guardrails(self, text: str) -> bool: + """Run output guardrails on the given text. Returns True if any guardrail was triggered.""" + output_guardrails = self._run_config.get("output_guardrails", []) + if not output_guardrails or self._interrupted_by_guardrail: + return False + + triggered_results = [] + + for guardrail in output_guardrails: + try: + result = await guardrail.run( + # TODO (rm) Remove this cast, it's wrong + self._context_wrapper, + cast(Agent[Any], self._current_agent), + text, + ) + if result.output.tripwire_triggered: + triggered_results.append(result) + except Exception: + # Continue with other guardrails if one fails + continue + + if triggered_results: + # Mark as interrupted to prevent multiple interrupts + self._interrupted_by_guardrail = True + + # Emit guardrail tripped event + await self._put_event( + RealtimeGuardrailTripped( + guardrail_results=triggered_results, + message=text, + info=self._event_info, + ) + ) + + # Interrupt the model + await self._model.send_event(RealtimeModelSendInterrupt()) + + # Send guardrail triggered message + guardrail_names = [result.guardrail.get_name() for result in triggered_results] + await self._model.send_event( + RealtimeModelSendUserInput( + user_input=f"guardrail triggered: {', '.join(guardrail_names)}" + ) + ) + + return True + + return False + + def _enqueue_guardrail_task(self, text: str) -> None: + # Runs the guardrails in a separate task to avoid blocking the main loop + + task = asyncio.create_task(self._run_output_guardrails(text)) + self._guardrail_tasks.add(task) + + # Add callback to remove completed tasks and handle exceptions + task.add_done_callback(self._on_guardrail_task_done) + + def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None: + """Handle completion of a guardrail task.""" + # Remove from tracking set + self._guardrail_tasks.discard(task) + + # Check for exceptions and propagate as events + if not task.cancelled(): + exception = task.exception() + if exception: + # Create an exception event instead of raising + asyncio.create_task( + self._put_event( + RealtimeError( + info=self._event_info, + error={"message": f"Guardrail task failed: {str(exception)}"}, + ) + ) + ) + + def _cleanup_guardrail_tasks(self) -> None: + for task in self._guardrail_tasks: + if not task.done(): + task.cancel() + self._guardrail_tasks.clear() + + async def _cleanup(self) -> None: + """Clean up all resources and mark session as closed.""" + # Cancel and cleanup guardrail tasks + self._cleanup_guardrail_tasks() + + # Remove ourselves as a listener + self._model.remove_listener(self) + + # Close the model connection + await self._model.close() + + # Mark as closed + self._closed = True + + async def _get__updated_model_settings( + self, new_agent: RealtimeAgent + ) -> RealtimeSessionModelSettings: + updated_settings: RealtimeSessionModelSettings = {} + instructions, tools = await asyncio.gather( + new_agent.get_system_prompt(self._context_wrapper), + new_agent.get_all_tools(self._context_wrapper), + ) + updated_settings["instructions"] = instructions or "" + updated_settings["tools"] = tools or [] + + return updated_settings diff --git a/src/agents/run.py b/src/agents/run.py index e5f9378ec..2dd9524bb 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -32,6 +32,7 @@ ModelBehaviorError, OutputGuardrailTripwireTriggered, RunErrorDetails, + UserError, ) from .guardrail import ( InputGuardrail, @@ -43,6 +44,7 @@ from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .lifecycle import RunHooks from .logger import logger +from .memory import Session from .model_settings import ModelSettings from .models.interface import Model, ModelProvider from .models.multi_provider import MultiProvider @@ -156,6 +158,9 @@ class RunOptions(TypedDict, Generic[TContext]): previous_response_id: NotRequired[str | None] """The ID of the previous response, if any.""" + session: NotRequired[Session | None] + """The session for the run.""" + class Runner: @classmethod @@ -169,6 +174,7 @@ async def run( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + session: Session | None = None, ) -> RunResult: """Run a workflow starting at the given agent. The agent will run in a loop until a final output is generated. The loop runs like so: @@ -205,6 +211,7 @@ async def run( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + session=session, ) @classmethod @@ -218,6 +225,7 @@ def run_sync( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + session: Session | None = None, ) -> RunResult: """Run a workflow synchronously, starting at the given agent. Note that this just wraps the `run` method, so it will not work if there's already an event loop (e.g. inside an async @@ -257,6 +265,7 @@ def run_sync( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + session=session, ) @classmethod @@ -269,6 +278,7 @@ def run_streamed( hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, + session: Session | None = None, ) -> RunResultStreaming: """Run a workflow starting at the given agent in streaming mode. The returned result object contains a method you can use to stream semantic events as they are generated. @@ -305,6 +315,7 @@ def run_streamed( hooks=hooks, run_config=run_config, previous_response_id=previous_response_id, + session=session, ) @@ -325,11 +336,15 @@ async def run( hooks = kwargs.get("hooks") run_config = kwargs.get("run_config") previous_response_id = kwargs.get("previous_response_id") + session = kwargs.get("session") if hooks is None: hooks = RunHooks[Any]() if run_config is None: run_config = RunConfig() + # Prepare input with session if enabled + prepared_input = await self._prepare_input_with_session(input, session) + tool_use_tracker = AgentToolUseTracker() with TraceCtxManager( @@ -340,7 +355,7 @@ async def run( disabled=run_config.tracing_disabled, ): current_turn = 0 - original_input: str | list[TResponseInputItem] = copy.deepcopy(input) + original_input: str | list[TResponseInputItem] = copy.deepcopy(prepared_input) generated_items: list[RunItem] = [] model_responses: list[ModelResponse] = [] @@ -399,7 +414,7 @@ async def run( starting_agent, starting_agent.input_guardrails + (run_config.input_guardrails or []), - copy.deepcopy(input), + copy.deepcopy(prepared_input), context_wrapper, ), self._run_single_turn( @@ -441,7 +456,7 @@ async def run( turn_result.next_step.output, context_wrapper, ) - return RunResult( + result = RunResult( input=original_input, new_items=generated_items, raw_responses=model_responses, @@ -451,6 +466,11 @@ async def run( output_guardrail_results=output_guardrail_results, context_wrapper=context_wrapper, ) + + # Save the conversation to session if enabled + await self._save_result_to_session(session, input, result) + + return result elif isinstance(turn_result.next_step, NextStepHandoff): current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) current_span.finish(reset_current=True) @@ -488,10 +508,13 @@ def run_sync( hooks = kwargs.get("hooks") run_config = kwargs.get("run_config") previous_response_id = kwargs.get("previous_response_id") + session = kwargs.get("session") + return asyncio.get_event_loop().run_until_complete( self.run( starting_agent, input, + session=session, context=context, max_turns=max_turns, hooks=hooks, @@ -511,6 +534,8 @@ def run_streamed( hooks = kwargs.get("hooks") run_config = kwargs.get("run_config") previous_response_id = kwargs.get("previous_response_id") + session = kwargs.get("session") + if hooks is None: hooks = RunHooks[Any]() if run_config is None: @@ -563,6 +588,7 @@ def run_streamed( context_wrapper=context_wrapper, run_config=run_config, previous_response_id=previous_response_id, + session=session, ) ) return streamed_result @@ -621,6 +647,7 @@ async def _start_streaming( context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, previous_response_id: str | None, + session: Session | None, ): if streamed_result.trace: streamed_result.trace.start(mark_as_current=True) @@ -634,6 +661,12 @@ async def _start_streaming( streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent)) try: + # Prepare input with session if enabled + prepared_input = await AgentRunner._prepare_input_with_session(starting_input, session) + + # Update the streamed result with the prepared input + streamed_result.input = prepared_input + while True: if streamed_result.is_complete: break @@ -680,7 +713,7 @@ async def _start_streaming( cls._run_input_guardrails_with_queue( starting_agent, starting_agent.input_guardrails + (run_config.input_guardrails or []), - copy.deepcopy(ItemHelpers.input_to_new_input_list(starting_input)), + copy.deepcopy(ItemHelpers.input_to_new_input_list(prepared_input)), context_wrapper, streamed_result, current_span, @@ -734,6 +767,23 @@ async def _start_streaming( streamed_result.output_guardrail_results = output_guardrail_results streamed_result.final_output = turn_result.next_step.output streamed_result.is_complete = True + + # Save the conversation to session if enabled + # Create a temporary RunResult for session saving + temp_result = RunResult( + input=streamed_result.input, + new_items=streamed_result.new_items, + raw_responses=streamed_result.raw_responses, + final_output=streamed_result.final_output, + _last_agent=current_agent, + input_guardrail_results=streamed_result.input_guardrail_results, + output_guardrail_results=streamed_result.output_guardrail_results, + context_wrapper=context_wrapper, + ) + await AgentRunner._save_result_to_session( + session, starting_input, temp_result + ) + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) elif isinstance(turn_result.next_step, NextStepRunAgain): pass @@ -1136,5 +1186,57 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: return run_config.model_provider.get_model(agent.model) + @classmethod + async def _prepare_input_with_session( + cls, + input: str | list[TResponseInputItem], + session: Session | None, + ) -> str | list[TResponseInputItem]: + """Prepare input by combining it with session history if enabled.""" + if session is None: + return input + + # Validate that we don't have both a session and a list input, as this creates + # ambiguity about whether the list should append to or replace existing session history + if isinstance(input, list): + raise UserError( + "Cannot provide both a session and a list of input items. " + "When using session memory, provide only a string input to append to the " + "conversation, or use session=None and provide a list to manually manage " + "conversation history." + ) + + # Get previous conversation history + history = await session.get_items() + + # Convert input to list format + new_input_list = ItemHelpers.input_to_new_input_list(input) + + # Combine history with new input + combined_input = history + new_input_list + + return combined_input + + @classmethod + async def _save_result_to_session( + cls, + session: Session | None, + original_input: str | list[TResponseInputItem], + result: RunResult, + ) -> None: + """Save the conversation turn to session.""" + if session is None: + return + + # Convert original input to list format if needed + input_list = ItemHelpers.input_to_new_input_list(original_input) + + # Convert new items to input format + new_items_as_input = [item.to_input_item() for item in result.new_items] + + # Save all items from this turn + items_to_save = input_list + new_items_as_input + await session.add_items(items_to_save) + DEFAULT_AGENT_RUNNER = AgentRunner() diff --git a/src/agents/tool.py b/src/agents/tool.py index 3aab47752..b967e899b 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -30,8 +30,7 @@ from .util._types import MaybeAwaitable if TYPE_CHECKING: - - from .agent import Agent + from .agent import Agent, AgentBase ToolParams = ParamSpec("ToolParams") @@ -88,7 +87,7 @@ class FunctionTool: """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, as it increases the likelihood of correct JSON input.""" - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True """Whether the tool is enabled. Either a bool or a Callable that takes the run context and agent and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool based on your context/state.""" @@ -201,7 +200,7 @@ class MCPToolApprovalFunctionResult(TypedDict): @dataclass class HostedMCPTool: """A tool that allows the LLM to use a remote MCP server. The LLM will automatically list and - call tools, without requiring a a round trip back to your code. + call tools, without requiring a round trip back to your code. If you want to run MCP servers locally via stdio, in a VPC or other non-publicly-accessible environment, or you just prefer to run tool calls locally, then you can instead use the servers in `agents.mcp` and pass `Agent(mcp_servers=[...])` to the agent.""" @@ -301,7 +300,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, ) -> FunctionTool: """Overload for usage as @function_tool (no parentheses).""" ... @@ -316,7 +315,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, ) -> Callable[[ToolFunction[...]], FunctionTool]: """Overload for usage as @function_tool(...).""" ... @@ -331,7 +330,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = default_tool_error_function, strict_mode: bool = True, - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, ) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]: """ Decorator to create a FunctionTool from a function. By default, we will: diff --git a/src/agents/tool_context.py b/src/agents/tool_context.py index c4329b8af..16845badd 100644 --- a/src/agents/tool_context.py +++ b/src/agents/tool_context.py @@ -1,5 +1,7 @@ from dataclasses import dataclass, field, fields -from typing import Any +from typing import Any, Optional + +from openai.types.responses import ResponseFunctionToolCall from .run_context import RunContextWrapper, TContext @@ -8,16 +10,26 @@ def _assert_must_pass_tool_call_id() -> str: raise ValueError("tool_call_id must be passed to ToolContext") +def _assert_must_pass_tool_name() -> str: + raise ValueError("tool_name must be passed to ToolContext") + + @dataclass class ToolContext(RunContextWrapper[TContext]): """The context of a tool call.""" + tool_name: str = field(default_factory=_assert_must_pass_tool_name) + """The name of the tool being invoked.""" + tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id) """The ID of the tool call.""" @classmethod def from_agent_context( - cls, context: RunContextWrapper[TContext], tool_call_id: str + cls, + context: RunContextWrapper[TContext], + tool_call_id: str, + tool_call: Optional[ResponseFunctionToolCall] = None, ) -> "ToolContext": """ Create a ToolContext from a RunContextWrapper. @@ -26,4 +38,5 @@ def from_agent_context( base_values: dict[str, Any] = { f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init } - return cls(tool_call_id=tool_call_id, **base_values) + tool_name = tool_call.name if tool_call is not None else _assert_must_pass_tool_name() + return cls(tool_name=tool_name, tool_call_id=tool_call_id, **base_values) diff --git a/src/agents/voice/models/openai_stt.py b/src/agents/voice/models/openai_stt.py index 1ae4ea147..b1f1b6da7 100644 --- a/src/agents/voice/models/openai_stt.py +++ b/src/agents/voice/models/openai_stt.py @@ -226,7 +226,7 @@ async def _handle_events(self) -> None: break event_type = event.get("type", "unknown") - if event_type == "conversation.item.input_audio_transcription.completed": + if event_type == "input_audio_transcription_completed": transcript = cast(str, event.get("transcript", "")) if len(transcript) > 0: self._end_turn(transcript) diff --git a/tests/mcp/test_tool_filtering.py b/tests/mcp/test_tool_filtering.py index c1ffff4b8..0127df806 100644 --- a/tests/mcp/test_tool_filtering.py +++ b/tests/mcp/test_tool_filtering.py @@ -3,6 +3,7 @@ external dependencies (processes, network connections) and ensure fast, reliable unit tests. FakeMCPServer delegates filtering logic to the real _MCPServerWithClientSession implementation. """ + import asyncio import pytest @@ -27,6 +28,7 @@ def create_test_context() -> RunContextWrapper: # === Static Tool Filtering Tests === + @pytest.mark.asyncio async def test_static_tool_filtering(): """Test all static tool filtering scenarios: allowed, blocked, both, none, etc.""" @@ -55,7 +57,7 @@ async def test_static_tool_filtering(): # Test both filters together (allowed first, then blocked) server.tool_filter = { "allowed_tool_names": ["tool1", "tool2", "tool3"], - "blocked_tool_names": ["tool3"] + "blocked_tool_names": ["tool3"], } tools = await server.list_tools(run_context, agent) assert len(tools) == 2 @@ -68,8 +70,7 @@ async def test_static_tool_filtering(): # Test helper function server.tool_filter = create_static_tool_filter( - allowed_tool_names=["tool1", "tool2"], - blocked_tool_names=["tool2"] + allowed_tool_names=["tool1", "tool2"], blocked_tool_names=["tool2"] ) tools = await server.list_tools(run_context, agent) assert len(tools) == 1 @@ -78,6 +79,7 @@ async def test_static_tool_filtering(): # === Dynamic Tool Filtering Core Tests === + @pytest.mark.asyncio async def test_dynamic_filter_sync_and_async(): """Test both synchronous and asynchronous dynamic filters""" @@ -181,6 +183,7 @@ def error_prone_filter(context: ToolFilterContext, tool: MCPTool) -> bool: # === Integration Tests === + @pytest.mark.asyncio async def test_agent_dynamic_filtering_integration(): """Test dynamic filtering integration with Agent methods""" diff --git a/tests/model_settings/test_serialization.py b/tests/model_settings/test_serialization.py index 94d11def3..6e8c65180 100644 --- a/tests/model_settings/test_serialization.py +++ b/tests/model_settings/test_serialization.py @@ -5,7 +5,7 @@ from pydantic import TypeAdapter from pydantic_core import to_json -from agents.model_settings import ModelSettings +from agents.model_settings import MCPToolChoice, ModelSettings def verify_serialization(model_settings: ModelSettings) -> None: @@ -29,6 +29,17 @@ def test_basic_serialization() -> None: verify_serialization(model_settings) +def test_mcp_tool_choice_serialization() -> None: + """Tests whether ModelSettings with MCPToolChoice can be serialized to a JSON string.""" + # First, lets create a ModelSettings instance + model_settings = ModelSettings( + temperature=0.5, + tool_choice=MCPToolChoice(server_label="mcp", name="mcp_tool"), + ) + # Now, lets serialize the ModelSettings instance to a JSON string + verify_serialization(model_settings) + + def test_all_fields_serialization() -> None: """Tests whether ModelSettings can be serialized to a JSON string.""" @@ -135,8 +146,8 @@ def test_extra_args_resolve_both_none() -> None: assert resolved.temperature == 0.5 assert resolved.top_p == 0.9 -def test_pydantic_serialization() -> None: +def test_pydantic_serialization() -> None: """Tests whether ModelSettings can be serialized with Pydantic.""" # First, lets create a ModelSettings instance diff --git a/tests/realtime/__init__.py b/tests/realtime/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/realtime/test_agent.py b/tests/realtime/test_agent.py new file mode 100644 index 000000000..aae8bc47c --- /dev/null +++ b/tests/realtime/test_agent.py @@ -0,0 +1,25 @@ +import pytest + +from agents import RunContextWrapper +from agents.realtime.agent import RealtimeAgent + + +def test_can_initialize_realtime_agent(): + agent = RealtimeAgent(name="test", instructions="Hello") + assert agent.name == "test" + assert agent.instructions == "Hello" + + +@pytest.mark.asyncio +async def test_dynamic_instructions(): + agent = RealtimeAgent(name="test") + assert agent.instructions is None + + def _instructions(ctx, agt) -> str: + assert ctx.context is None + assert agt == agent + return "Dynamic" + + agent = RealtimeAgent(name="test", instructions=_instructions) + instructions = await agent.get_system_prompt(RunContextWrapper(context=None)) + assert instructions == "Dynamic" diff --git a/tests/realtime/test_item_parsing.py b/tests/realtime/test_item_parsing.py new file mode 100644 index 000000000..ba128f7fd --- /dev/null +++ b/tests/realtime/test_item_parsing.py @@ -0,0 +1,80 @@ +from openai.types.beta.realtime.conversation_item import ConversationItem +from openai.types.beta.realtime.conversation_item_content import ConversationItemContent + +from agents.realtime.items import ( + AssistantMessageItem, + RealtimeMessageItem, + SystemMessageItem, + UserMessageItem, +) +from agents.realtime.openai_realtime import _ConversionHelper + + +def test_user_message_conversion() -> None: + item = ConversationItem( + id="123", + type="message", + role="user", + content=[ + ConversationItemContent( + id=None, audio=None, text=None, transcript=None, type="input_text" + ) + ], + ) + + converted: RealtimeMessageItem = _ConversionHelper.conversation_item_to_realtime_message_item( + item, None + ) + + assert isinstance(converted, UserMessageItem) + + item = ConversationItem( + id="123", + type="message", + role="user", + content=[ + ConversationItemContent( + id=None, audio=None, text=None, transcript=None, type="input_audio" + ) + ], + ) + + converted = _ConversionHelper.conversation_item_to_realtime_message_item(item, None) + + assert isinstance(converted, UserMessageItem) + + +def test_assistant_message_conversion() -> None: + item = ConversationItem( + id="123", + type="message", + role="assistant", + content=[ + ConversationItemContent(id=None, audio=None, text=None, transcript=None, type="text") + ], + ) + + converted: RealtimeMessageItem = _ConversionHelper.conversation_item_to_realtime_message_item( + item, None + ) + + assert isinstance(converted, AssistantMessageItem) + + +def test_system_message_conversion() -> None: + item = ConversationItem( + id="123", + type="message", + role="system", + content=[ + ConversationItemContent( + id=None, audio=None, text=None, transcript=None, type="input_text" + ) + ], + ) + + converted: RealtimeMessageItem = _ConversionHelper.conversation_item_to_realtime_message_item( + item, None + ) + + assert isinstance(converted, SystemMessageItem) diff --git a/tests/realtime/test_model_events.py b/tests/realtime/test_model_events.py new file mode 100644 index 000000000..b8696cc29 --- /dev/null +++ b/tests/realtime/test_model_events.py @@ -0,0 +1,12 @@ +from typing import get_args + +from agents.realtime.model_events import RealtimeModelEvent + + +def test_all_events_have_type() -> None: + """Test that all events have a type.""" + events = get_args(RealtimeModelEvent) + assert len(events) > 0 + for event in events: + assert event.type is not None + assert isinstance(event.type, str) diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py new file mode 100644 index 000000000..9ecc433ca --- /dev/null +++ b/tests/realtime/test_openai_realtime.py @@ -0,0 +1,378 @@ +from datetime import datetime +from typing import Any +from unittest.mock import AsyncMock, Mock, patch + +import pytest +import websockets + +from agents.exceptions import UserError +from agents.realtime.model_events import ( + RealtimeModelAudioEvent, + RealtimeModelErrorEvent, + RealtimeModelToolCallEvent, +) +from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel + + +class TestOpenAIRealtimeWebSocketModel: + """Test suite for OpenAIRealtimeWebSocketModel connection and event handling.""" + + @pytest.fixture + def model(self): + """Create a fresh model instance for each test.""" + return OpenAIRealtimeWebSocketModel() + + @pytest.fixture + def mock_websocket(self): + """Create a mock websocket connection.""" + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + mock_ws.close = AsyncMock() + return mock_ws + + +class TestConnectionLifecycle(TestOpenAIRealtimeWebSocketModel): + """Test connection establishment, configuration, and error handling.""" + + @pytest.mark.asyncio + async def test_connect_missing_api_key_raises_error(self, model): + """Test that missing API key raises UserError.""" + config: dict[str, Any] = {"initial_model_settings": {}} + + with patch.dict("os.environ", {}, clear=True): + with pytest.raises(UserError, match="API key is required"): + await model.connect(config) + + @pytest.mark.asyncio + async def test_connect_with_string_api_key(self, model, mock_websocket): + """Test successful connection with string API key.""" + config = { + "api_key": "test-api-key-123", + "initial_model_settings": {"model_name": "gpt-4o-realtime-preview"}, + } + + async def async_websocket(*args, **kwargs): + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket) as mock_connect: + with patch("asyncio.create_task") as mock_create_task: + # Mock create_task to return a mock task and properly handle the coroutine + mock_task = AsyncMock() + + def mock_create_task_func(coro): + # Properly close the coroutine to avoid RuntimeWarning + coro.close() + return mock_task + + mock_create_task.side_effect = mock_create_task_func + + await model.connect(config) + + # Verify WebSocket connection called with correct parameters + mock_connect.assert_called_once() + call_args = mock_connect.call_args + assert ( + call_args[0][0] + == "wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview" + ) + assert ( + call_args[1]["additional_headers"]["Authorization"] == "Bearer test-api-key-123" + ) + assert call_args[1]["additional_headers"]["OpenAI-Beta"] == "realtime=v1" + + # Verify task was created for message listening + mock_create_task.assert_called_once() + + # Verify internal state + assert model._websocket == mock_websocket + assert model._websocket_task is not None + assert model.model == "gpt-4o-realtime-preview" + + @pytest.mark.asyncio + async def test_connect_with_callable_api_key(self, model, mock_websocket): + """Test connection with callable API key provider.""" + + def get_api_key(): + return "callable-api-key" + + config = {"api_key": get_api_key} + + async def async_websocket(*args, **kwargs): + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + # Mock create_task to return a mock task and properly handle the coroutine + mock_task = AsyncMock() + + def mock_create_task_func(coro): + # Properly close the coroutine to avoid RuntimeWarning + coro.close() + return mock_task + + mock_create_task.side_effect = mock_create_task_func + + await model.connect(config) + # Should succeed with callable API key + assert model._websocket == mock_websocket + + @pytest.mark.asyncio + async def test_connect_with_async_callable_api_key(self, model, mock_websocket): + """Test connection with async callable API key provider.""" + + async def get_api_key(): + return "async-api-key" + + config = {"api_key": get_api_key} + + async def async_websocket(*args, **kwargs): + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + # Mock create_task to return a mock task and properly handle the coroutine + mock_task = AsyncMock() + + def mock_create_task_func(coro): + # Properly close the coroutine to avoid RuntimeWarning + coro.close() + return mock_task + + mock_create_task.side_effect = mock_create_task_func + + await model.connect(config) + assert model._websocket == mock_websocket + + @pytest.mark.asyncio + async def test_connect_websocket_failure_propagates(self, model): + """Test that WebSocket connection failures are properly propagated.""" + config = {"api_key": "test-key"} + + with patch( + "websockets.connect", side_effect=websockets.exceptions.ConnectionClosed(None, None) + ): + with pytest.raises(websockets.exceptions.ConnectionClosed): + await model.connect(config) + + # Verify internal state remains clean after failure + assert model._websocket is None + assert model._websocket_task is None + + @pytest.mark.asyncio + async def test_connect_already_connected_assertion(self, model, mock_websocket): + """Test that connecting when already connected raises assertion error.""" + model._websocket = mock_websocket # Simulate already connected + + config = {"api_key": "test-key"} + + with pytest.raises(AssertionError, match="Already connected"): + await model.connect(config) + + +class TestEventHandlingRobustness(TestOpenAIRealtimeWebSocketModel): + """Test event parsing, validation, and error handling robustness.""" + + @pytest.mark.asyncio + async def test_handle_malformed_json_logs_error_continues(self, model): + """Test that malformed JSON emits error event but doesn't crash.""" + mock_listener = AsyncMock() + model.add_listener(mock_listener) + + # Malformed JSON should not crash the handler + await model._handle_ws_event("invalid json {") + + # Should emit error event to listeners + mock_listener.on_event.assert_called_once() + error_event = mock_listener.on_event.call_args[0][0] + assert error_event.type == "error" + + @pytest.mark.asyncio + async def test_handle_invalid_event_schema_logs_error(self, model): + """Test that events with invalid schema emit error events but don't crash.""" + mock_listener = AsyncMock() + model.add_listener(mock_listener) + + invalid_event = {"type": "response.audio.delta"} # Missing required fields + + await model._handle_ws_event(invalid_event) + + # Should emit error event to listeners + mock_listener.on_event.assert_called_once() + error_event = mock_listener.on_event.call_args[0][0] + assert error_event.type == "error" + + @pytest.mark.asyncio + async def test_handle_unknown_event_type_ignored(self, model): + """Test that unknown event types are ignored gracefully.""" + mock_listener = AsyncMock() + model.add_listener(mock_listener) + + # Create a well-formed but unknown event type + unknown_event = {"type": "unknown.event.type", "data": "some data"} + + # Should not raise error or log anything for unknown types + with patch("agents.realtime.openai_realtime.logger"): + await model._handle_ws_event(unknown_event) + + # Should not log errors for unknown events (they're just ignored) + # This will depend on the TypeAdapter validation behavior + # If it fails validation, it should log; if it passes but type is + # unknown, it should be ignored + pass + + @pytest.mark.asyncio + async def test_handle_audio_delta_event_success(self, model): + """Test successful handling of audio delta events.""" + mock_listener = AsyncMock() + model.add_listener(mock_listener) + + # Valid audio delta event (minimal required fields for OpenAI spec) + audio_event = { + "type": "response.audio.delta", + "event_id": "event_123", + "response_id": "resp_123", + "item_id": "item_456", + "output_index": 0, + "content_index": 0, + "delta": "dGVzdCBhdWRpbw==", # base64 encoded "test audio" + } + + with patch("agents.realtime.openai_realtime.datetime") as mock_datetime: + mock_now = datetime(2024, 1, 1, 12, 0, 0) + mock_datetime.now.return_value = mock_now + + await model._handle_ws_event(audio_event) + + # Should emit audio event to listeners + mock_listener.on_event.assert_called_once() + emitted_event = mock_listener.on_event.call_args[0][0] + assert isinstance(emitted_event, RealtimeModelAudioEvent) + assert emitted_event.response_id == "resp_123" + assert emitted_event.data == b"test audio" # decoded from base64 + + # Should update internal audio tracking state + assert model._current_item_id == "item_456" + assert model._current_audio_content_index == 0 + assert model._audio_start_time == mock_now + + @pytest.mark.asyncio + async def test_handle_error_event_success(self, model): + """Test successful handling of error events.""" + mock_listener = AsyncMock() + model.add_listener(mock_listener) + + error_event = { + "type": "error", + "event_id": "event_456", + "error": { + "type": "invalid_request_error", + "code": "invalid_api_key", + "message": "Invalid API key provided", + }, + } + + await model._handle_ws_event(error_event) + + # Should emit error event to listeners + mock_listener.on_event.assert_called_once() + emitted_event = mock_listener.on_event.call_args[0][0] + assert isinstance(emitted_event, RealtimeModelErrorEvent) + + @pytest.mark.asyncio + async def test_handle_tool_call_event_success(self, model): + """Test successful handling of function call events.""" + mock_listener = AsyncMock() + model.add_listener(mock_listener) + + # Test response.output_item.done with function_call + tool_call_event = { + "type": "response.output_item.done", + "event_id": "event_789", + "response_id": "resp_789", + "output_index": 0, + "item": { + "id": "call_123", + "type": "function_call", + "status": "completed", + "name": "get_weather", + "arguments": '{"location": "San Francisco"}', + }, + } + + await model._handle_ws_event(tool_call_event) + + # Should emit both item updated and tool call events + assert mock_listener.on_event.call_count == 2 + + # First should be item updated, second should be tool call + calls = mock_listener.on_event.call_args_list + tool_call_emitted = calls[1][0][0] + assert isinstance(tool_call_emitted, RealtimeModelToolCallEvent) + assert tool_call_emitted.name == "get_weather" + assert tool_call_emitted.arguments == '{"location": "San Francisco"}' + assert tool_call_emitted.call_id == "call_123" + + @pytest.mark.asyncio + async def test_audio_timing_calculation_accuracy(self, model): + """Test that audio timing calculations are accurate for interruption handling.""" + mock_listener = AsyncMock() + model.add_listener(mock_listener) + + # Send multiple audio deltas to test cumulative timing + audio_deltas = [ + { + "type": "response.audio.delta", + "event_id": "event_1", + "response_id": "resp_1", + "item_id": "item_1", + "output_index": 0, + "content_index": 0, + "delta": "dGVzdA==", # 4 bytes -> "test" + }, + { + "type": "response.audio.delta", + "event_id": "event_2", + "response_id": "resp_1", + "item_id": "item_1", + "output_index": 0, + "content_index": 0, + "delta": "bW9yZQ==", # 4 bytes -> "more" + }, + ] + + for event in audio_deltas: + await model._handle_ws_event(event) + + # Should accumulate audio length: 8 bytes / 24 / 2 = ~0.167ms per byte + # Total: 8 bytes / 24 / 2 = 0.167ms + expected_length = 8 / 24 / 2 + assert abs(model._audio_length_ms - expected_length) < 0.001 + + def test_calculate_audio_length_ms_pure_function(self, model): + """Test the pure audio length calculation function.""" + # Test various audio buffer sizes + assert model._calculate_audio_length_ms(b"test") == 4 / 24 / 2 # 4 bytes + assert model._calculate_audio_length_ms(b"") == 0 # empty + assert model._calculate_audio_length_ms(b"a" * 48) == 1.0 # exactly 1ms worth + + @pytest.mark.asyncio + async def test_handle_audio_delta_state_management(self, model): + """Test that _handle_audio_delta properly manages internal state.""" + # Create mock parsed event + mock_parsed = Mock() + mock_parsed.content_index = 5 + mock_parsed.item_id = "test_item" + mock_parsed.delta = "dGVzdA==" # "test" in base64 + mock_parsed.response_id = "resp_123" + + with patch("agents.realtime.openai_realtime.datetime") as mock_datetime: + mock_now = datetime(2024, 1, 1, 12, 0, 0) + mock_datetime.now.return_value = mock_now + + await model._handle_audio_delta(mock_parsed) + + # Check state was updated correctly + assert model._current_audio_content_index == 5 + assert model._current_item_id == "test_item" + assert model._audio_start_time == mock_now + assert model._audio_length_ms == 4 / 24 / 2 # 4 bytes diff --git a/tests/realtime/test_runner.py b/tests/realtime/test_runner.py new file mode 100644 index 000000000..aabdff140 --- /dev/null +++ b/tests/realtime/test_runner.py @@ -0,0 +1,224 @@ +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from inline_snapshot import snapshot + +from agents.realtime.agent import RealtimeAgent +from agents.realtime.config import RealtimeRunConfig, RealtimeSessionModelSettings +from agents.realtime.model import RealtimeModel, RealtimeModelConfig +from agents.realtime.runner import RealtimeRunner +from agents.realtime.session import RealtimeSession + + +class MockRealtimeModel(RealtimeModel): + async def connect(self, options=None): + pass + + def add_listener(self, listener): + pass + + def remove_listener(self, listener): + pass + + async def send_event(self, event): + pass + + async def send_message(self, message, other_event_data=None): + pass + + async def send_audio(self, audio, commit=False): + pass + + async def send_tool_output(self, tool_call, output, start_response=True): + pass + + async def interrupt(self): + pass + + async def close(self): + pass + + +@pytest.fixture +def mock_agent(): + agent = Mock(spec=RealtimeAgent) + agent.get_system_prompt = AsyncMock(return_value="Test instructions") + agent.get_all_tools = AsyncMock(return_value=[{"type": "function", "name": "test_tool"}]) + return agent + + +@pytest.fixture +def mock_model(): + return MockRealtimeModel() + + +@pytest.mark.asyncio +async def test_run_creates_session_with_no_settings(mock_agent, mock_model): + """Test that run() creates a session correctly if no settings are provided""" + runner = RealtimeRunner(mock_agent, model=mock_model) + + with patch("agents.realtime.runner.RealtimeSession") as mock_session_class: + mock_session = Mock(spec=RealtimeSession) + mock_session_class.return_value = mock_session + + session = await runner.run() + + # Verify session was created with correct parameters + mock_session_class.assert_called_once() + call_args = mock_session_class.call_args + + assert call_args[1]["model"] == mock_model + assert call_args[1]["agent"] == mock_agent + assert call_args[1]["context"] is None + + # Verify model_config contains expected settings from agent + model_config = call_args[1]["model_config"] + assert model_config == snapshot( + { + "initial_model_settings": { + "instructions": "Test instructions", + "tools": [{"type": "function", "name": "test_tool"}], + } + } + ) + + assert session == mock_session + + +@pytest.mark.asyncio +async def test_run_creates_session_with_settings_only_in_init(mock_agent, mock_model): + """Test that it creates a session with the right settings if they are provided only in init""" + config = RealtimeRunConfig( + model_settings=RealtimeSessionModelSettings(model_name="gpt-4o-realtime", voice="nova") + ) + runner = RealtimeRunner(mock_agent, model=mock_model, config=config) + + with patch("agents.realtime.runner.RealtimeSession") as mock_session_class: + mock_session = Mock(spec=RealtimeSession) + mock_session_class.return_value = mock_session + + _ = await runner.run() + + # Verify session was created with config overrides + call_args = mock_session_class.call_args + model_config = call_args[1]["model_config"] + + # Should have agent settings plus config overrides + assert model_config == snapshot( + { + "initial_model_settings": { + "instructions": "Test instructions", + "tools": [{"type": "function", "name": "test_tool"}], + "model_name": "gpt-4o-realtime", + "voice": "nova", + } + } + ) + + +@pytest.mark.asyncio +async def test_run_creates_session_with_settings_in_both_init_and_run_overrides( + mock_agent, mock_model +): + """Test settings in both init and run() - init should override run()""" + init_config = RealtimeRunConfig( + model_settings=RealtimeSessionModelSettings(model_name="gpt-4o-realtime", voice="nova") + ) + runner = RealtimeRunner(mock_agent, model=mock_model, config=init_config) + + run_model_config: RealtimeModelConfig = { + "initial_model_settings": RealtimeSessionModelSettings( + voice="alloy", input_audio_format="pcm16" + ) + } + + with patch("agents.realtime.runner.RealtimeSession") as mock_session_class: + mock_session = Mock(spec=RealtimeSession) + mock_session_class.return_value = mock_session + + _ = await runner.run(model_config=run_model_config) + + # Verify run() settings override init settings + call_args = mock_session_class.call_args + model_config = call_args[1]["model_config"] + + # Should have agent settings, then init config, then run config overrides + assert model_config == snapshot( + { + "initial_model_settings": { + "voice": "nova", + "input_audio_format": "pcm16", + "instructions": "Test instructions", + "tools": [{"type": "function", "name": "test_tool"}], + "model_name": "gpt-4o-realtime", + } + } + ) + + +@pytest.mark.asyncio +async def test_run_creates_session_with_settings_only_in_run(mock_agent, mock_model): + """Test settings provided only in run()""" + runner = RealtimeRunner(mock_agent, model=mock_model) + + run_model_config: RealtimeModelConfig = { + "initial_model_settings": RealtimeSessionModelSettings( + model_name="gpt-4o-realtime-preview", voice="shimmer", modalities=["text", "audio"] + ) + } + + with patch("agents.realtime.runner.RealtimeSession") as mock_session_class: + mock_session = Mock(spec=RealtimeSession) + mock_session_class.return_value = mock_session + + _ = await runner.run(model_config=run_model_config) + + # Verify run() settings are applied + call_args = mock_session_class.call_args + model_config = call_args[1]["model_config"] + + # Should have agent settings plus run() settings + assert model_config == snapshot( + { + "initial_model_settings": { + "model_name": "gpt-4o-realtime-preview", + "voice": "shimmer", + "modalities": ["text", "audio"], + "instructions": "Test instructions", + "tools": [{"type": "function", "name": "test_tool"}], + } + } + ) + + +@pytest.mark.asyncio +async def test_run_with_context_parameter(mock_agent, mock_model): + """Test that context parameter is passed through to session""" + runner = RealtimeRunner(mock_agent, model=mock_model) + test_context = {"user_id": "test123"} + + with patch("agents.realtime.runner.RealtimeSession") as mock_session_class: + mock_session = Mock(spec=RealtimeSession) + mock_session_class.return_value = mock_session + + await runner.run(context=test_context) + + call_args = mock_session_class.call_args + assert call_args[1]["context"] == test_context + + +@pytest.mark.asyncio +async def test_get_model_settings_with_none_values(mock_model): + """Test _get_model_settings handles None values from agent properly""" + agent = Mock(spec=RealtimeAgent) + agent.get_system_prompt = AsyncMock(return_value=None) + agent.get_all_tools = AsyncMock(return_value=None) + + runner = RealtimeRunner(agent, model=mock_model) + + with patch("agents.realtime.runner.RealtimeSession"): + await runner.run() + + # Should not crash and agent methods should be called + agent.get_system_prompt.assert_called_once() + agent.get_all_tools.assert_called_once() diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py new file mode 100644 index 000000000..4cc0dae6b --- /dev/null +++ b/tests/realtime/test_session.py @@ -0,0 +1,1210 @@ +from typing import cast +from unittest.mock import AsyncMock, Mock + +import pytest + +from agents.guardrail import GuardrailFunctionOutput, OutputGuardrail +from agents.handoffs import Handoff +from agents.realtime.agent import RealtimeAgent +from agents.realtime.config import RealtimeRunConfig +from agents.realtime.events import ( + RealtimeAgentEndEvent, + RealtimeAgentStartEvent, + RealtimeAudio, + RealtimeAudioEnd, + RealtimeAudioInterrupted, + RealtimeError, + RealtimeGuardrailTripped, + RealtimeHistoryAdded, + RealtimeHistoryUpdated, + RealtimeRawModelEvent, + RealtimeToolEnd, + RealtimeToolStart, +) +from agents.realtime.items import ( + AssistantMessageItem, + AssistantText, + InputAudio, + InputText, + RealtimeItem, + UserMessageItem, +) +from agents.realtime.model import RealtimeModel +from agents.realtime.model_events import ( + RealtimeModelAudioDoneEvent, + RealtimeModelAudioEvent, + RealtimeModelAudioInterruptedEvent, + RealtimeModelConnectionStatusEvent, + RealtimeModelErrorEvent, + RealtimeModelInputAudioTranscriptionCompletedEvent, + RealtimeModelItemDeletedEvent, + RealtimeModelItemUpdatedEvent, + RealtimeModelOtherEvent, + RealtimeModelToolCallEvent, + RealtimeModelTranscriptDeltaEvent, + RealtimeModelTurnEndedEvent, + RealtimeModelTurnStartedEvent, +) +from agents.realtime.session import RealtimeSession +from agents.tool import FunctionTool +from agents.tool_context import ToolContext + + +class MockRealtimeModel(RealtimeModel): + def __init__(self): + super().__init__() + self.listeners = [] + self.connect_called = False + self.close_called = False + self.sent_events = [] + # Legacy tracking for tests that haven't been updated yet + self.sent_messages = [] + self.sent_audio = [] + self.sent_tool_outputs = [] + self.interrupts_called = 0 + + async def connect(self, options=None): + self.connect_called = True + + def add_listener(self, listener): + self.listeners.append(listener) + + def remove_listener(self, listener): + if listener in self.listeners: + self.listeners.remove(listener) + + async def send_event(self, event): + from agents.realtime.model_inputs import ( + RealtimeModelSendAudio, + RealtimeModelSendInterrupt, + RealtimeModelSendToolOutput, + RealtimeModelSendUserInput, + ) + + self.sent_events.append(event) + + # Update legacy tracking for compatibility + if isinstance(event, RealtimeModelSendUserInput): + self.sent_messages.append(event.user_input) + elif isinstance(event, RealtimeModelSendAudio): + self.sent_audio.append((event.audio, event.commit)) + elif isinstance(event, RealtimeModelSendToolOutput): + self.sent_tool_outputs.append((event.tool_call, event.output, event.start_response)) + elif isinstance(event, RealtimeModelSendInterrupt): + self.interrupts_called += 1 + + async def close(self): + self.close_called = True + + +@pytest.fixture +def mock_agent(): + agent = Mock(spec=RealtimeAgent) + agent.get_all_tools = AsyncMock(return_value=[]) + return agent + + +@pytest.fixture +def mock_model(): + return MockRealtimeModel() + + +@pytest.fixture +def mock_function_tool(): + tool = Mock(spec=FunctionTool) + tool.name = "test_function" + tool.on_invoke_tool = AsyncMock(return_value="function_result") + return tool + + +@pytest.fixture +def mock_handoff(): + handoff = Mock(spec=Handoff) + handoff.name = "test_handoff" + return handoff + + +class TestEventHandling: + """Test suite for event handling and transformation in RealtimeSession.on_event""" + + @pytest.mark.asyncio + async def test_error_event_transformation(self, mock_model, mock_agent): + """Test that error events are properly transformed and queued""" + session = RealtimeSession(mock_model, mock_agent, None) + + error_event = RealtimeModelErrorEvent(error="Test error") + + await session.on_event(error_event) + + # Check that events were queued + assert session._event_queue.qsize() == 2 + + # First event should be raw model event + raw_event = await session._event_queue.get() + assert isinstance(raw_event, RealtimeRawModelEvent) + assert raw_event.data == error_event + + # Second event should be transformed error event + error_session_event = await session._event_queue.get() + assert isinstance(error_session_event, RealtimeError) + assert error_session_event.error == "Test error" + + @pytest.mark.asyncio + async def test_audio_events_transformation(self, mock_model, mock_agent): + """Test that audio-related events are properly transformed""" + session = RealtimeSession(mock_model, mock_agent, None) + + # Test audio event + audio_event = RealtimeModelAudioEvent(data=b"audio_data", response_id="resp_1") + await session.on_event(audio_event) + + # Test audio interrupted event + interrupted_event = RealtimeModelAudioInterruptedEvent() + await session.on_event(interrupted_event) + + # Test audio done event + done_event = RealtimeModelAudioDoneEvent() + await session.on_event(done_event) + + # Should have 6 events total (2 per event: raw + transformed) + assert session._event_queue.qsize() == 6 + + # Check audio event transformation + await session._event_queue.get() # raw event + audio_session_event = await session._event_queue.get() + assert isinstance(audio_session_event, RealtimeAudio) + assert audio_session_event.audio == audio_event + + # Check audio interrupted transformation + await session._event_queue.get() # raw event + interrupted_session_event = await session._event_queue.get() + assert isinstance(interrupted_session_event, RealtimeAudioInterrupted) + + # Check audio done transformation + await session._event_queue.get() # raw event + done_session_event = await session._event_queue.get() + assert isinstance(done_session_event, RealtimeAudioEnd) + + @pytest.mark.asyncio + async def test_turn_events_transformation(self, mock_model, mock_agent): + """Test that turn start/end events are properly transformed""" + session = RealtimeSession(mock_model, mock_agent, None) + + # Test turn started event + turn_started = RealtimeModelTurnStartedEvent() + await session.on_event(turn_started) + + # Test turn ended event + turn_ended = RealtimeModelTurnEndedEvent() + await session.on_event(turn_ended) + + # Should have 4 events total (2 per event: raw + transformed) + assert session._event_queue.qsize() == 4 + + # Check turn started transformation + await session._event_queue.get() # raw event + start_session_event = await session._event_queue.get() + assert isinstance(start_session_event, RealtimeAgentStartEvent) + assert start_session_event.agent == mock_agent + + # Check turn ended transformation + await session._event_queue.get() # raw event + end_session_event = await session._event_queue.get() + assert isinstance(end_session_event, RealtimeAgentEndEvent) + assert end_session_event.agent == mock_agent + + @pytest.mark.asyncio + async def test_transcription_completed_event_updates_history(self, mock_model, mock_agent): + """Test that transcription completed events update history and emit events""" + session = RealtimeSession(mock_model, mock_agent, None) + + # Set up initial history with an audio message + initial_item = UserMessageItem( + item_id="item_1", role="user", content=[InputAudio(transcript=None)] + ) + session._history = [initial_item] + + # Create transcription completed event + transcription_event = RealtimeModelInputAudioTranscriptionCompletedEvent( + item_id="item_1", transcript="Hello world" + ) + + await session.on_event(transcription_event) + + # Check that history was updated + assert len(session._history) == 1 + updated_item = session._history[0] + assert updated_item.content[0].transcript == "Hello world" # type: ignore + assert updated_item.status == "completed" # type: ignore + + # Should have 2 events: raw + history updated + assert session._event_queue.qsize() == 2 + + await session._event_queue.get() # raw event + history_event = await session._event_queue.get() + assert isinstance(history_event, RealtimeHistoryUpdated) + assert len(history_event.history) == 1 + + @pytest.mark.asyncio + async def test_item_updated_event_adds_new_item(self, mock_model, mock_agent): + """Test that item_updated events add new items to history""" + session = RealtimeSession(mock_model, mock_agent, None) + + new_item = AssistantMessageItem( + item_id="new_item", role="assistant", content=[AssistantText(text="Hello")] + ) + + item_updated_event = RealtimeModelItemUpdatedEvent(item=new_item) + + await session.on_event(item_updated_event) + + # Check that item was added to history + assert len(session._history) == 1 + assert session._history[0] == new_item + + # Should have 2 events: raw + history added + assert session._event_queue.qsize() == 2 + + await session._event_queue.get() # raw event + history_event = await session._event_queue.get() + assert isinstance(history_event, RealtimeHistoryAdded) + assert history_event.item == new_item + + @pytest.mark.asyncio + async def test_item_updated_event_updates_existing_item(self, mock_model, mock_agent): + """Test that item_updated events update existing items in history""" + session = RealtimeSession(mock_model, mock_agent, None) + + # Set up initial history + initial_item = AssistantMessageItem( + item_id="existing_item", role="assistant", content=[AssistantText(text="Initial")] + ) + session._history = [initial_item] + + # Create updated version + updated_item = AssistantMessageItem( + item_id="existing_item", role="assistant", content=[AssistantText(text="Updated")] + ) + + item_updated_event = RealtimeModelItemUpdatedEvent(item=updated_item) + + await session.on_event(item_updated_event) + + # Check that item was updated + assert len(session._history) == 1 + updated_item = cast(AssistantMessageItem, session._history[0]) + assert updated_item.content[0].text == "Updated" + + # Should have 2 events: raw + history updated (not added) + assert session._event_queue.qsize() == 2 + + await session._event_queue.get() # raw event + history_event = await session._event_queue.get() + assert isinstance(history_event, RealtimeHistoryUpdated) + + @pytest.mark.asyncio + async def test_item_deleted_event_removes_item(self, mock_model, mock_agent): + """Test that item_deleted events remove items from history""" + session = RealtimeSession(mock_model, mock_agent, None) + + # Set up initial history with multiple items + item1 = AssistantMessageItem( + item_id="item_1", role="assistant", content=[AssistantText(text="First")] + ) + item2 = AssistantMessageItem( + item_id="item_2", role="assistant", content=[AssistantText(text="Second")] + ) + session._history = [item1, item2] + + # Delete first item + delete_event = RealtimeModelItemDeletedEvent(item_id="item_1") + + await session.on_event(delete_event) + + # Check that item was removed + assert len(session._history) == 1 + assert session._history[0].item_id == "item_2" + + # Should have 2 events: raw + history updated + assert session._event_queue.qsize() == 2 + + await session._event_queue.get() # raw event + history_event = await session._event_queue.get() + assert isinstance(history_event, RealtimeHistoryUpdated) + assert len(history_event.history) == 1 + + @pytest.mark.asyncio + async def test_ignored_events_only_generate_raw_events(self, mock_model, mock_agent): + """Test that ignored events (transcript_delta, connection_status, other) only generate raw + events""" + session = RealtimeSession(mock_model, mock_agent, None) + + # Test transcript delta (should be ignored per TODO comment) + transcript_event = RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="hello", response_id="resp_1" + ) + await session.on_event(transcript_event) + + # Test connection status (should be ignored) + connection_event = RealtimeModelConnectionStatusEvent(status="connected") + await session.on_event(connection_event) + + # Test other event (should be ignored) + other_event = RealtimeModelOtherEvent(data={"custom": "data"}) + await session.on_event(other_event) + + # Should only have 3 raw events (no transformed events) + assert session._event_queue.qsize() == 3 + + for _ in range(3): + event = await session._event_queue.get() + assert isinstance(event, RealtimeRawModelEvent) + + @pytest.mark.asyncio + async def test_function_call_event_triggers_tool_handling(self, mock_model, mock_agent): + """Test that function_call events trigger tool call handling""" + session = RealtimeSession(mock_model, mock_agent, None) + + # Create function call event + function_call_event = RealtimeModelToolCallEvent( + name="test_function", call_id="call_123", arguments='{"param": "value"}' + ) + + # We'll test the detailed tool handling in a separate test class + # Here we just verify that it gets to the handler + with pytest.MonkeyPatch().context() as m: + handle_tool_call_mock = AsyncMock() + m.setattr(session, "_handle_tool_call", handle_tool_call_mock) + + await session.on_event(function_call_event) + + # Should have called the tool handler + handle_tool_call_mock.assert_called_once_with(function_call_event) + + # Should still have raw event + assert session._event_queue.qsize() == 1 + raw_event = await session._event_queue.get() + assert isinstance(raw_event, RealtimeRawModelEvent) + assert raw_event.data == function_call_event + + +class TestHistoryManagement: + """Test suite for history management and audio transcription in + RealtimeSession._get_new_history""" + + def test_merge_transcript_into_existing_audio_message(self): + """Test merging audio transcript into existing placeholder input_audio message""" + # Create initial history with audio message without transcript + initial_item = UserMessageItem( + item_id="item_1", + role="user", + content=[ + InputText(text="Before audio"), + InputAudio(transcript=None, audio="audio_data"), + InputText(text="After audio"), + ], + ) + old_history = [initial_item] + + # Create transcription completed event + transcription_event = RealtimeModelInputAudioTranscriptionCompletedEvent( + item_id="item_1", transcript="Hello world" + ) + + # Apply the history update + new_history = RealtimeSession._get_new_history( + cast(list[RealtimeItem], old_history), transcription_event + ) + + # Verify the transcript was merged + assert len(new_history) == 1 + updated_item = cast(UserMessageItem, new_history[0]) + assert updated_item.item_id == "item_1" + assert hasattr(updated_item, "status") and updated_item.status == "completed" + assert len(updated_item.content) == 3 + + # Check that audio content got transcript but other content unchanged + assert cast(InputText, updated_item.content[0]).text == "Before audio" + assert cast(InputAudio, updated_item.content[1]).transcript == "Hello world" + # Should preserve audio data + assert cast(InputAudio, updated_item.content[1]).audio == "audio_data" + assert cast(InputText, updated_item.content[2]).text == "After audio" + + def test_merge_transcript_preserves_other_items(self): + """Test that merging transcript preserves other items in history""" + # Create history with multiple items + item1 = UserMessageItem( + item_id="item_1", role="user", content=[InputText(text="First message")] + ) + item2 = UserMessageItem( + item_id="item_2", role="user", content=[InputAudio(transcript=None)] + ) + item3 = AssistantMessageItem( + item_id="item_3", role="assistant", content=[AssistantText(text="Third message")] + ) + old_history = [item1, item2, item3] + + # Create transcription event for item_2 + transcription_event = RealtimeModelInputAudioTranscriptionCompletedEvent( + item_id="item_2", transcript="Transcribed audio" + ) + + new_history = RealtimeSession._get_new_history( + cast(list[RealtimeItem], old_history), transcription_event + ) + + # Should have same number of items + assert len(new_history) == 3 + + # First and third items should be unchanged + assert new_history[0] == item1 + assert new_history[2] == item3 + + # Second item should have transcript + updated_item2 = cast(UserMessageItem, new_history[1]) + assert updated_item2.item_id == "item_2" + assert cast(InputAudio, updated_item2.content[0]).transcript == "Transcribed audio" + assert hasattr(updated_item2, "status") and updated_item2.status == "completed" + + def test_merge_transcript_only_affects_matching_audio_content(self): + """Test that transcript merge only affects audio content, not text content""" + # Create item with mixed content including multiple audio items + item = UserMessageItem( + item_id="item_1", + role="user", + content=[ + InputText(text="Text content"), + InputAudio(transcript=None, audio="audio1"), + InputAudio(transcript="existing", audio="audio2"), + InputText(text="More text"), + ], + ) + old_history = [item] + + transcription_event = RealtimeModelInputAudioTranscriptionCompletedEvent( + item_id="item_1", transcript="New transcript" + ) + + new_history = RealtimeSession._get_new_history( + cast(list[RealtimeItem], old_history), transcription_event + ) + + updated_item = cast(UserMessageItem, new_history[0]) + + # Text content should be unchanged + assert cast(InputText, updated_item.content[0]).text == "Text content" + assert cast(InputText, updated_item.content[3]).text == "More text" + + # All audio content should have the new transcript (current implementation overwrites all) + assert cast(InputAudio, updated_item.content[1]).transcript == "New transcript" + assert ( + cast(InputAudio, updated_item.content[2]).transcript == "New transcript" + ) # Implementation overwrites existing + + def test_update_existing_item_by_id(self): + """Test updating an existing item by item_id""" + # Create initial history + original_item = AssistantMessageItem( + item_id="item_1", role="assistant", content=[AssistantText(text="Original")] + ) + old_history = [original_item] + + # Create updated version of same item + updated_item = AssistantMessageItem( + item_id="item_1", role="assistant", content=[AssistantText(text="Updated")] + ) + + new_history = RealtimeSession._get_new_history( + cast(list[RealtimeItem], old_history), updated_item + ) + + # Should have same number of items + assert len(new_history) == 1 + + # Item should be updated + result_item = cast(AssistantMessageItem, new_history[0]) + assert result_item.item_id == "item_1" + assert result_item.content[0].text == "Updated" + + def test_update_existing_item_preserves_order(self): + """Test that updating existing item preserves its position in history""" + # Create history with multiple items + item1 = AssistantMessageItem( + item_id="item_1", role="assistant", content=[AssistantText(text="First")] + ) + item2 = AssistantMessageItem( + item_id="item_2", role="assistant", content=[AssistantText(text="Second")] + ) + item3 = AssistantMessageItem( + item_id="item_3", role="assistant", content=[AssistantText(text="Third")] + ) + old_history = [item1, item2, item3] + + # Update middle item + updated_item2 = AssistantMessageItem( + item_id="item_2", role="assistant", content=[AssistantText(text="Updated Second")] + ) + + new_history = RealtimeSession._get_new_history( + cast(list[RealtimeItem], old_history), updated_item2 + ) + + # Should have same number of items in same order + assert len(new_history) == 3 + assert new_history[0].item_id == "item_1" + assert new_history[1].item_id == "item_2" + assert new_history[2].item_id == "item_3" + + # Middle item should be updated + updated_result = cast(AssistantMessageItem, new_history[1]) + assert updated_result.content[0].text == "Updated Second" + + # Other items should be unchanged + item1_result = cast(AssistantMessageItem, new_history[0]) + item3_result = cast(AssistantMessageItem, new_history[2]) + assert item1_result.content[0].text == "First" + assert item3_result.content[0].text == "Third" + + def test_insert_new_item_after_previous_item(self): + """Test inserting new item after specified previous_item_id""" + # Create initial history + item1 = AssistantMessageItem( + item_id="item_1", role="assistant", content=[AssistantText(text="First")] + ) + item3 = AssistantMessageItem( + item_id="item_3", role="assistant", content=[AssistantText(text="Third")] + ) + old_history = [item1, item3] + + # Create new item to insert between them + new_item = AssistantMessageItem( + item_id="item_2", + previous_item_id="item_1", + role="assistant", + content=[AssistantText(text="Second")], + ) + + new_history = RealtimeSession._get_new_history( + cast(list[RealtimeItem], old_history), new_item + ) + + # Should have one more item + assert len(new_history) == 3 + + # Items should be in correct order + assert new_history[0].item_id == "item_1" + assert new_history[1].item_id == "item_2" + assert new_history[2].item_id == "item_3" + + # Content should be correct + item2_result = cast(AssistantMessageItem, new_history[1]) + assert item2_result.content[0].text == "Second" + + def test_insert_new_item_after_nonexistent_previous_item(self): + """Test that item with nonexistent previous_item_id gets added to end""" + # Create initial history + item1 = AssistantMessageItem( + item_id="item_1", role="assistant", content=[AssistantText(text="First")] + ) + old_history = [item1] + + # Create new item with nonexistent previous_item_id + new_item = AssistantMessageItem( + item_id="item_2", + previous_item_id="nonexistent", + role="assistant", + content=[AssistantText(text="Second")], + ) + + new_history = RealtimeSession._get_new_history( + cast(list[RealtimeItem], old_history), new_item + ) + + # Should add to end when previous_item_id not found + assert len(new_history) == 2 + assert new_history[0].item_id == "item_1" + assert new_history[1].item_id == "item_2" + + def test_add_new_item_to_end_when_no_previous_item_id(self): + """Test adding new item to end when no previous_item_id is specified""" + # Create initial history + item1 = AssistantMessageItem( + item_id="item_1", role="assistant", content=[AssistantText(text="First")] + ) + old_history = [item1] + + # Create new item without previous_item_id + new_item = AssistantMessageItem( + item_id="item_2", role="assistant", content=[AssistantText(text="Second")] + ) + + new_history = RealtimeSession._get_new_history( + cast(list[RealtimeItem], old_history), new_item + ) + + # Should add to end + assert len(new_history) == 2 + assert new_history[0].item_id == "item_1" + assert new_history[1].item_id == "item_2" + + def test_add_first_item_to_empty_history(self): + """Test adding first item to empty history""" + old_history: list[RealtimeItem] = [] + + new_item = AssistantMessageItem( + item_id="item_1", role="assistant", content=[AssistantText(text="First")] + ) + + new_history = RealtimeSession._get_new_history(old_history, new_item) + + assert len(new_history) == 1 + assert new_history[0].item_id == "item_1" + + def test_complex_insertion_scenario(self): + """Test complex scenario with multiple insertions and updates""" + # Start with items A and C + itemA = AssistantMessageItem( + item_id="A", role="assistant", content=[AssistantText(text="A")] + ) + itemC = AssistantMessageItem( + item_id="C", role="assistant", content=[AssistantText(text="C")] + ) + history: list[RealtimeItem] = [itemA, itemC] + + # Insert B after A + itemB = AssistantMessageItem( + item_id="B", previous_item_id="A", role="assistant", content=[AssistantText(text="B")] + ) + history = RealtimeSession._get_new_history(history, itemB) + + # Should be A, B, C + assert len(history) == 3 + assert [item.item_id for item in history] == ["A", "B", "C"] + + # Insert D after B + itemD = AssistantMessageItem( + item_id="D", previous_item_id="B", role="assistant", content=[AssistantText(text="D")] + ) + history = RealtimeSession._get_new_history(history, itemD) + + # Should be A, B, D, C + assert len(history) == 4 + assert [item.item_id for item in history] == ["A", "B", "D", "C"] + + # Update B + updated_itemB = AssistantMessageItem( + item_id="B", role="assistant", content=[AssistantText(text="Updated B")] + ) + history = RealtimeSession._get_new_history(history, updated_itemB) + + # Should still be A, B, D, C but B is updated + assert len(history) == 4 + assert [item.item_id for item in history] == ["A", "B", "D", "C"] + itemB_result = cast(AssistantMessageItem, history[1]) + assert itemB_result.content[0].text == "Updated B" + + +# Test 3: Tool call execution flow (_handle_tool_call method) +class TestToolCallExecution: + """Test suite for tool call execution flow in RealtimeSession._handle_tool_call""" + + @pytest.mark.asyncio + async def test_function_tool_execution_success( + self, mock_model, mock_agent, mock_function_tool + ): + """Test successful function tool execution""" + # Set up agent to return our mock tool + mock_agent.get_all_tools.return_value = [mock_function_tool] + + session = RealtimeSession(mock_model, mock_agent, None) + + # Create function call event + tool_call_event = RealtimeModelToolCallEvent( + name="test_function", call_id="call_123", arguments='{"param": "value"}' + ) + + await session._handle_tool_call(tool_call_event) + + # Verify the flow + mock_agent.get_all_tools.assert_called_once() + mock_function_tool.on_invoke_tool.assert_called_once() + + # Check the tool context was created correctly + call_args = mock_function_tool.on_invoke_tool.call_args + tool_context = call_args[0][0] + assert isinstance(tool_context, ToolContext) + assert call_args[0][1] == '{"param": "value"}' + + # Verify tool output was sent to model + assert len(mock_model.sent_tool_outputs) == 1 + sent_call, sent_output, start_response = mock_model.sent_tool_outputs[0] + assert sent_call == tool_call_event + assert sent_output == "function_result" + assert start_response is True + + # Verify events were queued + assert session._event_queue.qsize() == 2 + + # Check tool start event + tool_start_event = await session._event_queue.get() + assert isinstance(tool_start_event, RealtimeToolStart) + assert tool_start_event.tool == mock_function_tool + assert tool_start_event.agent == mock_agent + + # Check tool end event + tool_end_event = await session._event_queue.get() + assert isinstance(tool_end_event, RealtimeToolEnd) + assert tool_end_event.tool == mock_function_tool + assert tool_end_event.output == "function_result" + assert tool_end_event.agent == mock_agent + + @pytest.mark.asyncio + async def test_function_tool_with_multiple_tools_available(self, mock_model, mock_agent): + """Test function tool execution when multiple tools are available""" + # Create multiple mock tools + tool1 = Mock(spec=FunctionTool) + tool1.name = "tool_one" + tool1.on_invoke_tool = AsyncMock(return_value="result_one") + + tool2 = Mock(spec=FunctionTool) + tool2.name = "tool_two" + tool2.on_invoke_tool = AsyncMock(return_value="result_two") + + handoff = Mock(spec=Handoff) + handoff.name = "handoff_tool" + + # Set up agent to return all tools + mock_agent.get_all_tools.return_value = [tool1, tool2, handoff] + + session = RealtimeSession(mock_model, mock_agent, None) + + # Call tool_two + tool_call_event = RealtimeModelToolCallEvent( + name="tool_two", call_id="call_456", arguments='{"test": "data"}' + ) + + await session._handle_tool_call(tool_call_event) + + # Only tool2 should have been called + tool1.on_invoke_tool.assert_not_called() + tool2.on_invoke_tool.assert_called_once() + + # Verify correct result was sent + sent_call, sent_output, _ = mock_model.sent_tool_outputs[0] + assert sent_output == "result_two" + + @pytest.mark.asyncio + async def test_handoff_tool_handling(self, mock_model, mock_agent, mock_handoff): + """Test that handoff tools are properly handled""" + from unittest.mock import AsyncMock + + from agents.realtime.agent import RealtimeAgent + + # Create a mock new agent to be returned by handoff + mock_new_agent = Mock(spec=RealtimeAgent) + mock_new_agent.name = "new_agent" + mock_new_agent.instructions = "New agent instructions" + mock_new_agent.get_all_tools = AsyncMock(return_value=[]) + mock_new_agent.get_system_prompt = AsyncMock(return_value="New agent system prompt") + + # Set up handoff to return the new agent + mock_handoff.on_invoke_handoff = AsyncMock(return_value=mock_new_agent) + mock_handoff.name = "test_handoff" + + # Set up agent to return handoff tool + mock_agent.get_all_tools.return_value = [mock_handoff] + + session = RealtimeSession(mock_model, mock_agent, None) + + tool_call_event = RealtimeModelToolCallEvent( + name="test_handoff", call_id="call_789", arguments="{}" + ) + + await session._handle_tool_call(tool_call_event) + + # Should have sent session update and tool output + assert len(mock_model.sent_events) >= 2 + + # Should have sent handoff event + assert session._event_queue.qsize() >= 1 + + # Verify agent was updated + assert session._current_agent == mock_new_agent + + @pytest.mark.asyncio + async def test_unknown_tool_handling(self, mock_model, mock_agent, mock_function_tool): + """Test that unknown tools raise an error""" + import pytest + + from agents.exceptions import ModelBehaviorError + + # Set up agent to return different tool than what's called + mock_function_tool.name = "known_tool" + mock_agent.get_all_tools.return_value = [mock_function_tool] + + session = RealtimeSession(mock_model, mock_agent, None) + + # Call unknown tool + tool_call_event = RealtimeModelToolCallEvent( + name="unknown_tool", call_id="call_unknown", arguments="{}" + ) + + # Should raise an error for unknown tool + with pytest.raises(ModelBehaviorError, match="Tool unknown_tool not found"): + await session._handle_tool_call(tool_call_event) + + # Should not have called any tools + mock_function_tool.on_invoke_tool.assert_not_called() + + @pytest.mark.asyncio + async def test_function_tool_exception_handling( + self, mock_model, mock_agent, mock_function_tool + ): + """Test that exceptions in function tools are handled (currently they propagate)""" + # Set up tool to raise exception + mock_function_tool.on_invoke_tool.side_effect = ValueError("Tool error") + mock_agent.get_all_tools.return_value = [mock_function_tool] + + session = RealtimeSession(mock_model, mock_agent, None) + + tool_call_event = RealtimeModelToolCallEvent( + name="test_function", call_id="call_error", arguments="{}" + ) + + # Currently exceptions propagate (no error handling implemented) + with pytest.raises(ValueError, match="Tool error"): + await session._handle_tool_call(tool_call_event) + + # Tool start event should have been queued before the error + assert session._event_queue.qsize() == 1 + tool_start_event = await session._event_queue.get() + assert isinstance(tool_start_event, RealtimeToolStart) + + # But no tool output should have been sent and no end event queued + assert len(mock_model.sent_tool_outputs) == 0 + + @pytest.mark.asyncio + async def test_tool_call_with_complex_arguments( + self, mock_model, mock_agent, mock_function_tool + ): + """Test tool call with complex JSON arguments""" + mock_agent.get_all_tools.return_value = [mock_function_tool] + + session = RealtimeSession(mock_model, mock_agent, None) + + # Complex arguments + complex_args = '{"nested": {"data": [1, 2, 3]}, "bool": true, "null": null}' + + tool_call_event = RealtimeModelToolCallEvent( + name="test_function", call_id="call_complex", arguments=complex_args + ) + + await session._handle_tool_call(tool_call_event) + + # Verify arguments were passed correctly + call_args = mock_function_tool.on_invoke_tool.call_args + assert call_args[0][1] == complex_args + + @pytest.mark.asyncio + async def test_tool_call_with_custom_call_id(self, mock_model, mock_agent, mock_function_tool): + """Test that tool context receives correct call_id""" + mock_agent.get_all_tools.return_value = [mock_function_tool] + + session = RealtimeSession(mock_model, mock_agent, None) + + custom_call_id = "custom_call_id_12345" + + tool_call_event = RealtimeModelToolCallEvent( + name="test_function", call_id=custom_call_id, arguments="{}" + ) + + await session._handle_tool_call(tool_call_event) + + # Verify tool context was created with correct call_id + call_args = mock_function_tool.on_invoke_tool.call_args + tool_context = call_args[0][0] + # The call_id is used internally in ToolContext.from_agent_context + # We can't directly access it, but we can verify the context was created + assert isinstance(tool_context, ToolContext) + + @pytest.mark.asyncio + async def test_tool_result_conversion_to_string(self, mock_model, mock_agent): + """Test that tool results are converted to strings for model output""" + # Create tool that returns non-string result + tool = Mock(spec=FunctionTool) + tool.name = "test_function" + tool.on_invoke_tool = AsyncMock(return_value={"result": "data", "count": 42}) + + mock_agent.get_all_tools.return_value = [tool] + + session = RealtimeSession(mock_model, mock_agent, None) + + tool_call_event = RealtimeModelToolCallEvent( + name="test_function", call_id="call_conversion", arguments="{}" + ) + + await session._handle_tool_call(tool_call_event) + + # Verify result was converted to string + sent_call, sent_output, _ = mock_model.sent_tool_outputs[0] + assert isinstance(sent_output, str) + assert sent_output == "{'result': 'data', 'count': 42}" + + @pytest.mark.asyncio + async def test_mixed_tool_types_filtering(self, mock_model, mock_agent): + """Test that function tools and handoffs are properly separated""" + # Create mixed tools + func_tool1 = Mock(spec=FunctionTool) + func_tool1.name = "func1" + func_tool1.on_invoke_tool = AsyncMock(return_value="result1") + + handoff1 = Mock(spec=Handoff) + handoff1.name = "handoff1" + + func_tool2 = Mock(spec=FunctionTool) + func_tool2.name = "func2" + func_tool2.on_invoke_tool = AsyncMock(return_value="result2") + + handoff2 = Mock(spec=Handoff) + handoff2.name = "handoff2" + + # Add some other object that's neither (should be ignored) + other_tool = Mock() + other_tool.name = "other" + + all_tools = [func_tool1, handoff1, func_tool2, handoff2, other_tool] + mock_agent.get_all_tools.return_value = all_tools + + session = RealtimeSession(mock_model, mock_agent, None) + + # Call a function tool + tool_call_event = RealtimeModelToolCallEvent( + name="func2", call_id="call_filtering", arguments="{}" + ) + + await session._handle_tool_call(tool_call_event) + + # Only func2 should have been called + func_tool1.on_invoke_tool.assert_not_called() + func_tool2.on_invoke_tool.assert_called_once() + + # Verify result + sent_call, sent_output, _ = mock_model.sent_tool_outputs[0] + assert sent_output == "result2" + + +class TestGuardrailFunctionality: + """Test suite for output guardrail functionality in RealtimeSession""" + + async def _wait_for_guardrail_tasks(self, session): + """Wait for all pending guardrail tasks to complete.""" + import asyncio + + if session._guardrail_tasks: + await asyncio.gather(*session._guardrail_tasks, return_exceptions=True) + + @pytest.fixture + def triggered_guardrail(self): + """Creates a guardrail that always triggers""" + + def guardrail_func(context, agent, output): + return GuardrailFunctionOutput( + output_info={"reason": "test trigger"}, tripwire_triggered=True + ) + + return OutputGuardrail(guardrail_function=guardrail_func, name="triggered_guardrail") + + @pytest.fixture + def safe_guardrail(self): + """Creates a guardrail that never triggers""" + + def guardrail_func(context, agent, output): + return GuardrailFunctionOutput( + output_info={"reason": "safe content"}, tripwire_triggered=False + ) + + return OutputGuardrail(guardrail_function=guardrail_func, name="safe_guardrail") + + @pytest.mark.asyncio + async def test_transcript_delta_triggers_guardrail_at_threshold( + self, mock_model, mock_agent, triggered_guardrail + ): + """Test that guardrails run when transcript delta reaches debounce threshold""" + run_config: RealtimeRunConfig = { + "output_guardrails": [triggered_guardrail], + "guardrails_settings": {"debounce_text_length": 10}, + } + + session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) + + # Send transcript delta that exceeds threshold (10 chars) + transcript_event = RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="this is more than ten characters", response_id="resp_1" + ) + + await session.on_event(transcript_event) + + # Wait for async guardrail tasks to complete + await self._wait_for_guardrail_tasks(session) + + # Should have triggered guardrail and interrupted + assert session._interrupted_by_guardrail is True + assert mock_model.interrupts_called == 1 + assert len(mock_model.sent_messages) == 1 + assert "triggered_guardrail" in mock_model.sent_messages[0] + + # Should have emitted guardrail_tripped event + events = [] + while not session._event_queue.empty(): + events.append(await session._event_queue.get()) + + guardrail_events = [e for e in events if isinstance(e, RealtimeGuardrailTripped)] + assert len(guardrail_events) == 1 + assert guardrail_events[0].message == "this is more than ten characters" + + @pytest.mark.asyncio + async def test_transcript_delta_multiple_thresholds_same_item( + self, mock_model, mock_agent, triggered_guardrail + ): + """Test guardrails run at 1x, 2x, 3x thresholds for same item_id""" + run_config: RealtimeRunConfig = { + "output_guardrails": [triggered_guardrail], + "guardrails_settings": {"debounce_text_length": 5}, + } + + session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) + + # First delta - reaches 1x threshold (5 chars) + await session.on_event( + RealtimeModelTranscriptDeltaEvent(item_id="item_1", delta="12345", response_id="resp_1") + ) + + # Second delta - reaches 2x threshold (10 chars total) + await session.on_event( + RealtimeModelTranscriptDeltaEvent(item_id="item_1", delta="67890", response_id="resp_1") + ) + + # Wait for async guardrail tasks to complete + await self._wait_for_guardrail_tasks(session) + + # Should only trigger once due to interrupted_by_guardrail flag + assert mock_model.interrupts_called == 1 + assert len(mock_model.sent_messages) == 1 + + @pytest.mark.asyncio + async def test_transcript_delta_different_items_tracked_separately( + self, mock_model, mock_agent, safe_guardrail + ): + """Test that different item_ids are tracked separately for debouncing""" + run_config: RealtimeRunConfig = { + "output_guardrails": [safe_guardrail], + "guardrails_settings": {"debounce_text_length": 10}, + } + + session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) + + # Add text to item_1 (8 chars - below threshold) + await session.on_event( + RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="12345678", response_id="resp_1" + ) + ) + + # Add text to item_2 (8 chars - below threshold) + await session.on_event( + RealtimeModelTranscriptDeltaEvent( + item_id="item_2", delta="abcdefgh", response_id="resp_2" + ) + ) + + # Neither should trigger guardrails yet + assert mock_model.interrupts_called == 0 + + # Add more text to item_1 (total 12 chars - above threshold) + await session.on_event( + RealtimeModelTranscriptDeltaEvent(item_id="item_1", delta="90ab", response_id="resp_1") + ) + + # item_1 should have triggered guardrail run (but not interrupted since safe) + assert session._item_guardrail_run_counts["item_1"] == 1 + assert ( + "item_2" not in session._item_guardrail_run_counts + or session._item_guardrail_run_counts["item_2"] == 0 + ) + + @pytest.mark.asyncio + async def test_turn_ended_clears_guardrail_state( + self, mock_model, mock_agent, triggered_guardrail + ): + """Test that turn_ended event clears guardrail state for next turn""" + run_config: RealtimeRunConfig = { + "output_guardrails": [triggered_guardrail], + "guardrails_settings": {"debounce_text_length": 5}, + } + + session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) + + # Trigger guardrail + await session.on_event( + RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="trigger", response_id="resp_1" + ) + ) + + # Wait for async guardrail tasks to complete + await self._wait_for_guardrail_tasks(session) + + assert session._interrupted_by_guardrail is True + assert len(session._item_transcripts) == 1 + + # End turn + await session.on_event(RealtimeModelTurnEndedEvent()) + + # State should be cleared + assert session._interrupted_by_guardrail is False + assert len(session._item_transcripts) == 0 + assert len(session._item_guardrail_run_counts) == 0 + + @pytest.mark.asyncio + async def test_multiple_guardrails_all_triggered(self, mock_model, mock_agent): + """Test that all triggered guardrails are included in the event""" + + def create_triggered_guardrail(name): + def guardrail_func(context, agent, output): + return GuardrailFunctionOutput(output_info={"name": name}, tripwire_triggered=True) + + return OutputGuardrail(guardrail_function=guardrail_func, name=name) + + guardrail1 = create_triggered_guardrail("guardrail_1") + guardrail2 = create_triggered_guardrail("guardrail_2") + + run_config: RealtimeRunConfig = { + "output_guardrails": [guardrail1, guardrail2], + "guardrails_settings": {"debounce_text_length": 5}, + } + + session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config) + + await session.on_event( + RealtimeModelTranscriptDeltaEvent( + item_id="item_1", delta="trigger", response_id="resp_1" + ) + ) + + # Wait for async guardrail tasks to complete + await self._wait_for_guardrail_tasks(session) + + # Should have interrupted and sent message with both guardrail names + assert mock_model.interrupts_called == 1 + assert len(mock_model.sent_messages) == 1 + message = mock_model.sent_messages[0] + assert "guardrail_1" in message and "guardrail_2" in message + + # Should have emitted event with both guardrail results + events = [] + while not session._event_queue.empty(): + events.append(await session._event_queue.get()) + + guardrail_events = [e for e in events if isinstance(e, RealtimeGuardrailTripped)] + assert len(guardrail_events) == 1 + assert len(guardrail_events[0].guardrail_results) == 2 diff --git a/tests/realtime/test_tracing.py b/tests/realtime/test_tracing.py new file mode 100644 index 000000000..4cff46c49 --- /dev/null +++ b/tests/realtime/test_tracing.py @@ -0,0 +1,264 @@ +from unittest.mock import AsyncMock, patch + +import pytest + +from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel + + +class TestRealtimeTracingIntegration: + """Test tracing configuration and session.update integration.""" + + @pytest.fixture + def model(self): + """Create a fresh model instance for each test.""" + return OpenAIRealtimeWebSocketModel() + + @pytest.fixture + def mock_websocket(self): + """Create a mock websocket connection.""" + mock_ws = AsyncMock() + mock_ws.send = AsyncMock() + mock_ws.close = AsyncMock() + return mock_ws + + @pytest.mark.asyncio + async def test_tracing_config_storage_and_defaults(self, model, mock_websocket): + """Test that tracing config is stored correctly and defaults to 'auto'.""" + # Test with explicit tracing config + config_with_tracing = { + "api_key": "test-key", + "initial_model_settings": { + "tracing": { + "workflow_name": "test_workflow", + "group_id": "group_123", + "metadata": {"version": "1.0"}, + } + }, + } + + async def async_websocket(*args, **kwargs): + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + mock_create_task.return_value = mock_task + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] + + await model.connect(config_with_tracing) + + # Should store the tracing config + assert model._tracing_config == { + "workflow_name": "test_workflow", + "group_id": "group_123", + "metadata": {"version": "1.0"}, + } + + # Test without tracing config - should default to "auto" + model2 = OpenAIRealtimeWebSocketModel() + config_no_tracing = { + "api_key": "test-key", + "initial_model_settings": {}, + } + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] + + await model2.connect(config_no_tracing) # type: ignore[arg-type] + assert model2._tracing_config == "auto" + + @pytest.mark.asyncio + async def test_send_tracing_config_on_session_created(self, model, mock_websocket): + """Test that tracing config is sent when session.created event is received.""" + config = { + "api_key": "test-key", + "initial_model_settings": { + "tracing": {"workflow_name": "test_workflow", "group_id": "group_123"} + }, + } + + async def async_websocket(*args, **kwargs): + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] + + await model.connect(config) + + # Simulate session.created event + session_created_event = { + "type": "session.created", + "event_id": "event_123", + "session": {"id": "session_456"}, + } + + with patch.object(model, "_send_raw_message") as mock_send_raw_message: + await model._handle_ws_event(session_created_event) + + # Should send session.update with tracing config + from agents.realtime.model_inputs import RealtimeModelSendRawMessage + + expected_event = RealtimeModelSendRawMessage( + message={ + "type": "session.update", + "other_data": { + "session": { + "tracing": { + "workflow_name": "test_workflow", + "group_id": "group_123", + } + } + }, + } + ) + mock_send_raw_message.assert_called_once_with(expected_event) + + @pytest.mark.asyncio + async def test_send_tracing_config_auto_mode(self, model, mock_websocket): + """Test that 'auto' tracing config is sent correctly.""" + config = { + "api_key": "test-key", + "initial_model_settings": {}, # No tracing config - defaults to "auto" + } + + async def async_websocket(*args, **kwargs): + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] + + await model.connect(config) + + session_created_event = { + "type": "session.created", + "event_id": "event_123", + "session": {"id": "session_456"}, + } + + with patch.object(model, "_send_raw_message") as mock_send_raw_message: + await model._handle_ws_event(session_created_event) + + # Should send session.update with "auto" + from agents.realtime.model_inputs import RealtimeModelSendRawMessage + + expected_event = RealtimeModelSendRawMessage( + message={ + "type": "session.update", + "other_data": {"session": {"tracing": "auto"}}, + } + ) + mock_send_raw_message.assert_called_once_with(expected_event) + + @pytest.mark.asyncio + async def test_tracing_config_none_skips_session_update(self, model, mock_websocket): + """Test that None tracing config skips sending session.update.""" + # Manually set tracing config to None (this would happen if explicitly set) + model._tracing_config = None + + session_created_event = { + "type": "session.created", + "event_id": "event_123", + "session": {"id": "session_456"}, + } + + with patch.object(model, "send_event") as mock_send_event: + await model._handle_ws_event(session_created_event) + + # Should not send any session.update + mock_send_event.assert_not_called() + + @pytest.mark.asyncio + async def test_tracing_config_with_metadata_serialization(self, model, mock_websocket): + """Test that complex metadata in tracing config is handled correctly.""" + complex_metadata = { + "user_id": "user_123", + "session_type": "demo", + "features": ["audio", "tools"], + "config": {"timeout": 30, "retries": 3}, + } + + config = { + "api_key": "test-key", + "initial_model_settings": { + "tracing": {"workflow_name": "complex_workflow", "metadata": complex_metadata} + }, + } + + async def async_websocket(*args, **kwargs): + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + mock_create_task.side_effect = lambda coro: (coro.close(), mock_task)[1] + + await model.connect(config) + + session_created_event = { + "type": "session.created", + "event_id": "event_123", + "session": {"id": "session_456"}, + } + + with patch.object(model, "_send_raw_message") as mock_send_raw_message: + await model._handle_ws_event(session_created_event) + + # Should send session.update with complete tracing config including metadata + from agents.realtime.model_inputs import RealtimeModelSendRawMessage + + expected_event = RealtimeModelSendRawMessage( + message={ + "type": "session.update", + "other_data": { + "session": { + "tracing": { + "workflow_name": "complex_workflow", + "metadata": complex_metadata, + } + } + }, + } + ) + mock_send_raw_message.assert_called_once_with(expected_event) + + @pytest.mark.asyncio + async def test_tracing_disabled_prevents_tracing(self, mock_websocket): + """Test that tracing_disabled=True prevents tracing configuration.""" + from agents.realtime.agent import RealtimeAgent + from agents.realtime.runner import RealtimeRunner + + # Create a test agent and runner with tracing disabled + agent = RealtimeAgent(name="test_agent", instructions="test") + + runner = RealtimeRunner(starting_agent=agent, config={"tracing_disabled": True}) + + # Test the _get_model_settings method directly since that's where the logic is + model_settings = await runner._get_model_settings( + agent=agent, + disable_tracing=True, # This should come from config["tracing_disabled"] + initial_settings=None, + overrides=None, + ) + + # When tracing is disabled, model settings should have tracing=None + assert model_settings["tracing"] is None + + # Also test that the runner passes disable_tracing=True correctly + with patch.object(runner, "_get_model_settings") as mock_get_settings: + mock_get_settings.return_value = {"tracing": None} + + with patch("agents.realtime.session.RealtimeSession") as mock_session_class: + mock_session = AsyncMock() + mock_session_class.return_value = mock_session + + await runner.run() + + # Verify that _get_model_settings was called with disable_tracing=True + mock_get_settings.assert_called_once_with( + agent=agent, disable_tracing=True, initial_settings=None, overrides=None + ) diff --git a/tests/test_function_schema.py b/tests/test_function_schema.py index 54e7e2f9e..f63d9b534 100644 --- a/tests/test_function_schema.py +++ b/tests/test_function_schema.py @@ -3,7 +3,7 @@ from typing import Any, Literal import pytest -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, Field, ValidationError from typing_extensions import TypedDict from agents import RunContextWrapper @@ -451,3 +451,220 @@ def foo(x: int) -> int: assert fs.name == "custom" assert fs.params_json_schema.get("title") == "custom_args" + + +def test_function_with_field_required_constraints(): + """Test function with required Field parameter that has constraints.""" + + def func_with_field_constraints(my_number: int = Field(..., gt=10, le=100)) -> int: + return my_number * 2 + + fs = function_schema(func_with_field_constraints, use_docstring_info=False) + + # Check that the schema includes the constraints + properties = fs.params_json_schema.get("properties", {}) + my_number_schema = properties.get("my_number", {}) + assert my_number_schema.get("type") == "integer" + assert my_number_schema.get("exclusiveMinimum") == 10 # gt=10 + assert my_number_schema.get("maximum") == 100 # le=100 + + # Valid input should work + valid_input = {"my_number": 50} + parsed = fs.params_pydantic_model(**valid_input) + args, kwargs_dict = fs.to_call_args(parsed) + result = func_with_field_constraints(*args, **kwargs_dict) + assert result == 100 + + # Invalid input: too small (should violate gt=10) + with pytest.raises(ValidationError): + fs.params_pydantic_model(**{"my_number": 5}) + + # Invalid input: too large (should violate le=100) + with pytest.raises(ValidationError): + fs.params_pydantic_model(**{"my_number": 150}) + + +def test_function_with_field_optional_with_default(): + """Test function with optional Field parameter that has default and constraints.""" + + def func_with_optional_field( + required_param: str, + optional_param: float = Field(default=5.0, ge=0.0), + ) -> str: + return f"{required_param}: {optional_param}" + + fs = function_schema(func_with_optional_field, use_docstring_info=False) + + # Check that the schema includes the constraints and description + properties = fs.params_json_schema.get("properties", {}) + optional_schema = properties.get("optional_param", {}) + assert optional_schema.get("type") == "number" + assert optional_schema.get("minimum") == 0.0 # ge=0.0 + assert optional_schema.get("default") == 5.0 + + # Valid input with default + valid_input = {"required_param": "test"} + parsed = fs.params_pydantic_model(**valid_input) + args, kwargs_dict = fs.to_call_args(parsed) + result = func_with_optional_field(*args, **kwargs_dict) + assert result == "test: 5.0" + + # Valid input with explicit value + valid_input2 = {"required_param": "test", "optional_param": 10.5} + parsed2 = fs.params_pydantic_model(**valid_input2) + args2, kwargs_dict2 = fs.to_call_args(parsed2) + result2 = func_with_optional_field(*args2, **kwargs_dict2) + assert result2 == "test: 10.5" + + # Invalid input: negative value (should violate ge=0.0) + with pytest.raises(ValidationError): + fs.params_pydantic_model(**{"required_param": "test", "optional_param": -1.0}) + + +def test_function_with_field_description_merge(): + """Test that Field descriptions are merged with docstring descriptions.""" + + def func_with_field_and_docstring( + param_with_field_desc: int = Field(..., description="Field description"), + param_with_both: str = Field(default="hello", description="Field description"), + ) -> str: + """ + Function with both field and docstring descriptions. + + Args: + param_with_field_desc: Docstring description + param_with_both: Docstring description + """ + return f"{param_with_field_desc}: {param_with_both}" + + fs = function_schema(func_with_field_and_docstring, use_docstring_info=True) + + # Check that docstring description takes precedence when both exist + properties = fs.params_json_schema.get("properties", {}) + param1_schema = properties.get("param_with_field_desc", {}) + param2_schema = properties.get("param_with_both", {}) + + # The docstring description should be used when both are present + assert param1_schema.get("description") == "Docstring description" + assert param2_schema.get("description") == "Docstring description" + + +def func_with_field_desc_only( + param_with_field_desc: int = Field(..., description="Field description only"), + param_without_desc: str = Field(default="hello"), +) -> str: + return f"{param_with_field_desc}: {param_without_desc}" + + +def test_function_with_field_description_only(): + """Test that Field descriptions are used when no docstring info.""" + + fs = function_schema(func_with_field_desc_only) + + # Check that field description is used when no docstring + properties = fs.params_json_schema.get("properties", {}) + param1_schema = properties.get("param_with_field_desc", {}) + param2_schema = properties.get("param_without_desc", {}) + + assert param1_schema.get("description") == "Field description only" + assert param2_schema.get("description") is None + + +def test_function_with_field_string_constraints(): + """Test function with Field parameter that has string-specific constraints.""" + + def func_with_string_field( + name: str = Field(..., min_length=3, max_length=20, pattern=r"^[A-Za-z]+$"), + ) -> str: + return f"Hello, {name}!" + + fs = function_schema(func_with_string_field, use_docstring_info=False) + + # Check that the schema includes string constraints + properties = fs.params_json_schema.get("properties", {}) + name_schema = properties.get("name", {}) + assert name_schema.get("type") == "string" + assert name_schema.get("minLength") == 3 + assert name_schema.get("maxLength") == 20 + assert name_schema.get("pattern") == r"^[A-Za-z]+$" + + # Valid input + valid_input = {"name": "Alice"} + parsed = fs.params_pydantic_model(**valid_input) + args, kwargs_dict = fs.to_call_args(parsed) + result = func_with_string_field(*args, **kwargs_dict) + assert result == "Hello, Alice!" + + # Invalid input: too short + with pytest.raises(ValidationError): + fs.params_pydantic_model(**{"name": "Al"}) + + # Invalid input: too long + with pytest.raises(ValidationError): + fs.params_pydantic_model(**{"name": "A" * 25}) + + # Invalid input: doesn't match pattern (contains numbers) + with pytest.raises(ValidationError): + fs.params_pydantic_model(**{"name": "Alice123"}) + + +def test_function_with_field_multiple_constraints(): + """Test function with multiple Field parameters having different constraint types.""" + + def func_with_multiple_field_constraints( + score: int = Field(..., ge=0, le=100, description="Score from 0 to 100"), + name: str = Field(default="Unknown", min_length=1, max_length=50), + factor: float = Field(default=1.0, gt=0.0, description="Positive multiplier"), + ) -> str: + final_score = score * factor + return f"{name} scored {final_score}" + + fs = function_schema(func_with_multiple_field_constraints, use_docstring_info=False) + + # Check schema structure + properties = fs.params_json_schema.get("properties", {}) + + # Check score field + score_schema = properties.get("score", {}) + assert score_schema.get("type") == "integer" + assert score_schema.get("minimum") == 0 + assert score_schema.get("maximum") == 100 + assert score_schema.get("description") == "Score from 0 to 100" + + # Check name field + name_schema = properties.get("name", {}) + assert name_schema.get("type") == "string" + assert name_schema.get("minLength") == 1 + assert name_schema.get("maxLength") == 50 + assert name_schema.get("default") == "Unknown" + + # Check factor field + factor_schema = properties.get("factor", {}) + assert factor_schema.get("type") == "number" + assert factor_schema.get("exclusiveMinimum") == 0.0 + assert factor_schema.get("default") == 1.0 + assert factor_schema.get("description") == "Positive multiplier" + + # Valid input with defaults + valid_input = {"score": 85} + parsed = fs.params_pydantic_model(**valid_input) + args, kwargs_dict = fs.to_call_args(parsed) + result = func_with_multiple_field_constraints(*args, **kwargs_dict) + assert result == "Unknown scored 85.0" + + # Valid input with all parameters + valid_input2 = {"score": 90, "name": "Alice", "factor": 1.5} + parsed2 = fs.params_pydantic_model(**valid_input2) + args2, kwargs_dict2 = fs.to_call_args(parsed2) + result2 = func_with_multiple_field_constraints(*args2, **kwargs_dict2) + assert result2 == "Alice scored 135.0" + + # Test various validation errors + with pytest.raises(ValidationError): # score too high + fs.params_pydantic_model(**{"score": 150}) + + with pytest.raises(ValidationError): # empty name + fs.params_pydantic_model(**{"score": 50, "name": ""}) + + with pytest.raises(ValidationError): # zero factor + fs.params_pydantic_model(**{"score": 50, "factor": 0.0}) diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index b232bf75e..15602bbac 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -5,7 +5,14 @@ from pydantic import BaseModel from typing_extensions import TypedDict -from agents import Agent, FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool +from agents import ( + Agent, + AgentBase, + FunctionTool, + ModelBehaviorError, + RunContextWrapper, + function_tool, +) from agents.tool import default_tool_error_function from agents.tool_context import ToolContext @@ -19,7 +26,9 @@ async def test_argless_function(): tool = function_tool(argless_function) assert tool.name == "argless_function" - result = await tool.on_invoke_tool(ToolContext(context=None, tool_call_id="1"), "") + result = await tool.on_invoke_tool( + ToolContext(context=None, tool_name=tool.name, tool_call_id="1"), "" + ) assert result == "ok" @@ -32,11 +41,13 @@ async def test_argless_with_context(): tool = function_tool(argless_with_context) assert tool.name == "argless_with_context" - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "") + result = await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "") assert result == "ok" # Extra JSON should not raise an error - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}') + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}' + ) assert result == "ok" @@ -49,15 +60,19 @@ async def test_simple_function(): tool = function_tool(simple_function, failure_error_function=None) assert tool.name == "simple_function" - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}') + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}' + ) assert result == 6 - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1, "b": 2}') + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1, "b": 2}' + ) assert result == 3 # Missing required argument should raise an error with pytest.raises(ModelBehaviorError): - await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "") + await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "") class Foo(BaseModel): @@ -85,7 +100,9 @@ async def test_complex_args_function(): "bar": Bar(x="hello", y=10), } ) - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json) + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json + ) assert result == "6 hello10 hello" valid_json = json.dumps( @@ -94,7 +111,9 @@ async def test_complex_args_function(): "bar": Bar(x="hello", y=10), } ) - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json) + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json + ) assert result == "3 hello10 hello" valid_json = json.dumps( @@ -104,12 +123,16 @@ async def test_complex_args_function(): "baz": "world", } ) - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json) + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json + ) assert result == "3 hello10 world" # Missing required argument should raise an error with pytest.raises(ModelBehaviorError): - await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"foo": {"a": 1}}') + await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"foo": {"a": 1}}' + ) def test_function_config_overrides(): @@ -169,7 +192,9 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str: assert tool.params_json_schema[key] == value assert tool.strict_json_schema - result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"data": "hello"}') + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"data": "hello"}' + ) assert result == "hello_done" tool_not_strict = FunctionTool( @@ -184,7 +209,8 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str: assert "additionalProperties" not in tool_not_strict.params_json_schema result = await tool_not_strict.on_invoke_tool( - ToolContext(None, tool_call_id="1"), '{"data": "hello", "bar": "baz"}' + ToolContext(None, tool_name=tool_not_strict.name, tool_call_id="1"), + '{"data": "hello", "bar": "baz"}', ) assert result == "hello_done" @@ -195,7 +221,7 @@ def my_func(a: int, b: int = 5): raise ValueError("test") tool = function_tool(my_func) - ctx = ToolContext(None, tool_call_id="1") + ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1") result = await tool.on_invoke_tool(ctx, "") assert "Invalid JSON" in str(result) @@ -219,7 +245,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> return f"error_{error.__class__.__name__}" tool = function_tool(my_func, failure_error_function=custom_sync_error_function) - ctx = ToolContext(None, tool_call_id="1") + ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1") result = await tool.on_invoke_tool(ctx, "") assert result == "error_ModelBehaviorError" @@ -243,7 +269,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> return f"error_{error.__class__.__name__}" tool = function_tool(my_func, failure_error_function=custom_sync_error_function) - ctx = ToolContext(None, tool_call_id="1") + ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1") result = await tool.on_invoke_tool(ctx, "") assert result == "error_ModelBehaviorError" @@ -268,7 +294,7 @@ async def test_is_enabled_bool_and_callable(): def disabled_tool(): return "nope" - async def cond_enabled(ctx: RunContextWrapper[BoolCtx], agent: Agent[Any]) -> bool: + async def cond_enabled(ctx: RunContextWrapper[BoolCtx], agent: AgentBase) -> bool: return ctx.context.enable_tools @function_tool(is_enabled=cond_enabled) diff --git a/tests/test_function_tool_decorator.py b/tests/test_function_tool_decorator.py index d334d8f84..b81d5dbe2 100644 --- a/tests/test_function_tool_decorator.py +++ b/tests/test_function_tool_decorator.py @@ -16,7 +16,7 @@ def __init__(self): def ctx_wrapper() -> ToolContext[DummyContext]: - return ToolContext(context=DummyContext(), tool_call_id="1") + return ToolContext(context=DummyContext(), tool_name="dummy", tool_call_id="1") @function_tool diff --git a/tests/test_items_helpers.py b/tests/test_items_helpers.py index 5dba21d88..f711f21e1 100644 --- a/tests/test_items_helpers.py +++ b/tests/test_items_helpers.py @@ -11,7 +11,10 @@ ) from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall from openai.types.responses.response_function_tool_call_param import ResponseFunctionToolCallParam -from openai.types.responses.response_function_web_search import ResponseFunctionWebSearch +from openai.types.responses.response_function_web_search import ( + ActionSearch, + ResponseFunctionWebSearch, +) from openai.types.responses.response_function_web_search_param import ResponseFunctionWebSearchParam from openai.types.responses.response_output_message import ResponseOutputMessage from openai.types.responses.response_output_message_param import ResponseOutputMessageParam @@ -225,7 +228,12 @@ def test_to_input_items_for_file_search_call() -> None: def test_to_input_items_for_web_search_call() -> None: """A web search tool call output should produce the same dict as a web search input.""" - ws_call = ResponseFunctionWebSearch(id="w1", status="completed", type="web_search_call") + ws_call = ResponseFunctionWebSearch( + id="w1", + action=ActionSearch(type="search", query="query"), + status="completed", + type="web_search_call", + ) resp = ModelResponse(output=[ws_call], usage=Usage(), response_id=None) input_items = resp.to_input_items() assert isinstance(input_items, list) and len(input_items) == 1 @@ -233,6 +241,7 @@ def test_to_input_items_for_web_search_call() -> None: "id": "w1", "status": "completed", "type": "web_search_call", + "action": {"type": "search", "query": "query"}, } assert input_items[0] == expected diff --git a/tests/test_reasoning_content.py b/tests/test_reasoning_content.py index 5160e09c2..69e9a7d0c 100644 --- a/tests/test_reasoning_content.py +++ b/tests/test_reasoning_content.py @@ -27,12 +27,8 @@ # Helper functions to create test objects consistently def create_content_delta(content: str) -> dict[str, Any]: """Create a delta dictionary with regular content""" - return { - "content": content, - "role": None, - "function_call": None, - "tool_calls": None - } + return {"content": content, "role": None, "function_call": None, "tool_calls": None} + def create_reasoning_delta(content: str) -> dict[str, Any]: """Create a delta dictionary with reasoning content. The Only difference is reasoning_content""" @@ -41,7 +37,7 @@ def create_reasoning_delta(content: str) -> dict[str, Any]: "role": None, "function_call": None, "tool_calls": None, - "reasoning_content": content + "reasoning_content": content, } @@ -188,7 +184,7 @@ async def test_get_response_with_reasoning_content(monkeypatch) -> None: "index": 0, "finish_reason": "stop", "message": msg_with_reasoning, - "delta": None + "delta": None, } chat = ChatCompletion( @@ -274,7 +270,7 @@ async def patched_fetch_response(self, *args, **kwargs): handoffs=[], tracing=ModelTracing.DISABLED, previous_response_id=None, - prompt=None + prompt=None, ): output_events.append(event) diff --git a/tests/test_run_step_processing.py b/tests/test_run_step_processing.py index 6a2904791..27d36afa8 100644 --- a/tests/test_run_step_processing.py +++ b/tests/test_run_step_processing.py @@ -7,6 +7,7 @@ ResponseFunctionWebSearch, ) from openai.types.responses.response_computer_tool_call import ActionClick +from openai.types.responses.response_function_web_search import ActionSearch from openai.types.responses.response_reasoning_item import ResponseReasoningItem, Summary from pydantic import BaseModel @@ -306,7 +307,12 @@ async def test_file_search_tool_call_parsed_correctly(): @pytest.mark.asyncio async def test_function_web_search_tool_call_parsed_correctly(): agent = Agent(name="test") - web_search_call = ResponseFunctionWebSearch(id="w1", status="completed", type="web_search_call") + web_search_call = ResponseFunctionWebSearch( + id="w1", + action=ActionSearch(type="search", query="query"), + status="completed", + type="web_search_call", + ) response = ModelResponse( output=[get_text_message("hello"), web_search_call], usage=Usage(), diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 000000000..032f2bb38 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,400 @@ +"""Tests for session memory functionality.""" + +import asyncio +import tempfile +from pathlib import Path + +import pytest + +from agents import Agent, Runner, SQLiteSession, TResponseInputItem +from agents.exceptions import UserError + +from .fake_model import FakeModel +from .test_responses import get_text_message + + +# Helper functions for parametrized testing of different Runner methods +def _run_sync_wrapper(agent, input_data, **kwargs): + """Wrapper for run_sync that properly sets up an event loop.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return Runner.run_sync(agent, input_data, **kwargs) + finally: + loop.close() + + +async def run_agent_async(runner_method: str, agent, input_data, **kwargs): + """Helper function to run agent with different methods.""" + if runner_method == "run": + return await Runner.run(agent, input_data, **kwargs) + elif runner_method == "run_sync": + # For run_sync, we need to run it in a thread with its own event loop + return await asyncio.to_thread(_run_sync_wrapper, agent, input_data, **kwargs) + elif runner_method == "run_streamed": + result = Runner.run_streamed(agent, input_data, **kwargs) + # For streaming, we first try to get at least one event to trigger any early exceptions + # If there's an exception in setup (like memory validation), it will be raised here + try: + first_event = None + async for event in result.stream_events(): + if first_event is None: + first_event = event + # Continue consuming all events + pass + except Exception: + # If an exception occurs during streaming, we let it propagate up + raise + return result + else: + raise ValueError(f"Unknown runner method: {runner_method}") + + +# Parametrized tests for different runner methods +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_memory_basic_functionality_parametrized(runner_method): + """Test basic session memory functionality with SQLite backend across all runner methods.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + session_id = "test_session_123" + session = SQLiteSession(session_id, db_path) + + model = FakeModel() + agent = Agent(name="test", model=model) + + # First turn + model.set_next_output([get_text_message("San Francisco")]) + result1 = await run_agent_async( + runner_method, + agent, + "What city is the Golden Gate Bridge in?", + session=session, + ) + assert result1.final_output == "San Francisco" + + # Second turn - should have conversation history + model.set_next_output([get_text_message("California")]) + result2 = await run_agent_async( + runner_method, + agent, + "What state is it in?", + session=session, + ) + assert result2.final_output == "California" + + # Verify that the input to the second turn includes the previous conversation + # The model should have received the full conversation history + last_input = model.last_turn_args["input"] + assert len(last_input) > 1 # Should have more than just the current message + + session.close() + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_memory_with_explicit_instance_parametrized(runner_method): + """Test session memory with an explicit SQLiteSession instance across all runner methods.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + session_id = "test_session_456" + session = SQLiteSession(session_id, db_path) + + model = FakeModel() + agent = Agent(name="test", model=model) + + # First turn + model.set_next_output([get_text_message("Hello")]) + result1 = await run_agent_async(runner_method, agent, "Hi there", session=session) + assert result1.final_output == "Hello" + + # Second turn + model.set_next_output([get_text_message("I remember you said hi")]) + result2 = await run_agent_async( + runner_method, + agent, + "Do you remember what I said?", + session=session, + ) + assert result2.final_output == "I remember you said hi" + + session.close() + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_memory_disabled_parametrized(runner_method): + """Test that session memory is disabled when session=None across all runner methods.""" + model = FakeModel() + agent = Agent(name="test", model=model) + + # First turn (no session parameters = disabled) + model.set_next_output([get_text_message("Hello")]) + result1 = await run_agent_async(runner_method, agent, "Hi there") + assert result1.final_output == "Hello" + + # Second turn - should NOT have conversation history + model.set_next_output([get_text_message("I don't remember")]) + result2 = await run_agent_async(runner_method, agent, "Do you remember what I said?") + assert result2.final_output == "I don't remember" + + # Verify that the input to the second turn is just the current message + last_input = model.last_turn_args["input"] + assert len(last_input) == 1 # Should only have the current message + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_memory_different_sessions_parametrized(runner_method): + """Test that different session IDs maintain separate conversation histories across all runner + methods.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_memory.db" + + model = FakeModel() + agent = Agent(name="test", model=model) + + # Session 1 + session_id_1 = "session_1" + session_1 = SQLiteSession(session_id_1, db_path) + + model.set_next_output([get_text_message("I like cats")]) + result1 = await run_agent_async(runner_method, agent, "I like cats", session=session_1) + assert result1.final_output == "I like cats" + + # Session 2 - different session + session_id_2 = "session_2" + session_2 = SQLiteSession(session_id_2, db_path) + + model.set_next_output([get_text_message("I like dogs")]) + result2 = await run_agent_async(runner_method, agent, "I like dogs", session=session_2) + assert result2.final_output == "I like dogs" + + # Back to Session 1 - should remember cats, not dogs + model.set_next_output([get_text_message("Yes, you mentioned cats")]) + result3 = await run_agent_async( + runner_method, + agent, + "What did I say I like?", + session=session_1, + ) + assert result3.final_output == "Yes, you mentioned cats" + + session_1.close() + session_2.close() + + +@pytest.mark.asyncio +async def test_sqlite_session_memory_direct(): + """Test SQLiteSession class directly.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_direct.db" + session_id = "direct_test" + session = SQLiteSession(session_id, db_path) + + # Test adding and retrieving items + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + await session.add_items(items) + retrieved = await session.get_items() + + assert len(retrieved) == 2 + assert retrieved[0].get("role") == "user" + assert retrieved[0].get("content") == "Hello" + assert retrieved[1].get("role") == "assistant" + assert retrieved[1].get("content") == "Hi there!" + + # Test clearing session + await session.clear_session() + retrieved_after_clear = await session.get_items() + assert len(retrieved_after_clear) == 0 + + session.close() + + +@pytest.mark.asyncio +async def test_sqlite_session_memory_pop_item(): + """Test SQLiteSession pop_item functionality.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_pop.db" + session_id = "pop_test" + session = SQLiteSession(session_id, db_path) + + # Test popping from empty session + popped = await session.pop_item() + assert popped is None + + # Add items + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + + await session.add_items(items) + + # Verify all items are there + retrieved = await session.get_items() + assert len(retrieved) == 3 + + # Pop the most recent item + popped = await session.pop_item() + assert popped is not None + assert popped.get("role") == "user" + assert popped.get("content") == "How are you?" + + # Verify item was removed + retrieved_after_pop = await session.get_items() + assert len(retrieved_after_pop) == 2 + assert retrieved_after_pop[-1].get("content") == "Hi there!" + + # Pop another item + popped2 = await session.pop_item() + assert popped2 is not None + assert popped2.get("role") == "assistant" + assert popped2.get("content") == "Hi there!" + + # Pop the last item + popped3 = await session.pop_item() + assert popped3 is not None + assert popped3.get("role") == "user" + assert popped3.get("content") == "Hello" + + # Try to pop from empty session again + popped4 = await session.pop_item() + assert popped4 is None + + # Verify session is empty + final_items = await session.get_items() + assert len(final_items) == 0 + + session.close() + + +@pytest.mark.asyncio +async def test_session_memory_pop_different_sessions(): + """Test that pop_item only affects the specified session.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_pop_sessions.db" + + session_1_id = "session_1" + session_2_id = "session_2" + session_1 = SQLiteSession(session_1_id, db_path) + session_2 = SQLiteSession(session_2_id, db_path) + + # Add items to both sessions + items_1: list[TResponseInputItem] = [ + {"role": "user", "content": "Session 1 message"}, + ] + items_2: list[TResponseInputItem] = [ + {"role": "user", "content": "Session 2 message 1"}, + {"role": "user", "content": "Session 2 message 2"}, + ] + + await session_1.add_items(items_1) + await session_2.add_items(items_2) + + # Pop from session 2 + popped = await session_2.pop_item() + assert popped is not None + assert popped.get("content") == "Session 2 message 2" + + # Verify session 1 is unaffected + session_1_items = await session_1.get_items() + assert len(session_1_items) == 1 + assert session_1_items[0].get("content") == "Session 1 message" + + # Verify session 2 has one item left + session_2_items = await session_2.get_items() + assert len(session_2_items) == 1 + assert session_2_items[0].get("content") == "Session 2 message 1" + + session_1.close() + session_2.close() + + +@pytest.mark.asyncio +async def test_sqlite_session_get_items_with_limit(): + """Test SQLiteSession get_items with limit parameter.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_count.db" + session_id = "count_test" + session = SQLiteSession(session_id, db_path) + + # Add multiple items + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Message 1"}, + {"role": "assistant", "content": "Response 1"}, + {"role": "user", "content": "Message 2"}, + {"role": "assistant", "content": "Response 2"}, + {"role": "user", "content": "Message 3"}, + {"role": "assistant", "content": "Response 3"}, + ] + + await session.add_items(items) + + # Test getting all items (default behavior) + all_items = await session.get_items() + assert len(all_items) == 6 + assert all_items[0].get("content") == "Message 1" + assert all_items[-1].get("content") == "Response 3" + + # Test getting latest 2 items + latest_2 = await session.get_items(limit=2) + assert len(latest_2) == 2 + assert latest_2[0].get("content") == "Message 3" + assert latest_2[1].get("content") == "Response 3" + + # Test getting latest 4 items + latest_4 = await session.get_items(limit=4) + assert len(latest_4) == 4 + assert latest_4[0].get("content") == "Message 2" + assert latest_4[1].get("content") == "Response 2" + assert latest_4[2].get("content") == "Message 3" + assert latest_4[3].get("content") == "Response 3" + + # Test getting more items than available + latest_10 = await session.get_items(limit=10) + assert len(latest_10) == 6 # Should return all available items + assert latest_10[0].get("content") == "Message 1" + assert latest_10[-1].get("content") == "Response 3" + + # Test getting 0 items + latest_0 = await session.get_items(limit=0) + assert len(latest_0) == 0 + + session.close() + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_memory_rejects_both_session_and_list_input(runner_method): + """Test that passing both a session and list input raises a UserError across all runner + methods. + """ + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_validation.db" + session_id = "test_validation_parametrized" + session = SQLiteSession(session_id, db_path) + + model = FakeModel() + agent = Agent(name="test", model=model) + + # Test that providing both a session and a list input raises a UserError + model.set_next_output([get_text_message("This shouldn't run")]) + + list_input = [ + {"role": "user", "content": "Test message"}, + ] + + with pytest.raises(UserError) as exc_info: + await run_agent_async(runner_method, agent, list_input, session=session) + + # Verify the error message explains the issue + assert "Cannot provide both a session and a list of input items" in str(exc_info.value) + assert "manually manage conversation history" in str(exc_info.value) + + session.close() diff --git a/tests/test_session_exceptions.py b/tests/test_session_exceptions.py new file mode 100644 index 000000000..a454cca92 --- /dev/null +++ b/tests/test_session_exceptions.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import asyncio +import json +from typing import Any +from unittest.mock import AsyncMock, Mock + +import pytest +import websockets.exceptions + +from agents.realtime.events import RealtimeError +from agents.realtime.model import RealtimeModel, RealtimeModelConfig, RealtimeModelListener +from agents.realtime.model_events import ( + RealtimeModelErrorEvent, + RealtimeModelEvent, + RealtimeModelExceptionEvent, +) +from agents.realtime.session import RealtimeSession + + +class FakeRealtimeModel(RealtimeModel): + """Fake model for testing that forwards events to listeners.""" + + def __init__(self): + self._listeners: list[RealtimeModelListener] = [] + self._events_to_send: list[RealtimeModelEvent] = [] + self._is_connected = False + self._send_task: asyncio.Task[None] | None = None + + def set_next_events(self, events: list[RealtimeModelEvent]) -> None: + """Set events to be sent to listeners.""" + self._events_to_send = events.copy() + + async def connect(self, options: RealtimeModelConfig) -> None: + """Fake connection that starts sending events.""" + self._is_connected = True + self._send_task = asyncio.create_task(self._send_events()) + + async def _send_events(self) -> None: + """Send queued events to all listeners.""" + for event in self._events_to_send: + await asyncio.sleep(0.001) # Small delay to simulate async behavior + for listener in self._listeners: + await listener.on_event(event) + + def add_listener(self, listener: RealtimeModelListener) -> None: + """Add a listener.""" + self._listeners.append(listener) + + def remove_listener(self, listener: RealtimeModelListener) -> None: + """Remove a listener.""" + if listener in self._listeners: + self._listeners.remove(listener) + + async def close(self) -> None: + """Close the fake model.""" + self._is_connected = False + if self._send_task and not self._send_task.done(): + self._send_task.cancel() + try: + await self._send_task + except asyncio.CancelledError: + pass + + async def send_message( + self, message: Any, other_event_data: dict[str, Any] | None = None + ) -> None: + """Fake send message.""" + pass + + async def send_audio(self, audio: bytes, *, commit: bool = False) -> None: + """Fake send audio.""" + pass + + async def send_event(self, event: Any) -> None: + """Fake send event.""" + pass + + async def send_tool_output(self, tool_call: Any, output: str, start_response: bool) -> None: + """Fake send tool output.""" + pass + + async def interrupt(self) -> None: + """Fake interrupt.""" + pass + + +@pytest.fixture +def fake_agent(): + """Create a fake agent for testing.""" + agent = Mock() + agent.get_all_tools = AsyncMock(return_value=[]) + return agent + + +@pytest.fixture +def fake_model(): + """Create a fake model for testing.""" + return FakeRealtimeModel() + + +class TestSessionExceptions: + """Test exception handling in RealtimeSession.""" + + @pytest.mark.asyncio + async def test_end_to_end_exception_propagation_and_cleanup( + self, fake_model: FakeRealtimeModel, fake_agent + ): + """Test that exceptions are stored, trigger cleanup, and are raised in __aiter__.""" + # Create test exception + test_exception = ValueError("Test error") + exception_event = RealtimeModelExceptionEvent( + exception=test_exception, context="Test context" + ) + + # Set up session + session = RealtimeSession(fake_model, fake_agent, None) + + # Set events to send + fake_model.set_next_events([exception_event]) + + # Start session + async with session: + # Try to iterate and expect exception + with pytest.raises(ValueError, match="Test error"): + async for _ in session: + pass # Should never reach here + + # Verify cleanup occurred + assert session._closed is True + assert session._stored_exception == test_exception + assert fake_model._is_connected is False + assert len(fake_model._listeners) == 0 + + @pytest.mark.asyncio + async def test_websocket_connection_closure_type_distinction( + self, fake_model: FakeRealtimeModel, fake_agent + ): + """Test different WebSocket closure types generate appropriate events.""" + # Test ConnectionClosed (should create exception event) + error_closure = websockets.exceptions.ConnectionClosed(None, None) + error_event = RealtimeModelExceptionEvent( + exception=error_closure, context="WebSocket connection closed unexpectedly" + ) + + session = RealtimeSession(fake_model, fake_agent, None) + fake_model.set_next_events([error_event]) + + with pytest.raises(websockets.exceptions.ConnectionClosed): + async with session: + async for _event in session: + pass + + # Verify error closure triggered cleanup + assert session._closed is True + assert isinstance(session._stored_exception, websockets.exceptions.ConnectionClosed) + + @pytest.mark.asyncio + async def test_json_parsing_error_handling(self, fake_model: FakeRealtimeModel, fake_agent): + """Test JSON parsing errors are properly handled and contextualized.""" + # Create JSON decode error + json_error = json.JSONDecodeError("Invalid JSON", "bad json", 0) + json_exception_event = RealtimeModelExceptionEvent( + exception=json_error, context="Failed to parse WebSocket message as JSON" + ) + + session = RealtimeSession(fake_model, fake_agent, None) + fake_model.set_next_events([json_exception_event]) + + with pytest.raises(json.JSONDecodeError): + async with session: + async for _event in session: + pass + + # Verify context is preserved + assert session._stored_exception == json_error + assert session._closed is True + + @pytest.mark.asyncio + async def test_exception_context_preservation(self, fake_model: FakeRealtimeModel, fake_agent): + """Test that exception context information is preserved through the handling process.""" + test_contexts = [ + ("Failed to send audio", RuntimeError("Audio encoding failed")), + ("WebSocket error in message listener", ConnectionError("Network error")), + ("Failed to send event: response.create", OSError("Socket closed")), + ] + + for context, exception in test_contexts: + exception_event = RealtimeModelExceptionEvent(exception=exception, context=context) + + session = RealtimeSession(fake_model, fake_agent, None) + fake_model.set_next_events([exception_event]) + + with pytest.raises(type(exception)): + async with session: + async for _event in session: + pass + + # Verify the exact exception is stored + assert session._stored_exception == exception + assert session._closed is True + + # Reset for next iteration + fake_model._is_connected = False + fake_model._listeners.clear() + + @pytest.mark.asyncio + async def test_multiple_exception_handling_behavior( + self, fake_model: FakeRealtimeModel, fake_agent + ): + """Test behavior when multiple exceptions occur before consumption.""" + # Create multiple exceptions + first_exception = ValueError("First error") + second_exception = RuntimeError("Second error") + + first_event = RealtimeModelExceptionEvent( + exception=first_exception, context="First context" + ) + second_event = RealtimeModelExceptionEvent( + exception=second_exception, context="Second context" + ) + + session = RealtimeSession(fake_model, fake_agent, None) + fake_model.set_next_events([first_event, second_event]) + + # Start session and let events process + async with session: + # Give time for events to be processed + await asyncio.sleep(0.05) + + # The first exception should be stored (second should overwrite, but that's + # the current behavior). In practice, once an exception occurs, cleanup + # should prevent further processing + assert session._stored_exception is not None + assert session._closed is True + + @pytest.mark.asyncio + async def test_exception_during_guardrail_processing( + self, fake_model: FakeRealtimeModel, fake_agent + ): + """Test that exceptions don't interfere with guardrail task cleanup.""" + # Create exception event + test_exception = RuntimeError("Processing error") + exception_event = RealtimeModelExceptionEvent( + exception=test_exception, context="Processing failed" + ) + + session = RealtimeSession(fake_model, fake_agent, None) + + # Add some fake guardrail tasks + fake_task1 = Mock() + fake_task1.done.return_value = False + fake_task1.cancel = Mock() + + fake_task2 = Mock() + fake_task2.done.return_value = True + fake_task2.cancel = Mock() + + session._guardrail_tasks = {fake_task1, fake_task2} + + fake_model.set_next_events([exception_event]) + + with pytest.raises(RuntimeError, match="Processing error"): + async with session: + async for _event in session: + pass + + # Verify guardrail tasks were properly cleaned up + fake_task1.cancel.assert_called_once() + fake_task2.cancel.assert_not_called() # Already done + assert len(session._guardrail_tasks) == 0 + + @pytest.mark.asyncio + async def test_normal_events_still_work_before_exception( + self, fake_model: FakeRealtimeModel, fake_agent + ): + """Test that normal events are processed before an exception occurs.""" + # Create normal event followed by exception + normal_event = RealtimeModelErrorEvent(error={"message": "Normal error"}) + exception_event = RealtimeModelExceptionEvent( + exception=ValueError("Fatal error"), context="Fatal context" + ) + + session = RealtimeSession(fake_model, fake_agent, None) + fake_model.set_next_events([normal_event, exception_event]) + + events_received = [] + + with pytest.raises(ValueError, match="Fatal error"): + async with session: + async for event in session: + events_received.append(event) + + # Should have received events before exception + assert len(events_received) >= 1 + # Look for the error event (might not be first due to history_updated + # being emitted initially) + error_events = [e for e in events_received if hasattr(e, "type") and e.type == "error"] + assert len(error_events) >= 1 + assert isinstance(error_events[0], RealtimeError) diff --git a/tests/voice/test_openai_stt.py b/tests/voice/test_openai_stt.py index 89b5cca70..ecc41f2e2 100644 --- a/tests/voice/test_openai_stt.py +++ b/tests/voice/test_openai_stt.py @@ -186,7 +186,7 @@ async def test_stream_audio_sends_correct_json(): @pytest.mark.asyncio async def test_transcription_event_puts_output_in_queue(): """ - Test that a 'conversation.item.input_audio_transcription.completed' event + Test that a 'input_audio_transcription_completed' event yields a transcript from transcribe_turns(). """ mock_ws = create_mock_websocket( @@ -196,7 +196,7 @@ async def test_transcription_event_puts_output_in_queue(): # Once configured, we mock a completed transcription event: json.dumps( { - "type": "conversation.item.input_audio_transcription.completed", + "type": "input_audio_transcription_completed", "transcript": "Hello world!", } ), diff --git a/uv.lock b/uv.lock index d882c9bc5..7d0621d88 100644 --- a/uv.lock +++ b/uv.lock @@ -1461,7 +1461,7 @@ wheels = [ [[package]] name = "openai" -version = "1.87.0" +version = "1.93.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1473,14 +1473,14 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/47/ed/2b3f6c7e950784e9442115ab8ebeff514d543fb33da10607b39364645a75/openai-1.87.0.tar.gz", hash = "sha256:5c69764171e0db9ef993e7a4d8a01fd8ff1026b66f8bdd005b9461782b6e7dfc", size = 470880, upload-time = "2025-06-16T19:04:26.316Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5e/a8/e4427729da048cb33bda15e70f09f7520bdf3577bafc546b135ecb36af7d/openai-1.93.1.tar.gz", hash = "sha256:11eb8932965d0f79ecc4cb38a60a0c4cef4bcd5fcf08b99fc9a399fa5f1e50ab", size = 487124, upload-time = "2025-07-07T16:40:38.389Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/36/ac/313ded47ce1d5bc2ec02ed5dd5506bf5718678a4655ac20f337231d9aae3/openai-1.87.0-py3-none-any.whl", hash = "sha256:f9bcae02ac4fff6522276eee85d33047335cfb692b863bd8261353ce4ada5692", size = 734368, upload-time = "2025-06-16T19:04:23.181Z" }, + { url = "https://files.pythonhosted.org/packages/64/4f/875e5af1fb4e5ed4ea9e4a88f482d9ca2e48932105605b6c516e9a14de25/openai-1.93.1-py3-none-any.whl", hash = "sha256:a2c2946c4f21346d4902311a7440381fd8a33466ee7ca688133d1cad29a9357c", size = 755081, upload-time = "2025-07-07T16:40:36.585Z" }, ] [[package]] name = "openai-agents" -version = "0.1.0" +version = "0.2.0" source = { editable = "." } dependencies = [ { name = "griffe" }, @@ -1496,6 +1496,9 @@ dependencies = [ litellm = [ { name = "litellm" }, ] +realtime = [ + { name = "websockets" }, +] viz = [ { name = "graphviz" }, ] @@ -1536,14 +1539,15 @@ requires-dist = [ { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.67.4.post1,<2" }, { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.9.4,<2" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, - { name = "openai", specifier = ">=1.87.0" }, + { name = "openai", specifier = ">=1.93.1,<2" }, { name = "pydantic", specifier = ">=2.10,<3" }, { name = "requests", specifier = ">=2.0,<3" }, { name = "types-requests", specifier = ">=2.0,<3" }, { name = "typing-extensions", specifier = ">=4.12.2,<5" }, + { name = "websockets", marker = "extra == 'realtime'", specifier = ">=15.0,<16" }, { name = "websockets", marker = "extra == 'voice'", specifier = ">=15.0,<16" }, ] -provides-extras = ["voice", "viz", "litellm"] +provides-extras = ["voice", "viz", "litellm", "realtime"] [package.metadata.requires-dev] dev = [