diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..ff37db326 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,69 @@ +Welcome to the OpenAI Agents SDK repository. This file contains the main points for new contributors. + +## Repository overview + +- **Source code**: `src/agents/` contains the implementation. +- **Tests**: `tests/` with a short guide in `tests/README.md`. +- **Examples**: under `examples/`. +- **Documentation**: markdown pages live in `docs/` with `mkdocs.yml` controlling the site. +- **Utilities**: developer commands are defined in the `Makefile`. +- **PR template**: `.github/PULL_REQUEST_TEMPLATE/pull_request_template.md` describes the information every PR must include. + +## Local workflow + +1. Format, lint and type‑check your changes: + + ```bash + make format + make lint + make mypy + ``` + +2. Run the tests: + + ```bash + make tests + ``` + + To run a single test, use `uv run pytest -s -k `. + +3. Build the documentation (optional but recommended for docs changes): + + ```bash + make build-docs + ``` + + Coverage can be generated with `make coverage`. + +## Snapshot tests + +Some tests rely on inline snapshots. See `tests/README.md` for details on updating them: + +```bash +make snapshots-fix # update existing snapshots +make snapshots-create # create new snapshots +``` + +Run `make tests` again after updating snapshots to ensure they pass. + +## Style notes + +- Write comments as full sentences and end them with a period. + +## Pull request expectations + +PRs should use the template located at `.github/PULL_REQUEST_TEMPLATE/pull_request_template.md`. Provide a summary, test plan and issue number if applicable, then check that: + +- New tests are added when needed. +- Documentation is updated. +- `make lint` and `make format` have been run. +- The full test suite passes. + +Commit messages should be concise and written in the imperative mood. Small, focused commits are preferred. + +## What reviewers look for + +- Tests covering new behaviour. +- Consistent style: code formatted with `ruff format`, imports sorted, and type hints passing `mypy`. +- Clear documentation for any public API changes. +- Clean history and a helpful PR description. diff --git a/README.md b/README.md index 7dcd97b33..785177916 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,9 @@ The OpenAI Agents SDK is a lightweight yet powerful framework for building multi Image of the Agents Tracing UI +> [!NOTE] +> Looking for the JavaScript/TypeScript version? Check out [Agents SDK JS/TS](https://github.com/openai/openai-agents-js). + ### Core concepts: 1. [**Agents**](https://openai.github.io/openai-agents-python/agents): LLMs configured with instructions, tools, guardrails, and handoffs diff --git a/docs/ja/mcp.md b/docs/ja/mcp.md index 7cdaa57ee..09804beb2 100644 --- a/docs/ja/mcp.md +++ b/docs/ja/mcp.md @@ -12,12 +12,13 @@ Agents SDK は MCP をサポートしており、これにより幅広い MCP ## MCP サーバー -現在、MCP 仕様では使用するトランスポート方式に基づき 2 種類のサーバーが定義されています。 +現在、MCP 仕様では使用するトランスポート方式に基づき 3 種類のサーバーが定義されています。 -1. **stdio** サーバー: アプリケーションのサブプロセスとして実行されます。ローカルで動かすイメージです。 +1. **stdio** サーバー: アプリケーションのサブプロセスとして実行されます。ローカルで動かすイメージです。 2. **HTTP over SSE** サーバー: リモートで動作し、 URL 経由で接続します。 +3. **Streamable HTTP** サーバー: MCP 仕様に定義された Streamable HTTP トランスポートを使用してリモートで動作します。 -これらのサーバーへは [`MCPServerStdio`][agents.mcp.server.MCPServerStdio] と [`MCPServerSse`][agents.mcp.server.MCPServerSse] クラスを使用して接続できます。 +これらのサーバーへは [`MCPServerStdio`][agents.mcp.server.MCPServerStdio]、[`MCPServerSse`][agents.mcp.server.MCPServerSse]、[`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp] クラスを使用して接続できます。 たとえば、[公式 MCP filesystem サーバー](https://www.npmjs.com/package/@modelcontextprotocol/server-filesystem)を利用する場合は次のようになります。 @@ -46,7 +47,7 @@ agent=Agent( ## キャッシュ -エージェントが実行されるたびに、MCP サーバーへ `list_tools()` が呼び出されます。サーバーがリモートの場合は特にレイテンシが発生します。ツール一覧を自動でキャッシュしたい場合は、[`MCPServerStdio`][agents.mcp.server.MCPServerStdio] と [`MCPServerSse`][agents.mcp.server.MCPServerSse] の両方に `cache_tools_list=True` を渡してください。ツール一覧が変更されないと確信できる場合のみ使用してください。 +エージェントが実行されるたびに、MCP サーバーへ `list_tools()` が呼び出されます。サーバーがリモートの場合は特にレイテンシが発生します。ツール一覧を自動でキャッシュしたい場合は、[`MCPServerStdio`][agents.mcp.server.MCPServerStdio]、[`MCPServerSse`][agents.mcp.server.MCPServerSse]、[`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp] の各クラスに `cache_tools_list=True` を渡してください。ツール一覧が変更されないと確信できる場合のみ使用してください。 キャッシュを無効化したい場合は、サーバーで `invalidate_tools_cache()` を呼び出します。 diff --git a/docs/ja/repl.md b/docs/ja/repl.md new file mode 100644 index 000000000..108c12b90 --- /dev/null +++ b/docs/ja/repl.md @@ -0,0 +1,22 @@ +--- +search: + exclude: true +--- +# REPL ユーティリティ + +`run_demo_loop` を使うと、ターミナルから手軽にエージェントを試せます。 + +```python +import asyncio +from agents import Agent, run_demo_loop + +async def main() -> None: + agent = Agent(name="Assistant", instructions="あなたは親切なアシスタントです") + await run_demo_loop(agent) + +if __name__ == "__main__": + asyncio.run(main()) +``` + +`run_demo_loop` は入力を繰り返し受け取り、会話履歴を保持したままエージェントを実行します。既定ではストリーミング出力を表示します。 +`quit` または `exit` と入力するか `Ctrl-D` を押すと終了します。 diff --git a/docs/ja/tools.md b/docs/ja/tools.md index 7ab15e472..cd80092d5 100644 --- a/docs/ja/tools.md +++ b/docs/ja/tools.md @@ -17,6 +17,10 @@ OpenAI は [`OpenAIResponsesModel`][agents.models.openai_responses.OpenAIRespons - [`WebSearchTool`][agents.tool.WebSearchTool] はエージェントに Web 検索を行わせます。 - [`FileSearchTool`][agents.tool.FileSearchTool] は OpenAI ベクトルストアから情報を取得します。 - [`ComputerTool`][agents.tool.ComputerTool] はコンピュータ操作タスクを自動化します。 +- [`CodeInterpreterTool`][agents.tool.CodeInterpreterTool] はサンドボックス環境でコードを実行します。 +- [`HostedMCPTool`][agents.tool.HostedMCPTool] はリモート MCP サーバーのツールをモデルから直接利用できるようにします。 +- [`ImageGenerationTool`][agents.tool.ImageGenerationTool] はプロンプトから画像を生成します。 +- [`LocalShellTool`][agents.tool.LocalShellTool] はローカルマシンでシェルコマンドを実行します。 ```python from agents import Agent, FileSearchTool, Runner, WebSearchTool diff --git a/docs/ja/tracing.md b/docs/ja/tracing.md index 0e0d0e77d..d67200ce4 100644 --- a/docs/ja/tracing.md +++ b/docs/ja/tracing.md @@ -119,4 +119,5 @@ async def main(): - [Comet Opik](https://www.comet.com/docs/opik/tracing/integrations/openai_agents) - [Langfuse](https://langfuse.com/docs/integrations/openaiagentssdk/openai-agents) - [Langtrace](https://docs.langtrace.ai/supported-integrations/llm-frameworks/openai-agents-sdk) -- [Okahu‑Monocle](https://github.com/monocle2ai/monocle) \ No newline at end of file +- [Okahu‑Monocle](https://github.com/monocle2ai/monocle) +- [Portkey AI](https://portkey.ai/docs/integrations/agents/openai-agents) diff --git a/docs/mcp.md b/docs/mcp.md index e279a25e0..76d142029 100644 --- a/docs/mcp.md +++ b/docs/mcp.md @@ -8,12 +8,13 @@ The Agents SDK has support for MCP. This enables you to use a wide range of MCP ## MCP servers -Currently, the MCP spec defines two kinds of servers, based on the transport mechanism they use: +Currently, the MCP spec defines three kinds of servers, based on the transport mechanism they use: 1. **stdio** servers run as a subprocess of your application. You can think of them as running "locally". 2. **HTTP over SSE** servers run remotely. You connect to them via a URL. +3. **Streamable HTTP** servers run remotely using the Streamable HTTP transport defined in the MCP spec. -You can use the [`MCPServerStdio`][agents.mcp.server.MCPServerStdio] and [`MCPServerSse`][agents.mcp.server.MCPServerSse] classes to connect to these servers. +You can use the [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServerSse`][agents.mcp.server.MCPServerSse], and [`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp] classes to connect to these servers. For example, this is how you'd use the [official MCP filesystem server](https://www.npmjs.com/package/@modelcontextprotocol/server-filesystem). @@ -42,7 +43,7 @@ agent=Agent( ## Caching -Every time an Agent runs, it calls `list_tools()` on the MCP server. This can be a latency hit, especially if the server is a remote server. To automatically cache the list of tools, you can pass `cache_tools_list=True` to both [`MCPServerStdio`][agents.mcp.server.MCPServerStdio] and [`MCPServerSse`][agents.mcp.server.MCPServerSse]. You should only do this if you're certain the tool list will not change. +Every time an Agent runs, it calls `list_tools()` on the MCP server. This can be a latency hit, especially if the server is a remote server. To automatically cache the list of tools, you can pass `cache_tools_list=True` to [`MCPServerStdio`][agents.mcp.server.MCPServerStdio], [`MCPServerSse`][agents.mcp.server.MCPServerSse], and [`MCPServerStreamableHttp`][agents.mcp.server.MCPServerStreamableHttp]. You should only do this if you're certain the tool list will not change. If you want to invalidate the cache, you can call `invalidate_tools_cache()` on the servers. diff --git a/docs/ref/repl.md b/docs/ref/repl.md new file mode 100644 index 000000000..a064a9bff --- /dev/null +++ b/docs/ref/repl.md @@ -0,0 +1,6 @@ +# `repl` + +::: agents.repl + options: + members: + - run_demo_loop diff --git a/docs/release.md b/docs/release.md new file mode 100644 index 000000000..3bfc6e9b8 --- /dev/null +++ b/docs/release.md @@ -0,0 +1,18 @@ +# Release process + +The project follows a slightly modified version of semantic versioning using the form `0.Y.Z`. The leading `0` indicates the SDK is still evolving rapidly. Increment the components as follows: + +## Minor (`Y`) versions + +We will increase minor versions `Y` for **breaking changes** to any public interfaces that are not marked as beta. For example, going from `0.0.x` to `0.1.x` might include breaking changes. + +If you don't want breaking changes, we recommend pinning to `0.0.x` versions in your project. + +## Patch (`Z`) versions + +We will increment `Z` for non-breaking changes: + +- Bug fixes +- New features +- Changes to private interfaces +- Updates to beta features diff --git a/docs/repl.md b/docs/repl.md new file mode 100644 index 000000000..073b87f51 --- /dev/null +++ b/docs/repl.md @@ -0,0 +1,19 @@ +# REPL utility + +The SDK provides `run_demo_loop` for quick interactive testing. + +```python +import asyncio +from agents import Agent, run_demo_loop + +async def main() -> None: + agent = Agent(name="Assistant", instructions="You are a helpful assistant.") + await run_demo_loop(agent) + +if __name__ == "__main__": + asyncio.run(main()) +``` + +`run_demo_loop` prompts for user input in a loop, keeping the conversation +history between turns. By default it streams model output as it is produced. +Type `quit` or `exit` (or press `Ctrl-D`) to leave the loop. diff --git a/docs/tools.md b/docs/tools.md index 5fe2ecedb..4e9a20d32 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -13,6 +13,10 @@ OpenAI offers a few built-in tools when using the [`OpenAIResponsesModel`][agent - The [`WebSearchTool`][agents.tool.WebSearchTool] lets an agent search the web. - The [`FileSearchTool`][agents.tool.FileSearchTool] allows retrieving information from your OpenAI Vector Stores. - The [`ComputerTool`][agents.tool.ComputerTool] allows automating computer use tasks. +- The [`CodeInterpreterTool`][agents.tool.CodeInterpreterTool] lets the LLM execute code in a sandboxed environment. +- The [`HostedMCPTool`][agents.tool.HostedMCPTool] exposes a remote MCP server's tools to the model. +- The [`ImageGenerationTool`][agents.tool.ImageGenerationTool] generates images from a prompt. +- The [`LocalShellTool`][agents.tool.LocalShellTool] runs shell commands on your machine. ```python from agents import Agent, FileSearchTool, Runner, WebSearchTool @@ -266,7 +270,7 @@ The `agent.as_tool` function is a convenience method to make it easy to turn an ```python @function_tool async def run_my_agent() -> str: - """A tool that runs the agent with custom configs". + """A tool that runs the agent with custom configs""" agent = Agent(name="My agent", instructions="...") diff --git a/docs/tracing.md b/docs/tracing.md index dd883c5aa..c7776ad7b 100644 --- a/docs/tracing.md +++ b/docs/tracing.md @@ -115,3 +115,5 @@ To customize this default setup, to send traces to alternative or additional bac - [Langfuse](https://langfuse.com/docs/integrations/openaiagentssdk/openai-agents) - [Langtrace](https://docs.langtrace.ai/supported-integrations/llm-frameworks/openai-agents-sdk) - [Okahu-Monocle](https://github.com/monocle2ai/monocle) +- [Galileo](https://v2docs.galileo.ai/integrations/openai-agent-integration#openai-agent-integration) +- [Portkey AI](https://portkey.ai/docs/integrations/agents/openai-agents) diff --git a/examples/agent_patterns/input_guardrails.py b/examples/agent_patterns/input_guardrails.py index 154535511..18ab9d2a7 100644 --- a/examples/agent_patterns/input_guardrails.py +++ b/examples/agent_patterns/input_guardrails.py @@ -20,7 +20,7 @@ Guardrails are checks that run in parallel to the agent's execution. They can be used to do things like: - Check if input messages are off-topic -- Check that output messages don't violate any policies +- Check that input messages don't violate any policies - Take over control of the agent's execution if an unexpected input is detected In this example, we'll setup an input guardrail that trips if the user is asking to do math homework. diff --git a/examples/hosted_mcp/__init__.py b/examples/hosted_mcp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/hosted_mcp/approvals.py b/examples/hosted_mcp/approvals.py new file mode 100644 index 000000000..3080a1d63 --- /dev/null +++ b/examples/hosted_mcp/approvals.py @@ -0,0 +1,61 @@ +import argparse +import asyncio + +from agents import ( + Agent, + HostedMCPTool, + MCPToolApprovalFunctionResult, + MCPToolApprovalRequest, + Runner, +) + +"""This example demonstrates how to use the hosted MCP support in the OpenAI Responses API, with +approval callbacks.""" + + +def approval_callback(request: MCPToolApprovalRequest) -> MCPToolApprovalFunctionResult: + answer = input(f"Approve running the tool `{request.data.name}`? (y/n) ") + result: MCPToolApprovalFunctionResult = {"approve": answer == "y"} + if not result["approve"]: + result["reason"] = "User denied" + return result + + +async def main(verbose: bool, stream: bool): + agent = Agent( + name="Assistant", + tools=[ + HostedMCPTool( + tool_config={ + "type": "mcp", + "server_label": "gitmcp", + "server_url": "https://gitmcp.io/openai/codex", + "require_approval": "always", + }, + on_approval_request=approval_callback, + ) + ], + ) + + if stream: + result = Runner.run_streamed(agent, "Which language is this repo written in?") + async for event in result.stream_events(): + if event.type == "run_item_stream_event": + print(f"Got event of type {event.item.__class__.__name__}") + print(f"Done streaming; final result: {result.final_output}") + else: + res = await Runner.run(agent, "Which language is this repo written in?") + print(res.final_output) + + if verbose: + for item in res.new_items: + print(item) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--verbose", action="store_true", default=False) + parser.add_argument("--stream", action="store_true", default=False) + args = parser.parse_args() + + asyncio.run(main(args.verbose, args.stream)) diff --git a/examples/hosted_mcp/simple.py b/examples/hosted_mcp/simple.py new file mode 100644 index 000000000..895fdfbe0 --- /dev/null +++ b/examples/hosted_mcp/simple.py @@ -0,0 +1,47 @@ +import argparse +import asyncio + +from agents import Agent, HostedMCPTool, Runner + +"""This example demonstrates how to use the hosted MCP support in the OpenAI Responses API, with +approvals not required for any tools. You should only use this for trusted MCP servers.""" + + +async def main(verbose: bool, stream: bool): + agent = Agent( + name="Assistant", + tools=[ + HostedMCPTool( + tool_config={ + "type": "mcp", + "server_label": "gitmcp", + "server_url": "https://gitmcp.io/openai/codex", + "require_approval": "never", + } + ) + ], + ) + + if stream: + result = Runner.run_streamed(agent, "Which language is this repo written in?") + async for event in result.stream_events(): + if event.type == "run_item_stream_event": + print(f"Got event of type {event.item.__class__.__name__}") + print(f"Done streaming; final result: {result.final_output}") + else: + res = await Runner.run(agent, "Which language is this repo written in?") + print(res.final_output) + # The repository is primarily written in multiple languages, including Rust and TypeScript... + + if verbose: + for item in res.new_items: + print(item) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--verbose", action="store_true", default=False) + parser.add_argument("--stream", action="store_true", default=False) + args = parser.parse_args() + + asyncio.run(main(args.verbose, args.stream)) diff --git a/examples/mcp/streamablehttp_example/README.md b/examples/mcp/streamablehttp_example/README.md new file mode 100644 index 000000000..a07fe19be --- /dev/null +++ b/examples/mcp/streamablehttp_example/README.md @@ -0,0 +1,13 @@ +# MCP Streamable HTTP Example + +This example uses a local Streamable HTTP server in [server.py](server.py). + +Run the example via: + +``` +uv run python examples/mcp/streamablehttp_example/main.py +``` + +## Details + +The example uses the `MCPServerStreamableHttp` class from `agents.mcp`. The server runs in a sub-process at `https://localhost:8000/mcp`. diff --git a/examples/mcp/streamablehttp_example/main.py b/examples/mcp/streamablehttp_example/main.py new file mode 100644 index 000000000..cc95e798b --- /dev/null +++ b/examples/mcp/streamablehttp_example/main.py @@ -0,0 +1,83 @@ +import asyncio +import os +import shutil +import subprocess +import time +from typing import Any + +from agents import Agent, Runner, gen_trace_id, trace +from agents.mcp import MCPServer, MCPServerStreamableHttp +from agents.model_settings import ModelSettings + + +async def run(mcp_server: MCPServer): + agent = Agent( + name="Assistant", + instructions="Use the tools to answer the questions.", + mcp_servers=[mcp_server], + model_settings=ModelSettings(tool_choice="required"), + ) + + # Use the `add` tool to add two numbers + message = "Add these numbers: 7 and 22." + print(f"Running: {message}") + result = await Runner.run(starting_agent=agent, input=message) + print(result.final_output) + + # Run the `get_weather` tool + message = "What's the weather in Tokyo?" + print(f"\n\nRunning: {message}") + result = await Runner.run(starting_agent=agent, input=message) + print(result.final_output) + + # Run the `get_secret_word` tool + message = "What's the secret word?" + print(f"\n\nRunning: {message}") + result = await Runner.run(starting_agent=agent, input=message) + print(result.final_output) + + +async def main(): + async with MCPServerStreamableHttp( + name="Streamable HTTP Python Server", + params={ + "url": "http://localhost:8000/mcp", + }, + ) as server: + trace_id = gen_trace_id() + with trace(workflow_name="Streamable HTTP Example", trace_id=trace_id): + print(f"View trace: https://platform.openai.com/traces/trace?trace_id={trace_id}\n") + await run(server) + + +if __name__ == "__main__": + # Let's make sure the user has uv installed + if not shutil.which("uv"): + raise RuntimeError( + "uv is not installed. Please install it: https://docs.astral.sh/uv/getting-started/installation/" + ) + + # We'll run the Streamable HTTP server in a subprocess. Usually this would be a remote server, but for this + # demo, we'll run it locally at http://localhost:8000/mcp + process: subprocess.Popen[Any] | None = None + try: + this_dir = os.path.dirname(os.path.abspath(__file__)) + server_file = os.path.join(this_dir, "server.py") + + print("Starting Streamable HTTP server at http://localhost:8000/mcp ...") + + # Run `uv run server.py` to start the Streamable HTTP server + process = subprocess.Popen(["uv", "run", server_file]) + # Give it 3 seconds to start + time.sleep(3) + + print("Streamable HTTP server started. Running example...\n\n") + except Exception as e: + print(f"Error starting Streamable HTTP server: {e}") + exit(1) + + try: + asyncio.run(main()) + finally: + if process: + process.terminate() diff --git a/examples/mcp/streamablehttp_example/server.py b/examples/mcp/streamablehttp_example/server.py new file mode 100644 index 000000000..d8f839652 --- /dev/null +++ b/examples/mcp/streamablehttp_example/server.py @@ -0,0 +1,33 @@ +import random + +import requests +from mcp.server.fastmcp import FastMCP + +# Create server +mcp = FastMCP("Echo Server") + + +@mcp.tool() +def add(a: int, b: int) -> int: + """Add two numbers""" + print(f"[debug-server] add({a}, {b})") + return a + b + + +@mcp.tool() +def get_secret_word() -> str: + print("[debug-server] get_secret_word()") + return random.choice(["apple", "banana", "cherry"]) + + +@mcp.tool() +def get_current_weather(city: str) -> str: + print(f"[debug-server] get_current_weather({city})") + + endpoint = "https://wttr.in" + response = requests.get(f"{endpoint}/{city}") + return response.text + + +if __name__ == "__main__": + mcp.run(transport="streamable-http") diff --git a/examples/research_bot/agents/search_agent.py b/examples/research_bot/agents/search_agent.py index 0212ce5b5..61f91701f 100644 --- a/examples/research_bot/agents/search_agent.py +++ b/examples/research_bot/agents/search_agent.py @@ -3,7 +3,7 @@ INSTRUCTIONS = ( "You are a research assistant. Given a search term, you search the web for that term and " - "produce a concise summary of the results. The summary must 2-3 paragraphs and less than 300 " + "produce a concise summary of the results. The summary must be 2-3 paragraphs and less than 300 " "words. Capture the main points. Write succinctly, no need to have complete sentences or good " "grammar. This will be consumed by someone synthesizing a report, so its vital you capture the " "essence and ignore any fluff. Do not include any additional commentary other than the summary " diff --git a/examples/tools/code_interpreter.py b/examples/tools/code_interpreter.py new file mode 100644 index 000000000..a5843ce3f --- /dev/null +++ b/examples/tools/code_interpreter.py @@ -0,0 +1,34 @@ +import asyncio + +from agents import Agent, CodeInterpreterTool, Runner, trace + + +async def main(): + agent = Agent( + name="Code interpreter", + instructions="You love doing math.", + tools=[ + CodeInterpreterTool( + tool_config={"type": "code_interpreter", "container": {"type": "auto"}}, + ) + ], + ) + + with trace("Code interpreter example"): + print("Solving math problem...") + result = Runner.run_streamed(agent, "What is the square root of273 * 312821 plus 1782?") + async for event in result.stream_events(): + if ( + event.type == "run_item_stream_event" + and event.item.type == "tool_call_item" + and event.item.raw_item.type == "code_interpreter_call" + ): + print(f"Code interpreter code:\n```\n{event.item.raw_item.code}\n```\n") + elif event.type == "run_item_stream_event": + print(f"Other event: {event.item.type}") + + print(f"Final output: {result.final_output}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/tools/image_generator.py b/examples/tools/image_generator.py new file mode 100644 index 000000000..fd6fcc6ba --- /dev/null +++ b/examples/tools/image_generator.py @@ -0,0 +1,54 @@ +import asyncio +import base64 +import os +import subprocess +import sys +import tempfile + +from agents import Agent, ImageGenerationTool, Runner, trace + + +def open_file(path: str) -> None: + if sys.platform.startswith("darwin"): + subprocess.run(["open", path], check=False) # macOS + elif os.name == "nt": # Windows + os.astartfile(path) # type: ignore + elif os.name == "posix": + subprocess.run(["xdg-open", path], check=False) # Linux/Unix + else: + print(f"Don't know how to open files on this platform: {sys.platform}") + + +async def main(): + agent = Agent( + name="Image generator", + instructions="You are a helpful agent.", + tools=[ + ImageGenerationTool( + tool_config={"type": "image_generation", "quality": "low"}, + ) + ], + ) + + with trace("Image generation example"): + print("Generating image, this may take a while...") + result = await Runner.run( + agent, "Create an image of a frog eating a pizza, comic book style." + ) + print(result.final_output) + for item in result.new_items: + if ( + item.type == "tool_call_item" + and item.raw_item.type == "image_generation_call" + and (img_result := item.raw_item.result) + ): + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + tmp.write(base64.b64decode(img_result)) + temp_path = tmp.name + + # Open the image + open_file(temp_path) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/mkdocs.yml b/mkdocs.yml index ad719670c..b79e6454f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -59,6 +59,7 @@ plugins: - running_agents.md - results.md - streaming.md + - repl.md - tools.md - mcp.md - handoffs.md @@ -71,6 +72,7 @@ plugins: - models/litellm.md - config.md - visualization.md + - release.md - Voice agents: - voice/quickstart.md - voice/pipeline.md @@ -80,6 +82,7 @@ plugins: - ref/index.md - ref/agent.md - ref/run.md + - ref/repl.md - ref/tool.md - ref/result.md - ref/stream_events.md @@ -139,6 +142,7 @@ plugins: - running_agents.md - results.md - streaming.md + - repl.md - tools.md - mcp.md - handoffs.md diff --git a/pyproject.toml b/pyproject.toml index 22b028ae7..b59073ed8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,19 +1,19 @@ [project] name = "openai-agents" -version = "0.0.14" +version = "0.0.17" description = "OpenAI Agents SDK" readme = "README.md" requires-python = ">=3.9" license = "MIT" authors = [{ name = "OpenAI", email = "support@openai.com" }] dependencies = [ - "openai>=1.76.0", + "openai>=1.81.0", "pydantic>=2.10, <3", "griffe>=1.5.6, <2", "typing-extensions>=4.12.2, <5", "requests>=2.0, <3", "types-requests>=2.0, <3", - "mcp>=1.6.0, <2; python_version >= '3.10'", + "mcp>=1.8.0, <2; python_version >= '3.10'", ] classifiers = [ "Typing :: Typed", diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 6d7c90b4f..ee2f2aa16 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -14,6 +14,7 @@ MaxTurnsExceeded, ModelBehaviorError, OutputGuardrailTripwireTriggered, + RunErrorDetails, UserError, ) from .guardrail import ( @@ -44,6 +45,7 @@ from .models.openai_chatcompletions import OpenAIChatCompletionsModel from .models.openai_provider import OpenAIProvider from .models.openai_responses import OpenAIResponsesModel +from .repl import run_demo_loop from .result import RunResult, RunResultStreaming from .run import RunConfig, Runner from .run_context import RunContextWrapper, TContext @@ -54,10 +56,19 @@ StreamEvent, ) from .tool import ( + CodeInterpreterTool, ComputerTool, FileSearchTool, FunctionTool, FunctionToolResult, + HostedMCPTool, + ImageGenerationTool, + LocalShellCommandRequest, + LocalShellExecutor, + LocalShellTool, + MCPToolApprovalFunction, + MCPToolApprovalFunctionResult, + MCPToolApprovalRequest, Tool, WebSearchTool, default_tool_error_function, @@ -150,6 +161,7 @@ def enable_verbose_stdout_logging(): "ToolsToFinalOutputFunction", "ToolsToFinalOutputResult", "Runner", + "run_demo_loop", "Model", "ModelProvider", "ModelTracing", @@ -195,6 +207,7 @@ def enable_verbose_stdout_logging(): "AgentHooks", "RunContextWrapper", "TContext", + "RunErrorDetails", "RunResult", "RunResultStreaming", "RunConfig", @@ -206,8 +219,17 @@ def enable_verbose_stdout_logging(): "FunctionToolResult", "ComputerTool", "FileSearchTool", + "CodeInterpreterTool", + "ImageGenerationTool", + "LocalShellCommandRequest", + "LocalShellExecutor", + "LocalShellTool", "Tool", "WebSearchTool", + "HostedMCPTool", + "MCPToolApprovalFunction", + "MCPToolApprovalRequest", + "MCPToolApprovalFunctionResult", "function_tool", "Usage", "add_trace_processor", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index b5a83685c..e98010555 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -14,6 +14,9 @@ ResponseFunctionWebSearch, ResponseOutputMessage, ) +from openai.types.responses.response_code_interpreter_tool_call import ( + ResponseCodeInterpreterToolCall, +) from openai.types.responses.response_computer_tool_call import ( ActionClick, ActionDoubleClick, @@ -25,7 +28,14 @@ ActionType, ActionWait, ) -from openai.types.responses.response_input_param import ComputerCallOutput +from openai.types.responses.response_input_param import ComputerCallOutput, McpApprovalResponse +from openai.types.responses.response_output_item import ( + ImageGenerationCall, + LocalShellCall, + McpApprovalRequest, + McpCall, + McpListTools, +) from openai.types.responses.response_reasoning_item import ResponseReasoningItem from .agent import Agent, ToolsToFinalOutputResult @@ -38,6 +48,9 @@ HandoffCallItem, HandoffOutputItem, ItemHelpers, + MCPApprovalRequestItem, + MCPApprovalResponseItem, + MCPListToolsItem, MessageOutputItem, ModelResponse, ReasoningItem, @@ -52,7 +65,16 @@ from .models.interface import ModelTracing from .run_context import RunContextWrapper, TContext from .stream_events import RunItemStreamEvent, StreamEvent -from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool +from .tool import ( + ComputerTool, + FunctionTool, + FunctionToolResult, + HostedMCPTool, + LocalShellCommandRequest, + LocalShellTool, + MCPToolApprovalRequest, + Tool, +) from .tracing import ( SpanError, Trace, @@ -112,15 +134,29 @@ class ToolRunComputerAction: computer_tool: ComputerTool +@dataclass +class ToolRunMCPApprovalRequest: + request_item: McpApprovalRequest + mcp_tool: HostedMCPTool + + +@dataclass +class ToolRunLocalShellCall: + tool_call: LocalShellCall + local_shell_tool: LocalShellTool + + @dataclass class ProcessedResponse: new_items: list[RunItem] handoffs: list[ToolRunHandoff] functions: list[ToolRunFunction] computer_actions: list[ToolRunComputerAction] + local_shell_calls: list[ToolRunLocalShellCall] tools_used: list[str] # Names of all tools used, including hosted tools + mcp_approval_requests: list[ToolRunMCPApprovalRequest] # Only requests with callbacks - def has_tools_to_run(self) -> bool: + def has_tools_or_approvals_to_run(self) -> bool: # Handoffs, functions and computer actions need local processing # Hosted tools have already run, so there's nothing to do. return any( @@ -128,6 +164,8 @@ def has_tools_to_run(self) -> bool: self.handoffs, self.functions, self.computer_actions, + self.local_shell_calls, + self.mcp_approval_requests, ] ) @@ -226,7 +264,16 @@ async def execute_tools_and_side_effects( new_step_items.extend([result.run_item for result in function_results]) new_step_items.extend(computer_results) - # Second, check if there are any handoffs + # Next, run the MCP approval requests + if processed_response.mcp_approval_requests: + approval_results = await cls.execute_mcp_approval_requests( + agent=agent, + approval_requests=processed_response.mcp_approval_requests, + context_wrapper=context_wrapper, + ) + new_step_items.extend(approval_results) + + # Next, check if there are any handoffs if run_handoffs := processed_response.handoffs: return await cls.execute_handoffs( agent=agent, @@ -240,7 +287,7 @@ async def execute_tools_and_side_effects( run_config=run_config, ) - # Third, we'll check if the tool use should result in a final output + # Next, we'll check if the tool use should result in a final output check_tool_use = await cls._check_for_final_output_from_tools( agent=agent, tool_results=function_results, @@ -295,7 +342,7 @@ async def execute_tools_and_side_effects( ) elif ( not output_schema or output_schema.is_plain_text() - ) and not processed_response.has_tools_to_run(): + ) and not processed_response.has_tools_or_approvals_to_run(): return await cls.execute_final_output( agent=agent, original_input=original_input, @@ -343,10 +390,20 @@ def process_model_response( run_handoffs = [] functions = [] computer_actions = [] + local_shell_calls = [] + mcp_approval_requests = [] tools_used: list[str] = [] handoff_map = {handoff.tool_name: handoff for handoff in handoffs} function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None) + local_shell_tool = next( + (tool for tool in all_tools if isinstance(tool, LocalShellTool)), None + ) + hosted_mcp_server_map = { + tool.tool_config["server_label"]: tool + for tool in all_tools + if isinstance(tool, HostedMCPTool) + } for output in response.output: if isinstance(output, ResponseOutputMessage): @@ -375,6 +432,57 @@ def process_model_response( computer_actions.append( ToolRunComputerAction(tool_call=output, computer_tool=computer_tool) ) + elif isinstance(output, McpApprovalRequest): + items.append(MCPApprovalRequestItem(raw_item=output, agent=agent)) + if output.server_label not in hosted_mcp_server_map: + _error_tracing.attach_error_to_current_span( + SpanError( + message="MCP server label not found", + data={"server_label": output.server_label}, + ) + ) + raise ModelBehaviorError(f"MCP server label {output.server_label} not found") + else: + server = hosted_mcp_server_map[output.server_label] + if server.on_approval_request: + mcp_approval_requests.append( + ToolRunMCPApprovalRequest( + request_item=output, + mcp_tool=server, + ) + ) + else: + logger.warning( + f"MCP server {output.server_label} has no on_approval_request hook" + ) + elif isinstance(output, McpListTools): + items.append(MCPListToolsItem(raw_item=output, agent=agent)) + elif isinstance(output, McpCall): + items.append(ToolCallItem(raw_item=output, agent=agent)) + tools_used.append("mcp") + elif isinstance(output, ImageGenerationCall): + items.append(ToolCallItem(raw_item=output, agent=agent)) + tools_used.append("image_generation") + elif isinstance(output, ResponseCodeInterpreterToolCall): + items.append(ToolCallItem(raw_item=output, agent=agent)) + tools_used.append("code_interpreter") + elif isinstance(output, LocalShellCall): + items.append(ToolCallItem(raw_item=output, agent=agent)) + tools_used.append("local_shell") + if not local_shell_tool: + _error_tracing.attach_error_to_current_span( + SpanError( + message="Local shell tool not found", + data={}, + ) + ) + raise ModelBehaviorError( + "Model produced local shell call without a local shell tool." + ) + local_shell_calls.append( + ToolRunLocalShellCall(tool_call=output, local_shell_tool=local_shell_tool) + ) + elif not isinstance(output, ResponseFunctionToolCall): logger.warning(f"Unexpected output type, ignoring: {type(output)}") continue @@ -416,7 +524,9 @@ def process_model_response( handoffs=run_handoffs, functions=functions, computer_actions=computer_actions, + local_shell_calls=local_shell_calls, tools_used=tools_used, + mcp_approval_requests=mcp_approval_requests, ) @classmethod @@ -489,6 +599,30 @@ async def run_single_tool( for tool_run, result in zip(tool_runs, results) ] + @classmethod + async def execute_local_shell_calls( + cls, + *, + agent: Agent[TContext], + calls: list[ToolRunLocalShellCall], + context_wrapper: RunContextWrapper[TContext], + hooks: RunHooks[TContext], + config: RunConfig, + ) -> list[RunItem]: + results: list[RunItem] = [] + # Need to run these serially, because each call can affect the local shell state + for call in calls: + results.append( + await LocalShellAction.execute( + agent=agent, + call=call, + hooks=hooks, + context_wrapper=context_wrapper, + config=config, + ) + ) + return results + @classmethod async def execute_computer_actions( cls, @@ -643,6 +777,40 @@ async def execute_handoffs( next_step=NextStepHandoff(new_agent), ) + @classmethod + async def execute_mcp_approval_requests( + cls, + *, + agent: Agent[TContext], + approval_requests: list[ToolRunMCPApprovalRequest], + context_wrapper: RunContextWrapper[TContext], + ) -> list[RunItem]: + async def run_single_approval(approval_request: ToolRunMCPApprovalRequest) -> RunItem: + callback = approval_request.mcp_tool.on_approval_request + assert callback is not None, "Callback is required for MCP approval requests" + maybe_awaitable_result = callback( + MCPToolApprovalRequest(context_wrapper, approval_request.request_item) + ) + if inspect.isawaitable(maybe_awaitable_result): + result = await maybe_awaitable_result + else: + result = maybe_awaitable_result + reason = result.get("reason", None) + raw_item: McpApprovalResponse = { + "approval_request_id": approval_request.request_item.id, + "approve": result["approve"], + "type": "mcp_approval_response", + } + if not result["approve"] and reason: + raw_item["reason"] = reason + return MCPApprovalResponseItem( + raw_item=raw_item, + agent=agent, + ) + + tasks = [run_single_approval(approval_request) for approval_request in approval_requests] + return await asyncio.gather(*tasks) + @classmethod async def execute_final_output( cls, @@ -727,6 +895,11 @@ def stream_step_result_to_queue( event = RunItemStreamEvent(item=item, name="tool_output") elif isinstance(item, ReasoningItem): event = RunItemStreamEvent(item=item, name="reasoning_item_created") + elif isinstance(item, MCPApprovalRequestItem): + event = RunItemStreamEvent(item=item, name="mcp_approval_requested") + elif isinstance(item, MCPListToolsItem): + event = RunItemStreamEvent(item=item, name="mcp_list_tools") + else: logger.warning(f"Unexpected item type: {type(item)}") event = None @@ -919,3 +1092,54 @@ async def _get_screenshot_async( await computer.wait() return await computer.screenshot() + + +class LocalShellAction: + @classmethod + async def execute( + cls, + *, + agent: Agent[TContext], + call: ToolRunLocalShellCall, + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + config: RunConfig, + ) -> RunItem: + await asyncio.gather( + hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool), + ( + agent.hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + request = LocalShellCommandRequest( + ctx_wrapper=context_wrapper, + data=call.tool_call, + ) + output = call.local_shell_tool.executor(request) + if inspect.isawaitable(output): + result = await output + else: + result = output + + await asyncio.gather( + hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result), + ( + agent.hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + return ToolCallOutputItem( + agent=agent, + output=output, + raw_item={ + "type": "local_shell_call_output", + "id": call.tool_call.call_id, + "output": result, + # "id": "out" + call.tool_call.id, # TODO remove this, it should be optional + }, + ) diff --git a/src/agents/agent.py b/src/agents/agent.py index e22f579fa..6adccedd5 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import dataclasses import inspect from collections.abc import Awaitable @@ -17,7 +18,7 @@ from .model_settings import ModelSettings from .models.interface import Model from .run_context import RunContextWrapper, TContext -from .tool import FunctionToolResult, Tool, function_tool +from .tool import FunctionTool, FunctionToolResult, Tool, function_tool from .util import _transforms from .util._types import MaybeAwaitable @@ -246,7 +247,22 @@ async def get_mcp_tools(self) -> list[Tool]: convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict) - async def get_all_tools(self) -> list[Tool]: + async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]: """All agent tools, including MCP tools and function tools.""" mcp_tools = await self.get_mcp_tools() - return mcp_tools + self.tools + + async def _check_tool_enabled(tool: Tool) -> bool: + if not isinstance(tool, FunctionTool): + return True + + attr = tool.is_enabled + if isinstance(attr, bool): + return attr + res = attr(run_context, self) + if inspect.isawaitable(res): + return bool(await res) + return bool(res) + + results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools)) + enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok] + return [*mcp_tools, *enabled] diff --git a/src/agents/agent_output.py b/src/agents/agent_output.py index 066bbd835..ee14956e3 100644 --- a/src/agents/agent_output.py +++ b/src/agents/agent_output.py @@ -38,7 +38,7 @@ def json_schema(self) -> dict[str, Any]: @abc.abstractmethod def is_strict_json_schema(self) -> bool: """Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema - features, but guarantees valis JSON. See here for details: + features, but guarantees valid JSON. See here for details: https://platform.openai.com/docs/guides/structured-outputs#supported-schemas """ pass diff --git a/src/agents/exceptions.py b/src/agents/exceptions.py index 78898f017..c00024c2e 100644 --- a/src/agents/exceptions.py +++ b/src/agents/exceptions.py @@ -1,12 +1,42 @@ -from typing import TYPE_CHECKING +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from .agent import Agent from .guardrail import InputGuardrailResult, OutputGuardrailResult + from .items import ModelResponse, RunItem, TResponseInputItem + from .run_context import RunContextWrapper + +from .util._pretty_print import pretty_print_run_error_details + + +@dataclass +class RunErrorDetails: + """Data collected from an agent run when an exception occurs.""" + + input: str | list[TResponseInputItem] + new_items: list[RunItem] + raw_responses: list[ModelResponse] + last_agent: Agent[Any] + context_wrapper: RunContextWrapper[Any] + input_guardrail_results: list[InputGuardrailResult] + output_guardrail_results: list[OutputGuardrailResult] + + def __str__(self) -> str: + return pretty_print_run_error_details(self) class AgentsException(Exception): """Base class for all exceptions in the Agents SDK.""" + run_data: RunErrorDetails | None + + def __init__(self, *args: object) -> None: + super().__init__(*args) + self.run_data = None + class MaxTurnsExceeded(AgentsException): """Exception raised when the maximum number of turns is exceeded.""" @@ -15,6 +45,7 @@ class MaxTurnsExceeded(AgentsException): def __init__(self, message: str): self.message = message + super().__init__(message) class ModelBehaviorError(AgentsException): @@ -26,6 +57,7 @@ class ModelBehaviorError(AgentsException): def __init__(self, message: str): self.message = message + super().__init__(message) class UserError(AgentsException): @@ -35,15 +67,16 @@ class UserError(AgentsException): def __init__(self, message: str): self.message = message + super().__init__(message) class InputGuardrailTripwireTriggered(AgentsException): """Exception raised when a guardrail tripwire is triggered.""" - guardrail_result: "InputGuardrailResult" + guardrail_result: InputGuardrailResult """The result data of the guardrail that was triggered.""" - def __init__(self, guardrail_result: "InputGuardrailResult"): + def __init__(self, guardrail_result: InputGuardrailResult): self.guardrail_result = guardrail_result super().__init__( f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire" @@ -53,10 +86,10 @@ def __init__(self, guardrail_result: "InputGuardrailResult"): class OutputGuardrailTripwireTriggered(AgentsException): """Exception raised when a guardrail tripwire is triggered.""" - guardrail_result: "OutputGuardrailResult" + guardrail_result: OutputGuardrailResult """The result data of the guardrail that was triggered.""" - def __init__(self, guardrail_result: "OutputGuardrailResult"): + def __init__(self, guardrail_result: OutputGuardrailResult): self.guardrail_result = guardrail_result super().__init__( f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire" diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index dc672acd4..7cf7a2de5 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterator from typing import Any, Literal, cast, overload -import litellm.types +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents.exceptions import ModelBehaviorError @@ -107,6 +107,18 @@ async def get_response( input_tokens=response_usage.prompt_tokens, output_tokens=response_usage.completion_tokens, total_tokens=response_usage.total_tokens, + input_tokens_details=InputTokensDetails( + cached_tokens=getattr( + response_usage.prompt_tokens_details, "cached_tokens", 0 + ) + or 0 + ), + output_tokens_details=OutputTokensDetails( + reasoning_tokens=getattr( + response_usage.completion_tokens_details, "reasoning_tokens", 0 + ) + or 0 + ), ) if response.usage else Usage() @@ -269,6 +281,8 @@ async def _fetch_response( extra_kwargs["extra_query"] = model_settings.extra_query if model_settings.metadata: extra_kwargs["metadata"] = model_settings.metadata + if model_settings.extra_body and isinstance(model_settings.extra_body, dict): + extra_kwargs.update(model_settings.extra_body) ret = await litellm.acompletion( model=self.model, diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index 888e262c3..be762a330 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -1,4 +1,4 @@ -from typing import Optional +from __future__ import annotations import graphviz # type: ignore @@ -31,7 +31,9 @@ def get_main_graph(agent: Agent) -> str: return "".join(parts) -def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: +def get_all_nodes( + agent: Agent, parent: Agent | None = None, visited: set[str] | None = None +) -> str: """ Recursively generates the nodes for the given agent and its handoffs in DOT format. @@ -41,17 +43,23 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: Returns: str: The DOT format string representing the nodes. """ + if visited is None: + visited = set() + if agent.name in visited: + return "" + visited.add(agent.name) + parts = [] # Start and end the graph - parts.append( - '"__start__" [label="__start__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" - '"__end__" [label="__end__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" - ) - # Ensure parent agent node is colored if not parent: + parts.append( + '"__start__" [label="__start__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" + '"__end__" [label="__end__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" + ) + # Ensure parent agent node is colored parts.append( f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, ' "fillcolor=lightyellow, width=1.5, height=0.8];" @@ -71,17 +79,20 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: f"fillcolor=lightyellow, width=1.5, height=0.8];" ) if isinstance(handoff, Agent): - parts.append( - f'"{handoff.name}" [label="{handoff.name}", ' - f"shape=box, style=filled, style=rounded, " - f"fillcolor=lightyellow, width=1.5, height=0.8];" - ) - parts.append(get_all_nodes(handoff)) + if handoff.name not in visited: + parts.append( + f'"{handoff.name}" [label="{handoff.name}", ' + f"shape=box, style=filled, style=rounded, " + f"fillcolor=lightyellow, width=1.5, height=0.8];" + ) + parts.append(get_all_nodes(handoff, agent, visited)) return "".join(parts) -def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: +def get_all_edges( + agent: Agent, parent: Agent | None = None, visited: set[str] | None = None +) -> str: """ Recursively generates the edges for the given agent and its handoffs in DOT format. @@ -92,6 +103,12 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: Returns: str: The DOT format string representing the edges. """ + if visited is None: + visited = set() + if agent.name in visited: + return "" + visited.add(agent.name) + parts = [] if not parent: @@ -109,7 +126,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: if isinstance(handoff, Agent): parts.append(f""" "{agent.name}" -> "{handoff.name}";""") - parts.append(get_all_edges(handoff, agent)) + parts.append(get_all_edges(handoff, agent, visited)) if not agent.handoffs and not isinstance(agent, Tool): # type: ignore parts.append(f'"{agent.name}" -> "__end__";') @@ -117,7 +134,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: return "".join(parts) -def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source: +def draw_graph(agent: Agent, filename: str | None = None) -> graphviz.Source: """ Draws the graph for the given agent and optionally saves it as a PNG file. diff --git a/src/agents/handoffs.py b/src/agents/handoffs.py index 686191f3d..50dd417a2 100644 --- a/src/agents/handoffs.py +++ b/src/agents/handoffs.py @@ -168,7 +168,7 @@ def handoff( input_filter: a function that filters the inputs that are passed to the next agent. """ assert (on_handoff and input_type) or not (on_handoff and input_type), ( - "You must provide either both on_input and input_type, or neither" + "You must provide either both on_handoff and input_type, or neither" ) type_adapter: TypeAdapter[Any] | None if input_type is not None: diff --git a/src/agents/items.py b/src/agents/items.py index 8fb2b52a3..64797ad22 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -18,7 +18,22 @@ ResponseOutputText, ResponseStreamEvent, ) -from openai.types.responses.response_input_item_param import ComputerCallOutput, FunctionCallOutput +from openai.types.responses.response_code_interpreter_tool_call import ( + ResponseCodeInterpreterToolCall, +) +from openai.types.responses.response_input_item_param import ( + ComputerCallOutput, + FunctionCallOutput, + LocalShellCallOutput, + McpApprovalResponse, +) +from openai.types.responses.response_output_item import ( + ImageGenerationCall, + LocalShellCall, + McpApprovalRequest, + McpCall, + McpListTools, +) from openai.types.responses.response_reasoning_item import ResponseReasoningItem from pydantic import BaseModel from typing_extensions import TypeAlias @@ -108,6 +123,10 @@ class HandoffOutputItem(RunItemBase[TResponseInputItem]): ResponseComputerToolCall, ResponseFileSearchToolCall, ResponseFunctionWebSearch, + McpCall, + ResponseCodeInterpreterToolCall, + ImageGenerationCall, + LocalShellCall, ] """A type that represents a tool call item.""" @@ -123,10 +142,12 @@ class ToolCallItem(RunItemBase[ToolCallItemTypes]): @dataclass -class ToolCallOutputItem(RunItemBase[Union[FunctionCallOutput, ComputerCallOutput]]): +class ToolCallOutputItem( + RunItemBase[Union[FunctionCallOutput, ComputerCallOutput, LocalShellCallOutput]] +): """Represents the output of a tool call.""" - raw_item: FunctionCallOutput | ComputerCallOutput + raw_item: FunctionCallOutput | ComputerCallOutput | LocalShellCallOutput """The raw item from the model.""" output: Any @@ -147,6 +168,36 @@ class ReasoningItem(RunItemBase[ResponseReasoningItem]): type: Literal["reasoning_item"] = "reasoning_item" +@dataclass +class MCPListToolsItem(RunItemBase[McpListTools]): + """Represents a call to an MCP server to list tools.""" + + raw_item: McpListTools + """The raw MCP list tools call.""" + + type: Literal["mcp_list_tools_item"] = "mcp_list_tools_item" + + +@dataclass +class MCPApprovalRequestItem(RunItemBase[McpApprovalRequest]): + """Represents a request for MCP approval.""" + + raw_item: McpApprovalRequest + """The raw MCP approval request.""" + + type: Literal["mcp_approval_request_item"] = "mcp_approval_request_item" + + +@dataclass +class MCPApprovalResponseItem(RunItemBase[McpApprovalResponse]): + """Represents a response to an MCP approval request.""" + + raw_item: McpApprovalResponse + """The raw MCP approval response.""" + + type: Literal["mcp_approval_response_item"] = "mcp_approval_response_item" + + RunItem: TypeAlias = Union[ MessageOutputItem, HandoffCallItem, @@ -154,6 +205,9 @@ class ReasoningItem(RunItemBase[ResponseReasoningItem]): ToolCallItem, ToolCallOutputItem, ReasoningItem, + MCPListToolsItem, + MCPApprovalRequestItem, + MCPApprovalResponseItem, ] """An item generated by an agent.""" diff --git a/src/agents/mcp/__init__.py b/src/agents/mcp/__init__.py index 1a72a89f0..d4eb8fa68 100644 --- a/src/agents/mcp/__init__.py +++ b/src/agents/mcp/__init__.py @@ -5,6 +5,8 @@ MCPServerSseParams, MCPServerStdio, MCPServerStdioParams, + MCPServerStreamableHttp, + MCPServerStreamableHttpParams, ) except ImportError: pass @@ -17,5 +19,7 @@ "MCPServerSseParams", "MCPServerStdio", "MCPServerStdioParams", + "MCPServerStreamableHttp", + "MCPServerStreamableHttpParams", "MCPUtil", ] diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index 9916c92b0..b012a0968 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -10,7 +10,9 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client from mcp.client.sse import sse_client -from mcp.types import CallToolResult, JSONRPCMessage +from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client +from mcp.shared.message import SessionMessage +from mcp.types import CallToolResult, InitializeResult from typing_extensions import NotRequired, TypedDict from ..exceptions import UserError @@ -71,6 +73,7 @@ def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float self.exit_stack: AsyncExitStack = AsyncExitStack() self._cleanup_lock: asyncio.Lock = asyncio.Lock() self.cache_tools_list = cache_tools_list + self.server_initialize_result: InitializeResult | None = None self.client_session_timeout_seconds = client_session_timeout_seconds @@ -83,8 +86,9 @@ def create_streams( self, ) -> AbstractAsyncContextManager[ tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback | None, ] ]: """Create the streams for the server.""" @@ -105,7 +109,11 @@ async def connect(self): """Connect to the server.""" try: transport = await self.exit_stack.enter_async_context(self.create_streams()) - read, write = transport + # streamablehttp_client returns (read, write, get_session_id) + # sse_client returns (read, write) + + read, write, *_ = transport + session = await self.exit_stack.enter_async_context( ClientSession( read, @@ -115,7 +123,8 @@ async def connect(self): else None, ) ) - await session.initialize() + server_result = await session.initialize() + self.server_initialize_result = server_result self.session = session except Exception as e: logger.error(f"Error initializing MCP server: {e}") @@ -232,8 +241,9 @@ def create_streams( self, ) -> AbstractAsyncContextManager[ tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback | None, ] ]: """Create the streams for the server.""" @@ -302,8 +312,9 @@ def create_streams( self, ) -> AbstractAsyncContextManager[ tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback | None, ] ]: """Create the streams for the server.""" @@ -318,3 +329,84 @@ def create_streams( def name(self) -> str: """A readable name for the server.""" return self._name + + +class MCPServerStreamableHttpParams(TypedDict): + """Mirrors the params in`mcp.client.streamable_http.streamablehttp_client`.""" + + url: str + """The URL of the server.""" + + headers: NotRequired[dict[str, str]] + """The headers to send to the server.""" + + timeout: NotRequired[timedelta] + """The timeout for the HTTP request. Defaults to 5 seconds.""" + + sse_read_timeout: NotRequired[timedelta] + """The timeout for the SSE connection, in seconds. Defaults to 5 minutes.""" + + terminate_on_close: NotRequired[bool] + """Terminate on close""" + + +class MCPServerStreamableHttp(_MCPServerWithClientSession): + """MCP server implementation that uses the Streamable HTTP transport. See the [spec] + (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http) + for details. + """ + + def __init__( + self, + params: MCPServerStreamableHttpParams, + cache_tools_list: bool = False, + name: str | None = None, + client_session_timeout_seconds: float | None = 5, + ): + """Create a new MCP server based on the Streamable HTTP transport. + + Args: + params: The params that configure the server. This includes the URL of the server, + the headers to send to the server, the timeout for the HTTP request, and the + timeout for the Streamable HTTP connection and whether we need to + terminate on close. + + cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be + cached and only fetched from the server once. If `False`, the tools list will be + fetched from the server on each call to `list_tools()`. The cache can be + invalidated by calling `invalidate_tools_cache()`. You should set this to `True` + if you know the server will not change its tools list, because it can drastically + improve latency (by avoiding a round-trip to the server every time). + + name: A readable name for the server. If not provided, we'll create one from the + URL. + + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + """ + super().__init__(cache_tools_list, client_session_timeout_seconds) + + self.params = params + self._name = name or f"streamable_http: {self.params['url']}" + + def create_streams( + self, + ) -> AbstractAsyncContextManager[ + tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback | None, + ] + ]: + """Create the streams for the server.""" + return streamablehttp_client( + url=self.params["url"], + headers=self.params.get("headers", None), + timeout=self.params.get("timeout", timedelta(seconds=30)), + sse_read_timeout=self.params.get("sse_read_timeout", timedelta(seconds=60 * 5)), + terminate_on_close=self.params.get("terminate_on_close", True), + ) + + @property + def name(self) -> str: + """A readable name for the server.""" + return self._name diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index bbfe1885c..5a963bc01 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -116,7 +116,7 @@ async def invoke_mcp_tool( if len(result.content) == 1: tool_output = result.content[0].model_dump_json() elif len(result.content) > 1: - tool_output = json.dumps([item.model_dump() for item in result.content]) + tool_output = json.dumps([item.model_dump(mode="json") for item in result.content]) else: logger.error(f"Errored MCP tool result: {result}") tool_output = "Error running tool." diff --git a/src/agents/models/chatcmpl_converter.py b/src/agents/models/chatcmpl_converter.py index 613a37453..1d599e8c0 100644 --- a/src/agents/models/chatcmpl_converter.py +++ b/src/agents/models/chatcmpl_converter.py @@ -234,7 +234,7 @@ def extract_all_content( type="image_url", image_url={ "url": casted_image_param["image_url"], - "detail": casted_image_param["detail"], + "detail": casted_image_param.get("detail", "auto"), }, ) ) diff --git a/src/agents/models/chatcmpl_stream_handler.py b/src/agents/models/chatcmpl_stream_handler.py index c71adeb55..d18f5912a 100644 --- a/src/agents/models/chatcmpl_stream_handler.py +++ b/src/agents/models/chatcmpl_stream_handler.py @@ -38,6 +38,16 @@ class StreamingState: function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict) +class SequenceNumber: + def __init__(self): + self._sequence_number = 0 + + def get_and_increment(self) -> int: + num = self._sequence_number + self._sequence_number += 1 + return num + + class ChatCmplStreamHandler: @classmethod async def handle_stream( @@ -47,13 +57,14 @@ async def handle_stream( ) -> AsyncIterator[TResponseStreamEvent]: usage: CompletionUsage | None = None state = StreamingState() - + sequence_number = SequenceNumber() async for chunk in stream: if not state.started: state.started = True yield ResponseCreatedEvent( response=response, type="response.created", + sequence_number=sequence_number.get_and_increment(), ) # This is always set by the OpenAI API, but not by others e.g. LiteLLM @@ -89,6 +100,7 @@ async def handle_stream( item=assistant_item, output_index=0, type="response.output_item.added", + sequence_number=sequence_number.get_and_increment(), ) yield ResponseContentPartAddedEvent( content_index=state.text_content_index_and_output[0], @@ -100,6 +112,7 @@ async def handle_stream( annotations=[], ), type="response.content_part.added", + sequence_number=sequence_number.get_and_increment(), ) # Emit the delta for this segment of content yield ResponseTextDeltaEvent( @@ -108,6 +121,7 @@ async def handle_stream( item_id=FAKE_RESPONSES_ID, output_index=0, type="response.output_text.delta", + sequence_number=sequence_number.get_and_increment(), ) # Accumulate the text into the response part state.text_content_index_and_output[1].text += delta.content @@ -134,6 +148,7 @@ async def handle_stream( item=assistant_item, output_index=0, type="response.output_item.added", + sequence_number=sequence_number.get_and_increment(), ) yield ResponseContentPartAddedEvent( content_index=state.refusal_content_index_and_output[0], @@ -145,6 +160,7 @@ async def handle_stream( annotations=[], ), type="response.content_part.added", + sequence_number=sequence_number.get_and_increment(), ) # Emit the delta for this segment of refusal yield ResponseRefusalDeltaEvent( @@ -153,6 +169,7 @@ async def handle_stream( item_id=FAKE_RESPONSES_ID, output_index=0, type="response.refusal.delta", + sequence_number=sequence_number.get_and_increment(), ) # Accumulate the refusal string in the output part state.refusal_content_index_and_output[1].refusal += delta.refusal @@ -190,6 +207,7 @@ async def handle_stream( output_index=0, part=state.text_content_index_and_output[1], type="response.content_part.done", + sequence_number=sequence_number.get_and_increment(), ) if state.refusal_content_index_and_output: @@ -201,6 +219,7 @@ async def handle_stream( output_index=0, part=state.refusal_content_index_and_output[1], type="response.content_part.done", + sequence_number=sequence_number.get_and_increment(), ) # Actually send events for the function calls @@ -216,6 +235,7 @@ async def handle_stream( ), output_index=function_call_starting_index, type="response.output_item.added", + sequence_number=sequence_number.get_and_increment(), ) # Then, yield the args yield ResponseFunctionCallArgumentsDeltaEvent( @@ -223,6 +243,7 @@ async def handle_stream( item_id=FAKE_RESPONSES_ID, output_index=function_call_starting_index, type="response.function_call_arguments.delta", + sequence_number=sequence_number.get_and_increment(), ) # Finally, the ResponseOutputItemDone yield ResponseOutputItemDoneEvent( @@ -235,6 +256,7 @@ async def handle_stream( ), output_index=function_call_starting_index, type="response.output_item.done", + sequence_number=sequence_number.get_and_increment(), ) # Finally, send the Response completed event @@ -258,6 +280,7 @@ async def handle_stream( item=assistant_msg, output_index=0, type="response.output_item.done", + sequence_number=sequence_number.get_and_increment(), ) for function_call in state.function_calls.values(): @@ -289,4 +312,5 @@ async def handle_stream( yield ResponseCompletedEvent( response=final_response, type="response.completed", + sequence_number=sequence_number.get_and_increment(), ) diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 89619f838..6b4045d21 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -9,6 +9,7 @@ from openai.types import ChatModel from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.responses import Response +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from .. import _debug from ..agent_output import AgentOutputSchemaBase @@ -70,12 +71,22 @@ async def get_response( stream=False, ) + first_choice = response.choices[0] + message = first_choice.message + if _debug.DONT_LOG_MODEL_DATA: logger.debug("Received model response") else: - logger.debug( - f"LLM resp:\n{json.dumps(response.choices[0].message.model_dump(), indent=2)}\n" - ) + if message is not None: + logger.debug( + "LLM resp:\n%s\n", + json.dumps(message.model_dump(), indent=2), + ) + else: + logger.debug( + "LLM resp had no message. finish_reason: %s", + first_choice.finish_reason, + ) usage = ( Usage( @@ -83,18 +94,32 @@ async def get_response( input_tokens=response.usage.prompt_tokens, output_tokens=response.usage.completion_tokens, total_tokens=response.usage.total_tokens, + input_tokens_details=InputTokensDetails( + cached_tokens=getattr( + response.usage.prompt_tokens_details, "cached_tokens", 0 + ) + or 0, + ), + output_tokens_details=OutputTokensDetails( + reasoning_tokens=getattr( + response.usage.completion_tokens_details, "reasoning_tokens", 0 + ) + or 0, + ), ) if response.usage else Usage() ) if tracing.include_data(): - span_generation.span_data.output = [response.choices[0].message.model_dump()] + span_generation.span_data.output = ( + [message.model_dump()] if message is not None else [] + ) span_generation.span_data.usage = { "input_tokens": usage.input_tokens, "output_tokens": usage.output_tokens, } - items = Converter.message_to_output_items(response.choices[0].message) + items = Converter.message_to_output_items(message) if message is not None else [] return ModelResponse( output=items, @@ -252,7 +277,7 @@ async def _fetch_response( stream_options=self._non_null_or_not_given(stream_options), store=self._non_null_or_not_given(store), reasoning_effort=self._non_null_or_not_given(reasoning_effort), - extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) }, + extra_headers={**HEADERS, **(model_settings.extra_headers or {})}, extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, metadata=self._non_null_or_not_given(model_settings.metadata), diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index c1ff85b98..86c8e69cb 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -10,6 +10,7 @@ from openai.types.responses import ( Response, ResponseCompletedEvent, + ResponseIncludable, ResponseStreamEvent, ResponseTextConfigParam, ToolParam, @@ -23,7 +24,17 @@ from ..handoffs import Handoff from ..items import ItemHelpers, ModelResponse, TResponseInputItem from ..logger import logger -from ..tool import ComputerTool, FileSearchTool, FunctionTool, Tool, WebSearchTool +from ..tool import ( + CodeInterpreterTool, + ComputerTool, + FileSearchTool, + FunctionTool, + HostedMCPTool, + ImageGenerationTool, + LocalShellTool, + Tool, + WebSearchTool, +) from ..tracing import SpanError, response_span from ..usage import Usage from ..version import __version__ @@ -36,13 +47,6 @@ _USER_AGENT = f"Agents/Python {__version__}" _HEADERS = {"User-Agent": _USER_AGENT} -# From the Responses API -IncludeLiteral = Literal[ - "file_search_call.results", - "message.input_image.image_url", - "computer_call_output.output.image_url", -] - class OpenAIResponsesModel(Model): """ @@ -98,6 +102,8 @@ async def get_response( input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens, total_tokens=response.usage.total_tokens, + input_tokens_details=response.usage.input_tokens_details, + output_tokens_details=response.usage.output_tokens_details, ) if response.usage else Usage() @@ -271,7 +277,7 @@ def _get_client(self) -> AsyncOpenAI: @dataclass class ConvertedTools: tools: list[ToolParam] - includes: list[IncludeLiteral] + includes: list[ResponseIncludable] class Converter: @@ -299,6 +305,18 @@ def convert_tool_choice( return { "type": "computer_use_preview", } + elif tool_choice == "image_generation": + return { + "type": "image_generation", + } + elif tool_choice == "code_interpreter": + return { + "type": "code_interpreter", + } + elif tool_choice == "mcp": + return { + "type": "mcp", + } else: return { "type": "function", @@ -328,7 +346,7 @@ def convert_tools( handoffs: list[Handoff[Any]], ) -> ConvertedTools: converted_tools: list[ToolParam] = [] - includes: list[IncludeLiteral] = [] + includes: list[ResponseIncludable] = [] computer_tools = [tool for tool in tools if isinstance(tool, ComputerTool)] if len(computer_tools) > 1: @@ -346,7 +364,7 @@ def convert_tools( return ConvertedTools(tools=converted_tools, includes=includes) @classmethod - def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, IncludeLiteral | None]: + def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, ResponseIncludable | None]: """Returns converted tool and includes""" if isinstance(tool, FunctionTool): @@ -357,7 +375,7 @@ def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, IncludeLiteral | None]: "type": "function", "description": tool.description, } - includes: IncludeLiteral | None = None + includes: ResponseIncludable | None = None elif isinstance(tool, WebSearchTool): ws: WebSearchToolParam = { "type": "web_search_preview", @@ -387,7 +405,20 @@ def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, IncludeLiteral | None]: "display_height": tool.computer.dimensions[1], } includes = None - + elif isinstance(tool, HostedMCPTool): + converted_tool = tool.tool_config + includes = None + elif isinstance(tool, ImageGenerationTool): + converted_tool = tool.tool_config + includes = None + elif isinstance(tool, CodeInterpreterTool): + converted_tool = tool.tool_config + includes = None + elif isinstance(tool, LocalShellTool): + converted_tool = { + "type": "local_shell", + } + includes = None else: raise UserError(f"Unknown tool type: {type(tool)}, tool") diff --git a/src/agents/repl.py b/src/agents/repl.py new file mode 100644 index 000000000..9a4f30759 --- /dev/null +++ b/src/agents/repl.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing import Any + +from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent + +from .agent import Agent +from .items import ItemHelpers, TResponseInputItem +from .result import RunResultBase +from .run import Runner +from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent, RunItemStreamEvent + + +async def run_demo_loop(agent: Agent[Any], *, stream: bool = True) -> None: + """Run a simple REPL loop with the given agent. + + This utility allows quick manual testing and debugging of an agent from the + command line. Conversation state is preserved across turns. Enter ``exit`` + or ``quit`` to stop the loop. + + Args: + agent: The starting agent to run. + stream: Whether to stream the agent output. + """ + + current_agent = agent + input_items: list[TResponseInputItem] = [] + while True: + try: + user_input = input(" > ") + except (EOFError, KeyboardInterrupt): + print() + break + if user_input.strip().lower() in {"exit", "quit"}: + break + if not user_input: + continue + + input_items.append({"role": "user", "content": user_input}) + + result: RunResultBase + if stream: + result = Runner.run_streamed(current_agent, input=input_items) + async for event in result.stream_events(): + if isinstance(event, RawResponsesStreamEvent): + if isinstance(event.data, ResponseTextDeltaEvent): + print(event.data.delta, end="", flush=True) + elif isinstance(event, RunItemStreamEvent): + if event.item.type == "tool_call_item": + print("\n[tool called]", flush=True) + elif event.item.type == "tool_call_output_item": + print(f"\n[tool output: {event.item.output}]", flush=True) + elif event.item.type == "message_output_item": + message = ItemHelpers.text_message_output(event.item) + print(message, end="", flush=True) + elif isinstance(event, AgentUpdatedStreamEvent): + print(f"\n[Agent updated: {event.new_agent.name}]", flush=True) + print() + else: + result = await Runner.run(current_agent, input_items) + if result.final_output is not None: + print(result.final_output) + + current_agent = result.last_agent + input_items = result.to_input_list() diff --git a/src/agents/result.py b/src/agents/result.py index 243db155c..5cf0e74c8 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -11,14 +11,22 @@ from ._run_impl import QueueCompleteSentinel from .agent import Agent from .agent_output import AgentOutputSchemaBase -from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded +from .exceptions import ( + AgentsException, + InputGuardrailTripwireTriggered, + MaxTurnsExceeded, + RunErrorDetails, +) from .guardrail import InputGuardrailResult, OutputGuardrailResult from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .logger import logger from .run_context import RunContextWrapper from .stream_events import StreamEvent from .tracing import Trace -from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming +from .util._pretty_print import ( + pretty_print_result, + pretty_print_run_result_streaming, +) if TYPE_CHECKING: from ._run_impl import QueueCompleteSentinel @@ -206,31 +214,53 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: if self._stored_exception: raise self._stored_exception + def _create_error_details(self) -> RunErrorDetails: + """Return a `RunErrorDetails` object considering the current attributes of the class.""" + return RunErrorDetails( + input=self.input, + new_items=self.new_items, + raw_responses=self.raw_responses, + last_agent=self.current_agent, + context_wrapper=self.context_wrapper, + input_guardrail_results=self.input_guardrail_results, + output_guardrail_results=self.output_guardrail_results, + ) + def _check_errors(self): if self.current_turn > self.max_turns: - self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded") + max_turns_exc = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded") + max_turns_exc.run_data = self._create_error_details() + self._stored_exception = max_turns_exc # Fetch all the completed guardrail results from the queue and raise if needed while not self._input_guardrail_queue.empty(): guardrail_result = self._input_guardrail_queue.get_nowait() if guardrail_result.output.tripwire_triggered: - self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result) + tripwire_exc = InputGuardrailTripwireTriggered(guardrail_result) + tripwire_exc.run_data = self._create_error_details() + self._stored_exception = tripwire_exc # Check the tasks for any exceptions if self._run_impl_task and self._run_impl_task.done(): - exc = self._run_impl_task.exception() - if exc and isinstance(exc, Exception): - self._stored_exception = exc + run_impl_exc = self._run_impl_task.exception() + if run_impl_exc and isinstance(run_impl_exc, Exception): + if isinstance(run_impl_exc, AgentsException) and run_impl_exc.run_data is None: + run_impl_exc.run_data = self._create_error_details() + self._stored_exception = run_impl_exc if self._input_guardrails_task and self._input_guardrails_task.done(): - exc = self._input_guardrails_task.exception() - if exc and isinstance(exc, Exception): - self._stored_exception = exc + in_guard_exc = self._input_guardrails_task.exception() + if in_guard_exc and isinstance(in_guard_exc, Exception): + if isinstance(in_guard_exc, AgentsException) and in_guard_exc.run_data is None: + in_guard_exc.run_data = self._create_error_details() + self._stored_exception = in_guard_exc if self._output_guardrails_task and self._output_guardrails_task.done(): - exc = self._output_guardrails_task.exception() - if exc and isinstance(exc, Exception): - self._stored_exception = exc + out_guard_exc = self._output_guardrails_task.exception() + if out_guard_exc and isinstance(out_guard_exc, Exception): + if isinstance(out_guard_exc, AgentsException) and out_guard_exc.run_data is None: + out_guard_exc.run_data = self._create_error_details() + self._stored_exception = out_guard_exc def _cleanup_tasks(self): if self._run_impl_task and not self._run_impl_task.done(): diff --git a/src/agents/run.py b/src/agents/run.py index 849da7bfc..f375a0b89 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -26,6 +26,7 @@ MaxTurnsExceeded, ModelBehaviorError, OutputGuardrailTripwireTriggered, + RunErrorDetails, ) from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult from .handoffs import Handoff, HandoffInputFilter, handoff @@ -180,6 +181,8 @@ async def run( try: while True: + all_tools = await cls._get_all_tools(current_agent, context_wrapper) + # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. if current_span is None: @@ -195,8 +198,6 @@ async def run( output_type=output_type_name, ) current_span.start(mark_as_current=True) - - all_tools = await cls._get_all_tools(current_agent) current_span.span_data.tools = [t.name for t in all_tools] current_turn += 1 @@ -283,6 +284,17 @@ async def run( raise AgentsException( f"Unknown next step type: {type(turn_result.next_step)}" ) + except AgentsException as exc: + exc.run_data = RunErrorDetails( + input=original_input, + new_items=generated_items, + raw_responses=model_responses, + last_agent=current_agent, + context_wrapper=context_wrapper, + input_guardrail_results=input_guardrail_results, + output_guardrail_results=[], + ) + raise finally: if current_span: current_span.finish(reset_current=True) @@ -513,6 +525,8 @@ async def _run_streamed_impl( if streamed_result.is_complete: break + all_tools = await cls._get_all_tools(current_agent, context_wrapper) + # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. if current_span is None: @@ -528,8 +542,6 @@ async def _run_streamed_impl( output_type=output_type_name, ) current_span.start(mark_as_current=True) - - all_tools = await cls._get_all_tools(current_agent) tool_names = [t.name for t in all_tools] current_span.span_data.tools = tool_names current_turn += 1 @@ -609,6 +621,19 @@ async def _run_streamed_impl( streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) elif isinstance(turn_result.next_step, NextStepRunAgain): pass + except AgentsException as exc: + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + exc.run_data = RunErrorDetails( + input=streamed_result.input, + new_items=streamed_result.new_items, + raw_responses=streamed_result.raw_responses, + last_agent=current_agent, + context_wrapper=context_wrapper, + input_guardrail_results=streamed_result.input_guardrail_results, + output_guardrail_results=streamed_result.output_guardrail_results, + ) + raise except Exception as e: if current_span: _error_tracing.attach_error_to_span( @@ -689,6 +714,8 @@ async def _run_single_turn_streamed( input_tokens=event.response.usage.input_tokens, output_tokens=event.response.usage.output_tokens, total_tokens=event.response.usage.total_tokens, + input_tokens_details=event.response.usage.input_tokens_details, + output_tokens_details=event.response.usage.output_tokens_details, ) if event.response.usage else Usage() @@ -953,8 +980,10 @@ def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]: return handoffs @classmethod - async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]: - return await agent.get_all_tools() + async def _get_all_tools( + cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any] + ) -> list[Tool]: + return await agent.get_all_tools(context_wrapper) @classmethod def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: diff --git a/src/agents/stream_events.py b/src/agents/stream_events.py index bd37d11f3..a271e8acd 100644 --- a/src/agents/stream_events.py +++ b/src/agents/stream_events.py @@ -31,10 +31,13 @@ class RunItemStreamEvent: name: Literal[ "message_output_created", "handoff_requested", + # This is misspelled, but we can't change it because that would be a breaking change "handoff_occured", "tool_called", "tool_output", "reasoning_item_created", + "mcp_approval_requested", + "mcp_list_tools", ] """The name of the event.""" diff --git a/src/agents/tool.py b/src/agents/tool.py index c1c162423..57272f9f4 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -4,12 +4,14 @@ import json from collections.abc import Awaitable from dataclasses import dataclass -from typing import Any, Callable, Literal, Union, overload +from typing import TYPE_CHECKING, Any, Callable, Literal, Union, overload from openai.types.responses.file_search_tool_param import Filters, RankingOptions +from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest +from openai.types.responses.tool_param import CodeInterpreter, ImageGeneration, Mcp from openai.types.responses.web_search_tool_param import UserLocation from pydantic import ValidationError -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import Concatenate, NotRequired, ParamSpec, TypedDict from . import _debug from .computer import AsyncComputer, Computer @@ -22,6 +24,9 @@ from .util import _error_tracing from .util._types import MaybeAwaitable +if TYPE_CHECKING: + from .agent import Agent + ToolParams = ParamSpec("ToolParams") ToolFunctionWithoutContext = Callable[ToolParams, Any] @@ -72,6 +77,11 @@ class FunctionTool: """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, as it increases the likelihood of correct JSON input.""" + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True + """Whether the tool is enabled. Either a bool or a Callable that takes the run context and agent + and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool + based on your context/state.""" + @dataclass class FileSearchTool: @@ -130,7 +140,115 @@ def name(self): return "computer_use_preview" -Tool = Union[FunctionTool, FileSearchTool, WebSearchTool, ComputerTool] +@dataclass +class MCPToolApprovalRequest: + """A request to approve a tool call.""" + + ctx_wrapper: RunContextWrapper[Any] + """The run context.""" + + data: McpApprovalRequest + """The data from the MCP tool approval request.""" + + +class MCPToolApprovalFunctionResult(TypedDict): + """The result of an MCP tool approval function.""" + + approve: bool + """Whether to approve the tool call.""" + + reason: NotRequired[str] + """An optional reason, if rejected.""" + + +MCPToolApprovalFunction = Callable[ + [MCPToolApprovalRequest], MaybeAwaitable[MCPToolApprovalFunctionResult] +] +"""A function that approves or rejects a tool call.""" + + +@dataclass +class HostedMCPTool: + """A tool that allows the LLM to use a remote MCP server. The LLM will automatically list and + call tools, without requiring a a round trip back to your code. + If you want to run MCP servers locally via stdio, in a VPC or other non-publicly-accessible + environment, or you just prefer to run tool calls locally, then you can instead use the servers + in `agents.mcp` and pass `Agent(mcp_servers=[...])` to the agent.""" + + tool_config: Mcp + """The MCP tool config, which includes the server URL and other settings.""" + + on_approval_request: MCPToolApprovalFunction | None = None + """An optional function that will be called if approval is requested for an MCP tool. If not + provided, you will need to manually add approvals/rejections to the input and call + `Runner.run(...)` again.""" + + @property + def name(self): + return "hosted_mcp" + + +@dataclass +class CodeInterpreterTool: + """A tool that allows the LLM to execute code in a sandboxed environment.""" + + tool_config: CodeInterpreter + """The tool config, which includes the container and other settings.""" + + @property + def name(self): + return "code_interpreter" + + +@dataclass +class ImageGenerationTool: + """A tool that allows the LLM to generate images.""" + + tool_config: ImageGeneration + """The tool config, which image generation settings.""" + + @property + def name(self): + return "image_generation" + + +@dataclass +class LocalShellCommandRequest: + """A request to execute a command on a shell.""" + + ctx_wrapper: RunContextWrapper[Any] + """The run context.""" + + data: LocalShellCall + """The data from the local shell tool call.""" + + +LocalShellExecutor = Callable[[LocalShellCommandRequest], MaybeAwaitable[str]] +"""A function that executes a command on a shell.""" + + +@dataclass +class LocalShellTool: + """A tool that allows the LLM to execute commands on a shell.""" + + executor: LocalShellExecutor + """A function that executes a command on a shell.""" + + @property + def name(self): + return "local_shell" + + +Tool = Union[ + FunctionTool, + FileSearchTool, + WebSearchTool, + ComputerTool, + HostedMCPTool, + LocalShellTool, + ImageGenerationTool, + CodeInterpreterTool, +] """A tool that can be used in an agent.""" @@ -152,6 +270,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> FunctionTool: """Overload for usage as @function_tool (no parentheses).""" ... @@ -166,6 +285,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Callable[[ToolFunction[...]], FunctionTool]: """Overload for usage as @function_tool(...).""" ... @@ -180,6 +300,7 @@ def function_tool( use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = default_tool_error_function, strict_mode: bool = True, + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]: """ Decorator to create a FunctionTool from a function. By default, we will: @@ -208,6 +329,9 @@ def function_tool( If False, it allows non-strict JSON schemas. For example, if a parameter has a default value, it will be optional, additional properties are allowed, etc. See here for more: https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas + is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run + context and agent and returns whether the tool is enabled. Disabled tools are hidden + from the LLM at runtime. """ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: @@ -297,6 +421,7 @@ async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any: params_json_schema=schema.params_json_schema, on_invoke_tool=_on_invoke_tool, strict_json_schema=strict_mode, + is_enabled=is_enabled, ) # If func is actually a callable, we were used as @function_tool with no parentheses diff --git a/src/agents/tracing/processors.py b/src/agents/tracing/processors.py index 2913b11a4..73f733125 100644 --- a/src/agents/tracing/processors.py +++ b/src/agents/tracing/processors.py @@ -188,10 +188,27 @@ def __init__( # Track when we next *must* perform a scheduled export self._next_export_time = time.time() + self._schedule_delay - self._worker_thread = threading.Thread(target=self._run, daemon=True) - self._worker_thread.start() + # We lazily start the background worker thread the first time a span/trace is queued. + self._worker_thread: threading.Thread | None = None + self._thread_start_lock = threading.Lock() + + def _ensure_thread_started(self) -> None: + # Fast path without holding the lock + if self._worker_thread and self._worker_thread.is_alive(): + return + + # Double-checked locking to avoid starting multiple threads + with self._thread_start_lock: + if self._worker_thread and self._worker_thread.is_alive(): + return + + self._worker_thread = threading.Thread(target=self._run, daemon=True) + self._worker_thread.start() def on_trace_start(self, trace: Trace) -> None: + # Ensure the background worker is running before we enqueue anything. + self._ensure_thread_started() + try: self._queue.put_nowait(trace) except queue.Full: @@ -206,6 +223,9 @@ def on_span_start(self, span: Span[Any]) -> None: pass def on_span_end(self, span: Span[Any]) -> None: + # Ensure the background worker is running before we enqueue anything. + self._ensure_thread_started() + try: self._queue.put_nowait(span) except queue.Full: @@ -216,7 +236,13 @@ def shutdown(self, timeout: float | None = None): Called when the application stops. We signal our thread to stop, then join it. """ self._shutdown_event.set() - self._worker_thread.join(timeout=timeout) + + # Only join if we ever started the background thread; otherwise flush synchronously. + if self._worker_thread and self._worker_thread.is_alive(): + self._worker_thread.join(timeout=timeout) + else: + # No background thread: process any remaining items synchronously. + self._export_batches(force=True) def force_flush(self): """ diff --git a/src/agents/usage.py b/src/agents/usage.py index 23d989b4b..843f62937 100644 --- a/src/agents/usage.py +++ b/src/agents/usage.py @@ -1,4 +1,6 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field + +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails @dataclass @@ -9,9 +11,18 @@ class Usage: input_tokens: int = 0 """Total input tokens sent, across all requests.""" + input_tokens_details: InputTokensDetails = field( + default_factory=lambda: InputTokensDetails(cached_tokens=0) + ) + """Details about the input tokens, matching responses API usage details.""" output_tokens: int = 0 """Total output tokens received, across all requests.""" + output_tokens_details: OutputTokensDetails = field( + default_factory=lambda: OutputTokensDetails(reasoning_tokens=0) + ) + """Details about the output tokens, matching responses API usage details.""" + total_tokens: int = 0 """Total tokens sent and received, across all requests.""" @@ -20,3 +31,12 @@ def add(self, other: "Usage") -> None: self.input_tokens += other.input_tokens if other.input_tokens else 0 self.output_tokens += other.output_tokens if other.output_tokens else 0 self.total_tokens += other.total_tokens if other.total_tokens else 0 + self.input_tokens_details = InputTokensDetails( + cached_tokens=self.input_tokens_details.cached_tokens + + other.input_tokens_details.cached_tokens + ) + + self.output_tokens_details = OutputTokensDetails( + reasoning_tokens=self.output_tokens_details.reasoning_tokens + + other.output_tokens_details.reasoning_tokens + ) diff --git a/src/agents/util/_pretty_print.py b/src/agents/util/_pretty_print.py index afd3e2b1b..29df3562e 100644 --- a/src/agents/util/_pretty_print.py +++ b/src/agents/util/_pretty_print.py @@ -3,6 +3,7 @@ from pydantic import BaseModel if TYPE_CHECKING: + from ..exceptions import RunErrorDetails from ..result import RunResult, RunResultBase, RunResultStreaming @@ -38,6 +39,17 @@ def pretty_print_result(result: "RunResult") -> str: return output +def pretty_print_run_error_details(result: "RunErrorDetails") -> str: + output = "RunErrorDetails:" + output += f'\n- Last agent: Agent(name="{result.last_agent.name}", ...)' + output += f"\n- {len(result.new_items)} new item(s)" + output += f"\n- {len(result.raw_responses)} raw response(s)" + output += f"\n- {len(result.input_guardrail_results)} input guardrail result(s)" + output += "\n(See `RunErrorDetails` for more details)" + + return output + + def pretty_print_run_result_streaming(result: "RunResultStreaming") -> str: output = "RunResultStreaming:" output += f'\n- Current agent: Agent(name="{result.current_agent.name}", ...)' diff --git a/src/agents/voice/model.py b/src/agents/voice/model.py index c36a4de76..b048a452d 100644 --- a/src/agents/voice/model.py +++ b/src/agents/voice/model.py @@ -17,9 +17,11 @@ TTSVoice = Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"] """Exportable type for the TTSModelSettings voice enum""" + @dataclass class TTSModelSettings: """Settings for a TTS model.""" + voice: TTSVoice | None = None """ The voice to use for the TTS model. If not provided, the default voice for the respective model diff --git a/tests/fake_model.py b/tests/fake_model.py index 32f919ef1..9f0c83a2f 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -129,6 +129,7 @@ async def stream_response( yield ResponseCompletedEvent( type="response.completed", response=get_response_obj(output, usage=self.hardcoded_usage), + sequence_number=0, ) diff --git a/tests/mcp/test_mcp_tracing.py b/tests/mcp/test_mcp_tracing.py index b71954b5b..54575dcb5 100644 --- a/tests/mcp/test_mcp_tracing.py +++ b/tests/mcp/test_mcp_tracing.py @@ -44,6 +44,10 @@ async def test_mcp_tracing(): { "workflow_name": "Agent workflow", "children": [ + { + "type": "mcp_tools", + "data": {"server": "fake_mcp_server", "result": ["test_tool_1"]}, + }, { "type": "agent", "data": { @@ -53,10 +57,6 @@ async def test_mcp_tracing(): "output_type": "str", }, "children": [ - { - "type": "mcp_tools", - "data": {"server": "fake_mcp_server", "result": ["test_tool_1"]}, - }, { "type": "function", "data": { @@ -66,8 +66,12 @@ async def test_mcp_tracing(): "mcp_data": {"server": "fake_mcp_server"}, }, }, + { + "type": "mcp_tools", + "data": {"server": "fake_mcp_server", "result": ["test_tool_1"]}, + }, ], - } + }, ], } ] @@ -100,6 +104,13 @@ async def test_mcp_tracing(): { "workflow_name": "Agent workflow", "children": [ + { + "type": "mcp_tools", + "data": { + "server": "fake_mcp_server", + "result": ["test_tool_1", "test_tool_2"], + }, + }, { "type": "agent", "data": { @@ -109,13 +120,6 @@ async def test_mcp_tracing(): "output_type": "str", }, "children": [ - { - "type": "mcp_tools", - "data": { - "server": "fake_mcp_server", - "result": ["test_tool_1", "test_tool_2"], - }, - }, { "type": "function", "data": { @@ -133,8 +137,15 @@ async def test_mcp_tracing(): "mcp_data": {"server": "fake_mcp_server"}, }, }, + { + "type": "mcp_tools", + "data": { + "server": "fake_mcp_server", + "result": ["test_tool_1", "test_tool_2"], + }, + }, ], - } + }, ], } ] @@ -165,6 +176,13 @@ async def test_mcp_tracing(): { "workflow_name": "Agent workflow", "children": [ + { + "type": "mcp_tools", + "data": { + "server": "fake_mcp_server", + "result": ["test_tool_1", "test_tool_2", "test_tool_3"], + }, + }, { "type": "agent", "data": { @@ -174,13 +192,6 @@ async def test_mcp_tracing(): "output_type": "str", }, "children": [ - { - "type": "mcp_tools", - "data": { - "server": "fake_mcp_server", - "result": ["test_tool_1", "test_tool_2", "test_tool_3"], - }, - }, { "type": "function", "data": { @@ -190,8 +201,15 @@ async def test_mcp_tracing(): "mcp_data": {"server": "fake_mcp_server"}, }, }, + { + "type": "mcp_tools", + "data": { + "server": "fake_mcp_server", + "result": ["test_tool_1", "test_tool_2", "test_tool_3"], + }, + }, ], - } + }, ], } ] diff --git a/tests/models/test_litellm_chatcompletions_stream.py b/tests/models/test_litellm_chatcompletions_stream.py index 80bd8ea22..06e46b39c 100644 --- a/tests/models/test_litellm_chatcompletions_stream.py +++ b/tests/models/test_litellm_chatcompletions_stream.py @@ -8,7 +8,11 @@ ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction, ) -from openai.types.completion_usage import CompletionUsage +from openai.types.completion_usage import ( + CompletionTokensDetails, + CompletionUsage, + PromptTokensDetails, +) from openai.types.responses import ( Response, ResponseFunctionToolCall, @@ -46,7 +50,13 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No model="fake", object="chat.completion.chunk", choices=[Choice(index=0, delta=ChoiceDelta(content="llo"))], - usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12), + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=2), + prompt_tokens_details=PromptTokensDetails(cached_tokens=6), + ), ) async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: @@ -112,6 +122,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert completed_resp.usage.input_tokens == 7 assert completed_resp.usage.output_tokens == 5 assert completed_resp.usage.total_tokens == 12 + assert completed_resp.usage.input_tokens_details.cached_tokens == 6 + assert completed_resp.usage.output_tokens_details.reasoning_tokens == 2 @pytest.mark.allow_call_model_methods diff --git a/tests/models/test_litellm_extra_body.py b/tests/models/test_litellm_extra_body.py new file mode 100644 index 000000000..3c83b0607 --- /dev/null +++ b/tests/models/test_litellm_extra_body.py @@ -0,0 +1,44 @@ +import litellm +import pytest +from litellm.types.utils import Choices, Message, ModelResponse, Usage + +from agents.extensions.models.litellm_model import LitellmModel +from agents.model_settings import ModelSettings +from agents.models.interface import ModelTracing + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_extra_body_is_forwarded(monkeypatch): + """ + Forward `extra_body` entries into litellm.acompletion kwargs. + + This ensures that user-provided parameters (e.g. cached_content) + arrive alongside default arguments. + """ + captured: dict[str, object] = {} + + async def fake_acompletion(model, messages=None, **kwargs): + captured.update(kwargs) + msg = Message(role="assistant", content="ok") + choice = Choices(index=0, message=msg) + return ModelResponse(choices=[choice], usage=Usage(0, 0, 0)) + + monkeypatch.setattr(litellm, "acompletion", fake_acompletion) + settings = ModelSettings( + temperature=0.1, extra_body={"cached_content": "some_cache", "foo": 123} + ) + model = LitellmModel(model="test-model") + + await model.get_response( + system_instructions=None, + input=[], + model_settings=settings, + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + + assert {"cached_content": "some_cache", "foo": 123}.items() <= captured.items() diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index 14a278a90..77eb8019e 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -745,3 +745,38 @@ async def test_previous_response_id_passed_between_runs_streamed_multi_turn(): pass assert model.last_turn_args.get("previous_response_id") == "resp-stream-test" + + +@pytest.mark.asyncio +async def test_dynamic_tool_addition_run() -> None: + """Test that tools can be added to an agent during a run.""" + model = FakeModel() + + executed: dict[str, bool] = {"called": False} + + agent = Agent(name="test", model=model, tool_use_behavior="run_llm_again") + + @function_tool(name_override="tool2") + def tool2() -> str: + executed["called"] = True + return "result2" + + @function_tool(name_override="add_tool") + async def add_tool() -> str: + agent.tools.append(tool2) + return "added" + + agent.tools.append(add_tool) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("add_tool", json.dumps({}))], + [get_function_tool_call("tool2", json.dumps({}))], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="start") + + assert executed["called"] is True + assert result.final_output == "done" diff --git a/tests/test_agent_runner_streamed.py b/tests/test_agent_runner_streamed.py index 87a76a706..31fe2979b 100644 --- a/tests/test_agent_runner_streamed.py +++ b/tests/test_agent_runner_streamed.py @@ -18,6 +18,7 @@ RunContextWrapper, Runner, UserError, + function_tool, handoff, ) from agents.items import RunItem @@ -684,3 +685,39 @@ async def test_streaming_events(): assert len(agent_data) == 2, "should have 2 agent updated events" assert agent_data[0].new_agent == agent_2, "should have started with agent_2" assert agent_data[1].new_agent == agent_1, "should have handed off to agent_1" + + +@pytest.mark.asyncio +async def test_dynamic_tool_addition_run_streamed() -> None: + model = FakeModel() + + executed: dict[str, bool] = {"called": False} + + agent = Agent(name="test", model=model, tool_use_behavior="run_llm_again") + + @function_tool(name_override="tool2") + def tool2() -> str: + executed["called"] = True + return "result2" + + @function_tool(name_override="add_tool") + async def add_tool() -> str: + agent.tools.append(tool2) + return "added" + + agent.tools.append(add_tool) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("add_tool", json.dumps({}))], + [get_function_tool_call("tool2", json.dumps({}))], + [get_text_message("done")], + ] + ) + + result = Runner.run_streamed(agent, input="start") + async for _ in result.stream_events(): + pass + + assert executed["called"] is True + assert result.final_output == "done" diff --git a/tests/test_extra_headers.py b/tests/test_extra_headers.py index f29c25408..a6af30077 100644 --- a/tests/test_extra_headers.py +++ b/tests/test_extra_headers.py @@ -1,6 +1,7 @@ import pytest from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents import ModelSettings, ModelTracing, OpenAIChatCompletionsModel, OpenAIResponsesModel @@ -17,21 +18,29 @@ class DummyResponses: async def create(self, **kwargs): nonlocal called_kwargs called_kwargs = kwargs + class DummyResponse: id = "dummy" output = [] usage = type( - "Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + "Usage", + (), + { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "input_tokens_details": InputTokensDetails(cached_tokens=0), + "output_tokens_details": OutputTokensDetails(reasoning_tokens=0), + }, )() + return DummyResponse() class DummyClient: def __init__(self): self.responses = DummyResponses() - - - model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore extra_headers = {"X-Test-Header": "test-value"} await model.get_response( system_instructions=None, @@ -47,7 +56,6 @@ def __init__(self): assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value" - @pytest.mark.allow_call_model_methods @pytest.mark.asyncio async def test_extra_headers_passed_to_openai_client(): @@ -76,7 +84,7 @@ def __init__(self): self.chat = type("_Chat", (), {"completions": DummyCompletions()})() self.base_url = "https://api.openai.com" - model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore extra_headers = {"X-Test-Header": "test-value"} await model.get_response( system_instructions=None, diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index 0a57aea87..c5d7da649 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from typing_extensions import TypedDict -from agents import FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool +from agents import Agent, FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool from agents.tool import default_tool_error_function @@ -255,3 +255,44 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> result = await tool.on_invoke_tool(ctx, '{"a": 1, "b": 2}') assert result == "error_ValueError" + + +class BoolCtx(BaseModel): + enable_tools: bool + + +@pytest.mark.asyncio +async def test_is_enabled_bool_and_callable(): + @function_tool(is_enabled=False) + def disabled_tool(): + return "nope" + + async def cond_enabled(ctx: RunContextWrapper[BoolCtx], agent: Agent[Any]) -> bool: + return ctx.context.enable_tools + + @function_tool(is_enabled=cond_enabled) + def another_tool(): + return "hi" + + async def third_tool_on_invoke_tool(ctx: RunContextWrapper[Any], args: str) -> str: + return "third" + + third_tool = FunctionTool( + name="third_tool", + description="third tool", + on_invoke_tool=third_tool_on_invoke_tool, + is_enabled=lambda ctx, agent: ctx.context.enable_tools, + params_json_schema={}, + ) + + agent = Agent(name="t", tools=[disabled_tool, another_tool, third_tool]) + context_1 = RunContextWrapper(BoolCtx(enable_tools=False)) + context_2 = RunContextWrapper(BoolCtx(enable_tools=True)) + + tools_with_ctx = await agent.get_all_tools(context_1) + assert tools_with_ctx == [] + + tools_with_ctx = await agent.get_all_tools(context_2) + assert len(tools_with_ctx) == 2 + assert tools_with_ctx[0].name == "another_tool" + assert tools_with_ctx[1].name == "third_tool" diff --git a/tests/test_openai_chatcompletions.py b/tests/test_openai_chatcompletions.py index ba3ec68d0..9a85dcb7b 100644 --- a/tests/test_openai_chatcompletions.py +++ b/tests/test_openai_chatcompletions.py @@ -13,7 +13,10 @@ ChatCompletionMessageToolCall, Function, ) -from openai.types.completion_usage import CompletionUsage +from openai.types.completion_usage import ( + CompletionUsage, + PromptTokensDetails, +) from openai.types.responses import ( Response, ResponseFunctionToolCall, @@ -51,7 +54,13 @@ async def test_get_response_with_text_message(monkeypatch) -> None: model="fake", object="chat.completion", choices=[choice], - usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12), + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + # completion_tokens_details left blank to test default + prompt_tokens_details=PromptTokensDetails(cached_tokens=3), + ), ) async def patched_fetch_response(self, *args, **kwargs): @@ -81,6 +90,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert resp.usage.input_tokens == 7 assert resp.usage.output_tokens == 5 assert resp.usage.total_tokens == 12 + assert resp.usage.input_tokens_details.cached_tokens == 3 + assert resp.usage.output_tokens_details.reasoning_tokens == 0 assert resp.response_id is None @@ -127,6 +138,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert resp.usage.requests == 0 assert resp.usage.input_tokens == 0 assert resp.usage.output_tokens == 0 + assert resp.usage.input_tokens_details.cached_tokens == 0 + assert resp.usage.output_tokens_details.reasoning_tokens == 0 @pytest.mark.allow_call_model_methods @@ -178,6 +191,40 @@ async def patched_fetch_response(self, *args, **kwargs): assert fn_call_item.arguments == "{'x':1}" +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_get_response_with_no_message(monkeypatch) -> None: + """If the model returns no message, get_response should return an empty output.""" + msg = ChatCompletionMessage(role="assistant", content="ignored") + choice = Choice(index=0, finish_reason="content_filter", message=msg) + choice.message = None # type: ignore[assignment] + chat = ChatCompletion( + id="resp-id", + created=0, + model="fake", + object="chat.completion", + choices=[choice], + usage=None, + ) + + async def patched_fetch_response(self, *args, **kwargs): + return chat + + monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response) + model = OpenAIProvider(use_responses=False).get_model("gpt-4") + resp: ModelResponse = await model.get_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + assert resp.output == [] + + @pytest.mark.asyncio async def test_fetch_response_non_stream(monkeypatch) -> None: """ diff --git a/tests/test_openai_chatcompletions_stream.py b/tests/test_openai_chatcompletions_stream.py index b82f24303..5c8bb9e3a 100644 --- a/tests/test_openai_chatcompletions_stream.py +++ b/tests/test_openai_chatcompletions_stream.py @@ -8,7 +8,11 @@ ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction, ) -from openai.types.completion_usage import CompletionUsage +from openai.types.completion_usage import ( + CompletionTokensDetails, + CompletionUsage, + PromptTokensDetails, +) from openai.types.responses import ( Response, ResponseFunctionToolCall, @@ -46,7 +50,13 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No model="fake", object="chat.completion.chunk", choices=[Choice(index=0, delta=ChoiceDelta(content="llo"))], - usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12), + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + prompt_tokens_details=PromptTokensDetails(cached_tokens=2), + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=3), + ), ) async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: @@ -112,6 +122,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert completed_resp.usage.input_tokens == 7 assert completed_resp.usage.output_tokens == 5 assert completed_resp.usage.total_tokens == 12 + assert completed_resp.usage.input_tokens_details.cached_tokens == 2 + assert completed_resp.usage.output_tokens_details.reasoning_tokens == 3 @pytest.mark.allow_call_model_methods diff --git a/tests/test_repl.py b/tests/test_repl.py new file mode 100644 index 000000000..7ba2011be --- /dev/null +++ b/tests/test_repl.py @@ -0,0 +1,28 @@ +import pytest + +from agents import Agent, run_demo_loop + +from .fake_model import FakeModel +from .test_responses import get_text_input_item, get_text_message + + +@pytest.mark.asyncio +async def test_run_demo_loop_conversation(monkeypatch, capsys): + model = FakeModel() + model.add_multiple_turn_outputs([[get_text_message("hello")], [get_text_message("good")]]) + + agent = Agent(name="test", model=model) + + inputs = iter(["Hi", "How are you?", "quit"]) + monkeypatch.setattr("builtins.input", lambda _=" > ": next(inputs)) + + await run_demo_loop(agent, stream=False) + + output = capsys.readouterr().out + assert "hello" in output + assert "good" in output + assert model.last_turn_args["input"] == [ + get_text_input_item("Hi"), + get_text_message("hello").model_dump(exclude_unset=True), + get_text_input_item("How are you?"), + ] diff --git a/tests/test_responses_tracing.py b/tests/test_responses_tracing.py index 0bc97a953..db24fe496 100644 --- a/tests/test_responses_tracing.py +++ b/tests/test_responses_tracing.py @@ -1,7 +1,10 @@ +from typing import Optional + import pytest from inline_snapshot import snapshot from openai import AsyncOpenAI from openai.types.responses import ResponseCompletedEvent +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents import ModelSettings, ModelTracing, OpenAIResponsesModel, trace from agents.tracing.span_data import ResponseSpanData @@ -16,10 +19,25 @@ def is_disabled(self): class DummyUsage: - def __init__(self, input_tokens=1, output_tokens=1, total_tokens=2): + def __init__( + self, + input_tokens: int = 1, + input_tokens_details: Optional[InputTokensDetails] = None, + output_tokens: int = 1, + output_tokens_details: Optional[OutputTokensDetails] = None, + total_tokens: int = 2, + ): self.input_tokens = input_tokens self.output_tokens = output_tokens self.total_tokens = total_tokens + self.input_tokens_details = ( + input_tokens_details if input_tokens_details else InputTokensDetails(cached_tokens=0) + ) + self.output_tokens_details = ( + output_tokens_details + if output_tokens_details + else OutputTokensDetails(reasoning_tokens=0) + ) class DummyResponse: @@ -32,6 +50,7 @@ def __aiter__(self): yield ResponseCompletedEvent( type="response.completed", response=fake_model.get_response_obj(self.output), + sequence_number=0, ) @@ -183,6 +202,7 @@ async def __aiter__(self): yield ResponseCompletedEvent( type="response.completed", response=fake_model.get_response_obj([], "dummy-id-123"), + sequence_number=0, ) return DummyStream() @@ -235,6 +255,7 @@ async def __aiter__(self): yield ResponseCompletedEvent( type="response.completed", response=fake_model.get_response_obj([], "dummy-id-123"), + sequence_number=0, ) return DummyStream() @@ -286,6 +307,7 @@ async def __aiter__(self): yield ResponseCompletedEvent( type="response.completed", response=fake_model.get_response_obj([], "dummy-id-123"), + sequence_number=0, ) return DummyStream() diff --git a/tests/test_run_error_details.py b/tests/test_run_error_details.py new file mode 100644 index 000000000..104b248fc --- /dev/null +++ b/tests/test_run_error_details.py @@ -0,0 +1,48 @@ +import json + +import pytest + +from agents import Agent, MaxTurnsExceeded, RunErrorDetails, Runner + +from .fake_model import FakeModel +from .test_responses import get_function_tool, get_function_tool_call, get_text_message + + +@pytest.mark.asyncio +async def test_run_error_includes_data(): + model = FakeModel() + agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")]) + model.add_multiple_turn_outputs( + [ + [get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))], + [get_text_message("done")], + ] + ) + with pytest.raises(MaxTurnsExceeded) as exc: + await Runner.run(agent, input="hello", max_turns=1) + data = exc.value.run_data + assert isinstance(data, RunErrorDetails) + assert data.last_agent == agent + assert len(data.raw_responses) == 1 + assert len(data.new_items) > 0 + + +@pytest.mark.asyncio +async def test_streamed_run_error_includes_data(): + model = FakeModel() + agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")]) + model.add_multiple_turn_outputs( + [ + [get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))], + [get_text_message("done")], + ] + ) + result = Runner.run_streamed(agent, input="hello", max_turns=1) + with pytest.raises(MaxTurnsExceeded) as exc: + async for _ in result.stream_events(): + pass + data = exc.value.run_data + assert isinstance(data, RunErrorDetails) + assert data.last_agent == agent + assert len(data.raw_responses) == 1 + assert len(data.new_items) > 0 diff --git a/tests/test_run_step_execution.py b/tests/test_run_step_execution.py index 6ae25fbd5..b4c83d015 100644 --- a/tests/test_run_step_execution.py +++ b/tests/test_run_step_execution.py @@ -290,7 +290,7 @@ async def get_execute_result( processed_response = RunImpl.process_model_response( agent=agent, - all_tools=await agent.get_all_tools(), + all_tools=await agent.get_all_tools(context_wrapper or RunContextWrapper(None)), response=response, output_schema=output_schema, handoffs=handoffs, diff --git a/tests/test_run_step_processing.py b/tests/test_run_step_processing.py index 2ea98f06a..3cc1231c1 100644 --- a/tests/test_run_step_processing.py +++ b/tests/test_run_step_processing.py @@ -34,6 +34,10 @@ ) +def _dummy_ctx() -> RunContextWrapper[None]: + return RunContextWrapper(context=None) + + def test_empty_response(): agent = Agent(name="test") response = ModelResponse( @@ -83,7 +87,7 @@ async def test_single_tool_call(): response=response, output_schema=None, handoffs=[], - all_tools=await agent.get_all_tools(), + all_tools=await agent.get_all_tools(_dummy_ctx()), ) assert not result.handoffs assert result.functions and len(result.functions) == 1 @@ -111,7 +115,7 @@ async def test_missing_tool_call_raises_error(): response=response, output_schema=None, handoffs=[], - all_tools=await agent.get_all_tools(), + all_tools=await agent.get_all_tools(_dummy_ctx()), ) @@ -140,7 +144,7 @@ async def test_multiple_tool_calls(): response=response, output_schema=None, handoffs=[], - all_tools=await agent.get_all_tools(), + all_tools=await agent.get_all_tools(_dummy_ctx()), ) assert not result.handoffs assert result.functions and len(result.functions) == 2 @@ -169,7 +173,7 @@ async def test_handoffs_parsed_correctly(): response=response, output_schema=None, handoffs=[], - all_tools=await agent_3.get_all_tools(), + all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert not result.handoffs, "Shouldn't have a handoff here" @@ -183,7 +187,7 @@ async def test_handoffs_parsed_correctly(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), - all_tools=await agent_3.get_all_tools(), + all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert len(result.handoffs) == 1, "Should have a handoff here" handoff = result.handoffs[0] @@ -213,7 +217,7 @@ async def test_missing_handoff_fails(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), - all_tools=await agent_3.get_all_tools(), + all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) @@ -236,7 +240,7 @@ async def test_multiple_handoffs_doesnt_error(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), - all_tools=await agent_3.get_all_tools(), + all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert len(result.handoffs) == 2, "Should have multiple handoffs here" @@ -262,7 +266,7 @@ async def test_final_output_parsed_correctly(): response=response, output_schema=Runner._get_output_schema(agent), handoffs=[], - all_tools=await agent.get_all_tools(), + all_tools=await agent.get_all_tools(_dummy_ctx()), ) @@ -288,7 +292,7 @@ async def test_file_search_tool_call_parsed_correctly(): response=response, output_schema=None, handoffs=[], - all_tools=await agent.get_all_tools(), + all_tools=await agent.get_all_tools(_dummy_ctx()), ) # The final item should be a ToolCallItem for the file search call assert any( @@ -313,7 +317,7 @@ async def test_function_web_search_tool_call_parsed_correctly(): response=response, output_schema=None, handoffs=[], - all_tools=await agent.get_all_tools(), + all_tools=await agent.get_all_tools(_dummy_ctx()), ) assert any( isinstance(item, ToolCallItem) and item.raw_item is web_search_call @@ -340,7 +344,7 @@ async def test_reasoning_item_parsed_correctly(): response=response, output_schema=None, handoffs=[], - all_tools=await Agent(name="test").get_all_tools(), + all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()), ) assert any( isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items @@ -409,7 +413,7 @@ async def test_computer_tool_call_without_computer_tool_raises_error(): response=response, output_schema=None, handoffs=[], - all_tools=await Agent(name="test").get_all_tools(), + all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()), ) @@ -437,7 +441,7 @@ async def test_computer_tool_call_with_computer_tool_parsed_correctly(): response=response, output_schema=None, handoffs=[], - all_tools=await agent.get_all_tools(), + all_tools=await agent.get_all_tools(_dummy_ctx()), ) assert any( isinstance(item, ToolCallItem) and item.raw_item is computer_call @@ -468,7 +472,7 @@ async def test_tool_and_handoff_parsed_correctly(): response=response, output_schema=None, handoffs=Runner._get_handoffs(agent_3), - all_tools=await agent_3.get_all_tools(), + all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert result.functions and len(result.functions) == 1 assert len(result.handoffs) == 1, "Should have a handoff here" diff --git a/tests/test_tracing_errors_streamed.py b/tests/test_tracing_errors_streamed.py index 416793e70..40efef3fa 100644 --- a/tests/test_tracing_errors_streamed.py +++ b/tests/test_tracing_errors_streamed.py @@ -168,10 +168,6 @@ async def test_tool_call_error(): "children": [ { "type": "agent", - "error": { - "message": "Error in agent run", - "data": {"error": "Invalid JSON input for tool foo: bad_json"}, - }, "data": { "name": "test_agent", "handoffs": [], diff --git a/tests/test_usage.py b/tests/test_usage.py new file mode 100644 index 000000000..405f99ddf --- /dev/null +++ b/tests/test_usage.py @@ -0,0 +1,52 @@ +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails + +from agents.usage import Usage + + +def test_usage_add_aggregates_all_fields(): + u1 = Usage( + requests=1, + input_tokens=10, + input_tokens_details=InputTokensDetails(cached_tokens=3), + output_tokens=20, + output_tokens_details=OutputTokensDetails(reasoning_tokens=5), + total_tokens=30, + ) + u2 = Usage( + requests=2, + input_tokens=7, + input_tokens_details=InputTokensDetails(cached_tokens=4), + output_tokens=8, + output_tokens_details=OutputTokensDetails(reasoning_tokens=6), + total_tokens=15, + ) + + u1.add(u2) + + assert u1.requests == 3 + assert u1.input_tokens == 17 + assert u1.output_tokens == 28 + assert u1.total_tokens == 45 + assert u1.input_tokens_details.cached_tokens == 7 + assert u1.output_tokens_details.reasoning_tokens == 11 + + +def test_usage_add_aggregates_with_none_values(): + u1 = Usage() + u2 = Usage( + requests=2, + input_tokens=7, + input_tokens_details=InputTokensDetails(cached_tokens=4), + output_tokens=8, + output_tokens_details=OutputTokensDetails(reasoning_tokens=6), + total_tokens=15, + ) + + u1.add(u2) + + assert u1.requests == 2 + assert u1.input_tokens == 7 + assert u1.output_tokens == 8 + assert u1.total_tokens == 15 + assert u1.input_tokens_details.cached_tokens == 4 + assert u1.output_tokens_details.reasoning_tokens == 6 diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 6aa867743..8bce897e9 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -134,3 +134,18 @@ def test_draw_graph(mock_agent): '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' "fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source ) + + +def test_cycle_detection(): + agent_a = Agent(name="A") + agent_b = Agent(name="B") + agent_a.handoffs.append(agent_b) + agent_b.handoffs.append(agent_a) + + nodes = get_all_nodes(agent_a) + edges = get_all_edges(agent_a) + + assert nodes.count('"A" [label="A"') == 1 + assert nodes.count('"B" [label="B"') == 1 + assert '"A" -> "B"' in edges + assert '"B" -> "A"' in edges diff --git a/tests/voice/test_workflow.py b/tests/voice/test_workflow.py index 2bdf2a657..035a05d56 100644 --- a/tests/voice/test_workflow.py +++ b/tests/voice/test_workflow.py @@ -81,11 +81,13 @@ async def stream_response( type="response.output_text.delta", output_index=0, item_id=item.id, + sequence_number=0, ) yield ResponseCompletedEvent( type="response.completed", response=get_response_obj(output), + sequence_number=1, ) diff --git a/uv.lock b/uv.lock index 87dd3cd2c..1779b8637 100644 --- a/uv.lock +++ b/uv.lock @@ -1047,7 +1047,7 @@ wheels = [ [[package]] name = "mcp" -version = "1.6.0" +version = "1.8.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio", marker = "python_full_version >= '3.10'" }, @@ -1055,13 +1055,14 @@ dependencies = [ { name = "httpx-sse", marker = "python_full_version >= '3.10'" }, { name = "pydantic", marker = "python_full_version >= '3.10'" }, { name = "pydantic-settings", marker = "python_full_version >= '3.10'" }, + { name = "python-multipart", marker = "python_full_version >= '3.10'" }, { name = "sse-starlette", marker = "python_full_version >= '3.10'" }, { name = "starlette", marker = "python_full_version >= '3.10'" }, - { name = "uvicorn", marker = "python_full_version >= '3.10'" }, + { name = "uvicorn", marker = "python_full_version >= '3.10' and sys_platform != 'emscripten'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/95/d2/f587cb965a56e992634bebc8611c5b579af912b74e04eb9164bd49527d21/mcp-1.6.0.tar.gz", hash = "sha256:d9324876de2c5637369f43161cd71eebfd803df5a95e46225cab8d280e366723", size = 200031 } +sdist = { url = "https://files.pythonhosted.org/packages/7c/13/16b712e8a3be6a736b411df2fc6b4e75eb1d3e99b1cd57a3a1decf17f612/mcp-1.8.1.tar.gz", hash = "sha256:ec0646271d93749f784d2316fb5fe6102fb0d1be788ec70a9e2517e8f2722c0e", size = 265605 } wheels = [ - { url = "https://files.pythonhosted.org/packages/10/30/20a7f33b0b884a9d14dd3aa94ff1ac9da1479fe2ad66dd9e2736075d2506/mcp-1.6.0-py3-none-any.whl", hash = "sha256:7bd24c6ea042dbec44c754f100984d186620d8b841ec30f1b19eda9b93a634d0", size = 76077 }, + { url = "https://files.pythonhosted.org/packages/1c/5d/91cf0d40e40ae9ecf8d4004e0f9611eea86085aa0b5505493e0ff53972da/mcp-1.8.1-py3-none-any.whl", hash = "sha256:948e03783859fa35abe05b9b6c0a1d5519be452fc079dc8d7f682549591c1770", size = 119761 }, ] [[package]] @@ -1460,7 +1461,7 @@ wheels = [ [[package]] name = "openai" -version = "1.76.0" +version = "1.81.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1472,14 +1473,14 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/84/51/817969ec969b73d8ddad085670ecd8a45ef1af1811d8c3b8a177ca4d1309/openai-1.76.0.tar.gz", hash = "sha256:fd2bfaf4608f48102d6b74f9e11c5ecaa058b60dad9c36e409c12477dfd91fb2", size = 434660 } +sdist = { url = "https://files.pythonhosted.org/packages/1c/89/a1e4f3fa7ca4f7fec90dbf47d93b7cd5ff65924926733af15044e302a192/openai-1.81.0.tar.gz", hash = "sha256:349567a8607e0bcffd28e02f96b5c2397d0d25d06732d90ab3ecbf97abf030f9", size = 456861 } wheels = [ - { url = "https://files.pythonhosted.org/packages/59/aa/84e02ab500ca871eb8f62784426963a1c7c17a72fea3c7f268af4bbaafa5/openai-1.76.0-py3-none-any.whl", hash = "sha256:a712b50e78cf78e6d7b2a8f69c4978243517c2c36999756673e07a14ce37dc0a", size = 661201 }, + { url = "https://files.pythonhosted.org/packages/02/66/bcc7f9bf48e8610a33e3b5c96a5a644dad032d92404ea2a5e8b43ba067e8/openai-1.81.0-py3-none-any.whl", hash = "sha256:1c71572e22b43876c5d7d65ade0b7b516bb527c3d44ae94111267a09125f7bae", size = 717529 }, ] [[package]] name = "openai-agents" -version = "0.0.14" +version = "0.0.17" source = { editable = "." } dependencies = [ { name = "griffe" }, @@ -1533,9 +1534,9 @@ requires-dist = [ { name = "graphviz", marker = "extra == 'viz'", specifier = ">=0.17" }, { name = "griffe", specifier = ">=1.5.6,<2" }, { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.67.4.post1,<2" }, - { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.6.0,<2" }, + { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.8.0,<2" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, - { name = "openai", specifier = ">=1.76.0" }, + { name = "openai", specifier = ">=1.81.0" }, { name = "pydantic", specifier = ">=2.10,<3" }, { name = "requests", specifier = ">=2.0,<3" }, { name = "types-requests", specifier = ">=2.0,<3" }, @@ -2085,6 +2086,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/18/98a99ad95133c6a6e2005fe89faedf294a748bd5dc803008059409ac9b1e/python_dotenv-1.1.0-py3-none-any.whl", hash = "sha256:d7c01d9e2293916c18baf562d95698754b0dbbb5e74d457c45d4f6561fb9d55d", size = 20256 }, ] +[[package]] +name = "python-multipart" +version = "0.0.20" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/87/f44d7c9f274c7ee665a29b885ec97089ec5dc034c7f3fafa03da9e39a09e/python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13", size = 37158 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size = 24546 }, +] + [[package]] name = "python-xlib" version = "0.33"