Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e5d3ad7

Browse files
authored
[1/n] Break Agent into AgentBase+Agent (openai#1068)
This allows base agent stuff to be shared with the realtime agent
1 parent a9757de commit e5d3ad7

File tree

7 files changed

+111
-61
lines changed

7 files changed

+111
-61
lines changed

src/agents/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from openai import AsyncOpenAI
66

77
from . import _config
8-
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
8+
from .agent import Agent, AgentBase, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
99
from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
1010
from .computer import AsyncComputer, Button, Computer, Environment
1111
from .exceptions import (
@@ -161,6 +161,7 @@ def enable_verbose_stdout_logging():
161161

162162
__all__ = [
163163
"Agent",
164+
"AgentBase",
164165
"ToolsToFinalOutputFunction",
165166
"ToolsToFinalOutputResult",
166167
"Runner",

src/agents/agent.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,63 @@ class MCPConfig(TypedDict):
6767

6868

6969
@dataclass
70-
class Agent(Generic[TContext]):
70+
class AgentBase(Generic[TContext]):
71+
"""Base class for `Agent` and `RealtimeAgent`."""
72+
73+
name: str
74+
"""The name of the agent."""
75+
76+
handoff_description: str | None = None
77+
"""A description of the agent. This is used when the agent is used as a handoff, so that an
78+
LLM knows what it does and when to invoke it.
79+
"""
80+
81+
tools: list[Tool] = field(default_factory=list)
82+
"""A list of tools that the agent can use."""
83+
84+
mcp_servers: list[MCPServer] = field(default_factory=list)
85+
"""A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
86+
the agent can use. Every time the agent runs, it will include tools from these servers in the
87+
list of available tools.
88+
89+
NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
90+
`server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
91+
longer needed.
92+
"""
93+
94+
mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
95+
"""Configuration for MCP servers."""
96+
97+
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
98+
"""Fetches the available tools from the MCP servers."""
99+
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
100+
return await MCPUtil.get_all_function_tools(
101+
self.mcp_servers, convert_schemas_to_strict, run_context, self
102+
)
103+
104+
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
105+
"""All agent tools, including MCP tools and function tools."""
106+
mcp_tools = await self.get_mcp_tools(run_context)
107+
108+
async def _check_tool_enabled(tool: Tool) -> bool:
109+
if not isinstance(tool, FunctionTool):
110+
return True
111+
112+
attr = tool.is_enabled
113+
if isinstance(attr, bool):
114+
return attr
115+
res = attr(run_context, self)
116+
if inspect.isawaitable(res):
117+
return bool(await res)
118+
return bool(res)
119+
120+
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
121+
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
122+
return [*mcp_tools, *enabled]
123+
124+
125+
@dataclass
126+
class Agent(AgentBase, Generic[TContext]):
71127
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
72128
73129
We strongly recommend passing `instructions`, which is the "system prompt" for the agent. In
@@ -76,10 +132,9 @@ class Agent(Generic[TContext]):
76132
77133
Agents are generic on the context type. The context is a (mutable) object you create. It is
78134
passed to tool functions, handoffs, guardrails, etc.
79-
"""
80135
81-
name: str
82-
"""The name of the agent."""
136+
See `AgentBase` for base parameters that are shared with `RealtimeAgent`s.
137+
"""
83138

84139
instructions: (
85140
str
@@ -103,11 +158,6 @@ class Agent(Generic[TContext]):
103158
usable with OpenAI models, using the Responses API.
104159
"""
105160

106-
handoff_description: str | None = None
107-
"""A description of the agent. This is used when the agent is used as a handoff, so that an
108-
LLM knows what it does and when to invoke it.
109-
"""
110-
111161
handoffs: list[Agent[Any] | Handoff[TContext]] = field(default_factory=list)
112162
"""Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs,
113163
and the agent can choose to delegate to them if relevant. Allows for separation of concerns and
@@ -125,22 +175,6 @@ class Agent(Generic[TContext]):
125175
"""Configures model-specific tuning parameters (e.g. temperature, top_p).
126176
"""
127177

128-
tools: list[Tool] = field(default_factory=list)
129-
"""A list of tools that the agent can use."""
130-
131-
mcp_servers: list[MCPServer] = field(default_factory=list)
132-
"""A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
133-
the agent can use. Every time the agent runs, it will include tools from these servers in the
134-
list of available tools.
135-
136-
NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
137-
`server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
138-
longer needed.
139-
"""
140-
141-
mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
142-
"""Configuration for MCP servers."""
143-
144178
input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
145179
"""A list of checks that run in parallel to the agent's execution, before generating a
146180
response. Runs only if the agent is the first agent in the chain.

src/agents/lifecycle.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
11
from typing import Any, Generic
22

3-
from .agent import Agent
3+
from typing_extensions import TypeVar
4+
5+
from .agent import Agent, AgentBase
46
from .run_context import RunContextWrapper, TContext
57
from .tool import Tool
68

9+
TAgent = TypeVar("TAgent", bound=AgentBase, default=AgentBase)
10+
711

8-
class RunHooks(Generic[TContext]):
12+
class RunHooksBase(Generic[TContext, TAgent]):
913
"""A class that receives callbacks on various lifecycle events in an agent run. Subclass and
1014
override the methods you need.
1115
"""
1216

13-
async def on_agent_start(
14-
self, context: RunContextWrapper[TContext], agent: Agent[TContext]
15-
) -> None:
17+
async def on_agent_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
1618
"""Called before the agent is invoked. Called each time the current agent changes."""
1719
pass
1820

1921
async def on_agent_end(
2022
self,
2123
context: RunContextWrapper[TContext],
22-
agent: Agent[TContext],
24+
agent: TAgent,
2325
output: Any,
2426
) -> None:
2527
"""Called when the agent produces a final output."""
@@ -28,16 +30,16 @@ async def on_agent_end(
2830
async def on_handoff(
2931
self,
3032
context: RunContextWrapper[TContext],
31-
from_agent: Agent[TContext],
32-
to_agent: Agent[TContext],
33+
from_agent: TAgent,
34+
to_agent: TAgent,
3335
) -> None:
3436
"""Called when a handoff occurs."""
3537
pass
3638

3739
async def on_tool_start(
3840
self,
3941
context: RunContextWrapper[TContext],
40-
agent: Agent[TContext],
42+
agent: TAgent,
4143
tool: Tool,
4244
) -> None:
4345
"""Called before a tool is invoked."""
@@ -46,30 +48,30 @@ async def on_tool_start(
4648
async def on_tool_end(
4749
self,
4850
context: RunContextWrapper[TContext],
49-
agent: Agent[TContext],
51+
agent: TAgent,
5052
tool: Tool,
5153
result: str,
5254
) -> None:
5355
"""Called after a tool is invoked."""
5456
pass
5557

5658

57-
class AgentHooks(Generic[TContext]):
59+
class AgentHooksBase(Generic[TContext, TAgent]):
5860
"""A class that receives callbacks on various lifecycle events for a specific agent. You can
5961
set this on `agent.hooks` to receive events for that specific agent.
6062
6163
Subclass and override the methods you need.
6264
"""
6365

64-
async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None:
66+
async def on_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
6567
"""Called before the agent is invoked. Called each time the running agent is changed to this
6668
agent."""
6769
pass
6870

6971
async def on_end(
7072
self,
7173
context: RunContextWrapper[TContext],
72-
agent: Agent[TContext],
74+
agent: TAgent,
7375
output: Any,
7476
) -> None:
7577
"""Called when the agent produces a final output."""
@@ -78,8 +80,8 @@ async def on_end(
7880
async def on_handoff(
7981
self,
8082
context: RunContextWrapper[TContext],
81-
agent: Agent[TContext],
82-
source: Agent[TContext],
83+
agent: TAgent,
84+
source: TAgent,
8385
) -> None:
8486
"""Called when the agent is being handed off to. The `source` is the agent that is handing
8587
off to this agent."""
@@ -88,7 +90,7 @@ async def on_handoff(
8890
async def on_tool_start(
8991
self,
9092
context: RunContextWrapper[TContext],
91-
agent: Agent[TContext],
93+
agent: TAgent,
9294
tool: Tool,
9395
) -> None:
9496
"""Called before a tool is invoked."""
@@ -97,9 +99,16 @@ async def on_tool_start(
9799
async def on_tool_end(
98100
self,
99101
context: RunContextWrapper[TContext],
100-
agent: Agent[TContext],
102+
agent: TAgent,
101103
tool: Tool,
102104
result: str,
103105
) -> None:
104106
"""Called after a tool is invoked."""
105107
pass
108+
109+
110+
RunHooks = RunHooksBase[TContext, Agent]
111+
"""Run hooks when using `Agent`."""
112+
113+
AgentHooks = AgentHooksBase[TContext, Agent]
114+
"""Agent hooks for `Agent`s."""

src/agents/mcp/server.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .util import ToolFilter, ToolFilterCallable, ToolFilterContext, ToolFilterStatic
2323

2424
if TYPE_CHECKING:
25-
from ..agent import Agent
25+
from ..agent import AgentBase
2626

2727

2828
class MCPServer(abc.ABC):
@@ -53,7 +53,7 @@ async def cleanup(self):
5353
async def list_tools(
5454
self,
5555
run_context: RunContextWrapper[Any] | None = None,
56-
agent: Agent[Any] | None = None,
56+
agent: AgentBase | None = None,
5757
) -> list[MCPTool]:
5858
"""List the tools available on the server."""
5959
pass
@@ -117,7 +117,7 @@ async def _apply_tool_filter(
117117
self,
118118
tools: list[MCPTool],
119119
run_context: RunContextWrapper[Any],
120-
agent: Agent[Any],
120+
agent: AgentBase,
121121
) -> list[MCPTool]:
122122
"""Apply the tool filter to the list of tools."""
123123
if self.tool_filter is None:
@@ -153,7 +153,7 @@ async def _apply_dynamic_tool_filter(
153153
self,
154154
tools: list[MCPTool],
155155
run_context: RunContextWrapper[Any],
156-
agent: Agent[Any],
156+
agent: AgentBase,
157157
) -> list[MCPTool]:
158158
"""Apply dynamic tool filtering using a callable filter function."""
159159

@@ -244,7 +244,7 @@ async def connect(self):
244244
async def list_tools(
245245
self,
246246
run_context: RunContextWrapper[Any] | None = None,
247-
agent: Agent[Any] | None = None,
247+
agent: AgentBase | None = None,
248248
) -> list[MCPTool]:
249249
"""List the tools available on the server."""
250250
if not self.session:

src/agents/mcp/util.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,19 @@
55

66
from typing_extensions import NotRequired, TypedDict
77

8-
from agents.strict_schema import ensure_strict_json_schema
9-
108
from .. import _debug
119
from ..exceptions import AgentsException, ModelBehaviorError, UserError
1210
from ..logger import logger
1311
from ..run_context import RunContextWrapper
12+
from ..strict_schema import ensure_strict_json_schema
1413
from ..tool import FunctionTool, Tool
1514
from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span
1615
from ..util._types import MaybeAwaitable
1716

1817
if TYPE_CHECKING:
1918
from mcp.types import Tool as MCPTool
2019

21-
from ..agent import Agent
20+
from ..agent import AgentBase
2221
from .server import MCPServer
2322

2423

@@ -29,7 +28,7 @@ class ToolFilterContext:
2928
run_context: RunContextWrapper[Any]
3029
"""The current run context."""
3130

32-
agent: "Agent[Any]"
31+
agent: "AgentBase"
3332
"""The agent that is requesting the tool list."""
3433

3534
server_name: str
@@ -100,7 +99,7 @@ async def get_all_function_tools(
10099
servers: list["MCPServer"],
101100
convert_schemas_to_strict: bool,
102101
run_context: RunContextWrapper[Any],
103-
agent: "Agent[Any]",
102+
agent: "AgentBase",
104103
) -> list[Tool]:
105104
"""Get all function tools from a list of MCP servers."""
106105
tools = []
@@ -126,7 +125,7 @@ async def get_function_tools(
126125
server: "MCPServer",
127126
convert_schemas_to_strict: bool,
128127
run_context: RunContextWrapper[Any],
129-
agent: "Agent[Any]",
128+
agent: "AgentBase",
130129
) -> list[Tool]:
131130
"""Get all function tools from a single MCP server."""
132131

src/agents/tool.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from .util._types import MaybeAwaitable
3131

3232
if TYPE_CHECKING:
33-
from .agent import Agent
33+
from .agent import Agent, AgentBase
3434

3535
ToolParams = ParamSpec("ToolParams")
3636

@@ -87,7 +87,7 @@ class FunctionTool:
8787
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
8888
as it increases the likelihood of correct JSON input."""
8989

90-
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
90+
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True
9191
"""Whether the tool is enabled. Either a bool or a Callable that takes the run context and agent
9292
and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool
9393
based on your context/state."""
@@ -300,7 +300,7 @@ def function_tool(
300300
use_docstring_info: bool = True,
301301
failure_error_function: ToolErrorFunction | None = None,
302302
strict_mode: bool = True,
303-
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
303+
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
304304
) -> FunctionTool:
305305
"""Overload for usage as @function_tool (no parentheses)."""
306306
...
@@ -315,7 +315,7 @@ def function_tool(
315315
use_docstring_info: bool = True,
316316
failure_error_function: ToolErrorFunction | None = None,
317317
strict_mode: bool = True,
318-
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
318+
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
319319
) -> Callable[[ToolFunction[...]], FunctionTool]:
320320
"""Overload for usage as @function_tool(...)."""
321321
...
@@ -330,7 +330,7 @@ def function_tool(
330330
use_docstring_info: bool = True,
331331
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
332332
strict_mode: bool = True,
333-
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
333+
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
334334
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
335335
"""
336336
Decorator to create a FunctionTool from a function. By default, we will:

0 commit comments

Comments
 (0)