diff --git a/.gitignore b/.gitignore index 7dd22b88..2e9b9237 100644 --- a/.gitignore +++ b/.gitignore @@ -141,4 +141,5 @@ cython_debug/ .ruff_cache/ # PyPI configuration file -.pypirc \ No newline at end of file +.pypirc +.aider* diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..ff37db32 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,69 @@ +Welcome to the OpenAI Agents SDK repository. This file contains the main points for new contributors. + +## Repository overview + +- **Source code**: `src/agents/` contains the implementation. +- **Tests**: `tests/` with a short guide in `tests/README.md`. +- **Examples**: under `examples/`. +- **Documentation**: markdown pages live in `docs/` with `mkdocs.yml` controlling the site. +- **Utilities**: developer commands are defined in the `Makefile`. +- **PR template**: `.github/PULL_REQUEST_TEMPLATE/pull_request_template.md` describes the information every PR must include. + +## Local workflow + +1. Format, lint and type‑check your changes: + + ```bash + make format + make lint + make mypy + ``` + +2. Run the tests: + + ```bash + make tests + ``` + + To run a single test, use `uv run pytest -s -k `. + +3. Build the documentation (optional but recommended for docs changes): + + ```bash + make build-docs + ``` + + Coverage can be generated with `make coverage`. + +## Snapshot tests + +Some tests rely on inline snapshots. See `tests/README.md` for details on updating them: + +```bash +make snapshots-fix # update existing snapshots +make snapshots-create # create new snapshots +``` + +Run `make tests` again after updating snapshots to ensure they pass. + +## Style notes + +- Write comments as full sentences and end them with a period. + +## Pull request expectations + +PRs should use the template located at `.github/PULL_REQUEST_TEMPLATE/pull_request_template.md`. Provide a summary, test plan and issue number if applicable, then check that: + +- New tests are added when needed. +- Documentation is updated. +- `make lint` and `make format` have been run. +- The full test suite passes. + +Commit messages should be concise and written in the imperative mood. Small, focused commits are preferred. + +## What reviewers look for + +- Tests covering new behaviour. +- Consistent style: code formatted with `ruff format`, imports sorted, and type hints passing `mypy`. +- Clear documentation for any public API changes. +- Clean history and a helpful PR description. diff --git a/README.md b/README.md index bbd4a5a0..7dcd97b3 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # OpenAI Agents SDK -The OpenAI Agents SDK is a lightweight yet powerful framework for building multi-agent workflows. +The OpenAI Agents SDK is a lightweight yet powerful framework for building multi-agent workflows. It is provider-agnostic, supporting the OpenAI Responses and Chat Completions APIs, as well as 100+ other LLMs. Image of the Agents Tracing UI @@ -13,8 +13,6 @@ The OpenAI Agents SDK is a lightweight yet powerful framework for building multi Explore the [examples](examples) directory to see the SDK in action, and read our [documentation](https://openai.github.io/openai-agents-python/) for more details. -Notably, our SDK [is compatible](https://openai.github.io/openai-agents-python/models/) with any model providers that support the OpenAI Chat Completions API format. - ## Get started 1. Set up your Python environment diff --git a/docs/ja/visualization.md b/docs/ja/visualization.md index 3c5c64ec..9093bb65 100644 --- a/docs/ja/visualization.md +++ b/docs/ja/visualization.md @@ -81,7 +81,7 @@ draw_graph(triage_agent).view() デフォルトでは、`draw_graph` はグラフをインラインで表示します。ファイルとして保存するには、ファイル名を指定します: ```python -draw_graph(triage_agent, filename="agent_graph.png") +draw_graph(triage_agent, filename="agent_graph") ``` -これにより、作業ディレクトリに `agent_graph.png` が生成されます。 \ No newline at end of file +これにより、作業ディレクトリに `agent_graph.png` が生成されます。 diff --git a/docs/models/index.md b/docs/models/index.md index 4cf4f643..1c89d778 100644 --- a/docs/models/index.md +++ b/docs/models/index.md @@ -5,11 +5,40 @@ The Agents SDK comes with out-of-the-box support for OpenAI models in two flavor - **Recommended**: the [`OpenAIResponsesModel`][agents.models.openai_responses.OpenAIResponsesModel], which calls OpenAI APIs using the new [Responses API](https://platform.openai.com/docs/api-reference/responses). - The [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel], which calls OpenAI APIs using the [Chat Completions API](https://platform.openai.com/docs/api-reference/chat). +## Non-OpenAI models + +You can use most other non-OpenAI models via the [LiteLLM integration](./litellm.md). First, install the litellm dependency group: + +```bash +pip install "openai-agents[litellm]" +``` + +Then, use any of the [supported models](https://docs.litellm.ai/docs/providers) with the `litellm/` prefix: + +```python +claude_agent = Agent(model="litellm/anthropic/claude-3-5-sonnet-20240620", ...) +gemini_agent = Agent(model="litellm/gemini/gemini-2.5-flash-preview-04-17", ...) +``` + +### Other ways to use non-OpenAI models + +You can integrate other LLM providers in 3 more ways (examples [here](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/)): + +1. [`set_default_openai_client`][agents.set_default_openai_client] is useful in cases where you want to globally use an instance of `AsyncOpenAI` as the LLM client. This is for cases where the LLM provider has an OpenAI compatible API endpoint, and you can set the `base_url` and `api_key`. See a configurable example in [examples/model_providers/custom_example_global.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_global.py). +2. [`ModelProvider`][agents.models.interface.ModelProvider] is at the `Runner.run` level. This lets you say "use a custom model provider for all agents in this run". See a configurable example in [examples/model_providers/custom_example_provider.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_provider.py). +3. [`Agent.model`][agents.agent.Agent.model] lets you specify the model on a specific Agent instance. This enables you to mix and match different providers for different agents. See a configurable example in [examples/model_providers/custom_example_agent.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_agent.py). An easy way to use most available models is via the [LiteLLM integration](./litellm.md). + +In cases where you do not have an API key from `platform.openai.com`, we recommend disabling tracing via `set_tracing_disabled()`, or setting up a [different tracing processor](../tracing.md). + +!!! note + + In these examples, we use the Chat Completions API/model, because most LLM providers don't yet support the Responses API. If your LLM provider does support it, we recommend using Responses. + ## Mixing and matching models Within a single workflow, you may want to use different models for each agent. For example, you could use a smaller, faster model for triage, while using a larger, more capable model for complex tasks. When configuring an [`Agent`][agents.Agent], you can select a specific model by either: -1. Passing the name of an OpenAI model. +1. Passing the name of a model. 2. Passing any model name + a [`ModelProvider`][agents.models.interface.ModelProvider] that can map that name to a Model instance. 3. Directly providing a [`Model`][agents.models.interface.Model] implementation. @@ -64,20 +93,6 @@ english_agent = Agent( ) ``` -## Using other LLM providers - -You can use other LLM providers in 3 ways (examples [here](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/)): - -1. [`set_default_openai_client`][agents.set_default_openai_client] is useful in cases where you want to globally use an instance of `AsyncOpenAI` as the LLM client. This is for cases where the LLM provider has an OpenAI compatible API endpoint, and you can set the `base_url` and `api_key`. See a configurable example in [examples/model_providers/custom_example_global.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_global.py). -2. [`ModelProvider`][agents.models.interface.ModelProvider] is at the `Runner.run` level. This lets you say "use a custom model provider for all agents in this run". See a configurable example in [examples/model_providers/custom_example_provider.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_provider.py). -3. [`Agent.model`][agents.agent.Agent.model] lets you specify the model on a specific Agent instance. This enables you to mix and match different providers for different agents. See a configurable example in [examples/model_providers/custom_example_agent.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_agent.py). An easy way to use most available models is via the [LiteLLM integration](./litellm.md). - -In cases where you do not have an API key from `platform.openai.com`, we recommend disabling tracing via `set_tracing_disabled()`, or setting up a [different tracing processor](../tracing.md). - -!!! note - - In these examples, we use the Chat Completions API/model, because most LLM providers don't yet support the Responses API. If your LLM provider does support it, we recommend using Responses. - ## Common issues with using other LLM providers ### Tracing client error 401 @@ -100,7 +115,17 @@ The SDK uses the Responses API by default, but most other LLM providers don't ye Some model providers don't have support for [structured outputs](https://platform.openai.com/docs/guides/structured-outputs). This sometimes results in an error that looks something like this: ``` + BadRequestError: Error code: 400 - {'error': {'message': "'response_format.type' : value is not one of the allowed values ['text','json_object']", 'type': 'invalid_request_error'}} + ``` This is a shortcoming of some model providers - they support JSON outputs, but don't allow you to specify the `json_schema` to use for the output. We are working on a fix for this, but we suggest relying on providers that do have support for JSON schema output, because otherwise your app will often break because of malformed JSON. + +## Mixing models across providers + +You need to be aware of feature differences between model providers, or you may run into errors. For example, OpenAI supports structured outputs, multimodal input, and hosted file search and web search, but many other providers don't support these features. Be aware of these limitations: + +- Don't send unsupported `tools` to providers that don't understand them +- Filter out multimodal inputs before calling models that are text-only +- Be aware that providers that don't support structured JSON outputs will occasionally produce invalid JSON. diff --git a/docs/tracing.md b/docs/tracing.md index ea48a2e2..4a9c1bd9 100644 --- a/docs/tracing.md +++ b/docs/tracing.md @@ -101,6 +101,7 @@ To customize this default setup, to send traces to alternative or additional bac - [Weights & Biases](https://weave-docs.wandb.ai/guides/integrations/openai_agents) - [Arize-Phoenix](https://docs.arize.com/phoenix/tracing/integrations-tracing/openai-agents-sdk) +- [Future AGI](https://docs.futureagi.com/future-agi/products/observability/auto-instrumentation/openai_agents) - [MLflow (self-hosted/OSS](https://mlflow.org/docs/latest/tracing/integrations/openai-agent) - [MLflow (Databricks hosted](https://docs.databricks.com/aws/en/mlflow/mlflow-tracing#-automatic-tracing) - [Braintrust](https://braintrust.dev/docs/guides/traces/integrations#openai-agents-sdk) @@ -114,3 +115,4 @@ To customize this default setup, to send traces to alternative or additional bac - [Langfuse](https://langfuse.com/docs/integrations/openaiagentssdk/openai-agents) - [Langtrace](https://docs.langtrace.ai/supported-integrations/llm-frameworks/openai-agents-sdk) - [Okahu-Monocle](https://github.com/monocle2ai/monocle) +- [Galileo](https://v2docs.galileo.ai/integrations/openai-agent-integration#openai-agent-integration) diff --git a/docs/visualization.md b/docs/visualization.md index 00f3126d..409803f7 100644 --- a/docs/visualization.md +++ b/docs/visualization.md @@ -78,9 +78,7 @@ draw_graph(triage_agent).view() By default, `draw_graph` displays the graph inline. To save it as a file, specify a filename: ```python -draw_graph(triage_agent, filename="agent_graph.png") +draw_graph(triage_agent, filename="agent_graph") ``` This will generate `agent_graph.png` in the working directory. - - diff --git a/examples/basic/local_image.py b/examples/basic/local_image.py new file mode 100644 index 00000000..d4a784ba --- /dev/null +++ b/examples/basic/local_image.py @@ -0,0 +1,48 @@ +import asyncio +import base64 +import os + +from agents import Agent, Runner + +FILEPATH = os.path.join(os.path.dirname(__file__), "media/image_bison.jpg") + + +def image_to_base64(image_path): + with open(image_path, "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()).decode("utf-8") + return encoded_string + + +async def main(): + # Print base64-encoded image + b64_image = image_to_base64(FILEPATH) + + agent = Agent( + name="Assistant", + instructions="You are a helpful assistant.", + ) + + result = await Runner.run( + agent, + [ + { + "role": "user", + "content": [ + { + "type": "input_image", + "detail": "auto", + "image_url": f"data:image/jpeg;base64,{b64_image}", + } + ], + }, + { + "role": "user", + "content": "What do you see in this image?", + }, + ], + ) + print(result.final_output) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/basic/media/image_bison.jpg b/examples/basic/media/image_bison.jpg new file mode 100644 index 00000000..b113c91f Binary files /dev/null and b/examples/basic/media/image_bison.jpg differ diff --git a/examples/basic/non_strict_output_type.py b/examples/basic/non_strict_output_type.py new file mode 100644 index 00000000..49fcc4e2 --- /dev/null +++ b/examples/basic/non_strict_output_type.py @@ -0,0 +1,81 @@ +import asyncio +import json +from dataclasses import dataclass +from typing import Any + +from agents import Agent, AgentOutputSchema, AgentOutputSchemaBase, Runner + +"""This example demonstrates how to use an output type that is not in strict mode. Strict mode +allows us to guarantee valid JSON output, but some schemas are not strict-compatible. + +In this example, we define an output type that is not strict-compatible, and then we run the +agent with strict_json_schema=False. + +We also demonstrate a custom output type. + +To understand which schemas are strict-compatible, see: +https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas +""" + + +@dataclass +class OutputType: + jokes: dict[int, str] + """A list of jokes, indexed by joke number.""" + + +class CustomOutputSchema(AgentOutputSchemaBase): + """A demonstration of a custom output schema.""" + + def is_plain_text(self) -> bool: + return False + + def name(self) -> str: + return "CustomOutputSchema" + + def json_schema(self) -> dict[str, Any]: + return { + "type": "object", + "properties": {"jokes": {"type": "object", "properties": {"joke": {"type": "string"}}}}, + } + + def is_strict_json_schema(self) -> bool: + return False + + def validate_json(self, json_str: str) -> Any: + json_obj = json.loads(json_str) + # Just for demonstration, we'll return a list. + return list(json_obj["jokes"].values()) + + +async def main(): + agent = Agent( + name="Assistant", + instructions="You are a helpful assistant.", + output_type=OutputType, + ) + + input = "Tell me 3 short jokes." + + # First, let's try with a strict output type. This should raise an exception. + try: + result = await Runner.run(agent, input) + raise AssertionError("Should have raised an exception") + except Exception as e: + print(f"Error (expected): {e}") + + # Now let's try again with a non-strict output type. This should work. + # In some cases, it will raise an error - the schema isn't strict, so the model may + # produce an invalid JSON object. + agent.output_type = AgentOutputSchema(OutputType, strict_json_schema=False) + result = await Runner.run(agent, input) + print(result.final_output) + + # Finally, let's try a custom output type. + agent.output_type = CustomOutputSchema() + result = await Runner.run(agent, input) + print(result.final_output) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/basic/remote_image.py b/examples/basic/remote_image.py new file mode 100644 index 00000000..948a22d9 --- /dev/null +++ b/examples/basic/remote_image.py @@ -0,0 +1,31 @@ +import asyncio + +from agents import Agent, Runner + +URL = "https://upload.wikimedia.org/wikipedia/commons/0/0c/GoldenGateBridge-001.jpg" + + +async def main(): + agent = Agent( + name="Assistant", + instructions="You are a helpful assistant.", + ) + + result = await Runner.run( + agent, + [ + { + "role": "user", + "content": [{"type": "input_image", "detail": "auto", "image_url": URL}], + }, + { + "role": "user", + "content": "What do you see in this image?", + }, + ], + ) + print(result.final_output) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/financial_research_agent/main.py b/examples/financial_research_agent/main.py index 3fa8a7e0..b5b6cfdf 100644 --- a/examples/financial_research_agent/main.py +++ b/examples/financial_research_agent/main.py @@ -4,7 +4,7 @@ # Entrypoint for the financial bot example. -# Run this as `python -m examples.financial_bot.main` and enter a +# Run this as `python -m examples.financial_research_agent.main` and enter a # financial research query, for example: # "Write up an analysis of Apple Inc.'s most recent quarter." async def main() -> None: diff --git a/examples/hosted_mcp/__init__.py b/examples/hosted_mcp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/hosted_mcp/approvals.py b/examples/hosted_mcp/approvals.py new file mode 100644 index 00000000..2cabb3ee --- /dev/null +++ b/examples/hosted_mcp/approvals.py @@ -0,0 +1,61 @@ +import argparse +import asyncio + +from agents import ( + Agent, + HostedMCPTool, + MCPToolApprovalFunctionResult, + MCPToolApprovalRequest, + Runner, +) + +"""This example demonstrates how to use the hosted MCP support in the OpenAI Responses API, with +approval callbacks.""" + + +def approval_callback(request: MCPToolApprovalRequest) -> MCPToolApprovalFunctionResult: + answer = input(f"Approve running the tool `{request.data.name}`? (y/n) ") + result: MCPToolApprovalFunctionResult = {"approve": answer == "y"} + if not result["approve"]: + result["reason"] = "User denied" + return result + + +async def main(verbose: bool, stream: bool): + agent = Agent( + name="Assistant", + tools=[ + HostedMCPTool( + tool_config={ + "type": "mcp", + "server_label": "gitmcp", + "server_url": "https://gitmcp.io/openai/codex", + "require_approval": "always", + }, + on_approval_request=approval_callback, + ) + ], + ) + + if stream: + result = Runner.run_streamed(agent, "Which language is this repo written in?") + async for event in result.stream_events(): + if event.type == "run_item_stream_event": + print(f"Got event of type {event.item.__class__.__name__}") + print(f"Done streaming; final result: {result.final_output}") + else: + res = await Runner.run(agent, "Which language is this repo written in?") + print(res.final_output) + + if verbose: + for item in result.new_items: + print(item) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--verbose", action="store_true", default=False) + parser.add_argument("--stream", action="store_true", default=False) + args = parser.parse_args() + + asyncio.run(main(args.verbose, args.stream)) diff --git a/examples/hosted_mcp/simple.py b/examples/hosted_mcp/simple.py new file mode 100644 index 00000000..508c3a7a --- /dev/null +++ b/examples/hosted_mcp/simple.py @@ -0,0 +1,47 @@ +import argparse +import asyncio + +from agents import Agent, HostedMCPTool, Runner + +"""This example demonstrates how to use the hosted MCP support in the OpenAI Responses API, with +approvals not required for any tools. You should only use this for trusted MCP servers.""" + + +async def main(verbose: bool, stream: bool): + agent = Agent( + name="Assistant", + tools=[ + HostedMCPTool( + tool_config={ + "type": "mcp", + "server_label": "gitmcp", + "server_url": "https://gitmcp.io/openai/codex", + "require_approval": "never", + } + ) + ], + ) + + if stream: + result = Runner.run_streamed(agent, "Which language is this repo written in?") + async for event in result.stream_events(): + if event.type == "run_item_stream_event": + print(f"Got event of type {event.item.__class__.__name__}") + print(f"Done streaming; final result: {result.final_output}") + else: + res = await Runner.run(agent, "Which language is this repo written in?") + print(res.final_output) + # The repository is primarily written in multiple languages, including Rust and TypeScript... + + if verbose: + for item in result.new_items: + print(item) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--verbose", action="store_true", default=False) + parser.add_argument("--stream", action="store_true", default=False) + args = parser.parse_args() + + asyncio.run(main(args.verbose, args.stream)) diff --git a/examples/mcp/streamablehttp_example/README.md b/examples/mcp/streamablehttp_example/README.md new file mode 100644 index 00000000..a07fe19b --- /dev/null +++ b/examples/mcp/streamablehttp_example/README.md @@ -0,0 +1,13 @@ +# MCP Streamable HTTP Example + +This example uses a local Streamable HTTP server in [server.py](server.py). + +Run the example via: + +``` +uv run python examples/mcp/streamablehttp_example/main.py +``` + +## Details + +The example uses the `MCPServerStreamableHttp` class from `agents.mcp`. The server runs in a sub-process at `https://localhost:8000/mcp`. diff --git a/examples/mcp/streamablehttp_example/main.py b/examples/mcp/streamablehttp_example/main.py new file mode 100644 index 00000000..cc95e798 --- /dev/null +++ b/examples/mcp/streamablehttp_example/main.py @@ -0,0 +1,83 @@ +import asyncio +import os +import shutil +import subprocess +import time +from typing import Any + +from agents import Agent, Runner, gen_trace_id, trace +from agents.mcp import MCPServer, MCPServerStreamableHttp +from agents.model_settings import ModelSettings + + +async def run(mcp_server: MCPServer): + agent = Agent( + name="Assistant", + instructions="Use the tools to answer the questions.", + mcp_servers=[mcp_server], + model_settings=ModelSettings(tool_choice="required"), + ) + + # Use the `add` tool to add two numbers + message = "Add these numbers: 7 and 22." + print(f"Running: {message}") + result = await Runner.run(starting_agent=agent, input=message) + print(result.final_output) + + # Run the `get_weather` tool + message = "What's the weather in Tokyo?" + print(f"\n\nRunning: {message}") + result = await Runner.run(starting_agent=agent, input=message) + print(result.final_output) + + # Run the `get_secret_word` tool + message = "What's the secret word?" + print(f"\n\nRunning: {message}") + result = await Runner.run(starting_agent=agent, input=message) + print(result.final_output) + + +async def main(): + async with MCPServerStreamableHttp( + name="Streamable HTTP Python Server", + params={ + "url": "http://localhost:8000/mcp", + }, + ) as server: + trace_id = gen_trace_id() + with trace(workflow_name="Streamable HTTP Example", trace_id=trace_id): + print(f"View trace: https://platform.openai.com/traces/trace?trace_id={trace_id}\n") + await run(server) + + +if __name__ == "__main__": + # Let's make sure the user has uv installed + if not shutil.which("uv"): + raise RuntimeError( + "uv is not installed. Please install it: https://docs.astral.sh/uv/getting-started/installation/" + ) + + # We'll run the Streamable HTTP server in a subprocess. Usually this would be a remote server, but for this + # demo, we'll run it locally at http://localhost:8000/mcp + process: subprocess.Popen[Any] | None = None + try: + this_dir = os.path.dirname(os.path.abspath(__file__)) + server_file = os.path.join(this_dir, "server.py") + + print("Starting Streamable HTTP server at http://localhost:8000/mcp ...") + + # Run `uv run server.py` to start the Streamable HTTP server + process = subprocess.Popen(["uv", "run", server_file]) + # Give it 3 seconds to start + time.sleep(3) + + print("Streamable HTTP server started. Running example...\n\n") + except Exception as e: + print(f"Error starting Streamable HTTP server: {e}") + exit(1) + + try: + asyncio.run(main()) + finally: + if process: + process.terminate() diff --git a/examples/mcp/streamablehttp_example/server.py b/examples/mcp/streamablehttp_example/server.py new file mode 100644 index 00000000..d8f83965 --- /dev/null +++ b/examples/mcp/streamablehttp_example/server.py @@ -0,0 +1,33 @@ +import random + +import requests +from mcp.server.fastmcp import FastMCP + +# Create server +mcp = FastMCP("Echo Server") + + +@mcp.tool() +def add(a: int, b: int) -> int: + """Add two numbers""" + print(f"[debug-server] add({a}, {b})") + return a + b + + +@mcp.tool() +def get_secret_word() -> str: + print("[debug-server] get_secret_word()") + return random.choice(["apple", "banana", "cherry"]) + + +@mcp.tool() +def get_current_weather(city: str) -> str: + print(f"[debug-server] get_current_weather({city})") + + endpoint = "https://wttr.in" + response = requests.get(f"{endpoint}/{city}") + return response.text + + +if __name__ == "__main__": + mcp.run(transport="streamable-http") diff --git a/examples/model_providers/litellm_auto.py b/examples/model_providers/litellm_auto.py new file mode 100644 index 00000000..12b1e891 --- /dev/null +++ b/examples/model_providers/litellm_auto.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import asyncio + +from agents import Agent, Runner, function_tool, set_tracing_disabled + +"""This example uses the built-in support for LiteLLM. To use this, ensure you have the +ANTHROPIC_API_KEY environment variable set. +""" + +set_tracing_disabled(disabled=True) + + +@function_tool +def get_weather(city: str): + print(f"[debug] getting weather for {city}") + return f"The weather in {city} is sunny." + + +async def main(): + agent = Agent( + name="Assistant", + instructions="You only respond in haikus.", + # We prefix with litellm/ to tell the Runner to use the LitellmModel + model="litellm/anthropic/claude-3-5-sonnet-20240620", + tools=[get_weather], + ) + + result = await Runner.run(agent, "What's the weather in Tokyo?") + print(result.final_output) + + +if __name__ == "__main__": + import os + + if os.getenv("ANTHROPIC_API_KEY") is None: + raise ValueError( + "ANTHROPIC_API_KEY is not set. Please set it the environment variable and try again." + ) + + asyncio.run(main()) diff --git a/examples/research_bot/agents/search_agent.py b/examples/research_bot/agents/search_agent.py index 0212ce5b..61f91701 100644 --- a/examples/research_bot/agents/search_agent.py +++ b/examples/research_bot/agents/search_agent.py @@ -3,7 +3,7 @@ INSTRUCTIONS = ( "You are a research assistant. Given a search term, you search the web for that term and " - "produce a concise summary of the results. The summary must 2-3 paragraphs and less than 300 " + "produce a concise summary of the results. The summary must be 2-3 paragraphs and less than 300 " "words. Capture the main points. Write succinctly, no need to have complete sentences or good " "grammar. This will be consumed by someone synthesizing a report, so its vital you capture the " "essence and ignore any fluff. Do not include any additional commentary other than the summary " diff --git a/examples/tools/code_interpreter.py b/examples/tools/code_interpreter.py new file mode 100644 index 00000000..a5843ce3 --- /dev/null +++ b/examples/tools/code_interpreter.py @@ -0,0 +1,34 @@ +import asyncio + +from agents import Agent, CodeInterpreterTool, Runner, trace + + +async def main(): + agent = Agent( + name="Code interpreter", + instructions="You love doing math.", + tools=[ + CodeInterpreterTool( + tool_config={"type": "code_interpreter", "container": {"type": "auto"}}, + ) + ], + ) + + with trace("Code interpreter example"): + print("Solving math problem...") + result = Runner.run_streamed(agent, "What is the square root of273 * 312821 plus 1782?") + async for event in result.stream_events(): + if ( + event.type == "run_item_stream_event" + and event.item.type == "tool_call_item" + and event.item.raw_item.type == "code_interpreter_call" + ): + print(f"Code interpreter code:\n```\n{event.item.raw_item.code}\n```\n") + elif event.type == "run_item_stream_event": + print(f"Other event: {event.item.type}") + + print(f"Final output: {result.final_output}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/tools/image_generator.py b/examples/tools/image_generator.py new file mode 100644 index 00000000..fd6fcc6b --- /dev/null +++ b/examples/tools/image_generator.py @@ -0,0 +1,54 @@ +import asyncio +import base64 +import os +import subprocess +import sys +import tempfile + +from agents import Agent, ImageGenerationTool, Runner, trace + + +def open_file(path: str) -> None: + if sys.platform.startswith("darwin"): + subprocess.run(["open", path], check=False) # macOS + elif os.name == "nt": # Windows + os.astartfile(path) # type: ignore + elif os.name == "posix": + subprocess.run(["xdg-open", path], check=False) # Linux/Unix + else: + print(f"Don't know how to open files on this platform: {sys.platform}") + + +async def main(): + agent = Agent( + name="Image generator", + instructions="You are a helpful agent.", + tools=[ + ImageGenerationTool( + tool_config={"type": "image_generation", "quality": "low"}, + ) + ], + ) + + with trace("Image generation example"): + print("Generating image, this may take a while...") + result = await Runner.run( + agent, "Create an image of a frog eating a pizza, comic book style." + ) + print(result.final_output) + for item in result.new_items: + if ( + item.type == "tool_call_item" + and item.raw_item.type == "image_generation_call" + and (img_result := item.raw_item.result) + ): + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + tmp.write(base64.b64decode(img_result)) + temp_path = tmp.name + + # Open the image + open_file(temp_path) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index f94eb859..38a2f2b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,19 +1,19 @@ [project] name = "openai-agents" -version = "0.0.11" +version = "0.0.16" description = "OpenAI Agents SDK" readme = "README.md" requires-python = ">=3.9" license = "MIT" authors = [{ name = "OpenAI", email = "support@openai.com" }] dependencies = [ - "openai>=1.66.5", + "openai>=1.81.0", "pydantic>=2.10, <3", "griffe>=1.5.6, <2", "typing-extensions>=4.12.2, <5", "requests>=2.0, <3", "types-requests>=2.0, <3", - "mcp>=1.6.0, <2; python_version >= '3.10'", + "mcp>=1.8.0, <2; python_version >= '3.10'", ] classifiers = [ "Typing :: Typed", @@ -36,7 +36,7 @@ Repository = "https://github.com/openai/openai-agents-python" [project.optional-dependencies] voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"] viz = ["graphviz>=0.17"] -litellm = ["litellm>=1.65.0, <2"] +litellm = ["litellm>=1.67.4.post1, <2"] [dependency-groups] dev = [ @@ -61,6 +61,7 @@ dev = [ "graphviz", "mkdocs-static-i18n>=1.3.0", "eval-type-backport>=0.2.2", + "fastapi >= 0.110.0, <1", ] [tool.uv.workspace] diff --git a/src/agents/__init__.py b/src/agents/__init__.py index db7d312e..58949157 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -6,7 +6,7 @@ from . import _config from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult -from .agent_output import AgentOutputSchema +from .agent_output import AgentOutputSchema, AgentOutputSchemaBase from .computer import AsyncComputer, Button, Computer, Environment from .exceptions import ( AgentsException, @@ -54,10 +54,19 @@ StreamEvent, ) from .tool import ( + CodeInterpreterTool, ComputerTool, FileSearchTool, FunctionTool, FunctionToolResult, + HostedMCPTool, + ImageGenerationTool, + LocalShellCommandRequest, + LocalShellExecutor, + LocalShellTool, + MCPToolApprovalFunction, + MCPToolApprovalFunctionResult, + MCPToolApprovalRequest, Tool, WebSearchTool, default_tool_error_function, @@ -158,6 +167,7 @@ def enable_verbose_stdout_logging(): "OpenAIProvider", "OpenAIResponsesModel", "AgentOutputSchema", + "AgentOutputSchemaBase", "Computer", "AsyncComputer", "Environment", @@ -205,8 +215,17 @@ def enable_verbose_stdout_logging(): "FunctionToolResult", "ComputerTool", "FileSearchTool", + "CodeInterpreterTool", + "ImageGenerationTool", + "LocalShellCommandRequest", + "LocalShellExecutor", + "LocalShellTool", "Tool", "WebSearchTool", + "HostedMCPTool", + "MCPToolApprovalFunction", + "MCPToolApprovalRequest", + "MCPToolApprovalFunctionResult", "function_tool", "Usage", "add_trace_processor", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 94c181b7..2cfa270e 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -14,6 +14,9 @@ ResponseFunctionWebSearch, ResponseOutputMessage, ) +from openai.types.responses.response_code_interpreter_tool_call import ( + ResponseCodeInterpreterToolCall, +) from openai.types.responses.response_computer_tool_call import ( ActionClick, ActionDoubleClick, @@ -25,11 +28,17 @@ ActionType, ActionWait, ) -from openai.types.responses.response_input_param import ComputerCallOutput +from openai.types.responses.response_input_param import ComputerCallOutput, McpApprovalResponse +from openai.types.responses.response_output_item import ( + ImageGenerationCall, + LocalShellCall, + McpApprovalRequest, + McpListTools, +) from openai.types.responses.response_reasoning_item import ResponseReasoningItem from .agent import Agent, ToolsToFinalOutputResult -from .agent_output import AgentOutputSchema +from .agent_output import AgentOutputSchemaBase from .computer import AsyncComputer, Computer from .exceptions import AgentsException, ModelBehaviorError, UserError from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult @@ -38,6 +47,9 @@ HandoffCallItem, HandoffOutputItem, ItemHelpers, + MCPApprovalRequestItem, + MCPApprovalResponseItem, + MCPListToolsItem, MessageOutputItem, ModelResponse, ReasoningItem, @@ -52,7 +64,16 @@ from .models.interface import ModelTracing from .run_context import RunContextWrapper, TContext from .stream_events import RunItemStreamEvent, StreamEvent -from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool +from .tool import ( + ComputerTool, + FunctionTool, + FunctionToolResult, + HostedMCPTool, + LocalShellCommandRequest, + LocalShellTool, + MCPToolApprovalRequest, + Tool, +) from .tracing import ( SpanError, Trace, @@ -112,15 +133,29 @@ class ToolRunComputerAction: computer_tool: ComputerTool +@dataclass +class ToolRunMCPApprovalRequest: + request_item: McpApprovalRequest + mcp_tool: HostedMCPTool + + +@dataclass +class ToolRunLocalShellCall: + tool_call: LocalShellCall + local_shell_tool: LocalShellTool + + @dataclass class ProcessedResponse: new_items: list[RunItem] handoffs: list[ToolRunHandoff] functions: list[ToolRunFunction] computer_actions: list[ToolRunComputerAction] + local_shell_calls: list[ToolRunLocalShellCall] tools_used: list[str] # Names of all tools used, including hosted tools + mcp_approval_requests: list[ToolRunMCPApprovalRequest] # Only requests with callbacks - def has_tools_to_run(self) -> bool: + def has_tools_or_approvals_to_run(self) -> bool: # Handoffs, functions and computer actions need local processing # Hosted tools have already run, so there's nothing to do. return any( @@ -128,6 +163,8 @@ def has_tools_to_run(self) -> bool: self.handoffs, self.functions, self.computer_actions, + self.local_shell_calls, + self.mcp_approval_requests, ] ) @@ -195,7 +232,7 @@ async def execute_tools_and_side_effects( pre_step_items: list[RunItem], new_response: ModelResponse, processed_response: ProcessedResponse, - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, @@ -226,7 +263,16 @@ async def execute_tools_and_side_effects( new_step_items.extend([result.run_item for result in function_results]) new_step_items.extend(computer_results) - # Second, check if there are any handoffs + # Next, run the MCP approval requests + if processed_response.mcp_approval_requests: + approval_results = await cls.execute_mcp_approval_requests( + agent=agent, + approval_requests=processed_response.mcp_approval_requests, + context_wrapper=context_wrapper, + ) + new_step_items.extend(approval_results) + + # Next, check if there are any handoffs if run_handoffs := processed_response.handoffs: return await cls.execute_handoffs( agent=agent, @@ -240,7 +286,7 @@ async def execute_tools_and_side_effects( run_config=run_config, ) - # Third, we'll check if the tool use should result in a final output + # Next, we'll check if the tool use should result in a final output check_tool_use = await cls._check_for_final_output_from_tools( agent=agent, tool_results=function_results, @@ -295,7 +341,7 @@ async def execute_tools_and_side_effects( ) elif ( not output_schema or output_schema.is_plain_text() - ) and not processed_response.has_tools_to_run(): + ) and not processed_response.has_tools_or_approvals_to_run(): return await cls.execute_final_output( agent=agent, original_input=original_input, @@ -335,7 +381,7 @@ def process_model_response( agent: Agent[Any], all_tools: list[Tool], response: ModelResponse, - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], ) -> ProcessedResponse: items: list[RunItem] = [] @@ -343,10 +389,20 @@ def process_model_response( run_handoffs = [] functions = [] computer_actions = [] + local_shell_calls = [] + mcp_approval_requests = [] tools_used: list[str] = [] handoff_map = {handoff.tool_name: handoff for handoff in handoffs} function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None) + local_shell_tool = next( + (tool for tool in all_tools if isinstance(tool, LocalShellTool)), None + ) + hosted_mcp_server_map = { + tool.tool_config["server_label"]: tool + for tool in all_tools + if isinstance(tool, HostedMCPTool) + } for output in response.output: if isinstance(output, ResponseOutputMessage): @@ -375,6 +431,54 @@ def process_model_response( computer_actions.append( ToolRunComputerAction(tool_call=output, computer_tool=computer_tool) ) + elif isinstance(output, McpApprovalRequest): + items.append(MCPApprovalRequestItem(raw_item=output, agent=agent)) + if output.server_label not in hosted_mcp_server_map: + _error_tracing.attach_error_to_current_span( + SpanError( + message="MCP server label not found", + data={"server_label": output.server_label}, + ) + ) + raise ModelBehaviorError(f"MCP server label {output.server_label} not found") + else: + server = hosted_mcp_server_map[output.server_label] + if server.on_approval_request: + mcp_approval_requests.append( + ToolRunMCPApprovalRequest( + request_item=output, + mcp_tool=server, + ) + ) + else: + logger.warning( + f"MCP server {output.server_label} has no on_approval_request hook" + ) + elif isinstance(output, McpListTools): + items.append(MCPListToolsItem(raw_item=output, agent=agent)) + elif isinstance(output, ImageGenerationCall): + items.append(ToolCallItem(raw_item=output, agent=agent)) + tools_used.append("image_generation") + elif isinstance(output, ResponseCodeInterpreterToolCall): + items.append(ToolCallItem(raw_item=output, agent=agent)) + tools_used.append("code_interpreter") + elif isinstance(output, LocalShellCall): + items.append(ToolCallItem(raw_item=output, agent=agent)) + tools_used.append("local_shell") + if not local_shell_tool: + _error_tracing.attach_error_to_current_span( + SpanError( + message="Local shell tool not found", + data={}, + ) + ) + raise ModelBehaviorError( + "Model produced local shell call without a local shell tool." + ) + local_shell_calls.append( + ToolRunLocalShellCall(tool_call=output, local_shell_tool=local_shell_tool) + ) + elif not isinstance(output, ResponseFunctionToolCall): logger.warning(f"Unexpected output type, ignoring: {type(output)}") continue @@ -416,7 +520,9 @@ def process_model_response( handoffs=run_handoffs, functions=functions, computer_actions=computer_actions, + local_shell_calls=local_shell_calls, tools_used=tools_used, + mcp_approval_requests=mcp_approval_requests, ) @classmethod @@ -489,6 +595,30 @@ async def run_single_tool( for tool_run, result in zip(tool_runs, results) ] + @classmethod + async def execute_local_shell_calls( + cls, + *, + agent: Agent[TContext], + calls: list[ToolRunLocalShellCall], + context_wrapper: RunContextWrapper[TContext], + hooks: RunHooks[TContext], + config: RunConfig, + ) -> list[RunItem]: + results: list[RunItem] = [] + # Need to run these serially, because each call can affect the local shell state + for call in calls: + results.append( + await LocalShellAction.execute( + agent=agent, + call=call, + hooks=hooks, + context_wrapper=context_wrapper, + config=config, + ) + ) + return results + @classmethod async def execute_computer_actions( cls, @@ -643,6 +773,40 @@ async def execute_handoffs( next_step=NextStepHandoff(new_agent), ) + @classmethod + async def execute_mcp_approval_requests( + cls, + *, + agent: Agent[TContext], + approval_requests: list[ToolRunMCPApprovalRequest], + context_wrapper: RunContextWrapper[TContext], + ) -> list[RunItem]: + async def run_single_approval(approval_request: ToolRunMCPApprovalRequest) -> RunItem: + callback = approval_request.mcp_tool.on_approval_request + assert callback is not None, "Callback is required for MCP approval requests" + maybe_awaitable_result = callback( + MCPToolApprovalRequest(context_wrapper, approval_request.request_item) + ) + if inspect.isawaitable(maybe_awaitable_result): + result = await maybe_awaitable_result + else: + result = maybe_awaitable_result + reason = result.get("reason", None) + raw_item: McpApprovalResponse = { + "approval_request_id": approval_request.request_item.id, + "approve": result["approve"], + "type": "mcp_approval_response", + } + if not result["approve"] and reason: + raw_item["reason"] = reason + return MCPApprovalResponseItem( + raw_item=raw_item, + agent=agent, + ) + + tasks = [run_single_approval(approval_request) for approval_request in approval_requests] + return await asyncio.gather(*tasks) + @classmethod async def execute_final_output( cls, @@ -727,6 +891,11 @@ def stream_step_result_to_queue( event = RunItemStreamEvent(item=item, name="tool_output") elif isinstance(item, ReasoningItem): event = RunItemStreamEvent(item=item, name="reasoning_item_created") + elif isinstance(item, MCPApprovalRequestItem): + event = RunItemStreamEvent(item=item, name="mcp_approval_requested") + elif isinstance(item, MCPListToolsItem): + event = RunItemStreamEvent(item=item, name="mcp_list_tools") + else: logger.warning(f"Unexpected item type: {type(item)}") event = None @@ -919,3 +1088,54 @@ async def _get_screenshot_async( await computer.wait() return await computer.screenshot() + + +class LocalShellAction: + @classmethod + async def execute( + cls, + *, + agent: Agent[TContext], + call: ToolRunLocalShellCall, + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + config: RunConfig, + ) -> RunItem: + await asyncio.gather( + hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool), + ( + agent.hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + request = LocalShellCommandRequest( + ctx_wrapper=context_wrapper, + data=call.tool_call, + ) + output = call.local_shell_tool.executor(request) + if inspect.isawaitable(output): + result = await output + else: + result = output + + await asyncio.gather( + hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result), + ( + agent.hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + return ToolCallOutputItem( + agent=agent, + output=output, + raw_item={ + "type": "local_shell_call_output", + "id": call.tool_call.call_id, + "output": result, + # "id": "out" + call.tool_call.id, # TODO remove this, it should be optional + }, + ) diff --git a/src/agents/agent.py b/src/agents/agent.py index a24456b0..e22f579f 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -8,6 +8,7 @@ from typing_extensions import NotRequired, TypeAlias, TypedDict +from .agent_output import AgentOutputSchemaBase from .guardrail import InputGuardrail, OutputGuardrail from .handoffs import Handoff from .items import ItemHelpers @@ -141,8 +142,14 @@ class Agent(Generic[TContext]): Runs only if the agent produces a final output. """ - output_type: type[Any] | None = None - """The type of the output object. If not provided, the output will be `str`.""" + output_type: type[Any] | AgentOutputSchemaBase | None = None + """The type of the output object. If not provided, the output will be `str`. In most cases, + you should pass a regular Python type (e.g. a dataclass, Pydantic model, TypedDict, etc). + You can customize this in two ways: + 1. If you want non-strict schemas, pass `AgentOutputSchema(MyClass, strict_json_schema=False)`. + 2. If you want to use a custom JSON schema (i.e. without using the SDK's automatic schema) + creation, subclass and pass an `AgentOutputSchemaBase` subclass. + """ hooks: AgentHooks[TContext] | None = None """A class that receives callbacks on various lifecycle events for this agent. diff --git a/src/agents/agent_output.py b/src/agents/agent_output.py index 3262c57d..066bbd83 100644 --- a/src/agents/agent_output.py +++ b/src/agents/agent_output.py @@ -1,3 +1,4 @@ +import abc from dataclasses import dataclass from typing import Any @@ -12,8 +13,46 @@ _WRAPPER_DICT_KEY = "response" +class AgentOutputSchemaBase(abc.ABC): + """An object that captures the JSON schema of the output, as well as validating/parsing JSON + produced by the LLM into the output type. + """ + + @abc.abstractmethod + def is_plain_text(self) -> bool: + """Whether the output type is plain text (versus a JSON object).""" + pass + + @abc.abstractmethod + def name(self) -> str: + """The name of the output type.""" + pass + + @abc.abstractmethod + def json_schema(self) -> dict[str, Any]: + """Returns the JSON schema of the output. Will only be called if the output type is not + plain text. + """ + pass + + @abc.abstractmethod + def is_strict_json_schema(self) -> bool: + """Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema + features, but guarantees valis JSON. See here for details: + https://platform.openai.com/docs/guides/structured-outputs#supported-schemas + """ + pass + + @abc.abstractmethod + def validate_json(self, json_str: str) -> Any: + """Validate a JSON string against the output type. You must return the validated object, + or raise a `ModelBehaviorError` if the JSON is invalid. + """ + pass + + @dataclass(init=False) -class AgentOutputSchema: +class AgentOutputSchema(AgentOutputSchemaBase): """An object that captures the JSON schema of the output, as well as validating/parsing JSON produced by the LLM into the output type. """ @@ -32,7 +71,7 @@ class AgentOutputSchema: _output_schema: dict[str, Any] """The JSON schema of the output.""" - strict_json_schema: bool + _strict_json_schema: bool """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, as it increases the likelihood of correct JSON input. """ @@ -45,7 +84,7 @@ def __init__(self, output_type: type[Any], strict_json_schema: bool = True): setting this to True, as it increases the likelihood of correct JSON input. """ self.output_type = output_type - self.strict_json_schema = strict_json_schema + self._strict_json_schema = strict_json_schema if output_type is None or output_type is str: self._is_wrapped = False @@ -70,24 +109,35 @@ def __init__(self, output_type: type[Any], strict_json_schema: bool = True): self._type_adapter = TypeAdapter(output_type) self._output_schema = self._type_adapter.json_schema() - if self.strict_json_schema: - self._output_schema = ensure_strict_json_schema(self._output_schema) + if self._strict_json_schema: + try: + self._output_schema = ensure_strict_json_schema(self._output_schema) + except UserError as e: + raise UserError( + "Strict JSON schema is enabled, but the output type is not valid. " + "Either make the output type strict, or pass output_schema_strict=False to " + "your Agent()" + ) from e def is_plain_text(self) -> bool: """Whether the output type is plain text (versus a JSON object).""" return self.output_type is None or self.output_type is str + def is_strict_json_schema(self) -> bool: + """Whether the JSON schema is in strict mode.""" + return self._strict_json_schema + def json_schema(self) -> dict[str, Any]: """The JSON schema of the output type.""" if self.is_plain_text(): raise UserError("Output type is plain text, so no JSON schema is available") return self._output_schema - def validate_json(self, json_str: str, partial: bool = False) -> Any: + def validate_json(self, json_str: str) -> Any: """Validate a JSON string against the output type. Returns the validated object, or raises a `ModelBehaviorError` if the JSON is invalid. """ - validated = _json.validate_json(json_str, self._type_adapter, partial) + validated = _json.validate_json(json_str, self._type_adapter, partial=False) if self._is_wrapped: if not isinstance(validated, dict): _error_tracing.attach_error_to_current_span( @@ -113,7 +163,7 @@ def validate_json(self, json_str: str, partial: bool = False) -> Any: return validated[_WRAPPER_DICT_KEY] return validated - def output_type_name(self) -> str: + def name(self) -> str: """The name of the output type.""" return _type_to_str(self.output_type) diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index 0fc277c3..49e2d42d 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -1,12 +1,12 @@ from __future__ import annotations -import dataclasses import json import time from collections.abc import AsyncIterator from typing import Any, Literal, cast, overload import litellm.types +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents.exceptions import ModelBehaviorError @@ -29,7 +29,7 @@ from openai.types.responses import Response from ... import _debug -from ...agent_output import AgentOutputSchema +from ...agent_output import AgentOutputSchemaBase from ...handoffs import Handoff from ...items import ModelResponse, TResponseInputItem, TResponseStreamEvent from ...logger import logger @@ -68,14 +68,14 @@ async def get_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, previous_response_id: str | None, ) -> ModelResponse: with generation_span( model=str(self.model), - model_config=dataclasses.asdict(model_settings) + model_config=model_settings.to_json_dict() | {"base_url": str(self.base_url or ""), "model_impl": "litellm"}, disabled=tracing.is_disabled(), ) as span_generation: @@ -108,6 +108,16 @@ async def get_response( input_tokens=response_usage.prompt_tokens, output_tokens=response_usage.completion_tokens, total_tokens=response_usage.total_tokens, + input_tokens_details=InputTokensDetails( + cached_tokens=getattr( + response_usage.prompt_tokens_details, "cached_tokens", 0 + ) or 0 + ), + output_tokens_details=OutputTokensDetails( + reasoning_tokens=getattr( + response_usage.completion_tokens_details, "reasoning_tokens", 0 + ) or 0 + ), ) if response.usage else Usage() @@ -139,7 +149,7 @@ async def stream_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, *, @@ -147,7 +157,7 @@ async def stream_response( ) -> AsyncIterator[TResponseStreamEvent]: with generation_span( model=str(self.model), - model_config=dataclasses.asdict(model_settings) + model_config=model_settings.to_json_dict() | {"base_url": str(self.base_url or ""), "model_impl": "litellm"}, disabled=tracing.is_disabled(), ) as span_generation: @@ -186,7 +196,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], span: Span[GenerationSpanData], tracing: ModelTracing, @@ -200,7 +210,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], span: Span[GenerationSpanData], tracing: ModelTracing, @@ -213,7 +223,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], span: Span[GenerationSpanData], tracing: ModelTracing, @@ -270,6 +280,8 @@ async def _fetch_response( extra_kwargs["extra_query"] = model_settings.extra_query if model_settings.metadata: extra_kwargs["metadata"] = model_settings.metadata + if model_settings.extra_body and isinstance(model_settings.extra_body, dict): + extra_kwargs.update(model_settings.extra_body) ret = await litellm.acompletion( model=self.model, @@ -286,7 +298,7 @@ async def _fetch_response( stream=stream, stream_options=stream_options, reasoning_effort=reasoning_effort, - extra_headers=HEADERS, + extra_headers={**HEADERS, **(model_settings.extra_headers or {})}, api_key=self.api_key, base_url=self.base_url, **extra_kwargs, diff --git a/src/agents/extensions/models/litellm_provider.py b/src/agents/extensions/models/litellm_provider.py new file mode 100644 index 00000000..5a2dc166 --- /dev/null +++ b/src/agents/extensions/models/litellm_provider.py @@ -0,0 +1,21 @@ +from ...models.interface import Model, ModelProvider +from .litellm_model import LitellmModel + +DEFAULT_MODEL: str = "gpt-4.1" + + +class LitellmProvider(ModelProvider): + """A ModelProvider that uses LiteLLM to route to any model provider. You can use it via: + ```python + Runner.run(agent, input, run_config=RunConfig(model_provider=LitellmProvider())) + ``` + See supported models here: [litellm models](https://docs.litellm.ai/docs/providers). + + NOTE: API keys must be set via environment variables. If you're using models that require + additional configuration (e.g. Azure API base or version), those must also be set via the + environment variables that LiteLLM expects. If you have more advanced needs, we recommend + copy-pasting this class and making any modifications you need. + """ + + def get_model(self, model_name: str | None) -> Model: + return LitellmModel(model_name or DEFAULT_MODEL) diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index 5fb35062..888e262c 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -132,6 +132,6 @@ def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source: graph = graphviz.Source(dot_code) if filename: - graph.render(filename, format="png") + graph.render(filename, format="png", cleanup=True) return graph diff --git a/src/agents/items.py b/src/agents/items.py index 8fb2b52a..64797ad2 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -18,7 +18,22 @@ ResponseOutputText, ResponseStreamEvent, ) -from openai.types.responses.response_input_item_param import ComputerCallOutput, FunctionCallOutput +from openai.types.responses.response_code_interpreter_tool_call import ( + ResponseCodeInterpreterToolCall, +) +from openai.types.responses.response_input_item_param import ( + ComputerCallOutput, + FunctionCallOutput, + LocalShellCallOutput, + McpApprovalResponse, +) +from openai.types.responses.response_output_item import ( + ImageGenerationCall, + LocalShellCall, + McpApprovalRequest, + McpCall, + McpListTools, +) from openai.types.responses.response_reasoning_item import ResponseReasoningItem from pydantic import BaseModel from typing_extensions import TypeAlias @@ -108,6 +123,10 @@ class HandoffOutputItem(RunItemBase[TResponseInputItem]): ResponseComputerToolCall, ResponseFileSearchToolCall, ResponseFunctionWebSearch, + McpCall, + ResponseCodeInterpreterToolCall, + ImageGenerationCall, + LocalShellCall, ] """A type that represents a tool call item.""" @@ -123,10 +142,12 @@ class ToolCallItem(RunItemBase[ToolCallItemTypes]): @dataclass -class ToolCallOutputItem(RunItemBase[Union[FunctionCallOutput, ComputerCallOutput]]): +class ToolCallOutputItem( + RunItemBase[Union[FunctionCallOutput, ComputerCallOutput, LocalShellCallOutput]] +): """Represents the output of a tool call.""" - raw_item: FunctionCallOutput | ComputerCallOutput + raw_item: FunctionCallOutput | ComputerCallOutput | LocalShellCallOutput """The raw item from the model.""" output: Any @@ -147,6 +168,36 @@ class ReasoningItem(RunItemBase[ResponseReasoningItem]): type: Literal["reasoning_item"] = "reasoning_item" +@dataclass +class MCPListToolsItem(RunItemBase[McpListTools]): + """Represents a call to an MCP server to list tools.""" + + raw_item: McpListTools + """The raw MCP list tools call.""" + + type: Literal["mcp_list_tools_item"] = "mcp_list_tools_item" + + +@dataclass +class MCPApprovalRequestItem(RunItemBase[McpApprovalRequest]): + """Represents a request for MCP approval.""" + + raw_item: McpApprovalRequest + """The raw MCP approval request.""" + + type: Literal["mcp_approval_request_item"] = "mcp_approval_request_item" + + +@dataclass +class MCPApprovalResponseItem(RunItemBase[McpApprovalResponse]): + """Represents a response to an MCP approval request.""" + + raw_item: McpApprovalResponse + """The raw MCP approval response.""" + + type: Literal["mcp_approval_response_item"] = "mcp_approval_response_item" + + RunItem: TypeAlias = Union[ MessageOutputItem, HandoffCallItem, @@ -154,6 +205,9 @@ class ReasoningItem(RunItemBase[ResponseReasoningItem]): ToolCallItem, ToolCallOutputItem, ReasoningItem, + MCPListToolsItem, + MCPApprovalRequestItem, + MCPApprovalResponseItem, ] """An item generated by an agent.""" diff --git a/src/agents/mcp/__init__.py b/src/agents/mcp/__init__.py index 1a72a89f..d4eb8fa6 100644 --- a/src/agents/mcp/__init__.py +++ b/src/agents/mcp/__init__.py @@ -5,6 +5,8 @@ MCPServerSseParams, MCPServerStdio, MCPServerStdioParams, + MCPServerStreamableHttp, + MCPServerStreamableHttpParams, ) except ImportError: pass @@ -17,5 +19,7 @@ "MCPServerSseParams", "MCPServerStdio", "MCPServerStdioParams", + "MCPServerStreamableHttp", + "MCPServerStreamableHttpParams", "MCPUtil", ] diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index e70d7ce6..414b517a 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -3,13 +3,16 @@ import abc import asyncio from contextlib import AbstractAsyncContextManager, AsyncExitStack +from datetime import timedelta from pathlib import Path from typing import Any, Literal from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client from mcp.client.sse import sse_client -from mcp.types import CallToolResult, JSONRPCMessage +from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client +from mcp.shared.message import SessionMessage +from mcp.types import CallToolResult, InitializeResult from typing_extensions import NotRequired, TypedDict from ..exceptions import UserError @@ -54,7 +57,7 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C class _MCPServerWithClientSession(MCPServer, abc.ABC): """Base class for MCP servers that use a `ClientSession` to communicate with the server.""" - def __init__(self, cache_tools_list: bool): + def __init__(self, cache_tools_list: bool, client_session_timeout_seconds: float | None): """ Args: cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be @@ -63,11 +66,16 @@ def __init__(self, cache_tools_list: bool): by calling `invalidate_tools_cache()`. You should set this to `True` if you know the server will not change its tools list, because it can drastically improve latency (by avoiding a round-trip to the server every time). + + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. """ self.session: ClientSession | None = None self.exit_stack: AsyncExitStack = AsyncExitStack() self._cleanup_lock: asyncio.Lock = asyncio.Lock() self.cache_tools_list = cache_tools_list + self.server_initialize_result: InitializeResult | None = None + + self.client_session_timeout_seconds = client_session_timeout_seconds # The cache is always dirty at startup, so that we fetch tools at least once self._cache_dirty = True @@ -78,8 +86,9 @@ def create_streams( self, ) -> AbstractAsyncContextManager[ tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback | None ] ]: """Create the streams for the server.""" @@ -100,9 +109,22 @@ async def connect(self): """Connect to the server.""" try: transport = await self.exit_stack.enter_async_context(self.create_streams()) - read, write = transport - session = await self.exit_stack.enter_async_context(ClientSession(read, write)) - await session.initialize() + # streamablehttp_client returns (read, write, get_session_id) + # sse_client returns (read, write) + + read, write, *_ = transport + + session = await self.exit_stack.enter_async_context( + ClientSession( + read, + write, + timedelta(seconds=self.client_session_timeout_seconds) + if self.client_session_timeout_seconds + else None, + ) + ) + server_result = await session.initialize() + self.server_initialize_result = server_result self.session = session except Exception as e: logger.error(f"Error initializing MCP server: {e}") @@ -137,9 +159,10 @@ async def cleanup(self): async with self._cleanup_lock: try: await self.exit_stack.aclose() - self.session = None except Exception as e: logger.error(f"Error cleaning up server: {e}") + finally: + self.session = None class MCPServerStdioParams(TypedDict): @@ -182,6 +205,7 @@ def __init__( params: MCPServerStdioParams, cache_tools_list: bool = False, name: str | None = None, + client_session_timeout_seconds: float | None = 5, ): """Create a new MCP server based on the stdio transport. @@ -198,8 +222,9 @@ def __init__( improve latency (by avoiding a round-trip to the server every time). name: A readable name for the server. If not provided, we'll create one from the command. + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. """ - super().__init__(cache_tools_list) + super().__init__(cache_tools_list, client_session_timeout_seconds) self.params = StdioServerParameters( command=params["command"], @@ -216,8 +241,9 @@ def create_streams( self, ) -> AbstractAsyncContextManager[ tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback | None ] ]: """Create the streams for the server.""" @@ -256,6 +282,7 @@ def __init__( params: MCPServerSseParams, cache_tools_list: bool = False, name: str | None = None, + client_session_timeout_seconds: float | None = 5, ): """Create a new MCP server based on the HTTP with SSE transport. @@ -273,8 +300,10 @@ def __init__( name: A readable name for the server. If not provided, we'll create one from the URL. + + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. """ - super().__init__(cache_tools_list) + super().__init__(cache_tools_list, client_session_timeout_seconds) self.params = params self._name = name or f"sse: {self.params['url']}" @@ -283,8 +312,9 @@ def create_streams( self, ) -> AbstractAsyncContextManager[ tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback | None ] ]: """Create the streams for the server.""" @@ -299,3 +329,84 @@ def create_streams( def name(self) -> str: """A readable name for the server.""" return self._name + + +class MCPServerStreamableHttpParams(TypedDict): + """Mirrors the params in`mcp.client.streamable_http.streamablehttp_client`.""" + + url: str + """The URL of the server.""" + + headers: NotRequired[dict[str, str]] + """The headers to send to the server.""" + + timeout: NotRequired[timedelta] + """The timeout for the HTTP request. Defaults to 5 seconds.""" + + sse_read_timeout: NotRequired[timedelta] + """The timeout for the SSE connection, in seconds. Defaults to 5 minutes.""" + + terminate_on_close: NotRequired[bool] + """Terminate on close""" + + +class MCPServerStreamableHttp(_MCPServerWithClientSession): + """MCP server implementation that uses the Streamable HTTP transport. See the [spec] + (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http) + for details. + """ + + def __init__( + self, + params: MCPServerStreamableHttpParams, + cache_tools_list: bool = False, + name: str | None = None, + client_session_timeout_seconds: float | None = 5, + ): + """Create a new MCP server based on the Streamable HTTP transport. + + Args: + params: The params that configure the server. This includes the URL of the server, + the headers to send to the server, the timeout for the HTTP request, and the + timeout for the Streamable HTTP connection and whether we need to + terminate on close. + + cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be + cached and only fetched from the server once. If `False`, the tools list will be + fetched from the server on each call to `list_tools()`. The cache can be + invalidated by calling `invalidate_tools_cache()`. You should set this to `True` + if you know the server will not change its tools list, because it can drastically + improve latency (by avoiding a round-trip to the server every time). + + name: A readable name for the server. If not provided, we'll create one from the + URL. + + client_session_timeout_seconds: the read timeout passed to the MCP ClientSession. + """ + super().__init__(cache_tools_list, client_session_timeout_seconds) + + self.params = params + self._name = name or f"streamable_http: {self.params['url']}" + + def create_streams( + self, + ) -> AbstractAsyncContextManager[ + tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback | None + ] + ]: + """Create the streams for the server.""" + return streamablehttp_client( + url=self.params["url"], + headers=self.params.get("headers", None), + timeout=self.params.get("timeout", timedelta(seconds=30)), + sse_read_timeout=self.params.get("sse_read_timeout", timedelta(seconds=60 * 5)), + terminate_on_close=self.params.get("terminate_on_close", True) + ) + + @property + def name(self) -> str: + """A readable name for the server.""" + return self._name diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index ed9a0131..7b016c98 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -1,10 +1,12 @@ from __future__ import annotations +import dataclasses from dataclasses import dataclass, fields, replace -from typing import Literal +from typing import Any, Literal -from openai._types import Body, Query +from openai._types import Body, Headers, Query from openai.types.shared import Reasoning +from pydantic import BaseModel @dataclass @@ -67,6 +69,10 @@ class ModelSettings: """Additional body fields to provide with the request. Defaults to None if not provided.""" + extra_headers: Headers | None = None + """Additional headers to provide with the request. + Defaults to None if not provided.""" + def resolve(self, override: ModelSettings | None) -> ModelSettings: """Produce a new ModelSettings by overlaying any non-None values from the override on top of this instance.""" @@ -79,3 +85,16 @@ def resolve(self, override: ModelSettings | None) -> ModelSettings: if getattr(override, field.name) is not None } return replace(self, **changes) + + def to_json_dict(self) -> dict[str, Any]: + dataclass_dict = dataclasses.asdict(self) + + json_dict: dict[str, Any] = {} + + for field_name, value in dataclass_dict.items(): + if isinstance(value, BaseModel): + json_dict[field_name] = value.model_dump(mode="json") + else: + json_dict[field_name] = value + + return json_dict diff --git a/src/agents/models/chatcmpl_converter.py b/src/agents/models/chatcmpl_converter.py index 00175a16..1d599e8c 100644 --- a/src/agents/models/chatcmpl_converter.py +++ b/src/agents/models/chatcmpl_converter.py @@ -36,7 +36,7 @@ ) from openai.types.responses.response_input_param import FunctionCallOutput, ItemReference, Message -from ..agent_output import AgentOutputSchema +from ..agent_output import AgentOutputSchemaBase from ..exceptions import AgentsException, UserError from ..handoffs import Handoff from ..items import TResponseInputItem, TResponseOutputItem @@ -67,7 +67,7 @@ def convert_tool_choice( @classmethod def convert_response_format( - cls, final_output_schema: AgentOutputSchema | None + cls, final_output_schema: AgentOutputSchemaBase | None ) -> ResponseFormat | NotGiven: if not final_output_schema or final_output_schema.is_plain_text(): return NOT_GIVEN @@ -76,7 +76,7 @@ def convert_response_format( "type": "json_schema", "json_schema": { "name": "final_output", - "strict": final_output_schema.strict_json_schema, + "strict": final_output_schema.is_strict_json_schema(), "schema": final_output_schema.json_schema(), }, } @@ -234,7 +234,7 @@ def extract_all_content( type="image_url", image_url={ "url": casted_image_param["image_url"], - "detail": casted_image_param["detail"], + "detail": casted_image_param.get("detail", "auto"), }, ) ) diff --git a/src/agents/models/chatcmpl_stream_handler.py b/src/agents/models/chatcmpl_stream_handler.py index 32f04acb..d18f5912 100644 --- a/src/agents/models/chatcmpl_stream_handler.py +++ b/src/agents/models/chatcmpl_stream_handler.py @@ -38,6 +38,16 @@ class StreamingState: function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict) +class SequenceNumber: + def __init__(self): + self._sequence_number = 0 + + def get_and_increment(self) -> int: + num = self._sequence_number + self._sequence_number += 1 + return num + + class ChatCmplStreamHandler: @classmethod async def handle_stream( @@ -47,16 +57,18 @@ async def handle_stream( ) -> AsyncIterator[TResponseStreamEvent]: usage: CompletionUsage | None = None state = StreamingState() - + sequence_number = SequenceNumber() async for chunk in stream: if not state.started: state.started = True yield ResponseCreatedEvent( response=response, type="response.created", + sequence_number=sequence_number.get_and_increment(), ) - usage = chunk.usage + # This is always set by the OpenAI API, but not by others e.g. LiteLLM + usage = chunk.usage if hasattr(chunk, "usage") else None if not chunk.choices or not chunk.choices[0].delta: continue @@ -88,6 +100,7 @@ async def handle_stream( item=assistant_item, output_index=0, type="response.output_item.added", + sequence_number=sequence_number.get_and_increment(), ) yield ResponseContentPartAddedEvent( content_index=state.text_content_index_and_output[0], @@ -99,6 +112,7 @@ async def handle_stream( annotations=[], ), type="response.content_part.added", + sequence_number=sequence_number.get_and_increment(), ) # Emit the delta for this segment of content yield ResponseTextDeltaEvent( @@ -107,12 +121,14 @@ async def handle_stream( item_id=FAKE_RESPONSES_ID, output_index=0, type="response.output_text.delta", + sequence_number=sequence_number.get_and_increment(), ) # Accumulate the text into the response part state.text_content_index_and_output[1].text += delta.content # Handle refusals (model declines to answer) - if delta.refusal: + # This is always set by the OpenAI API, but not by others e.g. LiteLLM + if hasattr(delta, "refusal") and delta.refusal: if not state.refusal_content_index_and_output: # Initialize a content tracker for streaming refusal text state.refusal_content_index_and_output = ( @@ -132,6 +148,7 @@ async def handle_stream( item=assistant_item, output_index=0, type="response.output_item.added", + sequence_number=sequence_number.get_and_increment(), ) yield ResponseContentPartAddedEvent( content_index=state.refusal_content_index_and_output[0], @@ -143,6 +160,7 @@ async def handle_stream( annotations=[], ), type="response.content_part.added", + sequence_number=sequence_number.get_and_increment(), ) # Emit the delta for this segment of refusal yield ResponseRefusalDeltaEvent( @@ -151,6 +169,7 @@ async def handle_stream( item_id=FAKE_RESPONSES_ID, output_index=0, type="response.refusal.delta", + sequence_number=sequence_number.get_and_increment(), ) # Accumulate the refusal string in the output part state.refusal_content_index_and_output[1].refusal += delta.refusal @@ -188,6 +207,7 @@ async def handle_stream( output_index=0, part=state.text_content_index_and_output[1], type="response.content_part.done", + sequence_number=sequence_number.get_and_increment(), ) if state.refusal_content_index_and_output: @@ -199,6 +219,7 @@ async def handle_stream( output_index=0, part=state.refusal_content_index_and_output[1], type="response.content_part.done", + sequence_number=sequence_number.get_and_increment(), ) # Actually send events for the function calls @@ -214,6 +235,7 @@ async def handle_stream( ), output_index=function_call_starting_index, type="response.output_item.added", + sequence_number=sequence_number.get_and_increment(), ) # Then, yield the args yield ResponseFunctionCallArgumentsDeltaEvent( @@ -221,6 +243,7 @@ async def handle_stream( item_id=FAKE_RESPONSES_ID, output_index=function_call_starting_index, type="response.function_call_arguments.delta", + sequence_number=sequence_number.get_and_increment(), ) # Finally, the ResponseOutputItemDone yield ResponseOutputItemDoneEvent( @@ -233,6 +256,7 @@ async def handle_stream( ), output_index=function_call_starting_index, type="response.output_item.done", + sequence_number=sequence_number.get_and_increment(), ) # Finally, send the Response completed event @@ -256,6 +280,7 @@ async def handle_stream( item=assistant_msg, output_index=0, type="response.output_item.done", + sequence_number=sequence_number.get_and_increment(), ) for function_call in state.function_calls.values(): @@ -287,4 +312,5 @@ async def handle_stream( yield ResponseCompletedEvent( response=final_response, type="response.completed", + sequence_number=sequence_number.get_and_increment(), ) diff --git a/src/agents/models/interface.py b/src/agents/models/interface.py index bcf2c1a6..3a79e564 100644 --- a/src/agents/models/interface.py +++ b/src/agents/models/interface.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterator from typing import TYPE_CHECKING -from ..agent_output import AgentOutputSchema +from ..agent_output import AgentOutputSchemaBase from ..handoffs import Handoff from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent from ..tool import Tool @@ -41,7 +41,7 @@ async def get_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, *, @@ -72,7 +72,7 @@ def stream_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, *, diff --git a/src/agents/models/multi_provider.py b/src/agents/models/multi_provider.py new file mode 100644 index 00000000..d075ac9b --- /dev/null +++ b/src/agents/models/multi_provider.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from openai import AsyncOpenAI + +from ..exceptions import UserError +from .interface import Model, ModelProvider +from .openai_provider import OpenAIProvider + + +class MultiProviderMap: + """A map of model name prefixes to ModelProviders.""" + + def __init__(self): + self._mapping: dict[str, ModelProvider] = {} + + def has_prefix(self, prefix: str) -> bool: + """Returns True if the given prefix is in the mapping.""" + return prefix in self._mapping + + def get_mapping(self) -> dict[str, ModelProvider]: + """Returns a copy of the current prefix -> ModelProvider mapping.""" + return self._mapping.copy() + + def set_mapping(self, mapping: dict[str, ModelProvider]): + """Overwrites the current mapping with a new one.""" + self._mapping = mapping + + def get_provider(self, prefix: str) -> ModelProvider | None: + """Returns the ModelProvider for the given prefix. + + Args: + prefix: The prefix of the model name e.g. "openai" or "my_prefix". + """ + return self._mapping.get(prefix) + + def add_provider(self, prefix: str, provider: ModelProvider): + """Adds a new prefix -> ModelProvider mapping. + + Args: + prefix: The prefix of the model name e.g. "openai" or "my_prefix". + provider: The ModelProvider to use for the given prefix. + """ + self._mapping[prefix] = provider + + def remove_provider(self, prefix: str): + """Removes the mapping for the given prefix. + + Args: + prefix: The prefix of the model name e.g. "openai" or "my_prefix". + """ + del self._mapping[prefix] + + +class MultiProvider(ModelProvider): + """This ModelProvider maps to a Model based on the prefix of the model name. By default, the + mapping is: + - "openai/" prefix or no prefix -> OpenAIProvider. e.g. "openai/gpt-4.1", "gpt-4.1" + - "litellm/" prefix -> LitellmProvider. e.g. "litellm/openai/gpt-4.1" + + You can override or customize this mapping. + """ + + def __init__( + self, + *, + provider_map: MultiProviderMap | None = None, + openai_api_key: str | None = None, + openai_base_url: str | None = None, + openai_client: AsyncOpenAI | None = None, + openai_organization: str | None = None, + openai_project: str | None = None, + openai_use_responses: bool | None = None, + ) -> None: + """Create a new OpenAI provider. + + Args: + provider_map: A MultiProviderMap that maps prefixes to ModelProviders. If not provided, + we will use a default mapping. See the documentation for this class to see the + default mapping. + openai_api_key: The API key to use for the OpenAI provider. If not provided, we will use + the default API key. + openai_base_url: The base URL to use for the OpenAI provider. If not provided, we will + use the default base URL. + openai_client: An optional OpenAI client to use. If not provided, we will create a new + OpenAI client using the api_key and base_url. + openai_organization: The organization to use for the OpenAI provider. + openai_project: The project to use for the OpenAI provider. + openai_use_responses: Whether to use the OpenAI responses API. + """ + self.provider_map = provider_map + self.openai_provider = OpenAIProvider( + api_key=openai_api_key, + base_url=openai_base_url, + openai_client=openai_client, + organization=openai_organization, + project=openai_project, + use_responses=openai_use_responses, + ) + + self._fallback_providers: dict[str, ModelProvider] = {} + + def _get_prefix_and_model_name(self, model_name: str | None) -> tuple[str | None, str | None]: + if model_name is None: + return None, None + elif "/" in model_name: + prefix, model_name = model_name.split("/", 1) + return prefix, model_name + else: + return None, model_name + + def _create_fallback_provider(self, prefix: str) -> ModelProvider: + if prefix == "litellm": + from ..extensions.models.litellm_provider import LitellmProvider + + return LitellmProvider() + else: + raise UserError(f"Unknown prefix: {prefix}") + + def _get_fallback_provider(self, prefix: str | None) -> ModelProvider: + if prefix is None or prefix == "openai": + return self.openai_provider + elif prefix in self._fallback_providers: + return self._fallback_providers[prefix] + else: + self._fallback_providers[prefix] = self._create_fallback_provider(prefix) + return self._fallback_providers[prefix] + + def get_model(self, model_name: str | None) -> Model: + """Returns a Model based on the model name. The model name can have a prefix, ending with + a "/", which will be used to look up the ModelProvider. If there is no prefix, we will use + the OpenAI provider. + + Args: + model_name: The name of the model to get. + + Returns: + A Model. + """ + prefix, model_name = self._get_prefix_and_model_name(model_name) + + if prefix and self.provider_map and (provider := self.provider_map.get_provider(prefix)): + return provider.get_model(model_name) + else: + return self._get_fallback_provider(prefix).get_model(model_name) diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 9989c1ee..4465ff2f 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -1,6 +1,5 @@ from __future__ import annotations -import dataclasses import json import time from collections.abc import AsyncIterator @@ -10,9 +9,10 @@ from openai.types import ChatModel from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.responses import Response +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from .. import _debug -from ..agent_output import AgentOutputSchema +from ..agent_output import AgentOutputSchemaBase from ..handoffs import Handoff from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent from ..logger import logger @@ -49,15 +49,14 @@ async def get_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, previous_response_id: str | None, ) -> ModelResponse: with generation_span( model=str(self.model), - model_config=dataclasses.asdict(model_settings) - | {"base_url": str(self._client.base_url)}, + model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, disabled=tracing.is_disabled(), ) as span_generation: response = await self._fetch_response( @@ -85,6 +84,18 @@ async def get_response( input_tokens=response.usage.prompt_tokens, output_tokens=response.usage.completion_tokens, total_tokens=response.usage.total_tokens, + input_tokens_details=InputTokensDetails( + cached_tokens=getattr( + response.usage.prompt_tokens_details, "cached_tokens", 0 + ) + or 0, + ), + output_tokens_details=OutputTokensDetails( + reasoning_tokens=getattr( + response.usage.completion_tokens_details, "reasoning_tokens", 0 + ) + or 0, + ), ) if response.usage else Usage() @@ -110,7 +121,7 @@ async def stream_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, *, @@ -121,8 +132,7 @@ async def stream_response( """ with generation_span( model=str(self.model), - model_config=dataclasses.asdict(model_settings) - | {"base_url": str(self._client.base_url)}, + model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, disabled=tracing.is_disabled(), ) as span_generation: response, stream = await self._fetch_response( @@ -160,7 +170,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], span: Span[GenerationSpanData], tracing: ModelTracing, @@ -174,7 +184,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], span: Span[GenerationSpanData], tracing: ModelTracing, @@ -187,7 +197,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], span: Span[GenerationSpanData], tracing: ModelTracing, @@ -255,7 +265,7 @@ async def _fetch_response( stream_options=self._non_null_or_not_given(stream_options), store=self._non_null_or_not_given(store), reasoning_effort=self._non_null_or_not_given(reasoning_effort), - extra_headers=HEADERS, + extra_headers={**HEADERS, **(model_settings.extra_headers or {})}, extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, metadata=self._non_null_or_not_given(model_settings.metadata), diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index ab4617d4..86c8e69c 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -10,6 +10,7 @@ from openai.types.responses import ( Response, ResponseCompletedEvent, + ResponseIncludable, ResponseStreamEvent, ResponseTextConfigParam, ToolParam, @@ -18,12 +19,22 @@ ) from .. import _debug -from ..agent_output import AgentOutputSchema +from ..agent_output import AgentOutputSchemaBase from ..exceptions import UserError from ..handoffs import Handoff from ..items import ItemHelpers, ModelResponse, TResponseInputItem from ..logger import logger -from ..tool import ComputerTool, FileSearchTool, FunctionTool, Tool, WebSearchTool +from ..tool import ( + CodeInterpreterTool, + ComputerTool, + FileSearchTool, + FunctionTool, + HostedMCPTool, + ImageGenerationTool, + LocalShellTool, + Tool, + WebSearchTool, +) from ..tracing import SpanError, response_span from ..usage import Usage from ..version import __version__ @@ -36,13 +47,6 @@ _USER_AGENT = f"Agents/Python {__version__}" _HEADERS = {"User-Agent": _USER_AGENT} -# From the Responses API -IncludeLiteral = Literal[ - "file_search_call.results", - "message.input_image.image_url", - "computer_call_output.output.image_url", -] - class OpenAIResponsesModel(Model): """ @@ -66,7 +70,7 @@ async def get_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, previous_response_id: str | None, @@ -98,6 +102,8 @@ async def get_response( input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens, total_tokens=response.usage.total_tokens, + input_tokens_details=response.usage.input_tokens_details, + output_tokens_details=response.usage.output_tokens_details, ) if response.usage else Usage() @@ -131,7 +137,7 @@ async def stream_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, previous_response_id: str | None, @@ -182,7 +188,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], previous_response_id: str | None, stream: Literal[True], @@ -195,7 +201,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], previous_response_id: str | None, stream: Literal[False], @@ -207,7 +213,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], previous_response_id: str | None, stream: Literal[True] | Literal[False] = False, @@ -253,7 +259,7 @@ async def _fetch_response( tool_choice=tool_choice, parallel_tool_calls=parallel_tool_calls, stream=stream, - extra_headers=_HEADERS, + extra_headers={**_HEADERS, **(model_settings.extra_headers or {})}, extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, text=response_format, @@ -271,7 +277,7 @@ def _get_client(self) -> AsyncOpenAI: @dataclass class ConvertedTools: tools: list[ToolParam] - includes: list[IncludeLiteral] + includes: list[ResponseIncludable] class Converter: @@ -299,6 +305,18 @@ def convert_tool_choice( return { "type": "computer_use_preview", } + elif tool_choice == "image_generation": + return { + "type": "image_generation", + } + elif tool_choice == "code_interpreter": + return { + "type": "code_interpreter", + } + elif tool_choice == "mcp": + return { + "type": "mcp", + } else: return { "type": "function", @@ -307,7 +325,7 @@ def convert_tool_choice( @classmethod def get_response_format( - cls, output_schema: AgentOutputSchema | None + cls, output_schema: AgentOutputSchemaBase | None ) -> ResponseTextConfigParam | NotGiven: if output_schema is None or output_schema.is_plain_text(): return NOT_GIVEN @@ -317,7 +335,7 @@ def get_response_format( "type": "json_schema", "name": "final_output", "schema": output_schema.json_schema(), - "strict": output_schema.strict_json_schema, + "strict": output_schema.is_strict_json_schema(), } } @@ -328,7 +346,7 @@ def convert_tools( handoffs: list[Handoff[Any]], ) -> ConvertedTools: converted_tools: list[ToolParam] = [] - includes: list[IncludeLiteral] = [] + includes: list[ResponseIncludable] = [] computer_tools = [tool for tool in tools if isinstance(tool, ComputerTool)] if len(computer_tools) > 1: @@ -346,7 +364,7 @@ def convert_tools( return ConvertedTools(tools=converted_tools, includes=includes) @classmethod - def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, IncludeLiteral | None]: + def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, ResponseIncludable | None]: """Returns converted tool and includes""" if isinstance(tool, FunctionTool): @@ -357,7 +375,7 @@ def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, IncludeLiteral | None]: "type": "function", "description": tool.description, } - includes: IncludeLiteral | None = None + includes: ResponseIncludable | None = None elif isinstance(tool, WebSearchTool): ws: WebSearchToolParam = { "type": "web_search_preview", @@ -387,7 +405,20 @@ def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, IncludeLiteral | None]: "display_height": tool.computer.dimensions[1], } includes = None - + elif isinstance(tool, HostedMCPTool): + converted_tool = tool.tool_config + includes = None + elif isinstance(tool, ImageGenerationTool): + converted_tool = tool.tool_config + includes = None + elif isinstance(tool, CodeInterpreterTool): + converted_tool = tool.tool_config + includes = None + elif isinstance(tool, LocalShellTool): + converted_tool = { + "type": "local_shell", + } + includes = None else: raise UserError(f"Unknown tool type: {type(tool)}, tool") diff --git a/src/agents/result.py b/src/agents/result.py index a2a6cc4a..243db155 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -10,11 +10,12 @@ from ._run_impl import QueueCompleteSentinel from .agent import Agent -from .agent_output import AgentOutputSchema +from .agent_output import AgentOutputSchemaBase from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded from .guardrail import InputGuardrailResult, OutputGuardrailResult from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .logger import logger +from .run_context import RunContextWrapper from .stream_events import StreamEvent from .tracing import Trace from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming @@ -50,6 +51,9 @@ class RunResultBase(abc.ABC): output_guardrail_results: list[OutputGuardrailResult] """Guardrail results for the final output of the agent.""" + context_wrapper: RunContextWrapper[Any] + """The context wrapper for the agent run.""" + @property @abc.abstractmethod def last_agent(self) -> Agent[Any]: @@ -124,9 +128,9 @@ class RunResultStreaming(RunResultBase): final_output: Any """The final output of the agent. This is None until the agent has finished running.""" - _current_agent_output_schema: AgentOutputSchema | None = field(repr=False) + _current_agent_output_schema: AgentOutputSchemaBase | None = field(repr=False) - _trace: Trace | None = field(repr=False) + trace: Trace | None = field(repr=False) is_complete: bool = False """Whether the agent has finished running.""" @@ -152,6 +156,18 @@ def last_agent(self) -> Agent[Any]: """ return self.current_agent + def cancel(self) -> None: + """Cancels the streaming run, stopping all background tasks and marking the run as + complete.""" + self._cleanup_tasks() # Cancel all running tasks + self.is_complete = True # Mark the run as complete to stop event streaming + + # Optionally, clear the event queue to prevent processing stale events + while not self._event_queue.empty(): + self._event_queue.get_nowait() + while not self._input_guardrail_queue.empty(): + self._input_guardrail_queue.get_nowait() + async def stream_events(self) -> AsyncIterator[StreamEvent]: """Stream deltas for new items as they are generated. We're using the types from the OpenAI Responses API, so these are semantic events: each event has a `type` field that @@ -185,9 +201,6 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: yield item self._event_queue.task_done() - if self._trace: - self._trace.finish(reset_current=True) - self._cleanup_tasks() if self._stored_exception: diff --git a/src/agents/run.py b/src/agents/run.py index e2b0dbce..b196c3bf 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -19,7 +19,7 @@ get_model_tracing_impl, ) from .agent import Agent -from .agent_output import AgentOutputSchema +from .agent_output import AgentOutputSchema, AgentOutputSchemaBase from .exceptions import ( AgentsException, InputGuardrailTripwireTriggered, @@ -34,7 +34,7 @@ from .logger import logger from .model_settings import ModelSettings from .models.interface import Model, ModelProvider -from .models.openai_provider import OpenAIProvider +from .models.multi_provider import MultiProvider from .result import RunResult, RunResultStreaming from .run_context import RunContextWrapper, TContext from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent @@ -56,7 +56,7 @@ class RunConfig: agent. The model_provider passed in below must be able to resolve this model name. """ - model_provider: ModelProvider = field(default_factory=OpenAIProvider) + model_provider: ModelProvider = field(default_factory=MultiProvider) """The model provider to use when looking up string model names. Defaults to OpenAI.""" model_settings: ModelSettings | None = None @@ -185,7 +185,7 @@ async def run( if current_span is None: handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] if output_schema := cls._get_output_schema(current_agent): - output_type_name = output_schema.output_type_name() + output_type_name = output_schema.name() else: output_type_name = "str" @@ -270,6 +270,7 @@ async def run( _last_agent=current_agent, input_guardrail_results=input_guardrail_results, output_guardrail_results=output_guardrail_results, + context_wrapper=context_wrapper, ) elif isinstance(turn_result.next_step, NextStepHandoff): current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) @@ -404,10 +405,6 @@ def run_streamed( disabled=run_config.tracing_disabled, ) ) - # Need to start the trace here, because the current trace contextvar is captured at - # asyncio.create_task time - if new_trace: - new_trace.start(mark_as_current=True) output_schema = cls._get_output_schema(starting_agent) context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( @@ -426,7 +423,8 @@ def run_streamed( input_guardrail_results=[], output_guardrail_results=[], _current_agent_output_schema=output_schema, - _trace=new_trace, + trace=new_trace, + context_wrapper=context_wrapper, ) # Kick off the actual agent loop in the background and return the streamed result object. @@ -499,6 +497,9 @@ async def _run_streamed_impl( run_config: RunConfig, previous_response_id: str | None, ): + if streamed_result.trace: + streamed_result.trace.start(mark_as_current=True) + current_span: Span[AgentSpanData] | None = None current_agent = starting_agent current_turn = 0 @@ -517,7 +518,7 @@ async def _run_streamed_impl( if current_span is None: handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] if output_schema := cls._get_output_schema(current_agent): - output_type_name = output_schema.output_type_name() + output_type_name = output_schema.name() else: output_type_name = "str" @@ -625,6 +626,8 @@ async def _run_streamed_impl( finally: if current_span: current_span.finish(reset_current=True) + if streamed_result.trace: + streamed_result.trace.finish(reset_current=True) @classmethod async def _run_single_turn_streamed( @@ -686,6 +689,8 @@ async def _run_single_turn_streamed( input_tokens=event.response.usage.input_tokens, output_tokens=event.response.usage.output_tokens, total_tokens=event.response.usage.total_tokens, + input_tokens_details=event.response.usage.input_tokens_details, + output_tokens_details=event.response.usage.output_tokens_details, ) if event.response.usage else Usage() @@ -695,6 +700,7 @@ async def _run_single_turn_streamed( usage=usage, response_id=event.response.id, ) + context_wrapper.usage.add(usage) streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) @@ -789,7 +795,7 @@ async def _get_single_step_result_from_response( original_input: str | list[TResponseInputItem], pre_step_items: list[RunItem], new_response: ModelResponse, - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], @@ -900,7 +906,7 @@ async def _get_new_response( agent: Agent[TContext], system_prompt: str | None, input: list[TResponseInputItem], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, all_tools: list[Tool], handoffs: list[Handoff], context_wrapper: RunContextWrapper[TContext], @@ -930,9 +936,11 @@ async def _get_new_response( return new_response @classmethod - def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchema | None: + def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None: if agent.output_type is None or agent.output_type is str: return None + elif isinstance(agent.output_type, AgentOutputSchemaBase): + return agent.output_type return AgentOutputSchema(agent.output_type) diff --git a/src/agents/stream_events.py b/src/agents/stream_events.py index bd37d11f..111d0b95 100644 --- a/src/agents/stream_events.py +++ b/src/agents/stream_events.py @@ -35,6 +35,8 @@ class RunItemStreamEvent: "tool_called", "tool_output", "reasoning_item_created", + "mcp_approval_requested", + "mcp_list_tools", ] """The name of the event.""" diff --git a/src/agents/tool.py b/src/agents/tool.py index c1c16242..fd5a21c8 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -7,9 +7,11 @@ from typing import Any, Callable, Literal, Union, overload from openai.types.responses.file_search_tool_param import Filters, RankingOptions +from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest +from openai.types.responses.tool_param import CodeInterpreter, ImageGeneration, Mcp from openai.types.responses.web_search_tool_param import UserLocation from pydantic import ValidationError -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import Concatenate, NotRequired, ParamSpec, TypedDict from . import _debug from .computer import AsyncComputer, Computer @@ -130,7 +132,115 @@ def name(self): return "computer_use_preview" -Tool = Union[FunctionTool, FileSearchTool, WebSearchTool, ComputerTool] +@dataclass +class MCPToolApprovalRequest: + """A request to approve a tool call.""" + + ctx_wrapper: RunContextWrapper[Any] + """The run context.""" + + data: McpApprovalRequest + """The data from the MCP tool approval request.""" + + +class MCPToolApprovalFunctionResult(TypedDict): + """The result of an MCP tool approval function.""" + + approve: bool + """Whether to approve the tool call.""" + + reason: NotRequired[str] + """An optional reason, if rejected.""" + + +MCPToolApprovalFunction = Callable[ + [MCPToolApprovalRequest], MaybeAwaitable[MCPToolApprovalFunctionResult] +] +"""A function that approves or rejects a tool call.""" + + +@dataclass +class HostedMCPTool: + """A tool that allows the LLM to use a remote MCP server. The LLM will automatically list and + call tools, without requiring a a round trip back to your code. + If you want to run MCP servers locally via stdio, in a VPC or other non-publicly-accessible + environment, or you just prefer to run tool calls locally, then you can instead use the servers + in `agents.mcp` and pass `Agent(mcp_servers=[...])` to the agent.""" + + tool_config: Mcp + """The MCP tool config, which includes the server URL and other settings.""" + + on_approval_request: MCPToolApprovalFunction | None = None + """An optional function that will be called if approval is requested for an MCP tool. If not + provided, you will need to manually add approvals/rejections to the input and call + `Runner.run(...)` again.""" + + @property + def name(self): + return "hosted_mcp" + + +@dataclass +class CodeInterpreterTool: + """A tool that allows the LLM to execute code in a sandboxed environment.""" + + tool_config: CodeInterpreter + """The tool config, which includes the container and other settings.""" + + @property + def name(self): + return "code_interpreter" + + +@dataclass +class ImageGenerationTool: + """A tool that allows the LLM to generate images.""" + + tool_config: ImageGeneration + """The tool config, which image generation settings.""" + + @property + def name(self): + return "image_generation" + + +@dataclass +class LocalShellCommandRequest: + """A request to execute a command on a shell.""" + + ctx_wrapper: RunContextWrapper[Any] + """The run context.""" + + data: LocalShellCall + """The data from the local shell tool call.""" + + +LocalShellExecutor = Callable[[LocalShellCommandRequest], MaybeAwaitable[str]] +"""A function that executes a command on a shell.""" + + +@dataclass +class LocalShellTool: + """A tool that allows the LLM to execute commands on a shell.""" + + executor: LocalShellExecutor + """A function that executes a command on a shell.""" + + @property + def name(self): + return "local_shell" + + +Tool = Union[ + FunctionTool, + FileSearchTool, + WebSearchTool, + ComputerTool, + HostedMCPTool, + LocalShellTool, + ImageGenerationTool, + CodeInterpreterTool, +] """A tool that can be used in an agent.""" diff --git a/src/agents/tracing/processors.py b/src/agents/tracing/processors.py index f929d05d..2913b11a 100644 --- a/src/agents/tracing/processors.py +++ b/src/agents/tracing/processors.py @@ -102,6 +102,12 @@ def export(self, items: list[Trace | Span[Any]]) -> None: "OpenAI-Beta": "traces=v1", } + if self.organization: + headers["OpenAI-Organization"] = self.organization + + if self.project: + headers["OpenAI-Project"] = self.project + # Exponential backoff loop attempt = 0 delay = self.base_delay diff --git a/src/agents/tracing/span_data.py b/src/agents/tracing/span_data.py index 1a450280..cb3e8491 100644 --- a/src/agents/tracing/span_data.py +++ b/src/agents/tracing/span_data.py @@ -338,7 +338,7 @@ def __init__( @property def type(self) -> str: - return "speech-group" + return "speech_group" def export(self) -> dict[str, Any]: return { diff --git a/src/agents/usage.py b/src/agents/usage.py index 23d989b4..843f6293 100644 --- a/src/agents/usage.py +++ b/src/agents/usage.py @@ -1,4 +1,6 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field + +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails @dataclass @@ -9,9 +11,18 @@ class Usage: input_tokens: int = 0 """Total input tokens sent, across all requests.""" + input_tokens_details: InputTokensDetails = field( + default_factory=lambda: InputTokensDetails(cached_tokens=0) + ) + """Details about the input tokens, matching responses API usage details.""" output_tokens: int = 0 """Total output tokens received, across all requests.""" + output_tokens_details: OutputTokensDetails = field( + default_factory=lambda: OutputTokensDetails(reasoning_tokens=0) + ) + """Details about the output tokens, matching responses API usage details.""" + total_tokens: int = 0 """Total tokens sent and received, across all requests.""" @@ -20,3 +31,12 @@ def add(self, other: "Usage") -> None: self.input_tokens += other.input_tokens if other.input_tokens else 0 self.output_tokens += other.output_tokens if other.output_tokens else 0 self.total_tokens += other.total_tokens if other.total_tokens else 0 + self.input_tokens_details = InputTokensDetails( + cached_tokens=self.input_tokens_details.cached_tokens + + other.input_tokens_details.cached_tokens + ) + + self.output_tokens_details = OutputTokensDetails( + reasoning_tokens=self.output_tokens_details.reasoning_tokens + + other.output_tokens_details.reasoning_tokens + ) diff --git a/src/agents/voice/__init__.py b/src/agents/voice/__init__.py index 499c064c..e11ee446 100644 --- a/src/agents/voice/__init__.py +++ b/src/agents/voice/__init__.py @@ -7,6 +7,7 @@ STTModelSettings, TTSModel, TTSModelSettings, + TTSVoice, VoiceModelProvider, ) from .models.openai_model_provider import OpenAIVoiceModelProvider @@ -30,6 +31,7 @@ "STTModelSettings", "TTSModel", "TTSModelSettings", + "TTSVoice", "VoiceModelProvider", "StreamedAudioResult", "SingleAgentVoiceWorkflow", diff --git a/src/agents/voice/model.py b/src/agents/voice/model.py index 220d4b48..c36a4de7 100644 --- a/src/agents/voice/model.py +++ b/src/agents/voice/model.py @@ -14,14 +14,13 @@ ) DEFAULT_TTS_BUFFER_SIZE = 120 +TTSVoice = Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"] +"""Exportable type for the TTSModelSettings voice enum""" @dataclass class TTSModelSettings: """Settings for a TTS model.""" - - voice: ( - Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"] | None - ) = None + voice: TTSVoice | None = None """ The voice to use for the TTS model. If not provided, the default voice for the respective model will be used. diff --git a/tests/fake_model.py b/tests/fake_model.py index 52d3a3b2..9f0c83a2 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -3,9 +3,10 @@ from collections.abc import AsyncIterator from typing import Any -from openai.types.responses import Response, ResponseCompletedEvent +from openai.types.responses import Response, ResponseCompletedEvent, ResponseUsage +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails -from agents.agent_output import AgentOutputSchema +from agents.agent_output import AgentOutputSchemaBase from agents.handoffs import Handoff from agents.items import ( ModelResponse, @@ -33,6 +34,10 @@ def __init__( ) self.tracing_enabled = tracing_enabled self.last_turn_args: dict[str, Any] = {} + self.hardcoded_usage: Usage | None = None + + def set_hardcoded_usage(self, usage: Usage): + self.hardcoded_usage = usage def set_next_output(self, output: list[TResponseOutputItem] | Exception): self.turn_outputs.append(output) @@ -51,7 +56,7 @@ async def get_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, *, @@ -83,7 +88,7 @@ async def get_response( return ModelResponse( output=output, - usage=Usage(), + usage=self.hardcoded_usage or Usage(), response_id=None, ) @@ -93,7 +98,7 @@ async def stream_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, *, @@ -123,11 +128,16 @@ async def stream_response( yield ResponseCompletedEvent( type="response.completed", - response=get_response_obj(output), + response=get_response_obj(output, usage=self.hardcoded_usage), + sequence_number=0, ) -def get_response_obj(output: list[TResponseOutputItem], response_id: str | None = None) -> Response: +def get_response_obj( + output: list[TResponseOutputItem], + response_id: str | None = None, + usage: Usage | None = None, +) -> Response: return Response( id=response_id or "123", created_at=123, @@ -138,4 +148,11 @@ def get_response_obj(output: list[TResponseOutputItem], response_id: str | None tools=[], top_p=None, parallel_tool_calls=False, + usage=ResponseUsage( + input_tokens=usage.input_tokens if usage else 0, + output_tokens=usage.output_tokens if usage else 0, + total_tokens=usage.total_tokens if usage else 0, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + ), ) diff --git a/tests/fastapi/__init__.py b/tests/fastapi/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fastapi/streaming_app.py b/tests/fastapi/streaming_app.py new file mode 100644 index 00000000..b93ccf3f --- /dev/null +++ b/tests/fastapi/streaming_app.py @@ -0,0 +1,30 @@ +from collections.abc import AsyncIterator + +from fastapi import FastAPI +from starlette.responses import StreamingResponse + +from agents import Agent, Runner, RunResultStreaming + +agent = Agent( + name="Assistant", + instructions="You are a helpful assistant.", +) + + +app = FastAPI() + + +@app.post("/stream") +async def stream(): + result = Runner.run_streamed(agent, input="Tell me a joke") + stream_handler = StreamHandler(result) + return StreamingResponse(stream_handler.stream_events(), media_type="application/x-ndjson") + + +class StreamHandler: + def __init__(self, result: RunResultStreaming): + self.result = result + + async def stream_events(self) -> AsyncIterator[str]: + async for event in self.result.stream_events(): + yield f"{event.type}\n\n" diff --git a/tests/fastapi/test_streaming_context.py b/tests/fastapi/test_streaming_context.py new file mode 100644 index 00000000..ee13045e --- /dev/null +++ b/tests/fastapi/test_streaming_context.py @@ -0,0 +1,29 @@ +import pytest +from httpx import ASGITransport, AsyncClient +from inline_snapshot import snapshot + +from ..fake_model import FakeModel +from ..test_responses import get_text_message +from .streaming_app import agent, app + + +@pytest.mark.asyncio +async def test_streaming_context(): + """This ensures that FastAPI streaming works. The context for this test is that the Runner + method was called in one async context, and the streaming was ended in another context, + leading to a tracing error because the context was closed in the wrong context. This test + ensures that this actually works. + """ + model = FakeModel() + agent.model = model + model.set_next_output([get_text_message("done")]) + + transport = ASGITransport(app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + async with ac.stream("POST", "/stream") as r: + assert r.status_code == 200 + body = (await r.aread()).decode("utf-8") + lines = [line for line in body.splitlines() if line] + assert lines == snapshot( + ["agent_updated_stream_event", "raw_response_event", "run_item_stream_event"] + ) diff --git a/tests/mcp/test_server_errors.py b/tests/mcp/test_server_errors.py index bdca7ce6..fbd8db17 100644 --- a/tests/mcp/test_server_errors.py +++ b/tests/mcp/test_server_errors.py @@ -6,7 +6,7 @@ class CrashingClientSessionServer(_MCPServerWithClientSession): def __init__(self): - super().__init__(cache_tools_list=False) + super().__init__(cache_tools_list=False, client_session_timeout_seconds=5) self.cleanup_called = False def create_streams(self): diff --git a/tests/model_settings/test_serialization.py b/tests/model_settings/test_serialization.py new file mode 100644 index 00000000..d76a58d1 --- /dev/null +++ b/tests/model_settings/test_serialization.py @@ -0,0 +1,59 @@ +import json +from dataclasses import fields + +from openai.types.shared import Reasoning + +from agents.model_settings import ModelSettings + + +def verify_serialization(model_settings: ModelSettings) -> None: + """Verify that ModelSettings can be serialized to a JSON string.""" + json_dict = model_settings.to_json_dict() + json_string = json.dumps(json_dict) + assert json_string is not None + + +def test_basic_serialization() -> None: + """Tests whether ModelSettings can be serialized to a JSON string.""" + + # First, lets create a ModelSettings instance + model_settings = ModelSettings( + temperature=0.5, + top_p=0.9, + max_tokens=100, + ) + + # Now, lets serialize the ModelSettings instance to a JSON string + verify_serialization(model_settings) + + +def test_all_fields_serialization() -> None: + """Tests whether ModelSettings can be serialized to a JSON string.""" + + # First, lets create a ModelSettings instance + model_settings = ModelSettings( + temperature=0.5, + top_p=0.9, + frequency_penalty=0.0, + presence_penalty=0.0, + tool_choice="auto", + parallel_tool_calls=True, + truncation="auto", + max_tokens=100, + reasoning=Reasoning(), + metadata={"foo": "bar"}, + store=False, + include_usage=False, + extra_query={"foo": "bar"}, + extra_body={"foo": "bar"}, + extra_headers={"foo": "bar"}, + ) + + # Verify that every single field is set to a non-None value + for field in fields(model_settings): + assert getattr(model_settings, field.name) is not None, ( + f"You must set the {field.name} field" + ) + + # Now, lets serialize the ModelSettings instance to a JSON string + verify_serialization(model_settings) diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models/conftest.py b/tests/models/conftest.py new file mode 100644 index 00000000..79d85d8b --- /dev/null +++ b/tests/models/conftest.py @@ -0,0 +1,11 @@ +import os +import sys + + +# Skip voice tests on Python 3.9 +def pytest_ignore_collect(collection_path, config): + if sys.version_info[:2] == (3, 9): + this_dir = os.path.dirname(__file__) + + if str(collection_path).startswith(this_dir): + return True diff --git a/tests/models/test_litellm_chatcompletions_stream.py b/tests/models/test_litellm_chatcompletions_stream.py new file mode 100644 index 00000000..06e46b39 --- /dev/null +++ b/tests/models/test_litellm_chatcompletions_stream.py @@ -0,0 +1,298 @@ +from collections.abc import AsyncIterator + +import pytest +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) +from openai.types.completion_usage import ( + CompletionTokensDetails, + CompletionUsage, + PromptTokensDetails, +) +from openai.types.responses import ( + Response, + ResponseFunctionToolCall, + ResponseOutputMessage, + ResponseOutputRefusal, + ResponseOutputText, +) + +from agents.extensions.models.litellm_model import LitellmModel +from agents.extensions.models.litellm_provider import LitellmProvider +from agents.model_settings import ModelSettings +from agents.models.interface import ModelTracing + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_yields_events_for_text_content(monkeypatch) -> None: + """ + Validate that `stream_response` emits the correct sequence of events when + streaming a simple assistant message consisting of plain text content. + We simulate two chunks of text returned from the chat completion stream. + """ + # Create two chunks that will be emitted by the fake stream. + chunk1 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(content="He"))], + ) + # Mark last chunk with usage so stream_response knows this is final. + chunk2 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(content="llo"))], + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=2), + prompt_tokens_details=PromptTokensDetails(cached_tokens=6), + ), + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + for c in (chunk1, chunk2): + yield c + + # Patch _fetch_response to inject our fake stream + async def patched_fetch_response(self, *args, **kwargs): + # `_fetch_response` is expected to return a Response skeleton and the async stream + resp = Response( + id="resp-id", + created_at=0, + model="fake-model", + object="response", + output=[], + tool_choice="none", + tools=[], + parallel_tool_calls=False, + ) + return resp, fake_stream() + + monkeypatch.setattr(LitellmModel, "_fetch_response", patched_fetch_response) + model = LitellmProvider().get_model("gpt-4") + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ): + output_events.append(event) + # We expect a response.created, then a response.output_item.added, content part added, + # two content delta events (for "He" and "llo"), a content part done, the assistant message + # output_item.done, and finally response.completed. + # There should be 8 events in total. + assert len(output_events) == 8 + # First event indicates creation. + assert output_events[0].type == "response.created" + # The output item added and content part added events should mark the assistant message. + assert output_events[1].type == "response.output_item.added" + assert output_events[2].type == "response.content_part.added" + # Two text delta events. + assert output_events[3].type == "response.output_text.delta" + assert output_events[3].delta == "He" + assert output_events[4].type == "response.output_text.delta" + assert output_events[4].delta == "llo" + # After streaming, the content part and item should be marked done. + assert output_events[5].type == "response.content_part.done" + assert output_events[6].type == "response.output_item.done" + # Last event indicates completion of the stream. + assert output_events[7].type == "response.completed" + # The completed response should have one output message with full text. + completed_resp = output_events[7].response + assert isinstance(completed_resp.output[0], ResponseOutputMessage) + assert isinstance(completed_resp.output[0].content[0], ResponseOutputText) + assert completed_resp.output[0].content[0].text == "Hello" + + assert completed_resp.usage, "usage should not be None" + assert completed_resp.usage.input_tokens == 7 + assert completed_resp.usage.output_tokens == 5 + assert completed_resp.usage.total_tokens == 12 + assert completed_resp.usage.input_tokens_details.cached_tokens == 6 + assert completed_resp.usage.output_tokens_details.reasoning_tokens == 2 + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_yields_events_for_refusal_content(monkeypatch) -> None: + """ + Validate that when the model streams a refusal string instead of normal content, + `stream_response` emits the appropriate sequence of events including + `response.refusal.delta` events for each chunk of the refusal message and + constructs a completed assistant message with a `ResponseOutputRefusal` part. + """ + # Simulate refusal text coming in two pieces, like content but using the `refusal` + # field on the delta rather than `content`. + chunk1 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(refusal="No"))], + ) + chunk2 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(refusal="Thanks"))], + usage=CompletionUsage(completion_tokens=2, prompt_tokens=2, total_tokens=4), + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + for c in (chunk1, chunk2): + yield c + + async def patched_fetch_response(self, *args, **kwargs): + resp = Response( + id="resp-id", + created_at=0, + model="fake-model", + object="response", + output=[], + tool_choice="none", + tools=[], + parallel_tool_calls=False, + ) + return resp, fake_stream() + + monkeypatch.setattr(LitellmModel, "_fetch_response", patched_fetch_response) + model = LitellmProvider().get_model("gpt-4") + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ): + output_events.append(event) + # Expect sequence similar to text: created, output_item.added, content part added, + # two refusal delta events, content part done, output_item.done, completed. + assert len(output_events) == 8 + assert output_events[0].type == "response.created" + assert output_events[1].type == "response.output_item.added" + assert output_events[2].type == "response.content_part.added" + assert output_events[3].type == "response.refusal.delta" + assert output_events[3].delta == "No" + assert output_events[4].type == "response.refusal.delta" + assert output_events[4].delta == "Thanks" + assert output_events[5].type == "response.content_part.done" + assert output_events[6].type == "response.output_item.done" + assert output_events[7].type == "response.completed" + completed_resp = output_events[7].response + assert isinstance(completed_resp.output[0], ResponseOutputMessage) + refusal_part = completed_resp.output[0].content[0] + assert isinstance(refusal_part, ResponseOutputRefusal) + assert refusal_part.refusal == "NoThanks" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_stream_response_yields_events_for_tool_call(monkeypatch) -> None: + """ + Validate that `stream_response` emits the correct sequence of events when + the model is streaming a function/tool call instead of plain text. + The function call will be split across two chunks. + """ + # Simulate a single tool call whose ID stays constant and function name/args built over chunks. + tool_call_delta1 = ChoiceDeltaToolCall( + index=0, + id="tool-id", + function=ChoiceDeltaToolCallFunction(name="my_", arguments="arg1"), + type="function", + ) + tool_call_delta2 = ChoiceDeltaToolCall( + index=0, + id="tool-id", + function=ChoiceDeltaToolCallFunction(name="func", arguments="arg2"), + type="function", + ) + chunk1 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta1]))], + ) + chunk2 = ChatCompletionChunk( + id="chunk-id", + created=1, + model="fake", + object="chat.completion.chunk", + choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta2]))], + usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2), + ) + + async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: + for c in (chunk1, chunk2): + yield c + + async def patched_fetch_response(self, *args, **kwargs): + resp = Response( + id="resp-id", + created_at=0, + model="fake-model", + object="response", + output=[], + tool_choice="none", + tools=[], + parallel_tool_calls=False, + ) + return resp, fake_stream() + + monkeypatch.setattr(LitellmModel, "_fetch_response", patched_fetch_response) + model = LitellmProvider().get_model("gpt-4") + output_events = [] + async for event in model.stream_response( + system_instructions=None, + input="", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ): + output_events.append(event) + # Sequence should be: response.created, then after loop we expect function call-related events: + # one response.output_item.added for function call, a response.function_call_arguments.delta, + # a response.output_item.done, and finally response.completed. + assert output_events[0].type == "response.created" + # The next three events are about the tool call. + assert output_events[1].type == "response.output_item.added" + # The added item should be a ResponseFunctionToolCall. + added_fn = output_events[1].item + assert isinstance(added_fn, ResponseFunctionToolCall) + assert added_fn.name == "my_func" # Name should be concatenation of both chunks. + assert added_fn.arguments == "arg1arg2" + assert output_events[2].type == "response.function_call_arguments.delta" + assert output_events[2].delta == "arg1arg2" + assert output_events[3].type == "response.output_item.done" + assert output_events[4].type == "response.completed" + assert output_events[2].delta == "arg1arg2" + assert output_events[3].type == "response.output_item.done" + assert output_events[4].type == "response.completed" + assert added_fn.name == "my_func" # Name should be concatenation of both chunks. + assert added_fn.arguments == "arg1arg2" + assert output_events[2].type == "response.function_call_arguments.delta" + assert output_events[2].delta == "arg1arg2" + assert output_events[3].type == "response.output_item.done" + assert output_events[4].type == "response.completed" diff --git a/tests/models/test_litellm_extra_body.py b/tests/models/test_litellm_extra_body.py new file mode 100644 index 00000000..ac56c25c --- /dev/null +++ b/tests/models/test_litellm_extra_body.py @@ -0,0 +1,45 @@ +import litellm +import pytest +from litellm.types.utils import Choices, Message, ModelResponse, Usage + +from agents.extensions.models.litellm_model import LitellmModel +from agents.model_settings import ModelSettings +from agents.models.interface import ModelTracing + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_extra_body_is_forwarded(monkeypatch): + """ + Forward `extra_body` entries into litellm.acompletion kwargs. + + This ensures that user-provided parameters (e.g. cached_content) + arrive alongside default arguments. + """ + captured: dict[str, object] = {} + + async def fake_acompletion(model, messages=None, **kwargs): + captured.update(kwargs) + msg = Message(role="assistant", content="ok") + choice = Choices(index=0, message=msg) + return ModelResponse(choices=[choice], usage=Usage(0, 0, 0)) + + monkeypatch.setattr(litellm, "acompletion", fake_acompletion) + settings = ModelSettings( + temperature=0.1, + extra_body={"cached_content": "some_cache", "foo": 123} + ) + model = LitellmModel(model="test-model") + + await model.get_response( + system_instructions=None, + input=[], + model_settings=settings, + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + + assert {"cached_content": "some_cache", "foo": 123}.items() <= captured.items() diff --git a/tests/models/test_map.py b/tests/models/test_map.py new file mode 100644 index 00000000..6b65fc09 --- /dev/null +++ b/tests/models/test_map.py @@ -0,0 +1,20 @@ +from agents import Agent, OpenAIResponsesModel, RunConfig, Runner +from agents.extensions.models.litellm_model import LitellmModel + + +def test_no_prefix_is_openai(): + agent = Agent(model="gpt-4o", instructions="", name="test") + model = Runner._get_model(agent, RunConfig()) + assert isinstance(model, OpenAIResponsesModel) + + +def openai_prefix_is_openai(): + agent = Agent(model="openai/gpt-4o", instructions="", name="test") + model = Runner._get_model(agent, RunConfig()) + assert isinstance(model, OpenAIResponsesModel) + + +def test_litellm_prefix_is_litellm(): + agent = Agent(model="litellm/foo/bar", instructions="", name="test") + model = Runner._get_model(agent, RunConfig()) + assert isinstance(model, LitellmModel) diff --git a/tests/test_agent_config.py b/tests/test_agent_config.py index 44339dad..f79c0cf8 100644 --- a/tests/test_agent_config.py +++ b/tests/test_agent_config.py @@ -1,7 +1,7 @@ import pytest from pydantic import BaseModel -from agents import Agent, Handoff, RunContextWrapper, Runner, handoff +from agents import Agent, AgentOutputSchema, Handoff, RunContextWrapper, Runner, handoff @pytest.mark.asyncio @@ -160,8 +160,9 @@ async def test_agent_final_output(): ) schema = Runner._get_output_schema(agent) + assert isinstance(schema, AgentOutputSchema) assert schema is not None assert schema.output_type == Foo - assert schema.strict_json_schema is True + assert schema.is_strict_json_schema() is True assert schema.json_schema() is not None assert not schema.is_plain_text() diff --git a/tests/test_cancel_streaming.py b/tests/test_cancel_streaming.py new file mode 100644 index 00000000..3417a3c5 --- /dev/null +++ b/tests/test_cancel_streaming.py @@ -0,0 +1,116 @@ +import json + +import pytest + +from agents import Agent, Runner + +from .fake_model import FakeModel +from .test_responses import get_function_tool, get_function_tool_call, get_text_message + + +@pytest.mark.asyncio +async def test_simple_streaming_with_cancel(): + model = FakeModel() + agent = Agent(name="Joker", model=model) + + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + num_events = 0 + stop_after = 1 # There are two that the model gives back. + + async for _event in result.stream_events(): + num_events += 1 + if num_events == stop_after: + result.cancel() + + assert num_events == 1, f"Expected {stop_after} visible events, but got {num_events}" + + +@pytest.mark.asyncio +async def test_multiple_events_streaming_with_cancel(): + model = FakeModel() + agent = Agent( + name="Joker", + model=model, + tools=[get_function_tool("foo", "tool_result")], + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [ + get_text_message("a_message"), + get_function_tool_call("foo", json.dumps({"a": "b"})), + ], + # Second turn: text message + [get_text_message("done")], + ] + ) + + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + num_events = 0 + stop_after = 2 + + async for _ in result.stream_events(): + num_events += 1 + if num_events == stop_after: + result.cancel() + + assert num_events == stop_after, f"Expected {stop_after} visible events, but got {num_events}" + + +@pytest.mark.asyncio +async def test_cancel_prevents_further_events(): + model = FakeModel() + agent = Agent(name="Joker", model=model) + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + events = [] + async for event in result.stream_events(): + events.append(event) + result.cancel() + break # Cancel after first event + # Try to get more events after cancel + more_events = [e async for e in result.stream_events()] + assert len(events) == 1 + assert more_events == [], "No events should be yielded after cancel()" + + +@pytest.mark.asyncio +async def test_cancel_is_idempotent(): + model = FakeModel() + agent = Agent(name="Joker", model=model) + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + events = [] + async for event in result.stream_events(): + events.append(event) + result.cancel() + result.cancel() # Call cancel again + break + # Should not raise or misbehave + assert len(events) == 1 + + +@pytest.mark.asyncio +async def test_cancel_before_streaming(): + model = FakeModel() + agent = Agent(name="Joker", model=model) + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + result.cancel() # Cancel before streaming + events = [e async for e in result.stream_events()] + assert events == [], "No events should be yielded if cancel() is called before streaming." + + +@pytest.mark.asyncio +async def test_cancel_cleans_up_resources(): + model = FakeModel() + agent = Agent(name="Joker", model=model) + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + # Start streaming, then cancel + async for _ in result.stream_events(): + result.cancel() + break + # After cancel, queues should be empty and is_complete True + assert result.is_complete, "Result should be marked complete after cancel." + assert result._event_queue.empty(), "Event queue should be empty after cancel." + assert result._input_guardrail_queue.empty(), ( + "Input guardrail queue should be empty after cancel." + ) diff --git a/tests/test_extra_headers.py b/tests/test_extra_headers.py new file mode 100644 index 00000000..a6af3007 --- /dev/null +++ b/tests/test_extra_headers.py @@ -0,0 +1,100 @@ +import pytest +from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails + +from agents import ModelSettings, ModelTracing, OpenAIChatCompletionsModel, OpenAIResponsesModel + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_extra_headers_passed_to_openai_responses_model(): + """ + Ensure extra_headers in ModelSettings is passed to the OpenAIResponsesModel client. + """ + called_kwargs = {} + + class DummyResponses: + async def create(self, **kwargs): + nonlocal called_kwargs + called_kwargs = kwargs + + class DummyResponse: + id = "dummy" + output = [] + usage = type( + "Usage", + (), + { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "input_tokens_details": InputTokensDetails(cached_tokens=0), + "output_tokens_details": OutputTokensDetails(reasoning_tokens=0), + }, + )() + + return DummyResponse() + + class DummyClient: + def __init__(self): + self.responses = DummyResponses() + + model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + extra_headers = {"X-Test-Header": "test-value"} + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(extra_headers=extra_headers), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + assert "extra_headers" in called_kwargs + assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value" + + +@pytest.mark.allow_call_model_methods +@pytest.mark.asyncio +async def test_extra_headers_passed_to_openai_client(): + """ + Ensure extra_headers in ModelSettings is passed to the OpenAI client. + """ + called_kwargs = {} + + class DummyCompletions: + async def create(self, **kwargs): + nonlocal called_kwargs + called_kwargs = kwargs + msg = ChatCompletionMessage(role="assistant", content="Hello") + choice = Choice(index=0, finish_reason="stop", message=msg) + return ChatCompletion( + id="resp-id", + created=0, + model="fake", + object="chat.completion", + choices=[choice], + usage=None, + ) + + class DummyClient: + def __init__(self): + self.chat = type("_Chat", (), {"completions": DummyCompletions()})() + self.base_url = "https://api.openai.com" + + model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + extra_headers = {"X-Test-Header": "test-value"} + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(extra_headers=extra_headers), + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + assert "extra_headers" in called_kwargs + assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value" diff --git a/tests/test_openai_chatcompletions.py b/tests/test_openai_chatcompletions.py index ba3ec68d..ba4605d0 100644 --- a/tests/test_openai_chatcompletions.py +++ b/tests/test_openai_chatcompletions.py @@ -13,7 +13,10 @@ ChatCompletionMessageToolCall, Function, ) -from openai.types.completion_usage import CompletionUsage +from openai.types.completion_usage import ( + CompletionUsage, + PromptTokensDetails, +) from openai.types.responses import ( Response, ResponseFunctionToolCall, @@ -51,7 +54,13 @@ async def test_get_response_with_text_message(monkeypatch) -> None: model="fake", object="chat.completion", choices=[choice], - usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12), + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + # completion_tokens_details left blank to test default + prompt_tokens_details=PromptTokensDetails(cached_tokens=3), + ), ) async def patched_fetch_response(self, *args, **kwargs): @@ -81,6 +90,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert resp.usage.input_tokens == 7 assert resp.usage.output_tokens == 5 assert resp.usage.total_tokens == 12 + assert resp.usage.input_tokens_details.cached_tokens == 3 + assert resp.usage.output_tokens_details.reasoning_tokens == 0 assert resp.response_id is None @@ -127,6 +138,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert resp.usage.requests == 0 assert resp.usage.input_tokens == 0 assert resp.usage.output_tokens == 0 + assert resp.usage.input_tokens_details.cached_tokens == 0 + assert resp.usage.output_tokens_details.reasoning_tokens == 0 @pytest.mark.allow_call_model_methods diff --git a/tests/test_openai_chatcompletions_converter.py b/tests/test_openai_chatcompletions_converter.py index e3a18b25..bcfca549 100644 --- a/tests/test_openai_chatcompletions_converter.py +++ b/tests/test_openai_chatcompletions_converter.py @@ -232,7 +232,7 @@ def test_convert_response_format_returns_not_given_for_plain_text_and_dict_for_s assert resp_format["type"] == "json_schema" assert resp_format["json_schema"]["name"] == "final_output" assert "strict" in resp_format["json_schema"] - assert resp_format["json_schema"]["strict"] == schema.strict_json_schema + assert resp_format["json_schema"]["strict"] == schema.is_strict_json_schema() assert "schema" in resp_format["json_schema"] assert resp_format["json_schema"]["schema"] == schema.json_schema() diff --git a/tests/test_openai_chatcompletions_stream.py b/tests/test_openai_chatcompletions_stream.py index b82f2430..5c8bb9e3 100644 --- a/tests/test_openai_chatcompletions_stream.py +++ b/tests/test_openai_chatcompletions_stream.py @@ -8,7 +8,11 @@ ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction, ) -from openai.types.completion_usage import CompletionUsage +from openai.types.completion_usage import ( + CompletionTokensDetails, + CompletionUsage, + PromptTokensDetails, +) from openai.types.responses import ( Response, ResponseFunctionToolCall, @@ -46,7 +50,13 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No model="fake", object="chat.completion.chunk", choices=[Choice(index=0, delta=ChoiceDelta(content="llo"))], - usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12), + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + prompt_tokens_details=PromptTokensDetails(cached_tokens=2), + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=3), + ), ) async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: @@ -112,6 +122,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert completed_resp.usage.input_tokens == 7 assert completed_resp.usage.output_tokens == 5 assert completed_resp.usage.total_tokens == 12 + assert completed_resp.usage.input_tokens_details.cached_tokens == 2 + assert completed_resp.usage.output_tokens_details.reasoning_tokens == 3 @pytest.mark.allow_call_model_methods diff --git a/tests/test_openai_responses_converter.py b/tests/test_openai_responses_converter.py index 34cbac5c..8e486665 100644 --- a/tests/test_openai_responses_converter.py +++ b/tests/test_openai_responses_converter.py @@ -92,7 +92,7 @@ class OutModel(BaseModel): assert inner.get("name") == "final_output" assert isinstance(inner.get("schema"), dict) # Should include a strict flag matching the schema's strictness setting. - assert inner.get("strict") == out_schema.strict_json_schema + assert inner.get("strict") == out_schema.is_strict_json_schema() def test_convert_tools_basic_types_and_includes(): diff --git a/tests/test_output_tool.py b/tests/test_output_tool.py index 86c4b3b5..37c1b1b6 100644 --- a/tests/test_output_tool.py +++ b/tests/test_output_tool.py @@ -1,10 +1,18 @@ import json +from typing import Any import pytest from pydantic import BaseModel from typing_extensions import TypedDict -from agents import Agent, AgentOutputSchema, ModelBehaviorError, Runner, UserError +from agents import ( + Agent, + AgentOutputSchema, + AgentOutputSchemaBase, + ModelBehaviorError, + Runner, + UserError, +) from agents.agent_output import _WRAPPER_DICT_KEY from agents.util import _json @@ -27,6 +35,7 @@ def test_structured_output_pydantic(): output_schema = Runner._get_output_schema(agent) assert output_schema, "Should have an output tool config with a structured output type" + assert isinstance(output_schema, AgentOutputSchema) assert output_schema.output_type == Foo, "Should have the correct output type" assert not output_schema._is_wrapped, "Pydantic objects should not be wrapped" for key, value in Foo.model_json_schema().items(): @@ -45,6 +54,7 @@ def test_structured_output_typed_dict(): agent = Agent(name="test", output_type=Bar) output_schema = Runner._get_output_schema(agent) assert output_schema, "Should have an output tool config with a structured output type" + assert isinstance(output_schema, AgentOutputSchema) assert output_schema.output_type == Bar, "Should have the correct output type" assert not output_schema._is_wrapped, "TypedDicts should not be wrapped" @@ -57,6 +67,7 @@ def test_structured_output_list(): agent = Agent(name="test", output_type=list[str]) output_schema = Runner._get_output_schema(agent) assert output_schema, "Should have an output tool config with a structured output type" + assert isinstance(output_schema, AgentOutputSchema) assert output_schema.output_type == list[str], "Should have the correct output type" assert output_schema._is_wrapped, "Lists should be wrapped" @@ -98,7 +109,7 @@ def test_plain_text_obj_doesnt_produce_schema(): def test_structured_output_is_strict(): output_wrapper = AgentOutputSchema(output_type=Foo) - assert output_wrapper.strict_json_schema + assert output_wrapper.is_strict_json_schema() for key, value in Foo.model_json_schema().items(): assert output_wrapper.json_schema()[key] == value @@ -110,6 +121,48 @@ def test_structured_output_is_strict(): def test_setting_strict_false_works(): output_wrapper = AgentOutputSchema(output_type=Foo, strict_json_schema=False) - assert not output_wrapper.strict_json_schema + assert not output_wrapper.is_strict_json_schema() assert output_wrapper.json_schema() == Foo.model_json_schema() assert output_wrapper.json_schema() == Foo.model_json_schema() + + +_CUSTOM_OUTPUT_SCHEMA_JSON_SCHEMA = { + "type": "object", + "properties": { + "foo": {"type": "string"}, + }, + "required": ["foo"], +} + + +class CustomOutputSchema(AgentOutputSchemaBase): + def is_plain_text(self) -> bool: + return False + + def name(self) -> str: + return "FooBarBaz" + + def json_schema(self) -> dict[str, Any]: + return _CUSTOM_OUTPUT_SCHEMA_JSON_SCHEMA + + def is_strict_json_schema(self) -> bool: + return False + + def validate_json(self, json_str: str) -> Any: + return ["some", "output"] + + +def test_custom_output_schema(): + custom_output_schema = CustomOutputSchema() + agent = Agent(name="test", output_type=custom_output_schema) + output_schema = Runner._get_output_schema(agent) + + assert output_schema, "Should have an output tool config with a structured output type" + assert isinstance(output_schema, CustomOutputSchema) + assert output_schema.json_schema() == _CUSTOM_OUTPUT_SCHEMA_JSON_SCHEMA + assert not output_schema.is_strict_json_schema() + assert not output_schema.is_plain_text() + + json_str = json.dumps({"foo": "bar"}) + validated = output_schema.validate_json(json_str) + assert validated == ["some", "output"] diff --git a/tests/test_responses_tracing.py b/tests/test_responses_tracing.py index 0bc97a95..db24fe49 100644 --- a/tests/test_responses_tracing.py +++ b/tests/test_responses_tracing.py @@ -1,7 +1,10 @@ +from typing import Optional + import pytest from inline_snapshot import snapshot from openai import AsyncOpenAI from openai.types.responses import ResponseCompletedEvent +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents import ModelSettings, ModelTracing, OpenAIResponsesModel, trace from agents.tracing.span_data import ResponseSpanData @@ -16,10 +19,25 @@ def is_disabled(self): class DummyUsage: - def __init__(self, input_tokens=1, output_tokens=1, total_tokens=2): + def __init__( + self, + input_tokens: int = 1, + input_tokens_details: Optional[InputTokensDetails] = None, + output_tokens: int = 1, + output_tokens_details: Optional[OutputTokensDetails] = None, + total_tokens: int = 2, + ): self.input_tokens = input_tokens self.output_tokens = output_tokens self.total_tokens = total_tokens + self.input_tokens_details = ( + input_tokens_details if input_tokens_details else InputTokensDetails(cached_tokens=0) + ) + self.output_tokens_details = ( + output_tokens_details + if output_tokens_details + else OutputTokensDetails(reasoning_tokens=0) + ) class DummyResponse: @@ -32,6 +50,7 @@ def __aiter__(self): yield ResponseCompletedEvent( type="response.completed", response=fake_model.get_response_obj(self.output), + sequence_number=0, ) @@ -183,6 +202,7 @@ async def __aiter__(self): yield ResponseCompletedEvent( type="response.completed", response=fake_model.get_response_obj([], "dummy-id-123"), + sequence_number=0, ) return DummyStream() @@ -235,6 +255,7 @@ async def __aiter__(self): yield ResponseCompletedEvent( type="response.completed", response=fake_model.get_response_obj([], "dummy-id-123"), + sequence_number=0, ) return DummyStream() @@ -286,6 +307,7 @@ async def __aiter__(self): yield ResponseCompletedEvent( type="response.completed", response=fake_model.get_response_obj([], "dummy-id-123"), + sequence_number=0, ) return DummyStream() diff --git a/tests/test_result_cast.py b/tests/test_result_cast.py index ec17e327..c621e735 100644 --- a/tests/test_result_cast.py +++ b/tests/test_result_cast.py @@ -3,7 +3,7 @@ import pytest from pydantic import BaseModel -from agents import Agent, RunResult +from agents import Agent, RunContextWrapper, RunResult def create_run_result(final_output: Any) -> RunResult: @@ -15,6 +15,7 @@ def create_run_result(final_output: Any) -> RunResult: input_guardrail_results=[], output_guardrail_results=[], _last_agent=Agent(name="test"), + context_wrapper=RunContextWrapper(context=None), ) diff --git a/tests/test_usage.py b/tests/test_usage.py new file mode 100644 index 00000000..405f99dd --- /dev/null +++ b/tests/test_usage.py @@ -0,0 +1,52 @@ +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails + +from agents.usage import Usage + + +def test_usage_add_aggregates_all_fields(): + u1 = Usage( + requests=1, + input_tokens=10, + input_tokens_details=InputTokensDetails(cached_tokens=3), + output_tokens=20, + output_tokens_details=OutputTokensDetails(reasoning_tokens=5), + total_tokens=30, + ) + u2 = Usage( + requests=2, + input_tokens=7, + input_tokens_details=InputTokensDetails(cached_tokens=4), + output_tokens=8, + output_tokens_details=OutputTokensDetails(reasoning_tokens=6), + total_tokens=15, + ) + + u1.add(u2) + + assert u1.requests == 3 + assert u1.input_tokens == 17 + assert u1.output_tokens == 28 + assert u1.total_tokens == 45 + assert u1.input_tokens_details.cached_tokens == 7 + assert u1.output_tokens_details.reasoning_tokens == 11 + + +def test_usage_add_aggregates_with_none_values(): + u1 = Usage() + u2 = Usage( + requests=2, + input_tokens=7, + input_tokens_details=InputTokensDetails(cached_tokens=4), + output_tokens=8, + output_tokens_details=OutputTokensDetails(reasoning_tokens=6), + total_tokens=15, + ) + + u1.add(u2) + + assert u1.requests == 2 + assert u1.input_tokens == 7 + assert u1.output_tokens == 8 + assert u1.total_tokens == 15 + assert u1.input_tokens_details.cached_tokens == 4 + assert u1.output_tokens_details.reasoning_tokens == 6 diff --git a/tests/voice/conftest.py b/tests/voice/conftest.py index 6ed7422c..79d85d8b 100644 --- a/tests/voice/conftest.py +++ b/tests/voice/conftest.py @@ -9,4 +9,3 @@ def pytest_ignore_collect(collection_path, config): if str(collection_path).startswith(this_dir): return True - diff --git a/tests/voice/test_workflow.py b/tests/voice/test_workflow.py index 72a3370d..035a05d5 100644 --- a/tests/voice/test_workflow.py +++ b/tests/voice/test_workflow.py @@ -9,7 +9,7 @@ from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent from agents import Agent, Model, ModelSettings, ModelTracing, Tool -from agents.agent_output import AgentOutputSchema +from agents.agent_output import AgentOutputSchemaBase from agents.handoffs import Handoff from agents.items import ( ModelResponse, @@ -48,7 +48,7 @@ async def get_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, *, @@ -62,7 +62,7 @@ async def stream_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, *, @@ -81,11 +81,13 @@ async def stream_response( type="response.output_text.delta", output_index=0, item_id=item.id, + sequence_number=0, ) yield ResponseCompletedEvent( type="response.completed", response=get_response_obj(output), + sequence_number=1, ) diff --git a/uv.lock b/uv.lock index 24d089a5..6f2f3f84 100644 --- a/uv.lock +++ b/uv.lock @@ -483,6 +483,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 }, ] +[[package]] +name = "fastapi" +version = "0.115.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/55/ae499352d82338331ca1e28c7f4a63bfd09479b16395dce38cf50a39e2c2/fastapi-0.115.12.tar.gz", hash = "sha256:1e2c2a2646905f9e83d32f04a3f86aff4a286669c6c950ca95b5fd68c2602681", size = 295236 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/50/b3/b51f09c2ba432a576fe63758bddc81f78f0c6309d9e5c10d194313bf021e/fastapi-0.115.12-py3-none-any.whl", hash = "sha256:e94613d6c05e27be7ffebdd6ea5f388112e5e430c8f7d6494a9d1d88d43e814d", size = 95164 }, +] + [[package]] name = "filelock" version = "3.18.0" @@ -914,7 +928,7 @@ wheels = [ [[package]] name = "litellm" -version = "1.66.1" +version = "1.67.4.post1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -929,10 +943,7 @@ dependencies = [ { name = "tiktoken" }, { name = "tokenizers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c1/21/12562c37310254456afdd277454dac4d14b8b40796216e8a438a9e1c5e86/litellm-1.66.1.tar.gz", hash = "sha256:98f7add913e5eae2131dd412ee27532d9a309defd9dbb64f6c6c42ea8a2af068", size = 7203211 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e9/33/fdc4615ca621940406e3b0b303e900bc2868cfcd8c62c4a6f5e7d2f6a56c/litellm-1.66.1-py3-none-any.whl", hash = "sha256:1f601fea3f086c1d2d91be60b9db115082a2f3a697e4e0def72f8b9c777c7232", size = 7559553 }, -] +sdist = { url = "https://files.pythonhosted.org/packages/4d/89/bacf75633dd43d6c5536380fb652c4af25046c29f5c6e5fdb4e8fe5af505/litellm-1.67.4.post1.tar.gz", hash = "sha256:057f2505f82d8c3f83d705c375b0d1931de998b13e239a6b06e16ee351fda648", size = 7243930 } [[package]] name = "markdown" @@ -1036,7 +1047,7 @@ wheels = [ [[package]] name = "mcp" -version = "1.6.0" +version = "1.8.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio", marker = "python_full_version >= '3.10'" }, @@ -1044,13 +1055,14 @@ dependencies = [ { name = "httpx-sse", marker = "python_full_version >= '3.10'" }, { name = "pydantic", marker = "python_full_version >= '3.10'" }, { name = "pydantic-settings", marker = "python_full_version >= '3.10'" }, + { name = "python-multipart", marker = "python_full_version >= '3.10'" }, { name = "sse-starlette", marker = "python_full_version >= '3.10'" }, { name = "starlette", marker = "python_full_version >= '3.10'" }, - { name = "uvicorn", marker = "python_full_version >= '3.10'" }, + { name = "uvicorn", marker = "python_full_version >= '3.10' and sys_platform != 'emscripten'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/95/d2/f587cb965a56e992634bebc8611c5b579af912b74e04eb9164bd49527d21/mcp-1.6.0.tar.gz", hash = "sha256:d9324876de2c5637369f43161cd71eebfd803df5a95e46225cab8d280e366723", size = 200031 } +sdist = { url = "https://files.pythonhosted.org/packages/7c/13/16b712e8a3be6a736b411df2fc6b4e75eb1d3e99b1cd57a3a1decf17f612/mcp-1.8.1.tar.gz", hash = "sha256:ec0646271d93749f784d2316fb5fe6102fb0d1be788ec70a9e2517e8f2722c0e", size = 265605 } wheels = [ - { url = "https://files.pythonhosted.org/packages/10/30/20a7f33b0b884a9d14dd3aa94ff1ac9da1479fe2ad66dd9e2736075d2506/mcp-1.6.0-py3-none-any.whl", hash = "sha256:7bd24c6ea042dbec44c754f100984d186620d8b841ec30f1b19eda9b93a634d0", size = 76077 }, + { url = "https://files.pythonhosted.org/packages/1c/5d/91cf0d40e40ae9ecf8d4004e0f9611eea86085aa0b5505493e0ff53972da/mcp-1.8.1-py3-none-any.whl", hash = "sha256:948e03783859fa35abe05b9b6c0a1d5519be452fc079dc8d7f682549591c1770", size = 119761 }, ] [[package]] @@ -1449,7 +1461,7 @@ wheels = [ [[package]] name = "openai" -version = "1.74.0" +version = "1.81.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1461,14 +1473,14 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/75/86/c605a6e84da0248f2cebfcd864b5a6076ecf78849245af5e11d2a5ec7977/openai-1.74.0.tar.gz", hash = "sha256:592c25b8747a7cad33a841958f5eb859a785caea9ee22b9e4f4a2ec062236526", size = 427571 } +sdist = { url = "https://files.pythonhosted.org/packages/1c/89/a1e4f3fa7ca4f7fec90dbf47d93b7cd5ff65924926733af15044e302a192/openai-1.81.0.tar.gz", hash = "sha256:349567a8607e0bcffd28e02f96b5c2397d0d25d06732d90ab3ecbf97abf030f9", size = 456861 } wheels = [ - { url = "https://files.pythonhosted.org/packages/a9/91/8c150f16a96367e14bd7d20e86e0bbbec3080e3eb593e63f21a7f013f8e4/openai-1.74.0-py3-none-any.whl", hash = "sha256:aff3e0f9fb209836382ec112778667027f4fd6ae38bdb2334bc9e173598b092a", size = 644790 }, + { url = "https://files.pythonhosted.org/packages/02/66/bcc7f9bf48e8610a33e3b5c96a5a644dad032d92404ea2a5e8b43ba067e8/openai-1.81.0-py3-none-any.whl", hash = "sha256:1c71572e22b43876c5d7d65ade0b7b516bb527c3d44ae94111267a09125f7bae", size = 717529 }, ] [[package]] name = "openai-agents" -version = "0.0.11" +version = "0.0.16" source = { editable = "." } dependencies = [ { name = "griffe" }, @@ -1496,6 +1508,7 @@ voice = [ dev = [ { name = "coverage" }, { name = "eval-type-backport" }, + { name = "fastapi" }, { name = "graphviz" }, { name = "inline-snapshot" }, { name = "mkdocs" }, @@ -1520,10 +1533,10 @@ dev = [ requires-dist = [ { name = "graphviz", marker = "extra == 'viz'", specifier = ">=0.17" }, { name = "griffe", specifier = ">=1.5.6,<2" }, - { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.65.0,<2" }, - { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.6.0,<2" }, + { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.67.4.post1,<2" }, + { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.8.0,<2" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, - { name = "openai", specifier = ">=1.66.5" }, + { name = "openai", specifier = ">=1.81.0" }, { name = "pydantic", specifier = ">=2.10,<3" }, { name = "requests", specifier = ">=2.0,<3" }, { name = "types-requests", specifier = ">=2.0,<3" }, @@ -1536,6 +1549,7 @@ provides-extras = ["voice", "viz", "litellm"] dev = [ { name = "coverage", specifier = ">=7.6.12" }, { name = "eval-type-backport", specifier = ">=0.2.2" }, + { name = "fastapi", specifier = ">=0.110.0,<1" }, { name = "graphviz" }, { name = "inline-snapshot", specifier = ">=0.20.7" }, { name = "mkdocs", specifier = ">=1.6.0" }, @@ -2072,6 +2086,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/18/98a99ad95133c6a6e2005fe89faedf294a748bd5dc803008059409ac9b1e/python_dotenv-1.1.0-py3-none-any.whl", hash = "sha256:d7c01d9e2293916c18baf562d95698754b0dbbb5e74d457c45d4f6561fb9d55d", size = 20256 }, ] +[[package]] +name = "python-multipart" +version = "0.0.20" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/87/f44d7c9f274c7ee665a29b885ec97089ec5dc034c7f3fafa03da9e39a09e/python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13", size = 37158 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size = 24546 }, +] + [[package]] name = "python-xlib" version = "0.33" @@ -2474,7 +2497,8 @@ name = "starlette" version = "0.46.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "anyio", marker = "python_full_version >= '3.10'" }, + { name = "anyio" }, + { name = "typing-extensions", marker = "python_full_version < '3.10'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846 } wheels = [