diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 5f89155092..b8ab970696 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -16,6 +16,7 @@ from typing_extensions import TypeVar, assert_never from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore +from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION from pydantic_ai._tool_manager import ToolManager from pydantic_ai._utils import dataclasses_no_defaults_repr, get_union_args, is_async_callable, run_in_executor from pydantic_ai.builtin_tools import AbstractBuiltinTool @@ -704,6 +705,9 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT tracer=ctx.deps.tracer, trace_include_content=ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content, + instrumentation_version=ctx.deps.instrumentation_settings.version + if ctx.deps.instrumentation_settings + else DEFAULT_INSTRUMENTATION_VERSION, run_step=ctx.state.run_step, tool_call_approved=ctx.state.run_step == 0, ) diff --git a/pydantic_ai_slim/pydantic_ai/_instrumentation.py b/pydantic_ai_slim/pydantic_ai/_instrumentation.py new file mode 100644 index 0000000000..cba1eb8139 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/_instrumentation.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import Self + +DEFAULT_INSTRUMENTATION_VERSION = 2 +"""Default instrumentation version for `InstrumentationSettings`.""" + + +@dataclass(frozen=True) +class InstrumentationNames: + """Configuration for instrumentation span names and attributes based on version.""" + + # Agent run span configuration + agent_run_span_name: str + agent_name_attr: str + + # Tool execution span configuration + tool_span_name: str + tool_arguments_attr: str + tool_result_attr: str + + # Output Tool execution span configuration + output_tool_span_name: str + + @classmethod + def for_version(cls, version: int) -> Self: + """Create instrumentation configuration for a specific version. + + Args: + version: The instrumentation version (1, 2, or 3+) + + Returns: + InstrumentationConfig instance with version-appropriate settings + """ + if version <= 2: + return cls( + agent_run_span_name='agent run', + agent_name_attr='agent_name', + tool_span_name='running tool', + tool_arguments_attr='tool_arguments', + tool_result_attr='tool_response', + output_tool_span_name='running output function', + ) + else: + return cls( + agent_run_span_name='invoke_agent', + agent_name_attr='gen_ai.agent.name', + tool_span_name='execute_tool', # Will be formatted with tool name + tool_arguments_attr='gen_ai.tool.call.arguments', + tool_result_attr='gen_ai.tool.call.result', + output_tool_span_name='execute_tool', + ) + + def get_agent_run_span_name(self, agent_name: str) -> str: + """Get the formatted agent span name. + + Args: + agent_name: Name of the agent being executed + + Returns: + Formatted span name + """ + if self.agent_run_span_name == 'invoke_agent': + return f'invoke_agent {agent_name}' + return self.agent_run_span_name + + def get_tool_span_name(self, tool_name: str) -> str: + """Get the formatted tool span name. + + Args: + tool_name: Name of the tool being executed + + Returns: + Formatted span name + """ + if self.tool_span_name == 'execute_tool': + return f'execute_tool {tool_name}' + return self.tool_span_name + + def get_output_tool_span_name(self, tool_name: str) -> str: + """Get the formatted output tool span name. + + Args: + tool_name: Name of the tool being executed + + Returns: + Formatted span name + """ + if self.output_tool_span_name == 'execute_tool': + return f'execute_tool {tool_name}' + return self.output_tool_span_name diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index c516b92644..b70a0780f9 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -11,6 +11,8 @@ from pydantic_core import SchemaValidator, to_json from typing_extensions import Self, TypedDict, TypeVar, assert_never +from pydantic_ai._instrumentation import InstrumentationNames + from . import _function_schema, _utils, messages as _messages from ._run_context import AgentDepsT, RunContext from .exceptions import ModelRetry, ToolRetryError, UserError @@ -95,6 +97,7 @@ async def execute_traced_output_function( ToolRetryError: When wrap_validation_errors is True and a ModelRetry is caught ModelRetry: When wrap_validation_errors is False and a ModelRetry occurs """ + instrumentation_names = InstrumentationNames.for_version(run_context.instrumentation_version) # Set up span attributes tool_name = run_context.tool_name or getattr(function_schema.function, '__name__', 'output_function') attributes = { @@ -104,18 +107,29 @@ async def execute_traced_output_function( if run_context.tool_call_id: attributes['gen_ai.tool.call.id'] = run_context.tool_call_id if run_context.trace_include_content: - attributes['tool_arguments'] = to_json(args).decode() - attributes['logfire.json_schema'] = json.dumps( - { - 'type': 'object', - 'properties': { - 'tool_arguments': {'type': 'object'}, - 'tool_response': {'type': 'object'}, - }, - } - ) + attributes[instrumentation_names.tool_arguments_attr] = to_json(args).decode() + + attributes['logfire.json_schema'] = json.dumps( + { + 'type': 'object', + 'properties': { + **( + { + instrumentation_names.tool_arguments_attr: {'type': 'object'}, + instrumentation_names.tool_result_attr: {'type': 'object'}, + } + if run_context.trace_include_content + else {} + ), + 'gen_ai.tool.name': {}, + **({'gen_ai.tool.call.id': {}} if run_context.tool_call_id else {}), + }, + } + ) - with run_context.tracer.start_as_current_span('running output function', attributes=attributes) as span: + with run_context.tracer.start_as_current_span( + instrumentation_names.get_output_tool_span_name(tool_name), attributes=attributes + ) as span: try: output = await function_schema.call(args, run_context) except ModelRetry as r: @@ -135,7 +149,7 @@ async def execute_traced_output_function( from .models.instrumented import InstrumentedModel span.set_attribute( - 'tool_response', + instrumentation_names.tool_result_attr, output if isinstance(output, str) else json.dumps(InstrumentedModel.serialize_any(output)), ) diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index 428697b59a..df2a4c1b5a 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -8,6 +8,8 @@ from opentelemetry.trace import NoOpTracer, Tracer from typing_extensions import TypeVar +from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION + from . import _utils, messages as _messages if TYPE_CHECKING: @@ -36,6 +38,8 @@ class RunContext(Generic[AgentDepsT]): """The tracer to use for tracing the run.""" trace_include_content: bool = False """Whether to include the content of the messages in the trace.""" + instrumentation_version: int = DEFAULT_INSTRUMENTATION_VERSION + """Instrumentation settings version, if instrumentation is enabled.""" retries: dict[str, int] = field(default_factory=dict) """Number of retries for each tool so far.""" tool_call_id: str | None = None diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index 5e27c2584e..5cf66b00dd 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -12,6 +12,7 @@ from typing_extensions import assert_never from . import messages as _messages +from ._instrumentation import InstrumentationNames from ._run_context import AgentDepsT, RunContext from .exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior from .messages import ToolCallPart @@ -115,6 +116,7 @@ async def handle_call( wrap_validation_errors, self.ctx.tracer, self.ctx.trace_include_content, + self.ctx.instrumentation_version, usage_limits, ) @@ -203,15 +205,18 @@ async def _call_tool_traced( allow_partial: bool, wrap_validation_errors: bool, tracer: Tracer, - include_content: bool = False, + include_content: bool, + instrumentation_version: int, usage_limits: UsageLimits | None = None, ) -> Any: """See .""" + instrumentation_names = InstrumentationNames.for_version(instrumentation_version) + span_attributes = { 'gen_ai.tool.name': call.tool_name, # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai 'gen_ai.tool.call.id': call.tool_call_id, - **({'tool_arguments': call.args_as_json_str()} if include_content else {}), + **({instrumentation_names.tool_arguments_attr: call.args_as_json_str()} if include_content else {}), 'logfire.msg': f'running tool: {call.tool_name}', # add the JSON schema so these attributes are formatted nicely in Logfire 'logfire.json_schema': json.dumps( @@ -220,8 +225,8 @@ async def _call_tool_traced( 'properties': { **( { - 'tool_arguments': {'type': 'object'}, - 'tool_response': {'type': 'object'}, + instrumentation_names.tool_arguments_attr: {'type': 'object'}, + instrumentation_names.tool_result_attr: {'type': 'object'}, } if include_content else {} @@ -232,18 +237,21 @@ async def _call_tool_traced( } ), } - with tracer.start_as_current_span('running tool', attributes=span_attributes) as span: + with tracer.start_as_current_span( + instrumentation_names.get_tool_span_name(call.tool_name), + attributes=span_attributes, + ) as span: try: tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors, usage_limits) except ToolRetryError as e: part = e.tool_retry if include_content and span.is_recording(): - span.set_attribute('tool_response', part.model_response()) + span.set_attribute(instrumentation_names.tool_result_attr, part.model_response()) raise e if include_content and span.is_recording(): span.set_attribute( - 'tool_response', + instrumentation_names.tool_result_attr, tool_result if isinstance(tool_result, str) else _messages.tool_return_ta.dump_json(tool_result).decode(), diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 7a411c556a..72c256e9c4 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -14,6 +14,7 @@ from pydantic.json_schema import GenerateJsonSchema from typing_extensions import Self, TypeVar, deprecated +from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION, InstrumentationNames from pydantic_graph import Graph from .. import ( @@ -65,7 +66,7 @@ from ..toolsets.combined import CombinedToolset from ..toolsets.function import FunctionToolset from ..toolsets.prepared import PreparedToolset -from .abstract import AbstractAgent, EventStreamHandler, RunOutputDataT +from .abstract import AbstractAgent, EventStreamHandler, Instructions, RunOutputDataT from .wrapper import WrapperAgent if TYPE_CHECKING: @@ -136,8 +137,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): _deps_type: type[AgentDepsT] = dataclasses.field(repr=False) _output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False) _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False) - _instructions: str | None = dataclasses.field(repr=False) - _instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) + _instructions: list[str | _system_prompt.SystemPromptFunc[AgentDepsT]] = dataclasses.field(repr=False) _system_prompts: tuple[str, ...] = dataclasses.field(repr=False) _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field( @@ -163,10 +163,7 @@ def __init__( model: models.Model | models.KnownModelName | str | None = None, *, output_type: OutputSpec[OutputDataT] = str, - instructions: str - | _system_prompt.SystemPromptFunc[AgentDepsT] - | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] - | None = None, + instructions: Instructions[AgentDepsT] = None, system_prompt: str | Sequence[str] = (), deps_type: type[AgentDepsT] = NoneType, name: str | None = None, @@ -192,10 +189,7 @@ def __init__( model: models.Model | models.KnownModelName | str | None = None, *, output_type: OutputSpec[OutputDataT] = str, - instructions: str - | _system_prompt.SystemPromptFunc[AgentDepsT] - | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] - | None = None, + instructions: Instructions[AgentDepsT] = None, system_prompt: str | Sequence[str] = (), deps_type: type[AgentDepsT] = NoneType, name: str | None = None, @@ -219,10 +213,7 @@ def __init__( model: models.Model | models.KnownModelName | str | None = None, *, output_type: OutputSpec[OutputDataT] = str, - instructions: str - | _system_prompt.SystemPromptFunc[AgentDepsT] - | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] - | None = None, + instructions: Instructions[AgentDepsT] = None, system_prompt: str | Sequence[str] = (), deps_type: type[AgentDepsT] = NoneType, name: str | None = None, @@ -321,16 +312,7 @@ def __init__( self._output_schema = _output.OutputSchema[OutputDataT].build(output_type, default_mode=default_output_mode) self._output_validators = [] - self._instructions = '' - self._instructions_functions = [] - if isinstance(instructions, str | Callable): - instructions = [instructions] - for instruction in instructions or []: - if isinstance(instruction, str): - self._instructions += instruction + '\n' - else: - self._instructions_functions.append(_system_prompt.SystemPromptRunner(instruction)) - self._instructions = self._instructions.strip() or None + self._instructions = self._normalize_instructions(instructions) self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt) self._system_prompt_functions = [] @@ -370,6 +352,9 @@ def __init__( self._override_tools: ContextVar[ _utils.Option[Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]]] ] = ContextVar('_override_tools', default=None) + self._override_instructions: ContextVar[ + _utils.Option[list[str | _system_prompt.SystemPromptFunc[AgentDepsT]]] + ] = ContextVar('_override_instructions', default=None) self._enter_lock = Lock() self._entered_count = 0 @@ -592,10 +577,12 @@ async def main(): model_settings = merge_model_settings(merged_settings, model_settings) usage_limits = usage_limits or _usage.UsageLimits() + instructions_literal, instructions_functions = self._get_instructions() + async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: parts = [ - self._instructions, - *[await func.run(run_context) for func in self._instructions_functions], + instructions_literal, + *[await func.run(run_context) for func in instructions_functions], ] model_profile = model_used.profile @@ -633,22 +620,28 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: get_instructions=get_instructions, instrumentation_settings=instrumentation_settings, ) + start_node = _agent_graph.UserPromptNode[AgentDepsT]( user_prompt=user_prompt, deferred_tool_results=deferred_tool_results, - instructions=self._instructions, - instructions_functions=self._instructions_functions, + instructions=instructions_literal, + instructions_functions=instructions_functions, system_prompts=self._system_prompts, system_prompt_functions=self._system_prompt_functions, system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, ) agent_name = self.name or 'agent' + instrumentation_names = InstrumentationNames.for_version( + instrumentation_settings.version if instrumentation_settings else DEFAULT_INSTRUMENTATION_VERSION + ) + run_span = tracer.start_span( - 'agent run', + instrumentation_names.get_agent_run_span_name(agent_name), attributes={ 'model_name': model_used.model_name if model_used else 'no-model', 'agent_name': agent_name, + 'gen_ai.agent.name': agent_name, 'logfire.msg': f'{agent_name} run', }, ) @@ -684,6 +677,8 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: def _run_span_end_attributes( self, state: _agent_graph.GraphAgentState, usage: _usage.RunUsage, settings: InstrumentationSettings ): + literal_instructions, _ = self._get_instructions() + if settings.version == 1: attrs = { 'all_messages_events': json.dumps( @@ -696,7 +691,7 @@ def _run_span_end_attributes( else: attrs = { 'pydantic_ai.all_messages': json.dumps(settings.messages_to_otel_messages(state.message_history)), - **settings.system_instructions_attributes(self._instructions), + **settings.system_instructions_attributes(literal_instructions), } return { @@ -721,8 +716,9 @@ def override( model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, + instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: - """Context manager to temporarily override agent dependencies, model, toolsets, or tools. + """Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions. This is particularly useful when testing. You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). @@ -732,6 +728,7 @@ def override( model: The model to use instead of the model passed to the agent run. toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. + instructions: The instructions to use instead of the instructions registered with the agent. """ if _utils.is_set(deps): deps_token = self._override_deps.set(_utils.Some(deps)) @@ -753,6 +750,12 @@ def override( else: tools_token = None + if _utils.is_set(instructions): + normalized_instructions = self._normalize_instructions(instructions) + instructions_token = self._override_instructions.set(_utils.Some(normalized_instructions)) + else: + instructions_token = None + try: yield finally: @@ -764,6 +767,8 @@ def override( self._override_toolsets.reset(toolsets_token) if tools_token is not None: self._override_tools.reset(tools_token) + if instructions_token is not None: + self._override_instructions.reset(instructions_token) @overload def instructions( @@ -824,12 +829,12 @@ async def async_instructions(ctx: RunContext[str]) -> str: def decorator( func_: _system_prompt.SystemPromptFunc[AgentDepsT], ) -> _system_prompt.SystemPromptFunc[AgentDepsT]: - self._instructions_functions.append(_system_prompt.SystemPromptRunner(func_)) + self._instructions.append(func_) return func_ return decorator else: - self._instructions_functions.append(_system_prompt.SystemPromptRunner(func)) + self._instructions.append(func) return func @overload @@ -1270,6 +1275,34 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T: else: return deps + def _normalize_instructions( + self, + instructions: Instructions[AgentDepsT], + ) -> list[str | _system_prompt.SystemPromptFunc[AgentDepsT]]: + if instructions is None: + return [] + if isinstance(instructions, str) or callable(instructions): + return [instructions] + return list(instructions) + + def _get_instructions( + self, + ) -> tuple[str | None, list[_system_prompt.SystemPromptRunner[AgentDepsT]]]: + override_instructions = self._override_instructions.get() + instructions = override_instructions.value if override_instructions else self._instructions + + literal_parts: list[str] = [] + functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = [] + + for instruction in instructions: + if isinstance(instruction, str): + literal_parts.append(instruction) + else: + functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](instruction)) + + literal = '\n'.join(literal_parts).strip() or None + return literal, functions + def _get_toolset( self, output_toolset: AbstractToolset[AgentDepsT] | None | _utils.Unset = _utils.UNSET, diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index 8d6c9ff293..fdd21b8065 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -14,6 +14,7 @@ from .. import ( _agent_graph, + _system_prompt, _utils, exceptions, messages as _messages, @@ -60,6 +61,14 @@ """A function that receives agent [`RunContext`][pydantic_ai.tools.RunContext] and an async iterable of events from the model's streaming response and the agent's execution of tools.""" +Instructions = ( + str + | _system_prompt.SystemPromptFunc[AgentDepsT] + | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] + | None +) + + class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC): """Abstract superclass for [`Agent`][pydantic_ai.agent.Agent], [`WrapperAgent`][pydantic_ai.agent.WrapperAgent], and your own custom agent implementations.""" @@ -681,8 +690,9 @@ def override( model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, + instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: - """Context manager to temporarily override agent dependencies, model, toolsets, or tools. + """Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions. This is particularly useful when testing. You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). @@ -692,6 +702,7 @@ def override( model: The model to use instead of the model passed to the agent run. toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. + instructions: The instructions to use instead of the instructions registered with the agent. """ raise NotImplementedError yield diff --git a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py index 36f7969323..ba735f0907 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py @@ -20,7 +20,7 @@ ToolFuncEither, ) from ..toolsets import AbstractToolset -from .abstract import AbstractAgent, EventStreamHandler, RunOutputDataT +from .abstract import AbstractAgent, EventStreamHandler, Instructions, RunOutputDataT class WrapperAgent(AbstractAgent[AgentDepsT, OutputDataT]): @@ -214,8 +214,9 @@ def override( model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, + instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: - """Context manager to temporarily override agent dependencies, model, toolsets, or tools. + """Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions. This is particularly useful when testing. You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). @@ -225,6 +226,13 @@ def override( model: The model to use instead of the model passed to the agent run. toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. + instructions: The instructions to use instead of the instructions registered with the agent. """ - with self.wrapped.override(deps=deps, model=model, toolsets=toolsets, tools=tools): + with self.wrapped.override( + deps=deps, + model=model, + toolsets=toolsets, + tools=tools, + instructions=instructions, + ): yield diff --git a/pydantic_ai_slim/pydantic_ai/direct.py b/pydantic_ai_slim/pydantic_ai/direct.py index 1383f58c66..38c92f71eb 100644 --- a/pydantic_ai_slim/pydantic_ai/direct.py +++ b/pydantic_ai_slim/pydantic_ai/direct.py @@ -81,7 +81,7 @@ async def main(): return await model_instance.request( messages, model_settings, - model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()), + model_request_parameters or models.ModelRequestParameters(), ) @@ -193,7 +193,7 @@ async def main(): return model_instance.request_stream( messages, model_settings, - model_instance.customize_request_parameters(model_request_parameters or models.ModelRequestParameters()), + model_request_parameters or models.ModelRequestParameters(), ) diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py index 10b3184fd4..d7d4987d8f 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py @@ -15,6 +15,7 @@ usage as _usage, ) from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent +from pydantic_ai.agent.abstract import Instructions from pydantic_ai.exceptions import UserError from pydantic_ai.models import Model from pydantic_ai.output import OutputDataT, OutputSpec @@ -704,8 +705,9 @@ def override( model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, + instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: - """Context manager to temporarily override agent dependencies, model, toolsets, or tools. + """Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions. This is particularly useful when testing. You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). @@ -715,11 +717,18 @@ def override( model: The model to use instead of the model passed to the agent run. toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. + instructions: The instructions to use instead of the instructions registered with the agent. """ if _utils.is_set(model) and not isinstance(model, (DBOSModel)): raise UserError( 'Non-DBOS model cannot be contextually overridden inside a DBOS workflow, it must be set at agent creation time.' ) - with super().override(deps=deps, model=model, toolsets=toolsets, tools=tools): + with super().override( + deps=deps, + model=model, + toolsets=toolsets, + tools=tools, + instructions=instructions, + ): yield diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py index 4f2161ebbf..a87b5195a9 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py @@ -22,7 +22,8 @@ models, usage as _usage, ) -from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent +from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, WrapperAgent +from pydantic_ai.agent.abstract import Instructions, RunOutputDataT from pydantic_ai.exceptions import UserError from pydantic_ai.models import Model from pydantic_ai.output import OutputDataT, OutputSpec @@ -748,8 +749,9 @@ def override( model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, + instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: - """Context manager to temporarily override agent dependencies, model, toolsets, or tools. + """Context manager to temporarily override agent dependencies, model, toolsets, tools, or instructions. This is particularly useful when testing. You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). @@ -759,6 +761,7 @@ def override( model: The model to use instead of the model passed to the agent run. toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. + instructions: The instructions to use instead of the instructions registered with the agent. """ if workflow.in_workflow(): if _utils.is_set(model): @@ -774,5 +777,11 @@ def override( 'Tools cannot be contextually overridden inside a Temporal workflow, they must be set at agent creation time.' ) - with super().override(deps=deps, model=model, toolsets=toolsets, tools=tools): + with super().override( + deps=deps, + model=model, + toolsets=toolsets, + tools=tools, + instructions=instructions, + ): yield diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 82f1260df7..024f32f686 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -112,6 +112,7 @@ class MCPServer(AbstractToolset[Any], ABC): _client: ClientSession _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] _write_stream: MemoryObjectSendStream[SessionMessage] + _server_info: mcp_types.Implementation def __init__( self, @@ -177,6 +178,15 @@ def label(self) -> str: def tool_name_conflict_hint(self) -> str: return 'Set the `tool_prefix` attribute to avoid name conflicts.' + @property + def server_info(self) -> mcp_types.Implementation: + """Access the information send by the MCP server during initialization.""" + if getattr(self, '_server_info', None) is None: + raise AttributeError( + f'The `{self.__class__.__name__}.server_info` is only instantiated after initialization.' + ) + return self._server_info + async def list_tools(self) -> list[mcp_types.Tool]: """Retrieve tools that are currently active on the server. @@ -312,8 +322,8 @@ async def __aenter__(self) -> Self: self._client = await exit_stack.enter_async_context(client) with anyio.fail_after(self.timeout): - await self._client.initialize() - + result = await self._client.initialize() + self._server_info = result.serverInfo if log_level := self.log_level: await self._client.set_logging_level(log_level) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index c5a0d26724..508c88554b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -41,7 +41,7 @@ ) from ..output import OutputMode from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec -from ..settings import ModelSettings +from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition from ..usage import RequestUsage @@ -390,6 +390,23 @@ def customize_request_parameters(self, model_request_parameters: ModelRequestPar return model_request_parameters + def prepare_request( + self, + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> tuple[ModelSettings | None, ModelRequestParameters]: + """Prepare request inputs before they are passed to the provider. + + This merges the given ``model_settings`` with the model's own ``settings`` attribute and ensures + ``customize_request_parameters`` is applied to the resolved + [`ModelRequestParameters`][pydantic_ai.models.ModelRequestParameters]. Subclasses can override this method if + they need to customize the preparation flow further, but most implementations should simply call + ``self.prepare_request(...)`` at the start of their ``request`` (and related) methods. + """ + merged_settings = merge_model_settings(self.settings, model_settings) + customized_parameters = self.customize_request_parameters(model_request_parameters) + return merged_settings, customized_parameters + @property @abstractmethod def model_name(self) -> str: diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 26cd3af956..6c6c99210e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -205,6 +205,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._messages_create( messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters ) @@ -220,6 +224,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._messages_create( messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 40da9d262a..2e4a65c6ba 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -264,6 +264,10 @@ async def request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) settings = cast(BedrockModelSettings, model_settings or {}) response = await self._messages_create(messages, False, settings, model_request_parameters) model_response = await self._process_response(response) @@ -277,6 +281,10 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) settings = cast(BedrockModelSettings, model_settings or {}) response = await self._messages_create(messages, True, settings, model_request_parameters) yield BedrockStreamedResponse( diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index 18f3b5e084..1a08db4348 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -165,6 +165,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters) model_response = self._process_response(response) return model_response diff --git a/pydantic_ai_slim/pydantic_ai/models/fallback.py b/pydantic_ai_slim/pydantic_ai/models/fallback.py index cf24f7002f..c8430f5775 100644 --- a/pydantic_ai_slim/pydantic_ai/models/fallback.py +++ b/pydantic_ai_slim/pydantic_ai/models/fallback.py @@ -11,7 +11,6 @@ from pydantic_ai.models.instrumented import InstrumentedModel from ..exceptions import FallbackExceptionGroup, ModelHTTPError -from ..settings import merge_model_settings from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model if TYPE_CHECKING: @@ -78,10 +77,8 @@ async def request( exceptions: list[Exception] = [] for model in self.models: - customized_model_request_parameters = model.customize_request_parameters(model_request_parameters) - merged_settings = merge_model_settings(model.settings, model_settings) try: - response = await model.request(messages, merged_settings, customized_model_request_parameters) + response = await model.request(messages, model_settings, model_request_parameters) except Exception as exc: if self._fallback_on(exc): exceptions.append(exc) @@ -105,14 +102,10 @@ async def request_stream( exceptions: list[Exception] = [] for model in self.models: - customized_model_request_parameters = model.customize_request_parameters(model_request_parameters) - merged_settings = merge_model_settings(model.settings, model_settings) async with AsyncExitStack() as stack: try: response = await stack.enter_async_context( - model.request_stream( - messages, merged_settings, customized_model_request_parameters, run_context - ) + model.request_stream(messages, model_settings, model_request_parameters, run_context) ) except Exception as exc: if self._fallback_on(exc): diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index e9cc927735..c0c0d7d62f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -125,6 +125,10 @@ async def request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) agent_info = AgentInfo( function_tools=model_request_parameters.function_tools, allow_text_output=model_request_parameters.allow_text_output, @@ -154,6 +158,10 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) agent_info = AgentInfo( function_tools=model_request_parameters.function_tools, allow_text_output=model_request_parameters.allow_text_output, diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 5f17800a78..fc0a2973be 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -155,6 +155,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) async with self._make_request( messages, False, cast(GeminiModelSettings, model_settings or {}), model_request_parameters ) as http_response: @@ -171,6 +175,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) async with self._make_request( messages, True, cast(GeminiModelSettings, model_settings or {}), model_request_parameters ) as http_response: diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 6e05675fa7..4e01ca18f8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -225,6 +225,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) model_settings = cast(GoogleModelSettings, model_settings or {}) response = await self._generate_content(messages, False, model_settings, model_request_parameters) return self._process_response(response) @@ -236,6 +240,10 @@ async def count_tokens( model_request_parameters: ModelRequestParameters, ) -> usage.RequestUsage: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) model_settings = cast(GoogleModelSettings, model_settings or {}) contents, generation_config = await self._build_content_and_config( messages, model_settings, model_request_parameters @@ -291,6 +299,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) model_settings = cast(GoogleModelSettings, model_settings or {}) response = await self._generate_content(messages, True, model_settings, model_request_parameters) yield await self._process_streamed_response(response, model_request_parameters) # type: ignore diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index b05913d2d6..81907d3944 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -182,6 +182,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) try: response = await self._completions_create( messages, False, cast(GroqModelSettings, model_settings or {}), model_request_parameters @@ -218,6 +222,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, True, cast(GroqModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 23473fb7d6..6fbbfd3dc8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -166,6 +166,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters ) @@ -181,6 +185,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters ) @@ -377,8 +385,6 @@ def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool: }, } ) - if f.strict is not None: - tool_param['function']['strict'] = f.strict return tool_param async def _map_user_message( diff --git a/pydantic_ai_slim/pydantic_ai/models/instrumented.py b/pydantic_ai_slim/pydantic_ai/models/instrumented.py index 57dc235da6..84cd23ba80 100644 --- a/pydantic_ai_slim/pydantic_ai/models/instrumented.py +++ b/pydantic_ai_slim/pydantic_ai/models/instrumented.py @@ -21,6 +21,8 @@ from opentelemetry.util.types import AttributeValue from pydantic import TypeAdapter +from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION + from .. import _otel_messages from .._run_context import RunContext from ..messages import ( @@ -90,7 +92,7 @@ class InstrumentationSettings: event_mode: Literal['attributes', 'logs'] = 'attributes' include_binary_content: bool = True include_content: bool = True - version: Literal[1, 2] = 1 + version: Literal[1, 2, 3] = DEFAULT_INSTRUMENTATION_VERSION def __init__( self, @@ -99,7 +101,7 @@ def __init__( meter_provider: MeterProvider | None = None, include_binary_content: bool = True, include_content: bool = True, - version: Literal[1, 2] = 2, + version: Literal[1, 2, 3] = DEFAULT_INSTRUMENTATION_VERSION, event_mode: Literal['attributes', 'logs'] = 'attributes', event_logger_provider: EventLoggerProvider | None = None, ): @@ -352,8 +354,12 @@ async def request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: - with self._instrument(messages, model_settings, model_request_parameters) as finish: - response = await super().request(messages, model_settings, model_request_parameters) + prepared_settings, prepared_parameters = self.wrapped.prepare_request( + model_settings, + model_request_parameters, + ) + with self._instrument(messages, prepared_settings, prepared_parameters) as finish: + response = await self.wrapped.request(messages, model_settings, model_request_parameters) finish(response) return response @@ -365,10 +371,14 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: - with self._instrument(messages, model_settings, model_request_parameters) as finish: + prepared_settings, prepared_parameters = self.wrapped.prepare_request( + model_settings, + model_request_parameters, + ) + with self._instrument(messages, prepared_settings, prepared_parameters) as finish: response_stream: StreamedResponse | None = None try: - async with super().request_stream( + async with self.wrapped.request_stream( messages, model_settings, model_request_parameters, run_context ) as response_stream: yield response_stream diff --git a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py index d1f3ade1e0..7b54ba0f6d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py +++ b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py @@ -52,6 +52,8 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: system_prompt, sampling_messages = _mcp.map_from_pai_messages(messages) + + model_settings, _ = self.prepare_request(model_settings, model_request_parameters) model_settings = cast(MCPSamplingModelSettings, model_settings or {}) result = await self.session.create_message( diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 0c749f3c60..dec09c0b1c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -185,6 +185,10 @@ async def request( ) -> ModelResponse: """Make a non-streaming request to the model from Pydantic AI call.""" check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters ) @@ -201,6 +205,10 @@ async def request_stream( ) -> AsyncIterator[StreamedResponse]: """Make a streaming request to the model from Pydantic AI call.""" check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._stream_completions_create( messages, cast(MistralModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 843716fcef..3ad09cb579 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -393,6 +393,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, False, cast(OpenAIChatModelSettings, model_settings or {}), model_request_parameters ) @@ -408,6 +412,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._completions_create( messages, True, cast(OpenAIChatModelSettings, model_settings or {}), model_request_parameters ) @@ -926,6 +934,10 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._responses_create( messages, False, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters ) @@ -940,6 +952,10 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: check_allow_model_requests() + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) response = await self._responses_create( messages, True, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters ) diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index bb374c78a3..7a01e5324c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -110,6 +110,10 @@ async def request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) self.last_model_request_parameters = model_request_parameters model_response = self._request(messages, model_settings, model_request_parameters) model_response.usage = _estimate_usage([*messages, model_response]) @@ -123,6 +127,10 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) self.last_model_request_parameters = model_request_parameters model_response = self._request(messages, model_settings, model_request_parameters) diff --git a/pydantic_ai_slim/pydantic_ai/models/wrapper.py b/pydantic_ai_slim/pydantic_ai/models/wrapper.py index 4c91991cc1..3260cc7d65 100644 --- a/pydantic_ai_slim/pydantic_ai/models/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/models/wrapper.py @@ -46,6 +46,13 @@ async def request_stream( def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters: return self.wrapped.customize_request_parameters(model_request_parameters) + def prepare_request( + self, + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> tuple[ModelSettings | None, ModelRequestParameters]: + return self.wrapped.prepare_request(model_settings, model_request_parameters) + @property def model_name(self) -> str: return self.wrapped.model_name diff --git a/pydantic_evals/pydantic_evals/dataset.py b/pydantic_evals/pydantic_evals/dataset.py index 98c703b048..ecc7697fdf 100644 --- a/pydantic_evals/pydantic_evals/dataset.py +++ b/pydantic_evals/pydantic_evals/dataset.py @@ -98,6 +98,7 @@ class _DatasetModel(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forb # $schema is included to avoid validation fails from the `$schema` key, see `_add_json_schema` below for context json_schema_path: str | None = Field(default=None, alias='$schema') + name: str | None = None cases: list[_CaseModel[InputsT, OutputT, MetadataT]] evaluators: list[EvaluatorSpec] = Field(default_factory=list) @@ -218,6 +219,8 @@ async def main(): ``` """ + name: str | None = None + """Optional name of the dataset.""" cases: list[Case[InputsT, OutputT, MetadataT]] """List of test cases in the dataset.""" evaluators: list[Evaluator[InputsT, OutputT, MetadataT]] = [] @@ -226,12 +229,14 @@ async def main(): def __init__( self, *, + name: str | None = None, cases: Sequence[Case[InputsT, OutputT, MetadataT]], evaluators: Sequence[Evaluator[InputsT, OutputT, MetadataT]] = (), ): """Initialize a new dataset with test cases and optional evaluators. Args: + name: Optional name for the dataset. cases: Sequence of test cases to include in the dataset. evaluators: Optional sequence of evaluators to apply to all cases in the dataset. """ @@ -244,10 +249,12 @@ def __init__( case_names.add(case.name) super().__init__( + name=name, cases=cases, evaluators=list(evaluators), ) + # TODO in v2: Make everything not required keyword-only async def evaluate( self, task: Callable[[InputsT], Awaitable[OutputT]] | Callable[[InputsT], OutputT], @@ -256,6 +263,8 @@ async def evaluate( progress: bool = True, retry_task: RetryConfig | None = None, retry_evaluators: RetryConfig | None = None, + *, + task_name: str | None = None, ) -> EvaluationReport[InputsT, OutputT, MetadataT]: """Evaluates the test cases in the dataset using the given task. @@ -265,28 +274,38 @@ async def evaluate( Args: task: The task to evaluate. This should be a callable that takes the inputs of the case and returns the output. - name: The name of the task being evaluated, this is used to identify the task in the report. - If omitted, the name of the task function will be used. + name: The name of the experiment being run, this is used to identify the experiment in the report. + If omitted, the task_name will be used; if that is not specified, the name of the task function is used. max_concurrency: The maximum number of concurrent evaluations of the task to allow. If None, all cases will be evaluated concurrently. progress: Whether to show a progress bar for the evaluation. Defaults to `True`. retry_task: Optional retry configuration for the task execution. retry_evaluators: Optional retry configuration for evaluator execution. + task_name: Optional override to the name of the task being executed, otherwise the name of the task + function will be used. Returns: A report containing the results of the evaluation. """ - name = name or get_unwrapped_function_name(task) + task_name = task_name or get_unwrapped_function_name(task) + name = name or task_name total_cases = len(self.cases) progress_bar = Progress() if progress else None limiter = anyio.Semaphore(max_concurrency) if max_concurrency is not None else AsyncExitStack() with ( - logfire_span('evaluate {name}', name=name, n_cases=len(self.cases)) as eval_span, + logfire_span( + 'evaluate {name}', + name=name, + task_name=task_name, + dataset_name=self.name, + n_cases=len(self.cases), + **{'gen_ai.operation.name': 'experiment'}, # pyright: ignore[reportArgumentType] + ) as eval_span, progress_bar or nullcontext(), ): - task_id = progress_bar.add_task(f'Evaluating {name}', total=total_cases) if progress_bar else None + task_id = progress_bar.add_task(f'Evaluating {task_name}', total=total_cases) if progress_bar else None async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name: str): async with limiter: @@ -357,7 +376,7 @@ def evaluate_sync( return get_event_loop().run_until_complete( self.evaluate( task, - name=name, + task_name=name, max_concurrency=max_concurrency, progress=progress, retry_task=retry_task, @@ -474,7 +493,7 @@ def from_file( raw = Path(path).read_text() try: - return cls.from_text(raw, fmt=fmt, custom_evaluator_types=custom_evaluator_types) + return cls.from_text(raw, fmt=fmt, custom_evaluator_types=custom_evaluator_types, default_name=path.stem) except ValidationError as e: # pragma: no cover raise ValueError(f'{path} contains data that does not match the schema for {cls.__name__}:\n{e}.') from e @@ -484,6 +503,8 @@ def from_text( contents: str, fmt: Literal['yaml', 'json'] = 'yaml', custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (), + *, + default_name: str | None = None, ) -> Self: """Load a dataset from a string. @@ -492,6 +513,7 @@ def from_text( fmt: Format of the content. Must be either 'yaml' or 'json'. custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset. These are additional evaluators beyond the default ones. + default_name: Default name of the dataset, to be used if not specified in the serialized contents. Returns: A new Dataset instance parsed from the string. @@ -501,17 +523,19 @@ def from_text( """ if fmt == 'yaml': loaded = yaml.safe_load(contents) - return cls.from_dict(loaded, custom_evaluator_types) + return cls.from_dict(loaded, custom_evaluator_types, default_name=default_name) else: dataset_model_type = cls._serialization_type() dataset_model = dataset_model_type.model_validate_json(contents) - return cls._from_dataset_model(dataset_model, custom_evaluator_types) + return cls._from_dataset_model(dataset_model, custom_evaluator_types, default_name) @classmethod def from_dict( cls, data: dict[str, Any], custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (), + *, + default_name: str | None = None, ) -> Self: """Load a dataset from a dictionary. @@ -519,6 +543,7 @@ def from_dict( data: Dictionary representation of the dataset. custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset. These are additional evaluators beyond the default ones. + default_name: Default name of the dataset, to be used if not specified in the data. Returns: A new Dataset instance created from the dictionary. @@ -528,19 +553,21 @@ def from_dict( """ dataset_model_type = cls._serialization_type() dataset_model = dataset_model_type.model_validate(data) - return cls._from_dataset_model(dataset_model, custom_evaluator_types) + return cls._from_dataset_model(dataset_model, custom_evaluator_types, default_name) @classmethod def _from_dataset_model( cls, dataset_model: _DatasetModel[InputsT, OutputT, MetadataT], custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (), + default_name: str | None = None, ) -> Self: """Create a Dataset from a _DatasetModel. Args: dataset_model: The _DatasetModel to convert. custom_evaluator_types: Custom evaluator classes to register for deserialization. + default_name: Default name of the dataset, to be used if the value is `None` in the provided model. Returns: A new Dataset instance created from the _DatasetModel. @@ -577,7 +604,9 @@ def _from_dataset_model( cases.append(row) if errors: raise ExceptionGroup(f'{len(errors)} error(s) loading evaluators from registry', errors[:3]) - result = cls(cases=cases) + result = cls(name=dataset_model.name, cases=cases) + if result.name is None: + result.name = default_name result.evaluators = dataset_evaluators return result diff --git a/tests/evals/test_dataset.py b/tests/evals/test_dataset.py index d94b8f9825..450bcacfd8 100644 --- a/tests/evals/test_dataset.py +++ b/tests/evals/test_dataset.py @@ -7,6 +7,7 @@ from typing import Any import pytest +import yaml from dirty_equals import HasRepr, IsNumber from inline_snapshot import snapshot from pydantic import BaseModel, TypeAdapter @@ -106,7 +107,7 @@ def example_cases() -> list[Case[TaskInput, TaskOutput, TaskMetadata]]: def example_dataset( example_cases: list[Case[TaskInput, TaskOutput, TaskMetadata]], ) -> Dataset[TaskInput, TaskOutput, TaskMetadata]: - return Dataset[TaskInput, TaskOutput, TaskMetadata](cases=example_cases) + return Dataset[TaskInput, TaskOutput, TaskMetadata](name='example', cases=example_cases) @pytest.fixture @@ -820,10 +821,29 @@ async def test_serialization_to_yaml(example_dataset: Dataset[TaskInput, TaskOut # Test loading back loaded_dataset = Dataset[TaskInput, TaskOutput, TaskMetadata].from_file(yaml_path) assert len(loaded_dataset.cases) == 2 + assert loaded_dataset.name == 'example' assert loaded_dataset.cases[0].name == 'case1' assert loaded_dataset.cases[0].inputs.query == 'What is 2+2?' +async def test_deserializing_without_name( + example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata], tmp_path: Path +): + """Test serializing a dataset to YAML.""" + # Save the dataset + yaml_path = tmp_path / 'test_cases.yaml' + example_dataset.to_file(yaml_path) + + # Rewrite the file _without_ a name to test deserializing a name-less file + obj = yaml.safe_load(yaml_path.read_text()) + obj.pop('name', None) + yaml_path.write_text(yaml.dump(obj)) + + # Test loading results in the name coming from the filename stem + loaded_dataset = Dataset[TaskInput, TaskOutput, TaskMetadata].from_file(yaml_path) + assert loaded_dataset.name == 'test_cases' + + async def test_serialization_to_json(example_dataset: Dataset[TaskInput, TaskOutput, TaskMetadata], tmp_path: Path): """Test serializing a dataset to JSON.""" json_path = tmp_path / 'test_cases.json' @@ -855,6 +875,7 @@ def test_serialization_errors(tmp_path: Path): async def test_from_text(): """Test creating a dataset from text.""" dataset_dict = { + 'name': 'my dataset', 'cases': [ { 'name': '1', @@ -874,6 +895,7 @@ async def test_from_text(): } loaded_dataset = Dataset[TaskInput, TaskOutput, TaskMetadata].from_text(json.dumps(dataset_dict)) + assert loaded_dataset.name == 'my dataset' assert loaded_dataset.cases == snapshot( [ Case( @@ -1241,7 +1263,7 @@ async def test_dataset_evaluate_with_custom_name(example_dataset: Dataset[TaskIn async def task(inputs: TaskInput) -> TaskOutput: return TaskOutput(answer=inputs.query.upper()) - report = await example_dataset.evaluate(task, name='custom_task') + report = await example_dataset.evaluate(task, task_name='custom_task') assert report.name == 'custom_task' @@ -1491,16 +1513,26 @@ async def mock_async_task(inputs: TaskInput) -> TaskOutput: ( 'evaluate {name}', { - 'name': 'mock_async_task', - 'n_cases': 2, 'assertion_pass_rate': 1.0, - 'logfire.msg_template': 'evaluate {name}', - 'logfire.msg': 'evaluate mock_async_task', - 'logfire.span_type': 'span', + 'dataset_name': 'example', + 'gen_ai.operation.name': 'experiment', 'logfire.json_schema': { + 'properties': { + 'assertion_pass_rate': {}, + 'dataset_name': {}, + 'gen_ai.operation.name': {}, + 'n_cases': {}, + 'name': {}, + 'task_name': {}, + }, 'type': 'object', - 'properties': {'name': {}, 'n_cases': {}, 'assertion_pass_rate': {}}, }, + 'logfire.msg': 'evaluate mock_async_task', + 'logfire.msg_template': 'evaluate {name}', + 'logfire.span_type': 'span', + 'n_cases': 2, + 'name': 'mock_async_task', + 'task_name': 'mock_async_task', }, ), ( diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py index 1220d3f0fd..673e67e105 100644 --- a/tests/models/test_fallback.py +++ b/tests/models/test_fallback.py @@ -169,6 +169,7 @@ def test_first_failed_instrumented(capfire: CaptureLogfire) -> None: 'attributes': { 'model_name': 'fallback:function:failure_response:,function:success_response:', 'agent_name': 'agent', + 'gen_ai.agent.name': 'agent', 'logfire.msg': 'agent run', 'logfire.span_type': 'span', 'gen_ai.usage.input_tokens': 51, @@ -268,6 +269,7 @@ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None 'attributes': { 'model_name': 'fallback:function::failure_response_stream,function::success_response_stream', 'agent_name': 'agent', + 'gen_ai.agent.name': 'agent', 'logfire.msg': 'agent run', 'logfire.span_type': 'span', 'gen_ai.usage.input_tokens': 50, @@ -375,6 +377,7 @@ def test_all_failed_instrumented(capfire: CaptureLogfire) -> None: 'attributes': { 'model_name': 'fallback:function:failure_response:,function:failure_response:', 'agent_name': 'agent', + 'gen_ai.agent.name': 'agent', 'logfire.msg': 'agent run', 'logfire.span_type': 'span', 'pydantic_ai.all_messages': [{'role': 'user', 'parts': [{'type': 'text', 'content': 'hello'}]}], diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index c872c6154b..7c9493fed7 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -893,51 +893,6 @@ async def test_thinking_part_in_history(allow_model_requests: None): ) -@pytest.mark.parametrize('strict', [True, False, None]) -async def test_tool_strict_mode(allow_model_requests: None, strict: bool | None): - tool_call = ChatCompletionOutputToolCall.parse_obj_as_instance( # type:ignore - { - 'function': ChatCompletionOutputFunctionDefinition.parse_obj_as_instance( # type:ignore - { - 'name': 'my_tool', - 'arguments': '{"x": 42}', - } - ), - 'id': '1', - 'type': 'function', - } - ) - responses = [ - completion_message( - ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore - { - 'content': None, - 'role': 'assistant', - 'tool_calls': [tool_call], - } - ) - ), - completion_message(ChatCompletionOutputMessage(content='final response', role='assistant')), # type: ignore - ] - mock_client = MockHuggingFace.create_mock(responses) - model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) - agent = Agent(model) - - @agent.tool_plain(strict=strict) - def my_tool(x: int) -> int: - return x - - result = await agent.run('hello') - assert result.output == 'final response' - - kwargs = get_mock_chat_completion_kwargs(mock_client)[0] - tools = kwargs['tools'] - if strict is not None: - assert tools[0]['function']['strict'] is strict - else: - assert 'strict' not in tools[0]['function'] - - @pytest.mark.parametrize( 'content_item, error_message', [ diff --git a/tests/models/test_model_settings.py b/tests/models/test_model_settings.py index 4db59e741d..3bcbedbafd 100644 --- a/tests/models/test_model_settings.py +++ b/tests/models/test_model_settings.py @@ -2,7 +2,12 @@ from __future__ import annotations -from pydantic_ai import Agent, ModelMessage, ModelResponse, TextPart +import asyncio + +from pydantic_ai import Agent +from pydantic_ai.direct import model_request as direct_model_request +from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart +from pydantic_ai.models import ModelRequestParameters from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.instrumented import InstrumentedModel from pydantic_ai.models.test import TestModel @@ -161,3 +166,39 @@ def capture_settings(messages: list[ModelMessage], agent_info: AgentInfo) -> Mod assert captured_settings is not None assert captured_settings.get('temperature') == 0.75 assert len(captured_settings) == 1 # Only one setting should be present + + +def test_direct_model_request_merges_model_settings(): + """Ensure direct requests merge model defaults with provided run settings.""" + + captured_settings = None + + async def capture(messages: list[ModelMessage], agent_info: AgentInfo) -> ModelResponse: + nonlocal captured_settings + captured_settings = agent_info.model_settings + return ModelResponse(parts=[TextPart('ok')]) + + model = FunctionModel( + capture, + settings=ModelSettings(max_tokens=50, temperature=0.3), + ) + + messages: list[ModelMessage] = [ModelRequest.user_text_prompt('hi')] + run_settings = ModelSettings(temperature=0.9, top_p=0.2) + + async def _run() -> ModelResponse: + return await direct_model_request( + model, + messages, + model_settings=run_settings, + model_request_parameters=ModelRequestParameters(), + ) + + response = asyncio.run(_run()) + + assert response.parts == [TextPart('ok')] + assert captured_settings == { + 'max_tokens': 50, + 'temperature': 0.9, + 'top_p': 0.2, + } diff --git a/tests/test_agent.py b/tests/test_agent.py index fcffd2b846..f11ba9aadb 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,3 +1,4 @@ +import asyncio import json import re import sys @@ -5148,3 +5149,231 @@ def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ), ] ) + + +def test_override_instructions_basic(): + """Test that override can override instructions.""" + agent = Agent('test') + + @agent.instructions + def instr_fn() -> str: + return 'SHOULD_BE_IGNORED' + + with capture_run_messages() as base_messages: + agent.run_sync('Hello', model=TestModel(custom_output_text='baseline')) + + base_req = base_messages[0] + assert isinstance(base_req, ModelRequest) + assert base_req.instructions == 'SHOULD_BE_IGNORED' + + with agent.override(instructions='OVERRIDE'): + with capture_run_messages() as messages: + agent.run_sync('Hello', model=TestModel(custom_output_text='ok')) + + req = messages[0] + assert isinstance(req, ModelRequest) + assert req.instructions == 'OVERRIDE' + + +def test_override_reset_after_context(): + """Test that instructions are reset after exiting the override context.""" + agent = Agent('test', instructions='ORIG') + + with agent.override(instructions='NEW'): + with capture_run_messages() as messages_new: + agent.run_sync('Hi', model=TestModel(custom_output_text='ok')) + + with capture_run_messages() as messages_orig: + agent.run_sync('Hi', model=TestModel(custom_output_text='ok')) + + req_new = messages_new[0] + assert isinstance(req_new, ModelRequest) + req_orig = messages_orig[0] + assert isinstance(req_orig, ModelRequest) + assert req_new.instructions == 'NEW' + assert req_orig.instructions == 'ORIG' + + +def test_override_none_clears_instructions(): + """Test that passing None for instructions clears all instructions.""" + agent = Agent('test', instructions='BASE') + + @agent.instructions + def instr_fn() -> str: # pragma: no cover - ignored under override + return 'ALSO_BASE' + + with agent.override(instructions=None): + with capture_run_messages() as messages: + agent.run_sync('Hello', model=TestModel(custom_output_text='ok')) + + req = messages[0] + assert isinstance(req, ModelRequest) + assert req.instructions is None + + +def test_override_instructions_callable_replaces_functions(): + """Override with a callable should replace existing instruction functions.""" + agent = Agent('test') + + @agent.instructions + def base_fn() -> str: + return 'BASE_FN' + + def override_fn() -> str: + return 'OVERRIDE_FN' + + with capture_run_messages() as base_messages: + agent.run_sync('Hello', model=TestModel(custom_output_text='baseline')) + + base_req = base_messages[0] + assert isinstance(base_req, ModelRequest) + assert base_req.instructions is not None + assert 'BASE_FN' in base_req.instructions + + with agent.override(instructions=override_fn): + with capture_run_messages() as messages: + agent.run_sync('Hello', model=TestModel(custom_output_text='ok')) + + req = messages[0] + assert isinstance(req, ModelRequest) + assert req.instructions == 'OVERRIDE_FN' + assert 'BASE_FN' not in req.instructions + + +async def test_override_instructions_async_callable(): + """Override with an async callable should be awaited.""" + agent = Agent('test') + + async def override_fn() -> str: + await asyncio.sleep(0) + return 'ASYNC_FN' + + with agent.override(instructions=override_fn): + with capture_run_messages() as messages: + await agent.run('Hi', model=TestModel(custom_output_text='ok')) + + req = messages[0] + assert isinstance(req, ModelRequest) + assert req.instructions == 'ASYNC_FN' + + +def test_override_instructions_sequence_mixed_types(): + """Override can mix literal strings and functions.""" + agent = Agent('test', instructions='BASE') + + def override_fn() -> str: + return 'FUNC_PART' + + def override_fn_2() -> str: + return 'FUNC_PART_2' + + with agent.override(instructions=['OVERRIDE1', override_fn, 'OVERRIDE2', override_fn_2]): + with capture_run_messages() as messages: + agent.run_sync('Hello', model=TestModel(custom_output_text='ok')) + + req = messages[0] + assert isinstance(req, ModelRequest) + assert req.instructions == 'OVERRIDE1\nOVERRIDE2\n\nFUNC_PART\n\nFUNC_PART_2' + assert 'BASE' not in req.instructions + + +async def test_override_concurrent_isolation(): + """Test that concurrent overrides are isolated from each other.""" + agent = Agent('test', instructions='ORIG') + + async def run_with(instr: str) -> str | None: + with agent.override(instructions=instr): + with capture_run_messages() as messages: + await agent.run('Hi', model=TestModel(custom_output_text='ok')) + req = messages[0] + assert isinstance(req, ModelRequest) + return req.instructions + + a, b = await asyncio.gather( + run_with('A'), + run_with('B'), + ) + + assert a == 'A' + assert b == 'B' + + +def test_override_replaces_instructions(): + """Test overriding instructions replaces the base instructions.""" + agent = Agent('test', instructions='ORIG_INSTR') + + with agent.override(instructions='NEW_INSTR'): + with capture_run_messages() as messages: + agent.run_sync('Hi', model=TestModel(custom_output_text='ok')) + + req = messages[0] + assert isinstance(req, ModelRequest) + assert req.instructions == 'NEW_INSTR' + + +def test_override_nested_contexts(): + """Test nested override contexts.""" + agent = Agent('test', instructions='ORIG') + + with agent.override(instructions='OUTER'): + with capture_run_messages() as outer_messages: + agent.run_sync('Hi', model=TestModel(custom_output_text='ok')) + + with agent.override(instructions='INNER'): + with capture_run_messages() as inner_messages: + agent.run_sync('Hi', model=TestModel(custom_output_text='ok')) + + outer_req = outer_messages[0] + assert isinstance(outer_req, ModelRequest) + inner_req = inner_messages[0] + assert isinstance(inner_req, ModelRequest) + + assert outer_req.instructions == 'OUTER' + assert inner_req.instructions == 'INNER' + + +async def test_override_async_run(): + """Test override with async run method.""" + agent = Agent('test', instructions='ORIG') + + with agent.override(instructions='ASYNC_OVERRIDE'): + with capture_run_messages() as messages: + await agent.run('Hi', model=TestModel(custom_output_text='ok')) + + req = messages[0] + assert isinstance(req, ModelRequest) + assert req.instructions == 'ASYNC_OVERRIDE' + + +def test_override_with_dynamic_prompts(): + """Test override interacting with dynamic prompts.""" + agent = Agent('test') + + dynamic_value = 'DYNAMIC' + + @agent.system_prompt + def dynamic_sys() -> str: + return dynamic_value + + @agent.instructions + def dynamic_instr() -> str: + return 'DYNAMIC_INSTR' + + with capture_run_messages() as base_messages: + agent.run_sync('Hi', model=TestModel(custom_output_text='baseline')) + + base_req = base_messages[0] + assert isinstance(base_req, ModelRequest) + assert base_req.instructions == 'DYNAMIC_INSTR' + + # Override should take precedence over dynamic instructions but leave system prompts intact + with agent.override(instructions='OVERRIDE_INSTR'): + with capture_run_messages() as messages: + agent.run_sync('Hi', model=TestModel(custom_output_text='ok')) + + req = messages[0] + assert isinstance(req, ModelRequest) + assert req.instructions == 'OVERRIDE_INSTR' + sys_texts = [p.content for p in req.parts if isinstance(p, SystemPromptPart)] + # The dynamic system prompt should still be present since overrides target instructions only + assert dynamic_value in sys_texts diff --git a/tests/test_direct.py b/tests/test_direct.py index b1149e7239..a26c18c0b4 100644 --- a/tests/test_direct.py +++ b/tests/test_direct.py @@ -7,8 +7,16 @@ import pytest from inline_snapshot import snapshot -from pydantic_ai import ( - Agent, +from pydantic_ai import Agent +from pydantic_ai.direct import ( + StreamedResponseSync, + _prepare_model, # pyright: ignore[reportPrivateUsage] + model_request, + model_request_stream, + model_request_stream_sync, + model_request_sync, +) +from pydantic_ai.messages import ( FinalResultEvent, ModelMessage, ModelRequest, @@ -19,14 +27,6 @@ TextPartDelta, ToolCallPart, ) -from pydantic_ai.direct import ( - StreamedResponseSync, - _prepare_model, # pyright: ignore[reportPrivateUsage] - model_request, - model_request_stream, - model_request_stream_sync, - model_request_sync, -) from pydantic_ai.models import ModelRequestParameters from pydantic_ai.models.instrumented import InstrumentedModel from pydantic_ai.models.test import TestModel diff --git a/tests/test_logfire.py b/tests/test_logfire.py index d40a1492f1..091687f5a1 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -2,7 +2,7 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import Any +from typing import Any, Literal import pytest from dirty_equals import IsInt, IsJson, IsList @@ -31,6 +31,7 @@ class SpanSummary(TypedDict): id: int + name: str message: str children: NotRequired[list[SpanSummary]] @@ -49,7 +50,9 @@ def __init__(self, capfire: CaptureLogfire): id_counter = 0 for span in spans: tid = span['context']['trace_id'], span['context']['span_id'] - span_lookup[tid] = span_summary = SpanSummary(id=id_counter, message=span['attributes']['logfire.msg']) + span_lookup[tid] = span_summary = SpanSummary( + id=id_counter, name=span['name'], message=span['attributes']['logfire.msg'] + ) self.attributes[id_counter] = span['attributes'] id_counter += 1 if parent := span['parent']: @@ -76,6 +79,7 @@ def get_summary() -> LogfireSummary: InstrumentationSettings(version=1, event_mode='attributes'), InstrumentationSettings(version=1, event_mode='logs'), InstrumentationSettings(version=2), + InstrumentationSettings(version=3), ], ) def test_logfire( @@ -97,31 +101,101 @@ async def my_ret(x: int) -> str: assert summary.traces == [] return - assert summary.traces == snapshot( - [ - { - 'id': 0, - 'message': 'my_agent run', - 'children': [ - {'id': 1, 'message': 'chat test'}, + if isinstance(instrument, InstrumentationSettings) and instrument.version == 3: + assert summary.traces == snapshot( + [ + { + 'id': 0, + 'name': 'invoke_agent my_agent', + 'message': 'my_agent run', + 'children': [ + {'id': 1, 'name': 'chat test', 'message': 'chat test'}, + { + 'id': 2, + 'name': 'running tools', + 'message': 'running 1 tool', + 'children': [ + {'id': 3, 'name': 'execute_tool my_ret', 'message': 'running tool: my_ret'}, + ], + }, + {'id': 4, 'name': 'chat test', 'message': 'chat test'}, + ], + } + ] + ) + else: + assert summary.traces == snapshot( + [ + { + 'id': 0, + 'name': 'agent run', + 'message': 'my_agent run', + 'children': [ + {'id': 1, 'name': 'chat test', 'message': 'chat test'}, + { + 'id': 2, + 'name': 'running tools', + 'message': 'running 1 tool', + 'children': [ + {'id': 3, 'name': 'running tool', 'message': 'running tool: my_ret'}, + ], + }, + {'id': 4, 'name': 'chat test', 'message': 'chat test'}, + ], + } + ] + ) + + if instrument is True or (isinstance(instrument, InstrumentationSettings) and instrument.version in (2, 3)): + if instrument is True or isinstance(instrument, InstrumentationSettings) and instrument.version == 2: + # default instrumentation settings + assert summary.traces == snapshot( + [ { - 'id': 2, - 'message': 'running 1 tool', + 'id': 0, + 'name': 'agent run', + 'message': 'my_agent run', 'children': [ - {'id': 3, 'message': 'running tool: my_ret'}, + {'id': 1, 'name': 'chat test', 'message': 'chat test'}, + { + 'id': 2, + 'name': 'running tools', + 'message': 'running 1 tool', + 'children': [{'id': 3, 'name': 'running tool', 'message': 'running tool: my_ret'}], + }, + {'id': 4, 'name': 'chat test', 'message': 'chat test'}, ], - }, - {'id': 4, 'message': 'chat test'}, - ], - } - ] - ) + } + ] + ) + else: + assert summary.traces == snapshot( + [ + { + 'id': 0, + 'name': 'invoke_agent my_agent', + 'message': 'my_agent run', + 'children': [ + {'id': 1, 'name': 'chat test', 'message': 'chat test'}, + { + 'id': 2, + 'name': 'running tools', + 'message': 'running 1 tool', + 'children': [ + {'id': 3, 'name': 'execute_tool my_ret', 'message': 'running tool: my_ret'} + ], + }, + {'id': 4, 'name': 'chat test', 'message': 'chat test'}, + ], + } + ] + ) - if instrument is True or isinstance(instrument, InstrumentationSettings) and instrument.version == 2: assert summary.attributes[0] == snapshot( { 'model_name': 'test', 'agent_name': 'my_agent', + 'gen_ai.agent.name': 'my_agent', 'logfire.msg': 'my_agent run', 'logfire.span_type': 'span', 'final_result': '{"my_ret":"1"}', @@ -175,6 +249,7 @@ async def my_ret(x: int) -> str: { 'model_name': 'test', 'agent_name': 'my_agent', + 'gen_ai.agent.name': 'my_agent', 'logfire.msg': 'my_agent run', 'logfire.span_type': 'span', 'gen_ai.usage.input_tokens': 103, @@ -410,10 +485,7 @@ async def my_ret(x: int) -> str: @pytest.mark.skipif(not logfire_installed, reason='logfire not installed') @pytest.mark.parametrize( 'instrument', - [ - InstrumentationSettings(version=1), - InstrumentationSettings(version=2), - ], + [InstrumentationSettings(version=1), InstrumentationSettings(version=2), InstrumentationSettings(version=3)], ) def test_instructions_with_structured_output( get_logfire_summary: Callable[[], LogfireSummary], instrument: InstrumentationSettings @@ -434,6 +506,7 @@ class MyOutput: { 'model_name': 'test', 'agent_name': 'my_agent', + 'gen_ai.agent.name': 'my_agent', 'logfire.msg': 'my_agent run', 'logfire.span_type': 'span', 'gen_ai.usage.input_tokens': 51, @@ -525,12 +598,37 @@ class MyOutput: ) ) else: + if instrument.version == 2: + assert summary.traces == snapshot( + [ + { + 'id': 0, + 'name': 'agent run', + 'message': 'my_agent run', + 'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}], + } + ] + ) + else: + assert summary.traces == snapshot( + [ + { + 'id': 0, + 'name': 'invoke_agent my_agent', + 'message': 'my_agent run', + 'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}], + } + ] + ) + assert summary.attributes[0] == snapshot( { 'model_name': 'test', 'agent_name': 'my_agent', + 'gen_ai.agent.name': 'my_agent', 'logfire.msg': 'my_agent run', 'logfire.span_type': 'span', + 'final_result': '{"content": "a"}', 'gen_ai.usage.input_tokens': 51, 'gen_ai.usage.output_tokens': 5, 'pydantic_ai.all_messages': IsJson( @@ -562,7 +660,6 @@ class MyOutput: ] ) ), - 'final_result': '{"content": "a"}', 'gen_ai.system_instructions': '[{"type": "text", "content": "Here are some instructions"}]', 'logfire.json_schema': IsJson( snapshot( @@ -619,6 +716,7 @@ class MyOutput: { 'model_name': 'test', 'agent_name': 'my_agent', + 'gen_ai.agent.name': 'my_agent', 'logfire.msg': 'my_agent run', 'logfire.span_type': 'span', 'gen_ai.usage.input_tokens': 51, @@ -690,25 +788,53 @@ class MyOutput: @pytest.mark.skipif(not logfire_installed, reason='logfire not installed') -def test_instructions_with_structured_output_exclude_content_v2( +@pytest.mark.parametrize('version', [2, 3]) +def test_instructions_with_structured_output_exclude_content_v2_v3( get_logfire_summary: Callable[[], LogfireSummary], + version: Literal[2, 3], ) -> None: @dataclass class MyOutput: content: str - settings: InstrumentationSettings = InstrumentationSettings(include_content=False, version=2) + settings: InstrumentationSettings = InstrumentationSettings(include_content=False, version=version) my_agent = Agent(model=TestModel(), instructions='Here are some instructions', instrument=settings) result = my_agent.run_sync('Hello', output_type=MyOutput) - assert result.output == snapshot(MyOutput(content='a')) + assert result.output == MyOutput(content='a') summary = get_logfire_summary() + + if version == 2: + assert summary.traces == snapshot( + [ + { + 'id': 0, + 'name': 'agent run', + 'message': 'my_agent run', + 'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}], + } + ] + ) + else: + assert summary.traces == snapshot( + [ + { + 'id': 0, + 'name': 'invoke_agent my_agent', + 'message': 'my_agent run', + 'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}], + } + ] + ) + + # Version 2 and 3 have identical snapshots for this test case assert summary.attributes[0] == snapshot( { 'model_name': 'test', 'agent_name': 'my_agent', + 'gen_ai.agent.name': 'my_agent', 'logfire.msg': 'my_agent run', 'logfire.span_type': 'span', 'gen_ai.usage.input_tokens': 51, @@ -930,6 +1056,7 @@ async def test_feedback(capfire: CaptureLogfire) -> None: 'attributes': { 'model_name': 'test', 'agent_name': 'agent', + 'gen_ai.agent.name': 'agent', 'logfire.msg': 'agent run', 'logfire.span_type': 'span', 'gen_ai.usage.input_tokens': 51, @@ -1122,6 +1249,129 @@ def get_weather_info(city: str) -> WeatherInfo: return WeatherInfo(temperature=28.7, description='sunny') +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize( + 'instrument', + [ + True, + False, + InstrumentationSettings(version=2), + InstrumentationSettings(version=3), + ], +) +def test_logfire_output_function_v2_v3( + get_logfire_summary: Callable[[], LogfireSummary], + instrument: InstrumentationSettings | bool, +) -> None: + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + my_agent = Agent(model=FunctionModel(call_tool), instrument=instrument) + result = my_agent.run_sync('Mexico City', output_type=get_weather_info) + assert result.output == WeatherInfo(temperature=28.7, description='sunny') + + summary = get_logfire_summary() + + if instrument is True or isinstance(instrument, InstrumentationSettings) and instrument.version == 2: + [output_function_attributes] = [ + attributes + for attributes in summary.attributes.values() + if attributes.get('gen_ai.tool.name') == 'final_result' + ] + assert summary.traces == snapshot( + [ + { + 'id': 0, + 'name': 'agent run', + 'message': 'my_agent run', + 'children': [ + {'id': 1, 'name': 'chat function:call_tool:', 'message': 'chat function:call_tool:'}, + { + 'id': 2, + 'name': 'running output function', + 'message': 'running output function: final_result', + }, + ], + } + ] + ) + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'logfire.msg': 'running output function: final_result', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"city":"Mexico City"}', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': '{"temperature": 28.7, "description": "sunny"}', + } + ) + + elif isinstance(instrument, InstrumentationSettings) and instrument.version == 3: + [output_function_attributes] = [ + attributes + for attributes in summary.attributes.values() + if attributes.get('gen_ai.tool.name') == 'final_result' + ] + assert summary.traces == snapshot( + [ + { + 'id': 0, + 'name': 'invoke_agent my_agent', + 'message': 'my_agent run', + 'children': [ + {'id': 1, 'name': 'chat function:call_tool:', 'message': 'chat function:call_tool:'}, + { + 'id': 2, + 'name': 'execute_tool final_result', + 'message': 'running output function: final_result', + }, + ], + } + ] + ) + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'logfire.msg': 'running output function: final_result', + 'gen_ai.tool.call.id': IsStr(), + 'gen_ai.tool.call.arguments': '{"city":"Mexico City"}', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'gen_ai.tool.call.arguments': {'type': 'object'}, + 'gen_ai.tool.call.result': {'type': 'object'}, + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'gen_ai.tool.call.result': '{"temperature": 28.7, "description": "sunny"}', + } + ) + else: + assert summary.traces == snapshot([]) + assert summary.attributes == snapshot({}) + + @pytest.mark.skipif(not logfire_installed, reason='logfire not installed') @pytest.mark.parametrize('include_content', [True, False]) def test_output_type_function_logfire_attributes( @@ -1160,6 +1410,8 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'properties': { 'tool_arguments': {'type': 'object'}, 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, }, } ) @@ -1174,6 +1426,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'gen_ai.tool.name': 'final_result', 'gen_ai.tool.call.id': IsStr(), 'logfire.msg': 'running output function: final_result', + 'logfire.json_schema': '{"type": "object", "properties": {"gen_ai.tool.name": {}, "gen_ai.tool.call.id": {}}}', 'logfire.span_type': 'span', } ) @@ -1221,6 +1474,8 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'properties': { 'tool_arguments': {'type': 'object'}, 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, }, } ) @@ -1235,6 +1490,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'gen_ai.tool.name': 'final_result', 'gen_ai.tool.call.id': IsStr(), 'logfire.msg': 'running output function: final_result', + 'logfire.json_schema': '{"type": "object", "properties": {"gen_ai.tool.name": {}, "gen_ai.tool.call.id": {}}}', 'logfire.span_type': 'span', } ) @@ -1290,6 +1546,8 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'properties': { 'tool_arguments': {'type': 'object'}, 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, }, } ) @@ -1309,6 +1567,8 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'properties': { 'tool_arguments': {'type': 'object'}, 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, }, } ) @@ -1325,6 +1585,9 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'gen_ai.tool.name': 'final_result', 'logfire.msg': 'running output function: final_result', 'gen_ai.tool.call.id': IsStr(), + 'logfire.json_schema': IsJson( + snapshot({'type': 'object', 'properties': {'gen_ai.tool.name': {}, 'gen_ai.tool.call.id': {}}}) + ), 'logfire.span_type': 'span', 'logfire.level_num': 17, }, @@ -1332,6 +1595,9 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'gen_ai.tool.name': 'final_result', 'logfire.msg': 'running output function: final_result', 'gen_ai.tool.call.id': IsStr(), + 'logfire.json_schema': IsJson( + snapshot({'type': 'object', 'properties': {'gen_ai.tool.name': {}, 'gen_ai.tool.call.id': {}}}) + ), 'logfire.span_type': 'span', }, ] @@ -1378,6 +1644,8 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'properties': { 'tool_arguments': {'type': 'object'}, 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, }, } ) @@ -1392,6 +1660,9 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'gen_ai.tool.name': 'get_weather', 'gen_ai.tool.call.id': IsStr(), 'logfire.msg': 'running output function: get_weather', + 'logfire.json_schema': IsJson( + snapshot({'type': 'object', 'properties': {'gen_ai.tool.name': {}, 'gen_ai.tool.call.id': {}}}) + ), 'logfire.span_type': 'span', } ) @@ -1444,6 +1715,8 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'properties': { 'tool_arguments': {'type': 'object'}, 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, }, } ) @@ -1458,6 +1731,9 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'gen_ai.tool.name': 'final_result', 'gen_ai.tool.call.id': IsStr(), 'logfire.msg': 'running output function: final_result', + 'logfire.json_schema': IsJson( + snapshot({'type': 'object', 'properties': {'gen_ai.tool.name': {}, 'gen_ai.tool.call.id': {}}}) + ), 'logfire.span_type': 'span', } ) @@ -1511,6 +1787,8 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'properties': { 'tool_arguments': {'type': 'object'}, 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, }, } ) @@ -1525,6 +1803,9 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'gen_ai.tool.name': 'final_result', 'gen_ai.tool.call.id': IsStr(), 'logfire.msg': 'running output function: final_result', + 'logfire.json_schema': IsJson( + snapshot({'type': 'object', 'properties': {'gen_ai.tool.name': {}, 'gen_ai.tool.call.id': {}}}) + ), 'logfire.span_type': 'span', } ) @@ -1573,6 +1854,8 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'properties': { 'tool_arguments': {'type': 'object'}, 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, }, } ) @@ -1587,6 +1870,9 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'gen_ai.tool.name': 'final_result', 'gen_ai.tool.call.id': IsStr(), 'logfire.msg': 'running output function: final_result', + 'logfire.json_schema': IsJson( + snapshot({'type': 'object', 'properties': {'gen_ai.tool.name': {}, 'gen_ai.tool.call.id': {}}}) + ), 'logfire.span_type': 'span', } ) @@ -1639,6 +1925,7 @@ def call_text_response(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'properties': { 'tool_arguments': {'type': 'object'}, 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, }, } ) @@ -1652,6 +1939,7 @@ def call_text_response(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: { 'gen_ai.tool.name': 'upcase_text', 'logfire.msg': 'running output function: upcase_text', + 'logfire.json_schema': IsJson(snapshot({'type': 'object', 'properties': {'gen_ai.tool.name': {}}})), 'logfire.span_type': 'span', } ) @@ -1709,6 +1997,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'properties': { 'tool_arguments': {'type': 'object'}, 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, }, } ) @@ -1722,6 +2011,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: { 'gen_ai.tool.name': 'upcase_text', 'logfire.msg': 'running output function: upcase_text', + 'logfire.json_schema': IsJson(snapshot({'type': 'object', 'properties': {'gen_ai.tool.name': {}}})), 'logfire.span_type': 'span', } ) @@ -1779,6 +2069,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'properties': { 'tool_arguments': {'type': 'object'}, 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, }, } ) @@ -1797,6 +2088,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'properties': { 'tool_arguments': {'type': 'object'}, 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, }, } ) @@ -1812,12 +2104,14 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: { 'gen_ai.tool.name': 'get_weather_with_retry', 'logfire.msg': 'running output function: get_weather_with_retry', + 'logfire.json_schema': IsJson(snapshot({'type': 'object', 'properties': {'gen_ai.tool.name': {}}})), 'logfire.span_type': 'span', 'logfire.level_num': 17, }, { 'gen_ai.tool.name': 'get_weather_with_retry', 'logfire.msg': 'running output function: get_weather_with_retry', + 'logfire.json_schema': IsJson(snapshot({'type': 'object', 'properties': {'gen_ai.tool.name': {}}})), 'logfire.span_type': 'span', }, ] diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 125874c43e..6034707dcd 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1474,3 +1474,13 @@ def test_load_mcp_servers(tmp_path: Path): with pytest.raises(FileNotFoundError): load_mcp_servers(tmp_path / 'does_not_exist.json') + + +async def test_server_info(mcp_server: MCPServerStdio) -> None: + with pytest.raises( + AttributeError, match='The `MCPServerStdio.server_info` is only instantiated after initialization.' + ): + mcp_server.server_info + async with mcp_server: + assert mcp_server.server_info is not None + assert mcp_server.server_info.name == 'Pydantic AI MCP Server'