From b4e2bda632c631020b414472515ff715290c45e2 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 19 Feb 2025 15:19:10 -0500 Subject: [PATCH 01/25] state persistence --- docs/graph.md | 7 +- .../pydantic_ai_examples/question_graph.py | 4 +- pydantic_ai_slim/pydantic_ai/agent.py | 4 +- pydantic_graph/pydantic_graph/__init__.py | 8 +- pydantic_graph/pydantic_graph/_utils.py | 5 - pydantic_graph/pydantic_graph/graph.py | 124 ++++++++-------- pydantic_graph/pydantic_graph/nodes.py | 16 ++- pydantic_graph/pydantic_graph/state.py | 126 ---------------- .../pydantic_graph/state/__init__.py | 135 ++++++++++++++++++ pydantic_graph/pydantic_graph/state/_utils.py | 52 +++++++ pydantic_graph/pydantic_graph/state/memory.py | 107 ++++++++++++++ tests/graph/test_graph.py | 42 +++--- tests/graph/test_history.py | 16 +-- tests/graph/test_mermaid.py | 8 +- tests/graph/test_state.py | 8 +- tests/typed_graph.py | 4 +- 16 files changed, 417 insertions(+), 249 deletions(-) delete mode 100644 pydantic_graph/pydantic_graph/state.py create mode 100644 pydantic_graph/pydantic_graph/state/__init__.py create mode 100644 pydantic_graph/pydantic_graph/state/_utils.py create mode 100644 pydantic_graph/pydantic_graph/state/memory.py diff --git a/docs/graph.md b/docs/graph.md index fa1b873430..b62d67e890 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -576,11 +576,10 @@ In this example, an AI asks the user a question, the user provides an answer, th _(This example is complete, it can be run "as is" with Python 3.10+)_ - ```python {title="ai_q_and_a_run.py" noqa="I001" py="3.10"} from rich.prompt import Prompt -from pydantic_graph import End, HistoryStep +from pydantic_graph import End, Snapshot from ai_q_and_a_graph import Ask, question_graph, QuestionState, Answer @@ -588,14 +587,14 @@ from ai_q_and_a_graph import Ask, question_graph, QuestionState, Answer async def main(): state = QuestionState() # (1)! node = Ask() # (2)! - history: list[HistoryStep[QuestionState]] = [] # (3)! + history: list[Snapshot[QuestionState]] = [] # (3)! while True: node = await question_graph.next(node, history, state=state) # (4)! if isinstance(node, Answer): node.answer = Prompt.ask(node.question) # (5)! elif isinstance(node, End): # (6)! print(f'Correct answer! {node.data}') - #> Correct answer! Well done, 1 + 1 = 2 + # > Correct answer! Well done, 1 + 1 = 2 print([e.data_snapshot() for e in history]) """ [ diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index af39466890..957e5d75b1 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -13,7 +13,7 @@ import logfire from devtools import debug -from pydantic_graph import BaseNode, Edge, End, Graph, GraphRunContext, HistoryStep +from pydantic_graph import BaseNode, Edge, End, Graph, GraphRunContext, Snapshot from pydantic_ai import Agent from pydantic_ai.format_as_xml import format_as_xml @@ -116,7 +116,7 @@ async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask: async def run_as_continuous(): state = QuestionState() node = Ask() - history: list[HistoryStep[QuestionState, None]] = [] + history: list[Snapshot[QuestionState, None]] = [] with logfire.span('run questions graph'): while True: node = await question_graph.next(node, history, state=state) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 3501833d23..3978d0138d 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -11,7 +11,7 @@ import logfire_api from typing_extensions import TypeVar, deprecated -from pydantic_graph import Graph, GraphRunContext, HistoryStep +from pydantic_graph import Graph, GraphRunContext, Snapshot from pydantic_graph.nodes import End from . import ( @@ -583,7 +583,7 @@ async def main(): # Actually run node = start_node - history: list[HistoryStep[_agent_graph.GraphAgentState, RunResultDataT]] = [] + history: list[Snapshot[_agent_graph.GraphAgentState, RunResultDataT]] = [] while True: if isinstance(node, _agent_graph.StreamModelRequestNode): node = cast( diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index d4c6074e1a..0b1ff87e9d 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -1,7 +1,7 @@ from .exceptions import GraphRuntimeError, GraphSetupError from .graph import Graph from .nodes import BaseNode, Edge, End, GraphRunContext -from .state import EndStep, HistoryStep, NodeStep +from .state import EndSnapshot, NodeSnapshot, Snapshot __all__ = ( 'Graph', @@ -9,9 +9,9 @@ 'End', 'GraphRunContext', 'Edge', - 'EndStep', - 'HistoryStep', - 'NodeStep', + 'EndSnapshot', + 'Snapshot', + 'NodeSnapshot', 'GraphSetupError', 'GraphRuntimeError', ) diff --git a/pydantic_graph/pydantic_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py index 6753138c60..390534a8b9 100644 --- a/pydantic_graph/pydantic_graph/_utils.py +++ b/pydantic_graph/pydantic_graph/_utils.py @@ -2,7 +2,6 @@ import sys import types -from datetime import datetime, timezone from typing import Annotated, Any, TypeVar, Union, get_args, get_origin import typing_extensions @@ -80,10 +79,6 @@ def get_parent_namespace(frame: types.FrameType | None) -> dict[str, Any] | None return back.f_locals -def now_utc() -> datetime: - return datetime.now(tz=timezone.utc) - - class Unset: """A singleton to represent an unset value. diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index a670c3d399..d4ae197fe2 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -7,9 +7,7 @@ from contextlib import ExitStack from dataclasses import dataclass, field from functools import cached_property -from pathlib import Path -from time import perf_counter -from typing import TYPE_CHECKING, Annotated, Any, Callable, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar import logfire_api import pydantic @@ -17,7 +15,8 @@ from . import _utils, exceptions, mermaid from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT -from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state, nodes_schema_var +from .state import StatePersistence, StateT, node_type_adapter +from .state.memory import LatestMemoryStatePersistence # while waiting for https://github.com/pydantic/logfire/issues/745 try: @@ -84,7 +83,6 @@ async def run(self, ctx: GraphRunContext) -> Increment | End[int]: name: str | None node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] - snapshot_state: Callable[[StateT], StateT] _state_type: type[StateT] | _utils.Unset = field(repr=False) _run_end_type: type[RunEndT] | _utils.Unset = field(repr=False) _auto_instrument: bool = field(repr=False) @@ -96,7 +94,6 @@ def __init__( name: str | None = None, state_type: type[StateT] | _utils.Unset = _utils.UNSET, run_end_type: type[RunEndT] | _utils.Unset = _utils.UNSET, - snapshot_state: Callable[[StateT], StateT] = deep_copy_state, auto_instrument: bool = True, ): """Create a graph from a sequence of nodes. @@ -108,16 +105,12 @@ def __init__( on the first call to a graph method. state_type: The type of the state for the graph, this can generally be inferred from `nodes`. run_end_type: The type of the result of running the graph, this can generally be inferred from `nodes`. - snapshot_state: A function to snapshot the state of the graph, this is used in - [`NodeStep`][pydantic_graph.state.NodeStep] and [`EndStep`][pydantic_graph.state.EndStep] to record - the state before each step. auto_instrument: Whether to create a span for the graph run and the execution of each node's run method. """ self.name = name self._state_type = state_type self._run_end_type = run_end_type self._auto_instrument = auto_instrument - self.snapshot_state = snapshot_state parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) self.node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] = {} @@ -132,8 +125,9 @@ async def run( *, state: StateT = None, deps: DepsT = None, + state_persistence: StatePersistence[StateT, T] | None = None, infer_name: bool = True, - ) -> tuple[T, list[HistoryStep[StateT, T]]]: + ) -> T: """Run the graph from a starting node until it ends. Args: @@ -141,6 +135,8 @@ async def run( you need to provide the starting node. state: The initial state of the graph. deps: The dependencies of the graph. + state_persistence: State persistence interface, defaults to + [`LatestMemoryStatePersistence`][pydantic_graph.state.memory.LatestMemoryStatePersistence] if `None`. infer_name: Whether to infer the graph name from the calling frame. Returns: @@ -170,11 +166,16 @@ async def main(): if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - history: list[HistoryStep[StateT, T]] = [] + if state_persistence is None: + state_persistence = LatestMemoryStatePersistence() + + # have to snapshot state before iterating over nodes, as we'll expect a snapshot in the + # state_persistence soon + await state_persistence.snapshot_node(state, start_node, self._node_type_adapter) + with ExitStack() as stack: - run_span: logfire_api.LogfireSpan | None = None if self._auto_instrument: - run_span = stack.enter_context( + stack.enter_context( _logfire.span( '{graph_name} run {start=}', graph_name=self.name or 'graph', @@ -184,12 +185,9 @@ async def main(): next_node = start_node while True: - next_node = await self.next(next_node, history, state=state, deps=deps, infer_name=False) + next_node = await self.next(next_node, state_persistence, state=state, deps=deps, infer_name=False) if isinstance(next_node, End): - history.append(EndStep(result=next_node)) - if run_span is not None: - run_span.set_attribute('history', history) - return next_node.data, history + return next_node.data elif not isinstance(next_node, BaseNode): if TYPE_CHECKING: typing_extensions.assert_never(next_node) @@ -204,8 +202,9 @@ def run_sync( *, state: StateT = None, deps: DepsT = None, + state_persistence: StatePersistence[StateT, T] | None = None, infer_name: bool = True, - ) -> tuple[T, list[HistoryStep[StateT, T]]]: + ) -> T: """Run the graph synchronously. This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`. @@ -216,6 +215,8 @@ def run_sync( you need to provide the starting node. state: The initial state of the graph. deps: The dependencies of the graph. + state_persistence: State persistence interface, defaults to + [`LatestMemoryStatePersistence`][pydantic_graph.state.memory.LatestMemoryStatePersistence] if `None`. infer_name: Whether to infer the graph name from the calling frame. Returns: @@ -224,13 +225,13 @@ def run_sync( if infer_name and self.name is None: self._infer_name(inspect.currentframe()) return asyncio.get_event_loop().run_until_complete( - self.run(start_node, state=state, deps=deps, infer_name=False) + self.run(start_node, state=state, deps=deps, state_persistence=state_persistence, infer_name=False) ) async def next( self: Graph[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T], - history: list[HistoryStep[StateT, T]], + state_persistence: StatePersistence[StateT, T], *, state: StateT = None, deps: DepsT = None, @@ -240,7 +241,7 @@ async def next( Args: node: The node to run. - history: The history of the graph run so far. NOTE: this will be mutated to add the new step. + state_persistence: State persistence interface. state: The current state of the graph. deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. @@ -258,52 +259,33 @@ async def next( if self._auto_instrument: stack.enter_context(_logfire.span('run node {node_id}', node_id=node_id, node=node)) ctx = GraphRunContext(state, deps) - start_ts = _utils.now_utc() - start = perf_counter() - next_node = await node.run(ctx) - duration = perf_counter() - start + async with state_persistence.record_run(): + next_or_end = await node.run(ctx) - history.append( - NodeStep(state=state, node=node, start_ts=start_ts, duration=duration, snapshot_state=self.snapshot_state) - ) - return next_node - - def dump_history( - self: Graph[StateT, DepsT, T], history: list[HistoryStep[StateT, T]], *, indent: int | None = None - ) -> bytes: - """Dump the history of a graph run as JSON. - - Args: - history: The history of the graph run. - indent: The number of spaces to indent the JSON. - - Returns: - The JSON representation of the history. - """ - return self.history_type_adapter.dump_json(history, indent=indent) - - def load_history(self, json_bytes: str | bytes | bytearray) -> list[HistoryStep[StateT, RunEndT]]: - """Load the history of a graph run from JSON. - - Args: - json_bytes: The JSON representation of the history. + if isinstance(next_or_end, BaseNode): + await state_persistence.snapshot_node(state, next_or_end, self._node_type_adapter) + else: + await state_persistence.snapshot_end(state, next_or_end, self._end_data_type_adapter) + return next_or_end - Returns: - The history of the graph run. - """ - return self.history_type_adapter.validate_json(json_bytes) + async def next_from_persistence( + self: Graph[StateT, DepsT, T], + state_persistence: StatePersistence[StateT, T], + *, + deps: DepsT = None, + infer_name: bool = True, + ) -> BaseNode[StateT, DepsT, Any] | End[T]: + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) - @cached_property - def history_type_adapter(self) -> pydantic.TypeAdapter[list[HistoryStep[StateT, RunEndT]]]: - nodes = [node_def.node for node_def in self.node_defs.values()] - state_t = self._get_state_type() - end_t = self._get_run_end_type() - token = nodes_schema_var.set(nodes) - try: - ta = pydantic.TypeAdapter(list[Annotated[HistoryStep[state_t, end_t], pydantic.Discriminator('kind')]]) - finally: - nodes_schema_var.reset(token) - return ta + snapshot = await state_persistence.restore_node_snapshot() + return await self.next( + snapshot.node, + state_persistence, + state=snapshot.state, + deps=deps, + infer_name=False, + ) def mermaid_code( self, @@ -428,6 +410,16 @@ def mermaid_save( kwargs['title'] = self.name mermaid.save_image(path, self, **kwargs) + @cached_property + def _node_type_adapter(self) -> pydantic.TypeAdapter[BaseNode[StateT, Any, RunEndT]]: + nodes = [node_def.node for node_def in self.node_defs.values()] + return node_type_adapter(nodes, self._get_state_type(), self._get_run_end_type()) + + @cached_property + def _end_data_type_adapter(self) -> pydantic.TypeAdapter[RunEndT]: + end_t = self._get_run_end_type() + return pydantic.TypeAdapter(end_t) + def _get_state_type(self) -> type[StateT]: if _utils.is_set(self._state_type): return self._state_type diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index b43391ffef..744896b07a 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -1,9 +1,10 @@ from __future__ import annotations as _annotations +import copy from abc import ABC, abstractmethod from dataclasses import dataclass, is_dataclass from functools import cache -from typing import TYPE_CHECKING, Any, ClassVar, Generic, get_origin, get_type_hints +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Self, get_origin, get_type_hints from typing_extensions import Never, TypeVar @@ -14,7 +15,7 @@ else: StateT = TypeVar('StateT', default=None) -__all__ = 'GraphRunContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'DepsT' +__all__ = 'GraphRunContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'DepsT', 'RunEndT' RunEndT = TypeVar('RunEndT', covariant=True, default=None) """Covariant type variable for the return type of a graph [`run`][pydantic_graph.graph.Graph.run].""" @@ -125,6 +126,10 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, DepsT, returns_base_node, ) + def deep_copy(self) -> Self: + """Returns a deep copy of the node.""" + return copy.deepcopy(self) + @dataclass class End(Generic[RunEndT]): @@ -133,6 +138,13 @@ class End(Generic[RunEndT]): data: RunEndT """Data to return from the graph.""" + def deep_copy_data(self) -> RunEndT: + """Returns a deep copy of the end of the run.""" + if self.data is None: + return self.data + else: + return copy.deepcopy(self.data) + @dataclass class Edge: diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py deleted file mode 100644 index 99bddd6138..0000000000 --- a/pydantic_graph/pydantic_graph/state.py +++ /dev/null @@ -1,126 +0,0 @@ -from __future__ import annotations as _annotations - -import copy -from collections.abc import Sequence -from contextvars import ContextVar -from dataclasses import dataclass, field -from datetime import datetime -from typing import Annotated, Any, Callable, Generic, Literal, Union - -import pydantic -from pydantic_core import core_schema -from typing_extensions import TypeVar - -from . import _utils -from .nodes import BaseNode, End, RunEndT - -__all__ = 'StateT', 'NodeStep', 'EndStep', 'HistoryStep', 'deep_copy_state', 'nodes_schema_var' - - -StateT = TypeVar('StateT', default=None) -"""Type variable for the state in a graph.""" - - -def deep_copy_state(state: StateT) -> StateT: - """Default method for snapshotting the state in a graph run, uses [`copy.deepcopy`][copy.deepcopy].""" - if state is None: - return state - else: - return copy.deepcopy(state) - - -@dataclass -class NodeStep(Generic[StateT, RunEndT]): - """History step describing the execution of a node in a graph.""" - - state: StateT - """The state of the graph after the node has been run.""" - node: Annotated[BaseNode[StateT, Any, RunEndT], CustomNodeSchema()] - """The node that was run.""" - start_ts: datetime = field(default_factory=_utils.now_utc) - """The timestamp when the node started running.""" - duration: float | None = None - """The duration of the node run in seconds.""" - kind: Literal['node'] = 'node' - """The kind of history step, can be used as a discriminator when deserializing history.""" - # TODO waiting for https://github.com/pydantic/pydantic/issues/11264, should be an InitVar - snapshot_state: Annotated[Callable[[StateT], StateT], pydantic.Field(exclude=True, repr=False)] = field( - default=deep_copy_state, repr=False - ) - """Function to snapshot the state of the graph.""" - - def __post_init__(self): - # Copy the state to prevent it from being modified by other code - self.state = self.snapshot_state(self.state) - - def data_snapshot(self) -> BaseNode[StateT, Any, RunEndT]: - """Returns a deep copy of [`self.node`][pydantic_graph.state.NodeStep.node]. - - Useful for summarizing history. - """ - return copy.deepcopy(self.node) - - -@dataclass -class EndStep(Generic[RunEndT]): - """History step describing the end of a graph run.""" - - result: End[RunEndT] - """The result of the graph run.""" - ts: datetime = field(default_factory=_utils.now_utc) - """The timestamp when the graph run ended.""" - kind: Literal['end'] = 'end' - """The kind of history step, can be used as a discriminator when deserializing history.""" - - def data_snapshot(self) -> End[RunEndT]: - """Returns a deep copy of [`self.result`][pydantic_graph.state.EndStep.result]. - - Useful for summarizing history. - """ - return copy.deepcopy(self.result) - - -HistoryStep = Union[NodeStep[StateT, RunEndT], EndStep[RunEndT]] -"""A step in the history of a graph run. - -[`Graph.run`][pydantic_graph.graph.Graph.run] returns a list of these steps describing the execution of the graph, -together with the run return value. -""" - - -nodes_schema_var: ContextVar[Sequence[type[BaseNode[Any, Any, Any]]]] = ContextVar('nodes_var') - - -class CustomNodeSchema: - def __get_pydantic_core_schema__( - self, _source_type: Any, handler: pydantic.GetCoreSchemaHandler - ) -> core_schema.CoreSchema: - try: - nodes = nodes_schema_var.get() - except LookupError as e: - raise RuntimeError( - 'Unable to build a Pydantic schema for `NodeStep` or `HistoryStep` without setting `nodes_schema_var`. ' - 'You probably want to use ' - ) from e - if len(nodes) == 1: - nodes_type = nodes[0] - else: - nodes_annotated = [Annotated[node, pydantic.Tag(node.get_id())] for node in nodes] - nodes_type = Annotated[Union[tuple(nodes_annotated)], pydantic.Discriminator(self._node_discriminator)] - - schema = handler(nodes_type) - schema['serialization'] = core_schema.wrap_serializer_function_ser_schema( - function=self._node_serializer, - return_schema=core_schema.dict_schema(core_schema.str_schema(), core_schema.any_schema()), - ) - return schema - - @staticmethod - def _node_discriminator(node_data: Any) -> str: - return node_data.get('node_id') - - @staticmethod - def _node_serializer(node: Any, handler: pydantic.SerializerFunctionWrapHandler) -> dict[str, Any]: - node_dict = handler(node) - node_dict['node_id'] = node.get_id() - return node_dict diff --git a/pydantic_graph/pydantic_graph/state/__init__.py b/pydantic_graph/pydantic_graph/state/__init__.py new file mode 100644 index 0000000000..748664d584 --- /dev/null +++ b/pydantic_graph/pydantic_graph/state/__init__.py @@ -0,0 +1,135 @@ +from __future__ import annotations as _annotations + +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Sequence +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from datetime import datetime +from typing import Annotated, Any, Generic, Literal, Union + +import pydantic +from typing_extensions import TypeVar + +from .. import exceptions +from ..nodes import BaseNode, End, RunEndT +from . import _utils + +__all__ = 'StateT', 'NodeSnapshot', 'EndSnapshot', 'Snapshot', 'StatePersistence', 'node_type_adapter' + + +StateT = TypeVar('StateT', default=None) +"""Type variable for the state in a graph.""" + + +@dataclass +class NodeSnapshot(Generic[StateT, RunEndT]): + """History step describing the execution of a node in a graph.""" + + state: StateT + """The state of the graph before the node is run.""" + node: Annotated[BaseNode[StateT, Any, RunEndT], _utils.CustomNodeSchema()] + """The node to run next.""" + start_ts: datetime | None = None + """The timestamp when the node started running, `None` until the run starts.""" + duration: float | None = None + """The duration of the node run in seconds, if the node has been run.""" + kind: Literal['node'] = 'node' + """The kind of history step, can be used as a discriminator when deserializing history.""" + + +@dataclass +class EndSnapshot(Generic[StateT, RunEndT]): + """History step describing the end of a graph run.""" + + state: StateT + """The state of the graph at the end of teh run.""" + result: RunEndT + """The result of the graph run.""" + ts: datetime = field(default_factory=_utils.now_utc) + """The timestamp when the graph run ended.""" + kind: Literal['end'] = 'end' + """The kind of history step, can be used as a discriminator when deserializing history.""" + + +Snapshot = Union[NodeSnapshot[StateT, RunEndT], EndSnapshot[StateT, RunEndT]] +"""A step in the history of a graph run. + +[`Graph.run`][pydantic_graph.graph.Graph.run] returns a list of these steps describing the execution of the graph, +together with the run return value. +""" + + +class StatePersistence(ABC, Generic[StateT, RunEndT]): + """Abstract base class for storing the state of a graph.""" + + @abstractmethod + async def snapshot_node( + self, + state: StateT, + next_node: BaseNode[StateT, Any, RunEndT], + node_type_adapter: pydantic.TypeAdapter[BaseNode[StateT, Any, RunEndT]], + ) -> NodeSnapshot[StateT, RunEndT]: + """Snapshot the state of a graph before a node is run. + + Args: + state: The state of the graph. + next_node: The next node to run or end if the graph has ended + node_type_adapter: Pydantic [`TypeAdapter`][pydantic.TypeAdapter] for the node. + + Returns: + The snapshot + """ + raise NotImplementedError + + @abstractmethod + @asynccontextmanager + async def record_run(self) -> AsyncIterator[None]: + """Record the run of the node. + + In particular this should set [`NodeSnapshot.start_ts`][pydantic_graph.state.NodeSnapshot.start_ts] + and [`NodeSnapshot.duration`][pydantic_graph.state.NodeSnapshot.duration]. + """ + yield + raise NotImplementedError + + @abstractmethod + async def snapshot_end( + self, state: StateT, end: End[RunEndT], end_data_type_adapter: pydantic.TypeAdapter[RunEndT] + ) -> None: + """Snapshot the state of a graph before a node is run. + + Args: + state: The state of the graph. + end: data from the end of the run. + end_data_type_adapter: a Pydantic [`TypeAdapter`][pydantic.TypeAdapter] for the end data. + """ + raise NotImplementedError + + @abstractmethod + async def restore(self) -> Snapshot[StateT, RunEndT] | None: + """Retrieve a snapshot. + + Returns: + The most recent [`Snapshot`][pydantic_graph.state.Snapshot] of the run. + """ + raise NotImplementedError + + async def restore_node_snapshot(self) -> NodeSnapshot[StateT, RunEndT]: + snapshot = await self.restore() + if snapshot is None: + raise exceptions.GraphRuntimeError('Unable to restore snapshot from state persistence.') + elif not isinstance(snapshot, NodeSnapshot): + raise exceptions.GraphRuntimeError('Snapshot returned from persistence indicates the graph has ended.') + return snapshot + + +def node_type_adapter( + nodes: Sequence[type[BaseNode[Any, Any, Any]]], state_t: type[StateT], end_t: type[RunEndT] +) -> pydantic.TypeAdapter[BaseNode[StateT, Any, RunEndT]]: + debug([repr(node) for node in nodes]) + token = _utils.nodes_schema_var.set(nodes) + try: + ta = pydantic.TypeAdapter(BaseNode[state_t, Any, end_t]) + finally: + _utils.nodes_schema_var.reset(token) + return ta diff --git a/pydantic_graph/pydantic_graph/state/_utils.py b/pydantic_graph/pydantic_graph/state/_utils.py new file mode 100644 index 0000000000..3bd35fb8af --- /dev/null +++ b/pydantic_graph/pydantic_graph/state/_utils.py @@ -0,0 +1,52 @@ +from __future__ import annotations as _annotations + +from collections.abc import Sequence +from contextvars import ContextVar +from datetime import datetime, timezone +from typing import Annotated, Any, Union + +import pydantic +from pydantic_core import core_schema + +from ..nodes import BaseNode + +nodes_schema_var: ContextVar[Sequence[type[BaseNode[Any, Any, Any]]]] = ContextVar('nodes_var') + + +class CustomNodeSchema: + def __get_pydantic_core_schema__( + self, _source_type: Any, handler: pydantic.GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + try: + nodes = nodes_schema_var.get() + except LookupError as e: + raise RuntimeError( + 'Unable to build a Pydantic schema for `NodeStep` without setting `nodes_schema_var`. ' + 'You probably want to use TODO' + ) from e + if len(nodes) == 1: + nodes_type = nodes[0] + else: + nodes_annotated = [Annotated[node, pydantic.Tag(node.get_id())] for node in nodes] + nodes_type = Annotated[Union[tuple(nodes_annotated)], pydantic.Discriminator(self._node_discriminator)] + + schema = handler(nodes_type) + schema['serialization'] = core_schema.wrap_serializer_function_ser_schema( + function=self._node_serializer, + return_schema=core_schema.dict_schema(core_schema.str_schema(), core_schema.any_schema()), + ) + return schema + + @staticmethod + def _node_discriminator(node_data: Any) -> str: + return node_data.get('node_id') + + @staticmethod + def _node_serializer(node: Any, handler: pydantic.SerializerFunctionWrapHandler) -> dict[str, Any]: + node_dict = handler(node) + node_dict['node_id'] = node.get_id() + return node_dict + + +def now_utc() -> datetime: + return datetime.now(tz=timezone.utc) diff --git a/pydantic_graph/pydantic_graph/state/memory.py b/pydantic_graph/pydantic_graph/state/memory.py new file mode 100644 index 0000000000..94e70e3fb0 --- /dev/null +++ b/pydantic_graph/pydantic_graph/state/memory.py @@ -0,0 +1,107 @@ +import copy +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from time import perf_counter +from typing import Any + +import pydantic + +from ..nodes import BaseNode, End, RunEndT +from . import EndSnapshot, NodeSnapshot, Snapshot, StatePersistence, StateT, _utils + + +@dataclass +class LatestMemoryStatePersistence(StatePersistence[StateT, RunEndT]): + deep_copy: bool = True + last_snapshot: Snapshot[StateT, RunEndT] | None = None + + async def snapshot_node( + self, + state: StateT, + next_node: BaseNode[StateT, Any, RunEndT], + node_type_adapter: pydantic.TypeAdapter[BaseNode[StateT, Any, RunEndT]], + ) -> NodeSnapshot[StateT, RunEndT]: + self.last_snapshot = snapshot = NodeSnapshot( + state=self.prep_state(state), + node=next_node.deep_copy() if self.deep_copy else next_node, + ) + return snapshot + + @asynccontextmanager + async def record_run(self) -> AsyncIterator[None]: + last_snapshot = await self.restore_node_snapshot() + last_snapshot.start_ts = _utils.now_utc() + start = perf_counter() + try: + yield + finally: + last_snapshot.duration = perf_counter() - start + + async def snapshot_end( + self, state: StateT, end: End[RunEndT], end_data_type_adapter: pydantic.TypeAdapter[RunEndT] + ) -> None: + self.last_snapshot = EndSnapshot( + state=self.prep_state(state), + result=end.deep_copy_data() if self.deep_copy else end.data, + ) + + async def restore(self) -> Snapshot[StateT, RunEndT] | None: + return self.last_snapshot + + def prep_state(self, state: StateT) -> StateT: + """Prepare state for snapshot, uses [`copy.deepcopy`][copy.deepcopy] by default.""" + if not self.deep_copy or state is None: + return state + else: + return copy.deepcopy(state) + + +@dataclass +class HistoryMemoryStatePersistence(StatePersistence[StateT, RunEndT]): + deep_copy: bool = True + history: list[Snapshot[StateT, RunEndT]] = field(default_factory=list) + + async def snapshot_node( + self, + state: StateT, + next_node: BaseNode[StateT, Any, RunEndT], + node_type_adapter: pydantic.TypeAdapter[BaseNode[StateT, Any, RunEndT]], + ) -> NodeSnapshot[StateT, RunEndT]: + snapshot = NodeSnapshot( + state=self.prep_state(state), + node=next_node.deep_copy() if self.deep_copy else next_node, + ) + self.history.append(snapshot) + return snapshot + + @asynccontextmanager + async def record_run(self) -> AsyncIterator[None]: + last_snapshot = await self.restore_node_snapshot() + last_snapshot.start_ts = _utils.now_utc() + start = perf_counter() + try: + yield + finally: + last_snapshot.duration = perf_counter() - start + + async def snapshot_end( + self, state: StateT, end: End[RunEndT], end_data_type_adapter: pydantic.TypeAdapter[RunEndT] + ) -> None: + self.history.append( + EndSnapshot( + state=self.prep_state(state), + result=end.deep_copy_data() if self.deep_copy else end.data, + ) + ) + + async def restore(self) -> Snapshot[StateT, RunEndT] | None: + if self.history: + return self.history[-1] + + def prep_state(self, state: StateT) -> StateT: + """Prepare state for snapshot, uses [`copy.deepcopy`][copy.deepcopy] by default.""" + if not self.deep_copy or state is None: + return state + else: + return copy.deepcopy(state) diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index ebd254a370..20735278d5 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -14,13 +14,13 @@ from pydantic_graph import ( BaseNode, End, - EndStep, + EndSnapshot, Graph, GraphRunContext, GraphRuntimeError, GraphSetupError, - HistoryStep, - NodeStep, + NodeSnapshot, + Snapshot, ) from ..conftest import IsFloat, IsNow @@ -63,25 +63,25 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # assert my_graph.name == 'my_graph' assert history == snapshot( [ - NodeStep( + NodeSnapshot( state=None, node=Float2String(input_data=3.14), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeStep( + NodeSnapshot( state=None, node=String2Length(input_data='3.14'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeStep( + NodeSnapshot( state=None, node=Double(input_data=4), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndStep(result=End(data=8), ts=IsNow(tz=timezone.utc)), + EndSnapshot(result=End(data=8), ts=IsNow(tz=timezone.utc)), ] ) result, history = await my_graph.run(Float2String(3.14159)) @@ -89,37 +89,37 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # assert result == 42 assert history == snapshot( [ - NodeStep( + NodeSnapshot( state=None, node=Float2String(input_data=3.14159), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeStep( + NodeSnapshot( state=None, node=String2Length(input_data='3.14159'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeStep( + NodeSnapshot( state=None, node=Double(input_data=7), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeStep( + NodeSnapshot( state=None, node=String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeStep( + NodeSnapshot( state=None, node=Double(input_data=21), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndStep(result=End(data=42), ts=IsNow(tz=timezone.utc)), + EndSnapshot(result=End(data=42), ts=IsNow(tz=timezone.utc)), ] ) assert [e.data_snapshot() for e in history] == snapshot( @@ -283,11 +283,13 @@ async def run(self, ctx: GraphRunContext) -> Foo: g = Graph(nodes=(Foo, Bar)) assert g.name is None - history: list[HistoryStep[None, Never]] = [] + history: list[Snapshot[None, Never]] = [] n = await g.next(Foo(), history) assert n == Bar() assert g.name == 'g' - assert history == snapshot([NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat())]) + assert history == snapshot( + [NodeSnapshot(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat())] + ) assert isinstance(n, Bar) n2 = await g.next(n, history) @@ -295,8 +297,8 @@ async def run(self, ctx: GraphRunContext) -> Foo: assert history == snapshot( [ - NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - NodeStep(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeSnapshot(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeSnapshot(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), ] ) @@ -325,8 +327,8 @@ async def run(self, ctx: GraphRunContext[None, Deps]) -> End[int]: assert result == 123 assert history == snapshot( [ - NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - NodeStep(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - EndStep(result=End(data=123), ts=IsNow(tz=timezone.utc)), + NodeSnapshot(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeSnapshot(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + EndSnapshot(result=End(data=123), ts=IsNow(tz=timezone.utc)), ] ) diff --git a/tests/graph/test_history.py b/tests/graph/test_history.py index 2508a53475..36d802f542 100644 --- a/tests/graph/test_history.py +++ b/tests/graph/test_history.py @@ -9,7 +9,7 @@ from dirty_equals import IsStr from inline_snapshot import snapshot -from pydantic_graph import BaseNode, End, EndStep, Graph, GraphRunContext, GraphSetupError, NodeStep +from pydantic_graph import BaseNode, End, EndSnapshot, Graph, GraphRunContext, GraphSetupError, NodeSnapshot from ..conftest import IsFloat, IsNow @@ -50,9 +50,9 @@ async def test_dump_load_history(graph: Graph[MyState, None, int]): assert result == snapshot(4) assert history == snapshot( [ - NodeStep(state=MyState(x=2, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - NodeStep(state=MyState(x=2, y='y'), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - EndStep(result=End(4), ts=IsNow(tz=timezone.utc)), + NodeSnapshot(state=MyState(x=2, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeSnapshot(state=MyState(x=2, y='y'), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + EndSnapshot(result=End(4), ts=IsNow(tz=timezone.utc)), ] ) history_json = graph.dump_history(history) @@ -91,13 +91,13 @@ async def test_dump_load_history(graph: Graph[MyState, None, int]): history_loaded = graph.load_history(json.dumps(custom_history)) assert history_loaded == snapshot( [ - NodeStep( + NodeSnapshot( state=MyState(x=2, y=''), node=Foo(), start_ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), duration=123.0, ), - EndStep(result=End(data=42), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), + EndSnapshot(result=End(data=42), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), ] ) @@ -116,7 +116,7 @@ async def run(self, ctx: GraphRunContext) -> End[int]: history_loaded = g.load_history(json.dumps(custom_history)) assert history_loaded == snapshot( [ - EndStep(result=End(data=123), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), + EndSnapshot(result=End(data=123), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), ] ) @@ -141,6 +141,6 @@ async def run(self, ctx: GraphRunContext) -> NoGenericArgsNode: history_loaded = g.load_history(json.dumps(custom_history)) assert history_loaded == snapshot( [ - EndStep(result=End(data=None), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), + EndSnapshot(result=End(data=None), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), ] ) diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 9f76d93cdd..b90b87a06d 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -11,7 +11,7 @@ import pytest from inline_snapshot import snapshot -from pydantic_graph import BaseNode, Edge, End, EndStep, Graph, GraphRunContext, GraphSetupError, NodeStep +from pydantic_graph import BaseNode, Edge, End, EndSnapshot, Graph, GraphRunContext, GraphSetupError, NodeSnapshot from pydantic_graph.nodes import NodeDef from ..conftest import IsFloat, IsNow @@ -62,19 +62,19 @@ async def test_run_graph(): assert result is None assert history == snapshot( [ - NodeStep( + NodeSnapshot( state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeStep( + NodeSnapshot( state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndStep(result=End(data=None), ts=IsNow(tz=timezone.utc)), + EndSnapshot(result=End(data=None), ts=IsNow(tz=timezone.utc)), ] ) diff --git a/tests/graph/test_state.py b/tests/graph/test_state.py index fbb570cf0c..7eec60824f 100644 --- a/tests/graph/test_state.py +++ b/tests/graph/test_state.py @@ -7,7 +7,7 @@ import pytest from inline_snapshot import snapshot -from pydantic_graph import BaseNode, End, EndStep, Graph, GraphRunContext, NodeStep +from pydantic_graph import BaseNode, End, EndSnapshot, Graph, GraphRunContext, NodeSnapshot from ..conftest import IsFloat, IsNow @@ -40,19 +40,19 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[str]: assert result == snapshot('x=2 y=y') assert history == snapshot( [ - NodeStep( + NodeSnapshot( state=MyState(x=2, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeStep( + NodeSnapshot( state=MyState(x=2, y='y'), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndStep(result=End('x=2 y=y'), ts=IsNow(tz=timezone.utc)), + EndSnapshot(result=End('x=2 y=y'), ts=IsNow(tz=timezone.utc)), ] ) assert state == MyState(x=2, y='y') diff --git a/tests/typed_graph.py b/tests/typed_graph.py index d0b6a02b7e..b484ad0571 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -4,7 +4,7 @@ from typing_extensions import assert_type -from pydantic_graph import BaseNode, End, Graph, GraphRunContext, HistoryStep +from pydantic_graph import BaseNode, End, Graph, GraphRunContext, Snapshot @dataclass @@ -111,4 +111,4 @@ def run_g5() -> None: g5.run_sync(A(), deps=MyDeps(y='y')) # pyright: ignore[reportArgumentType] answer, history = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y')) assert_type(answer, int) - assert_type(history, list[HistoryStep[MyState, int]]) + assert_type(history, list[Snapshot[MyState, int]]) From 4c1d50d217839564f3e47c01acfc15caae1eb97b Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 23 Feb 2025 18:44:13 -0500 Subject: [PATCH 02/25] fixing tests --- pydantic_graph/pydantic_graph/__init__.py | 3 + pydantic_graph/pydantic_graph/graph.py | 71 +++---- pydantic_graph/pydantic_graph/nodes.py | 6 +- .../pydantic_graph/state/__init__.py | 74 ++++---- pydantic_graph/pydantic_graph/state/_utils.py | 20 +- pydantic_graph/pydantic_graph/state/memory.py | 98 +++++----- tests/graph/test_graph.py | 88 +++++---- tests/graph/test_history.py | 174 +++++++++--------- tests/graph/test_mermaid.py | 19 +- tests/graph/test_state.py | 26 ++- 10 files changed, 317 insertions(+), 262 deletions(-) diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index 0b1ff87e9d..0a96d52b02 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -2,6 +2,7 @@ from .graph import Graph from .nodes import BaseNode, Edge, End, GraphRunContext from .state import EndSnapshot, NodeSnapshot, Snapshot +from .state.memory import FullStatePersistence, SimpleStatePersistence __all__ = ( 'Graph', @@ -14,4 +15,6 @@ 'NodeSnapshot', 'GraphSetupError', 'GraphRuntimeError', + 'SimpleStatePersistence', + 'FullStatePersistence', ) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index d4ae197fe2..d4b79497d6 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -12,11 +12,12 @@ import logfire_api import pydantic import typing_extensions +from inline_snapshot import Snapshot from . import _utils, exceptions, mermaid from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT -from .state import StatePersistence, StateT, node_type_adapter -from .state.memory import LatestMemoryStatePersistence +from .state import StatePersistence, StateT, build_nodes_type_adapter +from .state.memory import SimpleStatePersistence # while waiting for https://github.com/pydantic/logfire/issues/745 try: @@ -125,7 +126,7 @@ async def run( *, state: StateT = None, deps: DepsT = None, - state_persistence: StatePersistence[StateT, T] | None = None, + persistence: StatePersistence[StateT, T] | None = None, infer_name: bool = True, ) -> T: """Run the graph from a starting node until it ends. @@ -135,8 +136,8 @@ async def run( you need to provide the starting node. state: The initial state of the graph. deps: The dependencies of the graph. - state_persistence: State persistence interface, defaults to - [`LatestMemoryStatePersistence`][pydantic_graph.state.memory.LatestMemoryStatePersistence] if `None`. + persistence: State persistence interface, defaults to + [`SimpleStatePersistence`][pydantic_graph.state.memory.SimpleStatePersistence] if `None`. infer_name: Whether to infer the graph name from the calling frame. Returns: @@ -166,27 +167,22 @@ async def main(): if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - if state_persistence is None: - state_persistence = LatestMemoryStatePersistence() - - # have to snapshot state before iterating over nodes, as we'll expect a snapshot in the - # state_persistence soon - await state_persistence.snapshot_node(state, start_node, self._node_type_adapter) + if persistence is None: + persistence = SimpleStatePersistence() with ExitStack() as stack: if self._auto_instrument: stack.enter_context( - _logfire.span( - '{graph_name} run {start=}', - graph_name=self.name or 'graph', - start=start_node, - ) + _logfire.span('{graph_name} run {start=}', graph_name=self.name or 'graph', start=start_node) ) next_node = start_node while True: - next_node = await self.next(next_node, state_persistence, state=state, deps=deps, infer_name=False) + next_node = await self.next( + next_node, persistence=persistence, state=state, deps=deps, infer_name=False + ) if isinstance(next_node, End): + await persistence.snapshot_end(state, next_node) return next_node.data elif not isinstance(next_node, BaseNode): if TYPE_CHECKING: @@ -202,7 +198,7 @@ def run_sync( *, state: StateT = None, deps: DepsT = None, - state_persistence: StatePersistence[StateT, T] | None = None, + persistence: StatePersistence[StateT, T] | None = None, infer_name: bool = True, ) -> T: """Run the graph synchronously. @@ -215,8 +211,8 @@ def run_sync( you need to provide the starting node. state: The initial state of the graph. deps: The dependencies of the graph. - state_persistence: State persistence interface, defaults to - [`LatestMemoryStatePersistence`][pydantic_graph.state.memory.LatestMemoryStatePersistence] if `None`. + persistence: State persistence interface, defaults to + [`SimpleStatePersistence`][pydantic_graph.state.memory.SimpleStatePersistence] if `None`. infer_name: Whether to infer the graph name from the calling frame. Returns: @@ -225,14 +221,14 @@ def run_sync( if infer_name and self.name is None: self._infer_name(inspect.currentframe()) return asyncio.get_event_loop().run_until_complete( - self.run(start_node, state=state, deps=deps, state_persistence=state_persistence, infer_name=False) + self.run(start_node, state=state, deps=deps, persistence=persistence, infer_name=False) ) async def next( self: Graph[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T], - state_persistence: StatePersistence[StateT, T], *, + persistence: StatePersistence[StateT, T] | None = None, state: StateT = None, deps: DepsT = None, infer_name: bool = True, @@ -241,7 +237,8 @@ async def next( Args: node: The node to run. - state_persistence: State persistence interface. + persistence: State persistence interface, defaults to + [`SimpleStatePersistence`][pydantic_graph.state.memory.SimpleStatePersistence] if `None`. state: The current state of the graph. deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. @@ -251,26 +248,27 @@ async def next( """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) + + if persistence is None: + persistence = SimpleStatePersistence() + node_id = node.get_id() if node_id not in self.node_defs: raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') + await persistence.snapshot_node(state, node) + with ExitStack() as stack: if self._auto_instrument: stack.enter_context(_logfire.span('run node {node_id}', node_id=node_id, node=node)) ctx = GraphRunContext(state, deps) - async with state_persistence.record_run(): + async with persistence.record_run(): next_or_end = await node.run(ctx) - - if isinstance(next_or_end, BaseNode): - await state_persistence.snapshot_node(state, next_or_end, self._node_type_adapter) - else: - await state_persistence.snapshot_end(state, next_or_end, self._end_data_type_adapter) return next_or_end async def next_from_persistence( self: Graph[StateT, DepsT, T], - state_persistence: StatePersistence[StateT, T], + persistence: StatePersistence[StateT, T], *, deps: DepsT = None, infer_name: bool = True, @@ -278,10 +276,10 @@ async def next_from_persistence( if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - snapshot = await state_persistence.restore_node_snapshot() + snapshot = await persistence.restore_node_snapshot() return await self.next( snapshot.node, - state_persistence, + persistence=persistence, state=snapshot.state, deps=deps, infer_name=False, @@ -413,13 +411,17 @@ def mermaid_save( @cached_property def _node_type_adapter(self) -> pydantic.TypeAdapter[BaseNode[StateT, Any, RunEndT]]: nodes = [node_def.node for node_def in self.node_defs.values()] - return node_type_adapter(nodes, self._get_state_type(), self._get_run_end_type()) + return build_nodes_type_adapter(nodes, self._get_state_type(), self._get_run_end_type()) @cached_property def _end_data_type_adapter(self) -> pydantic.TypeAdapter[RunEndT]: end_t = self._get_run_end_type() return pydantic.TypeAdapter(end_t) + @cached_property + def _snapshot_type_adapter(self) -> pydantic.TypeAdapter[Snapshot[StateT, RunEndT]]: + pass + def _get_state_type(self) -> type[StateT]: if _utils.is_set(self._state_type): return self._state_type @@ -449,7 +451,8 @@ def _get_run_end_type(self) -> type[RunEndT]: return t # break the inner (bases) loop break - raise exceptions.GraphSetupError('Could not infer run end type from nodes, please set `run_end_type`.') + # this happens if a graph has no return nodes, use None so any downstream errors a clear + return type(None) # pyright: ignore[reportReturnType] def _register_node( self: Graph[StateT, DepsT, T], diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index 744896b07a..df3c04260f 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -138,12 +138,12 @@ class End(Generic[RunEndT]): data: RunEndT """Data to return from the graph.""" - def deep_copy_data(self) -> RunEndT: + def deep_copy_data(self) -> End[RunEndT]: """Returns a deep copy of the end of the run.""" if self.data is None: - return self.data + return self else: - return copy.deepcopy(self.data) + return End(copy.deepcopy(self.data)) @dataclass diff --git a/pydantic_graph/pydantic_graph/state/__init__.py b/pydantic_graph/pydantic_graph/state/__init__.py index 748664d584..2ef10cd138 100644 --- a/pydantic_graph/pydantic_graph/state/__init__.py +++ b/pydantic_graph/pydantic_graph/state/__init__.py @@ -5,7 +5,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime -from typing import Annotated, Any, Generic, Literal, Union +from typing import Annotated, Any, Callable, Generic, Literal, Union import pydantic from typing_extensions import TypeVar @@ -14,8 +14,7 @@ from ..nodes import BaseNode, End, RunEndT from . import _utils -__all__ = 'StateT', 'NodeSnapshot', 'EndSnapshot', 'Snapshot', 'StatePersistence', 'node_type_adapter' - +__all__ = 'StateT', 'NodeSnapshot', 'EndSnapshot', 'Snapshot', 'StatePersistence', 'build_nodes_type_adapter' StateT = TypeVar('StateT', default=None) """Type variable for the state in a graph.""" @@ -27,7 +26,7 @@ class NodeSnapshot(Generic[StateT, RunEndT]): state: StateT """The state of the graph before the node is run.""" - node: Annotated[BaseNode[StateT, Any, RunEndT], _utils.CustomNodeSchema()] + node: BaseNode[StateT, Any, RunEndT] """The node to run next.""" start_ts: datetime | None = None """The timestamp when the node started running, `None` until the run starts.""" @@ -42,14 +41,22 @@ class EndSnapshot(Generic[StateT, RunEndT]): """History step describing the end of a graph run.""" state: StateT - """The state of the graph at the end of teh run.""" - result: RunEndT + """The state of the graph at the end of the run.""" + result: End[RunEndT] """The result of the graph run.""" ts: datetime = field(default_factory=_utils.now_utc) """The timestamp when the graph run ended.""" kind: Literal['end'] = 'end' """The kind of history step, can be used as a discriminator when deserializing history.""" + @property + def node(self) -> End[RunEndT]: + """Shim to get the [`result`][pydantic_graph.state.EndSnapshot.result]. + + Useful to allow `[snapshot.node for snapshot in persistence.history]`. + """ + return self.result + Snapshot = Union[NodeSnapshot[StateT, RunEndT], EndSnapshot[StateT, RunEndT]] """A step in the history of a graph run. @@ -63,24 +70,28 @@ class StatePersistence(ABC, Generic[StateT, RunEndT]): """Abstract base class for storing the state of a graph.""" @abstractmethod - async def snapshot_node( - self, - state: StateT, - next_node: BaseNode[StateT, Any, RunEndT], - node_type_adapter: pydantic.TypeAdapter[BaseNode[StateT, Any, RunEndT]], - ) -> NodeSnapshot[StateT, RunEndT]: + async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: """Snapshot the state of a graph before a node is run. Args: state: The state of the graph. next_node: The next node to run or end if the graph has ended - node_type_adapter: Pydantic [`TypeAdapter`][pydantic.TypeAdapter] for the node. Returns: The snapshot """ raise NotImplementedError + @abstractmethod + async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: + """Snapshot the state of a graph before a node is run. + + Args: + state: The state of the graph. + end: data from the end of the run. + """ + raise NotImplementedError + @abstractmethod @asynccontextmanager async def record_run(self) -> AsyncIterator[None]: @@ -92,28 +103,24 @@ async def record_run(self) -> AsyncIterator[None]: yield raise NotImplementedError - @abstractmethod - async def snapshot_end( - self, state: StateT, end: End[RunEndT], end_data_type_adapter: pydantic.TypeAdapter[RunEndT] - ) -> None: - """Snapshot the state of a graph before a node is run. - - Args: - state: The state of the graph. - end: data from the end of the run. - end_data_type_adapter: a Pydantic [`TypeAdapter`][pydantic.TypeAdapter] for the end data. - """ - raise NotImplementedError - @abstractmethod async def restore(self) -> Snapshot[StateT, RunEndT] | None: - """Retrieve a snapshot. + """Retrieve the latest snapshot. Returns: The most recent [`Snapshot`][pydantic_graph.state.Snapshot] of the run. """ raise NotImplementedError + def set_type_adapters( + self, + *, + get_node_type_adapter: Callable[[], pydantic.TypeAdapter[BaseNode[StateT, Any, RunEndT]]], + get_end_data_type_adapter: Callable[[], pydantic.TypeAdapter[RunEndT]], + get_snapshot_type_adapter: Callable[[], pydantic.TypeAdapter[Snapshot[StateT, RunEndT]]], + ): + pass + async def restore_node_snapshot(self) -> NodeSnapshot[StateT, RunEndT]: snapshot = await self.restore() if snapshot is None: @@ -123,13 +130,10 @@ async def restore_node_snapshot(self) -> NodeSnapshot[StateT, RunEndT]: return snapshot -def node_type_adapter( +def build_nodes_type_adapter( # noqa: D103 nodes: Sequence[type[BaseNode[Any, Any, Any]]], state_t: type[StateT], end_t: type[RunEndT] ) -> pydantic.TypeAdapter[BaseNode[StateT, Any, RunEndT]]: - debug([repr(node) for node in nodes]) - token = _utils.nodes_schema_var.set(nodes) - try: - ta = pydantic.TypeAdapter(BaseNode[state_t, Any, end_t]) - finally: - _utils.nodes_schema_var.reset(token) - return ta + return pydantic.TypeAdapter( + Annotated[BaseNode[state_t, Any, end_t], _utils.CustomNodeSchema(nodes)], + config=pydantic.ConfigDict(defer_build=True), + ) diff --git a/pydantic_graph/pydantic_graph/state/_utils.py b/pydantic_graph/pydantic_graph/state/_utils.py index 3bd35fb8af..b854df8d14 100644 --- a/pydantic_graph/pydantic_graph/state/_utils.py +++ b/pydantic_graph/pydantic_graph/state/_utils.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations from collections.abc import Sequence -from contextvars import ContextVar +from dataclasses import dataclass from datetime import datetime, timezone from typing import Annotated, Any, Union @@ -10,24 +10,18 @@ from ..nodes import BaseNode -nodes_schema_var: ContextVar[Sequence[type[BaseNode[Any, Any, Any]]]] = ContextVar('nodes_var') - +@dataclass class CustomNodeSchema: + nodes: Sequence[type[BaseNode[Any, Any, Any]]] + def __get_pydantic_core_schema__( self, _source_type: Any, handler: pydantic.GetCoreSchemaHandler ) -> core_schema.CoreSchema: - try: - nodes = nodes_schema_var.get() - except LookupError as e: - raise RuntimeError( - 'Unable to build a Pydantic schema for `NodeStep` without setting `nodes_schema_var`. ' - 'You probably want to use TODO' - ) from e - if len(nodes) == 1: - nodes_type = nodes[0] + if len(self.nodes) == 1: + nodes_type = self.nodes[0] else: - nodes_annotated = [Annotated[node, pydantic.Tag(node.get_id())] for node in nodes] + nodes_annotated = [Annotated[node, pydantic.Tag(node.get_id())] for node in self.nodes] nodes_type = Annotated[Union[tuple(nodes_annotated)], pydantic.Discriminator(self._node_discriminator)] schema = handler(nodes_type) diff --git a/pydantic_graph/pydantic_graph/state/memory.py b/pydantic_graph/pydantic_graph/state/memory.py index 94e70e3fb0..3a83eba241 100644 --- a/pydantic_graph/pydantic_graph/state/memory.py +++ b/pydantic_graph/pydantic_graph/state/memory.py @@ -1,36 +1,50 @@ +from __future__ import annotations as _annotations + import copy from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass, field from time import perf_counter -from typing import Any - -import pydantic +from typing import Any, TypeVar from ..nodes import BaseNode, End, RunEndT from . import EndSnapshot, NodeSnapshot, Snapshot, StatePersistence, StateT, _utils +S = TypeVar('S') +R = TypeVar('R') + @dataclass -class LatestMemoryStatePersistence(StatePersistence[StateT, RunEndT]): +class SimpleStatePersistence(StatePersistence[StateT, RunEndT]): + """Simple in memory state persistence that just hold the latest snapshot.""" + deep_copy: bool = True last_snapshot: Snapshot[StateT, RunEndT] | None = None - async def snapshot_node( - self, - state: StateT, - next_node: BaseNode[StateT, Any, RunEndT], - node_type_adapter: pydantic.TypeAdapter[BaseNode[StateT, Any, RunEndT]], - ) -> NodeSnapshot[StateT, RunEndT]: - self.last_snapshot = snapshot = NodeSnapshot( + @classmethod + def from_types(cls, state_type: type[S], run_end_type: type[R]) -> SimpleStatePersistence[S, R]: + """No-op init method that help type checkers.""" + return SimpleStatePersistence() + + async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: + self.last_snapshot = NodeSnapshot( state=self.prep_state(state), node=next_node.deep_copy() if self.deep_copy else next_node, ) - return snapshot + + async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: + self.last_snapshot = EndSnapshot( + state=self.prep_state(state), + result=end.deep_copy_data() if self.deep_copy else end, + ) @asynccontextmanager async def record_run(self) -> AsyncIterator[None]: - last_snapshot = await self.restore_node_snapshot() + last_snapshot = await self.restore() + if not isinstance(last_snapshot, NodeSnapshot): + yield + return + last_snapshot.start_ts = _utils.now_utc() start = perf_counter() try: @@ -38,14 +52,6 @@ async def record_run(self) -> AsyncIterator[None]: finally: last_snapshot.duration = perf_counter() - start - async def snapshot_end( - self, state: StateT, end: End[RunEndT], end_data_type_adapter: pydantic.TypeAdapter[RunEndT] - ) -> None: - self.last_snapshot = EndSnapshot( - state=self.prep_state(state), - result=end.deep_copy_data() if self.deep_copy else end.data, - ) - async def restore(self) -> Snapshot[StateT, RunEndT] | None: return self.last_snapshot @@ -58,26 +64,40 @@ def prep_state(self, state: StateT) -> StateT: @dataclass -class HistoryMemoryStatePersistence(StatePersistence[StateT, RunEndT]): +class FullStatePersistence(StatePersistence[StateT, RunEndT]): + """In memory state persistence that hold a history of nodes that were executed.""" + deep_copy: bool = True history: list[Snapshot[StateT, RunEndT]] = field(default_factory=list) - async def snapshot_node( - self, - state: StateT, - next_node: BaseNode[StateT, Any, RunEndT], - node_type_adapter: pydantic.TypeAdapter[BaseNode[StateT, Any, RunEndT]], - ) -> NodeSnapshot[StateT, RunEndT]: - snapshot = NodeSnapshot( - state=self.prep_state(state), - node=next_node.deep_copy() if self.deep_copy else next_node, + @classmethod + def from_types(cls, state_type: type[S], run_end_type: type[R]) -> FullStatePersistence[S, R]: + """No-op init method that help type checkers.""" + return FullStatePersistence() + + async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: + self.history.append( + NodeSnapshot( + state=self.prep_state(state), + node=next_node.deep_copy() if self.deep_copy else next_node, + ) + ) + + async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: + self.history.append( + EndSnapshot( + state=self.prep_state(state), + result=end.deep_copy_data() if self.deep_copy else end, + ) ) - self.history.append(snapshot) - return snapshot @asynccontextmanager async def record_run(self) -> AsyncIterator[None]: - last_snapshot = await self.restore_node_snapshot() + last_snapshot = await self.restore() + if not isinstance(last_snapshot, NodeSnapshot): + yield + return + last_snapshot.start_ts = _utils.now_utc() start = perf_counter() try: @@ -85,16 +105,6 @@ async def record_run(self) -> AsyncIterator[None]: finally: last_snapshot.duration = perf_counter() - start - async def snapshot_end( - self, state: StateT, end: End[RunEndT], end_data_type_adapter: pydantic.TypeAdapter[RunEndT] - ) -> None: - self.history.append( - EndSnapshot( - state=self.prep_state(state), - result=end.deep_copy_data() if self.deep_copy else end.data, - ) - ) - async def restore(self) -> Snapshot[StateT, RunEndT] | None: if self.history: return self.history[-1] diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 20735278d5..6320ea89d3 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -9,18 +9,17 @@ import pytest from dirty_equals import IsStr from inline_snapshot import snapshot -from typing_extensions import Never from pydantic_graph import ( BaseNode, End, EndSnapshot, + FullStatePersistence, Graph, GraphRunContext, GraphRuntimeError, GraphSetupError, NodeSnapshot, - Snapshot, ) from ..conftest import IsFloat, IsNow @@ -28,40 +27,55 @@ pytestmark = pytest.mark.anyio -async def test_graph(): - @dataclass - class Float2String(BaseNode): - input_data: float +@dataclass +class Float2String(BaseNode): + input_data: float - async def run(self, ctx: GraphRunContext) -> String2Length: - return String2Length(str(self.input_data)) + async def run(self, ctx: GraphRunContext) -> String2Length: + return String2Length(str(self.input_data)) - @dataclass - class String2Length(BaseNode): - input_data: str - async def run(self, ctx: GraphRunContext) -> Double: - return Double(len(self.input_data)) +@dataclass +class String2Length(BaseNode): + input_data: str - @dataclass - class Double(BaseNode[None, None, int]): - input_data: int + async def run(self, ctx: GraphRunContext) -> Double: + return Double(len(self.input_data)) + + +@dataclass +class Double(BaseNode[None, None, int]): + input_data: int + + async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # noqa: UP007 + if self.input_data == 7: + return String2Length('x' * 21) + else: + return End(self.input_data * 2) + + +async def test_graph(): + my_graph = Graph(nodes=(Float2String, String2Length, Double)) + assert my_graph.name is None + assert my_graph._get_state_type() is type(None) + assert my_graph._get_run_end_type() is int + result = await my_graph.run(Float2String(3.14)) + # len('3.14') * 2 == 8 + assert result == 8 + assert my_graph.name == 'my_graph' - async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # noqa: UP007 - if self.input_data == 7: - return String2Length('x' * 21) - else: - return End(self.input_data * 2) +async def test_graph_history(): my_graph = Graph[None, None, int](nodes=(Float2String, String2Length, Double)) assert my_graph.name is None assert my_graph._get_state_type() is type(None) assert my_graph._get_run_end_type() is int - result, history = await my_graph.run(Float2String(3.14)) + sp = FullStatePersistence() + result = await my_graph.run(Float2String(3.14), persistence=sp) # len('3.14') * 2 == 8 assert result == 8 assert my_graph.name == 'my_graph' - assert history == snapshot( + assert sp.history == snapshot( [ NodeSnapshot( state=None, @@ -81,13 +95,14 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndSnapshot(result=End(data=8), ts=IsNow(tz=timezone.utc)), + EndSnapshot(state=None, result=End(8), ts=IsNow(tz=timezone.utc)), ] ) - result, history = await my_graph.run(Float2String(3.14159)) + sp = FullStatePersistence() + result = await my_graph.run(Float2String(3.14159), persistence=sp) # len('3.14159') == 7, 21 * 2 == 42 assert result == 42 - assert history == snapshot( + assert sp.history == snapshot( [ NodeSnapshot( state=None, @@ -119,10 +134,10 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndSnapshot(result=End(data=42), ts=IsNow(tz=timezone.utc)), + EndSnapshot(state=None, result=End(42), ts=IsNow(tz=timezone.utc)), ] ) - assert [e.data_snapshot() for e in history] == snapshot( + assert [e.node for e in sp.history] == snapshot( [ Float2String(input_data=3.14159), String2Length(input_data='3.14159'), @@ -283,19 +298,19 @@ async def run(self, ctx: GraphRunContext) -> Foo: g = Graph(nodes=(Foo, Bar)) assert g.name is None - history: list[Snapshot[None, Never]] = [] - n = await g.next(Foo(), history) + sp = FullStatePersistence() + n = await g.next(Foo(), persistence=sp) assert n == Bar() assert g.name == 'g' - assert history == snapshot( + assert sp.history == snapshot( [NodeSnapshot(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat())] ) assert isinstance(n, Bar) - n2 = await g.next(n, history) + n2 = await g.next(n, persistence=sp) assert n2 == Foo() - assert history == snapshot( + assert sp.history == snapshot( [ NodeSnapshot(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), NodeSnapshot(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), @@ -322,13 +337,14 @@ async def run(self, ctx: GraphRunContext[None, Deps]) -> End[int]: return End(123) g = Graph(nodes=(Foo, Bar)) - result, history = await g.run(Foo(), deps=Deps(1, 2)) + sp = FullStatePersistence() + result = await g.run(Foo(), deps=Deps(1, 2), persistence=sp) assert result == 123 - assert history == snapshot( + assert sp.history == snapshot( [ NodeSnapshot(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), NodeSnapshot(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - EndSnapshot(result=End(data=123), ts=IsNow(tz=timezone.utc)), + EndSnapshot(state=None, result=End(123), ts=IsNow(tz=timezone.utc)), ] ) diff --git a/tests/graph/test_history.py b/tests/graph/test_history.py index 36d802f542..8dcfda570e 100644 --- a/tests/graph/test_history.py +++ b/tests/graph/test_history.py @@ -1,15 +1,21 @@ # pyright: reportPrivateUsage=false from __future__ import annotations as _annotations -import json from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import timezone import pytest -from dirty_equals import IsStr from inline_snapshot import snapshot -from pydantic_graph import BaseNode, End, EndSnapshot, Graph, GraphRunContext, GraphSetupError, NodeSnapshot +from pydantic_graph import ( + BaseNode, + End, + EndSnapshot, + FullStatePersistence, + Graph, + GraphRunContext, + NodeSnapshot, +) from ..conftest import IsFloat, IsNow @@ -46,79 +52,80 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[int]: ], ) async def test_dump_load_history(graph: Graph[MyState, None, int]): - result, history = await graph.run(Foo(), state=MyState(1, '')) + sp = FullStatePersistence.from_types(MyState, int) + result = await graph.run(Foo(), state=MyState(1, ''), persistence=sp) assert result == snapshot(4) - assert history == snapshot( - [ - NodeSnapshot(state=MyState(x=2, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - NodeSnapshot(state=MyState(x=2, y='y'), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - EndSnapshot(result=End(4), ts=IsNow(tz=timezone.utc)), - ] - ) - history_json = graph.dump_history(history) - assert json.loads(history_json) == snapshot( - [ - { - 'state': {'x': 2, 'y': ''}, - 'node': {'node_id': 'Foo'}, - 'start_ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), - 'duration': IsFloat(), - 'kind': 'node', - }, - { - 'state': {'x': 2, 'y': 'y'}, - 'node': {'node_id': 'Bar'}, - 'start_ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), - 'duration': IsFloat(), - 'kind': 'node', - }, - {'result': {'data': 4}, 'ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), 'kind': 'end'}, - ] - ) - history_loaded = graph.load_history(history_json) - assert history == history_loaded - - custom_history = [ - { - 'state': {'x': 2, 'y': ''}, - 'node': {'node_id': 'Foo'}, - 'start_ts': '2025-01-01T00:00:00Z', - 'duration': 123, - 'kind': 'node', - }, - {'result': {'data': '42'}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'}, - ] - history_loaded = graph.load_history(json.dumps(custom_history)) - assert history_loaded == snapshot( - [ - NodeSnapshot( - state=MyState(x=2, y=''), - node=Foo(), - start_ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), - duration=123.0, - ), - EndSnapshot(result=End(data=42), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), - ] - ) - - -def test_one_node(): - @dataclass - class MyNode(BaseNode[None, None, int]): - async def run(self, ctx: GraphRunContext) -> End[int]: - return End(123) - - g = Graph(nodes=[MyNode]) - - custom_history = [ - {'result': {'data': '123'}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'}, - ] - history_loaded = g.load_history(json.dumps(custom_history)) - assert history_loaded == snapshot( + assert sp.history == snapshot( [ - EndSnapshot(result=End(data=123), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), + NodeSnapshot(state=MyState(x=1, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeSnapshot(state=MyState(x=2, y=''), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + EndSnapshot(state=MyState(x=2, y='y'), result=End(4), ts=IsNow(tz=timezone.utc)), ] ) + # history_json = graph.dump_history(history) + # assert json.loads(history_json) == snapshot( + # [ + # { + # 'state': {'x': 2, 'y': ''}, + # 'node': {'node_id': 'Foo'}, + # 'start_ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), + # 'duration': IsFloat(), + # 'kind': 'node', + # }, + # { + # 'state': {'x': 2, 'y': 'y'}, + # 'node': {'node_id': 'Bar'}, + # 'start_ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), + # 'duration': IsFloat(), + # 'kind': 'node', + # }, + # {'result': {'data': 4}, 'ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), 'kind': 'end'}, + # ] + # ) + # history_loaded = graph.load_history(history_json) + # assert history == history_loaded + # + # custom_history = [ + # { + # 'state': {'x': 2, 'y': ''}, + # 'node': {'node_id': 'Foo'}, + # 'start_ts': '2025-01-01T00:00:00Z', + # 'duration': 123, + # 'kind': 'node', + # }, + # {'result': {'data': '42'}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'}, + # ] + # history_loaded = graph.load_history(json.dumps(custom_history)) + # assert history_loaded == snapshot( + # [ + # NodeSnapshot( + # state=MyState(x=2, y=''), + # node=Foo(), + # start_ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), + # duration=123.0, + # ), + # EndSnapshot(result=End(data=42), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), + # ] + # ) + + +# def test_one_node(): +# @dataclass +# class MyNode(BaseNode[None, None, int]): +# async def run(self, ctx: GraphRunContext) -> End[int]: +# return End(123) +# +# g = Graph(nodes=[MyNode]) +# +# custom_history = [ +# {'result': {'data': '123'}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'}, +# ] +# history_loaded = g.load_history(json.dumps(custom_history)) +# assert history_loaded == snapshot( +# [ +# EndSnapshot(result=End(data=123), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), +# ] +# ) def test_no_generic_arg(): @@ -129,18 +136,17 @@ async def run(self, ctx: GraphRunContext) -> NoGenericArgsNode: g = Graph(nodes=[NoGenericArgsNode]) assert g._get_state_type() is type(None) - with pytest.raises(GraphSetupError, match='Could not infer run end type from nodes, please set `run_end_type`.'): - g._get_run_end_type() + assert g._get_run_end_type() is type(None) g = Graph(nodes=[NoGenericArgsNode], run_end_type=None) # pyright: ignore[reportArgumentType] assert g._get_run_end_type() is None - custom_history = [ - {'result': {'data': None}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'}, - ] - history_loaded = g.load_history(json.dumps(custom_history)) - assert history_loaded == snapshot( - [ - EndSnapshot(result=End(data=None), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), - ] - ) + # custom_history = [ + # {'result': {'data': None}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'}, + # ] + # history_loaded = g.load_history(json.dumps(custom_history)) + # assert history_loaded == snapshot( + # [ + # EndSnapshot(result=End(data=None), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), + # ] + # ) diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index b90b87a06d..15e707fddb 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -11,7 +11,17 @@ import pytest from inline_snapshot import snapshot -from pydantic_graph import BaseNode, Edge, End, EndSnapshot, Graph, GraphRunContext, GraphSetupError, NodeSnapshot +from pydantic_graph import ( + BaseNode, + Edge, + End, + EndSnapshot, + FullStatePersistence, + Graph, + GraphRunContext, + GraphSetupError, + NodeSnapshot, +) from pydantic_graph.nodes import NodeDef from ..conftest import IsFloat, IsNow @@ -58,9 +68,10 @@ async def run(self, ctx: GraphRunContext) -> Annotated[End[None], Edge(label='eg async def test_run_graph(): - result, history = await graph1.run(Foo()) + sp = FullStatePersistence() + result = await graph1.run(Foo(), persistence=sp) assert result is None - assert history == snapshot( + assert sp.history == snapshot( [ NodeSnapshot( state=None, @@ -74,7 +85,7 @@ async def test_run_graph(): start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndSnapshot(result=End(data=None), ts=IsNow(tz=timezone.utc)), + EndSnapshot(state=None, result=End(None), ts=IsNow(tz=timezone.utc)), ] ) diff --git a/tests/graph/test_state.py b/tests/graph/test_state.py index 7eec60824f..2bdfc03f6b 100644 --- a/tests/graph/test_state.py +++ b/tests/graph/test_state.py @@ -1,4 +1,3 @@ -# pyright: reportPrivateUsage=false from __future__ import annotations as _annotations from dataclasses import dataclass @@ -7,7 +6,15 @@ import pytest from inline_snapshot import snapshot -from pydantic_graph import BaseNode, End, EndSnapshot, Graph, GraphRunContext, NodeSnapshot +from pydantic_graph import ( + BaseNode, + End, + EndSnapshot, + FullStatePersistence, + Graph, + GraphRunContext, + NodeSnapshot, +) from ..conftest import IsFloat, IsNow @@ -33,26 +40,27 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[str]: return End(f'x={ctx.state.x} y={ctx.state.y}') graph = Graph(nodes=(Foo, Bar)) - assert graph._get_state_type() is MyState - assert graph._get_run_end_type() is str + assert graph._get_state_type() is MyState # pyright: ignore[reportPrivateUsage] + assert graph._get_run_end_type() is str # pyright: ignore[reportPrivateUsage] state = MyState(1, '') - result, history = await graph.run(Foo(), state=state) + sp = FullStatePersistence.from_types(MyState, str) + result = await graph.run(Foo(), state=state, persistence=sp) assert result == snapshot('x=2 y=y') - assert history == snapshot( + assert sp.history == snapshot( [ NodeSnapshot( - state=MyState(x=2, y=''), + state=MyState(x=1, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), NodeSnapshot( - state=MyState(x=2, y='y'), + state=MyState(x=2, y=''), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndSnapshot(result=End('x=2 y=y'), ts=IsNow(tz=timezone.utc)), + EndSnapshot(state=MyState(x=2, y='y'), result=End('x=2 y=y'), ts=IsNow(tz=timezone.utc)), ] ) assert state == MyState(x=2, y='y') From b737c97ac0feb6ba0c9ac3a4fd25a24b08dff901 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 23 Feb 2025 20:28:41 -0500 Subject: [PATCH 03/25] tests all passing --- docs/graph.md | 21 +- .../pydantic_ai_examples/question_graph.py | 37 ++-- pydantic_ai_slim/pydantic_ai/agent.py | 6 +- pydantic_graph/README.md | 5 +- pydantic_graph/pydantic_graph/graph.py | 84 ++++---- .../pydantic_graph/state/__init__.py | 34 +-- pydantic_graph/pydantic_graph/state/_utils.py | 17 +- pydantic_graph/pydantic_graph/state/memory.py | 38 ++-- tests/graph/test_graph.py | 11 +- tests/graph/test_history.py | 152 ------------- tests/graph/test_persistence.py | 202 ++++++++++++++++++ tests/graph/test_state.py | 5 +- tests/typed_graph.py | 29 ++- 13 files changed, 363 insertions(+), 278 deletions(-) delete mode 100644 tests/graph/test_history.py create mode 100644 tests/graph/test_persistence.py diff --git a/docs/graph.md b/docs/graph.md index b62d67e890..b4e68930db 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -156,12 +156,9 @@ class Increment(BaseNode): # (2)! fives_graph = Graph(nodes=[DivisibleBy5, Increment]) # (3)! -result, history = fives_graph.run_sync(DivisibleBy5(4)) # (4)! +result = fives_graph.run_sync(DivisibleBy5(4)) # (4)! print(result) #> 5 -# the full history is quite verbose (see below), so we'll just print the summary -print([item.data_snapshot() for item in history]) -#> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] ``` 1. The `DivisibleBy5` node is parameterized with `None` for the state param and `None` for the deps param as this graph doesn't use state or deps, and `int` as it can end the run. @@ -464,7 +461,7 @@ async def main(): ) state = State(user) feedback_graph = Graph(nodes=(WriteEmail, Feedback)) - email, _ = await feedback_graph.run(WriteEmail(), state=state) + email = await feedback_graph.run(WriteEmail(), state=state) print(email) """ Email( @@ -579,7 +576,7 @@ In this example, an AI asks the user a question, the user provides an answer, th ```python {title="ai_q_and_a_run.py" noqa="I001" py="3.10"} from rich.prompt import Prompt -from pydantic_graph import End, Snapshot +from pydantic_graph import End, FullStatePersistence from ai_q_and_a_graph import Ask, question_graph, QuestionState, Answer @@ -587,15 +584,17 @@ from ai_q_and_a_graph import Ask, question_graph, QuestionState, Answer async def main(): state = QuestionState() # (1)! node = Ask() # (2)! - history: list[Snapshot[QuestionState]] = [] # (3)! + persistence = FullStatePersistence() # (3)! while True: - node = await question_graph.next(node, history, state=state) # (4)! + node = await question_graph.next( # (4)! + node, persistence=persistence, state=state + ) if isinstance(node, Answer): node.answer = Prompt.ask(node.question) # (5)! elif isinstance(node, End): # (6)! print(f'Correct answer! {node.data}') - # > Correct answer! Well done, 1 + 1 = 2 - print([e.data_snapshot() for e in history]) + #> Correct answer! Well done, 1 + 1 = 2 + print([e.node for e in persistence.history]) """ [ Ask(), @@ -613,7 +612,7 @@ async def main(): 1. Create the state object which will be mutated by [`next`][pydantic_graph.graph.Graph.next]. 2. The start node is `Ask` but will be updated by [`next`][pydantic_graph.graph.Graph.next] as the graph runs. -3. The history of the graph run is stored in a list of [`HistoryStep`][pydantic_graph.state.HistoryStep] objects. Again [`next`][pydantic_graph.graph.Graph.next] will update this list in place. +3. The history of the graph run is stored using [`FullStatePersistence`][pydantic_graph.state.memory.FullStatePersistence]. Again [`next`][pydantic_graph.graph.Graph.next] will update this list in place. 4. [Run][pydantic_graph.graph.Graph.next] the graph one node at a time, updating the state, current node and history as the graph runs. 5. If the current node is an `Answer` node, prompt the user for an answer. 6. Since we're using [`next`][pydantic_graph.graph.Graph.next] we have to manually check for an [`End`][pydantic_graph.nodes.End] and exit the loop if we get one. diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index 957e5d75b1..78f751b5df 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -13,7 +13,14 @@ import logfire from devtools import debug -from pydantic_graph import BaseNode, Edge, End, Graph, GraphRunContext, Snapshot +from pydantic_graph import ( + BaseNode, + Edge, + End, + FullStatePersistence, + Graph, + GraphRunContext, +) from pydantic_ai import Agent from pydantic_ai.format_as_xml import format_as_xml @@ -116,12 +123,12 @@ async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask: async def run_as_continuous(): state = QuestionState() node = Ask() - history: list[Snapshot[QuestionState, None]] = [] + persistence = FullStatePersistence() with logfire.span('run questions graph'): while True: - node = await question_graph.next(node, history, state=state) + node = await question_graph.next(node, persistence=persistence, state=state) if isinstance(node, End): - debug([e.data_snapshot() for e in history]) + debug([e.node for e in persistence.history]) break elif isinstance(node, Answer): assert state.question @@ -131,14 +138,14 @@ async def run_as_continuous(): async def run_as_cli(answer: str | None): history_file = Path('question_graph_history.json') - history = ( - question_graph.load_history(history_file.read_bytes()) - if history_file.exists() - else [] - ) - - if history: - last = history[-1] + persistence = FullStatePersistence() + question_graph.set_persistence_types(persistence) + + if history_file.exists(): + persistence.load_json(history_file.read_bytes()) + + if persistence.history: + last = persistence.history[-1] assert last.kind == 'node', 'expected last step to be a node' state = last.state assert answer is not None, 'answer is required to continue from history' @@ -150,9 +157,9 @@ async def run_as_cli(answer: str | None): with logfire.span('run questions graph'): while True: - node = await question_graph.next(node, history, state=state) + node = await question_graph.next(node, persistence=persistence, state=state) if isinstance(node, End): - debug([e.data_snapshot() for e in history]) + debug([e.node for e in persistence.history]) print('Finished!') break elif isinstance(node, Answer): @@ -160,7 +167,7 @@ async def run_as_cli(answer: str | None): break # otherwise just continue - history_file.write_bytes(question_graph.dump_history(history, indent=2)) + history_file.write_bytes(persistence.dump_json(indent=2)) if __name__ == '__main__': diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 3978d0138d..edff583d9a 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -11,7 +11,7 @@ import logfire_api from typing_extensions import TypeVar, deprecated -from pydantic_graph import Graph, GraphRunContext, Snapshot +from pydantic_graph import Graph, GraphRunContext from pydantic_graph.nodes import End from . import ( @@ -337,7 +337,7 @@ async def main(): ) # Actually run - end_result, _ = await graph.run( + end_result = await graph.run( start_node, state=state, deps=graph_deps, @@ -583,7 +583,6 @@ async def main(): # Actually run node = start_node - history: list[Snapshot[_agent_graph.GraphAgentState, RunResultDataT]] = [] while True: if isinstance(node, _agent_graph.StreamModelRequestNode): node = cast( @@ -599,7 +598,6 @@ async def main(): assert not isinstance(node, End) # the previous line should be hit first node = await graph.next( node, - history, state=graph_state, deps=graph_deps, infer_name=False, diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md index 15a4062e05..456dd676f0 100644 --- a/pydantic_graph/README.md +++ b/pydantic_graph/README.md @@ -50,10 +50,7 @@ class Increment(BaseNode): fives_graph = Graph(nodes=[DivisibleBy5, Increment]) -result, history = fives_graph.run_sync(DivisibleBy5(4)) +result = fives_graph.run_sync(DivisibleBy5(4)) print(result) #> 5 -# the full history is quite verbose (see below), so we'll just print the summary -print([item.data_snapshot() for item in history]) -#> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] ``` diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index d4b79497d6..9ca70e54e9 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -10,13 +10,11 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar import logfire_api -import pydantic import typing_extensions -from inline_snapshot import Snapshot from . import _utils, exceptions, mermaid from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT -from .state import StatePersistence, StateT, build_nodes_type_adapter +from .state import StatePersistence, StateT, set_nodes_type_context from .state.memory import SimpleStatePersistence # while waiting for https://github.com/pydantic/logfire/issues/745 @@ -150,18 +148,14 @@ async def run( async def main(): state = MyState(1) - _, history = await never_42_graph.run(Increment(), state=state) + await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=2) - print(len(history)) - #> 3 state = MyState(41) - _, history = await never_42_graph.run(Increment(), state=state) + await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=43) - print(len(history)) - #> 5 ``` """ if infer_name and self.name is None: @@ -170,6 +164,8 @@ async def main(): if persistence is None: persistence = SimpleStatePersistence() + self.set_persistence_types(persistence) + with ExitStack() as stack: if self._auto_instrument: stack.enter_context( @@ -178,9 +174,7 @@ async def main(): next_node = start_node while True: - next_node = await self.next( - next_node, persistence=persistence, state=state, deps=deps, infer_name=False - ) + next_node = await self._next(next_node, persistence, state, deps) if isinstance(next_node, End): await persistence.snapshot_end(state, next_node) return next_node.data @@ -251,7 +245,17 @@ async def next( if persistence is None: persistence = SimpleStatePersistence() + self.set_persistence_types(persistence) + + return await self._next(node, persistence, state, deps) + async def _next( + self: Graph[StateT, DepsT, T], + node: BaseNode[StateT, DepsT, T], + persistence: StatePersistence[StateT, T], + state: StateT, + deps: DepsT, + ) -> BaseNode[StateT, DepsT, Any] | End[T]: node_id = node.get_id() if node_id not in self.node_defs: raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') @@ -285,6 +289,10 @@ async def next_from_persistence( infer_name=False, ) + def set_persistence_types(self, persistence: StatePersistence[StateT, RunEndT]) -> None: + with set_nodes_type_context([node_def.node for node_def in self.node_defs.values()]): + persistence.set_types(lambda: self._inferred_types) + def mermaid_code( self, *, @@ -409,50 +417,36 @@ def mermaid_save( mermaid.save_image(path, self, **kwargs) @cached_property - def _node_type_adapter(self) -> pydantic.TypeAdapter[BaseNode[StateT, Any, RunEndT]]: - nodes = [node_def.node for node_def in self.node_defs.values()] - return build_nodes_type_adapter(nodes, self._get_state_type(), self._get_run_end_type()) - - @cached_property - def _end_data_type_adapter(self) -> pydantic.TypeAdapter[RunEndT]: - end_t = self._get_run_end_type() - return pydantic.TypeAdapter(end_t) + def _inferred_types(self) -> tuple[type[StateT], type[RunEndT]]: + if _utils.is_set(self._state_type) and _utils.is_set(self._run_end_type): + return self._state_type, self._run_end_type - @cached_property - def _snapshot_type_adapter(self) -> pydantic.TypeAdapter[Snapshot[StateT, RunEndT]]: - pass - - def _get_state_type(self) -> type[StateT]: - if _utils.is_set(self._state_type): - return self._state_type + state_type = self._state_type + run_end_type = self._run_end_type for node_def in self.node_defs.values(): for base in typing_extensions.get_original_bases(node_def.node): if typing_extensions.get_origin(base) is BaseNode: args = typing_extensions.get_args(base) - if args: - return args[0] - # break the inner (bases) loop - break - # state defaults to None, so use that if we can't infer it - return type(None) # pyright: ignore[reportReturnType] - - def _get_run_end_type(self) -> type[RunEndT]: - if _utils.is_set(self._run_end_type): - return self._run_end_type + if not _utils.is_set(state_type) and args: + state_type = args[0] - for node_def in self.node_defs.values(): - for base in typing_extensions.get_original_bases(node_def.node): - if typing_extensions.get_origin(base) is BaseNode: - args = typing_extensions.get_args(base) - if len(args) == 3: + if not _utils.is_set(run_end_type) and len(args) == 3: t = args[2] if not _utils.is_never(t): - return t + run_end_type = t + if _utils.is_set(state_type) and _utils.is_set(run_end_type): + return state_type, run_end_type # break the inner (bases) loop break - # this happens if a graph has no return nodes, use None so any downstream errors a clear - return type(None) # pyright: ignore[reportReturnType] + + if not _utils.is_set(state_type): + # state defaults to None, so use that if we can't infer it + state_type = None + if not _utils.is_set(run_end_type): + # this happens if a graph has no return nodes, use None so any downstream errors are clear + run_end_type = None + return state_type, run_end_type # pyright: ignore[reportReturnType] def _register_node( self: Graph[StateT, DepsT, T], diff --git a/pydantic_graph/pydantic_graph/state/__init__.py b/pydantic_graph/pydantic_graph/state/__init__.py index 2ef10cd138..521e608572 100644 --- a/pydantic_graph/pydantic_graph/state/__init__.py +++ b/pydantic_graph/pydantic_graph/state/__init__.py @@ -1,8 +1,8 @@ from __future__ import annotations as _annotations from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Sequence -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Iterator, Sequence +from contextlib import asynccontextmanager, contextmanager from dataclasses import dataclass, field from datetime import datetime from typing import Annotated, Any, Callable, Generic, Literal, Union @@ -14,7 +14,7 @@ from ..nodes import BaseNode, End, RunEndT from . import _utils -__all__ = 'StateT', 'NodeSnapshot', 'EndSnapshot', 'Snapshot', 'StatePersistence', 'build_nodes_type_adapter' +__all__ = 'StateT', 'NodeSnapshot', 'EndSnapshot', 'Snapshot', 'StatePersistence', 'set_nodes_type_context' StateT = TypeVar('StateT', default=None) """Type variable for the state in a graph.""" @@ -26,7 +26,7 @@ class NodeSnapshot(Generic[StateT, RunEndT]): state: StateT """The state of the graph before the node is run.""" - node: BaseNode[StateT, Any, RunEndT] + node: Annotated[BaseNode[StateT, Any, RunEndT], _utils.CustomNodeSchema()] """The node to run next.""" start_ts: datetime | None = None """The timestamp when the node started running, `None` until the run starts.""" @@ -112,13 +112,7 @@ async def restore(self) -> Snapshot[StateT, RunEndT] | None: """ raise NotImplementedError - def set_type_adapters( - self, - *, - get_node_type_adapter: Callable[[], pydantic.TypeAdapter[BaseNode[StateT, Any, RunEndT]]], - get_end_data_type_adapter: Callable[[], pydantic.TypeAdapter[RunEndT]], - get_snapshot_type_adapter: Callable[[], pydantic.TypeAdapter[Snapshot[StateT, RunEndT]]], - ): + def set_types(self, get_types: Callable[[], tuple[type[StateT], type[RunEndT]]]) -> None: pass async def restore_node_snapshot(self) -> NodeSnapshot[StateT, RunEndT]: @@ -130,10 +124,18 @@ async def restore_node_snapshot(self) -> NodeSnapshot[StateT, RunEndT]: return snapshot -def build_nodes_type_adapter( # noqa: D103 - nodes: Sequence[type[BaseNode[Any, Any, Any]]], state_t: type[StateT], end_t: type[RunEndT] -) -> pydantic.TypeAdapter[BaseNode[StateT, Any, RunEndT]]: +@contextmanager +def set_nodes_type_context(nodes: Sequence[type[BaseNode[Any, Any, Any]]]) -> Iterator[None]: # noqa: D103 + token = _utils.nodes_type_context.set(nodes) + try: + yield + finally: + _utils.nodes_type_context.reset(token) + + +def build_snapshots_type_adapter( + state_t: type[StateT], run_end_t: type[RunEndT] +) -> pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]]: return pydantic.TypeAdapter( - Annotated[BaseNode[state_t, Any, end_t], _utils.CustomNodeSchema(nodes)], - config=pydantic.ConfigDict(defer_build=True), + list[Annotated[Snapshot[state_t, run_end_t], pydantic.Discriminator('kind')]], ) diff --git a/pydantic_graph/pydantic_graph/state/_utils.py b/pydantic_graph/pydantic_graph/state/_utils.py index b854df8d14..eaaf2d2c92 100644 --- a/pydantic_graph/pydantic_graph/state/_utils.py +++ b/pydantic_graph/pydantic_graph/state/_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations from collections.abc import Sequence +from contextvars import ContextVar from dataclasses import dataclass from datetime import datetime, timezone from typing import Annotated, Any, Union @@ -10,18 +11,24 @@ from ..nodes import BaseNode +nodes_type_context: ContextVar[Sequence[type[BaseNode[Any, Any, Any]]]] = ContextVar('nodes_type_context') + @dataclass class CustomNodeSchema: - nodes: Sequence[type[BaseNode[Any, Any, Any]]] - def __get_pydantic_core_schema__( self, _source_type: Any, handler: pydantic.GetCoreSchemaHandler ) -> core_schema.CoreSchema: - if len(self.nodes) == 1: - nodes_type = self.nodes[0] + try: + nodes = nodes_type_context.get() + except LookupError as e: + raise RuntimeError( + 'Unable to build a Pydantic schema for `BaseNode` without setting `nodes_type_context`.' + ) from e + if len(nodes) == 1: + nodes_type = nodes[0] else: - nodes_annotated = [Annotated[node, pydantic.Tag(node.get_id())] for node in self.nodes] + nodes_annotated = [Annotated[node, pydantic.Tag(node.get_id())] for node in nodes] nodes_type = Annotated[Union[tuple(nodes_annotated)], pydantic.Discriminator(self._node_discriminator)] schema = handler(nodes_type) diff --git a/pydantic_graph/pydantic_graph/state/memory.py b/pydantic_graph/pydantic_graph/state/memory.py index 3a83eba241..4697460ee0 100644 --- a/pydantic_graph/pydantic_graph/state/memory.py +++ b/pydantic_graph/pydantic_graph/state/memory.py @@ -5,10 +5,16 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field from time import perf_counter -from typing import Any, TypeVar +from typing import Any, Callable -from ..nodes import BaseNode, End, RunEndT -from . import EndSnapshot, NodeSnapshot, Snapshot, StatePersistence, StateT, _utils +import pydantic +from typing_extensions import TypeVar + +from ..nodes import BaseNode, End +from . import EndSnapshot, NodeSnapshot, Snapshot, StatePersistence, _utils, build_snapshots_type_adapter + +StateT = TypeVar('StateT', default=Any) +RunEndT = TypeVar('RunEndT', default=Any) S = TypeVar('S') R = TypeVar('R') @@ -21,11 +27,6 @@ class SimpleStatePersistence(StatePersistence[StateT, RunEndT]): deep_copy: bool = True last_snapshot: Snapshot[StateT, RunEndT] | None = None - @classmethod - def from_types(cls, state_type: type[S], run_end_type: type[R]) -> SimpleStatePersistence[S, R]: - """No-op init method that help type checkers.""" - return SimpleStatePersistence() - async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: self.last_snapshot = NodeSnapshot( state=self.prep_state(state), @@ -69,11 +70,9 @@ class FullStatePersistence(StatePersistence[StateT, RunEndT]): deep_copy: bool = True history: list[Snapshot[StateT, RunEndT]] = field(default_factory=list) - - @classmethod - def from_types(cls, state_type: type[S], run_end_type: type[R]) -> FullStatePersistence[S, R]: - """No-op init method that help type checkers.""" - return FullStatePersistence() + _snapshots_type_adapter: pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]] | None = field( + default=None, init=False, repr=False + ) async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: self.history.append( @@ -109,6 +108,19 @@ async def restore(self) -> Snapshot[StateT, RunEndT] | None: if self.history: return self.history[-1] + def dump_json(self, *, indent: int | None = None) -> bytes: + assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `dump_json`' + return self._snapshots_type_adapter.dump_json(self.history, indent=indent) + + def load_json(self, json_data: str | bytes | bytearray) -> None: + assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `load_json`' + self.history = self._snapshots_type_adapter.validate_json(json_data) + + def set_types(self, get_types: Callable[[], tuple[type[StateT], type[RunEndT]]]) -> None: + if self._snapshots_type_adapter is None: + state_t, run_end_t = get_types() + self._snapshots_type_adapter = build_snapshots_type_adapter(state_t, run_end_t) + def prep_state(self, state: StateT) -> StateT: """Prepare state for snapshot, uses [`copy.deepcopy`][copy.deepcopy] by default.""" if not self.deep_copy or state is None: diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 6320ea89d3..4928b595a6 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -57,8 +57,7 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # async def test_graph(): my_graph = Graph(nodes=(Float2String, String2Length, Double)) assert my_graph.name is None - assert my_graph._get_state_type() is type(None) - assert my_graph._get_run_end_type() is int + assert my_graph._inferred_types == (type(None), int) result = await my_graph.run(Float2String(3.14)) # len('3.14') * 2 == 8 assert result == 8 @@ -68,8 +67,7 @@ async def test_graph(): async def test_graph_history(): my_graph = Graph[None, None, int](nodes=(Float2String, String2Length, Double)) assert my_graph.name is None - assert my_graph._get_state_type() is type(None) - assert my_graph._get_run_end_type() is int + assert my_graph._inferred_types == (type(None), int) sp = FullStatePersistence() result = await my_graph.run(Float2String(3.14), persistence=sp) # len('3.14') * 2 == 8 @@ -154,7 +152,7 @@ class Float2String(BaseNode): async def run(self, ctx: GraphRunContext) -> String2Length: raise NotImplementedError() - class String2Length(BaseNode[None, None, None]): + class String2Length(BaseNode[None, None, None]): # pyright: ignore[reportUnusedClass] async def run(self, ctx: GraphRunContext) -> End[None]: raise NotImplementedError() @@ -277,8 +275,7 @@ async def run(self, ctx: GraphRunContext) -> End[None]: return 42 # type: ignore g = Graph(nodes=(Foo, Bar)) - assert g._get_state_type() is type(None) - assert g._get_run_end_type() is type(None) + assert g._inferred_types == (type(None), type(None)) with pytest.raises(GraphRuntimeError) as exc_info: await g.run(Foo()) diff --git a/tests/graph/test_history.py b/tests/graph/test_history.py deleted file mode 100644 index 8dcfda570e..0000000000 --- a/tests/graph/test_history.py +++ /dev/null @@ -1,152 +0,0 @@ -# pyright: reportPrivateUsage=false -from __future__ import annotations as _annotations - -from dataclasses import dataclass -from datetime import timezone - -import pytest -from inline_snapshot import snapshot - -from pydantic_graph import ( - BaseNode, - End, - EndSnapshot, - FullStatePersistence, - Graph, - GraphRunContext, - NodeSnapshot, -) - -from ..conftest import IsFloat, IsNow - -pytestmark = pytest.mark.anyio - - -@dataclass -class MyState: - x: int - y: str - - -@dataclass -class Foo(BaseNode[MyState]): - async def run(self, ctx: GraphRunContext[MyState]) -> Bar: - ctx.state.x += 1 - return Bar() - - -@dataclass -class Bar(BaseNode[MyState, None, int]): - async def run(self, ctx: GraphRunContext[MyState]) -> End[int]: - ctx.state.y += 'y' - return End(ctx.state.x * 2) - - -@pytest.mark.parametrize( - 'graph', - [ - Graph(nodes=(Foo, Bar), state_type=MyState, run_end_type=int), - Graph(nodes=(Foo, Bar), state_type=MyState), - Graph(nodes=(Foo, Bar), run_end_type=int), - Graph(nodes=(Foo, Bar)), - ], -) -async def test_dump_load_history(graph: Graph[MyState, None, int]): - sp = FullStatePersistence.from_types(MyState, int) - result = await graph.run(Foo(), state=MyState(1, ''), persistence=sp) - assert result == snapshot(4) - assert sp.history == snapshot( - [ - NodeSnapshot(state=MyState(x=1, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - NodeSnapshot(state=MyState(x=2, y=''), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - EndSnapshot(state=MyState(x=2, y='y'), result=End(4), ts=IsNow(tz=timezone.utc)), - ] - ) - # history_json = graph.dump_history(history) - # assert json.loads(history_json) == snapshot( - # [ - # { - # 'state': {'x': 2, 'y': ''}, - # 'node': {'node_id': 'Foo'}, - # 'start_ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), - # 'duration': IsFloat(), - # 'kind': 'node', - # }, - # { - # 'state': {'x': 2, 'y': 'y'}, - # 'node': {'node_id': 'Bar'}, - # 'start_ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), - # 'duration': IsFloat(), - # 'kind': 'node', - # }, - # {'result': {'data': 4}, 'ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), 'kind': 'end'}, - # ] - # ) - # history_loaded = graph.load_history(history_json) - # assert history == history_loaded - # - # custom_history = [ - # { - # 'state': {'x': 2, 'y': ''}, - # 'node': {'node_id': 'Foo'}, - # 'start_ts': '2025-01-01T00:00:00Z', - # 'duration': 123, - # 'kind': 'node', - # }, - # {'result': {'data': '42'}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'}, - # ] - # history_loaded = graph.load_history(json.dumps(custom_history)) - # assert history_loaded == snapshot( - # [ - # NodeSnapshot( - # state=MyState(x=2, y=''), - # node=Foo(), - # start_ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), - # duration=123.0, - # ), - # EndSnapshot(result=End(data=42), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), - # ] - # ) - - -# def test_one_node(): -# @dataclass -# class MyNode(BaseNode[None, None, int]): -# async def run(self, ctx: GraphRunContext) -> End[int]: -# return End(123) -# -# g = Graph(nodes=[MyNode]) -# -# custom_history = [ -# {'result': {'data': '123'}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'}, -# ] -# history_loaded = g.load_history(json.dumps(custom_history)) -# assert history_loaded == snapshot( -# [ -# EndSnapshot(result=End(data=123), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), -# ] -# ) - - -def test_no_generic_arg(): - @dataclass - class NoGenericArgsNode(BaseNode): - async def run(self, ctx: GraphRunContext) -> NoGenericArgsNode: - return NoGenericArgsNode() - - g = Graph(nodes=[NoGenericArgsNode]) - assert g._get_state_type() is type(None) - assert g._get_run_end_type() is type(None) - - g = Graph(nodes=[NoGenericArgsNode], run_end_type=None) # pyright: ignore[reportArgumentType] - assert g._get_run_end_type() is None - - # custom_history = [ - # {'result': {'data': None}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'}, - # ] - # history_loaded = g.load_history(json.dumps(custom_history)) - # assert history_loaded == snapshot( - # [ - # EndSnapshot(result=End(data=None), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), - # ] - # ) diff --git a/tests/graph/test_persistence.py b/tests/graph/test_persistence.py new file mode 100644 index 0000000000..0c9a830425 --- /dev/null +++ b/tests/graph/test_persistence.py @@ -0,0 +1,202 @@ +# pyright: reportPrivateUsage=false +from __future__ import annotations as _annotations + +import json +from dataclasses import dataclass +from datetime import datetime, timezone + +import pytest +from dirty_equals import IsStr +from inline_snapshot import snapshot + +from pydantic_graph import ( + BaseNode, + End, + EndSnapshot, + FullStatePersistence, + Graph, + GraphRunContext, + NodeSnapshot, +) + +from ..conftest import IsFloat, IsNow + +pytestmark = pytest.mark.anyio + + +@dataclass +class MyState: + x: int + y: str + + +@dataclass +class Foo(BaseNode[MyState]): + async def run(self, ctx: GraphRunContext[MyState]) -> Bar: + ctx.state.x += 1 + return Bar() + + +@dataclass +class Bar(BaseNode[MyState, None, int]): + async def run(self, ctx: GraphRunContext[MyState]) -> End[int]: + ctx.state.y += 'y' + return End(ctx.state.x * 2) + + +@pytest.mark.parametrize( + 'graph', + [ + Graph(nodes=(Foo, Bar), state_type=MyState, run_end_type=int), + Graph(nodes=(Foo, Bar), state_type=MyState), + Graph(nodes=(Foo, Bar), run_end_type=int), + Graph(nodes=(Foo, Bar)), + ], +) +async def test_dump_load_history(graph: Graph[MyState, None, int]): + sp = FullStatePersistence() + result = await graph.run(Foo(), state=MyState(1, ''), persistence=sp) + assert result == snapshot(4) + assert sp.history == snapshot( + [ + NodeSnapshot(state=MyState(x=1, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeSnapshot(state=MyState(x=2, y=''), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + EndSnapshot(state=MyState(x=2, y='y'), result=End(4), ts=IsNow(tz=timezone.utc)), + ] + ) + history_json = sp.dump_json() + assert json.loads(history_json) == snapshot( + [ + { + 'state': {'x': 1, 'y': ''}, + 'node': {'node_id': 'Foo'}, + 'start_ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), + 'duration': IsFloat(), + 'kind': 'node', + }, + { + 'state': {'x': 2, 'y': ''}, + 'node': {'node_id': 'Bar'}, + 'start_ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), + 'duration': IsFloat(), + 'kind': 'node', + }, + { + 'state': {'x': 2, 'y': 'y'}, + 'result': {'data': 4}, + 'ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), + 'kind': 'end', + }, + ] + ) + + sp2 = FullStatePersistence() + graph.set_persistence_types(sp2) + + sp2.load_json(history_json) + assert sp.history == sp2.history + + custom_history = [ + { + 'state': {'x': 2, 'y': ''}, + 'node': {'node_id': 'Foo'}, + 'start_ts': '2025-01-01T00:00:00Z', + 'duration': 123, + 'kind': 'node', + }, + { + 'state': {'x': 42, 'y': 'new'}, + 'result': {'data': '42'}, + 'ts': '2025-01-01T00:00:00Z', + 'kind': 'end', + }, + ] + sp3 = FullStatePersistence() + graph.set_persistence_types(sp3) + sp3.load_json(json.dumps(custom_history)) + assert sp3.history == snapshot( + [ + NodeSnapshot( + state=MyState(x=2, y=''), + node=Foo(), + start_ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), + duration=123.0, + ), + EndSnapshot( + state=MyState(x=42, y='new'), result=End(data=42), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc) + ), + ] + ) + + +def test_one_node(): + @dataclass + class MyNode(BaseNode[None, None, int]): + node_field: int + + async def run(self, ctx: GraphRunContext) -> End[int]: + return End(123) + + g = Graph(nodes=[MyNode]) + + custom_history = [ + { + 'state': None, + 'node': {'node_id': 'MyNode', 'node_field': 42}, + 'start_ts': '2025-01-01T00:00:00Z', + 'duration': 123, + 'kind': 'node', + }, + ] + sp = FullStatePersistence() + g.set_persistence_types(sp) + sp.load_json(json.dumps(custom_history)) + assert sp.history == snapshot( + [ + NodeSnapshot( + state=None, + node=MyNode(node_field=42), + start_ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), + duration=123.0, + ) + ] + ) + + +def test_no_generic_arg(): + @dataclass + class NoGenericArgsNode(BaseNode): + async def run(self, ctx: GraphRunContext) -> NoGenericArgsNode: + return NoGenericArgsNode() + + g = Graph(nodes=[NoGenericArgsNode]) + assert g._inferred_types == (None, None) + + g = Graph(nodes=[NoGenericArgsNode], run_end_type=None) # pyright: ignore[reportArgumentType] + + assert g._inferred_types == (None, None) + + custom_history = [ + { + 'state': None, + 'node': {'node_id': 'NoGenericArgsNode'}, + 'start_ts': '2025-01-01T00:00:00Z', + 'duration': 123, + 'kind': 'node', + }, + ] + + sp = FullStatePersistence() + g.set_persistence_types(sp) + sp.load_json(json.dumps(custom_history)) + + assert sp.history == snapshot( + [ + NodeSnapshot( + state=None, + node=NoGenericArgsNode(), + start_ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), + duration=123.0, + ) + ] + ) diff --git a/tests/graph/test_state.py b/tests/graph/test_state.py index 2bdfc03f6b..ffc0637327 100644 --- a/tests/graph/test_state.py +++ b/tests/graph/test_state.py @@ -40,10 +40,9 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[str]: return End(f'x={ctx.state.x} y={ctx.state.y}') graph = Graph(nodes=(Foo, Bar)) - assert graph._get_state_type() is MyState # pyright: ignore[reportPrivateUsage] - assert graph._get_run_end_type() is str # pyright: ignore[reportPrivateUsage] + assert graph._inferred_types == (MyState, str) # pyright: ignore[reportPrivateUsage] state = MyState(1, '') - sp = FullStatePersistence.from_types(MyState, str) + sp = FullStatePersistence() result = await graph.run(Foo(), state=state, persistence=sp) assert result == snapshot('x=2 y=y') assert sp.history == snapshot( diff --git a/tests/typed_graph.py b/tests/typed_graph.py index b484ad0571..d1dd691e2d 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -1,10 +1,11 @@ from __future__ import annotations as _annotations from dataclasses import dataclass +from typing import Any from typing_extensions import assert_type -from pydantic_graph import BaseNode, End, Graph, GraphRunContext, Snapshot +from pydantic_graph import BaseNode, End, FullStatePersistence, Graph, GraphRunContext @dataclass @@ -109,6 +110,28 @@ def run_g5() -> None: g5.run_sync(A()) # pyright: ignore[reportArgumentType] g5.run_sync(A(), state=MyState(x=1)) # pyright: ignore[reportArgumentType] g5.run_sync(A(), deps=MyDeps(y='y')) # pyright: ignore[reportArgumentType] - answer, history = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y')) + answer = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y')) assert_type(answer, int) - assert_type(history, list[Snapshot[MyState, int]]) + + +p = FullStatePersistence() +assert_type(p, FullStatePersistence[Any, Any]) + + +def run_persistence_any() -> None: + p = FullStatePersistence() + answer = g5.run_sync(A(), persistence=p, state=MyState(x=1), deps=MyDeps(y='y')) + assert_type(answer, int) + assert_type(p, FullStatePersistence[Any, Any]) + + +def run_persistence_right() -> None: + p: FullStatePersistence[MyState, int] = FullStatePersistence() + answer = g5.run_sync(A(), persistence=p, state=MyState(x=1), deps=MyDeps(y='y')) + assert_type(answer, int) + assert_type(p, FullStatePersistence[MyState, int]) + + +def run_persistence_wrong() -> None: + p: FullStatePersistence[str, int] = FullStatePersistence() + g5.run_sync(A(), persistence=p, state=MyState(x=1), deps=MyDeps(y='y')) # type: ignore[arg-type] From 1c46cf977acf9f3369390c34f8765dafe453926d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 26 Feb 2025 13:32:07 -0800 Subject: [PATCH 04/25] simplify --- pydantic_graph/pydantic_graph/state/memory.py | 3 --- tests/typed_graph.py | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/pydantic_graph/pydantic_graph/state/memory.py b/pydantic_graph/pydantic_graph/state/memory.py index 4697460ee0..bca5b1b4fd 100644 --- a/pydantic_graph/pydantic_graph/state/memory.py +++ b/pydantic_graph/pydantic_graph/state/memory.py @@ -16,9 +16,6 @@ StateT = TypeVar('StateT', default=Any) RunEndT = TypeVar('RunEndT', default=Any) -S = TypeVar('S') -R = TypeVar('R') - @dataclass class SimpleStatePersistence(StatePersistence[StateT, RunEndT]): diff --git a/tests/typed_graph.py b/tests/typed_graph.py index d1dd691e2d..c08f404e7a 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -126,12 +126,12 @@ def run_persistence_any() -> None: def run_persistence_right() -> None: - p: FullStatePersistence[MyState, int] = FullStatePersistence() + p = FullStatePersistence[MyState, int]() answer = g5.run_sync(A(), persistence=p, state=MyState(x=1), deps=MyDeps(y='y')) assert_type(answer, int) assert_type(p, FullStatePersistence[MyState, int]) def run_persistence_wrong() -> None: - p: FullStatePersistence[str, int] = FullStatePersistence() + p = FullStatePersistence[str, int]() g5.run_sync(A(), persistence=p, state=MyState(x=1), deps=MyDeps(y='y')) # type: ignore[arg-type] From 456560d2a394dbbb04b2efaceef1476e3ad314cd Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 3 Mar 2025 15:24:29 +0000 Subject: [PATCH 05/25] fixing tests --- docs/graph.md | 20 +++++++++++-------- pydantic_graph/pydantic_graph/graph.py | 7 +++++++ pydantic_graph/pydantic_graph/state/memory.py | 2 +- tests/graph/test_graph.py | 6 +++--- tests/graph/test_mermaid.py | 2 +- tests/graph/test_persistence.py | 2 +- 6 files changed, 25 insertions(+), 14 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index a8a8c5dcb3..f01b4e8b57 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -655,7 +655,7 @@ Here's an example: from __future__ import annotations as _annotations from dataclasses import dataclass -from pydantic_graph import Graph, BaseNode, End, GraphRunContext +from pydantic_graph import Graph, BaseNode, End, GraphRunContext, FullStatePersistence @dataclass @@ -677,7 +677,10 @@ count_down_graph = Graph(nodes=[CountDown]) async def main(): state = CountDownState(counter=3) - async with count_down_graph.iter(CountDown(), state=state) as run: # (1)! + persistence = FullStatePersistence() + async with count_down_graph.iter( + CountDown(), state=state, persistence=persistence + ) as run: # (1)! async for node in run: # (2)! print('Node:', node) #> Node: CountDown() @@ -686,7 +689,7 @@ async def main(): #> Node: End(data=0) print('Final result:', run.result.output) # (3)! #> Final result: 0 - print('History snapshots:', [step.data_snapshot() for step in run.history]) + print('History snapshots:', [step.node for step in persistence.history]) """ History snapshots: [CountDown(), CountDown(), CountDown(), CountDown(), End(data=0)] @@ -704,13 +707,14 @@ Alternatively, you can drive iteration manually with the [`GraphRun.next`][pydan Below is a contrived example that stops whenever the counter is at 2, ignoring any node runs beyond that: ```python {title="count_down_next.py" noqa="I001" py="3.10"} -from pydantic_graph import End +from pydantic_graph import End, FullStatePersistence from count_down import CountDown, CountDownState, count_down_graph async def main(): state = CountDownState(counter=5) - async with count_down_graph.iter(CountDown(), state=state) as run: + sp = FullStatePersistence() + async with count_down_graph.iter(CountDown(), state=state, persistence=sp) as run: node = run.next_node # (1)! while not isinstance(node, End): # (2)! print('Node:', node) @@ -725,11 +729,11 @@ async def main(): print(run.result) # (5)! #> None - for step in run.history: # (6)! - print('History Step:', step.data_snapshot(), step.state) + for step in sp.history: # (6)! + print('History Step:', step.node, step.state) + #> History Step: CountDown() CountDownState(counter=5) #> History Step: CountDown() CountDownState(counter=4) #> History Step: CountDown() CountDownState(counter=3) - #> History Step: CountDown() CountDownState(counter=2) ``` 1. We start by grabbing the first node that will be run in the agent's graph. diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 5d3ecf5eb2..e3713a9e47 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -322,6 +322,13 @@ async def _next( ctx = GraphRunContext(state, deps) async with persistence.record_run(): next_or_end = await node.run(ctx) + + if isinstance(next_or_end, End): + await persistence.snapshot_end(state, next_or_end) + elif not isinstance(next_or_end, BaseNode): + raise exceptions.GraphRuntimeError( + f'Invalid node return type: `{type(next_or_end).__name__}`. Expected `BaseNode` or `End`.' + ) return next_or_end async def next_from_persistence( diff --git a/pydantic_graph/pydantic_graph/state/memory.py b/pydantic_graph/pydantic_graph/state/memory.py index bca5b1b4fd..a9684a4edc 100644 --- a/pydantic_graph/pydantic_graph/state/memory.py +++ b/pydantic_graph/pydantic_graph/state/memory.py @@ -21,7 +21,7 @@ class SimpleStatePersistence(StatePersistence[StateT, RunEndT]): """Simple in memory state persistence that just hold the latest snapshot.""" - deep_copy: bool = True + deep_copy: bool = False last_snapshot: Snapshot[StateT, RunEndT] | None = None async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index e196933818..fdaec007d6 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -71,7 +71,7 @@ async def test_graph_history(): sp = FullStatePersistence() result = await my_graph.run(Float2String(3.14), persistence=sp) # len('3.14') * 2 == 8 - assert result == 8 + assert result.output == 8 assert my_graph.name == 'my_graph' assert sp.history == snapshot( [ @@ -93,7 +93,7 @@ async def test_graph_history(): start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndSnapshot(state=None, result=End(8), ts=IsNow(tz=timezone.utc)), + EndSnapshot(state=None, result=End(data=8), ts=IsNow(tz=timezone.utc)), ] ) sp = FullStatePersistence() @@ -132,7 +132,7 @@ async def test_graph_history(): start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndSnapshot(state=None, result=End(42), ts=IsNow(tz=timezone.utc)), + EndSnapshot(state=None, result=End(data=42), ts=IsNow(tz=timezone.utc)), ] ) assert [e.node for e in sp.history] == snapshot( diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 334ed9b69a..24bdfb0680 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -71,7 +71,7 @@ async def test_run_graph(): sp = FullStatePersistence() result = await graph1.run(Foo(), persistence=sp) assert result.output is None - assert result.persistence == snapshot( + assert sp.history == snapshot( [ NodeSnapshot( state=None, diff --git a/tests/graph/test_persistence.py b/tests/graph/test_persistence.py index daf1a7827b..e394f8a7fc 100644 --- a/tests/graph/test_persistence.py +++ b/tests/graph/test_persistence.py @@ -58,7 +58,7 @@ async def test_dump_load_history(graph: Graph[MyState, None, int]): result = await graph.run(Foo(), state=MyState(1, ''), persistence=sp) assert result.output == snapshot(4) assert result.state == snapshot(MyState(x=2, y='y')) - assert result.persistence == snapshot( + assert sp.history == snapshot( [ NodeSnapshot(state=MyState(x=1, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), NodeSnapshot(state=MyState(x=2, y=''), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), From 767e08de5f3d862e283c214a86e4e5171093d16d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 3 Mar 2025 19:43:11 +0000 Subject: [PATCH 06/25] fix tests for 3.9 etc --- Makefile | 10 +++++----- pydantic_graph/pydantic_graph/nodes.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index c0f78236a9..0f62c060ff 100644 --- a/Makefile +++ b/Makefile @@ -49,11 +49,11 @@ test: ## Run tests and collect coverage data .PHONY: test-all-python test-all-python: ## Run tests on Python 3.9 to 3.13 - UV_PROJECT_ENVIRONMENT=.venv39 uv run --python 3.9 --all-extras coverage run -p -m pytest - UV_PROJECT_ENVIRONMENT=.venv310 uv run --python 3.10 --all-extras coverage run -p -m pytest - UV_PROJECT_ENVIRONMENT=.venv311 uv run --python 3.11 --all-extras coverage run -p -m pytest - UV_PROJECT_ENVIRONMENT=.venv312 uv run --python 3.12 --all-extras coverage run -p -m pytest - UV_PROJECT_ENVIRONMENT=.venv313 uv run --python 3.13 --all-extras coverage run -p -m pytest + UV_PROJECT_ENVIRONMENT=.venv39 uv run --python 3.9 --all-extras --all-packages coverage run -p -m pytest + UV_PROJECT_ENVIRONMENT=.venv310 uv run --python 3.10 --all-extras --all-packages coverage run -p -m pytest + UV_PROJECT_ENVIRONMENT=.venv311 uv run --python 3.11 --all-extras --all-packages coverage run -p -m pytest + UV_PROJECT_ENVIRONMENT=.venv312 uv run --python 3.12 --all-extras --all-packages coverage run -p -m pytest + UV_PROJECT_ENVIRONMENT=.venv313 uv run --python 3.13 --all-extras --all-packages coverage run -p -m pytest @uv run coverage combine @uv run coverage report diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index 1b3dad6cea..08440ee052 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -4,9 +4,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, is_dataclass from functools import cache -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Self, get_origin, get_type_hints +from typing import TYPE_CHECKING, Any, ClassVar, Generic, get_origin, get_type_hints -from typing_extensions import Never, TypeVar +from typing_extensions import Never, Self, TypeVar from . import _utils, exceptions From daffea5d6c8ab0b7c8e285782e54aa427d69b864 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 6 Mar 2025 15:37:14 +0000 Subject: [PATCH 07/25] refactoring state persistence --- pydantic_graph/pydantic_graph/__init__.py | 4 +- pydantic_graph/pydantic_graph/_utils.py | 15 +- pydantic_graph/pydantic_graph/graph.py | 201 +++++++----------- pydantic_graph/pydantic_graph/nodes.py | 2 +- .../{state => persistence}/__init__.py | 60 ++++-- .../{state => persistence}/_utils.py | 0 .../{state => persistence}/memory.py | 84 +++++--- tests/graph/test_persistence.py | 29 +++ tests/typed_graph.py | 2 +- 9 files changed, 219 insertions(+), 178 deletions(-) rename pydantic_graph/pydantic_graph/{state => persistence}/__init__.py (64%) rename pydantic_graph/pydantic_graph/{state => persistence}/_utils.py (100%) rename pydantic_graph/pydantic_graph/{state => persistence}/memory.py (60%) diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index 9db8021889..80f601ad05 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -1,8 +1,8 @@ from .exceptions import GraphRuntimeError, GraphSetupError from .graph import Graph, GraphRun, GraphRunResult from .nodes import BaseNode, Edge, End, GraphRunContext -from .state import EndSnapshot, NodeSnapshot, Snapshot -from .state.memory import FullStatePersistence, SimpleStatePersistence +from .persistence import EndSnapshot, NodeSnapshot, Snapshot +from .persistence.memory import FullStatePersistence, SimpleStatePersistence __all__ = ( 'Graph', diff --git a/pydantic_graph/pydantic_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py index fcb27b2b31..92f9b1ad3e 100644 --- a/pydantic_graph/pydantic_graph/_utils.py +++ b/pydantic_graph/pydantic_graph/_utils.py @@ -3,7 +3,8 @@ import asyncio import sys import types -from typing import Annotated, Any, TypeVar, Union, get_args, get_origin +from functools import partial +from typing import Annotated, Any, Callable, ParamSpec, TypeVar, Union, get_args, get_origin import typing_extensions @@ -104,3 +105,15 @@ class Unset: def is_set(t_or_unset: T | Unset) -> typing_extensions.TypeGuard[T]: return t_or_unset is not UNSET + + +_P = ParamSpec('_P') +_R = TypeVar('_R') + + +async def run_in_executor(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs) -> _R: + if kwargs: + # noinspection PyTypeChecker + return await asyncio.get_running_loop().run_in_executor(None, partial(func, *args, **kwargs)) + else: + return await asyncio.get_running_loop().run_in_executor(None, func, *args) # type: ignore diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index e3713a9e47..4eac2f7eb8 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -6,7 +6,7 @@ from contextlib import ExitStack, asynccontextmanager from dataclasses import dataclass, field from functools import cached_property -from typing import Any, Generic, TypeVar +from typing import Any, Generic, TypeVar, cast import logfire_api import typing_extensions @@ -14,8 +14,8 @@ from . import _utils, exceptions, mermaid from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT -from .state import StatePersistence, StateT, set_nodes_type_context -from .state.memory import SimpleStatePersistence +from .persistence import NodeRunId, StatePersistence, StateT, set_nodes_type_context +from .persistence.memory import SimpleStatePersistence # while waiting for https://github.com/pydantic/logfire/issues/745 try: @@ -84,7 +84,7 @@ async def run(self, ctx: GraphRunContext) -> Increment | End[int]: node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] _state_type: type[StateT] | _utils.Unset = field(repr=False) _run_end_type: type[RunEndT] | _utils.Unset = field(repr=False) - _auto_instrument: bool = field(repr=False) + auto_instrument: bool = field(repr=False) def __init__( self, @@ -109,7 +109,7 @@ def __init__( self.name = name self._state_type = state_type self._run_end_type = run_end_type - self._auto_instrument = auto_instrument + self.auto_instrument = auto_instrument parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) self.node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] = {} @@ -174,66 +174,6 @@ async def main(): assert final_result is not None, 'GraphRun should have a final result' return final_result - @asynccontextmanager - async def iter( - self: Graph[StateT, DepsT, T], - start_node: BaseNode[StateT, DepsT, T], - *, - state: StateT = None, - deps: DepsT = None, - persistence: StatePersistence[StateT, T] | None = None, - infer_name: bool = True, - span: LogfireSpan | None = None, - ) -> AsyncIterator[GraphRun[StateT, DepsT, T]]: - """A contextmanager which can be used to iterate over the graph's nodes as they are executed. - - This method returns a `GraphRun` object which can be used to async-iterate over the nodes of this `Graph` as - they are executed. This is the API to use if you want to record or interact with the nodes as the graph - execution unfolds. - - The `GraphRun` can also be used to manually drive the graph execution by calling - [`GraphRun.next`][pydantic_graph.graph.GraphRun.next]. - - The `GraphRun` provides access to the full run history, state, deps, and the final result of the run once - it has completed. - - For more details, see the API documentation of [`GraphRun`][pydantic_graph.graph.GraphRun]. - - Args: - start_node: the first node to run. Since the graph definition doesn't define the entry point in the graph, - you need to provide the starting node. - state: The initial state of the graph. - deps: The dependencies of the graph. - persistence: State persistence interface, defaults to - [`SimpleStatePersistence`][pydantic_graph.state.memory.SimpleStatePersistence] if `None`. - infer_name: Whether to infer the graph name from the calling frame. - span: The span to use for the graph run. If not provided, a new span will be created. - - Yields: - A GraphRun that can be async iterated over to drive the graph to completion. - """ - if infer_name and self.name is None: - self._infer_name(inspect.currentframe()) - - if self._auto_instrument and span is None: - span = logfire_api.span('run graph {graph.name}', graph=self) - - if persistence is None: - persistence = SimpleStatePersistence() - - with ExitStack() as stack: - if span is not None: - stack.enter_context(span) - yield GraphRun[StateT, DepsT, T]( - self, - start_node, - persistence=persistence, - state=state, - deps=deps, - auto_instrument=self._auto_instrument, - span=span, - ) - def run_sync( self: Graph[StateT, DepsT, T], start_node: BaseNode[StateT, DepsT, T], @@ -267,69 +207,67 @@ def run_sync( self.run(start_node, state=state, deps=deps, persistence=persistence, infer_name=False) ) - async def next( + @asynccontextmanager + async def iter( self: Graph[StateT, DepsT, T], - node: BaseNode[StateT, DepsT, T], + start_node: BaseNode[StateT, DepsT, T], *, - persistence: StatePersistence[StateT, T] | None = None, state: StateT = None, deps: DepsT = None, + persistence: StatePersistence[StateT, T] | None = None, infer_name: bool = True, - ) -> BaseNode[StateT, DepsT, Any] | End[T]: - """Run a node in the graph and return the next node to run. + span: LogfireSpan | None = None, + ) -> AsyncIterator[GraphRun[StateT, DepsT, T]]: + """A contextmanager which can be used to iterate over the graph's nodes as they are executed. + + This method returns a `GraphRun` object which can be used to async-iterate over the nodes of this `Graph` as + they are executed. This is the API to use if you want to record or interact with the nodes as the graph + execution unfolds. + + The `GraphRun` can also be used to manually drive the graph execution by calling + [`GraphRun.next`][pydantic_graph.graph.GraphRun.next]. + + The `GraphRun` provides access to the full run history, state, deps, and the final result of the run once + it has completed. + + For more details, see the API documentation of [`GraphRun`][pydantic_graph.graph.GraphRun]. Args: - node: The node to run. + start_node: the first node to run. Since the graph definition doesn't define the entry point in the graph, + you need to provide the starting node. + state: The initial state of the graph. + deps: The dependencies of the graph. persistence: State persistence interface, defaults to [`SimpleStatePersistence`][pydantic_graph.state.memory.SimpleStatePersistence] if `None`. - state: The current state of the graph. - deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. + span: The span to use for the graph run. If not provided, a new span will be created. - Returns: - The next node to run or [`End`][pydantic_graph.nodes.End] if the graph has finished. + Yields: + A GraphRun that can be async iterated over to drive the graph to completion. """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - if isinstance(node, End): - # While technically this is not compatible with the documented method signature, it's an easy mistake to - # make, and we should eagerly provide a more helpful error message than you'd get otherwise. - raise exceptions.GraphRuntimeError(f'Cannot call `next` with an `End` node: {node!r}.') + if self.auto_instrument and span is None: + span = logfire_api.span('run graph {graph.name}', graph=self) if persistence is None: persistence = SimpleStatePersistence() self.set_persistence_types(persistence) - return await self._next(node, persistence, state, deps) - - async def _next( - self: Graph[StateT, DepsT, T], - node: BaseNode[StateT, DepsT, T], - persistence: StatePersistence[StateT, T], - state: StateT, - deps: DepsT, - ) -> BaseNode[StateT, DepsT, Any] | End[T]: - node_id = node.get_id() - if node_id not in self.node_defs: - raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') - - await persistence.snapshot_node(state, node) - with ExitStack() as stack: - if self._auto_instrument: - stack.enter_context(_logfire.span('run node {node_id}', node_id=node_id, node=node)) - ctx = GraphRunContext(state, deps) - async with persistence.record_run(): - next_or_end = await node.run(ctx) - - if isinstance(next_or_end, End): - await persistence.snapshot_end(state, next_or_end) - elif not isinstance(next_or_end, BaseNode): - raise exceptions.GraphRuntimeError( - f'Invalid node return type: `{type(next_or_end).__name__}`. Expected `BaseNode` or `End`.' + if span is not None: + stack.enter_context(span) + node_run_id = await persistence.snapshot_node(state, start_node) + yield GraphRun[StateT, DepsT, T]( + graph=self, + start_node=start_node, + persistence=persistence, + node_run_id=node_run_id, + state=state, + deps=deps, + span=span, ) - return next_or_end async def next_from_persistence( self: Graph[StateT, DepsT, T], @@ -342,13 +280,16 @@ async def next_from_persistence( self._infer_name(inspect.currentframe()) snapshot = await persistence.restore_node_snapshot() - return await self.next( - snapshot.node, + node_run_id = NotImplemented + run = GraphRun[StateT, DepsT, T]( + graph=self, + start_node=snapshot.node, persistence=persistence, + node_run_id=node_run_id, state=snapshot.state, deps=deps, - infer_name=False, ) + return await run.next() def set_persistence_types(self, persistence: StatePersistence[StateT, RunEndT]) -> None: with set_nodes_type_context([node_def.node for node_def in self.node_defs.values()]): @@ -615,9 +556,9 @@ def __init__( graph: Graph[StateT, DepsT, RunEndT], start_node: BaseNode[StateT, DepsT, RunEndT], persistence: StatePersistence[StateT, RunEndT], + node_run_id: NodeRunId, state: StateT, deps: DepsT, - auto_instrument: bool, span: LogfireSpan | None = None, ): """Create a new run for a given graph, starting at the specified node. @@ -628,18 +569,18 @@ def __init__( graph: The [`Graph`][pydantic_graph.graph.Graph] to run. start_node: The node where execution will begin. persistence: State persistence interface. + node_run_id: The ID of the current node run. state: A shared state object or primitive (like a counter, dataclass, etc.) that is available to all nodes via `ctx.state`. deps: Optional dependencies that each node can access via `ctx.deps`, e.g. database connections, configuration, or logging clients. - auto_instrument: Whether to automatically create instrumentation spans during the run. span: An optional existing Logfire span to nest node-level spans under (advanced usage). """ self.graph = graph self.persistence = persistence + self._node_run_id = node_run_id self.state = state self.deps = deps - self._auto_instrument = auto_instrument self._span = span self._next_node: BaseNode[StateT, DepsT, RunEndT] | End[RunEndT] = start_node @@ -708,15 +649,33 @@ async def main(): the run has completed. """ if node is None: - if isinstance(self._next_node, End): - # Note: we could alternatively just return `self._next_node` here, but it's easier to start with an - # error and relax the behavior later, than vice versa. - raise exceptions.GraphRuntimeError('This graph run has already ended.') - node = self._next_node - - self._next_node = await self.graph.next( - node, persistence=self.persistence, state=self.state, deps=self.deps, infer_name=False - ) + node = cast(BaseNode[StateT, DepsT, T], self._next_node) + + if isinstance(node, End): + # While technically this is not compatible with the documented method signature, it's an easy mistake to + # make, and we should eagerly provide a more helpful error message than you'd get otherwise. + raise exceptions.GraphRuntimeError(f'Cannot call `next` with an `End` node: {node!r}.') + + node_id = node.get_id() + if node_id not in self.graph.node_defs: + raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') + + with ExitStack() as stack: + if self.graph.auto_instrument: + stack.enter_context(_logfire.span('run node {node_id}', node_id=node_id, node=node)) + + async with self.persistence.record_run(self._node_run_id): + ctx = GraphRunContext(self.state, self.deps) + self._next_node = await node.run(ctx) + + if isinstance(self._next_node, End): + self._node_run_id = await self.persistence.snapshot_end(self.state, self._next_node) + elif isinstance(self._next_node, BaseNode): + self._node_run_id = await self.persistence.snapshot_node(self.state, self._next_node) + else: + raise exceptions.GraphRuntimeError( + f'Invalid node return type: `{type(self._next_node).__name__}`. Expected `BaseNode` or `End`.' + ) return self._next_node diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index 08440ee052..8f7863b71a 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -11,7 +11,7 @@ from . import _utils, exceptions if TYPE_CHECKING: - from .state import StateT + from .persistence import StateT else: StateT = TypeVar('StateT', default=None) diff --git a/pydantic_graph/pydantic_graph/state/__init__.py b/pydantic_graph/pydantic_graph/persistence/__init__.py similarity index 64% rename from pydantic_graph/pydantic_graph/state/__init__.py rename to pydantic_graph/pydantic_graph/persistence/__init__.py index 521e608572..c32aa10f70 100644 --- a/pydantic_graph/pydantic_graph/state/__init__.py +++ b/pydantic_graph/pydantic_graph/persistence/__init__.py @@ -1,11 +1,12 @@ from __future__ import annotations as _annotations from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Iterator, Sequence -from contextlib import asynccontextmanager, contextmanager +from collections.abc import Iterator, Sequence +from contextlib import AbstractAsyncContextManager, contextmanager from dataclasses import dataclass, field from datetime import datetime -from typing import Annotated, Any, Callable, Generic, Literal, Union +from typing import Annotated, Any, Callable, Generic, Literal, NewType, Union +from uuid import uuid4 import pydantic from typing_extensions import TypeVar @@ -14,10 +15,16 @@ from ..nodes import BaseNode, End, RunEndT from . import _utils -__all__ = 'StateT', 'NodeSnapshot', 'EndSnapshot', 'Snapshot', 'StatePersistence', 'set_nodes_type_context' +__all__ = 'StateT', 'NodeRunId', 'NodeSnapshot', 'EndSnapshot', 'Snapshot', 'StatePersistence', 'set_nodes_type_context' StateT = TypeVar('StateT', default=None) """Type variable for the state in a graph.""" +NodeRunId = NewType('NodeRunId', str) +"""Unique ID for a node run.""" + + +def new_run_id() -> NodeRunId: + return NodeRunId(uuid4().hex) @dataclass @@ -28,10 +35,13 @@ class NodeSnapshot(Generic[StateT, RunEndT]): """The state of the graph before the node is run.""" node: Annotated[BaseNode[StateT, Any, RunEndT], _utils.CustomNodeSchema()] """The node to run next.""" + run_id: NodeRunId = field(default_factory=new_run_id) + """Unique ID of the node run.""" start_ts: datetime | None = None """The timestamp when the node started running, `None` until the run starts.""" duration: float | None = None """The duration of the node run in seconds, if the node has been run.""" + status: Literal['not_started', 'pending', 'running', 'success', 'error'] = 'not_started' kind: Literal['node'] = 'node' """The kind of history step, can be used as a discriminator when deserializing history.""" @@ -44,6 +54,8 @@ class EndSnapshot(Generic[StateT, RunEndT]): """The state of the graph at the end of the run.""" result: End[RunEndT] """The result of the graph run.""" + run_id: NodeRunId = field(default_factory=new_run_id) + """Unique ID for the end of the graph run.""" ts: datetime = field(default_factory=_utils.now_utc) """The timestamp when the graph run ended.""" kind: Literal['end'] = 'end' @@ -70,21 +82,25 @@ class StatePersistence(ABC, Generic[StateT, RunEndT]): """Abstract base class for storing the state of a graph.""" @abstractmethod - async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: - """Snapshot the state of a graph before a node is run. + async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> NodeRunId: + """Snapshot the state of a graph, when the next step is to run a node. + + In particular this should set [`NodeSnapshot.duration`][pydantic_graph.state.NodeSnapshot.duration] + when the run finishes. + + Note: although the node Args: state: The state of the graph. next_node: The next node to run or end if the graph has ended - Returns: - The snapshot + Returns: an async context manager that wraps the run of the node. """ raise NotImplementedError @abstractmethod - async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: - """Snapshot the state of a graph before a node is run. + async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> NodeRunId: + """Snapshot the state of a graph after a node has run, when the graph has ended. Args: state: The state of the graph. @@ -93,14 +109,16 @@ async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: raise NotImplementedError @abstractmethod - @asynccontextmanager - async def record_run(self) -> AsyncIterator[None]: + def record_run(self, run_id: NodeRunId) -> AbstractAsyncContextManager[None]: """Record the run of the node. - In particular this should set [`NodeSnapshot.start_ts`][pydantic_graph.state.NodeSnapshot.start_ts] - and [`NodeSnapshot.duration`][pydantic_graph.state.NodeSnapshot.duration]. + In particular this should set: + + - [`NodeSnapshot.status`][pydantic_graph.state.NodeSnapshot.status] to `'running'` and + [`NodeSnapshot.start_ts`][pydantic_graph.state.NodeSnapshot.start_ts] when the run starts. + - [`NodeSnapshot.status`][pydantic_graph.state.NodeSnapshot.status] to `'success'` or `'error'` and + [`NodeSnapshot.duration`][pydantic_graph.state.NodeSnapshot.duration] when the run finishes. """ - yield raise NotImplementedError @abstractmethod @@ -133,9 +151,13 @@ def set_nodes_type_context(nodes: Sequence[type[BaseNode[Any, Any, Any]]]) -> It _utils.nodes_type_context.reset(token) -def build_snapshots_type_adapter( +def build_snapshot_list_type_adapter( state_t: type[StateT], run_end_t: type[RunEndT] ) -> pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]]: - return pydantic.TypeAdapter( - list[Annotated[Snapshot[state_t, run_end_t], pydantic.Discriminator('kind')]], - ) + return pydantic.TypeAdapter(list[Annotated[Snapshot[state_t, run_end_t], pydantic.Discriminator('kind')]]) + + +def build_snapshot_single_type_adapter( + state_t: type[StateT], run_end_t: type[RunEndT] +) -> pydantic.TypeAdapter[Snapshot[StateT, RunEndT]]: + return pydantic.TypeAdapter(Annotated[Snapshot[state_t, run_end_t], pydantic.Discriminator('kind')]) diff --git a/pydantic_graph/pydantic_graph/state/_utils.py b/pydantic_graph/pydantic_graph/persistence/_utils.py similarity index 100% rename from pydantic_graph/pydantic_graph/state/_utils.py rename to pydantic_graph/pydantic_graph/persistence/_utils.py diff --git a/pydantic_graph/pydantic_graph/state/memory.py b/pydantic_graph/pydantic_graph/persistence/memory.py similarity index 60% rename from pydantic_graph/pydantic_graph/state/memory.py rename to pydantic_graph/pydantic_graph/persistence/memory.py index a9684a4edc..bb4d89b054 100644 --- a/pydantic_graph/pydantic_graph/state/memory.py +++ b/pydantic_graph/pydantic_graph/persistence/memory.py @@ -11,7 +11,14 @@ from typing_extensions import TypeVar from ..nodes import BaseNode, End -from . import EndSnapshot, NodeSnapshot, Snapshot, StatePersistence, _utils, build_snapshots_type_adapter +from . import ( + EndSnapshot, + NodeRunId, + NodeSnapshot, + Snapshot, + StatePersistence, + build_snapshot_list_type_adapter, +) StateT = TypeVar('StateT', default=Any) RunEndT = TypeVar('RunEndT', default=Any) @@ -24,31 +31,36 @@ class SimpleStatePersistence(StatePersistence[StateT, RunEndT]): deep_copy: bool = False last_snapshot: Snapshot[StateT, RunEndT] | None = None - async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: + async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> NodeRunId: self.last_snapshot = NodeSnapshot( state=self.prep_state(state), node=next_node.deep_copy() if self.deep_copy else next_node, ) + return self.last_snapshot.run_id - async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: + async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> NodeRunId: self.last_snapshot = EndSnapshot( state=self.prep_state(state), result=end.deep_copy_data() if self.deep_copy else end, ) + return self.last_snapshot.run_id @asynccontextmanager - async def record_run(self) -> AsyncIterator[None]: - last_snapshot = await self.restore() - if not isinstance(last_snapshot, NodeSnapshot): - yield - return - - last_snapshot.start_ts = _utils.now_utc() + async def record_run(self, run_id: NodeRunId) -> AsyncIterator[None]: + assert self.last_snapshot is not None, 'No snapshot to record' + assert isinstance(self.last_snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' + assert run_id == self.last_snapshot.run_id, 'run_id must match the last snapshot run_id' + self.last_snapshot.status = 'running' start = perf_counter() try: yield - finally: - last_snapshot.duration = perf_counter() - start + except Exception: + self.last_snapshot.duration = perf_counter() - start + self.last_snapshot.status = 'error' + raise + else: + self.last_snapshot.duration = perf_counter() - start + self.last_snapshot.status = 'success' async def restore(self) -> Snapshot[StateT, RunEndT] | None: return self.last_snapshot @@ -71,35 +83,41 @@ class FullStatePersistence(StatePersistence[StateT, RunEndT]): default=None, init=False, repr=False ) - async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: - self.history.append( - NodeSnapshot( - state=self.prep_state(state), - node=next_node.deep_copy() if self.deep_copy else next_node, - ) + async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> NodeRunId: + last_snapshot = NodeSnapshot( + state=self.prep_state(state), + node=next_node.deep_copy() if self.deep_copy else next_node, ) + self.history.append(last_snapshot) + return last_snapshot.run_id - async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: - self.history.append( - EndSnapshot( - state=self.prep_state(state), - result=end.deep_copy_data() if self.deep_copy else end, - ) + async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> NodeRunId: + end = EndSnapshot( + state=self.prep_state(state), + result=end.deep_copy_data() if self.deep_copy else end, ) + self.history.append(end) + return end.run_id @asynccontextmanager - async def record_run(self) -> AsyncIterator[None]: - last_snapshot = await self.restore() - if not isinstance(last_snapshot, NodeSnapshot): - yield - return + async def record_run(self, run_id: NodeRunId) -> AsyncIterator[None]: + try: + snapshot = next(s for s in self.history if s.run_id == run_id) + except StopIteration as e: + raise LookupError(f'No snapshot found for run_id {run_id}') from e - last_snapshot.start_ts = _utils.now_utc() + assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' + snapshot.status = 'running' start = perf_counter() try: yield - finally: - last_snapshot.duration = perf_counter() - start + except Exception: + snapshot.duration = perf_counter() - start + snapshot.status = 'error' + raise + else: + snapshot.duration = perf_counter() - start + snapshot.status = 'success' async def restore(self) -> Snapshot[StateT, RunEndT] | None: if self.history: @@ -116,7 +134,7 @@ def load_json(self, json_data: str | bytes | bytearray) -> None: def set_types(self, get_types: Callable[[], tuple[type[StateT], type[RunEndT]]]) -> None: if self._snapshots_type_adapter is None: state_t, run_end_t = get_types() - self._snapshots_type_adapter = build_snapshots_type_adapter(state_t, run_end_t) + self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_t, run_end_t) def prep_state(self, state: StateT) -> StateT: """Prepare state for snapshot, uses [`copy.deepcopy`][copy.deepcopy] by default.""" diff --git a/tests/graph/test_persistence.py b/tests/graph/test_persistence.py index e394f8a7fc..4fcce2c552 100644 --- a/tests/graph/test_persistence.py +++ b/tests/graph/test_persistence.py @@ -201,3 +201,32 @@ async def run(self, ctx: GraphRunContext) -> NoGenericArgsNode: ) ] ) + + +async def test_node_error(): + @dataclass + class Foo(BaseNode): + async def run(self, ctx: GraphRunContext) -> Bar: + return Bar() + + @dataclass + class Bar(BaseNode[None, None, int]): + async def run(self, ctx: GraphRunContext) -> End[int]: + raise RuntimeError('test error') + + graph = Graph(nodes=[Foo, Bar]) + + sp = FullStatePersistence() + with pytest.raises(RuntimeError, match='test error'): + await graph.run(Foo(), persistence=sp) + + assert sp.history == snapshot( + [ + NodeSnapshot( + state=None, + node=Foo(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + ), + ] + ) diff --git a/tests/typed_graph.py b/tests/typed_graph.py index 286085f76d..cda6b3c6c3 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -6,7 +6,7 @@ from typing_extensions import assert_type from pydantic_graph import BaseNode, End, FullStatePersistence, Graph, GraphRunContext -from pydantic_graph.state import StatePersistence +from pydantic_graph.persistence import StatePersistence @dataclass From 45622447e309f4a48bb958a4902410c0b8888a71 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 7 Mar 2025 21:28:26 +0000 Subject: [PATCH 08/25] snapshot id on node --- pydantic_graph/pydantic_graph/graph.py | 73 +++++++++++++---- pydantic_graph/pydantic_graph/mermaid.py | 2 +- pydantic_graph/pydantic_graph/nodes.py | 39 +++++++-- .../pydantic_graph/persistence/__init__.py | 40 ++++----- .../pydantic_graph/persistence/_utils.py | 4 +- .../pydantic_graph/persistence/memory.py | 38 ++++----- tests/conftest.py | 13 +++ tests/graph/test_graph.py | 81 ++++++++++++++++--- 8 files changed, 215 insertions(+), 75 deletions(-) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 4eac2f7eb8..2890b17c93 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -13,8 +13,8 @@ from logfire_api import LogfireSpan from . import _utils, exceptions, mermaid -from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT -from .persistence import NodeRunId, StatePersistence, StateT, set_nodes_type_context +from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT, SnapshotId +from .persistence import StatePersistence, StateT, set_nodes_type_context from .persistence.memory import SimpleStatePersistence # while waiting for https://github.com/pydantic/logfire/issues/745 @@ -258,17 +258,50 @@ async def iter( with ExitStack() as stack: if span is not None: stack.enter_context(span) - node_run_id = await persistence.snapshot_node(state, start_node) yield GraphRun[StateT, DepsT, T]( graph=self, start_node=start_node, persistence=persistence, - node_run_id=node_run_id, state=state, deps=deps, span=span, ) + # @deprecated('`graph.next` is deprecated, use `graph.iter` ... `run.next` instead') + async def next( + self: Graph[StateT, DepsT, T], + node: BaseNode[StateT, DepsT, T], + persistence: StatePersistence[StateT, T], + *, + state: StateT = None, + deps: DepsT = None, + infer_name: bool = True, + ) -> BaseNode[StateT, DepsT, Any] | End[T]: + """Run a node in the graph and return the next node to run. + + Args: + node: The node to run. + persistence: State persistence interface, defaults to + [`SimpleStatePersistence`][pydantic_graph.state.memory.SimpleStatePersistence] if `None`. + state: The current state of the graph. + deps: The dependencies of the graph. + infer_name: Whether to infer the graph name from the calling frame. + + Returns: + The next node to run or [`End`][pydantic_graph.nodes.End] if the graph has finished. + """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + + run = GraphRun[StateT, DepsT, T]( + graph=self, + start_node=node, + persistence=persistence, + state=state, + deps=deps, + ) + return await run.next(node) + async def next_from_persistence( self: Graph[StateT, DepsT, T], persistence: StatePersistence[StateT, T], @@ -280,14 +313,13 @@ async def next_from_persistence( self._infer_name(inspect.currentframe()) snapshot = await persistence.restore_node_snapshot() - node_run_id = NotImplemented run = GraphRun[StateT, DepsT, T]( graph=self, start_node=snapshot.node, persistence=persistence, - node_run_id=node_run_id, state=snapshot.state, deps=deps, + snapshot_id=snapshot.id, ) return await run.next() @@ -455,7 +487,7 @@ def _register_node( node: type[BaseNode[StateT, DepsT, T]], parent_namespace: dict[str, Any] | None, ) -> None: - node_id = node.get_id() + node_id = node.get_node_id() if existing_node := self.node_defs.get(node_id): raise exceptions.GraphSetupError( f'Node ID `{node_id}` is not unique — found on {existing_node.node} and {node}' @@ -553,12 +585,13 @@ async def main(): def __init__( self, + *, graph: Graph[StateT, DepsT, RunEndT], start_node: BaseNode[StateT, DepsT, RunEndT], persistence: StatePersistence[StateT, RunEndT], - node_run_id: NodeRunId, state: StateT, deps: DepsT, + snapshot_id: SnapshotId | None = None, span: LogfireSpan | None = None, ): """Create a new run for a given graph, starting at the specified node. @@ -569,16 +602,16 @@ def __init__( graph: The [`Graph`][pydantic_graph.graph.Graph] to run. start_node: The node where execution will begin. persistence: State persistence interface. - node_run_id: The ID of the current node run. state: A shared state object or primitive (like a counter, dataclass, etc.) that is available to all nodes via `ctx.state`. deps: Optional dependencies that each node can access via `ctx.deps`, e.g. database connections, configuration, or logging clients. + snapshot_id: The ID of the snapshot the node came from. span: An optional existing Logfire span to nest node-level spans under (advanced usage). """ self.graph = graph self.persistence = persistence - self._node_run_id = node_run_id + self._snapshot_id: SnapshotId | None = snapshot_id self.state = state self.deps = deps self._span = span @@ -650,13 +683,19 @@ async def main(): """ if node is None: node = cast(BaseNode[StateT, DepsT, T], self._next_node) + node_snapshot_id = node.get_snapshot_id() + else: + node_snapshot_id = node.get_snapshot_id() + if node_snapshot_id != self._snapshot_id: + await self.persistence.snapshot_node(self.state, node) + self._snapshot_id = node_snapshot_id - if isinstance(node, End): + if not isinstance(node, BaseNode): # While technically this is not compatible with the documented method signature, it's an easy mistake to # make, and we should eagerly provide a more helpful error message than you'd get otherwise. - raise exceptions.GraphRuntimeError(f'Cannot call `next` with an `End` node: {node!r}.') + raise exceptions.GraphRuntimeError(f'`next` must be called with a `BaseNode` instance: {node!r}.') - node_id = node.get_id() + node_id = node.get_node_id() if node_id not in self.graph.node_defs: raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') @@ -664,14 +703,16 @@ async def main(): if self.graph.auto_instrument: stack.enter_context(_logfire.span('run node {node_id}', node_id=node_id, node=node)) - async with self.persistence.record_run(self._node_run_id): + async with self.persistence.record_run(node_snapshot_id): ctx = GraphRunContext(self.state, self.deps) self._next_node = await node.run(ctx) if isinstance(self._next_node, End): - self._node_run_id = await self.persistence.snapshot_end(self.state, self._next_node) + self._snapshot_id = self._next_node.get_snapshot_id() + await self.persistence.snapshot_end(self.state, self._next_node) elif isinstance(self._next_node, BaseNode): - self._node_run_id = await self.persistence.snapshot_node(self.state, self._next_node) + self._snapshot_id = self._next_node.get_snapshot_id() + await self.persistence.snapshot_node(self.state, self._next_node) else: raise exceptions.GraphRuntimeError( f'Invalid node return type: `{type(self._next_node).__name__}`. Expected `BaseNode` or `End`.' diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index 456bab81e9..1c4c4f2ada 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -127,7 +127,7 @@ def _node_ids(node_idents: Sequence[NodeIdent] | NodeIdent) -> Iterable[str]: if isinstance(node, str): yield node else: - yield node.get_id() + yield node.get_node_id() def request_image( diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index 8f7863b71a..402730964b 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -4,7 +4,8 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, is_dataclass from functools import cache -from typing import TYPE_CHECKING, Any, ClassVar, Generic, get_origin, get_type_hints +from typing import TYPE_CHECKING, Any, ClassVar, Generic, NewType, get_origin, get_type_hints +from uuid import uuid4 from typing_extensions import Never, Self, TypeVar @@ -15,7 +16,7 @@ else: StateT = TypeVar('StateT', default=None) -__all__ = 'GraphRunContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'DepsT', 'RunEndT' +__all__ = 'GraphRunContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'DepsT', 'RunEndT', 'SnapshotId' RunEndT = TypeVar('RunEndT', covariant=True, default=None) """Covariant type variable for the return type of a graph [`run`][pydantic_graph.graph.Graph.run].""" @@ -23,6 +24,8 @@ """Covariant type variable for the return type of a node [`run`][pydantic_graph.nodes.BaseNode.run].""" DepsT = TypeVar('DepsT', default=None, contravariant=True) """Type variable for the dependencies of a graph and node.""" +SnapshotId = NewType('SnapshotId', str) +"""Unique ID for a node run.""" @dataclass @@ -67,9 +70,16 @@ async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, Dep """ ... + def get_snapshot_id(self) -> SnapshotId: + if snapshot_id := getattr(self, '__snapshot_id', None): + return snapshot_id + else: + self.__dict__['__snapshot_id'] = snapshot_id = generate_snapshot_id(self.get_node_id()) + return snapshot_id + @classmethod @cache - def get_id(cls) -> str: + def get_node_id(cls) -> str: """Get the ID of the node.""" return cls.__name__ @@ -115,13 +125,13 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, DepsT, # TODO: Should we disallow this? returns_base_node = True elif issubclass(return_type_origin, BaseNode): - next_node_edges[return_type.get_id()] = edge + next_node_edges[return_type.get_node_id()] = edge else: raise exceptions.GraphSetupError(f'Invalid return type: {return_type}') return NodeDef( cls, - cls.get_id(), + cls.get_node_id(), cls.get_note(), next_node_edges, end_edge, @@ -145,7 +155,24 @@ def deep_copy_data(self) -> End[RunEndT]: if self.data is None: return self else: - return End(copy.deepcopy(self.data)) + end = End(copy.deepcopy(self.data)) + end.set_snapshot_id(self.get_snapshot_id()) + return end + + def set_snapshot_id(self, set_id: str) -> None: + self.__dict__['__snapshot_id'] = SnapshotId(set_id) + + def get_snapshot_id(self) -> SnapshotId: + if snapshot_id := getattr(self, '__snapshot_id', None): + return snapshot_id + else: + self.__dict__['__snapshot_id'] = snapshot_id = generate_snapshot_id('end') + return snapshot_id + + +def generate_snapshot_id(node_id: str) -> SnapshotId: + # module method to allow mocking + return SnapshotId(f'{node_id}:{uuid4().hex}') @dataclass diff --git a/pydantic_graph/pydantic_graph/persistence/__init__.py b/pydantic_graph/pydantic_graph/persistence/__init__.py index c32aa10f70..c86db907a2 100644 --- a/pydantic_graph/pydantic_graph/persistence/__init__.py +++ b/pydantic_graph/pydantic_graph/persistence/__init__.py @@ -5,26 +5,21 @@ from contextlib import AbstractAsyncContextManager, contextmanager from dataclasses import dataclass, field from datetime import datetime -from typing import Annotated, Any, Callable, Generic, Literal, NewType, Union -from uuid import uuid4 +from typing import Annotated, Any, Callable, Generic, Literal, Union import pydantic from typing_extensions import TypeVar from .. import exceptions -from ..nodes import BaseNode, End, RunEndT +from ..nodes import BaseNode, End, RunEndT, SnapshotId from . import _utils -__all__ = 'StateT', 'NodeRunId', 'NodeSnapshot', 'EndSnapshot', 'Snapshot', 'StatePersistence', 'set_nodes_type_context' +__all__ = 'StateT', 'NodeSnapshot', 'EndSnapshot', 'Snapshot', 'StatePersistence', 'set_nodes_type_context' StateT = TypeVar('StateT', default=None) """Type variable for the state in a graph.""" -NodeRunId = NewType('NodeRunId', str) -"""Unique ID for a node run.""" - -def new_run_id() -> NodeRunId: - return NodeRunId(uuid4().hex) +UNSET_ID = SnapshotId('__unset__') @dataclass @@ -35,8 +30,6 @@ class NodeSnapshot(Generic[StateT, RunEndT]): """The state of the graph before the node is run.""" node: Annotated[BaseNode[StateT, Any, RunEndT], _utils.CustomNodeSchema()] """The node to run next.""" - run_id: NodeRunId = field(default_factory=new_run_id) - """Unique ID of the node run.""" start_ts: datetime | None = None """The timestamp when the node started running, `None` until the run starts.""" duration: float | None = None @@ -45,6 +38,13 @@ class NodeSnapshot(Generic[StateT, RunEndT]): kind: Literal['node'] = 'node' """The kind of history step, can be used as a discriminator when deserializing history.""" + id: SnapshotId = UNSET_ID + """Unique ID of the snapshot.""" + + def __post_init__(self) -> None: + if self.id == UNSET_ID: + self.id = self.node.get_snapshot_id() + @dataclass class EndSnapshot(Generic[StateT, RunEndT]): @@ -54,13 +54,18 @@ class EndSnapshot(Generic[StateT, RunEndT]): """The state of the graph at the end of the run.""" result: End[RunEndT] """The result of the graph run.""" - run_id: NodeRunId = field(default_factory=new_run_id) - """Unique ID for the end of the graph run.""" ts: datetime = field(default_factory=_utils.now_utc) """The timestamp when the graph run ended.""" kind: Literal['end'] = 'end' """The kind of history step, can be used as a discriminator when deserializing history.""" + id: SnapshotId = UNSET_ID + """Unique ID of the snapshot.""" + + def __post_init__(self) -> None: + if self.id == UNSET_ID: + self.id = self.node.get_snapshot_id() + @property def node(self) -> End[RunEndT]: """Shim to get the [`result`][pydantic_graph.state.EndSnapshot.result]. @@ -82,12 +87,9 @@ class StatePersistence(ABC, Generic[StateT, RunEndT]): """Abstract base class for storing the state of a graph.""" @abstractmethod - async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> NodeRunId: + async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: """Snapshot the state of a graph, when the next step is to run a node. - In particular this should set [`NodeSnapshot.duration`][pydantic_graph.state.NodeSnapshot.duration] - when the run finishes. - Note: although the node Args: @@ -99,7 +101,7 @@ async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, Ru raise NotImplementedError @abstractmethod - async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> NodeRunId: + async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: """Snapshot the state of a graph after a node has run, when the graph has ended. Args: @@ -109,7 +111,7 @@ async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> NodeRunId: raise NotImplementedError @abstractmethod - def record_run(self, run_id: NodeRunId) -> AbstractAsyncContextManager[None]: + def record_run(self, snapshot_id: SnapshotId) -> AbstractAsyncContextManager[None]: """Record the run of the node. In particular this should set: diff --git a/pydantic_graph/pydantic_graph/persistence/_utils.py b/pydantic_graph/pydantic_graph/persistence/_utils.py index eaaf2d2c92..d396d28d45 100644 --- a/pydantic_graph/pydantic_graph/persistence/_utils.py +++ b/pydantic_graph/pydantic_graph/persistence/_utils.py @@ -28,7 +28,7 @@ def __get_pydantic_core_schema__( if len(nodes) == 1: nodes_type = nodes[0] else: - nodes_annotated = [Annotated[node, pydantic.Tag(node.get_id())] for node in nodes] + nodes_annotated = [Annotated[node, pydantic.Tag(node.get_node_id())] for node in nodes] nodes_type = Annotated[Union[tuple(nodes_annotated)], pydantic.Discriminator(self._node_discriminator)] schema = handler(nodes_type) @@ -45,7 +45,7 @@ def _node_discriminator(node_data: Any) -> str: @staticmethod def _node_serializer(node: Any, handler: pydantic.SerializerFunctionWrapHandler) -> dict[str, Any]: node_dict = handler(node) - node_dict['node_id'] = node.get_id() + node_dict['node_id'] = node.get_node_id() return node_dict diff --git a/pydantic_graph/pydantic_graph/persistence/memory.py b/pydantic_graph/pydantic_graph/persistence/memory.py index bb4d89b054..43f612db15 100644 --- a/pydantic_graph/pydantic_graph/persistence/memory.py +++ b/pydantic_graph/pydantic_graph/persistence/memory.py @@ -13,10 +13,11 @@ from ..nodes import BaseNode, End from . import ( EndSnapshot, - NodeRunId, NodeSnapshot, Snapshot, + SnapshotId, StatePersistence, + _utils, build_snapshot_list_type_adapter, ) @@ -31,26 +32,28 @@ class SimpleStatePersistence(StatePersistence[StateT, RunEndT]): deep_copy: bool = False last_snapshot: Snapshot[StateT, RunEndT] | None = None - async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> NodeRunId: + async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: self.last_snapshot = NodeSnapshot( state=self.prep_state(state), node=next_node.deep_copy() if self.deep_copy else next_node, ) - return self.last_snapshot.run_id - async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> NodeRunId: + async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: self.last_snapshot = EndSnapshot( state=self.prep_state(state), result=end.deep_copy_data() if self.deep_copy else end, ) - return self.last_snapshot.run_id @asynccontextmanager - async def record_run(self, run_id: NodeRunId) -> AsyncIterator[None]: + async def record_run(self, snapshot_id: SnapshotId) -> AsyncIterator[None]: assert self.last_snapshot is not None, 'No snapshot to record' assert isinstance(self.last_snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' - assert run_id == self.last_snapshot.run_id, 'run_id must match the last snapshot run_id' + assert snapshot_id == self.last_snapshot.id, ( + f'snapshot_id must match the last snapshot ID: {snapshot_id!r} != {self.last_snapshot.id!r}' + ) self.last_snapshot.status = 'running' + self.last_snapshot.start_ts = _utils.now_utc() + start = perf_counter() try: yield @@ -83,31 +86,30 @@ class FullStatePersistence(StatePersistence[StateT, RunEndT]): default=None, init=False, repr=False ) - async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> NodeRunId: - last_snapshot = NodeSnapshot( + async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: + snapshot = NodeSnapshot( state=self.prep_state(state), node=next_node.deep_copy() if self.deep_copy else next_node, ) - self.history.append(last_snapshot) - return last_snapshot.run_id + self.history.append(snapshot) - async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> NodeRunId: - end = EndSnapshot( + async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: + snapshot = EndSnapshot( state=self.prep_state(state), result=end.deep_copy_data() if self.deep_copy else end, ) - self.history.append(end) - return end.run_id + self.history.append(snapshot) @asynccontextmanager - async def record_run(self, run_id: NodeRunId) -> AsyncIterator[None]: + async def record_run(self, snapshot_id: SnapshotId) -> AsyncIterator[None]: try: - snapshot = next(s for s in self.history if s.run_id == run_id) + snapshot = next(s for s in self.history if s.id == snapshot_id) except StopIteration as e: - raise LookupError(f'No snapshot found for run_id {run_id}') from e + raise LookupError(f'No snapshot found with id={snapshot_id}') from e assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' snapshot.status = 'running' + snapshot.start_ts = _utils.now_utc() start = perf_counter() try: yield diff --git a/tests/conftest.py b/tests/conftest.py index be0faf92ed..bd4d305189 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,7 @@ import httpx import pytest from _pytest.assertion.rewrite import AssertionRewritingHook +from pytest_mock import MockerFixture from typing_extensions import TypeAlias from vcr import VCR @@ -247,3 +248,15 @@ def groq_api_key() -> str: @pytest.fixture(scope='session') def anthropic_api_key() -> str: return os.getenv('ANTHROPIC_API_KEY', 'mock-api-key') + + +@pytest.fixture +def mock_snapshot_id(mocker: MockerFixture): + i = 0 + + def generate_snapshot_id(node_id: str) -> str: + nonlocal i + i += 1 + return f'{node_id}:{i}' + + return mocker.patch('pydantic_graph.nodes.generate_snapshot_id', side_effect=generate_snapshot_id) diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index fdaec007d6..abd88a341b 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -64,7 +64,7 @@ async def test_graph(): assert my_graph.name == 'my_graph' -async def test_graph_history(): +async def test_graph_history(mock_snapshot_id): my_graph = Graph[None, None, int](nodes=(Float2String, String2Length, Double)) assert my_graph.name is None assert my_graph._inferred_types == (type(None), int) @@ -79,21 +79,27 @@ async def test_graph_history(): state=None, node=Float2String(input_data=3.14), start_ts=IsNow(tz=timezone.utc), + status='success', + id='Float2String:1', duration=IsFloat(), ), NodeSnapshot( state=None, node=String2Length(input_data='3.14'), start_ts=IsNow(tz=timezone.utc), + status='success', + id='String2Length:2', duration=IsFloat(), ), NodeSnapshot( state=None, node=Double(input_data=4), start_ts=IsNow(tz=timezone.utc), + status='success', + id='Double:3', duration=IsFloat(), ), - EndSnapshot(state=None, result=End(data=8), ts=IsNow(tz=timezone.utc)), + EndSnapshot(state=None, result=End(data=8), ts=IsNow(tz=timezone.utc), id='end:4'), ] ) sp = FullStatePersistence() @@ -106,33 +112,43 @@ async def test_graph_history(): state=None, node=Float2String(input_data=3.14159), start_ts=IsNow(tz=timezone.utc), + status='success', + id='Float2String:5', duration=IsFloat(), ), NodeSnapshot( state=None, node=String2Length(input_data='3.14159'), start_ts=IsNow(tz=timezone.utc), + status='success', + id='String2Length:6', duration=IsFloat(), ), NodeSnapshot( state=None, node=Double(input_data=7), start_ts=IsNow(tz=timezone.utc), + status='success', + id='Double:7', duration=IsFloat(), ), NodeSnapshot( state=None, node=String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx'), start_ts=IsNow(tz=timezone.utc), + status='success', + id='String2Length:8', duration=IsFloat(), ), NodeSnapshot( state=None, node=Double(input_data=21), start_ts=IsNow(tz=timezone.utc), + status='success', + id='Double:9', duration=IsFloat(), ), - EndSnapshot(state=None, result=End(data=42), ts=IsNow(tz=timezone.utc)), + EndSnapshot(state=None, result=End(data=42), ts=IsNow(tz=timezone.utc), id='end:10'), ] ) assert [e.node for e in sp.history] == snapshot( @@ -231,7 +247,7 @@ async def run(self, ctx: GraphRunContext) -> End[None]: @classmethod @cache - def get_id(cls) -> str: + def get_node_id(cls) -> str: return 'Foo' with pytest.raises(GraphSetupError) as exc_info: @@ -263,7 +279,7 @@ async def run(self, ctx: GraphRunContext) -> End[None]: assert exc_info.value.message == snapshot('Node `test_run_node_not_in_graph..Spam()` is not in the graph.') -async def test_run_return_other(): +async def test_run_return_other(mock_snapshot_id): @dataclass class Foo(BaseNode): async def run(self, ctx: GraphRunContext) -> Bar: @@ -282,7 +298,7 @@ async def run(self, ctx: GraphRunContext) -> End[None]: assert exc_info.value.message == snapshot('Invalid node return type: `int`. Expected `BaseNode` or `End`.') -async def test_next(): +async def test_next(mock_snapshot_id): @dataclass class Foo(BaseNode): async def run(self, ctx: GraphRunContext) -> Bar: @@ -300,7 +316,17 @@ async def run(self, ctx: GraphRunContext) -> Foo: assert n == Bar() assert g.name == 'g' assert sp.history == snapshot( - [NodeSnapshot(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat())] + [ + NodeSnapshot( + state=None, + node=Foo(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + status='success', + id='Foo:1', + ), + NodeSnapshot(state=None, node=Bar(), id='Bar:2'), + ] ) assert isinstance(n, Bar) @@ -309,13 +335,28 @@ async def run(self, ctx: GraphRunContext) -> Foo: assert sp.history == snapshot( [ - NodeSnapshot(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - NodeSnapshot(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeSnapshot( + state=None, + node=Foo(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + status='success', + id='Foo:1', + ), + NodeSnapshot( + state=None, + node=Bar(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + status='success', + id='Bar:2', + ), + NodeSnapshot(state=None, node=Foo(), id='Foo:3'), ] ) -async def test_deps(): +async def test_deps(mock_snapshot_id): @dataclass class Deps: a: int @@ -340,8 +381,22 @@ async def run(self, ctx: GraphRunContext[None, Deps]) -> End[int]: assert result.output == 123 assert sp.history == snapshot( [ - NodeSnapshot(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - NodeSnapshot(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - EndSnapshot(state=None, result=End(123), ts=IsNow(tz=timezone.utc)), + NodeSnapshot( + state=None, + node=Foo(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + status='success', + id='Foo:1', + ), + NodeSnapshot( + state=None, + node=Bar(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + status='success', + id='Bar:2', + ), + EndSnapshot(state=None, result=End(data=123), ts=IsNow(tz=timezone.utc), id='end:3'), ] ) From 2bdb0294c2457a03fe0e9010992f60b41cac5aa7 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 8 Mar 2025 16:28:49 +0000 Subject: [PATCH 09/25] fixing snapshot id --- docs/graph.md | 5 +- pydantic_graph/pydantic_graph/graph.py | 10 +-- pydantic_graph/pydantic_graph/nodes.py | 23 +++---- .../pydantic_graph/persistence/__init__.py | 42 +++++++++--- .../pydantic_graph/persistence/memory.py | 26 +++++++- tests/graph/test_graph.py | 8 +-- tests/graph/test_mermaid.py | 8 ++- tests/graph/test_persistence.py | 64 +++++++++++++++---- tests/graph/test_state.py | 8 ++- 9 files changed, 145 insertions(+), 49 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index f01b4e8b57..c17d758749 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -598,11 +598,11 @@ async def main(): """ [ Ask(), - Answer(question='What is the capital of France?', answer='Vichy'), + Answer(question='What is the capital of France?', answer=None), Evaluate(answer='Vichy'), Reprimand(comment='Vichy is no longer the capital of France.'), Ask(), - Answer(question='what is 1 + 1?', answer='2'), + Answer(question='what is 1 + 1?', answer=None), Evaluate(answer='2'), End(data='Well done, 1 + 1 = 2'), ] @@ -734,6 +734,7 @@ async def main(): #> History Step: CountDown() CountDownState(counter=5) #> History Step: CountDown() CountDownState(counter=4) #> History Step: CountDown() CountDownState(counter=3) + #> History Step: CountDown() CountDownState(counter=2) ``` 1. We start by grabbing the first node that will be run in the agent's graph. diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 2890b17c93..448f553c47 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -13,7 +13,7 @@ from logfire_api import LogfireSpan from . import _utils, exceptions, mermaid -from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT, SnapshotId +from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT from .persistence import StatePersistence, StateT, set_nodes_type_context from .persistence.memory import SimpleStatePersistence @@ -591,7 +591,7 @@ def __init__( persistence: StatePersistence[StateT, RunEndT], state: StateT, deps: DepsT, - snapshot_id: SnapshotId | None = None, + snapshot_id: str | None = None, span: LogfireSpan | None = None, ): """Create a new run for a given graph, starting at the specified node. @@ -611,7 +611,7 @@ def __init__( """ self.graph = graph self.persistence = persistence - self._snapshot_id: SnapshotId | None = snapshot_id + self._snapshot_id: str | None = snapshot_id self.state = state self.deps = deps self._span = span @@ -687,7 +687,9 @@ async def main(): else: node_snapshot_id = node.get_snapshot_id() if node_snapshot_id != self._snapshot_id: - await self.persistence.snapshot_node(self.state, node) + existing_snapshot = await self.persistence.get_node_snapshot(node_snapshot_id, status='created') + if not existing_snapshot: + await self.persistence.snapshot_node(self.state, node) self._snapshot_id = node_snapshot_id if not isinstance(node, BaseNode): diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index 402730964b..59bd9cae6b 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, is_dataclass from functools import cache -from typing import TYPE_CHECKING, Any, ClassVar, Generic, NewType, get_origin, get_type_hints +from typing import TYPE_CHECKING, Any, ClassVar, Generic, get_origin, get_type_hints from uuid import uuid4 from typing_extensions import Never, Self, TypeVar @@ -16,7 +16,7 @@ else: StateT = TypeVar('StateT', default=None) -__all__ = 'GraphRunContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'DepsT', 'RunEndT', 'SnapshotId' +__all__ = 'GraphRunContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'DepsT', 'RunEndT' RunEndT = TypeVar('RunEndT', covariant=True, default=None) """Covariant type variable for the return type of a graph [`run`][pydantic_graph.graph.Graph.run].""" @@ -24,8 +24,6 @@ """Covariant type variable for the return type of a node [`run`][pydantic_graph.nodes.BaseNode.run].""" DepsT = TypeVar('DepsT', default=None, contravariant=True) """Type variable for the dependencies of a graph and node.""" -SnapshotId = NewType('SnapshotId', str) -"""Unique ID for a node run.""" @dataclass @@ -70,13 +68,16 @@ async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, Dep """ ... - def get_snapshot_id(self) -> SnapshotId: + def get_snapshot_id(self) -> str: if snapshot_id := getattr(self, '__snapshot_id', None): return snapshot_id else: self.__dict__['__snapshot_id'] = snapshot_id = generate_snapshot_id(self.get_node_id()) return snapshot_id + def set_snapshot_id(self, snapshot_id: str) -> None: + self.__dict__['__snapshot_id'] = snapshot_id + @classmethod @cache def get_node_id(cls) -> str: @@ -159,20 +160,20 @@ def deep_copy_data(self) -> End[RunEndT]: end.set_snapshot_id(self.get_snapshot_id()) return end - def set_snapshot_id(self, set_id: str) -> None: - self.__dict__['__snapshot_id'] = SnapshotId(set_id) - - def get_snapshot_id(self) -> SnapshotId: + def get_snapshot_id(self) -> str: if snapshot_id := getattr(self, '__snapshot_id', None): return snapshot_id else: self.__dict__['__snapshot_id'] = snapshot_id = generate_snapshot_id('end') return snapshot_id + def set_snapshot_id(self, set_id: str) -> None: + self.__dict__['__snapshot_id'] = set_id + -def generate_snapshot_id(node_id: str) -> SnapshotId: +def generate_snapshot_id(node_id: str) -> str: # module method to allow mocking - return SnapshotId(f'{node_id}:{uuid4().hex}') + return f'{node_id}:{uuid4().hex}' @dataclass diff --git a/pydantic_graph/pydantic_graph/persistence/__init__.py b/pydantic_graph/pydantic_graph/persistence/__init__.py index c86db907a2..27716a830a 100644 --- a/pydantic_graph/pydantic_graph/persistence/__init__.py +++ b/pydantic_graph/pydantic_graph/persistence/__init__.py @@ -11,15 +11,24 @@ from typing_extensions import TypeVar from .. import exceptions -from ..nodes import BaseNode, End, RunEndT, SnapshotId +from ..nodes import BaseNode, End, RunEndT from . import _utils -__all__ = 'StateT', 'NodeSnapshot', 'EndSnapshot', 'Snapshot', 'StatePersistence', 'set_nodes_type_context' +__all__ = ( + 'StateT', + 'NodeSnapshot', + 'EndSnapshot', + 'Snapshot', + 'StatePersistence', + 'set_nodes_type_context', + 'SnapshotStatus', +) StateT = TypeVar('StateT', default=None) """Type variable for the state in a graph.""" -UNSET_ID = SnapshotId('__unset__') +UNSET_SNAPSHOT_ID = '__unset__' +SnapshotStatus = Literal['created', 'pending', 'running', 'success', 'error'] @dataclass @@ -34,15 +43,16 @@ class NodeSnapshot(Generic[StateT, RunEndT]): """The timestamp when the node started running, `None` until the run starts.""" duration: float | None = None """The duration of the node run in seconds, if the node has been run.""" - status: Literal['not_started', 'pending', 'running', 'success', 'error'] = 'not_started' + status: SnapshotStatus = 'created' + """The status of the snapshot.""" kind: Literal['node'] = 'node' """The kind of history step, can be used as a discriminator when deserializing history.""" - id: SnapshotId = UNSET_ID + id: str = UNSET_SNAPSHOT_ID """Unique ID of the snapshot.""" def __post_init__(self) -> None: - if self.id == UNSET_ID: + if self.id == UNSET_SNAPSHOT_ID: self.id = self.node.get_snapshot_id() @@ -59,11 +69,11 @@ class EndSnapshot(Generic[StateT, RunEndT]): kind: Literal['end'] = 'end' """The kind of history step, can be used as a discriminator when deserializing history.""" - id: SnapshotId = UNSET_ID + id: str = UNSET_SNAPSHOT_ID """Unique ID of the snapshot.""" def __post_init__(self) -> None: - if self.id == UNSET_ID: + if self.id == UNSET_SNAPSHOT_ID: self.id = self.node.get_snapshot_id() @property @@ -111,7 +121,7 @@ async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: raise NotImplementedError @abstractmethod - def record_run(self, snapshot_id: SnapshotId) -> AbstractAsyncContextManager[None]: + def record_run(self, snapshot_id: str) -> AbstractAsyncContextManager[None]: """Record the run of the node. In particular this should set: @@ -132,6 +142,20 @@ async def restore(self) -> Snapshot[StateT, RunEndT] | None: """ raise NotImplementedError + @abstractmethod + async def get_node_snapshot( + self, snapshot_id: str, status: SnapshotStatus | None = None + ) -> Snapshot[StateT, RunEndT] | None: + """Get a snapshot by ID. + + Args: + snapshot_id: The ID of the snapshot to get. + status: The status of the snapshot to get, or `None` to get any status. + + Returns: The snapshot with the given ID and status, or `None` if no snapshot with that ID exists. + """ + raise NotImplementedError + def set_types(self, get_types: Callable[[], tuple[type[StateT], type[RunEndT]]]) -> None: pass diff --git a/pydantic_graph/pydantic_graph/persistence/memory.py b/pydantic_graph/pydantic_graph/persistence/memory.py index 43f612db15..d15290b31d 100644 --- a/pydantic_graph/pydantic_graph/persistence/memory.py +++ b/pydantic_graph/pydantic_graph/persistence/memory.py @@ -15,7 +15,7 @@ EndSnapshot, NodeSnapshot, Snapshot, - SnapshotId, + SnapshotStatus, StatePersistence, _utils, build_snapshot_list_type_adapter, @@ -45,7 +45,7 @@ async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: ) @asynccontextmanager - async def record_run(self, snapshot_id: SnapshotId) -> AsyncIterator[None]: + async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: assert self.last_snapshot is not None, 'No snapshot to record' assert isinstance(self.last_snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' assert snapshot_id == self.last_snapshot.id, ( @@ -68,6 +68,15 @@ async def record_run(self, snapshot_id: SnapshotId) -> AsyncIterator[None]: async def restore(self) -> Snapshot[StateT, RunEndT] | None: return self.last_snapshot + async def get_node_snapshot( + self, snapshot_id: str, status: SnapshotStatus | None = None + ) -> NodeSnapshot[StateT, RunEndT] | None: + if isinstance(self.last_snapshot, NodeSnapshot) and self.last_snapshot.id == snapshot_id: + if status and self.last_snapshot.status != status: + return None + else: + return self.last_snapshot + def prep_state(self, state: StateT) -> StateT: """Prepare state for snapshot, uses [`copy.deepcopy`][copy.deepcopy] by default.""" if not self.deep_copy or state is None: @@ -101,7 +110,7 @@ async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: self.history.append(snapshot) @asynccontextmanager - async def record_run(self, snapshot_id: SnapshotId) -> AsyncIterator[None]: + async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: try: snapshot = next(s for s in self.history if s.id == snapshot_id) except StopIteration as e: @@ -125,6 +134,17 @@ async def restore(self) -> Snapshot[StateT, RunEndT] | None: if self.history: return self.history[-1] + async def get_node_snapshot( + self, snapshot_id: str, status: SnapshotStatus | None = None + ) -> Snapshot[StateT, RunEndT] | None: + for snapshot in self.history: + if ( + isinstance(snapshot, NodeSnapshot) + and snapshot.id == snapshot_id + and (status is None or snapshot.status == status) + ): + return snapshot + def dump_json(self, *, indent: int | None = None) -> bytes: assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `dump_json`' return self._snapshots_type_adapter.dump_json(self.history, indent=indent) diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index abd88a341b..c62712981d 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -64,7 +64,7 @@ async def test_graph(): assert my_graph.name == 'my_graph' -async def test_graph_history(mock_snapshot_id): +async def test_graph_history(mock_snapshot_id: object): my_graph = Graph[None, None, int](nodes=(Float2String, String2Length, Double)) assert my_graph.name is None assert my_graph._inferred_types == (type(None), int) @@ -279,7 +279,7 @@ async def run(self, ctx: GraphRunContext) -> End[None]: assert exc_info.value.message == snapshot('Node `test_run_node_not_in_graph..Spam()` is not in the graph.') -async def test_run_return_other(mock_snapshot_id): +async def test_run_return_other(mock_snapshot_id: object): @dataclass class Foo(BaseNode): async def run(self, ctx: GraphRunContext) -> Bar: @@ -298,7 +298,7 @@ async def run(self, ctx: GraphRunContext) -> End[None]: assert exc_info.value.message == snapshot('Invalid node return type: `int`. Expected `BaseNode` or `End`.') -async def test_next(mock_snapshot_id): +async def test_next(mock_snapshot_id: object): @dataclass class Foo(BaseNode): async def run(self, ctx: GraphRunContext) -> Bar: @@ -356,7 +356,7 @@ async def run(self, ctx: GraphRunContext) -> Foo: ) -async def test_deps(mock_snapshot_id): +async def test_deps(mock_snapshot_id: object): @dataclass class Deps: a: int diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 24bdfb0680..55c02a9875 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -67,7 +67,7 @@ async def run(self, ctx: GraphRunContext) -> Annotated[End[None], Edge(label='eg graph2 = Graph(nodes=(Spam, Foo, Bar, Eggs)) -async def test_run_graph(): +async def test_run_graph(mock_snapshot_id: object): sp = FullStatePersistence() result = await graph1.run(Foo(), persistence=sp) assert result.output is None @@ -78,14 +78,18 @@ async def test_run_graph(): node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), + status='success', + id='Foo:1', ), NodeSnapshot( state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), + status='success', + id='Bar:2', ), - EndSnapshot(state=None, result=End(None), ts=IsNow(tz=timezone.utc)), + EndSnapshot(state=None, result=End(data=None), ts=IsNow(tz=timezone.utc), id='end:3'), ] ) diff --git a/tests/graph/test_persistence.py b/tests/graph/test_persistence.py index 4fcce2c552..6413255fdf 100644 --- a/tests/graph/test_persistence.py +++ b/tests/graph/test_persistence.py @@ -53,16 +53,35 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[int]: Graph(nodes=(Foo, Bar)), ], ) -async def test_dump_load_history(graph: Graph[MyState, None, int]): +async def test_dump_load_state(graph: Graph[MyState, None, int], mock_snapshot_id: object): sp = FullStatePersistence() result = await graph.run(Foo(), state=MyState(1, ''), persistence=sp) assert result.output == snapshot(4) assert result.state == snapshot(MyState(x=2, y='y')) assert sp.history == snapshot( [ - NodeSnapshot(state=MyState(x=1, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - NodeSnapshot(state=MyState(x=2, y=''), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - EndSnapshot(state=MyState(x=2, y='y'), result=End(4), ts=IsNow(tz=timezone.utc)), + NodeSnapshot( + state=MyState(x=1, y=''), + node=Foo(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + status='success', + id='Foo:1', + ), + NodeSnapshot( + state=MyState(x=2, y=''), + node=Bar(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + status='success', + id='Bar:2', + ), + EndSnapshot( + state=MyState(x=2, y='y'), + result=End(data=4), + ts=IsNow(tz=timezone.utc), + id='end:3', + ), ] ) history_json = sp.dump_json() @@ -73,20 +92,25 @@ async def test_dump_load_history(graph: Graph[MyState, None, int]): 'node': {'node_id': 'Foo'}, 'start_ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), 'duration': IsFloat(), + 'status': 'success', 'kind': 'node', + 'id': 'Foo:1', }, { 'state': {'x': 2, 'y': ''}, 'node': {'node_id': 'Bar'}, 'start_ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), 'duration': IsFloat(), + 'status': 'success', 'kind': 'node', + 'id': 'Bar:2', }, { 'state': {'x': 2, 'y': 'y'}, 'result': {'data': 4}, 'ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), 'kind': 'end', + 'id': 'end:3', }, ] ) @@ -122,15 +146,19 @@ async def test_dump_load_history(graph: Graph[MyState, None, int]): node=Foo(), start_ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), duration=123.0, + id='Foo:4', ), EndSnapshot( - state=MyState(x=42, y='new'), result=End(data=42), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc) + state=MyState(x=42, y='new'), + result=End(data=42), + ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), + id='end:5', ), ] ) -def test_one_node(): +def test_one_node(mock_snapshot_id: object): @dataclass class MyNode(BaseNode[None, None, int]): node_field: int @@ -159,12 +187,13 @@ async def run(self, ctx: GraphRunContext) -> End[int]: node=MyNode(node_field=42), start_ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), duration=123.0, + id='MyNode:1', ) ] ) -def test_no_generic_arg(): +def test_no_generic_arg(mock_snapshot_id: object): @dataclass class NoGenericArgsNode(BaseNode): async def run(self, ctx: GraphRunContext) -> NoGenericArgsNode: @@ -198,23 +227,24 @@ async def run(self, ctx: GraphRunContext) -> NoGenericArgsNode: node=NoGenericArgsNode(), start_ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), duration=123.0, + id='NoGenericArgsNode:1', ) ] ) -async def test_node_error(): +async def test_node_error(mock_snapshot_id: object): @dataclass class Foo(BaseNode): - async def run(self, ctx: GraphRunContext) -> Bar: - return Bar() + async def run(self, ctx: GraphRunContext) -> Spam: + return Spam() @dataclass - class Bar(BaseNode[None, None, int]): + class Spam(BaseNode[None, None, int]): async def run(self, ctx: GraphRunContext) -> End[int]: raise RuntimeError('test error') - graph = Graph(nodes=[Foo, Bar]) + graph = Graph(nodes=[Foo, Spam]) sp = FullStatePersistence() with pytest.raises(RuntimeError, match='test error'): @@ -225,8 +255,18 @@ async def run(self, ctx: GraphRunContext) -> End[int]: NodeSnapshot( state=None, node=Foo(), + status='success', + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + id='Foo:1', + ), + NodeSnapshot( + state=None, + node=Spam(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), + status='error', + id='Spam:2', ), ] ) diff --git a/tests/graph/test_state.py b/tests/graph/test_state.py index c9d79a4ec8..563cbd6b29 100644 --- a/tests/graph/test_state.py +++ b/tests/graph/test_state.py @@ -21,7 +21,7 @@ pytestmark = pytest.mark.anyio -async def test_run_graph(): +async def test_run_graph(mock_snapshot_id: object): @dataclass class MyState: x: int @@ -52,14 +52,18 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[str]: node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), + status='success', + id='Foo:1', ), NodeSnapshot( state=MyState(x=2, y=''), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), + status='success', + id='Bar:2', ), - EndSnapshot(state=MyState(x=2, y='y'), result=End('x=2 y=y'), ts=IsNow(tz=timezone.utc)), + EndSnapshot(state=MyState(x=2, y='y'), result=End(data='x=2 y=y'), ts=IsNow(tz=timezone.utc), id='end:3'), ] ) assert state == MyState(x=2, y='y') From e9d80522500a72bd2160bcccc442b253978f4922 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 8 Mar 2025 18:50:35 +0000 Subject: [PATCH 10/25] improving docs --- docs/api/pydantic_graph/persistence.md | 5 ++ docs/api/pydantic_graph/state.md | 3 -- docs/graph.md | 41 +++++++++------- mkdocs.yml | 2 +- pydantic_graph/pydantic_graph/__init__.py | 2 +- pydantic_graph/pydantic_graph/graph.py | 10 ++-- .../pydantic_graph/persistence/__init__.py | 20 +++++--- .../persistence/{memory.py => in_mem.py} | 47 ++++++++++++++----- 8 files changed, 85 insertions(+), 45 deletions(-) create mode 100644 docs/api/pydantic_graph/persistence.md delete mode 100644 docs/api/pydantic_graph/state.md rename pydantic_graph/pydantic_graph/persistence/{memory.py => in_mem.py} (83%) diff --git a/docs/api/pydantic_graph/persistence.md b/docs/api/pydantic_graph/persistence.md new file mode 100644 index 0000000000..4b68150307 --- /dev/null +++ b/docs/api/pydantic_graph/persistence.md @@ -0,0 +1,5 @@ +# `pydantic_graph.persistence` + +::: pydantic_graph.persistence + +::: pydantic_graph.persistence.in_mem diff --git a/docs/api/pydantic_graph/state.md b/docs/api/pydantic_graph/state.md deleted file mode 100644 index 480eea5a54..0000000000 --- a/docs/api/pydantic_graph/state.md +++ /dev/null @@ -1,3 +0,0 @@ -# `pydantic_graph.state` - -::: pydantic_graph.state diff --git a/docs/graph.md b/docs/graph.md index c17d758749..473a7fd56f 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -39,7 +39,7 @@ pip/uv-add pydantic-graph [`GraphRunContext`][pydantic_graph.nodes.GraphRunContext] — The context for the graph run, similar to PydanticAI's [`RunContext`][pydantic_ai.tools.RunContext]. This holds the state of the graph and dependencies and is passed to nodes when they're run. -`GraphRunContext` is generic in the state type of the graph it's used in, [`StateT`][pydantic_graph.state.StateT]. +`GraphRunContext` is generic in the state type of the graph it's used in, [`StateT`][pydantic_graph.persistence.StateT]. ### End @@ -59,7 +59,7 @@ Nodes, which are generally [`dataclass`es][dataclasses.dataclass], generally con Nodes are generic in: -* **state**, which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.state.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter, see [stateful graphs](#stateful-graphs) for more information +* **state**, which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.persistence.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter, see [stateful graphs](#stateful-graphs) for more information * **deps**, which must have the same type as the deps of the graph they're included in, [`DepsT`][pydantic_graph.nodes.DepsT] has a default of `None`, so if you're not using deps you can omit this generic parameter, see [dependency injection](#dependency-injection) for more information * **graph return type** — this only applies if the node returns [`End`][pydantic_graph.nodes.End]. [`RunEndT`][pydantic_graph.nodes.RunEndT] has a default of [Never][typing.Never] so this generic parameter can be omitted if the node doesn't return `End`, but must be included if it does. @@ -119,7 +119,7 @@ class MyNode(BaseNode[MyState, None, int]): # (1)! `Graph` is generic in: -* **state** the state type of the graph, [`StateT`][pydantic_graph.state.StateT] +* **state** the state type of the graph, [`StateT`][pydantic_graph.persistence.StateT] * **deps** the deps type of the graph, [`DepsT`][pydantic_graph.nodes.DepsT] * **graph return type** the return type of the graph run, [`RunEndT`][pydantic_graph.nodes.RunEndT] @@ -512,16 +512,22 @@ In this example, an AI asks the user a question, the user provides an answer, th ) ctx.state.ask_agent_messages += result.all_messages() ctx.state.question = result.data - return Answer(result.data) + return Question(result.data) @dataclass - class Answer(BaseNode[QuestionState]): + class Question(BaseNode[QuestionState]): question: str - answer: str | None = None + + async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer: + raise NotImplementedError('Question nodes should not be run') + + + @dataclass + class Answer(BaseNode[QuestionState]): + answer: str async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate: - assert self.answer is not None return Evaluate(self.answer) @@ -568,7 +574,7 @@ In this example, an AI asks the user a question, the user provides an answer, th return Ask() - question_graph = Graph(nodes=(Ask, Answer, Evaluate, Reprimand)) + question_graph = Graph(nodes=(Ask, Question, Answer, Evaluate, Reprimand)) ``` _(This example is complete, it can be run "as is" with Python 3.10+)_ @@ -578,7 +584,7 @@ from rich.prompt import Prompt from pydantic_graph import End, FullStatePersistence -from ai_q_and_a_graph import Ask, question_graph, QuestionState, Answer +from ai_q_and_a_graph import Ask, question_graph, Question, QuestionState, Answer async def main(): @@ -589,8 +595,8 @@ async def main(): node = await question_graph.next( # (4)! node, persistence=persistence, state=state ) - if isinstance(node, Answer): - node.answer = Prompt.ask(node.question) # (5)! + if isinstance(node, Question): + node = Answer(Prompt.ask(node.question)) # (5)! elif isinstance(node, End): # (6)! print(f'Correct answer! {node.data}') #> Correct answer! Well done, 1 + 1 = 2 @@ -598,11 +604,13 @@ async def main(): """ [ Ask(), - Answer(question='What is the capital of France?', answer=None), + Question(question='What is the capital of France?'), + Answer(answer='Vichy'), Evaluate(answer='Vichy'), Reprimand(comment='Vichy is no longer the capital of France.'), Ask(), - Answer(question='what is 1 + 1?', answer=None), + Question(question='what is 1 + 1?'), + Answer(answer='2'), Evaluate(answer='2'), End(data='Well done, 1 + 1 = 2'), ] @@ -613,9 +621,9 @@ async def main(): 1. Create the state object which will be mutated by [`next`][pydantic_graph.graph.Graph.next]. 2. The start node is `Ask` but will be updated by [`next`][pydantic_graph.graph.Graph.next] as the graph runs. -3. The history of the graph run is stored using [`FullStatePersistence`][pydantic_graph.state.memory.FullStatePersistence]. Again [`next`][pydantic_graph.graph.Graph.next] will update this list in place. +3. The history of the graph run is stored using [`FullStatePersistence`][pydantic_graph.FullStatePersistence]. Again [`next`][pydantic_graph.graph.Graph.next] will update this list in place. 4. [Run][pydantic_graph.graph.Graph.next] the graph one node at a time, updating the state, current node and history as the graph runs. -5. If the current node is an `Answer` node, prompt the user for an answer. +5. If the current node is an `Question` node, prompt the user for an answer. 6. Since we're using [`next`][pydantic_graph.graph.Graph.next] we have to manually check for an [`End`][pydantic_graph.nodes.End] and exit the loop if we get one. _(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ @@ -634,7 +642,8 @@ title: question_graph --- stateDiagram-v2 [*] --> Ask - Ask --> Answer + Ask --> Question + Question --> Answer Answer --> Evaluate Evaluate --> Reprimand Evaluate --> [*] diff --git a/mkdocs.yml b/mkdocs.yml index d38c0f6fb6..56b1181eb1 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -63,7 +63,7 @@ nav: - api/models/fallback.md - api/pydantic_graph/graph.md - api/pydantic_graph/nodes.md - - api/pydantic_graph/state.md + - api/pydantic_graph/persistence.md - api/pydantic_graph/mermaid.md - api/pydantic_graph/exceptions.md diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index 80f601ad05..5fff288fae 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -2,7 +2,7 @@ from .graph import Graph, GraphRun, GraphRunResult from .nodes import BaseNode, Edge, End, GraphRunContext from .persistence import EndSnapshot, NodeSnapshot, Snapshot -from .persistence.memory import FullStatePersistence, SimpleStatePersistence +from .persistence.in_mem import FullStatePersistence, SimpleStatePersistence __all__ = ( 'Graph', diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 448f553c47..f579337b29 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -15,7 +15,7 @@ from . import _utils, exceptions, mermaid from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT from .persistence import StatePersistence, StateT, set_nodes_type_context -from .persistence.memory import SimpleStatePersistence +from .persistence.in_mem import SimpleStatePersistence # while waiting for https://github.com/pydantic/logfire/issues/745 try: @@ -136,7 +136,7 @@ async def run( state: The initial state of the graph. deps: The dependencies of the graph. persistence: State persistence interface, defaults to - [`SimpleStatePersistence`][pydantic_graph.state.memory.SimpleStatePersistence] if `None`. + [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. infer_name: Whether to infer the graph name from the calling frame. span: The span to use for the graph run. If not provided, a span will be created depending on the value of the `_auto_instrument` field. @@ -194,7 +194,7 @@ def run_sync( state: The initial state of the graph. deps: The dependencies of the graph. persistence: State persistence interface, defaults to - [`SimpleStatePersistence`][pydantic_graph.state.memory.SimpleStatePersistence] if `None`. + [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. infer_name: Whether to infer the graph name from the calling frame. Returns: @@ -238,7 +238,7 @@ async def iter( state: The initial state of the graph. deps: The dependencies of the graph. persistence: State persistence interface, defaults to - [`SimpleStatePersistence`][pydantic_graph.state.memory.SimpleStatePersistence] if `None`. + [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. infer_name: Whether to infer the graph name from the calling frame. span: The span to use for the graph run. If not provided, a new span will be created. @@ -282,7 +282,7 @@ async def next( Args: node: The node to run. persistence: State persistence interface, defaults to - [`SimpleStatePersistence`][pydantic_graph.state.memory.SimpleStatePersistence] if `None`. + [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. state: The current state of the graph. deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. diff --git a/pydantic_graph/pydantic_graph/persistence/__init__.py b/pydantic_graph/pydantic_graph/persistence/__init__.py index 27716a830a..1c207cd266 100644 --- a/pydantic_graph/pydantic_graph/persistence/__init__.py +++ b/pydantic_graph/pydantic_graph/persistence/__init__.py @@ -78,7 +78,7 @@ def __post_init__(self) -> None: @property def node(self) -> End[RunEndT]: - """Shim to get the [`result`][pydantic_graph.state.EndSnapshot.result]. + """Shim to get the [`result`][pydantic_graph.persistence.EndSnapshot.result]. Useful to allow `[snapshot.node for snapshot in persistence.history]`. """ @@ -126,10 +126,10 @@ def record_run(self, snapshot_id: str) -> AbstractAsyncContextManager[None]: In particular this should set: - - [`NodeSnapshot.status`][pydantic_graph.state.NodeSnapshot.status] to `'running'` and - [`NodeSnapshot.start_ts`][pydantic_graph.state.NodeSnapshot.start_ts] when the run starts. - - [`NodeSnapshot.status`][pydantic_graph.state.NodeSnapshot.status] to `'success'` or `'error'` and - [`NodeSnapshot.duration`][pydantic_graph.state.NodeSnapshot.duration] when the run finishes. + - [`NodeSnapshot.status`][pydantic_graph.persistence.NodeSnapshot.status] to `'running'` and + [`NodeSnapshot.start_ts`][pydantic_graph.persistence.NodeSnapshot.start_ts] when the run starts. + - [`NodeSnapshot.status`][pydantic_graph.persistence.NodeSnapshot.status] to `'success'` or `'error'` and + [`NodeSnapshot.duration`][pydantic_graph.persistence.NodeSnapshot.duration] when the run finishes. """ raise NotImplementedError @@ -138,7 +138,7 @@ async def restore(self) -> Snapshot[StateT, RunEndT] | None: """Retrieve the latest snapshot. Returns: - The most recent [`Snapshot`][pydantic_graph.state.Snapshot] of the run. + The most recent [`Snapshot`][pydantic_graph.persistence.Snapshot] of the run. """ raise NotImplementedError @@ -157,6 +157,14 @@ async def get_node_snapshot( raise NotImplementedError def set_types(self, get_types: Callable[[], tuple[type[StateT], type[RunEndT]]]) -> None: + """Set the types of the state and run end. + + This can be used to create [type adapters][pydantic.TypeAdapter] for serializing and deserializing + snapshots. + + Args: + get_types: A callback that returns the types of the state and run end. + """ pass async def restore_node_snapshot(self) -> NodeSnapshot[StateT, RunEndT]: diff --git a/pydantic_graph/pydantic_graph/persistence/memory.py b/pydantic_graph/pydantic_graph/persistence/in_mem.py similarity index 83% rename from pydantic_graph/pydantic_graph/persistence/memory.py rename to pydantic_graph/pydantic_graph/persistence/in_mem.py index d15290b31d..18052f44d3 100644 --- a/pydantic_graph/pydantic_graph/persistence/memory.py +++ b/pydantic_graph/pydantic_graph/persistence/in_mem.py @@ -1,3 +1,8 @@ +"""In memory state persistence. + +This module provides simple in memory state persistence for graphs. +""" + from __future__ import annotations as _annotations import copy @@ -27,20 +32,28 @@ @dataclass class SimpleStatePersistence(StatePersistence[StateT, RunEndT]): - """Simple in memory state persistence that just hold the latest snapshot.""" + """Simple in memory state persistence that just hold the latest snapshot. + + If no state persistence implementation is provided when running a graph, this is used by default. + """ deep_copy: bool = False + """Whether to deep copy the state and nodes when storing them. + + Defaults to `False` so you can use nodes that don't support deep copying. + """ last_snapshot: Snapshot[StateT, RunEndT] | None = None + """The last snapshot.""" async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: self.last_snapshot = NodeSnapshot( - state=self.prep_state(state), + state=self._prep_state(state), node=next_node.deep_copy() if self.deep_copy else next_node, ) async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: self.last_snapshot = EndSnapshot( - state=self.prep_state(state), + state=self._prep_state(state), result=end.deep_copy_data() if self.deep_copy else end, ) @@ -77,7 +90,7 @@ async def get_node_snapshot( else: return self.last_snapshot - def prep_state(self, state: StateT) -> StateT: + def _prep_state(self, state: StateT) -> StateT: """Prepare state for snapshot, uses [`copy.deepcopy`][copy.deepcopy] by default.""" if not self.deep_copy or state is None: return state @@ -87,24 +100,30 @@ def prep_state(self, state: StateT) -> StateT: @dataclass class FullStatePersistence(StatePersistence[StateT, RunEndT]): - """In memory state persistence that hold a history of nodes that were executed.""" + """In memory state persistence that hold a history of nodes.""" deep_copy: bool = True + """Whether to deep copy the state and nodes when storing them. + + Defaults to `True` so even if nodes or state are modified after the snapshot is taken, + the persistence history will record the value at the time of the snapshot. + """ history: list[Snapshot[StateT, RunEndT]] = field(default_factory=list) + """List of snapshots taken during the graph run.""" _snapshots_type_adapter: pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]] | None = field( default=None, init=False, repr=False ) async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: snapshot = NodeSnapshot( - state=self.prep_state(state), + state=self._prep_state(state), node=next_node.deep_copy() if self.deep_copy else next_node, ) self.history.append(snapshot) async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: snapshot = EndSnapshot( - state=self.prep_state(state), + state=self._prep_state(state), result=end.deep_copy_data() if self.deep_copy else end, ) self.history.append(snapshot) @@ -145,20 +164,22 @@ async def get_node_snapshot( ): return snapshot + def set_types(self, get_types: Callable[[], tuple[type[StateT], type[RunEndT]]]) -> None: + if self._snapshots_type_adapter is None: + state_t, run_end_t = get_types() + self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_t, run_end_t) + def dump_json(self, *, indent: int | None = None) -> bytes: + """Dump the history to JSON bytes.""" assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `dump_json`' return self._snapshots_type_adapter.dump_json(self.history, indent=indent) def load_json(self, json_data: str | bytes | bytearray) -> None: + """Load the history from JSON.""" assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `load_json`' self.history = self._snapshots_type_adapter.validate_json(json_data) - def set_types(self, get_types: Callable[[], tuple[type[StateT], type[RunEndT]]]) -> None: - if self._snapshots_type_adapter is None: - state_t, run_end_t = get_types() - self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_t, run_end_t) - - def prep_state(self, state: StateT) -> StateT: + def _prep_state(self, state: StateT) -> StateT: """Prepare state for snapshot, uses [`copy.deepcopy`][copy.deepcopy] by default.""" if not self.deep_copy or state is None: return state From 09a5174a268cd0f616e09bc59c3aaab7dbd55d03 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 8 Mar 2025 19:12:55 +0000 Subject: [PATCH 11/25] fix spans --- pydantic_graph/pydantic_graph/_utils.py | 4 ++-- pydantic_graph/pydantic_graph/graph.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/pydantic_graph/pydantic_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py index 4506149d54..bc49ed6fcb 100644 --- a/pydantic_graph/pydantic_graph/_utils.py +++ b/pydantic_graph/pydantic_graph/_utils.py @@ -3,9 +3,9 @@ import asyncio import types from functools import partial -from typing import Any, Callable, ParamSpec, TypeVar +from typing import Any, Callable, TypeVar -from typing_extensions import TypeIs, get_args, get_origin +from typing_extensions import ParamSpec, TypeIs, get_args, get_origin from typing_inspection import typing_objects from typing_inspection.introspection import is_union_origin diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 5f3d7f88aa..029c0e900e 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -253,9 +253,15 @@ async def iter( persistence = SimpleStatePersistence() self.set_persistence_types(persistence) - yield GraphRun[StateT, DepsT, T]( - graph=self, start_node=start_node, persistence=persistence, state=state, deps=deps - ) + if self.auto_instrument and span is None: + span = logfire_api.span('run graph {graph.name}', graph=self) + + with ExitStack() as stack: + if span is not None: + stack.enter_context(span) + yield GraphRun[StateT, DepsT, T]( + graph=self, start_node=start_node, persistence=persistence, state=state, deps=deps + ) # @deprecated('`graph.next` is deprecated, use `graph.iter` ... `run.next` instead') async def next( From e271af50b01cfc6dfde795c755d1fe9fa34a0bfd Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 8 Mar 2025 19:28:28 +0000 Subject: [PATCH 12/25] more tests, improve coverage --- pydantic_graph/pydantic_graph/_utils.py | 17 ++--------------- pydantic_graph/pydantic_graph/graph.py | 7 ++++--- tests/graph/test_graph.py | 16 ++++++++++++++++ tests/graph/test_persistence.py | 4 ++-- 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/pydantic_graph/pydantic_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py index bc49ed6fcb..f97b14fe03 100644 --- a/pydantic_graph/pydantic_graph/_utils.py +++ b/pydantic_graph/pydantic_graph/_utils.py @@ -2,10 +2,9 @@ import asyncio import types -from functools import partial -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar -from typing_extensions import ParamSpec, TypeIs, get_args, get_origin +from typing_extensions import TypeIs, get_args, get_origin from typing_inspection import typing_objects from typing_inspection.introspection import is_union_origin @@ -88,15 +87,3 @@ class Unset: def is_set(t_or_unset: T | Unset) -> TypeIs[T]: return t_or_unset is not UNSET - - -_P = ParamSpec('_P') -_R = TypeVar('_R') - - -async def run_in_executor(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs) -> _R: - if kwargs: - # noinspection PyTypeChecker - return await asyncio.get_running_loop().run_in_executor(None, partial(func, *args, **kwargs)) - else: - return await asyncio.get_running_loop().run_in_executor(None, func, *args) # type: ignore diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 029c0e900e..f03f9e258a 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -247,7 +247,9 @@ async def iter( A GraphRun that can be async iterated over to drive the graph to completion. """ if infer_name and self.name is None: - self._infer_name(inspect.currentframe()) + # f_back because `asynccontextmanager` adds one frame + if frame := inspect.currentframe(): # pragma: no branch + self._infer_name(frame.f_back) if persistence is None: persistence = SimpleStatePersistence() @@ -725,8 +727,7 @@ async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: return await self.next(self._next_node) def __repr__(self) -> str: - step = -1 # TODO - return f'"} step={step}>' + return f'' @dataclass diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index c62712981d..9bdbdc1314 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -298,6 +298,22 @@ async def run(self, ctx: GraphRunContext) -> End[None]: assert exc_info.value.message == snapshot('Invalid node return type: `int`. Expected `BaseNode` or `End`.') +async def test_iter(): + my_graph = Graph(nodes=(Float2String, String2Length, Double)) + assert my_graph.name is None + assert my_graph._inferred_types == (type(None), int) + node_reprs: list[str] = [] + async with my_graph.iter(Float2String(3.14)) as graph_iter: + assert repr(graph_iter) == snapshot('') + async for node in graph_iter: + node_reprs.append(repr(node)) + # len('3.14') * 2 == 8 + assert graph_iter.result + assert graph_iter.result.output == 8 + + assert node_reprs == snapshot(["String2Length(input_data='3.14')", 'Double(input_data=4)', 'End(data=8)']) + + async def test_next(mock_snapshot_id: object): @dataclass class Foo(BaseNode): diff --git a/tests/graph/test_persistence.py b/tests/graph/test_persistence.py index 6413255fdf..7220d80b8b 100644 --- a/tests/graph/test_persistence.py +++ b/tests/graph/test_persistence.py @@ -163,7 +163,7 @@ def test_one_node(mock_snapshot_id: object): class MyNode(BaseNode[None, None, int]): node_field: int - async def run(self, ctx: GraphRunContext) -> End[int]: + async def run(self, ctx: GraphRunContext) -> End[int]: # pragma: no cover return End(123) g = Graph(nodes=[MyNode]) @@ -196,7 +196,7 @@ async def run(self, ctx: GraphRunContext) -> End[int]: def test_no_generic_arg(mock_snapshot_id: object): @dataclass class NoGenericArgsNode(BaseNode): - async def run(self, ctx: GraphRunContext) -> NoGenericArgsNode: + async def run(self, ctx: GraphRunContext) -> NoGenericArgsNode: # pragma: no cover return NoGenericArgsNode() g = Graph(nodes=[NoGenericArgsNode]) From 3d93cb02f19f5d36986eb09048253d5b70faf035 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 9 Mar 2025 15:05:20 +0000 Subject: [PATCH 13/25] add file persistence --- .../pydantic_ai_examples/question_graph.py | 102 +++++------ pydantic_graph/pydantic_graph/_utils.py | 17 +- pydantic_graph/pydantic_graph/exceptions.py | 20 +++ pydantic_graph/pydantic_graph/graph.py | 12 +- .../pydantic_graph/persistence/__init__.py | 66 ++++--- .../pydantic_graph/persistence/file.py | 162 ++++++++++++++++++ .../pydantic_graph/persistence/in_mem.py | 43 +++-- tests/graph/test_file_persistence.py | 138 +++++++++++++++ 8 files changed, 452 insertions(+), 108 deletions(-) create mode 100644 pydantic_graph/pydantic_graph/persistence/file.py create mode 100644 tests/graph/test_file_persistence.py diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index 1711e63e9f..2ce0861460 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -9,18 +9,16 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Annotated import logfire -from devtools import debug +from groq import BaseModel from pydantic_graph import ( BaseNode, - Edge, End, - FullStatePersistence, Graph, GraphRunContext, ) +from pydantic_graph.persistence.file import FileStatePersistence from pydantic_ai import Agent from pydantic_ai.format_as_xml import format_as_xml @@ -40,30 +38,31 @@ class QuestionState: @dataclass -class Ask(BaseNode[QuestionState]): - async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer: +class GenerateQuestion(BaseNode[QuestionState]): + async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask: result = await ask_agent.run( 'Ask a simple question with a single correct answer.', message_history=ctx.state.ask_agent_messages, ) ctx.state.ask_agent_messages += result.all_messages() ctx.state.question = result.data - return Answer() + return Ask(result.data) @dataclass -class Answer(BaseNode[QuestionState]): - answer: str | None = None +class Ask(BaseNode[QuestionState]): + question: str - async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate: - assert self.answer is not None - return Evaluate(self.answer) + async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer: + answer = input(f'{self.question}: ') + return Answer(answer) -@dataclass -class EvaluationResult: +class EvaluationResult(BaseModel, use_attribute_docstrings=True): correct: bool + """Whether the answer is correct.""" comment: str + """Comment on the answer, reprimand the user if the answer is wrong.""" evaluate_agent = Agent( @@ -74,13 +73,13 @@ class EvaluationResult: @dataclass -class Evaluate(BaseNode[QuestionState]): +class Answer(BaseNode[QuestionState, None, str]): answer: str async def run( self, ctx: GraphRunContext[QuestionState], - ) -> Congratulate | Reprimand: + ) -> End[str] | Reprimand: assert ctx.state.question is not None result = await evaluate_agent.run( format_as_xml({'question': ctx.state.question, 'answer': self.answer}), @@ -88,87 +87,63 @@ async def run( ) ctx.state.evaluate_agent_messages += result.all_messages() if result.data.correct: - return Congratulate(result.data.comment) + return End(result.data.comment) else: return Reprimand(result.data.comment) -@dataclass -class Congratulate(BaseNode[QuestionState, None, None]): - comment: str - - async def run( - self, ctx: GraphRunContext[QuestionState] - ) -> Annotated[End, Edge(label='success')]: - print(f'Correct answer! {self.comment}') - return End(None) - - @dataclass class Reprimand(BaseNode[QuestionState]): comment: str - async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask: + async def run(self, ctx: GraphRunContext[QuestionState]) -> GenerateQuestion: print(f'Comment: {self.comment}') # > Comment: Vichy is no longer the capital of France. ctx.state.question = None - return Ask() + return GenerateQuestion() question_graph = Graph( - nodes=(Ask, Answer, Evaluate, Congratulate, Reprimand), state_type=QuestionState + nodes=(GenerateQuestion, Ask, Answer, Reprimand), state_type=QuestionState ) async def run_as_continuous(): state = QuestionState() - node = Ask() - persistence = FullStatePersistence() - with logfire.span('run questions graph'): - while True: - node = await question_graph.next(node, persistence=persistence, state=state) - if isinstance(node, End): - debug([e.node for e in persistence.history]) - break - elif isinstance(node, Answer): - assert state.question - node.answer = input(f'{state.question} ') - # otherwise just continue + node = GenerateQuestion() + end = await question_graph.run(node, state=state) + print('END:', end.output) async def run_as_cli(answer: str | None): - history_file = Path('question_graph_history.json') - persistence = FullStatePersistence() + persistence = FileStatePersistence(Path('question_graph.json')) question_graph.set_persistence_types(persistence) - if history_file.exists(): - persistence.load_json(history_file.read_bytes()) - - if persistence.history: - last = persistence.history[-1] - assert last.kind == 'node', 'expected last step to be a node' - state = last.state - assert answer is not None, 'answer is required to continue from history' + if snapshot := await persistence.retrieve_next(): + state = snapshot.state + assert answer is not None, ( + 'answer required, usage "uv run -m pydantic_ai_examples.question_graph cli "' + ) node = Answer(answer) else: state = QuestionState() - node = Ask() - debug(state, node) + node = GenerateQuestion() + # debug(state, node) with logfire.span('run questions graph'): while True: node = await question_graph.next(node, persistence=persistence, state=state) if isinstance(node, End): - debug([e.node for e in persistence.history]) + print('END:', node.data) + history = await persistence.load() + print('history:', '\n'.join(str(e.node) for e in history), sep='\n') print('Finished!') break - elif isinstance(node, Answer): - print(state.question) + elif isinstance(node, Ask): + print(node.question) break # otherwise just continue - history_file.write_bytes(persistence.dump_json(indent=2)) - if __name__ == '__main__': import asyncio @@ -190,7 +165,12 @@ async def run_as_cli(answer: str | None): sys.exit(1) if sub_command == 'mermaid': - print(question_graph.mermaid_code(start_node=Ask)) + print(question_graph.mermaid_code(start_node=GenerateQuestion)) + print( + question_graph.mermaid_save( + 'question_graph.jpg', start_node=GenerateQuestion + ) + ) elif sub_command == 'continuous': asyncio.run(run_as_continuous()) else: diff --git a/pydantic_graph/pydantic_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py index f97b14fe03..bc49ed6fcb 100644 --- a/pydantic_graph/pydantic_graph/_utils.py +++ b/pydantic_graph/pydantic_graph/_utils.py @@ -2,9 +2,10 @@ import asyncio import types -from typing import Any, TypeVar +from functools import partial +from typing import Any, Callable, TypeVar -from typing_extensions import TypeIs, get_args, get_origin +from typing_extensions import ParamSpec, TypeIs, get_args, get_origin from typing_inspection import typing_objects from typing_inspection.introspection import is_union_origin @@ -87,3 +88,15 @@ class Unset: def is_set(t_or_unset: T | Unset) -> TypeIs[T]: return t_or_unset is not UNSET + + +_P = ParamSpec('_P') +_R = TypeVar('_R') + + +async def run_in_executor(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs) -> _R: + if kwargs: + # noinspection PyTypeChecker + return await asyncio.get_running_loop().run_in_executor(None, partial(func, *args, **kwargs)) + else: + return await asyncio.get_running_loop().run_in_executor(None, func, *args) # type: ignore diff --git a/pydantic_graph/pydantic_graph/exceptions.py b/pydantic_graph/pydantic_graph/exceptions.py index 5288402c36..1bfa4b5930 100644 --- a/pydantic_graph/pydantic_graph/exceptions.py +++ b/pydantic_graph/pydantic_graph/exceptions.py @@ -1,3 +1,9 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .persistence import SnapshotStatus + + class GraphSetupError(TypeError): """Error caused by an incorrectly configured graph.""" @@ -18,3 +24,17 @@ class GraphRuntimeError(RuntimeError): def __init__(self, message: str): self.message = message super().__init__(message) + + +class GraphNodeStatusError(GraphRuntimeError): + """Error caused by trying to run a node that has status other than `'created'` or `'pending'`.""" + + def __init__(self, actual_status: 'SnapshotStatus'): + self.actual_status = actual_status + super().__init__(f"Snapshot status is {actual_status!r}, not 'created' or 'pending'.") + + @classmethod + def check(cls, status: 'SnapshotStatus') -> None: + """Check if the status is valid.""" + if status not in {'created', 'pending'}: + raise cls(status) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index f03f9e258a..df2e48ab2d 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -265,7 +265,6 @@ async def iter( graph=self, start_node=start_node, persistence=persistence, state=state, deps=deps ) - # @deprecated('`graph.next` is deprecated, use `graph.iter` ... `run.next` instead') async def next( self: Graph[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T], @@ -291,6 +290,7 @@ async def next( if infer_name and self.name is None: self._infer_name(inspect.currentframe()) + self.set_persistence_types(persistence) run = GraphRun[StateT, DepsT, T]( graph=self, start_node=node, @@ -310,7 +310,11 @@ async def next_from_persistence( if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - snapshot = await persistence.restore_node_snapshot() + snapshot = await persistence.retrieve_next() + if snapshot is None: + raise exceptions.GraphRuntimeError('Unable to restore snapshot from state persistence.') + snapshot.node.set_snapshot_id(snapshot.id) + run = GraphRun[StateT, DepsT, T]( graph=self, start_node=snapshot.node, @@ -682,9 +686,7 @@ async def main(): else: node_snapshot_id = node.get_snapshot_id() if node_snapshot_id != self._snapshot_id: - existing_snapshot = await self.persistence.get_node_snapshot(node_snapshot_id, status='created') - if not existing_snapshot: - await self.persistence.snapshot_node(self.state, node) + await self.persistence.snapshot_node_if_new(node_snapshot_id, self.state, node) self._snapshot_id = node_snapshot_id if not isinstance(node, BaseNode): diff --git a/pydantic_graph/pydantic_graph/persistence/__init__.py b/pydantic_graph/pydantic_graph/persistence/__init__.py index 1c207cd266..63fd000dc6 100644 --- a/pydantic_graph/pydantic_graph/persistence/__init__.py +++ b/pydantic_graph/pydantic_graph/persistence/__init__.py @@ -10,7 +10,6 @@ import pydantic from typing_extensions import TypeVar -from .. import exceptions from ..nodes import BaseNode, End, RunEndT from . import _utils @@ -29,6 +28,15 @@ UNSET_SNAPSHOT_ID = '__unset__' SnapshotStatus = Literal['created', 'pending', 'running', 'success', 'error'] +"""The status of a snapshot. + +- `'created'`: The snapshot has been created but not yet run. +- `'pending'`: The snapshot has been retrieved with + [`retrieve_next`][pydantic_graph.persistence.StatePersistence.retrieve_next] but not yet run. +- `'running'`: The snapshot is currently running. +- `'success'`: The snapshot has been run successfully. +- `'error'`: The snapshot has been run but an error occurred. +""" @dataclass @@ -100,13 +108,22 @@ class StatePersistence(ABC, Generic[StateT, RunEndT]): async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: """Snapshot the state of a graph, when the next step is to run a node. - Note: although the node - Args: state: The state of the graph. - next_node: The next node to run or end if the graph has ended + next_node: The next node to run. + """ + raise NotImplementedError - Returns: an async context manager that wraps the run of the node. + @abstractmethod + async def snapshot_node_if_new( + self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT] + ) -> None: + """Snapshot the state of a graph, if the snapshot ID doesn't already exist in persistence. + + Args: + snapshot_id: The ID of the snapshot to check. + state: The state of the graph. + next_node: The next node to run. """ raise NotImplementedError @@ -122,7 +139,17 @@ async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: @abstractmethod def record_run(self, snapshot_id: str) -> AbstractAsyncContextManager[None]: - """Record the run of the node. + """Record the run of the node, or error if the node is already running. + + Args: + snapshot_id: The ID of the snapshot to record. + + Raises: + GraphNodeRunningError: if the node status it not `'created'` or `'pending'`. + LookupError: if the snapshot ID is not found in persistence. + + Returns: + An async context manager that records the run of the node. In particular this should set: @@ -134,25 +161,18 @@ def record_run(self, snapshot_id: str) -> AbstractAsyncContextManager[None]: raise NotImplementedError @abstractmethod - async def restore(self) -> Snapshot[StateT, RunEndT] | None: - """Retrieve the latest snapshot. + async def retrieve_next(self) -> NodeSnapshot[StateT, RunEndT] | None: + """Retrieve a node snapshot with status `'created`' and set its status to `'pending'`. - Returns: - The most recent [`Snapshot`][pydantic_graph.persistence.Snapshot] of the run. + Returns: The snapshot, or `None` if no snapshot with status `'created`' exists. """ raise NotImplementedError @abstractmethod - async def get_node_snapshot( - self, snapshot_id: str, status: SnapshotStatus | None = None - ) -> Snapshot[StateT, RunEndT] | None: - """Get a snapshot by ID. - - Args: - snapshot_id: The ID of the snapshot to get. - status: The status of the snapshot to get, or `None` to get any status. + async def load(self) -> list[Snapshot[StateT, RunEndT]]: + """Load the entire history of snapshots. - Returns: The snapshot with the given ID and status, or `None` if no snapshot with that ID exists. + Returns: The list of snapshots. """ raise NotImplementedError @@ -167,14 +187,6 @@ def set_types(self, get_types: Callable[[], tuple[type[StateT], type[RunEndT]]]) """ pass - async def restore_node_snapshot(self) -> NodeSnapshot[StateT, RunEndT]: - snapshot = await self.restore() - if snapshot is None: - raise exceptions.GraphRuntimeError('Unable to restore snapshot from state persistence.') - elif not isinstance(snapshot, NodeSnapshot): - raise exceptions.GraphRuntimeError('Snapshot returned from persistence indicates the graph has ended.') - return snapshot - @contextmanager def set_nodes_type_context(nodes: Sequence[type[BaseNode[Any, Any, Any]]]) -> Iterator[None]: # noqa: D103 diff --git a/pydantic_graph/pydantic_graph/persistence/file.py b/pydantic_graph/pydantic_graph/persistence/file.py new file mode 100644 index 0000000000..30012eb55d --- /dev/null +++ b/pydantic_graph/pydantic_graph/persistence/file.py @@ -0,0 +1,162 @@ +from __future__ import annotations as _annotations + +import asyncio +import secrets +from collections.abc import AsyncIterator +from contextlib import AsyncExitStack, asynccontextmanager +from dataclasses import dataclass, field +from pathlib import Path +from time import perf_counter +from typing import Any, Callable + +import pydantic +from typing_extensions import TypeVar + +from .. import _utils as _graph_utils, exceptions +from ..nodes import BaseNode, End +from . import ( + EndSnapshot, + NodeSnapshot, + Snapshot, + SnapshotStatus, + StatePersistence, + _utils, + build_snapshot_list_type_adapter, +) + +StateT = TypeVar('StateT', default=Any) +RunEndT = TypeVar('RunEndT', default=Any) + + +@dataclass +class FileStatePersistence(StatePersistence[StateT, RunEndT]): + """State persistence that just hold the latest snapshot.""" + + json_file: Path + _snapshots_type_adapter: pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]] | None = field( + default=None, init=False, repr=False + ) + + async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: + await self._append_save(NodeSnapshot(state=state, node=next_node)) + + async def snapshot_node_if_new( + self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT] + ) -> None: + async with self._lock(): + snapshots = await self.load() + if not any(s.id == snapshot_id for s in snapshots): + await self._append_save(NodeSnapshot(state=state, node=next_node), lock=False) + + async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: + await self._append_save(EndSnapshot(state=state, result=end)) + + @asynccontextmanager + async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: + async with self._lock(): + snapshots = await self.load() + try: + snapshot = next(s for s in snapshots if s.id == snapshot_id) + except StopIteration as e: + raise LookupError(f'No snapshot found with id={snapshot_id}') from e + + assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' + exceptions.GraphNodeStatusError.check(snapshot.status) + snapshot.status = 'running' + snapshot.start_ts = _utils.now_utc() + await self._save(snapshots) + + start = perf_counter() + try: + yield + except Exception: + duration = perf_counter() - start + async with self._lock(): + await _graph_utils.run_in_executor(self._after_run_sync, snapshot_id, duration, 'error') + raise + else: + snapshot.duration = perf_counter() - start + async with self._lock(): + await _graph_utils.run_in_executor(self._after_run_sync, snapshot_id, snapshot.duration, 'success') + + async def retrieve_next(self) -> NodeSnapshot[StateT, RunEndT] | None: + async with self._lock(): + snapshots = await self.load() + if snapshot := next((s for s in snapshots if isinstance(s, NodeSnapshot) and s.status == 'created'), None): + snapshot.status = 'pending' + await self._save(snapshots) + return snapshot + + def set_types(self, get_types: Callable[[], tuple[type[StateT], type[RunEndT]]]) -> None: + if self._snapshots_type_adapter is None: + state_t, run_end_t = get_types() + self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_t, run_end_t) + + async def load(self) -> list[Snapshot[StateT, RunEndT]]: + return await _graph_utils.run_in_executor(self._load_sync) + + def _load_sync(self) -> list[Snapshot[StateT, RunEndT]]: + assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set' + try: + content = self.json_file.read_bytes() + except FileNotFoundError: + return [] + else: + return self._snapshots_type_adapter.validate_json(content) + + def _after_run_sync(self, snapshot_id: str, duration: float, status: SnapshotStatus) -> None: + snapshots = self._load_sync() + snapshot = next(s for s in snapshots if s.id == snapshot_id) + assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' + snapshot.duration = duration + snapshot.status = status + self._save_sync(snapshots) + + async def _save(self, snapshots: list[Snapshot[StateT, RunEndT]]) -> None: + await _graph_utils.run_in_executor(self._save_sync, snapshots) + + def _save_sync(self, snapshots: list[Snapshot[StateT, RunEndT]]) -> None: + assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set' + self.json_file.write_bytes(self._snapshots_type_adapter.dump_json(snapshots, indent=2)) + + async def _append_save(self, snapshot: Snapshot[StateT, RunEndT], *, lock: bool = True) -> None: + assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set' + async with AsyncExitStack() as stack: + if lock: + await stack.enter_async_context(self._lock()) + snapshots = await self.load() + snapshots.append(snapshot) + await self._save(snapshots) + + @asynccontextmanager + async def _lock(self, *, timeout: float = 1.0) -> AsyncIterator[None]: + """Lock a file by checking and writing a `.pydantic-graph-persistence-lock` to it. + + Args: + timeout: how long to wait for the lock + + Returns: an async context manager that holds the lock + """ + lock_file = self.json_file.parent / f'{self.json_file.name}.pydantic-graph-persistence-lock' + lock_id = secrets.token_urlsafe().encode() + await asyncio.wait_for(_get_lock(lock_file, lock_id), timeout=timeout) + try: + yield + finally: + await _graph_utils.run_in_executor(lock_file.unlink, missing_ok=True) + + +async def _get_lock(lock_file: Path, lock_id: bytes): + # TODO replace with inline code and `asyncio.timeout` when we drop 3.9 + while not await _graph_utils.run_in_executor(_file_append_check, lock_file, lock_id): + await asyncio.sleep(0.01) + + +def _file_append_check(file: Path, content: bytes) -> bool: + if file.exists(): + return False + + with file.open(mode='ab') as f: + f.write(content + b'\n') + + return file.read_bytes().startswith(content) diff --git a/pydantic_graph/pydantic_graph/persistence/in_mem.py b/pydantic_graph/pydantic_graph/persistence/in_mem.py index 18052f44d3..ddd2f5daeb 100644 --- a/pydantic_graph/pydantic_graph/persistence/in_mem.py +++ b/pydantic_graph/pydantic_graph/persistence/in_mem.py @@ -15,6 +15,7 @@ import pydantic from typing_extensions import TypeVar +from .. import exceptions from ..nodes import BaseNode, End from . import ( EndSnapshot, @@ -51,6 +52,14 @@ async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, Ru node=next_node.deep_copy() if self.deep_copy else next_node, ) + async def snapshot_node_if_new( + self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT] + ) -> None: + if self.last_snapshot and self.last_snapshot.id == snapshot_id: + return + else: + await self.snapshot_node(state, next_node) + async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: self.last_snapshot = EndSnapshot( state=self._prep_state(state), @@ -64,6 +73,7 @@ async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: assert snapshot_id == self.last_snapshot.id, ( f'snapshot_id must match the last snapshot ID: {snapshot_id!r} != {self.last_snapshot.id!r}' ) + exceptions.GraphNodeStatusError.check(self.last_snapshot.status) self.last_snapshot.status = 'running' self.last_snapshot.start_ts = _utils.now_utc() @@ -78,17 +88,13 @@ async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: self.last_snapshot.duration = perf_counter() - start self.last_snapshot.status = 'success' - async def restore(self) -> Snapshot[StateT, RunEndT] | None: - return self.last_snapshot + async def retrieve_next(self) -> NodeSnapshot[StateT, RunEndT] | None: + if isinstance(self.last_snapshot, NodeSnapshot) and self.last_snapshot.status == 'created': + self.last_snapshot.status = 'pending' + return self.last_snapshot - async def get_node_snapshot( - self, snapshot_id: str, status: SnapshotStatus | None = None - ) -> NodeSnapshot[StateT, RunEndT] | None: - if isinstance(self.last_snapshot, NodeSnapshot) and self.last_snapshot.id == snapshot_id: - if status and self.last_snapshot.status != status: - return None - else: - return self.last_snapshot + async def load(self) -> list[Snapshot[StateT, RunEndT]]: + raise NotImplementedError('load is not supported for SimpleStatePersistence') def _prep_state(self, state: StateT) -> StateT: """Prepare state for snapshot, uses [`copy.deepcopy`][copy.deepcopy] by default.""" @@ -121,6 +127,12 @@ async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, Ru ) self.history.append(snapshot) + async def snapshot_node_if_new( + self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT] + ) -> None: + if not any(s.id == snapshot_id for s in self.history): + await self.snapshot_node(state, next_node) + async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: snapshot = EndSnapshot( state=self._prep_state(state), @@ -136,6 +148,7 @@ async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: raise LookupError(f'No snapshot found with id={snapshot_id}') from e assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' + exceptions.GraphNodeStatusError.check(snapshot.status) snapshot.status = 'running' snapshot.start_ts = _utils.now_utc() start = perf_counter() @@ -149,9 +162,10 @@ async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: snapshot.duration = perf_counter() - start snapshot.status = 'success' - async def restore(self) -> Snapshot[StateT, RunEndT] | None: - if self.history: - return self.history[-1] + async def retrieve_next(self) -> NodeSnapshot[StateT, RunEndT] | None: + if snapshot := next((s for s in self.history if isinstance(s, NodeSnapshot) and s.status == 'created'), None): + snapshot.status = 'pending' + return snapshot async def get_node_snapshot( self, snapshot_id: str, status: SnapshotStatus | None = None @@ -164,6 +178,9 @@ async def get_node_snapshot( ): return snapshot + async def load(self) -> list[Snapshot[StateT, RunEndT]]: + return self.history + def set_types(self, get_types: Callable[[], tuple[type[StateT], type[RunEndT]]]) -> None: if self._snapshots_type_adapter is None: state_t, run_end_t = get_types() diff --git a/tests/graph/test_file_persistence.py b/tests/graph/test_file_persistence.py new file mode 100644 index 0000000000..27f131267e --- /dev/null +++ b/tests/graph/test_file_persistence.py @@ -0,0 +1,138 @@ +from __future__ import annotations as _annotations + +from dataclasses import dataclass +from datetime import timezone +from pathlib import Path +from typing import Union + +import pytest +from inline_snapshot import snapshot + +from pydantic_graph import ( + BaseNode, + End, + EndSnapshot, + Graph, + GraphRunContext, + NodeSnapshot, +) +from pydantic_graph.persistence.file import FileStatePersistence + +from ..conftest import IsFloat, IsNow + +pytestmark = pytest.mark.anyio + + +@dataclass +class Float2String(BaseNode): + input_data: float + + async def run(self, ctx: GraphRunContext) -> String2Length: + return String2Length(str(self.input_data)) + + +@dataclass +class String2Length(BaseNode): + input_data: str + + async def run(self, ctx: GraphRunContext) -> Double: + return Double(len(self.input_data)) + + +@dataclass +class Double(BaseNode[None, None, int]): + input_data: int + + async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # noqa: UP007 + if self.input_data == 7: + return String2Length('x' * 21) + else: + return End(self.input_data * 2) + + +async def test_run(tmp_path: Path, mock_snapshot_id: object): + my_graph = Graph(nodes=(Float2String, String2Length, Double)) + p = tmp_path / 'test_graph.json' + persistence = FileStatePersistence(p) + result = await my_graph.run(Float2String(3.14), persistence=persistence) + # len('3.14') * 2 == 8 + assert result.output == 8 + assert my_graph.name == 'my_graph' + assert await persistence.load() == snapshot( + [ + NodeSnapshot( + state=None, + node=Float2String(input_data=3.14), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + status='success', + id='Float2String:1', + ), + NodeSnapshot( + state=None, + node=String2Length(input_data='3.14'), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + status='success', + id='String2Length:2', + ), + NodeSnapshot( + state=None, + node=Double(input_data=4), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + status='success', + id='Double:3', + ), + EndSnapshot(state=None, result=End(data=8), ts=IsNow(tz=timezone.utc), id='end:4'), + ] + ) + + +async def test_next_from_persistence(tmp_path: Path, mock_snapshot_id: object): + my_graph = Graph(nodes=(Float2String, String2Length, Double)) + p = tmp_path / 'test_graph.json' + persistence = FileStatePersistence(p) + + node = await my_graph.next(Float2String(3.14), persistence=persistence) + assert node == snapshot(String2Length(input_data='3.14')) + assert node.get_snapshot_id() == snapshot('String2Length:2') + assert my_graph.name == 'my_graph' + + node = await my_graph.next_from_persistence(persistence) + assert node == snapshot(Double(input_data=4)) + assert node.get_snapshot_id() == snapshot('Double:3') + + node = await my_graph.next_from_persistence(persistence) + assert node == snapshot(End(data=8)) + assert node.get_snapshot_id() == snapshot('end:4') + + assert await persistence.load() == snapshot( + [ + NodeSnapshot( + state=None, + node=Float2String(input_data=3.14), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + status='success', + id='Float2String:1', + ), + NodeSnapshot( + state=None, + node=String2Length(input_data='3.14'), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + status='success', + id='String2Length:2', + ), + NodeSnapshot( + state=None, + node=Double(input_data=4), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + status='success', + id='Double:3', + ), + EndSnapshot(state=None, result=End(data=8), ts=IsNow(tz=timezone.utc), id='end:4'), + ] + ) From 95819794b7b5e4b8e455d36cbf4af5affdb1c19d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 9 Mar 2025 19:21:00 +0000 Subject: [PATCH 14/25] improve coverage --- pydantic_graph/pydantic_graph/exceptions.py | 2 +- pydantic_graph/pydantic_graph/graph.py | 2 +- .../pydantic_graph/persistence/__init__.py | 6 --- tests/graph/test_file_persistence.py | 51 ++++++++++++++++++ tests/graph/test_graph.py | 25 +++++++++ tests/graph/test_persistence.py | 52 ++++++++++++++++++- 6 files changed, 129 insertions(+), 9 deletions(-) diff --git a/pydantic_graph/pydantic_graph/exceptions.py b/pydantic_graph/pydantic_graph/exceptions.py index 1bfa4b5930..9eb47b6f35 100644 --- a/pydantic_graph/pydantic_graph/exceptions.py +++ b/pydantic_graph/pydantic_graph/exceptions.py @@ -31,7 +31,7 @@ class GraphNodeStatusError(GraphRuntimeError): def __init__(self, actual_status: 'SnapshotStatus'): self.actual_status = actual_status - super().__init__(f"Snapshot status is {actual_status!r}, not 'created' or 'pending'.") + super().__init__(f"Incorrect snapshot status {actual_status!r}, must be 'created' or 'pending'.") @classmethod def check(cls, status: 'SnapshotStatus') -> None: diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index df2e48ab2d..f9fff4929d 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -692,7 +692,7 @@ async def main(): if not isinstance(node, BaseNode): # While technically this is not compatible with the documented method signature, it's an easy mistake to # make, and we should eagerly provide a more helpful error message than you'd get otherwise. - raise exceptions.GraphRuntimeError(f'`next` must be called with a `BaseNode` instance: {node!r}.') + raise TypeError(f'`next` must be called with a `BaseNode` instance, got {node!r}.') node_id = node.get_node_id() if node_id not in self.graph.node_defs: diff --git a/pydantic_graph/pydantic_graph/persistence/__init__.py b/pydantic_graph/pydantic_graph/persistence/__init__.py index 63fd000dc6..e044effceb 100644 --- a/pydantic_graph/pydantic_graph/persistence/__init__.py +++ b/pydantic_graph/pydantic_graph/persistence/__init__.py @@ -201,9 +201,3 @@ def build_snapshot_list_type_adapter( state_t: type[StateT], run_end_t: type[RunEndT] ) -> pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]]: return pydantic.TypeAdapter(list[Annotated[Snapshot[state_t, run_end_t], pydantic.Discriminator('kind')]]) - - -def build_snapshot_single_type_adapter( - state_t: type[StateT], run_end_t: type[RunEndT] -) -> pydantic.TypeAdapter[Snapshot[StateT, RunEndT]]: - return pydantic.TypeAdapter(Annotated[Snapshot[state_t, run_end_t], pydantic.Discriminator('kind')]) diff --git a/tests/graph/test_file_persistence.py b/tests/graph/test_file_persistence.py index 27f131267e..61a157f2b7 100644 --- a/tests/graph/test_file_persistence.py +++ b/tests/graph/test_file_persistence.py @@ -136,3 +136,54 @@ async def test_next_from_persistence(tmp_path: Path, mock_snapshot_id: object): EndSnapshot(state=None, result=End(data=8), ts=IsNow(tz=timezone.utc), id='end:4'), ] ) + + +async def test_node_error(tmp_path: Path, mock_snapshot_id: object): + @dataclass + class Foo(BaseNode): + async def run(self, ctx: GraphRunContext) -> Bar: + return Bar() + + @dataclass + class Bar(BaseNode[None, None, None]): + async def run(self, ctx: GraphRunContext) -> End[None]: + raise RuntimeError('test error') + + g = Graph(nodes=(Foo, Bar)) + p = tmp_path / 'test_graph.json' + persistence = FileStatePersistence(p) + with pytest.raises(RuntimeError, match='test error'): + await g.run(Foo(), persistence=persistence) + + assert await persistence.load() == snapshot( + [ + NodeSnapshot( + state=None, + node=Foo(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + status='success', + id='Foo:1', + ), + NodeSnapshot( + state=None, + node=Bar(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + status='error', + id='Bar:2', + ), + ] + ) + + +async def test_lock_timeout(tmp_path: Path): + p = tmp_path / 'test_graph.json' + persistence = FileStatePersistence(p) + async with persistence._lock(): # type: ignore[reportPrivateUsage] + pass + + async with persistence._lock(): # type: ignore[reportPrivateUsage] + with pytest.raises(TimeoutError): + async with persistence._lock(timeout=0.1): # type: ignore[reportPrivateUsage] + pass diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 9bdbdc1314..1714d2f35f 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -20,6 +20,7 @@ GraphRuntimeError, GraphSetupError, NodeSnapshot, + SimpleStatePersistence, ) from ..conftest import IsFloat, IsNow @@ -372,6 +373,30 @@ async def run(self, ctx: GraphRunContext) -> Foo: ) +async def test_next_error(mock_snapshot_id: object): + @dataclass + class Foo(BaseNode): + async def run(self, ctx: GraphRunContext) -> Bar: + return Bar() + + @dataclass + class Bar(BaseNode[None, None, None]): + async def run(self, ctx: GraphRunContext) -> End[None]: + return End(None) + + g = Graph(nodes=(Foo, Bar)) + sp = SimpleStatePersistence() + n = await g.next(Foo(), sp) + assert n == snapshot(Bar()) + + assert isinstance(n, BaseNode) + n = await g.next(n, sp) + assert n == snapshot(End(None)) + + with pytest.raises(TypeError, match=r'`next` must be called with a `BaseNode` instance, got End\(data=None\).'): + await g.next(n, sp) # pyright: ignore[reportArgumentType] + + async def test_deps(mock_snapshot_id: object): @dataclass class Deps: diff --git a/tests/graph/test_persistence.py b/tests/graph/test_persistence.py index 7220d80b8b..e16419a552 100644 --- a/tests/graph/test_persistence.py +++ b/tests/graph/test_persistence.py @@ -18,6 +18,7 @@ GraphRunContext, NodeSnapshot, ) +from pydantic_graph.exceptions import GraphNodeStatusError, GraphRuntimeError from ..conftest import IsFloat, IsNow @@ -58,7 +59,7 @@ async def test_dump_load_state(graph: Graph[MyState, None, int], mock_snapshot_i result = await graph.run(Foo(), state=MyState(1, ''), persistence=sp) assert result.output == snapshot(4) assert result.state == snapshot(MyState(x=2, y='y')) - assert sp.history == snapshot( + assert await sp.load() == snapshot( [ NodeSnapshot( state=MyState(x=1, y=''), @@ -270,3 +271,52 @@ async def run(self, ctx: GraphRunContext) -> End[int]: ), ] ) + + +async def test_rerun_node(mock_snapshot_id: object): + @dataclass + class Foo(BaseNode[None, None, int]): + async def run(self, ctx: GraphRunContext) -> End[int]: + return End(123) + + graph = Graph(nodes=[Foo]) + + sp = FullStatePersistence() + node = Foo() + end = await graph.next(node, sp) + assert end == snapshot(End(123)) + + msg = "Incorrect snapshot status 'success', must be 'created' or 'pending'." + with pytest.raises(GraphNodeStatusError, match=msg): + await graph.next(node, sp) + + +async def test_next_from_persistence(mock_snapshot_id: object): + @dataclass + class Foo(BaseNode): + async def run(self, ctx: GraphRunContext) -> Spam: + return Spam() + + @dataclass + class Spam(BaseNode[None, None, int]): + async def run(self, ctx: GraphRunContext) -> End[int]: + return End(123) + + g1 = Graph(nodes=[Foo, Spam]) + + sp = FullStatePersistence() + node = Foo() + assert g1.name is None + node = await g1.next(node, sp) + assert g1.name == 'g1' + assert node == snapshot(Spam()) + + end = await g1.next_from_persistence(sp) + assert end == snapshot(End(123)) + + g2 = Graph(nodes=[Foo, Spam]) + sp = FullStatePersistence() + assert g2.name is None + with pytest.raises(GraphRuntimeError, match='Unable to restore snapshot from state persistence.'): + await g2.next_from_persistence(sp) + assert g2.name == 'g2' From 88723e33ee56b2e67061072fa1ba7e4a8b4ef89a Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 9 Mar 2025 19:23:31 +0000 Subject: [PATCH 15/25] fix for 3.9 and 3.10 --- tests/graph/test_file_persistence.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/graph/test_file_persistence.py b/tests/graph/test_file_persistence.py index 61a157f2b7..1e2c7886f9 100644 --- a/tests/graph/test_file_persistence.py +++ b/tests/graph/test_file_persistence.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +from asyncio.exceptions import TimeoutError from dataclasses import dataclass from datetime import timezone from pathlib import Path From 2a9f90e03180645a1c047112bed36eb9769a4571 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 9 Mar 2025 19:50:24 +0000 Subject: [PATCH 16/25] improve coverage --- .../pydantic_graph/persistence/file.py | 2 +- .../pydantic_graph/persistence/in_mem.py | 43 +++---------------- tests/graph/test_file_persistence.py | 14 +++++- tests/graph/test_persistence.py | 26 ++++++++--- 4 files changed, 41 insertions(+), 44 deletions(-) diff --git a/pydantic_graph/pydantic_graph/persistence/file.py b/pydantic_graph/pydantic_graph/persistence/file.py index 30012eb55d..7894dfce30 100644 --- a/pydantic_graph/pydantic_graph/persistence/file.py +++ b/pydantic_graph/pydantic_graph/persistence/file.py @@ -58,7 +58,7 @@ async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: try: snapshot = next(s for s in snapshots if s.id == snapshot_id) except StopIteration as e: - raise LookupError(f'No snapshot found with id={snapshot_id}') from e + raise LookupError(f'No snapshot found with id={snapshot_id!r}') from e assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' exceptions.GraphNodeStatusError.check(snapshot.status) diff --git a/pydantic_graph/pydantic_graph/persistence/in_mem.py b/pydantic_graph/pydantic_graph/persistence/in_mem.py index ddd2f5daeb..1854341fad 100644 --- a/pydantic_graph/pydantic_graph/persistence/in_mem.py +++ b/pydantic_graph/pydantic_graph/persistence/in_mem.py @@ -21,7 +21,6 @@ EndSnapshot, NodeSnapshot, Snapshot, - SnapshotStatus, StatePersistence, _utils, build_snapshot_list_type_adapter, @@ -38,19 +37,11 @@ class SimpleStatePersistence(StatePersistence[StateT, RunEndT]): If no state persistence implementation is provided when running a graph, this is used by default. """ - deep_copy: bool = False - """Whether to deep copy the state and nodes when storing them. - - Defaults to `False` so you can use nodes that don't support deep copying. - """ last_snapshot: Snapshot[StateT, RunEndT] | None = None """The last snapshot.""" async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: - self.last_snapshot = NodeSnapshot( - state=self._prep_state(state), - node=next_node.deep_copy() if self.deep_copy else next_node, - ) + self.last_snapshot = NodeSnapshot(state=state, node=next_node) async def snapshot_node_if_new( self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT] @@ -61,18 +52,14 @@ async def snapshot_node_if_new( await self.snapshot_node(state, next_node) async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: - self.last_snapshot = EndSnapshot( - state=self._prep_state(state), - result=end.deep_copy_data() if self.deep_copy else end, - ) + self.last_snapshot = EndSnapshot(state=state, result=end) @asynccontextmanager async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: - assert self.last_snapshot is not None, 'No snapshot to record' + if self.last_snapshot is None or snapshot_id != self.last_snapshot.id: + raise LookupError(f'No snapshot found with id={snapshot_id!r}') + assert isinstance(self.last_snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' - assert snapshot_id == self.last_snapshot.id, ( - f'snapshot_id must match the last snapshot ID: {snapshot_id!r} != {self.last_snapshot.id!r}' - ) exceptions.GraphNodeStatusError.check(self.last_snapshot.status) self.last_snapshot.status = 'running' self.last_snapshot.start_ts = _utils.now_utc() @@ -96,13 +83,6 @@ async def retrieve_next(self) -> NodeSnapshot[StateT, RunEndT] | None: async def load(self) -> list[Snapshot[StateT, RunEndT]]: raise NotImplementedError('load is not supported for SimpleStatePersistence') - def _prep_state(self, state: StateT) -> StateT: - """Prepare state for snapshot, uses [`copy.deepcopy`][copy.deepcopy] by default.""" - if not self.deep_copy or state is None: - return state - else: - return copy.deepcopy(state) - @dataclass class FullStatePersistence(StatePersistence[StateT, RunEndT]): @@ -145,7 +125,7 @@ async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: try: snapshot = next(s for s in self.history if s.id == snapshot_id) except StopIteration as e: - raise LookupError(f'No snapshot found with id={snapshot_id}') from e + raise LookupError(f'No snapshot found with id={snapshot_id!r}') from e assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' exceptions.GraphNodeStatusError.check(snapshot.status) @@ -167,17 +147,6 @@ async def retrieve_next(self) -> NodeSnapshot[StateT, RunEndT] | None: snapshot.status = 'pending' return snapshot - async def get_node_snapshot( - self, snapshot_id: str, status: SnapshotStatus | None = None - ) -> Snapshot[StateT, RunEndT] | None: - for snapshot in self.history: - if ( - isinstance(snapshot, NodeSnapshot) - and snapshot.id == snapshot_id - and (status is None or snapshot.status == status) - ): - return snapshot - async def load(self) -> list[Snapshot[StateT, RunEndT]]: return self.history diff --git a/tests/graph/test_file_persistence.py b/tests/graph/test_file_persistence.py index 1e2c7886f9..b178f58531 100644 --- a/tests/graph/test_file_persistence.py +++ b/tests/graph/test_file_persistence.py @@ -45,7 +45,7 @@ class Double(BaseNode[None, None, int]): input_data: int async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # noqa: UP007 - if self.input_data == 7: + if self.input_data == 7: # pragma: no cover return String2Length('x' * 21) else: return End(self.input_data * 2) @@ -188,3 +188,15 @@ async def test_lock_timeout(tmp_path: Path): with pytest.raises(TimeoutError): async with persistence._lock(timeout=0.1): # type: ignore[reportPrivateUsage] pass + + +async def test_record_lookup_error(tmp_path: Path): + p = tmp_path / 'test_graph.json' + persistence = FileStatePersistence(p) + my_graph = Graph(nodes=(Float2String, String2Length, Double)) + my_graph.set_persistence_types(persistence) + my_graph.set_persistence_types(persistence) + + with pytest.raises(LookupError, match="No snapshot found with id='foobar'"): + async with persistence.record_run('foobar'): + pass diff --git a/tests/graph/test_persistence.py b/tests/graph/test_persistence.py index e16419a552..b9bcd8f54e 100644 --- a/tests/graph/test_persistence.py +++ b/tests/graph/test_persistence.py @@ -4,6 +4,7 @@ import json from dataclasses import dataclass from datetime import datetime, timezone +from typing import Any import pytest from dirty_equals import IsStr @@ -17,8 +18,10 @@ Graph, GraphRunContext, NodeSnapshot, + SimpleStatePersistence, ) from pydantic_graph.exceptions import GraphNodeStatusError, GraphRuntimeError +from pydantic_graph.persistence import StatePersistence from ..conftest import IsFloat, IsNow @@ -291,7 +294,8 @@ async def run(self, ctx: GraphRunContext) -> End[int]: await graph.next(node, sp) -async def test_next_from_persistence(mock_snapshot_id: object): +@pytest.mark.parametrize('persistence_cls', [SimpleStatePersistence, FullStatePersistence]) +async def test_next_from_persistence(persistence_cls: type[StatePersistence[None, int]], mock_snapshot_id: object): @dataclass class Foo(BaseNode): async def run(self, ctx: GraphRunContext) -> Spam: @@ -304,19 +308,31 @@ async def run(self, ctx: GraphRunContext) -> End[int]: g1 = Graph(nodes=[Foo, Spam]) - sp = FullStatePersistence() + sp = persistence_cls() node = Foo() assert g1.name is None node = await g1.next(node, sp) assert g1.name == 'g1' - assert node == snapshot(Spam()) + assert node == Spam() end = await g1.next_from_persistence(sp) - assert end == snapshot(End(123)) + assert end == End(123) g2 = Graph(nodes=[Foo, Spam]) - sp = FullStatePersistence() + sp = persistence_cls() assert g2.name is None with pytest.raises(GraphRuntimeError, match='Unable to restore snapshot from state persistence.'): await g2.next_from_persistence(sp) assert g2.name == 'g2' + + +@pytest.mark.parametrize('persistence_cls', [SimpleStatePersistence, FullStatePersistence]) +async def test_record_lookup_error(persistence_cls: type[StatePersistence[Any, Any]]): + persistence = persistence_cls() + my_graph = Graph(nodes=(Foo, Bar)) + my_graph.set_persistence_types(persistence) + my_graph.set_persistence_types(persistence) + + with pytest.raises(LookupError, match="No snapshot found with id='foobar'"): + async with persistence.record_run('foobar'): + pass From 294cbd2ee2c910e9569605b38554038df8d31883 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 10 Mar 2025 14:58:12 +0000 Subject: [PATCH 17/25] more docs --- docs/api/pydantic_graph/nodes.md | 1 + docs/api/pydantic_graph/persistence.md | 2 + docs/graph.md | 370 +++++++++--------- .../pydantic_ai_examples/question_graph.py | 36 +- pydantic_graph/pydantic_graph/graph.py | 23 +- pydantic_graph/pydantic_graph/nodes.py | 10 +- .../pydantic_graph/persistence/__init__.py | 15 +- .../pydantic_graph/persistence/_utils.py | 3 +- .../pydantic_graph/persistence/file.py | 13 +- .../pydantic_graph/persistence/in_mem.py | 14 +- tests/graph/test_persistence.py | 12 +- tests/test_examples.py | 2 + tests/typed_graph.py | 4 +- 13 files changed, 264 insertions(+), 241 deletions(-) diff --git a/docs/api/pydantic_graph/nodes.md b/docs/api/pydantic_graph/nodes.md index ecf6d35f50..948beb988e 100644 --- a/docs/api/pydantic_graph/nodes.md +++ b/docs/api/pydantic_graph/nodes.md @@ -3,6 +3,7 @@ ::: pydantic_graph.nodes options: members: + - StateT - GraphRunContext - BaseNode - End diff --git a/docs/api/pydantic_graph/persistence.md b/docs/api/pydantic_graph/persistence.md index 4b68150307..e35c83f928 100644 --- a/docs/api/pydantic_graph/persistence.md +++ b/docs/api/pydantic_graph/persistence.md @@ -3,3 +3,5 @@ ::: pydantic_graph.persistence ::: pydantic_graph.persistence.in_mem + +::: pydantic_graph.persistence.file diff --git a/docs/graph.md b/docs/graph.md index 473a7fd56f..a6440598f3 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -39,7 +39,7 @@ pip/uv-add pydantic-graph [`GraphRunContext`][pydantic_graph.nodes.GraphRunContext] — The context for the graph run, similar to PydanticAI's [`RunContext`][pydantic_ai.tools.RunContext]. This holds the state of the graph and dependencies and is passed to nodes when they're run. -`GraphRunContext` is generic in the state type of the graph it's used in, [`StateT`][pydantic_graph.persistence.StateT]. +`GraphRunContext` is generic in the state type of the graph it's used in, [`StateT`][pydantic_graph.nodes.StateT]. ### End @@ -59,7 +59,7 @@ Nodes, which are generally [`dataclass`es][dataclasses.dataclass], generally con Nodes are generic in: -* **state**, which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.persistence.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter, see [stateful graphs](#stateful-graphs) for more information +* **state**, which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.nodes.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter, see [stateful graphs](#stateful-graphs) for more information * **deps**, which must have the same type as the deps of the graph they're included in, [`DepsT`][pydantic_graph.nodes.DepsT] has a default of `None`, so if you're not using deps you can omit this generic parameter, see [dependency injection](#dependency-injection) for more information * **graph return type** — this only applies if the node returns [`End`][pydantic_graph.nodes.End]. [`RunEndT`][pydantic_graph.nodes.RunEndT] has a default of [Never][typing.Never] so this generic parameter can be omitted if the node doesn't return `End`, but must be included if it does. @@ -119,7 +119,7 @@ class MyNode(BaseNode[MyState, None, int]): # (1)! `Graph` is generic in: -* **state** the state type of the graph, [`StateT`][pydantic_graph.persistence.StateT] +* **state** the state type of the graph, [`StateT`][pydantic_graph.nodes.StateT] * **deps** the deps type of the graph, [`DepsT`][pydantic_graph.nodes.DepsT] * **graph return type** the return type of the graph run, [`RunEndT`][pydantic_graph.nodes.RunEndT] @@ -473,11 +473,150 @@ async def main(): _(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ -## Custom Control Flow +## Iterating Over a Graph + +### Using `Graph.iter` for `async for` iteration + +Sometimes you want direct control or insight into each node as the graph executes. The easiest way to do that is with the [`Graph.iter`][pydantic_graph.graph.Graph.iter] method, which returns a **context manager** that yields a [`GraphRun`][pydantic_graph.graph.GraphRun] object. The `GraphRun` is an async-iterable over the nodes of your graph, allowing you to record or modify them as they execute. + +Here's an example: + +```python {title="count_down.py" noqa="I001" py="3.10"} +from __future__ import annotations as _annotations + +from dataclasses import dataclass +from pydantic_graph import Graph, BaseNode, End, GraphRunContext + + +@dataclass +class CountDownState: + counter: int + + +@dataclass +class CountDown(BaseNode[CountDownState, None, int]): + async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]: + if ctx.state.counter <= 0: + return End(ctx.state.counter) + ctx.state.counter -= 1 + return CountDown() + + +count_down_graph = Graph(nodes=[CountDown]) + + +async def main(): + state = CountDownState(counter=3) + async with count_down_graph.iter(CountDown(), state=state) as run: # (1)! + async for node in run: # (2)! + print('Node:', node) + #> Node: CountDown() + #> Node: CountDown() + #> Node: CountDown() + #> Node: End(data=0) + print('Final result:', run.result.output) # (3)! + #> Final result: 0 +``` + +1. `Graph.iter(...)` returns a [`GraphRun`][pydantic_graph.graph.GraphRun]. +2. Here, we step through each node as it is executed. +3. Once the graph returns an [`End`][pydantic_graph.nodes.End], the loop ends, and `run.final_result` becomes a [`GraphRunResult`][pydantic_graph.graph.GraphRunResult] containing the final outcome (`0` here). + +### Using `GraphRun.next(node)` manually + +Alternatively, you can drive iteration manually with the [`GraphRun.next`][pydantic_graph.graph.GraphRun.next] method, which allows you to pass in whichever node you want to run next. You can modify or selectively skip nodes this way. + +Below is a contrived example that stops whenever the counter is at 2, ignoring any node runs beyond that: + +```python {title="count_down_next.py" noqa="I001" py="3.10"} +from pydantic_graph import End, SimpleStatePersistence +from count_down import CountDown, CountDownState, count_down_graph + + +async def main(): + state = CountDownState(counter=5) + sp = SimpleStatePersistence() + async with count_down_graph.iter(CountDown(), state=state, persistence=sp) as run: + node = run.next_node # (1)! + while not isinstance(node, End): # (2)! + print('Node:', node) + #> Node: CountDown() + #> Node: CountDown() + #> Node: CountDown() + #> Node: CountDown() + if state.counter == 2: + break # (3)! + node = await run.next(node) # (4)! + + print(run.result) # (5)! + #> None +``` + +1. We start by grabbing the first node that will be run in the agent's graph. +2. The agent run is finished once an `End` node has been produced; instances of `End` cannot be passed to `next`. +3. If the user decides to stop early, we break out of the loop. The graph run won't have a real final result in that case (`run.final_result` remains `None`). +4. At each step, we call `await run.next(node)` to run it and get the next node (or an `End`). +5. Because we did not continue the run until it finished, the `result` is not set. +6. The run's history is still populated with the steps we executed so far. -In many real-world applications, Graphs cannot run uninterrupted from start to finish — they might require external input, or run over an extended period of time such that a single process cannot execute the entire graph run from start to finish without interruption. +## State Persistence -In these scenarios the [`next`][pydantic_graph.graph.Graph.next] method can be used to run the graph one node at a time. +The greatest value of finite state machine (FSM) graphs comes when their execution is interrupted. This can be for a variety of reasons: + +- because the logic they encompass must be paused — e.g. the returns workflow for an e-commerce order needs to wait for the item to be posted to the returns center or because execution of the next node needs input from a user so needs to wait for a new http request, +- because their execution takes long enough that the entire graph can't be executed in a single continuous run — e.g. a deep research agent that takes hours to run, +- or, because multiple nodes can be run in parallel on different instances (note: parallel node execution is not yet supported in `pydantic-graph`, see [#704](https://github.com/pydantic/pydantic-ai/issues/704)). + +In all these scenarios, conventional control flow (boolean logic and nest function calls) breaks down and application code becomes spaghetti with the logic required to interrupt and resume execution dominating the code. + +To allow graph runs to be interrupted and resumed, `pydantic-graph` provides state persistence — a system for snapshotting the state of a graph run before and after each node is run, allowing a graph run to be resumed from any point in the graph. + +`pydantic-graph` includes three state persistence implementations: + +- [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] — Simple in memory state persistence that just hold the latest snapshot. If no state persistence implementation is provided when running a graph, this is used by default. +- [`FullStatePersistence`][pydantic_graph.FullStatePersistence] — In memory state persistence that hold a list of snapshots. +- [`FileStatePersistence`][pydantic_graph.persistence.file.FileStatePersistence] — File based state persistence that saves snapshots to a JSON file. + +In production applications, developers should implement their own state persistence by subclassing [`BaseStatePersistence`][pydantic_graph.persistence.BaseStatePersistence] abstract base class. + +At a high level the role of `StatePersistence` implementations is to store and retrieve [`NodeSnapshot`][pydantic_graph.persistence.NodeSnapshot] and [`EndSnapshot`][pydantic_graph.persistence.EndSnapshot] objects. + +[`graph.next_from_persistence()`][pydantic_graph.graph.Graph.next_from_persistence] may be used to run nodes of a graph based on its state stored in persistence. + +We can run the `count_down_graph` from [above](#iterating-over-a-graph), using [`graph.next_from_persistence()`][pydantic_graph.graph.Graph.next_from_persistence] and [`FileStatePersistence`][pydantic_graph.persistence.file.FileStatePersistence]. + +```python {title="count_down_from_persistence.py" noqa="I001" py="3.10"} +from pathlib import Path +from pydantic_graph import End +from pydantic_graph.persistence.file import FileStatePersistence +from count_down import CountDown, CountDownState, count_down_graph + + +async def main(): + persistence = FileStatePersistence(Path('count_down.json')) # (1)! + state = CountDownState(counter=5) + await count_down_graph.next(CountDown(), state=state, persistence=persistence) + + done = False + while not done: + done = await run_node() + + +async def run_node() -> bool: # (2)! + persistence = FileStatePersistence(Path('count_down.json')) + node_or_end = await count_down_graph.next_from_persistence(persistence) # (3)! + print('Node:', node_or_end) + #> Node: CountDown() + #> Node: CountDown() + #> Node: CountDown() + #> Node: CountDown() + #> Node: End(data=0) + return isinstance(node_or_end, End) +``` + +_(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ + +### Example: Q&A with GenAI In this example, an AI asks the user a question, the user provides an answer, the AI evaluates the answer and ends if the user got it right or asks another question if they got it wrong. @@ -487,13 +626,19 @@ In this example, an AI asks the user a question, the user provides an answer, th from dataclasses import dataclass, field - from pydantic_graph import BaseNode, End, Graph, GraphRunContext + from groq import BaseModel + from pydantic_graph import ( + BaseNode, + End, + Graph, + GraphRunContext, + ) from pydantic_ai import Agent from pydantic_ai.format_as_xml import format_as_xml from pydantic_ai.messages import ModelMessage - ask_agent = Agent('openai:gpt-4o', result_type=str) + ask_agent = Agent('openai:gpt-4o', result_type=str, instrument=True) @dataclass @@ -512,29 +657,23 @@ In this example, an AI asks the user a question, the user provides an answer, th ) ctx.state.ask_agent_messages += result.all_messages() ctx.state.question = result.data - return Question(result.data) - - - @dataclass - class Question(BaseNode[QuestionState]): - question: str - - async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer: - raise NotImplementedError('Question nodes should not be run') + return Answer(result.data) @dataclass class Answer(BaseNode[QuestionState]): - answer: str + question: str async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate: - return Evaluate(self.answer) + answer = input(f'{self.question}: ') + return Evaluate(answer) - @dataclass - class EvaluationResult: + class EvaluationResult(BaseModel, use_attribute_docstrings=True): correct: bool + """Whether the answer is correct.""" comment: str + """Comment on the answer, reprimand the user if the answer is wrong.""" evaluate_agent = Agent( @@ -545,7 +684,7 @@ In this example, an AI asks the user a question, the user provides an answer, th @dataclass - class Evaluate(BaseNode[QuestionState]): + class Evaluate(BaseNode[QuestionState, None, str]): answer: str async def run( @@ -574,48 +713,50 @@ In this example, an AI asks the user a question, the user provides an answer, th return Ask() - question_graph = Graph(nodes=(Ask, Question, Answer, Evaluate, Reprimand)) + question_graph = Graph( + nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState + ) ``` _(This example is complete, it can be run "as is" with Python 3.10+)_ ```python {title="ai_q_and_a_run.py" noqa="I001" py="3.10"} -from rich.prompt import Prompt +import sys +from pathlib import Path -from pydantic_graph import End, FullStatePersistence +from pydantic_graph import End +from pydantic_graph.persistence.file import FileStatePersistence +from pydantic_ai.messages import ModelMessage # noqa: F401 -from ai_q_and_a_graph import Ask, question_graph, Question, QuestionState, Answer +from ai_q_and_a_graph import Ask, question_graph, Evaluate, QuestionState, Answer async def main(): - state = QuestionState() # (1)! - node = Ask() # (2)! - persistence = FullStatePersistence() # (3)! + answer: str | None = sys.argv[2] if len(sys.argv) > 2 else None # (1)! + persistence = FileStatePersistence(Path('question_graph.json')) # (2)! + question_graph.set_persistence_types(persistence) # (3)! + + if snapshot := await persistence.retrieve_next(): # (4)! + state = snapshot.state + assert answer is not None + node = Evaluate(answer) # (5)! + else: + state = QuestionState() + node = Ask() # (6)! + while True: - node = await question_graph.next( # (4)! + node = await question_graph.next( # (7)! node, persistence=persistence, state=state ) - if isinstance(node, Question): - node = Answer(Prompt.ask(node.question)) # (5)! - elif isinstance(node, End): # (6)! - print(f'Correct answer! {node.data}') - #> Correct answer! Well done, 1 + 1 = 2 - print([e.node for e in persistence.history]) - """ - [ - Ask(), - Question(question='What is the capital of France?'), - Answer(answer='Vichy'), - Evaluate(answer='Vichy'), - Reprimand(comment='Vichy is no longer the capital of France.'), - Ask(), - Question(question='what is 1 + 1?'), - Answer(answer='2'), - Evaluate(answer='2'), - End(data='Well done, 1 + 1 = 2'), - ] - """ - return + if isinstance(node, End): + print('END:', node.data) + history = await persistence.load() + print([e.node for e in history]) + break + elif isinstance(node, Answer): + print(node.question) + #> What is the capital of France? + break # otherwise just continue ``` @@ -626,132 +767,9 @@ async def main(): 5. If the current node is an `Question` node, prompt the user for an answer. 6. Since we're using [`next`][pydantic_graph.graph.Graph.next] we have to manually check for an [`End`][pydantic_graph.nodes.End] and exit the loop if we get one. -_(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ - -A [mermaid diagram](#mermaid-diagrams) for this graph can be generated with the following code: - -```py {title="ai_q_and_a_diagram.py" py="3.10"} -from ai_q_and_a_graph import Ask, question_graph - -question_graph.mermaid_code(start_node=Ask) -``` - -```mermaid ---- -title: question_graph ---- -stateDiagram-v2 - [*] --> Ask - Ask --> Question - Question --> Answer - Answer --> Evaluate - Evaluate --> Reprimand - Evaluate --> [*] - Reprimand --> Ask -``` - -You maybe have noticed that although this example transfers control flow out of the graph run, we're still using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] to get user input, with the process hanging while we wait for the user to enter a response. For an example of genuine out-of-process control flow, see the [question graph example](examples/question-graph.md). - -## Iterating Over a Graph - -### Using `Graph.iter` for `async for` iteration - -Sometimes you want direct control or insight into each node as the graph executes. The easiest way to do that is with the [`Graph.iter`][pydantic_graph.graph.Graph.iter] method, which returns a **context manager** that yields a [`GraphRun`][pydantic_graph.graph.GraphRun] object. The `GraphRun` is an async-iterable over the nodes of your graph, allowing you to record or modify them as they execute. - -Here's an example: - -```python {title="count_down.py" noqa="I001" py="3.10"} -from __future__ import annotations as _annotations - -from dataclasses import dataclass -from pydantic_graph import Graph, BaseNode, End, GraphRunContext, FullStatePersistence - - -@dataclass -class CountDownState: - counter: int - - -@dataclass -class CountDown(BaseNode[CountDownState]): - async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]: - if ctx.state.counter <= 0: - return End(ctx.state.counter) - ctx.state.counter -= 1 - return CountDown() - - -count_down_graph = Graph(nodes=[CountDown]) - +_(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main(answer))` to run `main`)_ -async def main(): - state = CountDownState(counter=3) - persistence = FullStatePersistence() - async with count_down_graph.iter( - CountDown(), state=state, persistence=persistence - ) as run: # (1)! - async for node in run: # (2)! - print('Node:', node) - #> Node: CountDown() - #> Node: CountDown() - #> Node: CountDown() - #> Node: End(data=0) - print('Final result:', run.result.output) # (3)! - #> Final result: 0 - print('History snapshots:', [step.node for step in persistence.history]) - """ - History snapshots: - [CountDown(), CountDown(), CountDown(), CountDown(), End(data=0)] - """ -``` - -1. `Graph.iter(...)` returns a [`GraphRun`][pydantic_graph.graph.GraphRun]. -2. Here, we step through each node as it is executed. -3. Once the graph returns an [`End`][pydantic_graph.nodes.End], the loop ends, and `run.final_result` becomes a [`GraphRunResult`][pydantic_graph.graph.GraphRunResult] containing the final outcome (`0` here). - -### Using `GraphRun.next(node)` manually - -Alternatively, you can drive iteration manually with the [`GraphRun.next`][pydantic_graph.graph.GraphRun.next] method, which allows you to pass in whichever node you want to run next. You can modify or selectively skip nodes this way. - -Below is a contrived example that stops whenever the counter is at 2, ignoring any node runs beyond that: - -```python {title="count_down_next.py" noqa="I001" py="3.10"} -from pydantic_graph import End, FullStatePersistence -from count_down import CountDown, CountDownState, count_down_graph - - -async def main(): - state = CountDownState(counter=5) - sp = FullStatePersistence() - async with count_down_graph.iter(CountDown(), state=state, persistence=sp) as run: - node = run.next_node # (1)! - while not isinstance(node, End): # (2)! - print('Node:', node) - #> Node: CountDown() - #> Node: CountDown() - #> Node: CountDown() - #> Node: CountDown() - if state.counter == 2: - break # (3)! - node = await run.next(node) # (4)! - - print(run.result) # (5)! - #> None - - for step in sp.history: # (6)! - print('History Step:', step.node, step.state) - #> History Step: CountDown() CountDownState(counter=5) - #> History Step: CountDown() CountDownState(counter=4) - #> History Step: CountDown() CountDownState(counter=3) - #> History Step: CountDown() CountDownState(counter=2) -``` - -1. We start by grabbing the first node that will be run in the agent's graph. -2. The agent run is finished once an `End` node has been produced; instances of `End` cannot be passed to `next`. -3. If the user decides to stop early, we break out of the loop. The graph run won't have a real final result in that case (`run.final_result` remains `None`). -4. At each step, we call `await run.next(node)` to run it and get the next node (or an `End`). -5. Because we did not continue the run until it finished, the `result` is not set. -6. The run's history is still populated with the steps we executed so far. +For an example of genuine out-of-process control flow, see the [question graph example](examples/question-graph.md). ## Dependency Injection diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index 2ce0861460..a2009fab3f 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -38,24 +38,24 @@ class QuestionState: @dataclass -class GenerateQuestion(BaseNode[QuestionState]): - async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask: +class Ask(BaseNode[QuestionState]): + async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer: result = await ask_agent.run( 'Ask a simple question with a single correct answer.', message_history=ctx.state.ask_agent_messages, ) ctx.state.ask_agent_messages += result.all_messages() ctx.state.question = result.data - return Ask(result.data) + return Answer(result.data) @dataclass -class Ask(BaseNode[QuestionState]): +class Answer(BaseNode[QuestionState]): question: str - async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer: + async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate: answer = input(f'{self.question}: ') - return Answer(answer) + return Evaluate(answer) class EvaluationResult(BaseModel, use_attribute_docstrings=True): @@ -73,7 +73,7 @@ class EvaluationResult(BaseModel, use_attribute_docstrings=True): @dataclass -class Answer(BaseNode[QuestionState, None, str]): +class Evaluate(BaseNode[QuestionState, None, str]): answer: str async def run( @@ -96,21 +96,20 @@ async def run( class Reprimand(BaseNode[QuestionState]): comment: str - async def run(self, ctx: GraphRunContext[QuestionState]) -> GenerateQuestion: + async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask: print(f'Comment: {self.comment}') - # > Comment: Vichy is no longer the capital of France. ctx.state.question = None - return GenerateQuestion() + return Ask() question_graph = Graph( - nodes=(GenerateQuestion, Ask, Answer, Reprimand), state_type=QuestionState + nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState ) async def run_as_continuous(): state = QuestionState() - node = GenerateQuestion() + node = Ask() end = await question_graph.run(node, state=state) print('END:', end.output) @@ -124,10 +123,10 @@ async def run_as_cli(answer: str | None): assert answer is not None, ( 'answer required, usage "uv run -m pydantic_ai_examples.question_graph cli "' ) - node = Answer(answer) + node = Evaluate(answer) else: state = QuestionState() - node = GenerateQuestion() + node = Ask() # debug(state, node) with logfire.span('run questions graph'): @@ -139,7 +138,7 @@ async def run_as_cli(answer: str | None): print('history:', '\n'.join(str(e.node) for e in history), sep='\n') print('Finished!') break - elif isinstance(node, Ask): + elif isinstance(node, Answer): print(node.question) break # otherwise just continue @@ -165,12 +164,7 @@ async def run_as_cli(answer: str | None): sys.exit(1) if sub_command == 'mermaid': - print(question_graph.mermaid_code(start_node=GenerateQuestion)) - print( - question_graph.mermaid_save( - 'question_graph.jpg', start_node=GenerateQuestion - ) - ) + print(question_graph.mermaid_code(start_node=Ask)) elif sub_command == 'continuous': asyncio.run(run_as_continuous()) else: diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index f9fff4929d..a060758fac 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -14,8 +14,8 @@ from typing_inspection import typing_objects from . import _utils, exceptions, mermaid -from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT -from .persistence import StatePersistence, StateT, set_nodes_type_context +from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT, StateT +from .persistence import BaseStatePersistence, set_nodes_type_context from .persistence.in_mem import SimpleStatePersistence # while waiting for https://github.com/pydantic/logfire/issues/745 @@ -125,7 +125,7 @@ async def run( *, state: StateT = None, deps: DepsT = None, - persistence: StatePersistence[StateT, T] | None = None, + persistence: BaseStatePersistence[StateT, T] | None = None, infer_name: bool = True, span: LogfireSpan | None = None, ) -> GraphRunResult[StateT, T]: @@ -181,7 +181,7 @@ def run_sync( *, state: StateT = None, deps: DepsT = None, - persistence: StatePersistence[StateT, T] | None = None, + persistence: BaseStatePersistence[StateT, T] | None = None, infer_name: bool = True, ) -> GraphRunResult[StateT, T]: """Synchronously run the graph. @@ -215,7 +215,7 @@ async def iter( *, state: StateT = None, deps: DepsT = None, - persistence: StatePersistence[StateT, T] | None = None, + persistence: BaseStatePersistence[StateT, T] | None = None, infer_name: bool = True, span: AbstractContextManager[Any] | None = None, ) -> AsyncIterator[GraphRun[StateT, DepsT, T]]: @@ -268,7 +268,7 @@ async def iter( async def next( self: Graph[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T], - persistence: StatePersistence[StateT, T], + persistence: BaseStatePersistence[StateT, T], *, state: StateT = None, deps: DepsT = None, @@ -302,17 +302,20 @@ async def next( async def next_from_persistence( self: Graph[StateT, DepsT, T], - persistence: StatePersistence[StateT, T], + persistence: BaseStatePersistence[StateT, T], *, deps: DepsT = None, infer_name: bool = True, ) -> BaseNode[StateT, DepsT, Any] | End[T]: + """Run the next node in the graph from a snapshot stored in persistence.""" if infer_name and self.name is None: self._infer_name(inspect.currentframe()) + self.set_persistence_types(persistence) snapshot = await persistence.retrieve_next() if snapshot is None: raise exceptions.GraphRuntimeError('Unable to restore snapshot from state persistence.') + assert snapshot.id is not None, 'Snapshot ID should be set' snapshot.node.set_snapshot_id(snapshot.id) run = GraphRun[StateT, DepsT, T]( @@ -325,7 +328,7 @@ async def next_from_persistence( ) return await run.next() - def set_persistence_types(self, persistence: StatePersistence[StateT, RunEndT]) -> None: + def set_persistence_types(self, persistence: BaseStatePersistence[StateT, RunEndT]) -> None: with set_nodes_type_context([node_def.node for node_def in self.node_defs.values()]): persistence.set_types(lambda: self._inferred_types) @@ -590,7 +593,7 @@ def __init__( *, graph: Graph[StateT, DepsT, RunEndT], start_node: BaseNode[StateT, DepsT, RunEndT], - persistence: StatePersistence[StateT, RunEndT], + persistence: BaseStatePersistence[StateT, RunEndT], state: StateT, deps: DepsT, snapshot_id: str | None = None, @@ -738,4 +741,4 @@ class GraphRunResult(Generic[StateT, RunEndT]): output: RunEndT state: StateT - persistence: StatePersistence[StateT, RunEndT] = field(repr=False) + persistence: BaseStatePersistence[StateT, RunEndT] = field(repr=False) diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index 5bf74e1a09..f460c556ab 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -4,20 +4,18 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, is_dataclass from functools import cache -from typing import TYPE_CHECKING, Any, ClassVar, Generic, get_type_hints +from typing import Any, ClassVar, Generic, get_type_hints from uuid import uuid4 from typing_extensions import Never, Self, TypeVar, get_origin from . import _utils, exceptions -if TYPE_CHECKING: - from .persistence import StateT -else: - StateT = TypeVar('StateT', default=None) +__all__ = 'GraphRunContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'DepsT', 'StateT', 'RunEndT' -__all__ = 'GraphRunContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'DepsT', 'RunEndT' +StateT = TypeVar('StateT', default=None) +"""Type variable for the state in a graph.""" RunEndT = TypeVar('RunEndT', covariant=True, default=None) """Covariant type variable for the return type of a graph [`run`][pydantic_graph.graph.Graph.run].""" NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) diff --git a/pydantic_graph/pydantic_graph/persistence/__init__.py b/pydantic_graph/pydantic_graph/persistence/__init__.py index e044effceb..a134bf8f46 100644 --- a/pydantic_graph/pydantic_graph/persistence/__init__.py +++ b/pydantic_graph/pydantic_graph/persistence/__init__.py @@ -10,7 +10,7 @@ import pydantic from typing_extensions import TypeVar -from ..nodes import BaseNode, End, RunEndT +from ..nodes import BaseNode, End from . import _utils __all__ = ( @@ -18,13 +18,13 @@ 'NodeSnapshot', 'EndSnapshot', 'Snapshot', - 'StatePersistence', + 'BaseStatePersistence', 'set_nodes_type_context', 'SnapshotStatus', ) -StateT = TypeVar('StateT', default=None) -"""Type variable for the state in a graph.""" +StateT = TypeVar('StateT', default=Any) +RunEndT = TypeVar('RunEndT', covariant=True, default=Any) UNSET_SNAPSHOT_ID = '__unset__' SnapshotStatus = Literal['created', 'pending', 'running', 'success', 'error'] @@ -32,7 +32,7 @@ - `'created'`: The snapshot has been created but not yet run. - `'pending'`: The snapshot has been retrieved with - [`retrieve_next`][pydantic_graph.persistence.StatePersistence.retrieve_next] but not yet run. + [`retrieve_next`][pydantic_graph.persistence.BaseStatePersistence.retrieve_next] but not yet run. - `'running'`: The snapshot is currently running. - `'success'`: The snapshot has been run successfully. - `'error'`: The snapshot has been run but an error occurred. @@ -101,7 +101,7 @@ def node(self) -> End[RunEndT]: """ -class StatePersistence(ABC, Generic[StateT, RunEndT]): +class BaseStatePersistence(ABC, Generic[StateT, RunEndT]): """Abstract base class for storing the state of a graph.""" @abstractmethod @@ -164,6 +164,9 @@ def record_run(self, snapshot_id: str) -> AbstractAsyncContextManager[None]: async def retrieve_next(self) -> NodeSnapshot[StateT, RunEndT] | None: """Retrieve a node snapshot with status `'created`' and set its status to `'pending'`. + This is used by [`Graph.next_from_persistence`][pydantic_graph.graph.Graph.next_from_persistence] + to get the next node to run. + Returns: The snapshot, or `None` if no snapshot with status `'created`' exists. """ raise NotImplementedError diff --git a/pydantic_graph/pydantic_graph/persistence/_utils.py b/pydantic_graph/pydantic_graph/persistence/_utils.py index d396d28d45..72f05a3277 100644 --- a/pydantic_graph/pydantic_graph/persistence/_utils.py +++ b/pydantic_graph/pydantic_graph/persistence/_utils.py @@ -23,7 +23,8 @@ def __get_pydantic_core_schema__( nodes = nodes_type_context.get() except LookupError as e: raise RuntimeError( - 'Unable to build a Pydantic schema for `BaseNode` without setting `nodes_type_context`.' + 'Unable to build a Pydantic schema for `BaseNode` without setting `nodes_type_context`. ' + 'You should build Pydantic schemas for snapshots using `StatePersistence.set_types()`.' ) from e if len(nodes) == 1: nodes_type = nodes[0] diff --git a/pydantic_graph/pydantic_graph/persistence/file.py b/pydantic_graph/pydantic_graph/persistence/file.py index 7894dfce30..69687743ba 100644 --- a/pydantic_graph/pydantic_graph/persistence/file.py +++ b/pydantic_graph/pydantic_graph/persistence/file.py @@ -10,29 +10,28 @@ from typing import Any, Callable import pydantic -from typing_extensions import TypeVar from .. import _utils as _graph_utils, exceptions from ..nodes import BaseNode, End from . import ( + BaseStatePersistence, EndSnapshot, NodeSnapshot, + RunEndT, Snapshot, SnapshotStatus, - StatePersistence, + StateT, _utils, build_snapshot_list_type_adapter, ) -StateT = TypeVar('StateT', default=Any) -RunEndT = TypeVar('RunEndT', default=Any) - @dataclass -class FileStatePersistence(StatePersistence[StateT, RunEndT]): - """State persistence that just hold the latest snapshot.""" +class FileStatePersistence(BaseStatePersistence[StateT, RunEndT]): + """File based state persistence that hold a list of snapshots in a JSON file.""" json_file: Path + """Path to the JSON file where the snapshots are stored.""" _snapshots_type_adapter: pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]] | None = field( default=None, init=False, repr=False ) diff --git a/pydantic_graph/pydantic_graph/persistence/in_mem.py b/pydantic_graph/pydantic_graph/persistence/in_mem.py index 1854341fad..df7fe203fc 100644 --- a/pydantic_graph/pydantic_graph/persistence/in_mem.py +++ b/pydantic_graph/pydantic_graph/persistence/in_mem.py @@ -13,25 +13,23 @@ from typing import Any, Callable import pydantic -from typing_extensions import TypeVar from .. import exceptions from ..nodes import BaseNode, End from . import ( + BaseStatePersistence, EndSnapshot, NodeSnapshot, + RunEndT, Snapshot, - StatePersistence, + StateT, _utils, build_snapshot_list_type_adapter, ) -StateT = TypeVar('StateT', default=Any) -RunEndT = TypeVar('RunEndT', default=Any) - @dataclass -class SimpleStatePersistence(StatePersistence[StateT, RunEndT]): +class SimpleStatePersistence(BaseStatePersistence[StateT, RunEndT]): """Simple in memory state persistence that just hold the latest snapshot. If no state persistence implementation is provided when running a graph, this is used by default. @@ -85,8 +83,8 @@ async def load(self) -> list[Snapshot[StateT, RunEndT]]: @dataclass -class FullStatePersistence(StatePersistence[StateT, RunEndT]): - """In memory state persistence that hold a history of nodes.""" +class FullStatePersistence(BaseStatePersistence[StateT, RunEndT]): + """In memory state persistence that hold a list of snapshots.""" deep_copy: bool = True """Whether to deep copy the state and nodes when storing them. diff --git a/tests/graph/test_persistence.py b/tests/graph/test_persistence.py index b9bcd8f54e..4eef4140ca 100644 --- a/tests/graph/test_persistence.py +++ b/tests/graph/test_persistence.py @@ -4,7 +4,6 @@ import json from dataclasses import dataclass from datetime import datetime, timezone -from typing import Any import pytest from dirty_equals import IsStr @@ -21,7 +20,7 @@ SimpleStatePersistence, ) from pydantic_graph.exceptions import GraphNodeStatusError, GraphRuntimeError -from pydantic_graph.persistence import StatePersistence +from pydantic_graph.persistence import BaseStatePersistence, build_snapshot_list_type_adapter from ..conftest import IsFloat, IsNow @@ -295,7 +294,7 @@ async def run(self, ctx: GraphRunContext) -> End[int]: @pytest.mark.parametrize('persistence_cls', [SimpleStatePersistence, FullStatePersistence]) -async def test_next_from_persistence(persistence_cls: type[StatePersistence[None, int]], mock_snapshot_id: object): +async def test_next_from_persistence(persistence_cls: type[BaseStatePersistence[None, int]], mock_snapshot_id: object): @dataclass class Foo(BaseNode): async def run(self, ctx: GraphRunContext) -> Spam: @@ -327,7 +326,7 @@ async def run(self, ctx: GraphRunContext) -> End[int]: @pytest.mark.parametrize('persistence_cls', [SimpleStatePersistence, FullStatePersistence]) -async def test_record_lookup_error(persistence_cls: type[StatePersistence[Any, Any]]): +async def test_record_lookup_error(persistence_cls: type[BaseStatePersistence]): persistence = persistence_cls() my_graph = Graph(nodes=(Foo, Bar)) my_graph.set_persistence_types(persistence) @@ -336,3 +335,8 @@ async def test_record_lookup_error(persistence_cls: type[StatePersistence[Any, A with pytest.raises(LookupError, match="No snapshot found with id='foobar'"): async with persistence.record_run('foobar'): pass + + +def test_snapshot_type_adapter_error(): + with pytest.raises(RuntimeError, match='Unable to build a Pydantic schema for `BaseNode` without setting'): + build_snapshot_list_type_adapter(int, int) diff --git a/tests/test_examples.py b/tests/test_examples.py index 710f07b3d1..c07db0e9b2 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -107,6 +107,8 @@ def test_docs_examples( examples = [{'request': f'sql prompt {i}', 'sql': f'SELECT {i}'} for i in range(15)] with (tmp_path / 'examples.json').open('w') as f: json.dump(examples, f) + elif opt_title in {'ai_q_and_a_run.py', 'count_down_from_persistence.py'}: + os.chdir(tmp_path) ruff_ignore: list[str] = ['D', 'Q001'] # `from bank_database import DatabaseConn` wrongly sorted in imports diff --git a/tests/typed_graph.py b/tests/typed_graph.py index cda6b3c6c3..0a6cfa8a0c 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -6,7 +6,7 @@ from typing_extensions import assert_type from pydantic_graph import BaseNode, End, FullStatePersistence, Graph, GraphRunContext -from pydantic_graph.persistence import StatePersistence +from pydantic_graph.persistence import BaseStatePersistence @dataclass @@ -118,7 +118,7 @@ def run_g5() -> None: def run_g6() -> None: result = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y')) assert_type(result.output, int) - assert_type(result.persistence, StatePersistence[MyState, int]) + assert_type(result.persistence, BaseStatePersistence[MyState, int]) p = FullStatePersistence() From f1e4ca14f173a3771e33b505b93b2082c48b6eae Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 10 Mar 2025 16:39:25 +0000 Subject: [PATCH 18/25] complete docs --- docs/graph.md | 54 ++++++++++++------- .../pydantic_ai_examples/question_graph.py | 2 +- pydantic_graph/pydantic_graph/graph.py | 19 +++---- .../pydantic_graph/persistence/__init__.py | 43 ++++++++------- .../pydantic_graph/persistence/_utils.py | 12 ++++- .../pydantic_graph/persistence/file.py | 12 +++-- .../pydantic_graph/persistence/in_mem.py | 11 ++-- tests/graph/test_file_persistence.py | 4 +- tests/graph/test_graph.py | 8 +-- tests/graph/test_persistence.py | 16 +++--- tests/graph/test_state.py | 2 +- 11 files changed, 106 insertions(+), 77 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index a6440598f3..ed1525450f 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -291,7 +291,7 @@ async def main(): 1. The state of the vending machine is defined as a dataclass with the user's balance and the product they've selected, if any. 2. A dictionary of products mapped to prices. 3. The `InsertCoin` node, [`BaseNode`][pydantic_graph.nodes.BaseNode] is parameterized with `MachineState` as that's the state used in this graph. -4. The `InsertCoin` node prompts the user to insert coins. We keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] within nodes, see [below](#custom-control-flow) for how control flow can be managed when nodes require external input. +4. The `InsertCoin` node prompts the user to insert coins. We keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] within nodes, see [below](#example-qa-with-genai) for how control flow can be managed when nodes require external input. 5. The `CoinsInserted` node; again this is a [`dataclass`][dataclasses.dataclass] with one field `amount`. 6. Update the user's balance with the amount inserted. 7. If the user has already selected a product, go to `Purchase`, otherwise go to `SelectProduct`. @@ -585,41 +585,55 @@ At a high level the role of `StatePersistence` implementations is to store and r We can run the `count_down_graph` from [above](#iterating-over-a-graph), using [`graph.next_from_persistence()`][pydantic_graph.graph.Graph.next_from_persistence] and [`FileStatePersistence`][pydantic_graph.persistence.file.FileStatePersistence]. +As you can see in this code, `run_node` requires no external application state apart from state persistence to be run, meaning graphs can easily be executed by distributed execution and queueing systems. + ```python {title="count_down_from_persistence.py" noqa="I001" py="3.10"} from pathlib import Path + from pydantic_graph import End from pydantic_graph.persistence.file import FileStatePersistence + from count_down import CountDown, CountDownState, count_down_graph async def main(): persistence = FileStatePersistence(Path('count_down.json')) # (1)! state = CountDownState(counter=5) - await count_down_graph.next(CountDown(), state=state, persistence=persistence) + await count_down_graph.next( # (2)! + CountDown(), state=state, persistence=persistence + ) done = False while not done: done = await run_node() -async def run_node() -> bool: # (2)! +async def run_node() -> bool: # (3)! persistence = FileStatePersistence(Path('count_down.json')) - node_or_end = await count_down_graph.next_from_persistence(persistence) # (3)! + node_or_end = await count_down_graph.next_from_persistence(persistence) # (4)! print('Node:', node_or_end) #> Node: CountDown() #> Node: CountDown() #> Node: CountDown() #> Node: CountDown() #> Node: End(data=0) - return isinstance(node_or_end, End) + return isinstance(node_or_end, End) # (5)! ``` +1. Create a [`FileStatePersistence`][pydantic_graph.persistence.file.FileStatePersistence] to use to start the graph. +2. Call [`graph.next()`][pydantic_graph.graph.Graph.next] to start the graph with the initial state. +3. `run_node` is a pure function that doesn't need access to any other process state to run the next node of the graph. +4. Call [`graph.next_from_persistence()`][pydantic_graph.graph.Graph.next_from_persistence] to run the next node of the graph from the state stored in persistence. This will return either a node or an `End` object. +5. Check if the node is an `End` object, if it is, the graph run is complete. + _(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ ### Example: Q&A with GenAI In this example, an AI asks the user a question, the user provides an answer, the AI evaluates the answer and ends if the user got it right or asks another question if they got it wrong. +Instead of running the entire graph in a single process invocation, we run the graph by running the process repeatedly, optionally providing an answer to the question as a command line argument. + ??? example "`ai_q_and_a_graph.py` — `question_graph` definition" ```python {title="ai_q_and_a_graph.py" noqa="I001" py="3.10"} from __future__ import annotations as _annotations @@ -734,42 +748,44 @@ from ai_q_and_a_graph import Ask, question_graph, Evaluate, QuestionState, Answe async def main(): answer: str | None = sys.argv[2] if len(sys.argv) > 2 else None # (1)! persistence = FileStatePersistence(Path('question_graph.json')) # (2)! - question_graph.set_persistence_types(persistence) # (3)! + persistence.set_graph_types(question_graph) # (3)! if snapshot := await persistence.retrieve_next(): # (4)! state = snapshot.state assert answer is not None - node = Evaluate(answer) # (5)! + node = Evaluate(answer) else: state = QuestionState() - node = Ask() # (6)! + node = Ask() # (5)! while True: - node = await question_graph.next( # (7)! + node = await question_graph.next( # (6)! node, persistence=persistence, state=state ) - if isinstance(node, End): + if isinstance(node, End): # (7)! print('END:', node.data) history = await persistence.load() print([e.node for e in history]) break - elif isinstance(node, Answer): + elif isinstance(node, Answer): # (8)! print(node.question) #> What is the capital of France? break # otherwise just continue ``` -1. Create the state object which will be mutated by [`next`][pydantic_graph.graph.Graph.next]. -2. The start node is `Ask` but will be updated by [`next`][pydantic_graph.graph.Graph.next] as the graph runs. -3. The history of the graph run is stored using [`FullStatePersistence`][pydantic_graph.FullStatePersistence]. Again [`next`][pydantic_graph.graph.Graph.next] will update this list in place. -4. [Run][pydantic_graph.graph.Graph.next] the graph one node at a time, updating the state, current node and history as the graph runs. -5. If the current node is an `Question` node, prompt the user for an answer. -6. Since we're using [`next`][pydantic_graph.graph.Graph.next] we have to manually check for an [`End`][pydantic_graph.nodes.End] and exit the loop if we get one. +1. Get the user's answer from the command line, if provided. See [question graph example](examples/question-graph.md) for a complete example. +2. Create a state persistence instance the `'question_graph.json'` file may or may not already exist. +3. Since we're using the [persistence interface][pydantic_graph.persistence.BaseStatePersistence] outside a graph, we need to call [`set_graph_types`][pydantic_graph.persistence.BaseStatePersistence.set_graph_types] to set the graph generic types `StateT` and `RunEndT` for the persistence instance. This is necessary to allow the persistence instance to know how to serialize and deserialize graph nodes. +4. If we're run the graph before, [`retrieve_next`][pydantic_graph.persistence.BaseStatePersistence.retrieve_next] will return a snapshot of the next node to run, here we use `state` from that snapshot, and create a new `Evaluate` node with the answer provided on the command line. +5. If the graph hasn't been run before, we create a new `QuestionState` and start with the `Ask` node. +6. Call [`graph.next()`][pydantic_graph.graph.Graph.next] to run the node. This will return either a node or an `End` object. +7. If the node is an `End` object, the graph run is complete. The `data` field of the `End` object contains the comment returned by the `evaluate_agent` about the correct answer. +8. If the node is an `Answer` object, we print the question and break out of the loop to end the process and wait for user input. _(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main(answer))` to run `main`)_ -For an example of genuine out-of-process control flow, see the [question graph example](examples/question-graph.md). +For a complete example of this graph, see the [question graph example](examples/question-graph.md). ## Dependency Injection @@ -863,7 +879,7 @@ Beyond the diagrams shown above, you can also customize mermaid diagrams with th * [`BaseNode.docstring_notes`][pydantic_graph.nodes.BaseNode.docstring_notes] and [`BaseNode.get_note`][pydantic_graph.nodes.BaseNode.get_note] allows you to add notes to nodes * The [`highlighted_nodes`][pydantic_graph.graph.Graph.mermaid_code] parameter allows you to highlight specific node(s) in the diagram -Putting that together, we can edit the last [`ai_q_and_a_graph.py`](#custom-control-flow) example to: +Putting that together, we can edit the last [`ai_q_and_a_graph.py`](#example-qa-with-genai) example to: * add labels to some edges * add a note to the `Ask` node diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index a2009fab3f..79ee60b29f 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -116,7 +116,7 @@ async def run_as_continuous(): async def run_as_cli(answer: str | None): persistence = FileStatePersistence(Path('question_graph.json')) - question_graph.set_persistence_types(persistence) + persistence.set_graph_types(question_graph) if snapshot := await persistence.retrieve_next(): state = snapshot.state diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index a060758fac..5ec102d911 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -15,7 +15,7 @@ from . import _utils, exceptions, mermaid from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT, StateT -from .persistence import BaseStatePersistence, set_nodes_type_context +from .persistence import BaseStatePersistence from .persistence.in_mem import SimpleStatePersistence # while waiting for https://github.com/pydantic/logfire/issues/745 @@ -253,7 +253,7 @@ async def iter( if persistence is None: persistence = SimpleStatePersistence() - self.set_persistence_types(persistence) + persistence.set_graph_types(self) if self.auto_instrument and span is None: span = logfire_api.span('run graph {graph.name}', graph=self) @@ -290,7 +290,7 @@ async def next( if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - self.set_persistence_types(persistence) + persistence.set_graph_types(self) run = GraphRun[StateT, DepsT, T]( graph=self, start_node=node, @@ -310,7 +310,7 @@ async def next_from_persistence( """Run the next node in the graph from a snapshot stored in persistence.""" if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - self.set_persistence_types(persistence) + persistence.set_graph_types(self) snapshot = await persistence.retrieve_next() if snapshot is None: @@ -328,10 +328,6 @@ async def next_from_persistence( ) return await run.next() - def set_persistence_types(self, persistence: BaseStatePersistence[StateT, RunEndT]) -> None: - with set_nodes_type_context([node_def.node for node_def in self.node_defs.values()]): - persistence.set_types(lambda: self._inferred_types) - def mermaid_code( self, *, @@ -455,8 +451,13 @@ def mermaid_save( kwargs['title'] = self.name mermaid.save_image(path, self, **kwargs) + def get_nodes(self) -> Sequence[type[BaseNode[StateT, DepsT, RunEndT]]]: + """Get the nodes in the graph.""" + return [node_def.node for node_def in self.node_defs.values()] + @cached_property - def _inferred_types(self) -> tuple[type[StateT], type[RunEndT]]: + def inferred_types(self) -> tuple[type[StateT], type[RunEndT]]: + # Get the types of the state and run end from the graph. if _utils.is_set(self._state_type) and _utils.is_set(self._run_end_type): return self._state_type, self._run_end_type diff --git a/pydantic_graph/pydantic_graph/persistence/__init__.py b/pydantic_graph/pydantic_graph/persistence/__init__.py index a134bf8f46..91c632b789 100644 --- a/pydantic_graph/pydantic_graph/persistence/__init__.py +++ b/pydantic_graph/pydantic_graph/persistence/__init__.py @@ -1,11 +1,10 @@ from __future__ import annotations as _annotations from abc import ABC, abstractmethod -from collections.abc import Iterator, Sequence -from contextlib import AbstractAsyncContextManager, contextmanager +from contextlib import AbstractAsyncContextManager from dataclasses import dataclass, field from datetime import datetime -from typing import Annotated, Any, Callable, Generic, Literal, Union +from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, Union import pydantic from typing_extensions import TypeVar @@ -13,18 +12,13 @@ from ..nodes import BaseNode, End from . import _utils -__all__ = ( - 'StateT', - 'NodeSnapshot', - 'EndSnapshot', - 'Snapshot', - 'BaseStatePersistence', - 'set_nodes_type_context', - 'SnapshotStatus', -) +if TYPE_CHECKING: + from .. import Graph + +__all__ = 'StateT', 'NodeSnapshot', 'EndSnapshot', 'Snapshot', 'BaseStatePersistence', 'SnapshotStatus' StateT = TypeVar('StateT', default=Any) -RunEndT = TypeVar('RunEndT', covariant=True, default=Any) +RunEndT = TypeVar('RunEndT', default=Any) UNSET_SNAPSHOT_ID = '__unset__' SnapshotStatus = Literal['created', 'pending', 'running', 'success', 'error'] @@ -179,25 +173,30 @@ async def load(self) -> list[Snapshot[StateT, RunEndT]]: """ raise NotImplementedError - def set_types(self, get_types: Callable[[], tuple[type[StateT], type[RunEndT]]]) -> None: + def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: """Set the types of the state and run end. This can be used to create [type adapters][pydantic.TypeAdapter] for serializing and deserializing snapshots. Args: - get_types: A callback that returns the types of the state and run end. + state_type: The state type. + run_end_type: The run end type. """ pass + def _should_set_types(self) -> bool: + """Whether types need to be set. + + Implementations should override this method to return `True` when types have not been set if they are needed. + """ + return False -@contextmanager -def set_nodes_type_context(nodes: Sequence[type[BaseNode[Any, Any, Any]]]) -> Iterator[None]: # noqa: D103 - token = _utils.nodes_type_context.set(nodes) - try: - yield - finally: - _utils.nodes_type_context.reset(token) + def set_graph_types(self, graph: Graph[StateT, Any, RunEndT]) -> None: + """Set the types of the state and run end from a graph.""" + if self._should_set_types(): + with _utils.set_nodes_type_context(graph.get_nodes()): + self.set_types(*graph.inferred_types) def build_snapshot_list_type_adapter( diff --git a/pydantic_graph/pydantic_graph/persistence/_utils.py b/pydantic_graph/pydantic_graph/persistence/_utils.py index 72f05a3277..ce28131465 100644 --- a/pydantic_graph/pydantic_graph/persistence/_utils.py +++ b/pydantic_graph/pydantic_graph/persistence/_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations -from collections.abc import Sequence +from collections.abc import Iterator, Sequence +from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass from datetime import datetime, timezone @@ -52,3 +53,12 @@ def _node_serializer(node: Any, handler: pydantic.SerializerFunctionWrapHandler) def now_utc() -> datetime: return datetime.now(tz=timezone.utc) + + +@contextmanager +def set_nodes_type_context(nodes: Sequence[type[BaseNode[Any, Any, Any]]]) -> Iterator[None]: + token = nodes_type_context.set(nodes) + try: + yield + finally: + nodes_type_context.reset(token) diff --git a/pydantic_graph/pydantic_graph/persistence/file.py b/pydantic_graph/pydantic_graph/persistence/file.py index 69687743ba..b41fb86cc3 100644 --- a/pydantic_graph/pydantic_graph/persistence/file.py +++ b/pydantic_graph/pydantic_graph/persistence/file.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field from pathlib import Path from time import perf_counter -from typing import Any, Callable +from typing import Any import pydantic @@ -86,10 +86,12 @@ async def retrieve_next(self) -> NodeSnapshot[StateT, RunEndT] | None: await self._save(snapshots) return snapshot - def set_types(self, get_types: Callable[[], tuple[type[StateT], type[RunEndT]]]) -> None: - if self._snapshots_type_adapter is None: - state_t, run_end_t = get_types() - self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_t, run_end_t) + def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: + self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type) + + def _should_set_types(self) -> bool: + """Whether types need to be set.""" + return self._snapshots_type_adapter is None async def load(self) -> list[Snapshot[StateT, RunEndT]]: return await _graph_utils.run_in_executor(self._load_sync) diff --git a/pydantic_graph/pydantic_graph/persistence/in_mem.py b/pydantic_graph/pydantic_graph/persistence/in_mem.py index df7fe203fc..eb275a5261 100644 --- a/pydantic_graph/pydantic_graph/persistence/in_mem.py +++ b/pydantic_graph/pydantic_graph/persistence/in_mem.py @@ -10,7 +10,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field from time import perf_counter -from typing import Any, Callable +from typing import Any import pydantic @@ -148,10 +148,11 @@ async def retrieve_next(self) -> NodeSnapshot[StateT, RunEndT] | None: async def load(self) -> list[Snapshot[StateT, RunEndT]]: return self.history - def set_types(self, get_types: Callable[[], tuple[type[StateT], type[RunEndT]]]) -> None: - if self._snapshots_type_adapter is None: - state_t, run_end_t = get_types() - self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_t, run_end_t) + def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: + self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type) + + def _should_set_types(self) -> bool: + return self._snapshots_type_adapter is None def dump_json(self, *, indent: int | None = None) -> bytes: """Dump the history to JSON bytes.""" diff --git a/tests/graph/test_file_persistence.py b/tests/graph/test_file_persistence.py index b178f58531..ec5dc30646 100644 --- a/tests/graph/test_file_persistence.py +++ b/tests/graph/test_file_persistence.py @@ -194,8 +194,8 @@ async def test_record_lookup_error(tmp_path: Path): p = tmp_path / 'test_graph.json' persistence = FileStatePersistence(p) my_graph = Graph(nodes=(Float2String, String2Length, Double)) - my_graph.set_persistence_types(persistence) - my_graph.set_persistence_types(persistence) + persistence.set_graph_types(my_graph) + persistence.set_graph_types(my_graph) with pytest.raises(LookupError, match="No snapshot found with id='foobar'"): async with persistence.record_run('foobar'): diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 1714d2f35f..cf07b170d4 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -58,7 +58,7 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # async def test_graph(): my_graph = Graph(nodes=(Float2String, String2Length, Double)) assert my_graph.name is None - assert my_graph._inferred_types == (type(None), int) + assert my_graph.inferred_types == (type(None), int) result = await my_graph.run(Float2String(3.14)) # len('3.14') * 2 == 8 assert result.output == 8 @@ -68,7 +68,7 @@ async def test_graph(): async def test_graph_history(mock_snapshot_id: object): my_graph = Graph[None, None, int](nodes=(Float2String, String2Length, Double)) assert my_graph.name is None - assert my_graph._inferred_types == (type(None), int) + assert my_graph.inferred_types == (type(None), int) sp = FullStatePersistence() result = await my_graph.run(Float2String(3.14), persistence=sp) # len('3.14') * 2 == 8 @@ -292,7 +292,7 @@ async def run(self, ctx: GraphRunContext) -> End[None]: return 42 # type: ignore g = Graph(nodes=(Foo, Bar)) - assert g._inferred_types == (type(None), type(None)) + assert g.inferred_types == (type(None), type(None)) with pytest.raises(GraphRuntimeError) as exc_info: await g.run(Foo()) @@ -302,7 +302,7 @@ async def run(self, ctx: GraphRunContext) -> End[None]: async def test_iter(): my_graph = Graph(nodes=(Float2String, String2Length, Double)) assert my_graph.name is None - assert my_graph._inferred_types == (type(None), int) + assert my_graph.inferred_types == (type(None), int) node_reprs: list[str] = [] async with my_graph.iter(Float2String(3.14)) as graph_iter: assert repr(graph_iter) == snapshot('') diff --git a/tests/graph/test_persistence.py b/tests/graph/test_persistence.py index 4eef4140ca..31f52c80af 100644 --- a/tests/graph/test_persistence.py +++ b/tests/graph/test_persistence.py @@ -119,7 +119,7 @@ async def test_dump_load_state(graph: Graph[MyState, None, int], mock_snapshot_i ) sp2 = FullStatePersistence() - graph.set_persistence_types(sp2) + sp2.set_graph_types(graph) sp2.load_json(history_json) assert sp.history == sp2.history @@ -140,7 +140,7 @@ async def test_dump_load_state(graph: Graph[MyState, None, int], mock_snapshot_i }, ] sp3 = FullStatePersistence() - graph.set_persistence_types(sp3) + sp3.set_graph_types(graph) sp3.load_json(json.dumps(custom_history)) assert sp3.history == snapshot( [ @@ -181,7 +181,7 @@ async def run(self, ctx: GraphRunContext) -> End[int]: # pragma: no cover }, ] sp = FullStatePersistence() - g.set_persistence_types(sp) + sp.set_graph_types(g) sp.load_json(json.dumps(custom_history)) assert sp.history == snapshot( [ @@ -203,11 +203,11 @@ async def run(self, ctx: GraphRunContext) -> NoGenericArgsNode: # pragma: no co return NoGenericArgsNode() g = Graph(nodes=[NoGenericArgsNode]) - assert g._inferred_types == (None, None) + assert g.inferred_types == (None, None) g = Graph(nodes=[NoGenericArgsNode], run_end_type=None) # pyright: ignore[reportArgumentType] - assert g._inferred_types == (None, None) + assert g.inferred_types == (None, None) custom_history = [ { @@ -220,7 +220,7 @@ async def run(self, ctx: GraphRunContext) -> NoGenericArgsNode: # pragma: no co ] sp = FullStatePersistence() - g.set_persistence_types(sp) + sp.set_graph_types(g) sp.load_json(json.dumps(custom_history)) assert sp.history == snapshot( @@ -329,8 +329,8 @@ async def run(self, ctx: GraphRunContext) -> End[int]: async def test_record_lookup_error(persistence_cls: type[BaseStatePersistence]): persistence = persistence_cls() my_graph = Graph(nodes=(Foo, Bar)) - my_graph.set_persistence_types(persistence) - my_graph.set_persistence_types(persistence) + persistence.set_graph_types(my_graph) + persistence.set_graph_types(my_graph) with pytest.raises(LookupError, match="No snapshot found with id='foobar'"): async with persistence.record_run('foobar'): diff --git a/tests/graph/test_state.py b/tests/graph/test_state.py index 563cbd6b29..c725b7aca7 100644 --- a/tests/graph/test_state.py +++ b/tests/graph/test_state.py @@ -40,7 +40,7 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[str]: return End(f'x={ctx.state.x} y={ctx.state.y}') graph = Graph(nodes=(Foo, Bar)) - assert graph._inferred_types == (MyState, str) # pyright: ignore[reportPrivateUsage] + assert graph.inferred_types == (MyState, str) state = MyState(1, '') sp = FullStatePersistence() result = await graph.run(Foo(), state=state, persistence=sp) From 0ad49fd584d2e12c2a3696377f493d3f354cb305 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 10 Mar 2025 17:23:01 +0000 Subject: [PATCH 19/25] replace human in the loop example --- docs/graph.md | 248 +------------------------------------------------- 1 file changed, 5 insertions(+), 243 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index a38ef017f6..09f385d81e 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -291,7 +291,7 @@ async def main(): 1. The state of the vending machine is defined as a dataclass with the user's balance and the product they've selected, if any. 2. A dictionary of products mapped to prices. 3. The `InsertCoin` node, [`BaseNode`][pydantic_graph.nodes.BaseNode] is parameterized with `MachineState` as that's the state used in this graph. -4. The `InsertCoin` node prompts the user to insert coins. We keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] within nodes, see [below](#example-qa-with-genai) for how control flow can be managed when nodes require external input. +4. The `InsertCoin` node prompts the user to insert coins. We keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] within nodes, see [below](#example-human-in-the-loop) for how control flow can be managed when nodes require external input. 5. The `CoinsInserted` node; again this is a [`dataclass`][dataclasses.dataclass] with one field `amount`. 6. Update the user's balance with the amount inserted. 7. If the user has already selected a product, go to `Purchase`, otherwise go to `SelectProduct`. @@ -628,7 +628,9 @@ async def run_node() -> bool: # (3)! _(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ -### Example: Q&A with GenAI +### Example: Human in the loop. + +As noted above, state persistence allows graphs to be interrupted and resumed. One use case of this is to allow user input to continue. In this example, an AI asks the user a question, the user provides an answer, the AI evaluates the answer and ends if the user got it right or asks another question if they got it wrong. @@ -787,246 +789,6 @@ _(This example is complete, it can be run "as is" with Python 3.10+ — you'll n For a complete example of this graph, see the [question graph example](examples/question-graph.md). -## **Interrupting Graph Execution** - -### Example: Pausing and Resuming with Human Review - -This example shows a simple graph that processes an order. -If the order amount is large, we require human review at a dedicated node and *pause* the workflow until that review occurs. - -We'll simulate persistence in a global dictionary rather than a real database. -We also show how to resume execution once the human has approved the order. - -```python {title="pause_and_resume.py" noqa="I001" py="3.10"} -import asyncio -from dataclasses import dataclass, field -from typing import Literal - -from typing_extensions import TypedDict -from pydantic import TypeAdapter - -from pydantic_graph import ( - BaseNode, - End, - Graph, - GraphRunContext, - HistoryStep, - GraphRunResult, -) - - -@dataclass -class OrderState: - """Order workflow state.""" - - order_id: str - amount: float - human_approved: bool = False # set to True after human review - - -class StoredRun(TypedDict): - """An object representing a mock-serialized run state.""" - - state: OrderState - history: bytes - node: bytes - - -# We'll use a global dictionary to simulate persist/load: -STORED_RUNS: dict[str, StoredRun] = {} - - -@dataclass -class CheckOrder(BaseNode[OrderState]): - """Check if this order needs human review.""" - - kind: Literal['check-order'] = field(default='check-order', init=False) - - async def run( - self, ctx: GraphRunContext[OrderState] - ) -> 'HumanReview | ProcessOrder': - if ctx.state.amount < 1000: - return ProcessOrder() # no human review required - else: - return HumanReview() # human review required - - -@dataclass -class HumanReview(BaseNode[OrderState]): - """Pause graph execution until a human sets `approved=True` in the order state.""" - - kind: Literal['human-review'] = field(default='human-review', init=False) - - async def run( - self, ctx: GraphRunContext[OrderState] - ) -> 'ProcessOrder | HumanReview': - if not ctx.state.human_approved: - # Still not approved: we'll stay on this node, effectively keeping the workflow paused - return self - return ProcessOrder() - - -@dataclass -class ProcessOrder(BaseNode[OrderState, None, str]): - """Final node: process the order.""" - - kind: Literal['process-order'] = field(default='process-order', init=False) - - async def run(self, ctx: GraphRunContext[OrderState]) -> End[str]: - # In a real system, you'd charge payment, update inventory, etc. - return End(f'Order {ctx.state.order_id} processed successfully!') - - -# Build the graph -order_graph = Graph[OrderState, None, str]( - nodes=[CheckOrder, HumanReview, ProcessOrder] -) -GraphNodeType = CheckOrder | HumanReview | ProcessOrder -node_adapter = TypeAdapter[GraphNodeType](GraphNodeType) - - -def persist_run_state( - run_id: str, - state: OrderState, - history: list[HistoryStep[OrderState, str]], - node: GraphNodeType, -) -> None: - """Simulate storing run state in a global dictionary.""" - STORED_RUNS[run_id] = StoredRun( - state=state, - history=order_graph.dump_history(history), - node=node_adapter.dump_json(node), - ) - - -def approve_order(run_id: str) -> None: - """Simulate a human approving an order.""" - stored_run = STORED_RUNS[run_id] - stored_run['state'].human_approved = True - - -def load_run_state( - run_id: str, -) -> tuple[OrderState, list[HistoryStep[OrderState, str]], GraphNodeType]: - """Simulate loading run state from a global dictionary.""" - stored_run = STORED_RUNS[run_id] - state = stored_run['state'] - history = order_graph.load_history(stored_run['history']) - node = node_adapter.validate_json(stored_run['node']) - return state, history, node - - -async def run_until_interrupted( - run_id: str, - state: OrderState, - history: list[HistoryStep[OrderState, str]], - start_node: GraphNodeType, -) -> GraphRunResult[OrderState, str] | tuple[HumanReview, OrderState]: - """Continue the workflow from any point.""" - async with order_graph.iter(start_node, state=state, history=history) as graph_run: - await graph_run.next() # The first node will be yielded before it has been run, so we ensure it runs first - async for node in graph_run: - if isinstance(node, HumanReview): - persist_run_state(run_id, state, history, node) - return node, state # Run is interrupted - - assert graph_run.result is not None # the graph run is complete at this point - return graph_run.result - - -async def begin_run( - run_id: str, amount: int -) -> GraphRunResult[OrderState, str] | tuple[HumanReview, OrderState]: - """Start the workflow. Possibly pause if human review is needed.""" - state = OrderState(order_id=run_id, amount=amount) - history: list[HistoryStep[OrderState, str]] = [] - node = CheckOrder() - return await run_until_interrupted(run_id, state, history, node) - - -async def resume_run( - run_id: str, -) -> GraphRunResult[OrderState, str] | tuple[HumanReview, OrderState]: - """Resume the workflow after human review.""" - state, history, node = load_run_state(run_id) - return await run_until_interrupted(run_id, state, history, node) - - -async def main(): - results = [] - - # Begin a run that will not require human review: - results.append(await begin_run('order-1', 100)) - - # Begin a run that _will_ require human review: - results.append(await begin_run('order-2', 1500)) - - # ... human review happens ... - approve_order('order-2') - - # Resume run after human review: - results.append(await resume_run('order-2')) - - return results - - -if __name__ == '__main__': - print(asyncio.run(main())) - """ - [ - GraphRunResult( - output='Order order-1 processed successfully!', - state=OrderState(order_id='order-1', amount=100, human_approved=False), - ), - ( - HumanReview(kind='human-review'), - OrderState(order_id='order-2', amount=1500, human_approved=True), - ), - GraphRunResult( - output='Order order-2 processed successfully!', - state=OrderState(order_id='order-2', amount=1500, human_approved=True), - ), - ] - """ -``` - -**How it works:** - -1. **`OrderState` and Node Classes** - - We define an `OrderState` dataclass that tracks the order ID, amount, and a `human_approved` flag. - - Three node classes (`CheckOrder`, `HumanReview`, `ProcessOrder`) use `pydantic-graph` generics to model a small state machine: - - `CheckOrder` decides whether we need human review (returns `HumanReview`) or can finalize directly. - - `HumanReview` loops on itself until someone sets `human_approved=True`. - - `ProcessOrder` completes the graph with an [`End`][pydantic_graph.nodes.End] node and a success message. - -2. **Global `STORED_RUNS` for Persistence** - - We simulate storing run state with a dictionary of typed-dict entries (`StoredRun`). - - For each "run," we store `OrderState`, serialized history (via [`graph.dump_history`][pydantic_graph.graph.Graph.dump_history]), and a serialized node. - -3. **`run_until_interrupted`** - - Accepts a starting node, plus the current `state` and `history`. - - Calls [`graph.iter`][pydantic_graph.graph.Graph.iter] to begin or continue the graph. - - If it encounters a `HumanReview` node, it persists the run and returns that node (thus "interrupting" the workflow). - - Otherwise, it continues until the graph ends. - -4. **`begin_run`** - - Creates a fresh `OrderState` (initializing the run) and starts from `CheckOrder`. - - It either completes immediately if no review is required or returns a `HumanReview` node if it needs sign-off. - -5. **`approve_order`** - - Emulates a real "human review" step by flipping `.human_approved` to True in the stored state. - -6. **`resume_run`** - - Loads the previously saved state, history, and node. - - Calls `run_until_interrupted` to continue from exactly where we left off, typically finalizing or pausing again. - -7. **In `main`** - - We run two orders: one small (`order-1`) that finishes immediately, and one large (`order-2`) that pauses. - - We call `approve_order("order-2")` to simulate a human approval, and then `resume_run("order-2")`. - - This finalizes the second order's workflow. - -While this is just a toy example, you can take a similar approach to build a persistent, interruptible workflow that uses `pydantic-graph` to pause execution at any node, store its state, and resume again after external events (like human approval) occur. - ## Dependency Injection As with PydanticAI, `pydantic-graph` supports dependency injection via a generic parameter on [`Graph`][pydantic_graph.graph.Graph] and [`BaseNode`][pydantic_graph.nodes.BaseNode], and the [`GraphRunContext.deps`][pydantic_graph.nodes.GraphRunContext.deps] field. @@ -1119,7 +881,7 @@ Beyond the diagrams shown above, you can also customize mermaid diagrams with th * [`BaseNode.docstring_notes`][pydantic_graph.nodes.BaseNode.docstring_notes] and [`BaseNode.get_note`][pydantic_graph.nodes.BaseNode.get_note] allows you to add notes to nodes * The [`highlighted_nodes`][pydantic_graph.graph.Graph.mermaid_code] parameter allows you to highlight specific node(s) in the diagram -Putting that together, we can edit the last [`ai_q_and_a_graph.py`](#example-qa-with-genai) example to: +Putting that together, we can edit the last [`ai_q_and_a_graph.py`](#example-human-in-the-loop) example to: * add labels to some edges * add a note to the `Ask` node From 27aee1aeaef399e7b7829049b8895852a7b55634 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 12 Mar 2025 10:55:28 +0000 Subject: [PATCH 20/25] Apply suggestions from code review Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com> --- docs/graph.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index 09f385d81e..6fea0fa4a9 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -561,13 +561,13 @@ async def main(): ## State Persistence -The greatest value of finite state machine (FSM) graphs comes when their execution is interrupted. This can be for a variety of reasons: +One of the biggest benefits of finite state machine (FSM) graphs is how they simplify the handling of interrupted execution. This might happen for a variety of reasons: -- because the logic they encompass must be paused — e.g. the returns workflow for an e-commerce order needs to wait for the item to be posted to the returns center or because execution of the next node needs input from a user so needs to wait for a new http request, -- because their execution takes long enough that the entire graph can't be executed in a single continuous run — e.g. a deep research agent that takes hours to run, -- or, because multiple nodes can be run in parallel on different instances (note: parallel node execution is not yet supported in `pydantic-graph`, see [#704](https://github.com/pydantic/pydantic-ai/issues/704)). +- the state machine logic might fundamentally need to be paused — e.g. the returns workflow for an e-commerce order needs to wait for the item to be posted to the returns center or because execution of the next node needs input from a user so needs to wait for a new http request, +- the execution takes so long that the entire graph can't reliably be executed in a single continuous run — e.g. a deep research agent that might take hours to run, +- you want to run multiple graph nodes in parallel in different processes / hardware instances (note: parallel node execution is not yet supported in `pydantic-graph`, see [#704](https://github.com/pydantic/pydantic-ai/issues/704)). -In all these scenarios, conventional control flow (boolean logic and nest function calls) breaks down and application code becomes spaghetti with the logic required to interrupt and resume execution dominating the code. +Trying to make a conventional control flow (i.e., boolean logic and nested function calls) implementation compatible with these usage scenarios generally results in brittle and over-complicated spaghetti code, with the logic required to interrupt and resume execution dominating the implementation. To allow graph runs to be interrupted and resumed, `pydantic-graph` provides state persistence — a system for snapshotting the state of a graph run before and after each node is run, allowing a graph run to be resumed from any point in the graph. @@ -575,17 +575,17 @@ To allow graph runs to be interrupted and resumed, `pydantic-graph` provides sta - [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] — Simple in memory state persistence that just hold the latest snapshot. If no state persistence implementation is provided when running a graph, this is used by default. - [`FullStatePersistence`][pydantic_graph.FullStatePersistence] — In memory state persistence that hold a list of snapshots. -- [`FileStatePersistence`][pydantic_graph.persistence.file.FileStatePersistence] — File based state persistence that saves snapshots to a JSON file. +- [`FileStatePersistence`][pydantic_graph.persistence.file.FileStatePersistence] — File-based state persistence that saves snapshots to a JSON file. -In production applications, developers should implement their own state persistence by subclassing [`BaseStatePersistence`][pydantic_graph.persistence.BaseStatePersistence] abstract base class. +In production applications, developers should implement their own state persistence by subclassing [`BaseStatePersistence`][pydantic_graph.persistence.BaseStatePersistence] abstract base class, which might persist runs in a relational database like PostgresQL. At a high level the role of `StatePersistence` implementations is to store and retrieve [`NodeSnapshot`][pydantic_graph.persistence.NodeSnapshot] and [`EndSnapshot`][pydantic_graph.persistence.EndSnapshot] objects. -[`graph.next_from_persistence()`][pydantic_graph.graph.Graph.next_from_persistence] may be used to run nodes of a graph based on its state stored in persistence. +[`graph.next_from_persistence()`][pydantic_graph.graph.Graph.next_from_persistence] may be used to run the next node of a graph run based on the state stored in persistence. We can run the `count_down_graph` from [above](#iterating-over-a-graph), using [`graph.next_from_persistence()`][pydantic_graph.graph.Graph.next_from_persistence] and [`FileStatePersistence`][pydantic_graph.persistence.file.FileStatePersistence]. -As you can see in this code, `run_node` requires no external application state apart from state persistence to be run, meaning graphs can easily be executed by distributed execution and queueing systems. +As you can see in this code, `run_node` requires no external application state (apart from state persistence) to be run, meaning graphs can easily be executed by distributed execution and queueing systems. ```python {title="count_down_from_persistence.py" noqa="I001" py="3.10"} from pathlib import Path From 76cae5ab5fa1a1548312fc48486a68e068ee07e0 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 12 Mar 2025 10:56:46 +0000 Subject: [PATCH 21/25] Apply suggestions from code review Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com> --- pydantic_graph/pydantic_graph/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_graph/pydantic_graph/exceptions.py b/pydantic_graph/pydantic_graph/exceptions.py index 9eb47b6f35..01876461a3 100644 --- a/pydantic_graph/pydantic_graph/exceptions.py +++ b/pydantic_graph/pydantic_graph/exceptions.py @@ -27,7 +27,7 @@ def __init__(self, message: str): class GraphNodeStatusError(GraphRuntimeError): - """Error caused by trying to run a node that has status other than `'created'` or `'pending'`.""" + """Error caused by trying to run a node that already has status `'running'`, `'success'`, or `'error'`.""" def __init__(self, actual_status: 'SnapshotStatus'): self.actual_status = actual_status From f51699eddaa8e157b6a3e3cd2387b55b69d954d1 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 12 Mar 2025 10:57:24 +0000 Subject: [PATCH 22/25] tweak docs as suggested --- docs/graph.md | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index 6fea0fa4a9..41aeb0a54e 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -529,14 +529,16 @@ Alternatively, you can drive iteration manually with the [`GraphRun.next`][pydan Below is a contrived example that stops whenever the counter is at 2, ignoring any node runs beyond that: ```python {title="count_down_next.py" noqa="I001" py="3.10"} -from pydantic_graph import End, SimpleStatePersistence +from pydantic_graph import End, FullStatePersistence from count_down import CountDown, CountDownState, count_down_graph async def main(): state = CountDownState(counter=5) - sp = SimpleStatePersistence() - async with count_down_graph.iter(CountDown(), state=state, persistence=sp) as run: + persistence = FullStatePersistence() # (7)! + async with count_down_graph.iter( + CountDown(), state=state, persistence=persistence + ) as run: node = run.next_node # (1)! while not isinstance(node, End): # (2)! print('Node:', node) @@ -550,6 +552,13 @@ async def main(): print(run.result) # (5)! #> None + + for step in persistence.history: # (6)! + print('History Step:', step.state, step.state) + #> History Step: CountDownState(counter=5) CountDownState(counter=5) + #> History Step: CountDownState(counter=4) CountDownState(counter=4) + #> History Step: CountDownState(counter=3) CountDownState(counter=3) + #> History Step: CountDownState(counter=2) CountDownState(counter=2) ``` 1. We start by grabbing the first node that will be run in the agent's graph. @@ -558,6 +567,7 @@ async def main(): 4. At each step, we call `await run.next(node)` to run it and get the next node (or an `End`). 5. Because we did not continue the run until it finished, the `result` is not set. 6. The run's history is still populated with the steps we executed so far. +7. Use [`FullStatePersistence`][pydantic_graph.FullStatePersistence] so we can show the history of the run, see [State Persistence](#state-persistence) below for more information. ## State Persistence From 1cccf75c5557f45f9e9478d3d0dbe260d3ac1754 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 12 Mar 2025 15:53:07 +0000 Subject: [PATCH 23/25] add iter_from_persistence, deprecate next() --- docs/graph.md | 60 ++++----- .../pydantic_ai_examples/question_graph.py | 4 +- pydantic_graph/pydantic_graph/graph.py | 119 ++++++++++++------ .../pydantic_graph/persistence/__init__.py | 7 +- .../pydantic_graph/persistence/file.py | 19 ++- tests/graph/test_file_persistence.py | 26 ++-- tests/graph/test_graph.py | 93 +++++++++----- tests/graph/test_persistence.py | 25 ++-- 8 files changed, 233 insertions(+), 120 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index 41aeb0a54e..d438d3001f 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -591,9 +591,9 @@ In production applications, developers should implement their own state persiste At a high level the role of `StatePersistence` implementations is to store and retrieve [`NodeSnapshot`][pydantic_graph.persistence.NodeSnapshot] and [`EndSnapshot`][pydantic_graph.persistence.EndSnapshot] objects. -[`graph.next_from_persistence()`][pydantic_graph.graph.Graph.next_from_persistence] may be used to run the next node of a graph run based on the state stored in persistence. +[`graph.iter_from_persistence()`][pydantic_graph.graph.Graph.iter_from_persistence] may be used to run the graph based on the state stored in persistence. -We can run the `count_down_graph` from [above](#iterating-over-a-graph), using [`graph.next_from_persistence()`][pydantic_graph.graph.Graph.next_from_persistence] and [`FileStatePersistence`][pydantic_graph.persistence.file.FileStatePersistence]. +We can run the `count_down_graph` from [above](#iterating-over-a-graph), using [`graph.iter_from_persistence()`][pydantic_graph.graph.Graph.iter_from_persistence] and [`FileStatePersistence`][pydantic_graph.persistence.file.FileStatePersistence]. As you can see in this code, `run_node` requires no external application state (apart from state persistence) to be run, meaning graphs can easily be executed by distributed execution and queueing systems. @@ -607,34 +607,39 @@ from count_down import CountDown, CountDownState, count_down_graph async def main(): - persistence = FileStatePersistence(Path('count_down.json')) # (1)! + run_id = 'run_abc123' + persistence = FileStatePersistence(Path(f'count_down_{run_id}.json')) # (1)! state = CountDownState(counter=5) - await count_down_graph.next( # (2)! + await count_down_graph.initialize( # (2)! CountDown(), state=state, persistence=persistence ) done = False while not done: - done = await run_node() + done = await run_node(run_id) -async def run_node() -> bool: # (3)! - persistence = FileStatePersistence(Path('count_down.json')) - node_or_end = await count_down_graph.next_from_persistence(persistence) # (4)! +async def run_node(run_id: str) -> bool: # (3)! + persistence = FileStatePersistence(Path(f'count_down_{run_id}.json')) + async with count_down_graph.iter_from_persistence(persistence) as run: # (4)! + node_or_end = await run.next() # (5)! + print('Node:', node_or_end) #> Node: CountDown() #> Node: CountDown() #> Node: CountDown() #> Node: CountDown() + #> Node: CountDown() #> Node: End(data=0) - return isinstance(node_or_end, End) # (5)! + return isinstance(node_or_end, End) # (6)! ``` 1. Create a [`FileStatePersistence`][pydantic_graph.persistence.file.FileStatePersistence] to use to start the graph. -2. Call [`graph.next()`][pydantic_graph.graph.Graph.next] to start the graph with the initial state. -3. `run_node` is a pure function that doesn't need access to any other process state to run the next node of the graph. -4. Call [`graph.next_from_persistence()`][pydantic_graph.graph.Graph.next_from_persistence] to run the next node of the graph from the state stored in persistence. This will return either a node or an `End` object. -5. Check if the node is an `End` object, if it is, the graph run is complete. +2. Call [`graph.initialize()`][pydantic_graph.graph.Graph.initialize] to set the initial graph state in the persistence object. +3. `run_node` is a pure function that doesn't need access to any other process state to run the next node of the graph, except the ID of the run. +4. Call [`graph.iter_from_persistence()`][pydantic_graph.graph.Graph.iter_from_persistence] create a [`GraphRun`][pydantic_graph.graph.GraphRun] object that will run the next node of the graph from the state stored in persistence. This will return either a node or an `End` object. +5. [`graph.run()`][pydantic_graph.graph.Graph.run] will return either a [node][pydantic_graph.nodes.BaseNode] or an [`End`][pydantic_graph.nodes.End] object. +5. Check if the node is an [`End`][pydantic_graph.nodes.End] object, if it is, the graph run is complete. _(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ @@ -770,20 +775,19 @@ async def main(): state = QuestionState() node = Ask() # (5)! - while True: - node = await question_graph.next( # (6)! - node, persistence=persistence, state=state - ) - if isinstance(node, End): # (7)! - print('END:', node.data) - history = await persistence.load() - print([e.node for e in history]) - break - elif isinstance(node, Answer): # (8)! - print(node.question) - #> What is the capital of France? - break - # otherwise just continue + async with question_graph.iter(node, state=state, persistence=persistence) as run: + while True: + node = await run.next() # (6)! + if isinstance(node, End): # (7)! + print('END:', node.data) + history = await persistence.load() + print([e.node for e in history]) + break + elif isinstance(node, Answer): # (8)! + print(node.question) + #> What is the capital of France? + break + # otherwise just continue ``` 1. Get the user's answer from the command line, if provided. See [question graph example](examples/question-graph.md) for a complete example. @@ -791,7 +795,7 @@ async def main(): 3. Since we're using the [persistence interface][pydantic_graph.persistence.BaseStatePersistence] outside a graph, we need to call [`set_graph_types`][pydantic_graph.persistence.BaseStatePersistence.set_graph_types] to set the graph generic types `StateT` and `RunEndT` for the persistence instance. This is necessary to allow the persistence instance to know how to serialize and deserialize graph nodes. 4. If we're run the graph before, [`retrieve_next`][pydantic_graph.persistence.BaseStatePersistence.retrieve_next] will return a snapshot of the next node to run, here we use `state` from that snapshot, and create a new `Evaluate` node with the answer provided on the command line. 5. If the graph hasn't been run before, we create a new `QuestionState` and start with the `Ask` node. -6. Call [`graph.next()`][pydantic_graph.graph.Graph.next] to run the node. This will return either a node or an `End` object. +6. Call [`GraphRun.next()`][pydantic_graph.graph.GraphRun.next] to run the node. This will return either a node or an `End` object. 7. If the node is an `End` object, the graph run is complete. The `data` field of the `End` object contains the comment returned by the `evaluate_agent` about the correct answer. 8. If the node is an `Answer` object, we print the question and break out of the loop to end the process and wait for user input. diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index 79ee60b29f..aa63befd8d 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -129,9 +129,9 @@ async def run_as_cli(answer: str | None): node = Ask() # debug(state, node) - with logfire.span('run questions graph'): + async with question_graph.iter(node, state=state, persistence=persistence) as run: while True: - node = await question_graph.next(node, persistence=persistence, state=state) + node = await run.next() if isinstance(node, End): print('END:', node.data) history = await persistence.load() diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 5ec102d911..c077667911 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -11,6 +11,7 @@ import logfire_api import typing_extensions from logfire_api import LogfireSpan +from typing_extensions import deprecated from typing_inspection import typing_objects from . import _utils, exceptions, mermaid @@ -216,8 +217,8 @@ async def iter( state: StateT = None, deps: DepsT = None, persistence: BaseStatePersistence[StateT, T] | None = None, - infer_name: bool = True, span: AbstractContextManager[Any] | None = None, + infer_name: bool = True, ) -> AsyncIterator[GraphRun[StateT, DepsT, T]]: """A contextmanager which can be used to iterate over the graph's nodes as they are executed. @@ -240,11 +241,10 @@ async def iter( deps: The dependencies of the graph. persistence: State persistence interface, defaults to [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. - infer_name: Whether to infer the graph name from the calling frame. span: The span to use for the graph run. If not provided, a new span will be created. + infer_name: Whether to infer the graph name from the calling frame. - Yields: - A GraphRun that can be async iterated over to drive the graph to completion. + Returns: A GraphRun that can be async iterated over to drive the graph to completion. """ if infer_name and self.name is None: # f_back because `asynccontextmanager` adds one frame @@ -265,6 +265,82 @@ async def iter( graph=self, start_node=start_node, persistence=persistence, state=state, deps=deps ) + @asynccontextmanager + async def iter_from_persistence( + self: Graph[StateT, DepsT, T], + persistence: BaseStatePersistence[StateT, T], + *, + deps: DepsT = None, + span: AbstractContextManager[Any] | None = None, + infer_name: bool = True, + ) -> AsyncIterator[GraphRun[StateT, DepsT, T]]: + """A contextmanager to iterate over the graph's nodes as they are executed, created from a persistence object. + + This method has similar functionality to [`iter`][pydantic_graph.graph.Graph.iter], + but instead of passing the node to run, it will restore the node and state from state persistence. + + Args: + persistence: The state persistence interface to use. + deps: The dependencies of the graph. + span: The span to use for the graph run. If not provided, a new span will be created. + infer_name: Whether to infer the graph name from the calling frame. + + Returns: A GraphRun that can be async iterated over to drive the graph to completion. + """ + if infer_name and self.name is None: + # f_back because `asynccontextmanager` adds one frame + if frame := inspect.currentframe(): # pragma: no branch + self._infer_name(frame.f_back) + + persistence.set_graph_types(self) + + snapshot = await persistence.retrieve_next() + if snapshot is None: + raise exceptions.GraphRuntimeError('Unable to restore snapshot from state persistence.') + + snapshot.node.set_snapshot_id(snapshot.id) + + if self.auto_instrument and span is None: + span = logfire_api.span('run graph {graph.name}', graph=self) + + with ExitStack() as stack: + if span is not None: + stack.enter_context(span) + yield GraphRun[StateT, DepsT, T]( + graph=self, + start_node=snapshot.node, + persistence=persistence, + state=snapshot.state, + deps=deps, + snapshot_id=snapshot.id, + ) + + async def initialize( + self: Graph[StateT, DepsT, T], + node: BaseNode[StateT, DepsT, T], + persistence: BaseStatePersistence[StateT, T], + *, + state: StateT = None, + infer_name: bool = True, + ) -> None: + """Initialize a new graph run in persistence without running it. + + This is useful if you want to set up a graph run to be run later, e.g. via + [`iter_from_persistence`][pydantic_graph.graph.Graph.iter_from_persistence]. + + Args: + node: The node to run first. + persistence: State persistence interface. + state: The start state of the graph. + infer_name: Whether to infer the graph name from the calling frame. + """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + + persistence.set_graph_types(self) + await persistence.snapshot_node(state, node) + + @deprecated('`next` is deprecated, use `async with graph.iter(...) as run: run.next()` instead') async def next( self: Graph[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T], @@ -300,34 +376,6 @@ async def next( ) return await run.next(node) - async def next_from_persistence( - self: Graph[StateT, DepsT, T], - persistence: BaseStatePersistence[StateT, T], - *, - deps: DepsT = None, - infer_name: bool = True, - ) -> BaseNode[StateT, DepsT, Any] | End[T]: - """Run the next node in the graph from a snapshot stored in persistence.""" - if infer_name and self.name is None: - self._infer_name(inspect.currentframe()) - persistence.set_graph_types(self) - - snapshot = await persistence.retrieve_next() - if snapshot is None: - raise exceptions.GraphRuntimeError('Unable to restore snapshot from state persistence.') - assert snapshot.id is not None, 'Snapshot ID should be set' - snapshot.node.set_snapshot_id(snapshot.id) - - run = GraphRun[StateT, DepsT, T]( - graph=self, - start_node=snapshot.node, - persistence=persistence, - state=snapshot.state, - deps=deps, - snapshot_id=snapshot.id, - ) - return await run.next() - def mermaid_code( self, *, @@ -689,9 +737,10 @@ async def main(): node_snapshot_id = node.get_snapshot_id() else: node_snapshot_id = node.get_snapshot_id() - if node_snapshot_id != self._snapshot_id: - await self.persistence.snapshot_node_if_new(node_snapshot_id, self.state, node) - self._snapshot_id = node_snapshot_id + + if node_snapshot_id != self._snapshot_id: + await self.persistence.snapshot_node_if_new(node_snapshot_id, self.state, node) + self._snapshot_id = node_snapshot_id if not isinstance(node, BaseNode): # While technically this is not compatible with the documented method signature, it's an easy mistake to diff --git a/pydantic_graph/pydantic_graph/persistence/__init__.py b/pydantic_graph/pydantic_graph/persistence/__init__.py index 91c632b789..ba6d778eb9 100644 --- a/pydantic_graph/pydantic_graph/persistence/__init__.py +++ b/pydantic_graph/pydantic_graph/persistence/__init__.py @@ -96,7 +96,10 @@ def node(self) -> End[RunEndT]: class BaseStatePersistence(ABC, Generic[StateT, RunEndT]): - """Abstract base class for storing the state of a graph.""" + """Abstract base class for storing the state of a graph run. + + Each instance of a `BaseStatePersistence` subclass should be used for a single graph run. + """ @abstractmethod async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: @@ -158,7 +161,7 @@ def record_run(self, snapshot_id: str) -> AbstractAsyncContextManager[None]: async def retrieve_next(self) -> NodeSnapshot[StateT, RunEndT] | None: """Retrieve a node snapshot with status `'created`' and set its status to `'pending'`. - This is used by [`Graph.next_from_persistence`][pydantic_graph.graph.Graph.next_from_persistence] + This is used by [`Graph.iter_from_persistence`][pydantic_graph.graph.Graph.iter_from_persistence] to get the next node to run. Returns: The snapshot, or `None` if no snapshot with status `'created`' exists. diff --git a/pydantic_graph/pydantic_graph/persistence/file.py b/pydantic_graph/pydantic_graph/persistence/file.py index b41fb86cc3..64cfc11899 100644 --- a/pydantic_graph/pydantic_graph/persistence/file.py +++ b/pydantic_graph/pydantic_graph/persistence/file.py @@ -28,10 +28,25 @@ @dataclass class FileStatePersistence(BaseStatePersistence[StateT, RunEndT]): - """File based state persistence that hold a list of snapshots in a JSON file.""" + """File based state persistence that hold graph run state in a JSON file.""" json_file: Path - """Path to the JSON file where the snapshots are stored.""" + """Path to the JSON file where the snapshots are stored. + + You should use a different file for each graph run, but a single file should be reused for multiple + steps of the same run. + + For example if you have a run ID of the form `run_123abc`, you might create a `FileStatePersistence` thus: + + ```py + from pathlib import Path + + from pydantic_graph import FullStatePersistence + + run_id = 'run_123abc' + persistence = FullStatePersistence(Path('runs') / f'{run_id}.json') + ``` + """ _snapshots_type_adapter: pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]] | None = field( default=None, init=False, repr=False ) diff --git a/tests/graph/test_file_persistence.py b/tests/graph/test_file_persistence.py index ec5dc30646..6472baab89 100644 --- a/tests/graph/test_file_persistence.py +++ b/tests/graph/test_file_persistence.py @@ -95,18 +95,20 @@ async def test_next_from_persistence(tmp_path: Path, mock_snapshot_id: object): p = tmp_path / 'test_graph.json' persistence = FileStatePersistence(p) - node = await my_graph.next(Float2String(3.14), persistence=persistence) - assert node == snapshot(String2Length(input_data='3.14')) - assert node.get_snapshot_id() == snapshot('String2Length:2') - assert my_graph.name == 'my_graph' - - node = await my_graph.next_from_persistence(persistence) - assert node == snapshot(Double(input_data=4)) - assert node.get_snapshot_id() == snapshot('Double:3') - - node = await my_graph.next_from_persistence(persistence) - assert node == snapshot(End(data=8)) - assert node.get_snapshot_id() == snapshot('end:4') + async with my_graph.iter(Float2String(3.14), persistence=persistence) as run: + node = await run.next() + assert node == snapshot(String2Length(input_data='3.14')) + assert node.get_snapshot_id() == snapshot('String2Length:2') + assert my_graph.name == 'my_graph' + + async with my_graph.iter_from_persistence(persistence) as run: + node = await run.next() + assert node == snapshot(Double(input_data=4)) + assert node.get_snapshot_id() == snapshot('Double:3') + + node = await run.next() + assert node == snapshot(End(data=8)) + assert node.get_snapshot_id() == snapshot('end:4') assert await persistence.load() == snapshot( [ diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index cf07b170d4..3631a164ab 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -315,7 +315,7 @@ async def test_iter(): assert node_reprs == snapshot(["String2Length(input_data='3.14')", 'Double(input_data=4)', 'End(data=8)']) -async def test_next(mock_snapshot_id: object): +async def test_iter_next(mock_snapshot_id: object): @dataclass class Foo(BaseNode): async def run(self, ctx: GraphRunContext) -> Bar: @@ -329,26 +329,27 @@ async def run(self, ctx: GraphRunContext) -> Foo: g = Graph(nodes=(Foo, Bar)) assert g.name is None sp = FullStatePersistence() - n = await g.next(Foo(), persistence=sp) - assert n == Bar() - assert g.name == 'g' - assert sp.history == snapshot( - [ - NodeSnapshot( - state=None, - node=Foo(), - start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(), - status='success', - id='Foo:1', - ), - NodeSnapshot(state=None, node=Bar(), id='Bar:2'), - ] - ) - - assert isinstance(n, Bar) - n2 = await g.next(n, persistence=sp) - assert n2 == Foo() + async with g.iter(Foo(), persistence=sp) as run: + assert g.name == 'g' + n = await run.next() + assert n == Bar() + assert sp.history == snapshot( + [ + NodeSnapshot( + state=None, + node=Foo(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + status='success', + id='Foo:1', + ), + NodeSnapshot(state=None, node=Bar(), id='Bar:2'), + ] + ) + + assert isinstance(n, Bar) + n2 = await run.next() + assert n2 == Foo() assert sp.history == snapshot( [ @@ -373,7 +374,7 @@ async def run(self, ctx: GraphRunContext) -> Foo: ) -async def test_next_error(mock_snapshot_id: object): +async def test_iter_next_error(mock_snapshot_id: object): @dataclass class Foo(BaseNode): async def run(self, ctx: GraphRunContext) -> Bar: @@ -386,15 +387,49 @@ async def run(self, ctx: GraphRunContext) -> End[None]: g = Graph(nodes=(Foo, Bar)) sp = SimpleStatePersistence() - n = await g.next(Foo(), sp) - assert n == snapshot(Bar()) + async with g.iter(Foo(), persistence=sp) as run: + n = await run.next() + assert n == snapshot(Bar()) + + assert isinstance(n, BaseNode) + n = await run.next() + assert n == snapshot(End(None)) + + with pytest.raises(TypeError, match=r'`next` must be called with a `BaseNode` instance, got End\(data=None\).'): + await run.next() + + +async def test_next(mock_snapshot_id: object): + @dataclass + class Foo(BaseNode): + async def run(self, ctx: GraphRunContext) -> Bar: + return Bar() - assert isinstance(n, BaseNode) - n = await g.next(n, sp) - assert n == snapshot(End(None)) + @dataclass + class Bar(BaseNode): + async def run(self, ctx: GraphRunContext) -> Foo: + return Foo() - with pytest.raises(TypeError, match=r'`next` must be called with a `BaseNode` instance, got End\(data=None\).'): - await g.next(n, sp) # pyright: ignore[reportArgumentType] + g = Graph(nodes=(Foo, Bar)) + assert g.name is None + sp = FullStatePersistence() + with pytest.warns(DeprecationWarning, match='`next` is deprecated, use `async with graph.iter(...)'): + n = await g.next(Foo(), persistence=sp) # pyright: ignore[reportDeprecated] + assert n == Bar() + assert g.name == 'g' + assert sp.history == snapshot( + [ + NodeSnapshot( + state=None, + node=Foo(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + status='success', + id='Foo:1', + ), + NodeSnapshot(state=None, node=Bar(), id='Bar:2'), + ] + ) async def test_deps(mock_snapshot_id: object): diff --git a/tests/graph/test_persistence.py b/tests/graph/test_persistence.py index 31f52c80af..389722605e 100644 --- a/tests/graph/test_persistence.py +++ b/tests/graph/test_persistence.py @@ -285,12 +285,14 @@ async def run(self, ctx: GraphRunContext) -> End[int]: sp = FullStatePersistence() node = Foo() - end = await graph.next(node, sp) - assert end == snapshot(End(123)) + async with graph.iter(node, persistence=sp) as run: + end = await run.next() + assert end == snapshot(End(123)) msg = "Incorrect snapshot status 'success', must be 'created' or 'pending'." with pytest.raises(GraphNodeStatusError, match=msg): - await graph.next(node, sp) + async with graph.iter(node, persistence=sp) as run: + await run.next() @pytest.mark.parametrize('persistence_cls', [SimpleStatePersistence, FullStatePersistence]) @@ -307,21 +309,24 @@ async def run(self, ctx: GraphRunContext) -> End[int]: g1 = Graph(nodes=[Foo, Spam]) - sp = persistence_cls() + persistence = persistence_cls() node = Foo() assert g1.name is None - node = await g1.next(node, sp) + await g1.initialize(node, persistence) assert g1.name == 'g1' - assert node == Spam() - end = await g1.next_from_persistence(sp) - assert end == End(123) + async with g1.iter_from_persistence(persistence) as run: + node = await run.next() + assert node == Spam() + end = await run.next() + assert end == End(123) g2 = Graph(nodes=[Foo, Spam]) - sp = persistence_cls() + persistence2 = persistence_cls() assert g2.name is None with pytest.raises(GraphRuntimeError, match='Unable to restore snapshot from state persistence.'): - await g2.next_from_persistence(sp) + async with g2.iter_from_persistence(persistence2): + pass assert g2.name == 'g2' From 2186e47eb306ba8a20f80943fe2429fb316c6638 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 13 Mar 2025 13:28:36 +0000 Subject: [PATCH 24/25] improve documentation --- docs/graph.md | 11 +-- .../pydantic_ai_examples/question_graph.py | 4 +- pydantic_graph/pydantic_graph/graph.py | 2 +- .../pydantic_graph/persistence/__init__.py | 72 +++++++++++++------ .../pydantic_graph/persistence/file.py | 14 ++-- .../pydantic_graph/persistence/in_mem.py | 10 +-- tests/graph/test_file_persistence.py | 6 +- tests/graph/test_persistence.py | 2 +- 8 files changed, 76 insertions(+), 45 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index d438d3001f..06570e8a25 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -767,7 +767,7 @@ async def main(): persistence = FileStatePersistence(Path('question_graph.json')) # (2)! persistence.set_graph_types(question_graph) # (3)! - if snapshot := await persistence.retrieve_next(): # (4)! + if snapshot := await persistence.load_next(): # (4)! state = snapshot.state assert answer is not None node = Evaluate(answer) @@ -780,10 +780,10 @@ async def main(): node = await run.next() # (6)! if isinstance(node, End): # (7)! print('END:', node.data) - history = await persistence.load() + history = await persistence.load_all() # (8)! print([e.node for e in history]) break - elif isinstance(node, Answer): # (8)! + elif isinstance(node, Answer): # (9)! print(node.question) #> What is the capital of France? break @@ -793,11 +793,12 @@ async def main(): 1. Get the user's answer from the command line, if provided. See [question graph example](examples/question-graph.md) for a complete example. 2. Create a state persistence instance the `'question_graph.json'` file may or may not already exist. 3. Since we're using the [persistence interface][pydantic_graph.persistence.BaseStatePersistence] outside a graph, we need to call [`set_graph_types`][pydantic_graph.persistence.BaseStatePersistence.set_graph_types] to set the graph generic types `StateT` and `RunEndT` for the persistence instance. This is necessary to allow the persistence instance to know how to serialize and deserialize graph nodes. -4. If we're run the graph before, [`retrieve_next`][pydantic_graph.persistence.BaseStatePersistence.retrieve_next] will return a snapshot of the next node to run, here we use `state` from that snapshot, and create a new `Evaluate` node with the answer provided on the command line. +4. If we're run the graph before, [`load_next`][pydantic_graph.persistence.BaseStatePersistence.load_next] will return a snapshot of the next node to run, here we use `state` from that snapshot, and create a new `Evaluate` node with the answer provided on the command line. 5. If the graph hasn't been run before, we create a new `QuestionState` and start with the `Ask` node. 6. Call [`GraphRun.next()`][pydantic_graph.graph.GraphRun.next] to run the node. This will return either a node or an `End` object. 7. If the node is an `End` object, the graph run is complete. The `data` field of the `End` object contains the comment returned by the `evaluate_agent` about the correct answer. -8. If the node is an `Answer` object, we print the question and break out of the loop to end the process and wait for user input. +8. To demonstrate the state persistence, we call [`load_all`][pydantic_graph.persistence.BaseStatePersistence.load_all] to get all the snapshots from the persistence instance. This will return a list of [`Snapshot`][pydantic_graph.persistence.Snapshot] objects. +9. If the node is an `Answer` object, we print the question and break out of the loop to end the process and wait for user input. _(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main(answer))` to run `main`)_ diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index aa63befd8d..f3da0008c4 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -118,7 +118,7 @@ async def run_as_cli(answer: str | None): persistence = FileStatePersistence(Path('question_graph.json')) persistence.set_graph_types(question_graph) - if snapshot := await persistence.retrieve_next(): + if snapshot := await persistence.load_next(): state = snapshot.state assert answer is not None, ( 'answer required, usage "uv run -m pydantic_ai_examples.question_graph cli "' @@ -134,7 +134,7 @@ async def run_as_cli(answer: str | None): node = await run.next() if isinstance(node, End): print('END:', node.data) - history = await persistence.load() + history = await persistence.load_all() print('history:', '\n'.join(str(e.node) for e in history), sep='\n') print('Finished!') break diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index c077667911..54beb38ce6 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -294,7 +294,7 @@ async def iter_from_persistence( persistence.set_graph_types(self) - snapshot = await persistence.retrieve_next() + snapshot = await persistence.load_next() if snapshot is None: raise exceptions.GraphRuntimeError('Unable to restore snapshot from state persistence.') diff --git a/pydantic_graph/pydantic_graph/persistence/__init__.py b/pydantic_graph/pydantic_graph/persistence/__init__.py index ba6d778eb9..16cd42d14c 100644 --- a/pydantic_graph/pydantic_graph/persistence/__init__.py +++ b/pydantic_graph/pydantic_graph/persistence/__init__.py @@ -15,7 +15,15 @@ if TYPE_CHECKING: from .. import Graph -__all__ = 'StateT', 'NodeSnapshot', 'EndSnapshot', 'Snapshot', 'BaseStatePersistence', 'SnapshotStatus' +__all__ = ( + 'StateT', + 'NodeSnapshot', + 'EndSnapshot', + 'Snapshot', + 'BaseStatePersistence', + 'SnapshotStatus', + 'build_snapshot_list_type_adapter', +) StateT = TypeVar('StateT', default=Any) RunEndT = TypeVar('RunEndT', default=Any) @@ -26,7 +34,7 @@ - `'created'`: The snapshot has been created but not yet run. - `'pending'`: The snapshot has been retrieved with - [`retrieve_next`][pydantic_graph.persistence.BaseStatePersistence.retrieve_next] but not yet run. + [`load_next`][pydantic_graph.persistence.BaseStatePersistence.load_next] but not yet run. - `'running'`: The snapshot is currently running. - `'success'`: The snapshot has been run successfully. - `'error'`: The snapshot has been run but an error occurred. @@ -105,6 +113,8 @@ class BaseStatePersistence(ABC, Generic[StateT, RunEndT]): async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: """Snapshot the state of a graph, when the next step is to run a node. + This method should add a [`NodeSnapshot`][pydantic_graph.persistence.NodeSnapshot] to persistence. + Args: state: The state of the graph. next_node: The next node to run. @@ -115,7 +125,10 @@ async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, Ru async def snapshot_node_if_new( self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT] ) -> None: - """Snapshot the state of a graph, if the snapshot ID doesn't already exist in persistence. + """Snapshot the state of a graph if the snapshot ID doesn't already exist in persistence. + + This method will generally call [`snapshot_node`][pydantic_graph.persistence.BaseStatePersistence.snapshot_node] + but should do so in an atomic way. Args: snapshot_id: The ID of the snapshot to check. @@ -126,7 +139,9 @@ async def snapshot_node_if_new( @abstractmethod async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: - """Snapshot the state of a graph after a node has run, when the graph has ended. + """Snapshot the state of a graph when the graph has ended. + + This method should add an [`EndSnapshot`][pydantic_graph.persistence.EndSnapshot] to persistence. Args: state: The state of the graph. @@ -158,7 +173,7 @@ def record_run(self, snapshot_id: str) -> AbstractAsyncContextManager[None]: raise NotImplementedError @abstractmethod - async def retrieve_next(self) -> NodeSnapshot[StateT, RunEndT] | None: + async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None: """Retrieve a node snapshot with status `'created`' and set its status to `'pending'`. This is used by [`Graph.iter_from_persistence`][pydantic_graph.graph.Graph.iter_from_persistence] @@ -169,40 +184,55 @@ async def retrieve_next(self) -> NodeSnapshot[StateT, RunEndT] | None: raise NotImplementedError @abstractmethod - async def load(self) -> list[Snapshot[StateT, RunEndT]]: + async def load_all(self) -> list[Snapshot[StateT, RunEndT]]: """Load the entire history of snapshots. + `load_all` is not used by pydantic-graph itself, instead it's provided to make it convenient to + get all [snapshots][pydantic_graph.persistence.Snapshot] from persistence. + Returns: The list of snapshots. """ raise NotImplementedError - def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: - """Set the types of the state and run end. - - This can be used to create [type adapters][pydantic.TypeAdapter] for serializing and deserializing - snapshots. + def set_graph_types(self, graph: Graph[StateT, Any, RunEndT]) -> None: + """Set the types of the state and run end from a graph. - Args: - state_type: The state type. - run_end_type: The run end type. + You generally won't need to customise this method, instead implement + [`set_types`][pydantic_graph.persistence.BaseStatePersistence.set_types] and + [`should_set_types`][pydantic_graph.persistence.BaseStatePersistence.should_set_types]. """ - pass + if self.should_set_types(): + with _utils.set_nodes_type_context(graph.get_nodes()): + self.set_types(*graph.inferred_types) - def _should_set_types(self) -> bool: + def should_set_types(self) -> bool: """Whether types need to be set. Implementations should override this method to return `True` when types have not been set if they are needed. """ return False - def set_graph_types(self, graph: Graph[StateT, Any, RunEndT]) -> None: - """Set the types of the state and run end from a graph.""" - if self._should_set_types(): - with _utils.set_nodes_type_context(graph.get_nodes()): - self.set_types(*graph.inferred_types) + def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: + """Set the types of the state and run end. + + This can be used to create [type adapters][pydantic.TypeAdapter] for serializing and deserializing snapshots, + e.g. with [`build_snapshot_list_type_adapter`][pydantic_graph.persistence.build_snapshot_list_type_adapter]. + + Args: + state_type: The state type. + run_end_type: The run end type. + """ + pass def build_snapshot_list_type_adapter( state_t: type[StateT], run_end_t: type[RunEndT] ) -> pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]]: + """Build a type adapter for a list of snapshots. + + This method should be called from within + [`set_types`][pydantic_graph.persistence.BaseStatePersistence.set_types] + where context variables will be set such that Pydantic can create a schema for + [`NodeSnapshot.node`][pydantic_graph.persistence.NodeSnapshot.node]. + """ return pydantic.TypeAdapter(list[Annotated[Snapshot[state_t, run_end_t], pydantic.Discriminator('kind')]]) diff --git a/pydantic_graph/pydantic_graph/persistence/file.py b/pydantic_graph/pydantic_graph/persistence/file.py index 64cfc11899..51d001c5bc 100644 --- a/pydantic_graph/pydantic_graph/persistence/file.py +++ b/pydantic_graph/pydantic_graph/persistence/file.py @@ -58,7 +58,7 @@ async def snapshot_node_if_new( self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT] ) -> None: async with self._lock(): - snapshots = await self.load() + snapshots = await self.load_all() if not any(s.id == snapshot_id for s in snapshots): await self._append_save(NodeSnapshot(state=state, node=next_node), lock=False) @@ -68,7 +68,7 @@ async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: @asynccontextmanager async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: async with self._lock(): - snapshots = await self.load() + snapshots = await self.load_all() try: snapshot = next(s for s in snapshots if s.id == snapshot_id) except StopIteration as e: @@ -93,22 +93,22 @@ async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: async with self._lock(): await _graph_utils.run_in_executor(self._after_run_sync, snapshot_id, snapshot.duration, 'success') - async def retrieve_next(self) -> NodeSnapshot[StateT, RunEndT] | None: + async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None: async with self._lock(): - snapshots = await self.load() + snapshots = await self.load_all() if snapshot := next((s for s in snapshots if isinstance(s, NodeSnapshot) and s.status == 'created'), None): snapshot.status = 'pending' await self._save(snapshots) return snapshot - def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: + def _set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type) def _should_set_types(self) -> bool: """Whether types need to be set.""" return self._snapshots_type_adapter is None - async def load(self) -> list[Snapshot[StateT, RunEndT]]: + async def load_all(self) -> list[Snapshot[StateT, RunEndT]]: return await _graph_utils.run_in_executor(self._load_sync) def _load_sync(self) -> list[Snapshot[StateT, RunEndT]]: @@ -140,7 +140,7 @@ async def _append_save(self, snapshot: Snapshot[StateT, RunEndT], *, lock: bool async with AsyncExitStack() as stack: if lock: await stack.enter_async_context(self._lock()) - snapshots = await self.load() + snapshots = await self.load_all() snapshots.append(snapshot) await self._save(snapshots) diff --git a/pydantic_graph/pydantic_graph/persistence/in_mem.py b/pydantic_graph/pydantic_graph/persistence/in_mem.py index eb275a5261..980594c633 100644 --- a/pydantic_graph/pydantic_graph/persistence/in_mem.py +++ b/pydantic_graph/pydantic_graph/persistence/in_mem.py @@ -73,12 +73,12 @@ async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: self.last_snapshot.duration = perf_counter() - start self.last_snapshot.status = 'success' - async def retrieve_next(self) -> NodeSnapshot[StateT, RunEndT] | None: + async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None: if isinstance(self.last_snapshot, NodeSnapshot) and self.last_snapshot.status == 'created': self.last_snapshot.status = 'pending' return self.last_snapshot - async def load(self) -> list[Snapshot[StateT, RunEndT]]: + async def load_all(self) -> list[Snapshot[StateT, RunEndT]]: raise NotImplementedError('load is not supported for SimpleStatePersistence') @@ -140,15 +140,15 @@ async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: snapshot.duration = perf_counter() - start snapshot.status = 'success' - async def retrieve_next(self) -> NodeSnapshot[StateT, RunEndT] | None: + async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None: if snapshot := next((s for s in self.history if isinstance(s, NodeSnapshot) and s.status == 'created'), None): snapshot.status = 'pending' return snapshot - async def load(self) -> list[Snapshot[StateT, RunEndT]]: + async def load_all(self) -> list[Snapshot[StateT, RunEndT]]: return self.history - def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: + def _set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type) def _should_set_types(self) -> bool: diff --git a/tests/graph/test_file_persistence.py b/tests/graph/test_file_persistence.py index 6472baab89..1aa867acbe 100644 --- a/tests/graph/test_file_persistence.py +++ b/tests/graph/test_file_persistence.py @@ -59,7 +59,7 @@ async def test_run(tmp_path: Path, mock_snapshot_id: object): # len('3.14') * 2 == 8 assert result.output == 8 assert my_graph.name == 'my_graph' - assert await persistence.load() == snapshot( + assert await persistence.load_all() == snapshot( [ NodeSnapshot( state=None, @@ -110,7 +110,7 @@ async def test_next_from_persistence(tmp_path: Path, mock_snapshot_id: object): assert node == snapshot(End(data=8)) assert node.get_snapshot_id() == snapshot('end:4') - assert await persistence.load() == snapshot( + assert await persistence.load_all() == snapshot( [ NodeSnapshot( state=None, @@ -158,7 +158,7 @@ async def run(self, ctx: GraphRunContext) -> End[None]: with pytest.raises(RuntimeError, match='test error'): await g.run(Foo(), persistence=persistence) - assert await persistence.load() == snapshot( + assert await persistence.load_all() == snapshot( [ NodeSnapshot( state=None, diff --git a/tests/graph/test_persistence.py b/tests/graph/test_persistence.py index 389722605e..4f1a7579c4 100644 --- a/tests/graph/test_persistence.py +++ b/tests/graph/test_persistence.py @@ -61,7 +61,7 @@ async def test_dump_load_state(graph: Graph[MyState, None, int], mock_snapshot_i result = await graph.run(Foo(), state=MyState(1, ''), persistence=sp) assert result.output == snapshot(4) assert result.state == snapshot(MyState(x=2, y='y')) - assert await sp.load() == snapshot( + assert await sp.load_all() == snapshot( [ NodeSnapshot( state=MyState(x=1, y=''), From 894a590f0fa365663b872e51b26920feba5f1659 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 13 Mar 2025 13:33:15 +0000 Subject: [PATCH 25/25] fix tests --- pydantic_graph/pydantic_graph/persistence/file.py | 8 ++++---- pydantic_graph/pydantic_graph/persistence/in_mem.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pydantic_graph/pydantic_graph/persistence/file.py b/pydantic_graph/pydantic_graph/persistence/file.py index 51d001c5bc..8ecf1ac116 100644 --- a/pydantic_graph/pydantic_graph/persistence/file.py +++ b/pydantic_graph/pydantic_graph/persistence/file.py @@ -101,13 +101,13 @@ async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None: await self._save(snapshots) return snapshot - def _set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: - self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type) - - def _should_set_types(self) -> bool: + def should_set_types(self) -> bool: """Whether types need to be set.""" return self._snapshots_type_adapter is None + def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: + self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type) + async def load_all(self) -> list[Snapshot[StateT, RunEndT]]: return await _graph_utils.run_in_executor(self._load_sync) diff --git a/pydantic_graph/pydantic_graph/persistence/in_mem.py b/pydantic_graph/pydantic_graph/persistence/in_mem.py index 980594c633..9118e087d2 100644 --- a/pydantic_graph/pydantic_graph/persistence/in_mem.py +++ b/pydantic_graph/pydantic_graph/persistence/in_mem.py @@ -148,12 +148,12 @@ async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None: async def load_all(self) -> list[Snapshot[StateT, RunEndT]]: return self.history - def _set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: - self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type) - - def _should_set_types(self) -> bool: + def should_set_types(self) -> bool: return self._snapshots_type_adapter is None + def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: + self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type) + def dump_json(self, *, indent: int | None = None) -> bytes: """Dump the history to JSON bytes.""" assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `dump_json`'