From dab200d63b32c937216730e1dd46b3243a4555d2 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 22 Dec 2024 16:38:31 +0000 Subject: [PATCH 01/57] start pydantic-ai-graph --- pydantic_ai_graph/README.md | 16 +++ .../pydantic_ai_graph/__init__.py | 0 pydantic_ai_graph/pydantic_ai_graph/_utils.py | 37 +++++ pydantic_ai_graph/pydantic_ai_graph/graph.py | 128 ++++++++++++++++++ pydantic_ai_graph/pydantic_ai_graph/nodes.py | 116 ++++++++++++++++ pydantic_ai_graph/pydantic_ai_graph/state.py | 17 +++ pydantic_ai_graph/pyproject.toml | 43 ++++++ pyproject.toml | 7 +- uv.lock | 18 +++ 9 files changed, 379 insertions(+), 3 deletions(-) create mode 100644 pydantic_ai_graph/README.md create mode 100644 pydantic_ai_graph/pydantic_ai_graph/__init__.py create mode 100644 pydantic_ai_graph/pydantic_ai_graph/_utils.py create mode 100644 pydantic_ai_graph/pydantic_ai_graph/graph.py create mode 100644 pydantic_ai_graph/pydantic_ai_graph/nodes.py create mode 100644 pydantic_ai_graph/pydantic_ai_graph/state.py create mode 100644 pydantic_ai_graph/pyproject.toml diff --git a/pydantic_ai_graph/README.md b/pydantic_ai_graph/README.md new file mode 100644 index 0000000000..2271f779cb --- /dev/null +++ b/pydantic_ai_graph/README.md @@ -0,0 +1,16 @@ +# PydanticAI Graph + +[![CI](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain) +[![Coverage](https://coverage-badge.samuelcolvin.workers.dev/pydantic/pydantic-ai.svg)](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai) +[![PyPI](https://img.shields.io/pypi/v/pydantic-ai-graph.svg)](https://pypi.python.org/pypi/pydantic-ai-graph) +[![versions](https://img.shields.io/pypi/pyversions/pydantic-ai-graph.svg)](https://github.com/pydantic/pydantic-ai) +[![license](https://img.shields.io/github/license/pydantic/pydantic-ai-graph.svg?v)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE) + +Graph and State Machine library. + +This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency +on `pydantic-ai` or related packages and can be considered as a pure graph library. + +As with PydanticAI, this library priorities type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. + +`pydantic-ai-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes. diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pydantic_ai_graph/pydantic_ai_graph/_utils.py b/pydantic_ai_graph/pydantic_ai_graph/_utils.py new file mode 100644 index 0000000000..b58e142f6e --- /dev/null +++ b/pydantic_ai_graph/pydantic_ai_graph/_utils.py @@ -0,0 +1,37 @@ +import sys +import types +from typing import Any, TypeAliasType, Union, get_args, get_origin + + +def get_union_args(tp: Any) -> tuple[Any, ...]: + """Extract the arguments of a Union type if `response_type` is a union, otherwise return the original type.""" + # similar to `pydantic_ai_slim/pydantic_ai/_result.py:get_union_args` + if isinstance(tp, TypeAliasType): + tp = tp.__value__ + + origin = get_origin(tp) + if origin_is_union(origin): + return get_args(tp) + else: + return (tp,) + + +# same as `pydantic_ai_slim/pydantic_ai/_result.py:origin_is_union` +if sys.version_info < (3, 10): + + def origin_is_union(tp: type[Any] | None) -> bool: + return tp is Union + +else: + + def origin_is_union(tp: type[Any] | None) -> bool: + return tp is Union or tp is types.UnionType + + +def comma_and(items: list[str]) -> str: + """Join with a comma and 'and' for the last item.""" + if len(items) == 1: + return items[0] + else: + # oxford comma ¯\_(ツ)_/¯ + return ', '.join(items[:-1]) + ', and ' + items[-1] diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py new file mode 100644 index 0000000000..0180cfa325 --- /dev/null +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -0,0 +1,128 @@ +from __future__ import annotations as _annotations + +import base64 +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Generic + +from typing_extensions import TypeVar, assert_never + +from . import _utils +from .nodes import ( + BaseNode, + DepsT, + End, + GraphContext, + GraphOutputT, + NodeDef, +) +from .state import StateT + +__all__ = ('Graph',) +GraphInputT = TypeVar('GraphInputT', default=Any) + + +# noinspection PyTypeHints +@dataclass(init=False) +class Graph(Generic[GraphInputT, GraphOutputT, DepsT, StateT]): + """Definition of a graph.""" + + first_node: NodeDef[Any, Any, DepsT, StateT] + nodes: dict[str, NodeDef[Any, Any, DepsT, StateT]] + state_type: type[StateT] | None + + # noinspection PyUnusedLocal + def __init__( + self, + first_node: type[BaseNode[GraphInputT, Any, DepsT, StateT]], + *other_nodes: type[BaseNode[Any, GraphOutputT, DepsT, StateT]], + deps_type: type[DepsT] | None = None, + state_type: type[StateT] | None = None, + ): + self.first_node = first_node.get_node_def() + self.nodes = nodes = {self.first_node.node_id: self.first_node} + for node in other_nodes: + node_def = node.get_node_def() + nodes[node_def.node_id] = node_def + + self._check() + self.state_type = state_type + + async def run( + self, input_data: GraphInputT, deps: DepsT = None, state: StateT = None + ) -> tuple[GraphOutputT, StateT]: + current_node_def = self.first_node + current_node = current_node_def.node(input_data) + ctx = GraphContext(deps, state) + while True: + # noinspection PyUnresolvedReferences + next_node = await current_node.run(ctx) + if isinstance(next_node, End): + if current_node_def.can_end: + return next_node.data, ctx.state + else: + raise ValueError(f'Node {current_node_def.node_id} cannot end the graph') + elif isinstance(next_node, BaseNode): + next_node_id = next_node.get_id() + try: + next_node_def = self.nodes[next_node_id] + except KeyError as e: + raise ValueError( + f'Node {current_node_def.node_id} cannot go to {next_node_id} which is not in the Graph' + ) from e + + if not current_node_def.dest_any and next_node_id not in current_node_def.next_node_ids: + raise ValueError( + f'Node {current_node_def.node_id} cannot go to {next_node_id} which is not in its ' + f'list of allowed next nodes' + ) + + current_node_def = next_node_def + current_node = next_node + else: + if TYPE_CHECKING: + assert_never(next_node) + else: + raise TypeError(f'Invalid node type: {type(next_node)} expected BaseNode or End') + + def mermaid_code(self) -> str: + lines = ['graph TD'] + # order of destination nodes should match their order in `self.nodes` + node_order = {nid: index for index, nid in enumerate(self.nodes.keys())} + for node_id, node in self.nodes.items(): + if node.dest_any: + for next_node_id in self.nodes: + lines.append(f' {node_id} --> {next_node_id}') + for _, next_node_id in sorted((node_order[nid], nid) for nid in node.next_node_ids): + lines.append(f' {node_id} --> {next_node_id}') + if node.can_end: + lines.append(f' {node_id} --> END') + return '\n'.join(lines) + + def mermaid_image(self, mermaid_ink_params: dict[str, str | int] | None = None) -> bytes: + import httpx + + code_base64 = base64.b64encode(self.mermaid_code().encode()).decode() + + response = httpx.get(f'https://mermaid.ink/img/{code_base64}', params=mermaid_ink_params) + response.raise_for_status() + return response.content + + def mermaid_save(self, path: Path, mermaid_ink_params: dict[str, str | int] | None = None) -> None: + image_data = self.mermaid_image(mermaid_ink_params) + path.write_bytes(image_data) + + def _check(self): + bad_edges: dict[str, list[str]] = {} + for node in self.nodes.values(): + node_bad_edges = node.next_node_ids - self.nodes.keys() + for bad_edge in node_bad_edges: + bad_edges.setdefault(bad_edge, []).append(f'"{node.node_id}"') + + if bad_edges: + bad_edges_list = [f'"{k}" is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] + if len(bad_edges_list) == 1: + raise ValueError(f'{bad_edges_list[0]} but not included in the graph.') + else: + b = '\n'.join(f' {be}' for be in bad_edges_list) + raise ValueError(f'Nodes are referenced in the graph but not included in the graph:\n{b}') diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_ai_graph/pydantic_ai_graph/nodes.py new file mode 100644 index 0000000000..a790e01959 --- /dev/null +++ b/pydantic_ai_graph/pydantic_ai_graph/nodes.py @@ -0,0 +1,116 @@ +from __future__ import annotations as _annotations + +from abc import ABC, ABCMeta, abstractmethod +from dataclasses import dataclass +from functools import cache +from typing import Any, ClassVar, Generic, get_args, get_origin, get_type_hints + +from typing_extensions import TypeVar + +from . import _utils +from .state import StateT + +__all__ = ( + 'NodeInputT', + 'GraphOutputT', + 'DepsT', + 'GraphContext', + 'End', + 'BaseNode', + 'NodeDef', +) + +NodeInputT = TypeVar('NodeInputT', default=Any) +GraphOutputT = TypeVar('GraphOutputT', default=Any) +DepsT = TypeVar('DepsT', default=None) + + +# noinspection PyTypeHints +@dataclass +class GraphContext(Generic[DepsT, StateT]): + """Context for a graph.""" + + deps: DepsT + state: StateT + + +# noinspection PyTypeHints +class End(ABC, Generic[NodeInputT]): + """Type to return from a node to signal the end of the graph.""" + + __slots__ = ('data',) + + def __init__(self, input_data: NodeInputT) -> None: + self.data = input_data + + +class _BaseNodeMeta(ABCMeta): + def __repr__(cls): + base: Any = cls.__orig_bases__[0] # type: ignore + args = get_args(base) + if len(args) == 4 and args[3] is None: + if args[2] is None: + args = args[:2] + else: + args = args[:3] + args = ', '.join(a.__name__ for a in args) + return f'{cls.__name__}({base.__name__}[{args}])' + + +# noinspection PyTypeHints +class BaseNode(Generic[NodeInputT, GraphOutputT, DepsT, StateT], metaclass=_BaseNodeMeta): + """Base class for a node.""" + + node_id: ClassVar[str | None] = None + __slots__ = ('input_data',) + + def __init__(self, input_data: NodeInputT) -> None: + self.input_data = input_data + + @abstractmethod + async def run(self, ctx: GraphContext[DepsT, StateT]) -> BaseNode[Any, Any, DepsT, StateT] | End[GraphOutputT]: ... + + @classmethod + @cache + def get_id(cls) -> str: + return cls.node_id or cls.__qualname__ + + @classmethod + def get_node_def(cls) -> NodeDef[Any, Any, DepsT, StateT]: + type_hints = get_type_hints(cls.run) + next_node_ids: set[str] = set() + can_end: bool = False + dest_any: bool = False + for return_type in _utils.get_union_args(type_hints['return']): + return_type_origin = get_origin(return_type) or return_type + if return_type_origin is BaseNode: + dest_any = True + elif issubclass(return_type_origin, BaseNode): + next_node_ids.add(return_type.get_id()) + elif return_type_origin is End: + can_end = True + else: + raise TypeError(f'Invalid return type: {return_type}') + + return NodeDef( + cls, + cls.get_id(), + next_node_ids, + can_end, + dest_any, + ) + + +# noinspection PyTypeHints +@dataclass +class NodeDef(ABC, Generic[NodeInputT, GraphOutputT, DepsT, StateT]): + """Definition of a node. + + Used by [`Graph`][pydantic_ai_graph.graph.Graph] store information about a node. + """ + + node: type[BaseNode[NodeInputT, GraphOutputT, DepsT, StateT]] + node_id: str + next_node_ids: set[str] + can_end: bool + dest_any: bool diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py new file mode 100644 index 0000000000..f9fac1199f --- /dev/null +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -0,0 +1,17 @@ +from abc import ABC + +from typing_extensions import TypeVar + +__all__ = 'AbstractState', 'StateT' + + +class AbstractState(ABC): + """Abstract class for a state object.""" + + def __init__(self): + pass + + # TODO serializing and deserialize state + + +StateT = TypeVar('StateT', None, AbstractState, default=None) diff --git a/pydantic_ai_graph/pyproject.toml b/pydantic_ai_graph/pyproject.toml new file mode 100644 index 0000000000..84cce90ea7 --- /dev/null +++ b/pydantic_ai_graph/pyproject.toml @@ -0,0 +1,43 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "pydantic-ai-graph" +version = "0.0.1" +description = "Graph and State Machine library" +authors = [ + { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, +] +license = "MIT" +readme = "README.md" +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Intended Audience :: Developers", + "Intended Audience :: Information Technology", + "Intended Audience :: System Administrators", + "License :: OSI Approved :: MIT License", + "Operating System :: Unix", + "Operating System :: POSIX :: Linux", + "Environment :: Console", + "Environment :: MacOS X", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Internet", +] +requires-python = ">=3.9" +dependencies = [ + "httpx>=0.27.2", + "logfire-api>=1.2.0", + "pydantic>=2.10", +] + +[tool.hatch.build.targets.wheel] +packages = ["pydantic_ai_graph"] diff --git a/pyproject.toml b/pyproject.toml index c275f64962..b538e26821 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ pydantic-ai-slim = { workspace = true } pydantic-ai-examples = { workspace = true } [tool.uv.workspace] -members = ["pydantic_ai_slim", "examples"] +members = ["pydantic_ai_slim", "pydantic_ai_graph", "examples"] [dependency-groups] # dev dependencies are defined in `pydantic-ai-slim/pyproject.toml` to allow for minimal testing @@ -82,6 +82,7 @@ line-length = 120 target-version = "py39" include = [ "pydantic_ai_slim/**/*.py", + "pydantic_ai_graph/**/*.py", "examples/**/*.py", "tests/**/*.py", "docs/**/*.py", @@ -125,7 +126,7 @@ typeCheckingMode = "strict" reportMissingTypeStubs = false reportUnnecessaryIsInstance = false reportUnnecessaryTypeIgnoreComment = true -include = ["pydantic_ai_slim", "tests", "examples"] +include = ["pydantic_ai_slim", "pydantic_ai_graph", "tests", "examples"] venvPath = ".venv" # see https://github.com/microsoft/pyright/issues/7771 - we don't want to error on decorated functions in tests # which are not otherwise used @@ -144,7 +145,7 @@ filterwarnings = [ # https://coverage.readthedocs.io/en/latest/config.html#run [tool.coverage.run] # required to avoid warnings about files created by create_module fixture -include = ["pydantic_ai_slim/**/*.py", "tests/**/*.py"] +include = ["pydantic_ai_slim/**/*.py", "pydantic_ai_graph/**/*.py","tests/**/*.py"] omit = ["tests/test_live.py", "tests/example_modules/*.py"] branch = true diff --git a/uv.lock b/uv.lock index 850a4072e4..37ed8c6cee 100644 --- a/uv.lock +++ b/uv.lock @@ -12,6 +12,7 @@ resolution-markers = [ members = [ "pydantic-ai", "pydantic-ai-examples", + "pydantic-ai-graph", "pydantic-ai-slim", ] @@ -2095,6 +2096,23 @@ requires-dist = [ { name = "uvicorn", specifier = ">=0.32.0" }, ] +[[package]] +name = "pydantic-ai-graph" +version = "0.0.1" +source = { editable = "pydantic_ai_graph" } +dependencies = [ + { name = "httpx" }, + { name = "logfire-api" }, + { name = "pydantic" }, +] + +[package.metadata] +requires-dist = [ + { name = "httpx", specifier = ">=0.27.2" }, + { name = "logfire-api", specifier = ">=1.2.0" }, + { name = "pydantic", specifier = ">=2.10" }, +] + [[package]] name = "pydantic-ai-slim" version = "0.0.18" From bdd5c2c3fec50c9a541dc05d2e48cac3daa08089 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 22 Dec 2024 16:44:56 +0000 Subject: [PATCH 02/57] lower case state machine --- pydantic_ai_graph/README.md | 2 +- pydantic_ai_graph/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_graph/README.md b/pydantic_ai_graph/README.md index 2271f779cb..ec0b531a25 100644 --- a/pydantic_ai_graph/README.md +++ b/pydantic_ai_graph/README.md @@ -6,7 +6,7 @@ [![versions](https://img.shields.io/pypi/pyversions/pydantic-ai-graph.svg)](https://github.com/pydantic/pydantic-ai) [![license](https://img.shields.io/github/license/pydantic/pydantic-ai-graph.svg?v)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE) -Graph and State Machine library. +Graph and state machine library. This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency on `pydantic-ai` or related packages and can be considered as a pure graph library. diff --git a/pydantic_ai_graph/pyproject.toml b/pydantic_ai_graph/pyproject.toml index 84cce90ea7..2231df5cd1 100644 --- a/pydantic_ai_graph/pyproject.toml +++ b/pydantic_ai_graph/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "hatchling.build" [project] name = "pydantic-ai-graph" version = "0.0.1" -description = "Graph and State Machine library" +description = "Graph and state machine library" authors = [ { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, ] From a65df821a25039fe184b7b4c01c0eaf8b5252b1e Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 22 Dec 2024 19:46:50 +0000 Subject: [PATCH 03/57] starting tests --- Makefile | 2 +- .../pydantic_ai_graph/__init__.py | 4 ++ pydantic_ai_graph/pydantic_ai_graph/graph.py | 21 ++++++- pydantic_ai_graph/pydantic_ai_graph/nodes.py | 4 +- pydantic_ai_graph/pydantic_ai_graph/py.typed | 0 pyproject.toml | 4 ++ tests/test_graph.py | 34 ++++++++++ tests/typed_graph.py | 63 +++++++++++++++++++ 8 files changed, 127 insertions(+), 5 deletions(-) create mode 100644 pydantic_ai_graph/pydantic_ai_graph/py.typed create mode 100644 tests/test_graph.py create mode 100644 tests/typed_graph.py diff --git a/Makefile b/Makefile index 8fea000af8..7b12e2a9f1 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ typecheck-pyright: .PHONY: typecheck-mypy typecheck-mypy: - uv run mypy --strict tests/typed_agent.py + uv run mypy .PHONY: typecheck typecheck: typecheck-pyright ## Run static type checking diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py index e69de29bb2..6345bd7090 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/__init__.py +++ b/pydantic_ai_graph/pydantic_ai_graph/__init__.py @@ -0,0 +1,4 @@ +from .graph import Graph +from .nodes import BaseNode, End, GraphContext + +__all__ = 'BaseNode', 'End', 'GraphContext', 'Graph' diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index 0180cfa325..e21bfc7b54 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -1,6 +1,8 @@ from __future__ import annotations as _annotations import base64 +import inspect +import types from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Generic @@ -39,10 +41,11 @@ def __init__( deps_type: type[DepsT] | None = None, state_type: type[StateT] | None = None, ): - self.first_node = first_node.get_node_def() + parent_namespace = get_parent_namespace(inspect.currentframe()) + self.first_node = first_node.get_node_def(parent_namespace) self.nodes = nodes = {self.first_node.node_id: self.first_node} for node in other_nodes: - node_def = node.get_node_def() + node_def = node.get_node_def(parent_namespace) nodes[node_def.node_id] = node_def self._check() @@ -126,3 +129,17 @@ def _check(self): else: b = '\n'.join(f' {be}' for be in bad_edges_list) raise ValueError(f'Nodes are referenced in the graph but not included in the graph:\n{b}') + + +def get_parent_namespace(frame: types.FrameType | None) -> dict[str, Any] | None: + """Attempt to get the namespace where the graph was defined. + + If the graph is defined with generics `Graph[a, b]` then another frame is inserted, and we have to skip that + to get the correct namespace. + """ + if frame is not None: + if back := frame.f_back: + if back.f_code.co_filename.endswith('/typing.py'): + return get_parent_namespace(back) + else: + return back.f_locals diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_ai_graph/pydantic_ai_graph/nodes.py index a790e01959..c58102dfaf 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/nodes.py +++ b/pydantic_ai_graph/pydantic_ai_graph/nodes.py @@ -76,8 +76,8 @@ def get_id(cls) -> str: return cls.node_id or cls.__qualname__ @classmethod - def get_node_def(cls) -> NodeDef[Any, Any, DepsT, StateT]: - type_hints = get_type_hints(cls.run) + def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[Any, Any, DepsT, StateT]: + type_hints = get_type_hints(cls.run, localns=local_ns) next_node_ids: set[str] = set() can_end: bool = False dest_any: bool = False diff --git a/pydantic_ai_graph/pydantic_ai_graph/py.typed b/pydantic_ai_graph/pydantic_ai_graph/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pyproject.toml b/pyproject.toml index b538e26821..47fcaf18bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,6 +135,10 @@ executionEnvironments = [ ] exclude = ["examples/pydantic_ai_examples/weather_agent_gradio.py"] +[tool.mypy] +files = "tests/typed_agent.py,tests/typed_graph.py" +strict = true + [tool.pytest.ini_options] testpaths = "tests" xfail_strict = true diff --git a/tests/test_graph.py b/tests/test_graph.py new file mode 100644 index 0000000000..3b8ede23eb --- /dev/null +++ b/tests/test_graph.py @@ -0,0 +1,34 @@ +from __future__ import annotations as _annotations + +import pytest + +from pydantic_ai_graph import BaseNode, End, Graph, GraphContext + +pytestmark = pytest.mark.anyio + + +async def test_graph(): + class Float2String(BaseNode[float]): + async def run(self, ctx: GraphContext) -> String2Length: + return String2Length(str(self.input_data)) + + class String2Length(BaseNode[str]): + async def run(self, ctx: GraphContext) -> Double: + return Double(len(self.input_data)) + + class Double(BaseNode[int, int]): + async def run(self, ctx: GraphContext) -> String2Length | End[int]: + if self.input_data == 7: + return String2Length('x' * 21) + else: + return End(self.input_data * 2) + + g = Graph[float, int]( + Float2String, + String2Length, + Double, + ) + # len('3.14') * 2 == 8 + assert await g.run(3.14) == (8, None) + # len('3.14159') == 7, 21 * 2 == 42 + assert await g.run(3.14159) == (42, None) diff --git a/tests/typed_graph.py b/tests/typed_graph.py new file mode 100644 index 0000000000..e6974e3070 --- /dev/null +++ b/tests/typed_graph.py @@ -0,0 +1,63 @@ +from __future__ import annotations as _annotations + +from typing import assert_type + +from pydantic_ai_graph import BaseNode, End, Graph, GraphContext + + +class Float2String(BaseNode[float]): + async def run(self, ctx: GraphContext) -> String2Length: + return String2Length(str(self.input_data)) + + +class String2Length(BaseNode[str]): + async def run(self, ctx: GraphContext) -> Double: + return Double(len(self.input_data)) + + +class Double(BaseNode[int, int]): + async def run(self, ctx: GraphContext) -> String2Length | End[int]: + if self.input_data == 7: + return String2Length('x' * 21) + else: + return End(self.input_data * 2) + + +def use_double(node: BaseNode[int, int]) -> None: + """Shoe that `Double` is valid as a `BaseNode[int, int]`.""" + + +use_double(Double(1)) + + +g1 = Graph[float, int]( + Float2String, + String2Length, + Double, +) +assert_type(g1, Graph[float, int]) + + +g2 = Graph(Float2String, Double) +assert_type(g2, Graph[float, int]) + +g3 = Graph( + Float2String, + Double, + String2Length, +) +MYPY = False +if MYPY: + # with mypy the presence of `String2Length` makes the output type Any + assert_type(g3, Graph[float]) # pyright: ignore[reportAssertTypeFailure] +else: + # pyright works correct and uses `Double` to infer the output type + assert_type(g3, Graph[float, int]) + +g4 = Graph( + Float2String, + String2Length, + Double, +) +# because String2Length came before Double, the output type is Any +assert_type(g4, Graph[float]) From af0ba327e949cf638f4aaa5a43c37b6e16617007 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 22 Dec 2024 20:37:37 +0000 Subject: [PATCH 04/57] add history and logfire --- .../pydantic_ai_graph/__init__.py | 3 +- pydantic_ai_graph/pydantic_ai_graph/graph.py | 102 +++++++++++------- pydantic_ai_graph/pydantic_ai_graph/nodes.py | 2 +- pydantic_ai_graph/pydantic_ai_graph/state.py | 39 +++++-- pydantic_ai_graph/pyproject.toml | 2 +- tests/conftest.py | 5 +- tests/test_graph.py | 69 +++++++++++- uprev.py | 14 ++- uv.lock | 2 +- 9 files changed, 184 insertions(+), 54 deletions(-) diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py index 6345bd7090..aaef92ef7a 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/__init__.py +++ b/pydantic_ai_graph/pydantic_ai_graph/__init__.py @@ -1,4 +1,5 @@ from .graph import Graph from .nodes import BaseNode, End, GraphContext +from .state import Snapshot -__all__ = 'BaseNode', 'End', 'GraphContext', 'Graph' +__all__ = 'BaseNode', 'End', 'GraphContext', 'Graph', 'Snapshot' diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index e21bfc7b54..acccccbbbe 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -4,9 +4,12 @@ import inspect import types from dataclasses import dataclass +from datetime import datetime, timezone from pathlib import Path +from time import perf_counter from typing import TYPE_CHECKING, Any, Generic +import logfire_api from typing_extensions import TypeVar, assert_never from . import _utils @@ -18,9 +21,11 @@ GraphOutputT, NodeDef, ) -from .state import StateT +from .state import Snapshot, StateT __all__ = ('Graph',) + +_logfire = logfire_api.Logfire(otel_scope='pydantic-ai-graph') GraphInputT = TypeVar('GraphInputT', default=Any) @@ -31,15 +36,13 @@ class Graph(Generic[GraphInputT, GraphOutputT, DepsT, StateT]): first_node: NodeDef[Any, Any, DepsT, StateT] nodes: dict[str, NodeDef[Any, Any, DepsT, StateT]] - state_type: type[StateT] | None + name: str | None - # noinspection PyUnusedLocal def __init__( self, first_node: type[BaseNode[GraphInputT, Any, DepsT, StateT]], *other_nodes: type[BaseNode[Any, GraphOutputT, DepsT, StateT]], - deps_type: type[DepsT] | None = None, - state_type: type[StateT] | None = None, + name: str | None = None, ): parent_namespace = get_parent_namespace(inspect.currentframe()) self.first_node = first_node.get_node_def(parent_namespace) @@ -49,44 +52,71 @@ def __init__( nodes[node_def.node_id] = node_def self._check() - self.state_type = state_type + self.name = name async def run( - self, input_data: GraphInputT, deps: DepsT = None, state: StateT = None - ) -> tuple[GraphOutputT, StateT]: + self, + input_data: GraphInputT, + deps: DepsT = None, + state: StateT = None, + history: list[Snapshot] | None = None, + ) -> tuple[GraphOutputT, list[Snapshot]]: current_node_def = self.first_node current_node = current_node_def.node(input_data) ctx = GraphContext(deps, state) - while True: - # noinspection PyUnresolvedReferences - next_node = await current_node.run(ctx) - if isinstance(next_node, End): - if current_node_def.can_end: - return next_node.data, ctx.state - else: - raise ValueError(f'Node {current_node_def.node_id} cannot end the graph') - elif isinstance(next_node, BaseNode): - next_node_id = next_node.get_id() - try: - next_node_def = self.nodes[next_node_id] - except KeyError as e: - raise ValueError( - f'Node {current_node_def.node_id} cannot go to {next_node_id} which is not in the Graph' - ) from e - - if not current_node_def.dest_any and next_node_id not in current_node_def.next_node_ids: - raise ValueError( - f'Node {current_node_def.node_id} cannot go to {next_node_id} which is not in its ' - f'list of allowed next nodes' + if history: + run_history = history[:] + else: + run_history = [] + + with _logfire.span( + '{graph_name} run {input=}', + graph_name=self.name or 'graph', + input=input_data, + graph=self, + ) as run_span: + while True: + with _logfire.span('run node {node_id}', node_id=current_node_def.node_id): + start_ts = datetime.now(tz=timezone.utc) + start = perf_counter() + # noinspection PyUnresolvedReferences + next_node = await current_node.run(ctx) + duration = perf_counter() - start + + if isinstance(next_node, End): + if current_node_def.can_end: + run_history.append( + Snapshot.from_state(current_node_def.node_id, None, start_ts, duration, ctx.state) + ) + run_span.set_attribute('history', run_history) + return next_node.data, run_history + else: + raise ValueError(f'Node {current_node_def.node_id} cannot end the graph') + elif isinstance(next_node, BaseNode): + next_node_id = next_node.get_id() + run_history.append( + Snapshot.from_state(current_node_def.node_id, next_node_id, start_ts, duration, ctx.state) ) - - current_node_def = next_node_def - current_node = next_node - else: - if TYPE_CHECKING: - assert_never(next_node) + try: + next_node_def = self.nodes[next_node_id] + except KeyError as e: + raise ValueError( + f'Node {current_node_def.node_id} cannot go to {next_node_id} which is not in the Graph' + ) from e + + if not current_node_def.dest_any and next_node_id not in current_node_def.next_node_ids: + raise ValueError( + f'Node {current_node_def.node_id} cannot go to {next_node_id} which is not in its ' + f'list of allowed next nodes' + ) + + current_node_def = next_node_def + current_node = next_node else: - raise TypeError(f'Invalid node type: {type(next_node)} expected BaseNode or End') + if TYPE_CHECKING: + assert_never(next_node) + else: + raise TypeError(f'Invalid node type: {type(next_node)} expected BaseNode or End') def mermaid_code(self) -> str: lines = ['graph TD'] diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_ai_graph/pydantic_ai_graph/nodes.py index c58102dfaf..376d9db1e0 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/nodes.py +++ b/pydantic_ai_graph/pydantic_ai_graph/nodes.py @@ -73,7 +73,7 @@ async def run(self, ctx: GraphContext[DepsT, StateT]) -> BaseNode[Any, Any, Deps @classmethod @cache def get_id(cls) -> str: - return cls.node_id or cls.__qualname__ + return cls.node_id or cls.__name__ @classmethod def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[Any, Any, DepsT, StateT]: diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py index f9fac1199f..133085ec8b 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -1,17 +1,44 @@ -from abc import ABC +from __future__ import annotations as _annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime from typing_extensions import TypeVar -__all__ = 'AbstractState', 'StateT' +__all__ = 'AbstractState', 'StateT', 'Snapshot' class AbstractState(ABC): """Abstract class for a state object.""" - def __init__(self): - pass - - # TODO serializing and deserialize state + @abstractmethod + def serialize(self) -> bytes | None: + """Serialize the state object.""" + raise NotImplementedError StateT = TypeVar('StateT', None, AbstractState, default=None) + + +@dataclass +class Snapshot: + """Snapshot of a graph.""" + + last_node_id: str + next_node_id: str | None + start_ts: datetime + duration: float + state: bytes | None = None + + @classmethod + def from_state( + cls, last_node_id: str, next_node_id: str | None, start_ts: datetime, duration: float, state: StateT + ) -> Snapshot: + return cls( + last_node_id=last_node_id, + next_node_id=next_node_id, + start_ts=start_ts, + duration=duration, + state=state.serialize() if state is not None else None, + ) diff --git a/pydantic_ai_graph/pyproject.toml b/pydantic_ai_graph/pyproject.toml index 2231df5cd1..cb047a9a8e 100644 --- a/pydantic_ai_graph/pyproject.toml +++ b/pydantic_ai_graph/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "pydantic-ai-graph" -version = "0.0.1" +version = "0.0.14" description = "Graph and state machine library" authors = [ { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, diff --git a/tests/conftest.py b/tests/conftest.py index 4a219ce00d..df874f8be1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,7 @@ import pydantic_ai.models -__all__ = 'IsNow', 'TestEnv', 'ClientWithHandler', 'try_import' +__all__ = 'IsNow', 'IsFloat', 'TestEnv', 'ClientWithHandler', 'try_import' pydantic_ai.models.ALLOW_MODEL_REQUESTS = False @@ -28,8 +28,9 @@ if TYPE_CHECKING: def IsNow(*args: Any, **kwargs: Any) -> datetime: ... + def IsFloat(*args: Any, **kwargs: Any) -> float: ... else: - from dirty_equals import IsNow + from dirty_equals import IsFloat, IsNow try: from logfire.testing import CaptureLogfire diff --git a/tests/test_graph.py b/tests/test_graph.py index 3b8ede23eb..af27f03cde 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,8 +1,13 @@ from __future__ import annotations as _annotations +from datetime import timezone + import pytest +from inline_snapshot import snapshot + +from pydantic_ai_graph import BaseNode, End, Graph, GraphContext, Snapshot -from pydantic_ai_graph import BaseNode, End, Graph, GraphContext +from .conftest import IsFloat, IsNow pytestmark = pytest.mark.anyio @@ -28,7 +33,65 @@ async def run(self, ctx: GraphContext) -> String2Length | End[int]: String2Length, Double, ) + result, history = await g.run(3.14) # len('3.14') * 2 == 8 - assert await g.run(3.14) == (8, None) + assert result == 8 + assert history == snapshot( + [ + Snapshot( + last_node_id='Float2String', + next_node_id='String2Length', + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(gt=0, lt=1e-5), + ), + Snapshot( + last_node_id='String2Length', + next_node_id='Double', + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(gt=0, lt=1e-5), + ), + Snapshot( + last_node_id='Double', + next_node_id=None, + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(gt=0, lt=1e-5), + ), + ] + ) + result, history = await g.run(3.14159) # len('3.14159') == 7, 21 * 2 == 42 - assert await g.run(3.14159) == (42, None) + assert result == 42 + assert history == snapshot( + [ + Snapshot( + last_node_id='Float2String', + next_node_id='String2Length', + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(gt=0, lt=1e-5), + ), + Snapshot( + last_node_id='String2Length', + next_node_id='Double', + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(gt=0, lt=1e-5), + ), + Snapshot( + last_node_id='Double', + next_node_id='String2Length', + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(gt=0, lt=1e-5), + ), + Snapshot( + last_node_id='String2Length', + next_node_id='Double', + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(gt=0, lt=1e-5), + ), + Snapshot( + last_node_id='Double', + next_node_id=None, + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(gt=0, lt=1e-5), + ), + ] + ) diff --git a/uprev.py b/uprev.py index 97562b5d3f..3eeca56d01 100644 --- a/uprev.py +++ b/uprev.py @@ -3,8 +3,9 @@ Because we have multiple packages which depend on one-another, we have to update the version number in: * pyproject.toml -* pydantic_ai_examples/pyproject.toml +* examples/pyproject.toml * pydantic_ai_slim/pyproject.toml +* pydantic_ai_graph/pyproject.toml Usage: @@ -67,10 +68,15 @@ def replace_deps_version(text: str) -> tuple[str, int]: slim_pp_text = slim_pp.read_text() slim_pp_text, count_slim = replace_deps_version(slim_pp_text) -if count_root == 2 and count_ex == 2 and count_slim == 1: +graph_pp = ROOT_DIR / 'pydantic_ai_graph' / 'pyproject.toml' +graph_pp_text = graph_pp.read_text() +graph_pp_text, count_graph = replace_deps_version(graph_pp_text) + +if count_root == 2 and count_ex == 2 and count_slim == 1 and count_graph == 1: root_pp.write_text(root_pp_text) examples_pp.write_text(examples_pp_text) slim_pp.write_text(slim_pp_text) + graph_pp.write_text(graph_pp_text) print('running `make sync`...') subprocess.run(['make', 'sync'], check=True) print(f'running `git switch -c uprev-{version}`...') @@ -79,7 +85,8 @@ def replace_deps_version(text: str) -> tuple[str, int]: f'SUCCESS: replaced version in\n' f' {root_pp}\n' f' {examples_pp}\n' - f' {slim_pp}' + f' {slim_pp}\n' + f' {graph_pp}' ) else: print( @@ -87,6 +94,7 @@ def replace_deps_version(text: str) -> tuple[str, int]: f' {count_root} version references in {root_pp} (expected 2)\n' f' {count_ex} version references in {examples_pp} (expected 2)\n' f' {count_slim} version references in {slim_pp} (expected 1)', + f' {count_graph} version references in {graph_pp} (expected 1)', file=sys.stderr, ) sys.exit(1) diff --git a/uv.lock b/uv.lock index 37ed8c6cee..38dff08aa3 100644 --- a/uv.lock +++ b/uv.lock @@ -2098,7 +2098,7 @@ requires-dist = [ [[package]] name = "pydantic-ai-graph" -version = "0.0.1" +version = "0.0.14" source = { editable = "pydantic_ai_graph" } dependencies = [ { name = "httpx" }, From 877ee3657435599545a97d027d1192455e59fbce Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 22 Dec 2024 23:41:50 +0000 Subject: [PATCH 05/57] add example, alter types --- .../email_extract_graph.py | 145 ++++++++++++++++++ .../pydantic_ai_graph/__init__.py | 4 +- pydantic_ai_graph/pydantic_ai_graph/_utils.py | 10 ++ pydantic_ai_graph/pydantic_ai_graph/graph.py | 29 ++-- pydantic_ai_graph/pydantic_ai_graph/nodes.py | 25 ++- pydantic_ai_graph/pydantic_ai_graph/state.py | 3 +- tests/test_graph.py | 8 +- tests/typed_graph.py | 47 +++--- 8 files changed, 208 insertions(+), 63 deletions(-) create mode 100644 examples/pydantic_ai_examples/email_extract_graph.py diff --git a/examples/pydantic_ai_examples/email_extract_graph.py b/examples/pydantic_ai_examples/email_extract_graph.py new file mode 100644 index 0000000000..c91613b4ec --- /dev/null +++ b/examples/pydantic_ai_examples/email_extract_graph.py @@ -0,0 +1,145 @@ +from __future__ import annotations as _annotations + +import asyncio +from datetime import datetime, timedelta + +import logfire +from devtools import debug +from pydantic import BaseModel +from pydantic_ai_graph import AbstractState, BaseNode, End, Graph, GraphContext + +from pydantic_ai import Agent, RunContext + +logfire.configure(send_to_logfire='if-token-present') + + +class EventDetails(BaseModel): + title: str + location: str + start_ts: datetime + end_ts: datetime + + +class State(AbstractState, BaseModel): + email_content: str + skip_events: list[str] = [] + attempt: int = 0 + + def serialize(self) -> bytes | None: + return self.model_dump_json(exclude={'email_content'}).encode() + + +class RawEventDetails(BaseModel): + title: str + location: str + start_ts: str + duration: str + + +extract_agent = Agent('openai:gpt-4o', result_type=RawEventDetails, deps_type=list[str]) + + +@extract_agent.system_prompt +def extract_system_prompt(ctx: RunContext[list[str]]): + prompt = 'Extract event details from the email body.' + if ctx.deps: + skip_events = '\n'.join(ctx.deps) + prompt += f'\n\nDo not return the following events:\n{skip_events}' + return prompt + + +class ExtractEvent(BaseNode[State, None]): + async def run(self, ctx: GraphContext[State]) -> CleanEvent: + event = await extract_agent.run( + ctx.state.email_content, deps=ctx.state.skip_events + ) + return CleanEvent(event.data) + + +# agent used to extract the timestamp from the string in `CleanEvent` +timestamp_agent = Agent('openai:gpt-4o', result_type=datetime) + + +@timestamp_agent.system_prompt +def timestamp_system_prompt(): + return f'Extract the timestamp from the string, the current timestamp is: {datetime.now().isoformat()}' + + +# agent used to extract the duration from the string in `CleanEvent` +duration_agent = Agent( + 'openai:gpt-4o', + result_type=timedelta, + system_prompt='Extract the duration from the string as an ISO 8601 interval.', +) + + +class CleanEvent(BaseNode[State, RawEventDetails]): + async def run(self, ctx: GraphContext[State]) -> InspectEvent: + start_ts, duration = await asyncio.gather( + timestamp_agent.run(self.input_data.start_ts), + duration_agent.run(self.input_data.duration), + ) + return InspectEvent( + EventDetails( + title=self.input_data.title, + location=self.input_data.location, + start_ts=start_ts.data, + end_ts=start_ts.data + duration.data, + ) + ) + + +class InspectEvent(BaseNode[State, EventDetails, EventDetails | None]): + async def run( + self, ctx: GraphContext[State] + ) -> ExtractEvent | End[EventDetails | None]: + now = datetime.now() + if self.input_data.start_ts.tzinfo is not None: + now = now.astimezone(self.input_data.start_ts.tzinfo) + + if self.input_data.start_ts > now: + return End(self.input_data) + ctx.state.attempt += 1 + if ctx.state.attempt > 2: + return End(None) + else: + ctx.state.skip_events.append(self.input_data.title) + return ExtractEvent(None) + + +graph = Graph[State, None, EventDetails | None]( + ExtractEvent, + CleanEvent, + InspectEvent, +) +print(graph.mermaid_code()) + +email = """ +Hi Samuel, + +I hope this message finds you well! I wanted to share a quick update on our recent and upcoming team events. + +Firstly, a big thank you to everyone who participated in last month's +Team Building Retreat held on November 15th 2024 for 1 day. +It was a fantastic opportunity to enhance our collaboration and communication skills while having fun. Your +feedback was incredibly positive, and we're already planning to make the next retreat even better! + +Looking ahead, I'm excited to invite you all to our Annual Year-End Gala on January 20th 2025. +This event will be held at the Grand City Ballroom starting at 6 PM until 8pm. It promises to be an evening full +of entertainment, good food, and great company, celebrating the achievements and hard work of our amazing team +over the past year. + +Please mark your calendars and RSVP by January 10th. I hope to see all of you there! + +Best regards, +""" + + +async def main(): + state = State(email_content=email) + result, history = await graph.run(None, state) + debug(result, history) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py index aaef92ef7a..6a6a4ef9f4 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/__init__.py +++ b/pydantic_ai_graph/pydantic_ai_graph/__init__.py @@ -1,5 +1,5 @@ from .graph import Graph from .nodes import BaseNode, End, GraphContext -from .state import Snapshot +from .state import AbstractState, Snapshot -__all__ = 'BaseNode', 'End', 'GraphContext', 'Graph', 'Snapshot' +__all__ = 'BaseNode', 'End', 'GraphContext', 'Graph', 'Snapshot', 'AbstractState' diff --git a/pydantic_ai_graph/pydantic_ai_graph/_utils.py b/pydantic_ai_graph/pydantic_ai_graph/_utils.py index b58e142f6e..d7474494a5 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/_utils.py +++ b/pydantic_ai_graph/pydantic_ai_graph/_utils.py @@ -35,3 +35,13 @@ def comma_and(items: list[str]) -> str: else: # oxford comma ¯\_(ツ)_/¯ return ', '.join(items[:-1]) + ', and ' + items[-1] + + +_NoneType = type(None) + + +def type_arg_name(arg: Any) -> str: + if arg is _NoneType: + return 'None' + else: + return arg.__name__ diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index acccccbbbe..5774543a91 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -13,14 +13,7 @@ from typing_extensions import TypeVar, assert_never from . import _utils -from .nodes import ( - BaseNode, - DepsT, - End, - GraphContext, - GraphOutputT, - NodeDef, -) +from .nodes import BaseNode, End, GraphContext, GraphOutputT, NodeDef from .state import Snapshot, StateT __all__ = ('Graph',) @@ -31,18 +24,19 @@ # noinspection PyTypeHints @dataclass(init=False) -class Graph(Generic[GraphInputT, GraphOutputT, DepsT, StateT]): +class Graph(Generic[StateT, GraphInputT, GraphOutputT]): """Definition of a graph.""" - first_node: NodeDef[Any, Any, DepsT, StateT] - nodes: dict[str, NodeDef[Any, Any, DepsT, StateT]] + first_node: NodeDef[StateT, Any, Any] + nodes: dict[str, NodeDef[StateT, Any, Any]] name: str | None def __init__( self, - first_node: type[BaseNode[GraphInputT, Any, DepsT, StateT]], - *other_nodes: type[BaseNode[Any, GraphOutputT, DepsT, StateT]], + first_node: type[BaseNode[StateT, GraphInputT, GraphOutputT]], + *other_nodes: type[BaseNode[StateT, Any, GraphOutputT]], name: str | None = None, + state_type: type[StateT] | None = None, ): parent_namespace = get_parent_namespace(inspect.currentframe()) self.first_node = first_node.get_node_def(parent_namespace) @@ -57,13 +51,12 @@ def __init__( async def run( self, input_data: GraphInputT, - deps: DepsT = None, state: StateT = None, history: list[Snapshot] | None = None, ) -> tuple[GraphOutputT, list[Snapshot]]: current_node_def = self.first_node current_node = current_node_def.node(input_data) - ctx = GraphContext(deps, state) + ctx = GraphContext(state) if history: run_history = history[:] else: @@ -123,6 +116,8 @@ def mermaid_code(self) -> str: # order of destination nodes should match their order in `self.nodes` node_order = {nid: index for index, nid in enumerate(self.nodes.keys())} for node_id, node in self.nodes.items(): + if node_id == self.first_node.node_id: + lines.append(f' START --> {node_id}') if node.dest_any: for next_node_id in self.nodes: lines.append(f' {node_id} --> {next_node_id}') @@ -141,9 +136,9 @@ def mermaid_image(self, mermaid_ink_params: dict[str, str | int] | None = None) response.raise_for_status() return response.content - def mermaid_save(self, path: Path, mermaid_ink_params: dict[str, str | int] | None = None) -> None: + def mermaid_save(self, path: Path | str, mermaid_ink_params: dict[str, str | int] | None = None) -> None: image_data = self.mermaid_image(mermaid_ink_params) - path.write_bytes(image_data) + Path(path).write_bytes(image_data) def _check(self): bad_edges: dict[str, list[str]] = {} diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_ai_graph/pydantic_ai_graph/nodes.py index 376d9db1e0..9ed72dfc23 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/nodes.py +++ b/pydantic_ai_graph/pydantic_ai_graph/nodes.py @@ -13,7 +13,6 @@ __all__ = ( 'NodeInputT', 'GraphOutputT', - 'DepsT', 'GraphContext', 'End', 'BaseNode', @@ -22,15 +21,13 @@ NodeInputT = TypeVar('NodeInputT', default=Any) GraphOutputT = TypeVar('GraphOutputT', default=Any) -DepsT = TypeVar('DepsT', default=None) # noinspection PyTypeHints @dataclass -class GraphContext(Generic[DepsT, StateT]): +class GraphContext(Generic[StateT]): """Context for a graph.""" - deps: DepsT state: StateT @@ -48,17 +45,17 @@ class _BaseNodeMeta(ABCMeta): def __repr__(cls): base: Any = cls.__orig_bases__[0] # type: ignore args = get_args(base) - if len(args) == 4 and args[3] is None: - if args[2] is None: - args = args[:2] + if len(args) == 3 and args[2] is Any: + if args[1] is Any: + args = args[:1] else: - args = args[:3] - args = ', '.join(a.__name__ for a in args) + args = args[:2] + args = ', '.join(_utils.type_arg_name(a) for a in args) return f'{cls.__name__}({base.__name__}[{args}])' # noinspection PyTypeHints -class BaseNode(Generic[NodeInputT, GraphOutputT, DepsT, StateT], metaclass=_BaseNodeMeta): +class BaseNode(Generic[StateT, NodeInputT, GraphOutputT], metaclass=_BaseNodeMeta): """Base class for a node.""" node_id: ClassVar[str | None] = None @@ -68,7 +65,7 @@ def __init__(self, input_data: NodeInputT) -> None: self.input_data = input_data @abstractmethod - async def run(self, ctx: GraphContext[DepsT, StateT]) -> BaseNode[Any, Any, DepsT, StateT] | End[GraphOutputT]: ... + async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any, Any] | End[GraphOutputT]: ... @classmethod @cache @@ -76,7 +73,7 @@ def get_id(cls) -> str: return cls.node_id or cls.__name__ @classmethod - def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[Any, Any, DepsT, StateT]: + def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, Any, Any]: type_hints = get_type_hints(cls.run, localns=local_ns) next_node_ids: set[str] = set() can_end: bool = False @@ -103,13 +100,13 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[Any, Any, Deps # noinspection PyTypeHints @dataclass -class NodeDef(ABC, Generic[NodeInputT, GraphOutputT, DepsT, StateT]): +class NodeDef(ABC, Generic[StateT, NodeInputT, GraphOutputT]): """Definition of a node. Used by [`Graph`][pydantic_ai_graph.graph.Graph] store information about a node. """ - node: type[BaseNode[NodeInputT, GraphOutputT, DepsT, StateT]] + node: type[BaseNode[StateT, NodeInputT, GraphOutputT]] node_id: str next_node_ids: set[str] can_end: bool diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py index 133085ec8b..cf64a32155 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime +from typing import Union from typing_extensions import TypeVar @@ -18,7 +19,7 @@ def serialize(self) -> bytes | None: raise NotImplementedError -StateT = TypeVar('StateT', None, AbstractState, default=None) +StateT = TypeVar('StateT', bound=Union[None, AbstractState], default=None) @dataclass diff --git a/tests/test_graph.py b/tests/test_graph.py index af27f03cde..2e579f6410 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -13,22 +13,22 @@ async def test_graph(): - class Float2String(BaseNode[float]): + class Float2String(BaseNode[None, float]): async def run(self, ctx: GraphContext) -> String2Length: return String2Length(str(self.input_data)) - class String2Length(BaseNode[str]): + class String2Length(BaseNode[None, str]): async def run(self, ctx: GraphContext) -> Double: return Double(len(self.input_data)) - class Double(BaseNode[int, int]): + class Double(BaseNode[None, int, int]): async def run(self, ctx: GraphContext) -> String2Length | End[int]: if self.input_data == 7: return String2Length('x' * 21) else: return End(self.input_data * 2) - g = Graph[float, int]( + g = Graph[None, float, int]( Float2String, String2Length, Double, diff --git a/tests/typed_graph.py b/tests/typed_graph.py index e6974e3070..806347f8fb 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -1,63 +1,60 @@ from __future__ import annotations as _annotations +from dataclasses import dataclass from typing import assert_type from pydantic_ai_graph import BaseNode, End, Graph, GraphContext -class Float2String(BaseNode[float]): +class Float2String(BaseNode[None, float]): async def run(self, ctx: GraphContext) -> String2Length: return String2Length(str(self.input_data)) -class String2Length(BaseNode[str]): +class String2Length(BaseNode[None, str]): async def run(self, ctx: GraphContext) -> Double: return Double(len(self.input_data)) -class Double(BaseNode[int, int]): - async def run(self, ctx: GraphContext) -> String2Length | End[int]: +@dataclass +class X: + v: int + + +class Double(BaseNode[None, int, X]): + async def run(self, ctx: GraphContext) -> String2Length | End[X]: if self.input_data == 7: return String2Length('x' * 21) else: - return End(self.input_data * 2) + return End(X(self.input_data * 2)) -def use_double(node: BaseNode[int, int]) -> None: - """Shoe that `Double` is valid as a `BaseNode[int, int]`.""" +def use_double(node: BaseNode[None, int, X]) -> None: + """Shoe that `Double` is valid as a `BaseNode[None, int, X]`.""" + print(node) use_double(Double(1)) -g1 = Graph[float, int]( +g1 = Graph[None, float, X]( Float2String, String2Length, Double, ) -assert_type(g1, Graph[float, int]) +assert_type(g1, Graph[None, float, X]) -g2 = Graph(Float2String, Double) -assert_type(g2, Graph[float, int]) +g2 = Graph(Double) +assert_type(g2, Graph[None, int, X]) g3 = Graph( - Float2String, - Double, - String2Length, -) -MYPY = False -if MYPY: - # with mypy the presence of `String2Length` makes the output type Any - assert_type(g3, Graph[float]) # pyright: ignore[reportAssertTypeFailure] -else: - # pyright works correct and uses `Double` to infer the output type - assert_type(g3, Graph[float, int]) - -g4 = Graph( Float2String, String2Length, Double, ) # because String2Length came before Double, the output type is Any -assert_type(g4, Graph[float]) +assert_type(g3, Graph[None, float]) + +Graph[None, float, bytes](Float2String, String2Length, Double) # type: ignore[arg-type] +Graph[None, int, str](Double) # type: ignore[arg-type] From 5cf3ad0fb1203ee290209631c02992c0132e95d7 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 23 Dec 2024 00:13:37 +0000 Subject: [PATCH 06/57] fix dependencies --- .github/workflows/ci.yml | 4 ++-- pydantic_ai_graph/pydantic_ai_graph/graph.py | 2 +- pydantic_ai_graph/pydantic_ai_graph/state.py | 10 ++++++++-- pydantic_ai_slim/pyproject.toml | 4 ++++ pyproject.toml | 3 ++- tests/test_graph.py | 4 ++-- uv.lock | 8 ++++++-- 7 files changed, 25 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0b6b855f4f..7de47d3257 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -128,8 +128,8 @@ jobs: - run: mkdir coverage - # run tests with just `pydantic-ai-slim` dependencies - - run: uv run --package pydantic-ai-slim coverage run -m pytest + # run tests with just `pydantic-ai-slim` and `pydantic-ai-graph` dependencies + - run: uv run --package pydantic-ai-slim --package pydantic-ai-graph coverage run -m pytest env: COVERAGE_FILE: coverage/.coverage.${{ runner.os }}-py${{ matrix.python-version }}-slim diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index 5774543a91..e0a9d10f8e 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -79,7 +79,7 @@ async def run( if isinstance(next_node, End): if current_node_def.can_end: run_history.append( - Snapshot.from_state(current_node_def.node_id, None, start_ts, duration, ctx.state) + Snapshot.from_state(current_node_def.node_id, 'END', start_ts, duration, ctx.state) ) run_span.set_attribute('history', run_history) return next_node.data, run_history diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py index cf64a32155..00789887ed 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -27,14 +27,14 @@ class Snapshot: """Snapshot of a graph.""" last_node_id: str - next_node_id: str | None + next_node_id: str start_ts: datetime duration: float state: bytes | None = None @classmethod def from_state( - cls, last_node_id: str, next_node_id: str | None, start_ts: datetime, duration: float, state: StateT + cls, last_node_id: str, next_node_id: str, start_ts: datetime, duration: float, state: StateT ) -> Snapshot: return cls( last_node_id=last_node_id, @@ -43,3 +43,9 @@ def from_state( duration=duration, state=state.serialize() if state is not None else None, ) + + def summary(self) -> str: + s = f'{self.last_node_id} -> {self.next_node_id}' + if self.duration > 1e-5: + s += f' ({self.duration:.6f}s)' + return s diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index abc4c0ba99..693d27aaec 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ ] [project.optional-dependencies] +graph = ["pydantic-ai-graph==0.0.14"] openai = ["openai>=1.54.3"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"] anthropic = ["anthropic>=0.40.0"] @@ -65,3 +66,6 @@ dev = [ [tool.hatch.build.targets.wheel] packages = ["pydantic_ai"] + +[tool.uv.sources] +pydantic-ai-graph = { workspace = true } diff --git a/pyproject.toml b/pyproject.toml index 47fcaf18bb..354b4b4448 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ classifiers = [ ] requires-python = ">=3.9" -dependencies = ["pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral]==0.0.18"] +dependencies = ["pydantic-ai-slim[graph,openai,vertexai,groq,anthropic,mistral]==0.0.18"] [project.urls] Homepage = "https://ai.pydantic.dev" @@ -51,6 +51,7 @@ logfire = ["logfire>=2.3"] [tool.uv.sources] pydantic-ai-slim = { workspace = true } +pydantic-ai-graph = { workspace = true } pydantic-ai-examples = { workspace = true } [tool.uv.workspace] diff --git a/tests/test_graph.py b/tests/test_graph.py index 2e579f6410..dafdff5dce 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -52,7 +52,7 @@ async def run(self, ctx: GraphContext) -> String2Length | End[int]: ), Snapshot( last_node_id='Double', - next_node_id=None, + next_node_id='END', start_ts=IsNow(tz=timezone.utc), duration=IsFloat(gt=0, lt=1e-5), ), @@ -89,7 +89,7 @@ async def run(self, ctx: GraphContext) -> String2Length | End[int]: ), Snapshot( last_node_id='Double', - next_node_id=None, + next_node_id='END', start_ts=IsNow(tz=timezone.utc), duration=IsFloat(gt=0, lt=1e-5), ), diff --git a/uv.lock b/uv.lock index 38dff08aa3..2c36b984b3 100644 --- a/uv.lock +++ b/uv.lock @@ -2020,7 +2020,7 @@ name = "pydantic-ai" version = "0.0.18" source = { editable = "." } dependencies = [ - { name = "pydantic-ai-slim", extra = ["anthropic", "groq", "mistral", "openai", "vertexai"] }, + { name = "pydantic-ai-slim", extra = ["anthropic", "graph", "groq", "mistral", "openai", "vertexai"] }, ] [package.optional-dependencies] @@ -2049,7 +2049,7 @@ lint = [ requires-dist = [ { name = "logfire", marker = "extra == 'logfire'", specifier = ">=2.3" }, { name = "pydantic-ai-examples", marker = "extra == 'examples'", editable = "examples" }, - { name = "pydantic-ai-slim", extras = ["openai", "vertexai", "groq", "anthropic", "mistral"], editable = "pydantic_ai_slim" }, + { name = "pydantic-ai-slim", extras = ["graph", "openai", "vertexai", "groq", "anthropic", "mistral"], editable = "pydantic_ai_slim" }, ] [package.metadata.requires-dev] @@ -2129,6 +2129,9 @@ dependencies = [ anthropic = [ { name = "anthropic" }, ] +graph = [ + { name = "pydantic-ai-graph" }, +] groq = [ { name = "groq" }, ] @@ -2174,6 +2177,7 @@ requires-dist = [ { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.2.5" }, { name = "openai", marker = "extra == 'openai'", specifier = ">=1.54.3" }, { name = "pydantic", specifier = ">=2.10" }, + { name = "pydantic-ai-graph", marker = "extra == 'graph'", editable = "pydantic_ai_graph" }, { name = "requests", marker = "extra == 'vertexai'", specifier = ">=2.32.3" }, ] From 544b6c804bafc8e3b6a52756fa9a452e200db232 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 23 Dec 2024 00:17:37 +0000 Subject: [PATCH 07/57] fix ci deps --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7de47d3257..12439b8320 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -129,7 +129,7 @@ jobs: - run: mkdir coverage # run tests with just `pydantic-ai-slim` and `pydantic-ai-graph` dependencies - - run: uv run --package pydantic-ai-slim --package pydantic-ai-graph coverage run -m pytest + - run: uv run --package pydantic-ai-slim --extra graph coverage run -m pytest env: COVERAGE_FILE: coverage/.coverage.${{ runner.os }}-py${{ matrix.python-version }}-slim From 7288cc90013d5eb0a41b8be4f8425dc13c81d811 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 23 Dec 2024 00:23:19 +0000 Subject: [PATCH 08/57] fix tests for other versions --- pydantic_ai_graph/pydantic_ai_graph/_utils.py | 6 +++++- tests/test_graph.py | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_graph/pydantic_ai_graph/_utils.py b/pydantic_ai_graph/pydantic_ai_graph/_utils.py index d7474494a5..1136108809 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/_utils.py +++ b/pydantic_ai_graph/pydantic_ai_graph/_utils.py @@ -1,6 +1,10 @@ +from __future__ import annotations as _annotations + import sys import types -from typing import Any, TypeAliasType, Union, get_args, get_origin +from typing import Any, Union, get_args, get_origin + +from typing_extensions import TypeAliasType def get_union_args(tp: Any) -> tuple[Any, ...]: diff --git a/tests/test_graph.py b/tests/test_graph.py index dafdff5dce..5045d81027 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations from datetime import timezone +from typing import Union import pytest from inline_snapshot import snapshot @@ -22,7 +23,7 @@ async def run(self, ctx: GraphContext) -> Double: return Double(len(self.input_data)) class Double(BaseNode[None, int, int]): - async def run(self, ctx: GraphContext) -> String2Length | End[int]: + async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noqa: UP007 if self.input_data == 7: return String2Length('x' * 21) else: From 0572bdaa39860026dcc2133de9f12f26d9358504 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 23 Dec 2024 00:27:05 +0000 Subject: [PATCH 09/57] change node test times --- tests/test_graph.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_graph.py b/tests/test_graph.py index 5045d81027..81b54fa070 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -43,19 +43,19 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq last_node_id='Float2String', next_node_id='String2Length', start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-5), + duration=IsFloat(gt=0, lt=1e-3), ), Snapshot( last_node_id='String2Length', next_node_id='Double', start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-5), + duration=IsFloat(gt=0, lt=1e-3), ), Snapshot( last_node_id='Double', next_node_id='END', start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-5), + duration=IsFloat(gt=0, lt=1e-3), ), ] ) @@ -68,31 +68,31 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq last_node_id='Float2String', next_node_id='String2Length', start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-5), + duration=IsFloat(gt=0, lt=1e-3), ), Snapshot( last_node_id='String2Length', next_node_id='Double', start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-5), + duration=IsFloat(gt=0, lt=1e-3), ), Snapshot( last_node_id='Double', next_node_id='String2Length', start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-5), + duration=IsFloat(gt=0, lt=1e-3), ), Snapshot( last_node_id='String2Length', next_node_id='Double', start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-5), + duration=IsFloat(gt=0, lt=1e-3), ), Snapshot( last_node_id='Double', next_node_id='END', start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-5), + duration=IsFloat(gt=0, lt=1e-3), ), ] ) From d10dc87842806f08b82f85dee3dc9ce5f9c2c73b Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Thu, 2 Jan 2025 11:17:46 -0700 Subject: [PATCH 10/57] pydantic-ai-graph - simplify public generics (#539) --- .../email_extract_graph.py | 31 +- .../pydantic_ai_graph/__init__.py | 17 +- pydantic_ai_graph/pydantic_ai_graph/_utils.py | 14 + pydantic_ai_graph/pydantic_ai_graph/graph.py | 426 +++++++++++++----- pydantic_ai_graph/pydantic_ai_graph/nodes.py | 90 ++-- pydantic_ai_graph/pydantic_ai_graph/state.py | 74 +-- tests/test_graph.py | 103 +++-- tests/typed_agent.py | 4 +- tests/typed_graph.py | 50 +- 9 files changed, 527 insertions(+), 282 deletions(-) diff --git a/examples/pydantic_ai_examples/email_extract_graph.py b/examples/pydantic_ai_examples/email_extract_graph.py index c91613b4ec..8465f5bcd9 100644 --- a/examples/pydantic_ai_examples/email_extract_graph.py +++ b/examples/pydantic_ai_examples/email_extract_graph.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import asyncio +from dataclasses import dataclass from datetime import datetime, timedelta import logfire @@ -48,7 +49,8 @@ def extract_system_prompt(ctx: RunContext[list[str]]): return prompt -class ExtractEvent(BaseNode[State, None]): +@dataclass +class ExtractEvent(BaseNode[State]): async def run(self, ctx: GraphContext[State]) -> CleanEvent: event = await extract_agent.run( ctx.state.email_content, deps=ctx.state.skip_events @@ -73,7 +75,10 @@ def timestamp_system_prompt(): ) -class CleanEvent(BaseNode[State, RawEventDetails]): +@dataclass +class CleanEvent(BaseNode[State]): + input_data: RawEventDetails + async def run(self, ctx: GraphContext[State]) -> InspectEvent: start_ts, duration = await asyncio.gather( timestamp_agent.run(self.input_data.start_ts), @@ -89,7 +94,10 @@ async def run(self, ctx: GraphContext[State]) -> InspectEvent: ) -class InspectEvent(BaseNode[State, EventDetails, EventDetails | None]): +@dataclass +class InspectEvent(BaseNode[State, EventDetails | None]): + input_data: EventDetails + async def run( self, ctx: GraphContext[State] ) -> ExtractEvent | End[EventDetails | None]: @@ -104,15 +112,18 @@ async def run( return End(None) else: ctx.state.skip_events.append(self.input_data.title) - return ExtractEvent(None) + return ExtractEvent() -graph = Graph[State, None, EventDetails | None]( - ExtractEvent, - CleanEvent, - InspectEvent, +graph = Graph[State, EventDetails | None]( + nodes=( + ExtractEvent, + CleanEvent, + InspectEvent, + ) ) -print(graph.mermaid_code()) +graph_runner = graph.get_runner(ExtractEvent) +print(graph_runner.mermaid_code()) email = """ Hi Samuel, @@ -137,7 +148,7 @@ async def run( async def main(): state = State(email_content=email) - result, history = await graph.run(None, state) + result, history = await graph_runner.run(state, None) debug(result, history) diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py index 6a6a4ef9f4..b7d79a6009 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/__init__.py +++ b/pydantic_ai_graph/pydantic_ai_graph/__init__.py @@ -1,5 +1,16 @@ -from .graph import Graph +from .graph import Graph, GraphRun, GraphRunner from .nodes import BaseNode, End, GraphContext -from .state import AbstractState, Snapshot +from .state import AbstractState, EndEvent, Step, StepOrEnd -__all__ = 'BaseNode', 'End', 'GraphContext', 'Graph', 'Snapshot', 'AbstractState' +__all__ = ( + 'Graph', + 'GraphRunner', + 'GraphRun', + 'BaseNode', + 'End', + 'GraphContext', + 'AbstractState', + 'EndEvent', + 'StepOrEnd', + 'Step', +) diff --git a/pydantic_ai_graph/pydantic_ai_graph/_utils.py b/pydantic_ai_graph/pydantic_ai_graph/_utils.py index 1136108809..d1b4671800 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/_utils.py +++ b/pydantic_ai_graph/pydantic_ai_graph/_utils.py @@ -49,3 +49,17 @@ def type_arg_name(arg: Any) -> str: return 'None' else: return arg.__name__ + + +def get_parent_namespace(frame: types.FrameType | None) -> dict[str, Any] | None: + """Attempt to get the namespace where the graph was defined. + + If the graph is defined with generics `Graph[a, b]` then another frame is inserted, and we have to skip that + to get the correct namespace. + """ + if frame is not None: + if back := frame.f_back: + if back.f_code.co_filename.endswith('/typing.py'): + return get_parent_namespace(back) + else: + return back.f_locals diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index e0a9d10f8e..a6e09e887a 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -2,169 +2,359 @@ import base64 import inspect -import types -from dataclasses import dataclass -from datetime import datetime, timezone +from dataclasses import dataclass, field from pathlib import Path from time import perf_counter -from typing import TYPE_CHECKING, Any, Generic +from typing import TYPE_CHECKING, Annotated, Any, Generic import logfire_api -from typing_extensions import TypeVar, assert_never +from annotated_types import Ge, Le +from typing_extensions import Literal, Never, ParamSpec, Protocol, TypeVar, assert_never from . import _utils -from .nodes import BaseNode, End, GraphContext, GraphOutputT, NodeDef -from .state import Snapshot, StateT +from ._utils import get_parent_namespace +from .nodes import BaseNode, End, GraphContext, NodeDef +from .state import EndEvent, StateT, Step, StepOrEnd -__all__ = ('Graph',) +__all__ = ('Graph', 'GraphRun', 'GraphRunner') _logfire = logfire_api.Logfire(otel_scope='pydantic-ai-graph') -GraphInputT = TypeVar('GraphInputT', default=Any) + +RunSignatureT = ParamSpec('RunSignatureT') +RunEndT = TypeVar('RunEndT', default=None) +NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) + + +class StartNodeProtocol(Protocol[RunSignatureT, StateT, NodeRunEndT]): + def get_id(self) -> str: ... + def __call__(self, *args: RunSignatureT.args, **kwargs: RunSignatureT.kwargs) -> BaseNode[StateT, NodeRunEndT]: ... -# noinspection PyTypeHints @dataclass(init=False) -class Graph(Generic[StateT, GraphInputT, GraphOutputT]): +class Graph(Generic[StateT, RunEndT]): """Definition of a graph.""" - first_node: NodeDef[StateT, Any, Any] - nodes: dict[str, NodeDef[StateT, Any, Any]] name: str | None + nodes: tuple[type[BaseNode[StateT, RunEndT]], ...] + node_defs: dict[str, NodeDef[StateT, RunEndT]] def __init__( self, - first_node: type[BaseNode[StateT, GraphInputT, GraphOutputT]], - *other_nodes: type[BaseNode[StateT, Any, GraphOutputT]], name: str | None = None, + nodes: tuple[type[BaseNode[StateT, RunEndT]], ...] = (), state_type: type[StateT] | None = None, ): + self.name = name + + _nodes_by_id: dict[str, type[BaseNode[StateT, RunEndT]]] = {} + for node in nodes: + node_id = node.get_id() + if (existing_node := _nodes_by_id.get(node_id)) and existing_node is not node: + raise ValueError(f'Node ID "{node_id}" is not unique — found in {existing_node} and {node}') + else: + _nodes_by_id[node_id] = node + self.nodes = tuple(_nodes_by_id.values()) + parent_namespace = get_parent_namespace(inspect.currentframe()) - self.first_node = first_node.get_node_def(parent_namespace) - self.nodes = nodes = {self.first_node.node_id: self.first_node} - for node in other_nodes: - node_def = node.get_node_def(parent_namespace) - nodes[node_def.node_id] = node_def + self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = {} + for node in self.nodes: + self.node_defs[node.get_id()] = node.get_node_def(parent_namespace) - self._check() - self.name = name + self._validate_edges() + + def _validate_edges(self): + known_node_ids = set(self.node_defs.keys()) + bad_edges: dict[str, list[str]] = {} + + for node_id, node_def in self.node_defs.items(): + node_bad_edges = node_def.next_node_ids - known_node_ids + for bad_edge in node_bad_edges: + bad_edges.setdefault(bad_edge, []).append(f'"{node_id}"') + + if bad_edges: + bad_edges_list = [f'"{k}" is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] + if len(bad_edges_list) == 1: + raise ValueError(f'{bad_edges_list[0]} but not included in the graph.') + else: + b = '\n'.join(f' {be}' for be in bad_edges_list) + raise ValueError(f'Nodes are referenced in the graph but not included in the graph:\n{b}') async def run( + self, state: StateT, node: BaseNode[StateT, RunEndT] + ) -> tuple[RunEndT, list[StepOrEnd[StateT, RunEndT]]]: + if not isinstance(node, self.nodes): + raise ValueError(f'Node "{node}" is not in the graph.') + run = GraphRun[StateT, RunEndT](state=state) + # TODO: Infer the graph name properly + result = await run.run(self.name or 'graph', node) + history = run.history + return result, history + + def get_runner( self, - input_data: GraphInputT, - state: StateT = None, - history: list[Snapshot] | None = None, - ) -> tuple[GraphOutputT, list[Snapshot]]: - current_node_def = self.first_node - current_node = current_node_def.node(input_data) - ctx = GraphContext(state) - if history: - run_history = history[:] - else: - run_history = [] + first_node: StartNodeProtocol[RunSignatureT, StateT, RunEndT], + ) -> GraphRunner[RunSignatureT, StateT, RunEndT]: + return GraphRunner( + graph=self, + first_node=first_node, + ) + + +@dataclass +class GraphRunner(Generic[RunSignatureT, StateT, RunEndT]): + """Runner for a graph. + + This is a separate class from Graph so that you can get a type-safe runner from a graph definition + without needing to manually annotate the paramspec of the start node. + """ + + graph: Graph[StateT, RunEndT] + first_node: StartNodeProtocol[RunSignatureT, StateT, RunEndT] + + def __post_init__(self): + if self.first_node not in self.graph.nodes: + raise ValueError(f'Start node "{self.first_node}" is not in the graph.') + + async def run( + self, state: StateT, /, *args: RunSignatureT.args, **kwargs: RunSignatureT.kwargs + ) -> tuple[RunEndT, list[StepOrEnd[StateT, RunEndT]]]: + run = GraphRun[StateT, RunEndT](state=state) + # TODO: Infer the graph name properly + result = await run.run(self.graph.name or 'graph', self.first_node(*args, **kwargs)) + history = run.history + return result, history + + def mermaid_code(self) -> str: + return mermaid_code(self.graph, self.first_node) + + def mermaid_image( + self, + image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] | str | None = None, + pdf_fit: bool = False, + pdf_landscape: bool = False, + pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] + | str + | None = None, + bg_color: str | None = None, + theme: Literal['default', 'neutral', 'dark', 'forest'] | str | None = None, + width: int | None = None, + height: int | None = None, + scale: Annotated[float, Ge(1), Le(3)] | None = None, + ) -> bytes: + return mermaid_image( + self.graph, + self.first_node, + image_type=image_type, + pdf_fit=pdf_fit, + pdf_landscape=pdf_landscape, + pdf_paper=pdf_paper, + bg_color=bg_color, + theme=theme, + width=width, + height=height, + scale=scale, + ) + + def mermaid_save( + self, + path: Path | str, + image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] | str | None = None, + pdf_fit: bool = False, + pdf_landscape: bool = False, + pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] + | str + | None = None, + bg_color: str | None = None, + theme: Literal['default', 'neutral', 'dark', 'forest'] | str | None = None, + width: int | None = None, + height: int | None = None, + scale: Annotated[float, Ge(1), Le(3)] | None = None, + ) -> None: + mermaid_save( + path, + self.graph, + self.first_node, + image_type=image_type, + pdf_fit=pdf_fit, + pdf_landscape=pdf_landscape, + pdf_paper=pdf_paper, + bg_color=bg_color, + theme=theme, + width=width, + height=height, + scale=scale, + ) + + +@dataclass +class GraphRun(Generic[StateT, RunEndT]): + """Stateful run of a graph.""" + + state: StateT + history: list[StepOrEnd[StateT, RunEndT]] = field(default_factory=list) + + async def run(self, graph_name: str, start: BaseNode[StateT, RunEndT], infer_name: bool = True) -> RunEndT: + current_node = start with _logfire.span( - '{graph_name} run {input=}', - graph_name=self.name or 'graph', - input=input_data, - graph=self, + '{graph_name} run {start=}', + graph_name=graph_name, + start=start, ) as run_span: while True: - with _logfire.span('run node {node_id}', node_id=current_node_def.node_id): - start_ts = datetime.now(tz=timezone.utc) - start = perf_counter() - # noinspection PyUnresolvedReferences - next_node = await current_node.run(ctx) - duration = perf_counter() - start - + next_node = await self.step(current_node) if isinstance(next_node, End): - if current_node_def.can_end: - run_history.append( - Snapshot.from_state(current_node_def.node_id, 'END', start_ts, duration, ctx.state) - ) - run_span.set_attribute('history', run_history) - return next_node.data, run_history - else: - raise ValueError(f'Node {current_node_def.node_id} cannot end the graph') + self.history.append(EndEvent(self.state, next_node)) + run_span.set_attribute('history', self.history) + return next_node.data elif isinstance(next_node, BaseNode): - next_node_id = next_node.get_id() - run_history.append( - Snapshot.from_state(current_node_def.node_id, next_node_id, start_ts, duration, ctx.state) - ) - try: - next_node_def = self.nodes[next_node_id] - except KeyError as e: - raise ValueError( - f'Node {current_node_def.node_id} cannot go to {next_node_id} which is not in the Graph' - ) from e - - if not current_node_def.dest_any and next_node_id not in current_node_def.next_node_ids: - raise ValueError( - f'Node {current_node_def.node_id} cannot go to {next_node_id} which is not in its ' - f'list of allowed next nodes' - ) - - current_node_def = next_node_def current_node = next_node else: if TYPE_CHECKING: assert_never(next_node) else: - raise TypeError(f'Invalid node type: {type(next_node)} expected BaseNode or End') + raise TypeError(f'Invalid node type: {type(next_node)}. Expected `BaseNode` or `End`.') - def mermaid_code(self) -> str: - lines = ['graph TD'] - # order of destination nodes should match their order in `self.nodes` - node_order = {nid: index for index, nid in enumerate(self.nodes.keys())} - for node_id, node in self.nodes.items(): - if node_id == self.first_node.node_id: - lines.append(f' START --> {node_id}') - if node.dest_any: - for next_node_id in self.nodes: - lines.append(f' {node_id} --> {next_node_id}') - for _, next_node_id in sorted((node_order[nid], nid) for nid in node.next_node_ids): - lines.append(f' {node_id} --> {next_node_id}') - if node.can_end: - lines.append(f' {node_id} --> END') - return '\n'.join(lines) + async def step(self, node: BaseNode[StateT, RunEndT]) -> BaseNode[StateT, RunEndT] | End[RunEndT]: + history_step = Step(self.state, node) + self.history.append(history_step) - def mermaid_image(self, mermaid_ink_params: dict[str, str | int] | None = None) -> bytes: - import httpx + ctx = GraphContext(self.state) + with _logfire.span('run node {node_id}', node_id=node.get_id()): + start = perf_counter() + next_node = await node.run(ctx) + history_step.duration = perf_counter() - start + return next_node - code_base64 = base64.b64encode(self.mermaid_code().encode()).decode() - response = httpx.get(f'https://mermaid.ink/img/{code_base64}', params=mermaid_ink_params) - response.raise_for_status() - return response.content +def mermaid_code( + graph: Graph[Any, Any], start: StartNodeProtocol[..., Any, Any] | tuple[StartNodeProtocol[..., Any, Any], ...] = () +) -> str: + if not isinstance(start, tuple): + start = (start,) - def mermaid_save(self, path: Path | str, mermaid_ink_params: dict[str, str | int] | None = None) -> None: - image_data = self.mermaid_image(mermaid_ink_params) - Path(path).write_bytes(image_data) + for node in start: + if node not in graph.nodes: + raise ValueError(f'Start node "{node}" is not in the graph.') - def _check(self): - bad_edges: dict[str, list[str]] = {} - for node in self.nodes.values(): - node_bad_edges = node.next_node_ids - self.nodes.keys() - for bad_edge in node_bad_edges: - bad_edges.setdefault(bad_edge, []).append(f'"{node.node_id}"') + node_order = {node_id: index for index, node_id in enumerate(graph.node_defs)} - if bad_edges: - bad_edges_list = [f'"{k}" is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] - if len(bad_edges_list) == 1: - raise ValueError(f'{bad_edges_list[0]} but not included in the graph.') - else: - b = '\n'.join(f' {be}' for be in bad_edges_list) - raise ValueError(f'Nodes are referenced in the graph but not included in the graph:\n{b}') + lines = ['graph TD'] + for node in graph.nodes: + node_id = node.get_id() + node_def = graph.node_defs[node_id] + if node in start: + lines.append(f' START --> {node_id}') + if node_def.returns_base_node: + for next_node_id in graph.nodes: + lines.append(f' {node_id} --> {next_node_id}') + else: + for _, next_node_id in sorted((node_order[node_id], node_id) for node_id in node_def.next_node_ids): + lines.append(f' {node_id} --> {next_node_id}') + if node_def.returns_end: + lines.append(f' {node_id} --> END') + + return '\n'.join(lines) -def get_parent_namespace(frame: types.FrameType | None) -> dict[str, Any] | None: - """Attempt to get the namespace where the graph was defined. +def mermaid_image( + graph: Graph[Any, Any], + start: StartNodeProtocol[..., Any, Any] | tuple[StartNodeProtocol[..., Any, Any], ...] = (), + image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] | str | None = None, + pdf_fit: bool = False, + pdf_landscape: bool = False, + pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] + | str + | None = None, + bg_color: str | None = None, + theme: Literal['default', 'neutral', 'dark', 'forest'] | str | None = None, + width: int | None = None, + height: int | None = None, + scale: Annotated[float, Ge(1), Le(3)] | None = None, +) -> bytes: + """Generate an image of a Mermaid diagram using mermaid.ink. - If the graph is defined with generics `Graph[a, b]` then another frame is inserted, and we have to skip that - to get the correct namespace. + Args: + graph: The graph to generate the image for. + start: The start node(s) of the graph. + image_type: The image type to generate. If unspecified, the default behavior is `'jpeg'`. + pdf_fit: When using image_type='pdf', whether to fit the diagram to the PDF page. + pdf_landscape: When using image_type='pdf', whether to use landscape orientation for the PDF. + This has no effect if using `pdf_fit`. + pdf_paper: When using image_type='pdf', the paper size of the PDF. + bg_color: The background color of the diagram. If None, the default transparent background is used. + The color value is interpreted as a hexadecimal color code by default (and should not have a leading '#'), + but you can also use named colors by prefixing the value with '!'. + For example, valid choices include `bg_color='!white'` or `bg_color='FF0000'`. + theme: The theme of the diagram. Defaults to 'default'. + width: The width of the diagram. + height: The height of the diagram. + scale: The scale of the diagram. The scale must be a number between 1 and 3, and you can only set + a scale if one or both of width and height are set. """ - if frame is not None: - if back := frame.f_back: - if back.f_code.co_filename.endswith('/typing.py'): - return get_parent_namespace(back) - else: - return back.f_locals + import httpx + + code_base64 = base64.b64encode(mermaid_code(graph, start).encode()).decode() + + params: dict[str, str] = {} + if image_type == 'pdf': + url = f'https://mermaid.ink/pdf/{code_base64}' + if pdf_fit: + params['fit'] = '' + if pdf_landscape: + params['landscape'] = '' + if pdf_paper: + params['paper'] = pdf_paper + else: + url = f'https://mermaid.ink/img/{code_base64}' + + if image_type: + params['type'] = image_type + + if bg_color: + params['bgColor'] = bg_color + if theme: + params['theme'] = theme + if width: + params['width'] = str(width) + if height: + params['height'] = str(height) + if scale: + params['scale'] = str(scale) + + response = httpx.get(url, params=params) + response.raise_for_status() + return response.content + + +def mermaid_save( + path: Path | str, + graph: Graph[Any, Any], + start: StartNodeProtocol[..., Any, Any] | tuple[StartNodeProtocol[..., Any, Any], ...] = (), + image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] | str | None = None, + pdf_fit: bool = False, + pdf_landscape: bool = False, + pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] + | str + | None = None, + bg_color: str | None = None, + theme: Literal['default', 'neutral', 'dark', 'forest'] | str | None = None, + width: int | None = None, + height: int | None = None, + scale: Annotated[float, Ge(1), Le(3)] | None = None, +) -> None: + # TODO: do something with the path file extension, e.g. error if it's incompatible, or use it to specify a param + image_data = mermaid_image( + graph, + start, + image_type=image_type, + pdf_fit=pdf_fit, + pdf_landscape=pdf_landscape, + pdf_paper=pdf_paper, + bg_color=bg_color, + theme=theme, + width=width, + height=height, + scale=scale, + ) + Path(path).write_bytes(image_data) diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_ai_graph/pydantic_ai_graph/nodes.py index 9ed72dfc23..1a1b8bcd5d 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/nodes.py +++ b/pydantic_ai_graph/pydantic_ai_graph/nodes.py @@ -1,29 +1,21 @@ from __future__ import annotations as _annotations -from abc import ABC, ABCMeta, abstractmethod +from abc import abstractmethod from dataclasses import dataclass from functools import cache -from typing import Any, ClassVar, Generic, get_args, get_origin, get_type_hints +from typing import Any, Generic, get_origin, get_type_hints -from typing_extensions import TypeVar +from typing_extensions import Never, TypeVar from . import _utils from .state import StateT -__all__ = ( - 'NodeInputT', - 'GraphOutputT', - 'GraphContext', - 'End', - 'BaseNode', - 'NodeDef', -) +__all__ = ('GraphContext', 'End', 'BaseNode', 'NodeDef') -NodeInputT = TypeVar('NodeInputT', default=Any) -GraphOutputT = TypeVar('GraphOutputT', default=Any) +RunEndT = TypeVar('RunEndT', default=None) +NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) -# noinspection PyTypeHints @dataclass class GraphContext(Generic[StateT]): """Context for a graph.""" @@ -31,61 +23,44 @@ class GraphContext(Generic[StateT]): state: StateT -# noinspection PyTypeHints -class End(ABC, Generic[NodeInputT]): +@dataclass +class End(Generic[RunEndT]): """Type to return from a node to signal the end of the graph.""" - __slots__ = ('data',) - - def __init__(self, input_data: NodeInputT) -> None: - self.data = input_data + data: RunEndT -class _BaseNodeMeta(ABCMeta): - def __repr__(cls): - base: Any = cls.__orig_bases__[0] # type: ignore - args = get_args(base) - if len(args) == 3 and args[2] is Any: - if args[1] is Any: - args = args[:1] - else: - args = args[:2] - args = ', '.join(_utils.type_arg_name(a) for a in args) - return f'{cls.__name__}({base.__name__}[{args}])' - - -# noinspection PyTypeHints -class BaseNode(Generic[StateT, NodeInputT, GraphOutputT], metaclass=_BaseNodeMeta): +class BaseNode(Generic[StateT, NodeRunEndT]): """Base class for a node.""" - node_id: ClassVar[str | None] = None - __slots__ = ('input_data',) - - def __init__(self, input_data: NodeInputT) -> None: - self.input_data = input_data - @abstractmethod - async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any, Any] | End[GraphOutputT]: ... + async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: ... @classmethod @cache def get_id(cls) -> str: - return cls.node_id or cls.__name__ + return cls.__name__ @classmethod - def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, Any, Any]: + def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRunEndT]: type_hints = get_type_hints(cls.run, localns=local_ns) next_node_ids: set[str] = set() - can_end: bool = False - dest_any: bool = False - for return_type in _utils.get_union_args(type_hints['return']): + returns_end: bool = False + returns_base_node: bool = False + try: + return_hint = type_hints['return'] + except KeyError: + raise TypeError(f'Node {cls} is missing a return type hint on its `run` method') + + for return_type in _utils.get_union_args(return_hint): return_type_origin = get_origin(return_type) or return_type - if return_type_origin is BaseNode: - dest_any = True + if return_type_origin is End: + returns_end = True + elif return_type_origin is BaseNode: + # TODO: Should we disallow this? More generally, what do we do about sub-subclasses? + returns_base_node = True elif issubclass(return_type_origin, BaseNode): next_node_ids.add(return_type.get_id()) - elif return_type_origin is End: - can_end = True else: raise TypeError(f'Invalid return type: {return_type}') @@ -93,21 +68,20 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, Any, A cls, cls.get_id(), next_node_ids, - can_end, - dest_any, + returns_end, + returns_base_node, ) -# noinspection PyTypeHints @dataclass -class NodeDef(ABC, Generic[StateT, NodeInputT, GraphOutputT]): +class NodeDef(Generic[StateT, NodeRunEndT]): """Definition of a node. Used by [`Graph`][pydantic_ai_graph.graph.Graph] store information about a node. """ - node: type[BaseNode[StateT, NodeInputT, GraphOutputT]] + node: type[BaseNode[StateT, NodeRunEndT]] node_id: str next_node_ids: set[str] - can_end: bool - dest_any: bool + returns_end: bool + returns_base_node: bool diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py index 00789887ed..45b01aedfd 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -1,13 +1,18 @@ from __future__ import annotations as _annotations +import copy from abc import ABC, abstractmethod -from dataclasses import dataclass -from datetime import datetime -from typing import Union +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Generic, Literal, Union -from typing_extensions import TypeVar +from typing_extensions import Never, TypeVar -__all__ = 'AbstractState', 'StateT', 'Snapshot' +__all__ = 'AbstractState', 'StateT', 'Step', 'EndEvent', 'StepOrEnd' + +if TYPE_CHECKING: + from pydantic_ai_graph import BaseNode + from pydantic_ai_graph.nodes import End class AbstractState(ABC): @@ -19,33 +24,40 @@ def serialize(self) -> bytes | None: raise NotImplementedError +RunEndT = TypeVar('RunEndT', default=None) +NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) StateT = TypeVar('StateT', bound=Union[None, AbstractState], default=None) @dataclass -class Snapshot: - """Snapshot of a graph.""" - - last_node_id: str - next_node_id: str - start_ts: datetime - duration: float - state: bytes | None = None - - @classmethod - def from_state( - cls, last_node_id: str, next_node_id: str, start_ts: datetime, duration: float, state: StateT - ) -> Snapshot: - return cls( - last_node_id=last_node_id, - next_node_id=next_node_id, - start_ts=start_ts, - duration=duration, - state=state.serialize() if state is not None else None, - ) - - def summary(self) -> str: - s = f'{self.last_node_id} -> {self.next_node_id}' - if self.duration > 1e-5: - s += f' ({self.duration:.6f}s)' - return s +class Step(Generic[StateT, RunEndT]): + """History item describing the execution of a step of a graph.""" + + state: StateT + node: BaseNode[StateT, RunEndT] + start_ts: datetime = field(default_factory=lambda: datetime.now(tz=timezone.utc)) + duration: float | None = None + + kind: Literal['start_step'] = 'start_step' + + def __post_init__(self): + # Copy the state to prevent it from being modified by other code + self.state = copy.deepcopy(self.state) + + +@dataclass +class EndEvent(Generic[StateT, RunEndT]): + """History item describing the end of a graph run.""" + + state: StateT + result: End[RunEndT] + ts: datetime = field(default_factory=lambda: datetime.now(tz=timezone.utc)) + + kind: Literal['end'] = 'end' + + def __post_init__(self): + # Copy the state to prevent it from being modified by other code + self.state = copy.deepcopy(self.state) + + +StepOrEnd = Union[Step[StateT, RunEndT], EndEvent[StateT, RunEndT]] diff --git a/tests/test_graph.py b/tests/test_graph.py index 81b54fa070..4a06603a7e 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,12 +1,13 @@ from __future__ import annotations as _annotations +from dataclasses import dataclass from datetime import timezone from typing import Union import pytest from inline_snapshot import snapshot -from pydantic_ai_graph import BaseNode, End, Graph, GraphContext, Snapshot +from pydantic_ai_graph import BaseNode, End, EndEvent, Graph, GraphContext, Step from .conftest import IsFloat, IsNow @@ -14,85 +15,101 @@ async def test_graph(): - class Float2String(BaseNode[None, float]): + @dataclass + class Float2String(BaseNode): + input_data: float + async def run(self, ctx: GraphContext) -> String2Length: return String2Length(str(self.input_data)) - class String2Length(BaseNode[None, str]): + @dataclass + class String2Length(BaseNode): + input_data: str + async def run(self, ctx: GraphContext) -> Double: return Double(len(self.input_data)) - class Double(BaseNode[None, int, int]): + @dataclass + class Double(BaseNode[None, int]): + input_data: int + async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noqa: UP007 if self.input_data == 7: return String2Length('x' * 21) else: return End(self.input_data * 2) - g = Graph[None, float, int]( - Float2String, - String2Length, - Double, - ) - result, history = await g.run(3.14) + g = Graph[None, int](nodes=(Float2String, String2Length, Double)) + runner = g.get_runner(Float2String) + result, history = await runner.run(None, 3.14) # len('3.14') * 2 == 8 assert result == 8 assert history == snapshot( [ - Snapshot( - last_node_id='Float2String', - next_node_id='String2Length', + Step( + state=None, + node=Float2String(input_data=3.14), start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-3), + duration=IsFloat(), ), - Snapshot( - last_node_id='String2Length', - next_node_id='Double', + Step( + state=None, + node=String2Length(input_data='3.14'), start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-3), + duration=IsFloat(), ), - Snapshot( - last_node_id='Double', - next_node_id='END', + Step( + state=None, + node=Double(input_data=4), start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-3), + duration=IsFloat(), + ), + EndEvent( + state=None, + result=End(data=8), + ts=IsNow(tz=timezone.utc), ), ] ) - result, history = await g.run(3.14159) + result, history = await runner.run(None, 3.14159) # len('3.14159') == 7, 21 * 2 == 42 assert result == 42 assert history == snapshot( [ - Snapshot( - last_node_id='Float2String', - next_node_id='String2Length', + Step( + state=None, + node=Float2String(input_data=3.14159), start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-3), + duration=IsFloat(), ), - Snapshot( - last_node_id='String2Length', - next_node_id='Double', + Step( + state=None, + node=String2Length(input_data='3.14159'), start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-3), + duration=IsFloat(), ), - Snapshot( - last_node_id='Double', - next_node_id='String2Length', + Step( + state=None, + node=Double(input_data=7), start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-3), + duration=IsFloat(), ), - Snapshot( - last_node_id='String2Length', - next_node_id='Double', + Step( + state=None, + node=String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx'), start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-3), + duration=IsFloat(), ), - Snapshot( - last_node_id='Double', - next_node_id='END', + Step( + state=None, + node=Double(input_data=21), start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-3), + duration=IsFloat(), + ), + EndEvent( + state=None, + result=End(data=42), + ts=IsNow(tz=timezone.utc), ), ] ) diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 040485af92..fdf9f1a255 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -3,7 +3,9 @@ from collections.abc import Awaitable, Iterator from contextlib import contextmanager from dataclasses import dataclass -from typing import Callable, TypeAlias, Union, assert_type +from typing import Callable, TypeAlias, Union + +from typing_extensions import assert_type from pydantic_ai import Agent, ModelRetry, RunContext, Tool from pydantic_ai.result import RunResult diff --git a/tests/typed_graph.py b/tests/typed_graph.py index 806347f8fb..cad0e7df16 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -1,17 +1,24 @@ from __future__ import annotations as _annotations from dataclasses import dataclass -from typing import assert_type + +from typing_extensions import assert_type from pydantic_ai_graph import BaseNode, End, Graph, GraphContext -class Float2String(BaseNode[None, float]): +@dataclass +class Float2String(BaseNode): + input_data: float + async def run(self, ctx: GraphContext) -> String2Length: return String2Length(str(self.input_data)) -class String2Length(BaseNode[None, str]): +@dataclass +class String2Length(BaseNode): + input_data: str + async def run(self, ctx: GraphContext) -> Double: return Double(len(self.input_data)) @@ -21,7 +28,10 @@ class X: v: int -class Double(BaseNode[None, int, X]): +@dataclass +class Double(BaseNode[None, X]): + input_data: int + async def run(self, ctx: GraphContext) -> String2Length | End[X]: if self.input_data == 7: return String2Length('x' * 21) @@ -29,7 +39,7 @@ async def run(self, ctx: GraphContext) -> String2Length | End[X]: return End(X(self.input_data * 2)) -def use_double(node: BaseNode[None, int, X]) -> None: +def use_double(node: BaseNode[None, X]) -> None: """Shoe that `Double` is valid as a `BaseNode[None, int, X]`.""" print(node) @@ -37,24 +47,28 @@ def use_double(node: BaseNode[None, int, X]) -> None: use_double(Double(1)) -g1 = Graph[None, float, X]( - Float2String, - String2Length, - Double, +g1 = Graph[None, X]( + nodes=( + Float2String, + String2Length, + Double, + ) ) -assert_type(g1, Graph[None, float, X]) +assert_type(g1, Graph[None, X]) -g2 = Graph(Double) -assert_type(g2, Graph[None, int, X]) +g2 = Graph(nodes=(Double,)) +assert_type(g2, Graph[None, X]) g3 = Graph( - Float2String, - String2Length, - Double, + nodes=( + Float2String, + String2Length, + Double, + ) ) # because String2Length came before Double, the output type is Any -assert_type(g3, Graph[None, float]) +assert_type(g3, Graph[None, X]) -Graph[None, float, bytes](Float2String, String2Length, Double) # type: ignore[arg-type] -Graph[None, int, str](Double) # type: ignore[arg-type] +Graph[None, bytes](Float2String, String2Length, Double) # type: ignore[arg-type] +Graph[None, str](Double) # type: ignore[arg-type] From 06428bbcfc2cc8f76da7a8327b12aa65ae158d91 Mon Sep 17 00:00:00 2001 From: Israel Ekpo <44282278+izzyacademy@users.noreply.github.com> Date: Fri, 3 Jan 2025 07:03:07 -0500 Subject: [PATCH 11/57] Typo in Graph Documentation (#596) --- pydantic_ai_graph/README.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_graph/README.md b/pydantic_ai_graph/README.md index ec0b531a25..90e7f787ff 100644 --- a/pydantic_ai_graph/README.md +++ b/pydantic_ai_graph/README.md @@ -11,6 +11,11 @@ Graph and state machine library. This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency on `pydantic-ai` or related packages and can be considered as a pure graph library. -As with PydanticAI, this library priorities type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. +As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. -`pydantic-ai-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes. +`pydantic-ai-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes. + +When designing your graph and state machine, you need to identify the data types for the overall graph input, the final graph output, the graph dependency object and graph state. Then for each specific node in the graph, you have to identify the specific data type each node is expected to receive as the input type from the prior node in the graph during transitions. + +Once the nodes in the graph are defined, you can use certain built-in methods on the Graph object to visualize the nodes +and state transitions on the graph as mermaid diagrams. From 1a5d3e22060dcd90e272bce4eaaa677909056351 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 7 Jan 2025 14:20:31 +0000 Subject: [PATCH 12/57] fix linting --- pydantic_ai_graph/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_graph/README.md b/pydantic_ai_graph/README.md index 90e7f787ff..3f3a5bd72a 100644 --- a/pydantic_ai_graph/README.md +++ b/pydantic_ai_graph/README.md @@ -13,9 +13,9 @@ on `pydantic-ai` or related packages and can be considered as a pure graph libra As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. -`pydantic-ai-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes. +`pydantic-ai-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes. When designing your graph and state machine, you need to identify the data types for the overall graph input, the final graph output, the graph dependency object and graph state. Then for each specific node in the graph, you have to identify the specific data type each node is expected to receive as the input type from the prior node in the graph during transitions. -Once the nodes in the graph are defined, you can use certain built-in methods on the Graph object to visualize the nodes +Once the nodes in the graph are defined, you can use certain built-in methods on the Graph object to visualize the nodes and state transitions on the graph as mermaid diagrams. From 6faaf97b374575e9e84a4376d86db58884f5d85a Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 7 Jan 2025 17:48:07 +0000 Subject: [PATCH 13/57] separate mermaid logic --- pydantic_ai_graph/pydantic_ai_graph/graph.py | 176 ++--------------- .../pydantic_ai_graph/mermaid.py | 186 ++++++++++++++++++ 2 files changed, 207 insertions(+), 155 deletions(-) create mode 100644 pydantic_ai_graph/pydantic_ai_graph/mermaid.py diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index a6e09e887a..ae3c926b49 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -1,22 +1,22 @@ from __future__ import annotations as _annotations -import base64 import inspect +from collections.abc import Sequence from dataclasses import dataclass, field from pathlib import Path from time import perf_counter -from typing import TYPE_CHECKING, Annotated, Any, Generic +from typing import TYPE_CHECKING, Annotated, Generic import logfire_api from annotated_types import Ge, Le -from typing_extensions import Literal, Never, ParamSpec, Protocol, TypeVar, assert_never +from typing_extensions import Never, ParamSpec, Protocol, TypeVar, assert_never -from . import _utils +from . import _utils, mermaid from ._utils import get_parent_namespace from .nodes import BaseNode, End, GraphContext, NodeDef from .state import EndEvent, StateT, Step, StepOrEnd -__all__ = ('Graph', 'GraphRun', 'GraphRunner') +__all__ = 'Graph', 'GraphRun', 'GraphRunner' _logfire = logfire_api.Logfire(otel_scope='pydantic-ai-graph') @@ -27,6 +27,7 @@ class StartNodeProtocol(Protocol[RunSignatureT, StateT, NodeRunEndT]): def get_id(self) -> str: ... + def __call__(self, *args: RunSignatureT.args, **kwargs: RunSignatureT.kwargs) -> BaseNode[StateT, NodeRunEndT]: ... @@ -40,9 +41,10 @@ class Graph(Generic[StateT, RunEndT]): def __init__( self, - name: str | None = None, - nodes: tuple[type[BaseNode[StateT, RunEndT]], ...] = (), + *, + nodes: Sequence[type[BaseNode[StateT, RunEndT]]], state_type: type[StateT] | None = None, + name: str | None = None, ): self.name = name @@ -94,6 +96,7 @@ def get_runner( self, first_node: StartNodeProtocol[RunSignatureT, StateT, RunEndT], ) -> GraphRunner[RunSignatureT, StateT, RunEndT]: + # noinspection PyTypeChecker return GraphRunner( graph=self, first_node=first_node, @@ -125,25 +128,23 @@ async def run( return result, history def mermaid_code(self) -> str: - return mermaid_code(self.graph, self.first_node) + return mermaid.generate_code(self.graph, self.first_node.get_id()) def mermaid_image( self, - image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] | str | None = None, + image_type: mermaid.ImageType | None = None, pdf_fit: bool = False, pdf_landscape: bool = False, - pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] - | str - | None = None, + pdf_paper: mermaid.PdfPaper | None = None, bg_color: str | None = None, - theme: Literal['default', 'neutral', 'dark', 'forest'] | str | None = None, + theme: mermaid.Theme | None = None, width: int | None = None, height: int | None = None, scale: Annotated[float, Ge(1), Le(3)] | None = None, ) -> bytes: - return mermaid_image( + return mermaid.request_image( self.graph, - self.first_node, + start_node_ids=self.first_node.get_id(), image_type=image_type, pdf_fit=pdf_fit, pdf_landscape=pdf_landscape, @@ -158,22 +159,20 @@ def mermaid_image( def mermaid_save( self, path: Path | str, - image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] | str | None = None, + image_type: mermaid.ImageType | None = None, pdf_fit: bool = False, pdf_landscape: bool = False, - pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] - | str - | None = None, + pdf_paper: mermaid.PdfPaper | None = None, bg_color: str | None = None, - theme: Literal['default', 'neutral', 'dark', 'forest'] | str | None = None, + theme: mermaid.Theme | None = None, width: int | None = None, height: int | None = None, scale: Annotated[float, Ge(1), Le(3)] | None = None, ) -> None: - mermaid_save( + mermaid.save_image( path, self.graph, - self.first_node, + self.first_node.get_id(), image_type=image_type, pdf_fit=pdf_fit, pdf_landscape=pdf_landscape, @@ -225,136 +224,3 @@ async def step(self, node: BaseNode[StateT, RunEndT]) -> BaseNode[StateT, RunEnd next_node = await node.run(ctx) history_step.duration = perf_counter() - start return next_node - - -def mermaid_code( - graph: Graph[Any, Any], start: StartNodeProtocol[..., Any, Any] | tuple[StartNodeProtocol[..., Any, Any], ...] = () -) -> str: - if not isinstance(start, tuple): - start = (start,) - - for node in start: - if node not in graph.nodes: - raise ValueError(f'Start node "{node}" is not in the graph.') - - node_order = {node_id: index for index, node_id in enumerate(graph.node_defs)} - - lines = ['graph TD'] - for node in graph.nodes: - node_id = node.get_id() - node_def = graph.node_defs[node_id] - if node in start: - lines.append(f' START --> {node_id}') - if node_def.returns_base_node: - for next_node_id in graph.nodes: - lines.append(f' {node_id} --> {next_node_id}') - else: - for _, next_node_id in sorted((node_order[node_id], node_id) for node_id in node_def.next_node_ids): - lines.append(f' {node_id} --> {next_node_id}') - if node_def.returns_end: - lines.append(f' {node_id} --> END') - - return '\n'.join(lines) - - -def mermaid_image( - graph: Graph[Any, Any], - start: StartNodeProtocol[..., Any, Any] | tuple[StartNodeProtocol[..., Any, Any], ...] = (), - image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] | str | None = None, - pdf_fit: bool = False, - pdf_landscape: bool = False, - pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] - | str - | None = None, - bg_color: str | None = None, - theme: Literal['default', 'neutral', 'dark', 'forest'] | str | None = None, - width: int | None = None, - height: int | None = None, - scale: Annotated[float, Ge(1), Le(3)] | None = None, -) -> bytes: - """Generate an image of a Mermaid diagram using mermaid.ink. - - Args: - graph: The graph to generate the image for. - start: The start node(s) of the graph. - image_type: The image type to generate. If unspecified, the default behavior is `'jpeg'`. - pdf_fit: When using image_type='pdf', whether to fit the diagram to the PDF page. - pdf_landscape: When using image_type='pdf', whether to use landscape orientation for the PDF. - This has no effect if using `pdf_fit`. - pdf_paper: When using image_type='pdf', the paper size of the PDF. - bg_color: The background color of the diagram. If None, the default transparent background is used. - The color value is interpreted as a hexadecimal color code by default (and should not have a leading '#'), - but you can also use named colors by prefixing the value with '!'. - For example, valid choices include `bg_color='!white'` or `bg_color='FF0000'`. - theme: The theme of the diagram. Defaults to 'default'. - width: The width of the diagram. - height: The height of the diagram. - scale: The scale of the diagram. The scale must be a number between 1 and 3, and you can only set - a scale if one or both of width and height are set. - """ - import httpx - - code_base64 = base64.b64encode(mermaid_code(graph, start).encode()).decode() - - params: dict[str, str] = {} - if image_type == 'pdf': - url = f'https://mermaid.ink/pdf/{code_base64}' - if pdf_fit: - params['fit'] = '' - if pdf_landscape: - params['landscape'] = '' - if pdf_paper: - params['paper'] = pdf_paper - else: - url = f'https://mermaid.ink/img/{code_base64}' - - if image_type: - params['type'] = image_type - - if bg_color: - params['bgColor'] = bg_color - if theme: - params['theme'] = theme - if width: - params['width'] = str(width) - if height: - params['height'] = str(height) - if scale: - params['scale'] = str(scale) - - response = httpx.get(url, params=params) - response.raise_for_status() - return response.content - - -def mermaid_save( - path: Path | str, - graph: Graph[Any, Any], - start: StartNodeProtocol[..., Any, Any] | tuple[StartNodeProtocol[..., Any, Any], ...] = (), - image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] | str | None = None, - pdf_fit: bool = False, - pdf_landscape: bool = False, - pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] - | str - | None = None, - bg_color: str | None = None, - theme: Literal['default', 'neutral', 'dark', 'forest'] | str | None = None, - width: int | None = None, - height: int | None = None, - scale: Annotated[float, Ge(1), Le(3)] | None = None, -) -> None: - # TODO: do something with the path file extension, e.g. error if it's incompatible, or use it to specify a param - image_data = mermaid_image( - graph, - start, - image_type=image_type, - pdf_fit=pdf_fit, - pdf_landscape=pdf_landscape, - pdf_paper=pdf_paper, - bg_color=bg_color, - theme=theme, - width=width, - height=height, - scale=scale, - ) - Path(path).write_bytes(image_data) diff --git a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py new file mode 100644 index 0000000000..431d26d642 --- /dev/null +++ b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py @@ -0,0 +1,186 @@ +from __future__ import annotations as _annotations + +import base64 +from collections.abc import Sequence +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Any, Literal, cast + +from annotated_types import Ge, Le + +if TYPE_CHECKING: + from .graph import Graph + + +def generate_code(graph: Graph[Any, Any], start_node_ids: Sequence[str] | str) -> str: + """Generate Mermaid code for a graph. + + Args: + graph: The graph to generate the image for. + start_node_ids: IDs of start nodes of the graph. + + Returns: The Mermaid code for the graph. + """ + if isinstance(start_node_ids, str): + start_node_ids = (start_node_ids,) + + for node_id in start_node_ids: + if node_id not in graph.node_defs: + raise LookupError(f'Start node "{node_id}" is not in the graph.') + + node_order = {node_id: index for index, node_id in enumerate(graph.node_defs)} + + lines = ['graph TD'] + for node in graph.nodes: + node_id = node.get_id() + node_def = graph.node_defs[node_id] + if node_id in start_node_ids: + lines.append(f' START --> {node_id}') + if node_def.returns_base_node: + for next_node_id in graph.nodes: + lines.append(f' {node_id} --> {next_node_id}') + else: + for _, next_node_id in sorted((node_order[node_id], node_id) for node_id in node_def.next_node_ids): + lines.append(f' {node_id} --> {next_node_id}') + if node_def.returns_end: + lines.append(f' {node_id} --> END') + + return '\n'.join(lines) + + +ImageType = Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] +PdfPaper = Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] +Theme = Literal['default', 'neutral', 'dark', 'forest'] + + +def request_image( + graph: Graph[Any, Any], + start_node_ids: Sequence[str] | str, + *, + image_type: ImageType | str | None = None, + pdf_fit: bool = False, + pdf_landscape: bool = False, + pdf_paper: PdfPaper | None = None, + bg_color: str | None = None, + theme: Theme | None = None, + width: int | None = None, + height: int | None = None, + scale: Annotated[float, Ge(1), Le(3)] | None = None, +) -> bytes: + """Generate an image of a Mermaid diagram using [mermaid.ink](https://mermaid.ink). + + Args: + graph: The graph to generate the image for. + start_node_ids: IDs of start nodes of the graph. + image_type: The image type to generate. If unspecified, the default behavior is `'jpeg'`. + pdf_fit: When using image_type='pdf', whether to fit the diagram to the PDF page. + pdf_landscape: When using image_type='pdf', whether to use landscape orientation for the PDF. + This has no effect if using `pdf_fit`. + pdf_paper: When using image_type='pdf', the paper size of the PDF. + bg_color: The background color of the diagram. If None, the default transparent background is used. + The color value is interpreted as a hexadecimal color code by default (and should not have a leading '#'), + but you can also use named colors by prefixing the value with `'!'`. + For example, valid choices include `bg_color='!white'` or `bg_color='FF0000'`. + theme: The theme of the diagram. Defaults to 'default'. + width: The width of the diagram. + height: The height of the diagram. + scale: The scale of the diagram. The scale must be a number between 1 and 3, and you can only set + a scale if one or both of width and height are set. + + Returns: The image data. + """ + import httpx + + code_base64 = base64.b64encode(generate_code(graph, start_node_ids).encode()).decode() + + params: dict[str, str] = {} + if image_type == 'pdf': + url = f'https://mermaid.ink/pdf/{code_base64}' + if pdf_fit: + params['fit'] = '' + if pdf_landscape: + params['landscape'] = '' + if pdf_paper: + params['paper'] = pdf_paper + elif image_type == 'svg': + url = f'https://mermaid.ink/svg/{code_base64}' + else: + url = f'https://mermaid.ink/img/{code_base64}' + + if image_type: + params['type'] = image_type + + if bg_color: + params['bgColor'] = bg_color + if theme: + params['theme'] = theme + if width: + params['width'] = str(width) + if height: + params['height'] = str(height) + if scale: + params['scale'] = str(scale) + + response = httpx.get(url, params=params) + response.raise_for_status() + return response.content + + +def save_image( + path: Path | str, + graph: Graph[Any, Any], + start_node_ids: Sequence[str] | str, + *, + image_type: ImageType | None = None, + pdf_fit: bool = False, + pdf_landscape: bool = False, + pdf_paper: PdfPaper | None = None, + bg_color: str | None = None, + theme: Theme | None = None, + width: int | None = None, + height: int | None = None, + scale: Annotated[float, Ge(1), Le(3)] | None = None, +) -> None: + """Generate an image of a Mermaid diagram using [mermaid.ink](https://mermaid.ink) and save it to a local file. + + Args: + path: The path to save the image to. + graph: The graph to generate the image for. + start_node_ids: IDs of start nodes of the graph. + image_type: The image type to generate. If unspecified, the default behavior is `'jpeg'`. + pdf_fit: When using image_type='pdf', whether to fit the diagram to the PDF page. + pdf_landscape: When using image_type='pdf', whether to use landscape orientation for the PDF. + This has no effect if using `pdf_fit`. + pdf_paper: When using image_type='pdf', the paper size of the PDF. + bg_color: The background color of the diagram. If None, the default transparent background is used. + The color value is interpreted as a hexadecimal color code by default (and should not have a leading '#'), + but you can also use named colors by prefixing the value with `'!'`. + For example, valid choices include `bg_color='!white'` or `bg_color='FF0000'`. + theme: The theme of the diagram. Defaults to 'default'. + width: The width of the diagram. + height: The height of the diagram. + scale: The scale of the diagram. The scale must be a number between 1 and 3, and you can only set + a scale if one or both of width and height are set. + """ + if isinstance(path, str): + path = Path(path) + + if image_type is None: + ext = path.suffix.lower() + # no need to check for .jpeg/.jpg, as it is the default + if ext in {'.png', '.webp', '.svg', '.pdf'}: + image_type = cast(ImageType, ext[1:]) + + image_data = request_image( + graph, + start_node_ids, + image_type=image_type, + pdf_fit=pdf_fit, + pdf_landscape=pdf_landscape, + pdf_paper=pdf_paper, + bg_color=bg_color, + theme=theme, + width=width, + height=height, + scale=scale, + ) + Path(path).write_bytes(image_data) From 892f6612bc0ccc0f148f8d81889992cb99ccd533 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 7 Jan 2025 21:09:21 +0000 Subject: [PATCH 14/57] fix graph type checking --- tests/typed_graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/typed_graph.py b/tests/typed_graph.py index cad0e7df16..c131f19b4c 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -70,5 +70,5 @@ def use_double(node: BaseNode[None, X]) -> None: # because String2Length came before Double, the output type is Any assert_type(g3, Graph[None, X]) -Graph[None, bytes](Float2String, String2Length, Double) # type: ignore[arg-type] -Graph[None, str](Double) # type: ignore[arg-type] +Graph[None, bytes](nodes=(Float2String, String2Length, Double)) # type: ignore[arg-type] +Graph[None, str](nodes=[Double]) # type: ignore[list-item] From 02b7f28881234b61ce2c4b00133b5dedca3f76aa Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 7 Jan 2025 21:16:35 +0000 Subject: [PATCH 15/57] bump From be3f689d2fabc34da28dfe8c3d45d3acf412f47f Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 7 Jan 2025 22:40:28 +0000 Subject: [PATCH 16/57] adding node highlighting to mermaid, testing locally --- pydantic_ai_graph/pydantic_ai_graph/_utils.py | 5 ++ pydantic_ai_graph/pydantic_ai_graph/graph.py | 24 ++++++++-- .../pydantic_ai_graph/mermaid.py | 46 +++++++++++++++---- pydantic_ai_graph/pydantic_ai_graph/state.py | 16 +++++-- 4 files changed, 75 insertions(+), 16 deletions(-) diff --git a/pydantic_ai_graph/pydantic_ai_graph/_utils.py b/pydantic_ai_graph/pydantic_ai_graph/_utils.py index d1b4671800..a451fc5533 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/_utils.py +++ b/pydantic_ai_graph/pydantic_ai_graph/_utils.py @@ -2,6 +2,7 @@ import sys import types +from datetime import datetime, timezone from typing import Any, Union, get_args, get_origin from typing_extensions import TypeAliasType @@ -63,3 +64,7 @@ def get_parent_namespace(frame: types.FrameType | None) -> dict[str, Any] | None return get_parent_namespace(back) else: return back.f_locals + + +def now_utc() -> datetime: + return datetime.now(tz=timezone.utc) diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index ae3c926b49..98437c4941 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -127,11 +127,20 @@ async def run( history = run.history return result, history - def mermaid_code(self) -> str: - return mermaid.generate_code(self.graph, self.first_node.get_id()) + def mermaid_code( + self, + highlighted_nodes: Sequence[mermaid.NodeIdent] | None = None, + highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, + ) -> str: + return mermaid.generate_code( + self.graph, {self.first_node.get_id()}, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css + ) def mermaid_image( self, + *, + highlighted_nodes: Sequence[mermaid.NodeIdent] | None = None, + highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, image_type: mermaid.ImageType | None = None, pdf_fit: bool = False, pdf_landscape: bool = False, @@ -144,7 +153,9 @@ def mermaid_image( ) -> bytes: return mermaid.request_image( self.graph, - start_node_ids=self.first_node.get_id(), + {self.first_node.get_id()}, + highlighted_nodes=highlighted_nodes, + highlight_css=highlight_css, image_type=image_type, pdf_fit=pdf_fit, pdf_landscape=pdf_landscape, @@ -159,6 +170,9 @@ def mermaid_image( def mermaid_save( self, path: Path | str, + *, + highlighted_nodes: Sequence[mermaid.NodeIdent] | None = None, + highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, image_type: mermaid.ImageType | None = None, pdf_fit: bool = False, pdf_landscape: bool = False, @@ -172,7 +186,9 @@ def mermaid_save( mermaid.save_image( path, self.graph, - self.first_node.get_id(), + {self.first_node.get_id()}, + highlighted_nodes=highlighted_nodes, + highlight_css=highlight_css, image_type=image_type, pdf_fit=pdf_fit, pdf_landscape=pdf_landscape, diff --git a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py index 431d26d642..f7d0e42611 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py +++ b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py @@ -9,20 +9,30 @@ if TYPE_CHECKING: from .graph import Graph + from .nodes import BaseNode -def generate_code(graph: Graph[Any, Any], start_node_ids: Sequence[str] | str) -> str: +NodeIdent = type[BaseNode[Any, Any]] | str +DEFAULT_HIGHLIGHT_CSS = 'fill:#f9f' + + +def generate_code( + graph: Graph[Any, Any], + start_node_ids: set[str], + *, + highlighted_nodes: Sequence[NodeIdent] | None = None, + highlight_css: str = DEFAULT_HIGHLIGHT_CSS, +) -> str: """Generate Mermaid code for a graph. Args: graph: The graph to generate the image for. - start_node_ids: IDs of start nodes of the graph. + start_node_ids: Identifiers of nodes that start the graph. + highlighted_nodes: Identifiers of nodes to highlight. + highlight_css: CSS to use for highlighting nodes. Returns: The Mermaid code for the graph. """ - if isinstance(start_node_ids, str): - start_node_ids = (start_node_ids,) - for node_id in start_node_ids: if node_id not in graph.node_defs: raise LookupError(f'Start node "{node_id}" is not in the graph.') @@ -44,6 +54,15 @@ def generate_code(graph: Graph[Any, Any], start_node_ids: Sequence[str] | str) - if node_def.returns_end: lines.append(f' {node_id} --> END') + if highlighted_nodes: + lines.append('') + lines.append(f'classDef highlighted {highlight_css}') + for node in highlighted_nodes: + node_id = node if isinstance(node, str) else node.get_id() + if node_id not in graph.node_defs: + raise LookupError(f'Highlighted node "{node_id}" is not in the graph.') + lines.append(f'class {node_id} highlighted') + return '\n'.join(lines) @@ -54,8 +73,10 @@ def generate_code(graph: Graph[Any, Any], start_node_ids: Sequence[str] | str) - def request_image( graph: Graph[Any, Any], - start_node_ids: Sequence[str] | str, + start_node_ids: set[str], *, + highlighted_nodes: Sequence[NodeIdent] | None = None, + highlight_css: str = 'fill:#f9f,stroke:#333,stroke-width:4px', image_type: ImageType | str | None = None, pdf_fit: bool = False, pdf_landscape: bool = False, @@ -71,6 +92,8 @@ def request_image( Args: graph: The graph to generate the image for. start_node_ids: IDs of start nodes of the graph. + highlighted_nodes: Identifiers of nodes to highlight. + highlight_css: CSS to use for highlighting nodes. image_type: The image type to generate. If unspecified, the default behavior is `'jpeg'`. pdf_fit: When using image_type='pdf', whether to fit the diagram to the PDF page. pdf_landscape: When using image_type='pdf', whether to use landscape orientation for the PDF. @@ -90,7 +113,8 @@ def request_image( """ import httpx - code_base64 = base64.b64encode(generate_code(graph, start_node_ids).encode()).decode() + code = generate_code(graph, start_node_ids, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css) + code_base64 = base64.b64encode(code.encode()).decode() params: dict[str, str] = {} if image_type == 'pdf': @@ -128,8 +152,10 @@ def request_image( def save_image( path: Path | str, graph: Graph[Any, Any], - start_node_ids: Sequence[str] | str, + start_node_ids: set[str], *, + highlighted_nodes: Sequence[NodeIdent] | None = None, + highlight_css: str = 'fill:#f9f,stroke:#333,stroke-width:4px', image_type: ImageType | None = None, pdf_fit: bool = False, pdf_landscape: bool = False, @@ -146,6 +172,8 @@ def save_image( path: The path to save the image to. graph: The graph to generate the image for. start_node_ids: IDs of start nodes of the graph. + highlighted_nodes: Identifiers of nodes to highlight. + highlight_css: CSS to use for highlighting nodes. image_type: The image type to generate. If unspecified, the default behavior is `'jpeg'`. pdf_fit: When using image_type='pdf', whether to fit the diagram to the PDF page. pdf_landscape: When using image_type='pdf', whether to use landscape orientation for the PDF. @@ -173,6 +201,8 @@ def save_image( image_data = request_image( graph, start_node_ids, + highlighted_nodes=highlighted_nodes, + highlight_css=highlight_css, image_type=image_type, pdf_fit=pdf_fit, pdf_landscape=pdf_landscape, diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py index 45b01aedfd..66c527d882 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -3,11 +3,13 @@ import copy from abc import ABC, abstractmethod from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import datetime from typing import TYPE_CHECKING, Generic, Literal, Union from typing_extensions import Never, TypeVar +from . import _utils + __all__ = 'AbstractState', 'StateT', 'Step', 'EndEvent', 'StepOrEnd' if TYPE_CHECKING: @@ -35,15 +37,18 @@ class Step(Generic[StateT, RunEndT]): state: StateT node: BaseNode[StateT, RunEndT] - start_ts: datetime = field(default_factory=lambda: datetime.now(tz=timezone.utc)) + start_ts: datetime = field(default_factory=_utils.now_utc) duration: float | None = None - kind: Literal['start_step'] = 'start_step' + kind: Literal['step'] = 'step' def __post_init__(self): # Copy the state to prevent it from being modified by other code self.state = copy.deepcopy(self.state) + def node_summary(self) -> str: + return str(self.node) + @dataclass class EndEvent(Generic[StateT, RunEndT]): @@ -51,7 +56,7 @@ class EndEvent(Generic[StateT, RunEndT]): state: StateT result: End[RunEndT] - ts: datetime = field(default_factory=lambda: datetime.now(tz=timezone.utc)) + ts: datetime = field(default_factory=_utils.now_utc) kind: Literal['end'] = 'end' @@ -59,5 +64,8 @@ def __post_init__(self): # Copy the state to prevent it from being modified by other code self.state = copy.deepcopy(self.state) + def node_summary(self) -> str: + return str(self.result) + StepOrEnd = Union[Step[StateT, RunEndT], EndEvent[StateT, RunEndT]] From 749cc31775fa5da88fc9a922bde93acabb9202f6 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 7 Jan 2025 23:21:22 +0000 Subject: [PATCH 17/57] bump From c0d35dac58a1bbba8cdc94b9182db6c4e943702e Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 7 Jan 2025 23:23:33 +0000 Subject: [PATCH 18/57] fix type checking imports --- pydantic_ai_graph/pydantic_ai_graph/mermaid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py index f7d0e42611..cf83f60de1 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py +++ b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py @@ -7,9 +7,10 @@ from annotated_types import Ge, Le +from .nodes import BaseNode + if TYPE_CHECKING: from .graph import Graph - from .nodes import BaseNode NodeIdent = type[BaseNode[Any, Any]] | str From 190fe40f2b7df9d015e6acdf22866f4a4009b4a5 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 7 Jan 2025 23:27:13 +0000 Subject: [PATCH 19/57] fix for python 3.9 --- pydantic_ai_graph/pydantic_ai_graph/mermaid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py index cf83f60de1..11fe9c6a59 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py +++ b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Literal, cast from annotated_types import Ge, Le +from typing_extensions import TypeAlias from .nodes import BaseNode @@ -13,7 +14,7 @@ from .graph import Graph -NodeIdent = type[BaseNode[Any, Any]] | str +NodeIdent: TypeAlias = 'type[BaseNode[Any, Any]] | str' DEFAULT_HIGHLIGHT_CSS = 'fill:#f9f' From ccc0c17e59ed9b77246f005032cb52f3a59cad67 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 8 Jan 2025 18:07:47 +0000 Subject: [PATCH 20/57] simplify mermaid config --- pydantic_ai_graph/pydantic_ai_graph/graph.py | 76 +------ .../pydantic_ai_graph/mermaid.py | 209 +++++++++--------- 2 files changed, 109 insertions(+), 176 deletions(-) diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index 98437c4941..ce39168d2c 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -5,11 +5,10 @@ from dataclasses import dataclass, field from pathlib import Path from time import perf_counter -from typing import TYPE_CHECKING, Annotated, Generic +from typing import TYPE_CHECKING, Generic import logfire_api -from annotated_types import Ge, Le -from typing_extensions import Never, ParamSpec, Protocol, TypeVar, assert_never +from typing_extensions import Never, ParamSpec, Protocol, TypeVar, Unpack, assert_never from . import _utils, mermaid from ._utils import get_parent_namespace @@ -129,76 +128,19 @@ async def run( def mermaid_code( self, - highlighted_nodes: Sequence[mermaid.NodeIdent] | None = None, + *, + highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, ) -> str: return mermaid.generate_code( - self.graph, {self.first_node.get_id()}, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css + self.graph, self.first_node.get_id(), highlighted_nodes=highlighted_nodes, highlight_css=highlight_css ) - def mermaid_image( - self, - *, - highlighted_nodes: Sequence[mermaid.NodeIdent] | None = None, - highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, - image_type: mermaid.ImageType | None = None, - pdf_fit: bool = False, - pdf_landscape: bool = False, - pdf_paper: mermaid.PdfPaper | None = None, - bg_color: str | None = None, - theme: mermaid.Theme | None = None, - width: int | None = None, - height: int | None = None, - scale: Annotated[float, Ge(1), Le(3)] | None = None, - ) -> bytes: - return mermaid.request_image( - self.graph, - {self.first_node.get_id()}, - highlighted_nodes=highlighted_nodes, - highlight_css=highlight_css, - image_type=image_type, - pdf_fit=pdf_fit, - pdf_landscape=pdf_landscape, - pdf_paper=pdf_paper, - bg_color=bg_color, - theme=theme, - width=width, - height=height, - scale=scale, - ) + def mermaid_image(self, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes: + return mermaid.request_image(self.graph, self.first_node.get_id(), **kwargs) - def mermaid_save( - self, - path: Path | str, - *, - highlighted_nodes: Sequence[mermaid.NodeIdent] | None = None, - highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, - image_type: mermaid.ImageType | None = None, - pdf_fit: bool = False, - pdf_landscape: bool = False, - pdf_paper: mermaid.PdfPaper | None = None, - bg_color: str | None = None, - theme: mermaid.Theme | None = None, - width: int | None = None, - height: int | None = None, - scale: Annotated[float, Ge(1), Le(3)] | None = None, - ) -> None: - mermaid.save_image( - path, - self.graph, - {self.first_node.get_id()}, - highlighted_nodes=highlighted_nodes, - highlight_css=highlight_css, - image_type=image_type, - pdf_fit=pdf_fit, - pdf_landscape=pdf_landscape, - pdf_paper=pdf_paper, - bg_color=bg_color, - theme=theme, - width=width, - height=height, - scale=scale, - ) + def mermaid_save(self, path: Path | str, /, **kwargs: Unpack[mermaid.MermaidConfig]) -> None: + mermaid.save_image(path, self.graph, self.first_node.get_id(), **kwargs) @dataclass diff --git a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py index 11fe9c6a59..00d2d3a9f2 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py +++ b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py @@ -1,12 +1,12 @@ from __future__ import annotations as _annotations import base64 -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any, Literal, cast +from typing import TYPE_CHECKING, Annotated, Any, Literal from annotated_types import Ge, Le -from typing_extensions import TypeAlias +from typing_extensions import TypeAlias, TypedDict, Unpack from .nodes import BaseNode @@ -14,27 +14,29 @@ from .graph import Graph -NodeIdent: TypeAlias = 'type[BaseNode[Any, Any]] | str' -DEFAULT_HIGHLIGHT_CSS = 'fill:#f9f' +NodeIdent: TypeAlias = 'type[BaseNode[Any, Any]] | BaseNode[Any, Any] | str' +DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32' def generate_code( graph: Graph[Any, Any], - start_node_ids: set[str], + start_nodes: Sequence[NodeIdent] | NodeIdent, + /, *, - highlighted_nodes: Sequence[NodeIdent] | None = None, + highlighted_nodes: Sequence[NodeIdent] | NodeIdent | None = None, highlight_css: str = DEFAULT_HIGHLIGHT_CSS, ) -> str: """Generate Mermaid code for a graph. Args: graph: The graph to generate the image for. - start_node_ids: Identifiers of nodes that start the graph. + start_nodes: Identifiers of nodes that start the graph. highlighted_nodes: Identifiers of nodes to highlight. highlight_css: CSS to use for highlighting nodes. Returns: The Mermaid code for the graph. """ + start_node_ids = set(node_ids(start_nodes)) for node_id in start_node_ids: if node_id not in graph.node_defs: raise LookupError(f'Start node "{node_id}" is not in the graph.') @@ -59,8 +61,7 @@ def generate_code( if highlighted_nodes: lines.append('') lines.append(f'classDef highlighted {highlight_css}') - for node in highlighted_nodes: - node_id = node if isinstance(node, str) else node.get_id() + for node_id in node_ids(highlighted_nodes): if node_id not in graph.node_defs: raise LookupError(f'Highlighted node "{node_id}" is not in the graph.') lines.append(f'class {node_id} highlighted') @@ -68,82 +69,111 @@ def generate_code( return '\n'.join(lines) -ImageType = Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] -PdfPaper = Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] -Theme = Literal['default', 'neutral', 'dark', 'forest'] +def node_ids(node_idents: Sequence[NodeIdent] | NodeIdent) -> Iterable[str]: + """Get the node IDs from a sequence of node identifiers.""" + if isinstance(node_idents, str): + node_iter = (node_idents,) + elif isinstance(node_idents, Sequence): + node_iter = node_idents + else: + node_iter = (node_idents,) + + for node in node_iter: + if isinstance(node, str): + yield node + else: + yield node.get_id() + + +class MermaidConfig(TypedDict, total=False): + """Parameters to configure mermaid chart generation.""" + + highlighted_nodes: Sequence[NodeIdent] | NodeIdent + """Identifiers of nodes to highlight.""" + highlight_css: str + """CSS to use for highlighting nodes.""" + image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] + """The image type to generate. If unspecified, the default behavior is `'jpeg'`.""" + pdf_fit: bool + """When using image_type='pdf', whether to fit the diagram to the PDF page.""" + pdf_landscape: bool + """When using image_type='pdf', whether to use landscape orientation for the PDF. + + This has no effect if using `pdf_fit`. + """ + pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] + """When using image_type='pdf', the paper size of the PDF.""" + background_color: str + """The background color of the diagram. + + If None, the default transparent background is used. The color value is interpreted as a hexadecimal color + code by default (and should not have a leading '#'), but you can also use named colors by prefixing the + value with `'!'`. For example, valid choices include `background_color='!white'` or `background_color='FF0000'`. + """ + theme: Literal['default', 'neutral', 'dark', 'forest'] + """The theme of the diagram. Defaults to 'default'.""" + width: int + """The width of the diagram.""" + height: int + """The height of the diagram.""" + scale: Annotated[float, Ge(1), Le(3)] + """The scale of the diagram. + + The scale must be a number between 1 and 3, and you can only set a scale if one or both of width and height are set. + """ def request_image( graph: Graph[Any, Any], - start_node_ids: set[str], - *, - highlighted_nodes: Sequence[NodeIdent] | None = None, - highlight_css: str = 'fill:#f9f,stroke:#333,stroke-width:4px', - image_type: ImageType | str | None = None, - pdf_fit: bool = False, - pdf_landscape: bool = False, - pdf_paper: PdfPaper | None = None, - bg_color: str | None = None, - theme: Theme | None = None, - width: int | None = None, - height: int | None = None, - scale: Annotated[float, Ge(1), Le(3)] | None = None, + start_nodes: Sequence[NodeIdent] | NodeIdent, + /, + **kwargs: Unpack[MermaidConfig], ) -> bytes: """Generate an image of a Mermaid diagram using [mermaid.ink](https://mermaid.ink). Args: graph: The graph to generate the image for. - start_node_ids: IDs of start nodes of the graph. - highlighted_nodes: Identifiers of nodes to highlight. - highlight_css: CSS to use for highlighting nodes. - image_type: The image type to generate. If unspecified, the default behavior is `'jpeg'`. - pdf_fit: When using image_type='pdf', whether to fit the diagram to the PDF page. - pdf_landscape: When using image_type='pdf', whether to use landscape orientation for the PDF. - This has no effect if using `pdf_fit`. - pdf_paper: When using image_type='pdf', the paper size of the PDF. - bg_color: The background color of the diagram. If None, the default transparent background is used. - The color value is interpreted as a hexadecimal color code by default (and should not have a leading '#'), - but you can also use named colors by prefixing the value with `'!'`. - For example, valid choices include `bg_color='!white'` or `bg_color='FF0000'`. - theme: The theme of the diagram. Defaults to 'default'. - width: The width of the diagram. - height: The height of the diagram. - scale: The scale of the diagram. The scale must be a number between 1 and 3, and you can only set - a scale if one or both of width and height are set. + start_nodes: Identifiers of nodes that start the graph. + **kwargs: Additional parameters to configure mermaid chart generation. Returns: The image data. """ import httpx - code = generate_code(graph, start_node_ids, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css) + code = generate_code( + graph, + start_nodes, + highlighted_nodes=kwargs.get('highlighted_nodes'), + highlight_css=kwargs.get('highlight_css', DEFAULT_HIGHLIGHT_CSS), + ) code_base64 = base64.b64encode(code.encode()).decode() - params: dict[str, str] = {} - if image_type == 'pdf': + params: dict[str, str | bool] = {} + if kwargs.get('image_type') == 'pdf': url = f'https://mermaid.ink/pdf/{code_base64}' - if pdf_fit: - params['fit'] = '' - if pdf_landscape: - params['landscape'] = '' - if pdf_paper: + if kwargs.get('pdf_fit'): + params['fit'] = True + if kwargs.get('pdf_landscape'): + params['landscape'] = True + if pdf_paper := kwargs.get('pdf_paper'): params['paper'] = pdf_paper - elif image_type == 'svg': + elif kwargs.get('image_type') == 'svg': url = f'https://mermaid.ink/svg/{code_base64}' else: url = f'https://mermaid.ink/img/{code_base64}' - if image_type: + if image_type := kwargs.get('image_type'): params['type'] = image_type - if bg_color: - params['bgColor'] = bg_color - if theme: + if background_color := kwargs.get('background_color'): + params['bgColor'] = background_color + if theme := kwargs.get('theme'): params['theme'] = theme - if width: + if width := kwargs.get('width'): params['width'] = str(width) - if height: + if height := kwargs.get('height'): params['height'] = str(height) - if scale: + if scale := kwargs.get('scale'): params['scale'] = str(scale) response = httpx.get(url, params=params) @@ -154,65 +184,26 @@ def request_image( def save_image( path: Path | str, graph: Graph[Any, Any], - start_node_ids: set[str], - *, - highlighted_nodes: Sequence[NodeIdent] | None = None, - highlight_css: str = 'fill:#f9f,stroke:#333,stroke-width:4px', - image_type: ImageType | None = None, - pdf_fit: bool = False, - pdf_landscape: bool = False, - pdf_paper: PdfPaper | None = None, - bg_color: str | None = None, - theme: Theme | None = None, - width: int | None = None, - height: int | None = None, - scale: Annotated[float, Ge(1), Le(3)] | None = None, + start_nodes: Sequence[NodeIdent] | NodeIdent, + /, + **kwargs: Unpack[MermaidConfig], ) -> None: """Generate an image of a Mermaid diagram using [mermaid.ink](https://mermaid.ink) and save it to a local file. Args: path: The path to save the image to. graph: The graph to generate the image for. - start_node_ids: IDs of start nodes of the graph. - highlighted_nodes: Identifiers of nodes to highlight. - highlight_css: CSS to use for highlighting nodes. - image_type: The image type to generate. If unspecified, the default behavior is `'jpeg'`. - pdf_fit: When using image_type='pdf', whether to fit the diagram to the PDF page. - pdf_landscape: When using image_type='pdf', whether to use landscape orientation for the PDF. - This has no effect if using `pdf_fit`. - pdf_paper: When using image_type='pdf', the paper size of the PDF. - bg_color: The background color of the diagram. If None, the default transparent background is used. - The color value is interpreted as a hexadecimal color code by default (and should not have a leading '#'), - but you can also use named colors by prefixing the value with `'!'`. - For example, valid choices include `bg_color='!white'` or `bg_color='FF0000'`. - theme: The theme of the diagram. Defaults to 'default'. - width: The width of the diagram. - height: The height of the diagram. - scale: The scale of the diagram. The scale must be a number between 1 and 3, and you can only set - a scale if one or both of width and height are set. + start_nodes: Identifiers of nodes that start the graph. + **kwargs: Additional parameters to configure mermaid chart generation. """ if isinstance(path, str): path = Path(path) - if image_type is None: - ext = path.suffix.lower() + if 'image_type' not in kwargs: + ext = path.suffix.lower()[1:] # no need to check for .jpeg/.jpg, as it is the default - if ext in {'.png', '.webp', '.svg', '.pdf'}: - image_type = cast(ImageType, ext[1:]) + if ext in ('png', 'webp', 'svg', 'pdf'): + kwargs['image_type'] = ext - image_data = request_image( - graph, - start_node_ids, - highlighted_nodes=highlighted_nodes, - highlight_css=highlight_css, - image_type=image_type, - pdf_fit=pdf_fit, - pdf_landscape=pdf_landscape, - pdf_paper=pdf_paper, - bg_color=bg_color, - theme=theme, - width=width, - height=height, - scale=scale, - ) - Path(path).write_bytes(image_data) + image_data = request_image(graph, start_nodes, **kwargs) + path.write_bytes(image_data) From 50b590f8056834f9d55450ec1ef3d03c48983494 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 8 Jan 2025 18:14:19 +0000 Subject: [PATCH 21/57] remove GraphRunner --- .../email_extract_graph.py | 5 +- .../pydantic_ai_graph/__init__.py | 3 +- pydantic_ai_graph/pydantic_ai_graph/graph.py | 63 +++++-------------- tests/test_graph.py | 5 +- 4 files changed, 20 insertions(+), 56 deletions(-) diff --git a/examples/pydantic_ai_examples/email_extract_graph.py b/examples/pydantic_ai_examples/email_extract_graph.py index 8465f5bcd9..a726b8afff 100644 --- a/examples/pydantic_ai_examples/email_extract_graph.py +++ b/examples/pydantic_ai_examples/email_extract_graph.py @@ -122,8 +122,7 @@ async def run( InspectEvent, ) ) -graph_runner = graph.get_runner(ExtractEvent) -print(graph_runner.mermaid_code()) +print(graph.mermaid_code(ExtractEvent)) email = """ Hi Samuel, @@ -148,7 +147,7 @@ async def run( async def main(): state = State(email_content=email) - result, history = await graph_runner.run(state, None) + result, history = await graph.run(state, ExtractEvent()) debug(result, history) diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py index b7d79a6009..f6b7054845 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/__init__.py +++ b/pydantic_ai_graph/pydantic_ai_graph/__init__.py @@ -1,10 +1,9 @@ -from .graph import Graph, GraphRun, GraphRunner +from .graph import Graph, GraphRun from .nodes import BaseNode, End, GraphContext from .state import AbstractState, EndEvent, Step, StepOrEnd __all__ = ( 'Graph', - 'GraphRunner', 'GraphRun', 'BaseNode', 'End', diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index ce39168d2c..d1541f29e9 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -8,14 +8,14 @@ from typing import TYPE_CHECKING, Generic import logfire_api -from typing_extensions import Never, ParamSpec, Protocol, TypeVar, Unpack, assert_never +from typing_extensions import Never, ParamSpec, TypeVar, Unpack, assert_never from . import _utils, mermaid from ._utils import get_parent_namespace from .nodes import BaseNode, End, GraphContext, NodeDef from .state import EndEvent, StateT, Step, StepOrEnd -__all__ = 'Graph', 'GraphRun', 'GraphRunner' +__all__ = 'Graph', 'GraphRun' _logfire = logfire_api.Logfire(otel_scope='pydantic-ai-graph') @@ -24,12 +24,6 @@ NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) -class StartNodeProtocol(Protocol[RunSignatureT, StateT, NodeRunEndT]): - def get_id(self) -> str: ... - - def __call__(self, *args: RunSignatureT.args, **kwargs: RunSignatureT.kwargs) -> BaseNode[StateT, NodeRunEndT]: ... - - @dataclass(init=False) class Graph(Generic[StateT, RunEndT]): """Definition of a graph.""" @@ -91,56 +85,29 @@ async def run( history = run.history return result, history - def get_runner( - self, - first_node: StartNodeProtocol[RunSignatureT, StateT, RunEndT], - ) -> GraphRunner[RunSignatureT, StateT, RunEndT]: - # noinspection PyTypeChecker - return GraphRunner( - graph=self, - first_node=first_node, - ) - - -@dataclass -class GraphRunner(Generic[RunSignatureT, StateT, RunEndT]): - """Runner for a graph. - - This is a separate class from Graph so that you can get a type-safe runner from a graph definition - without needing to manually annotate the paramspec of the start node. - """ - - graph: Graph[StateT, RunEndT] - first_node: StartNodeProtocol[RunSignatureT, StateT, RunEndT] - - def __post_init__(self): - if self.first_node not in self.graph.nodes: - raise ValueError(f'Start node "{self.first_node}" is not in the graph.') - - async def run( - self, state: StateT, /, *args: RunSignatureT.args, **kwargs: RunSignatureT.kwargs - ) -> tuple[RunEndT, list[StepOrEnd[StateT, RunEndT]]]: - run = GraphRun[StateT, RunEndT](state=state) - # TODO: Infer the graph name properly - result = await run.run(self.graph.name or 'graph', self.first_node(*args, **kwargs)) - history = run.history - return result, history - def mermaid_code( self, + start_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent, *, highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, ) -> str: return mermaid.generate_code( - self.graph, self.first_node.get_id(), highlighted_nodes=highlighted_nodes, highlight_css=highlight_css + self, start_nodes, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css ) - def mermaid_image(self, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes: - return mermaid.request_image(self.graph, self.first_node.get_id(), **kwargs) + def mermaid_image( + self, start_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent, **kwargs: Unpack[mermaid.MermaidConfig] + ) -> bytes: + return mermaid.request_image(self, start_nodes, **kwargs) - def mermaid_save(self, path: Path | str, /, **kwargs: Unpack[mermaid.MermaidConfig]) -> None: - mermaid.save_image(path, self.graph, self.first_node.get_id(), **kwargs) + def mermaid_save( + self, + start_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent, + path: Path | str, + **kwargs: Unpack[mermaid.MermaidConfig], + ) -> None: + mermaid.save_image(path, self, start_nodes, **kwargs) @dataclass diff --git a/tests/test_graph.py b/tests/test_graph.py index 4a06603a7e..6187b764e6 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -40,8 +40,7 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq return End(self.input_data * 2) g = Graph[None, int](nodes=(Float2String, String2Length, Double)) - runner = g.get_runner(Float2String) - result, history = await runner.run(None, 3.14) + result, history = await g.run(None, Float2String(3.14)) # len('3.14') * 2 == 8 assert result == 8 assert history == snapshot( @@ -71,7 +70,7 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq ), ] ) - result, history = await runner.run(None, 3.14159) + result, history = await g.run(None, Float2String(3.14159)) # len('3.14159') == 7, 21 * 2 == 42 assert result == 42 assert history == snapshot( From 8ac10c776228085f9f8e31bc485c6b22b2a929bb Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 9 Jan 2025 00:19:38 +0000 Subject: [PATCH 22/57] add Interrupt --- .../email_extract_graph.py | 2 +- .../pydantic_ai_graph/__init__.py | 9 +-- pydantic_ai_graph/pydantic_ai_graph/graph.py | 46 +++++++------- .../pydantic_ai_graph/mermaid.py | 33 ++++++---- pydantic_ai_graph/pydantic_ai_graph/nodes.py | 61 ++++++++++++++----- pydantic_ai_graph/pydantic_ai_graph/state.py | 47 +++++++++++--- tests/test_graph.py | 22 +++---- 7 files changed, 144 insertions(+), 76 deletions(-) diff --git a/examples/pydantic_ai_examples/email_extract_graph.py b/examples/pydantic_ai_examples/email_extract_graph.py index a726b8afff..8cf30803e1 100644 --- a/examples/pydantic_ai_examples/email_extract_graph.py +++ b/examples/pydantic_ai_examples/email_extract_graph.py @@ -122,7 +122,7 @@ async def run( InspectEvent, ) ) -print(graph.mermaid_code(ExtractEvent)) +print(graph.mermaid_code(start_node=ExtractEvent)) email = """ Hi Samuel, diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py index f6b7054845..afdad1bdc2 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/__init__.py +++ b/pydantic_ai_graph/pydantic_ai_graph/__init__.py @@ -1,15 +1,16 @@ from .graph import Graph, GraphRun -from .nodes import BaseNode, End, GraphContext -from .state import AbstractState, EndEvent, Step, StepOrEnd +from .nodes import BaseNode, End, GraphContext, Interrupt +from .state import AbstractState, EndEvent, HistoryStep, NextNodeEvent __all__ = ( 'Graph', 'GraphRun', 'BaseNode', 'End', + 'Interrupt', 'GraphContext', 'AbstractState', 'EndEvent', - 'StepOrEnd', - 'Step', + 'HistoryStep', + 'NextNodeEvent', ) diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index d1541f29e9..22db3d6497 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -12,8 +12,8 @@ from . import _utils, mermaid from ._utils import get_parent_namespace -from .nodes import BaseNode, End, GraphContext, NodeDef -from .state import EndEvent, StateT, Step, StepOrEnd +from .nodes import BaseNode, End, GraphContext, Interrupt, NodeDef, RunInterrupt +from .state import EndEvent, HistoryStep, InterruptEvent, NextNodeEvent, StateT __all__ = 'Graph', 'GraphRun' @@ -76,7 +76,7 @@ def _validate_edges(self): async def run( self, state: StateT, node: BaseNode[StateT, RunEndT] - ) -> tuple[RunEndT, list[StepOrEnd[StateT, RunEndT]]]: + ) -> tuple[End[RunEndT] | RunInterrupt[StateT], list[HistoryStep[StateT, RunEndT]]]: if not isinstance(node, self.nodes): raise ValueError(f'Node "{node}" is not in the graph.') run = GraphRun[StateT, RunEndT](state=state) @@ -87,27 +87,20 @@ async def run( def mermaid_code( self, - start_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent, *, + start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, ) -> str: return mermaid.generate_code( - self, start_nodes, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css + self, start_node=start_node, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css ) - def mermaid_image( - self, start_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent, **kwargs: Unpack[mermaid.MermaidConfig] - ) -> bytes: - return mermaid.request_image(self, start_nodes, **kwargs) + def mermaid_image(self, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes: + return mermaid.request_image(self, **kwargs) - def mermaid_save( - self, - start_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent, - path: Path | str, - **kwargs: Unpack[mermaid.MermaidConfig], - ) -> None: - mermaid.save_image(path, self, start_nodes, **kwargs) + def mermaid_save(self, path: Path | str, /, **kwargs: Unpack[mermaid.MermaidConfig]) -> None: + mermaid.save_image(path, self, **kwargs) @dataclass @@ -115,9 +108,11 @@ class GraphRun(Generic[StateT, RunEndT]): """Stateful run of a graph.""" state: StateT - history: list[StepOrEnd[StateT, RunEndT]] = field(default_factory=list) + history: list[HistoryStep[StateT, RunEndT]] = field(default_factory=list) - async def run(self, graph_name: str, start: BaseNode[StateT, RunEndT], infer_name: bool = True) -> RunEndT: + async def run( + self, graph_name: str, start: BaseNode[StateT, RunEndT], infer_name: bool = True + ) -> End[RunEndT] | RunInterrupt[StateT]: current_node = start with _logfire.span( @@ -127,10 +122,13 @@ async def run(self, graph_name: str, start: BaseNode[StateT, RunEndT], infer_nam ) as run_span: while True: next_node = await self.step(current_node) - if isinstance(next_node, End): - self.history.append(EndEvent(self.state, next_node)) + if isinstance(next_node, (Interrupt, End)): + if isinstance(next_node, End): + self.history.append(EndEvent(self.state, next_node)) + else: + self.history.append(InterruptEvent(self.state, next_node)) run_span.set_attribute('history', self.history) - return next_node.data + return next_node elif isinstance(next_node, BaseNode): current_node = next_node else: @@ -139,8 +137,10 @@ async def run(self, graph_name: str, start: BaseNode[StateT, RunEndT], infer_nam else: raise TypeError(f'Invalid node type: {type(next_node)}. Expected `BaseNode` or `End`.') - async def step(self, node: BaseNode[StateT, RunEndT]) -> BaseNode[StateT, RunEndT] | End[RunEndT]: - history_step = Step(self.state, node) + async def step( + self, node: BaseNode[StateT, RunEndT] + ) -> BaseNode[StateT, RunEndT] | End[RunEndT] | RunInterrupt[StateT]: + history_step = NextNodeEvent(self.state, node) self.history.append(history_step) ctx = GraphContext(self.state) diff --git a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py index 00d2d3a9f2..109423bcee 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py +++ b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py @@ -2,6 +2,7 @@ import base64 from collections.abc import Iterable, Sequence +from itertools import chain from pathlib import Path from typing import TYPE_CHECKING, Annotated, Any, Literal @@ -20,9 +21,9 @@ def generate_code( graph: Graph[Any, Any], - start_nodes: Sequence[NodeIdent] | NodeIdent, /, *, + start_node: Sequence[NodeIdent] | NodeIdent | None = None, highlighted_nodes: Sequence[NodeIdent] | NodeIdent | None = None, highlight_css: str = DEFAULT_HIGHLIGHT_CSS, ) -> str: @@ -30,33 +31,41 @@ def generate_code( Args: graph: The graph to generate the image for. - start_nodes: Identifiers of nodes that start the graph. + start_node: Identifiers of nodes that start the graph. highlighted_nodes: Identifiers of nodes to highlight. highlight_css: CSS to use for highlighting nodes. Returns: The Mermaid code for the graph. """ - start_node_ids = set(node_ids(start_nodes)) + start_node_ids = set(node_ids(start_node or ())) for node_id in start_node_ids: if node_id not in graph.node_defs: raise LookupError(f'Start node "{node_id}" is not in the graph.') node_order = {node_id: index for index, node_id in enumerate(graph.node_defs)} + after_interrupt_nodes = set(chain(*(node_def.after_interrupt_nodes for node_def in graph.node_defs.values()))) lines = ['graph TD'] for node in graph.nodes: node_id = node.get_id() node_def = graph.node_defs[node_id] + + # we use square brackets (square box) for nodes that can interrupt the flow, + # and round brackets (rounded box) for nodes that cannot interrupt the flow + if node_id in after_interrupt_nodes: + mermaid_name = f'[{node_id}]' + else: + mermaid_name = f'({node_id})' if node_id in start_node_ids: - lines.append(f' START --> {node_id}') + lines.append(f' START --> {node_id}{mermaid_name}') if node_def.returns_base_node: for next_node_id in graph.nodes: - lines.append(f' {node_id} --> {next_node_id}') + lines.append(f' {node_id}{mermaid_name} --> {next_node_id}') else: for _, next_node_id in sorted((node_order[node_id], node_id) for node_id in node_def.next_node_ids): - lines.append(f' {node_id} --> {next_node_id}') + lines.append(f' {node_id}{mermaid_name} --> {next_node_id}') if node_def.returns_end: - lines.append(f' {node_id} --> END') + lines.append(f' {node_id}{mermaid_name} --> END') if highlighted_nodes: lines.append('') @@ -88,6 +97,8 @@ def node_ids(node_idents: Sequence[NodeIdent] | NodeIdent) -> Iterable[str]: class MermaidConfig(TypedDict, total=False): """Parameters to configure mermaid chart generation.""" + start_node: Sequence[NodeIdent] | NodeIdent + """Identifiers of nodes that start the graph.""" highlighted_nodes: Sequence[NodeIdent] | NodeIdent """Identifiers of nodes to highlight.""" highlight_css: str @@ -125,7 +136,6 @@ class MermaidConfig(TypedDict, total=False): def request_image( graph: Graph[Any, Any], - start_nodes: Sequence[NodeIdent] | NodeIdent, /, **kwargs: Unpack[MermaidConfig], ) -> bytes: @@ -133,7 +143,6 @@ def request_image( Args: graph: The graph to generate the image for. - start_nodes: Identifiers of nodes that start the graph. **kwargs: Additional parameters to configure mermaid chart generation. Returns: The image data. @@ -142,7 +151,7 @@ def request_image( code = generate_code( graph, - start_nodes, + start_node=kwargs.get('start_node'), highlighted_nodes=kwargs.get('highlighted_nodes'), highlight_css=kwargs.get('highlight_css', DEFAULT_HIGHLIGHT_CSS), ) @@ -184,7 +193,6 @@ def request_image( def save_image( path: Path | str, graph: Graph[Any, Any], - start_nodes: Sequence[NodeIdent] | NodeIdent, /, **kwargs: Unpack[MermaidConfig], ) -> None: @@ -193,7 +201,6 @@ def save_image( Args: path: The path to save the image to. graph: The graph to generate the image for. - start_nodes: Identifiers of nodes that start the graph. **kwargs: Additional parameters to configure mermaid chart generation. """ if isinstance(path, str): @@ -205,5 +212,5 @@ def save_image( if ext in ('png', 'webp', 'svg', 'pdf'): kwargs['image_type'] = ext - image_data = request_image(graph, start_nodes, **kwargs) + image_data = request_image(graph, **kwargs) path.write_bytes(image_data) diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_ai_graph/pydantic_ai_graph/nodes.py index 1a1b8bcd5d..168339d0a9 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/nodes.py +++ b/pydantic_ai_graph/pydantic_ai_graph/nodes.py @@ -3,14 +3,14 @@ from abc import abstractmethod from dataclasses import dataclass from functools import cache -from typing import Any, Generic, get_origin, get_type_hints +from typing import Any, Generic, get_args, get_origin, get_type_hints from typing_extensions import Never, TypeVar from . import _utils from .state import StateT -__all__ = ('GraphContext', 'End', 'BaseNode', 'NodeDef') +__all__ = 'GraphContext', 'BaseNode', 'End', 'Interrupt', 'RunInterrupt', 'NodeDef' RunEndT = TypeVar('RunEndT', default=None) NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) @@ -23,18 +23,13 @@ class GraphContext(Generic[StateT]): state: StateT -@dataclass -class End(Generic[RunEndT]): - """Type to return from a node to signal the end of the graph.""" - - data: RunEndT - - class BaseNode(Generic[StateT, NodeRunEndT]): """Base class for a node.""" @abstractmethod - async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: ... + async def run( + self, ctx: GraphContext[StateT] + ) -> BaseNode[StateT, Any] | End[NodeRunEndT] | RunInterrupt[StateT]: ... @classmethod @cache @@ -44,20 +39,27 @@ def get_id(cls) -> str: @classmethod def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRunEndT]: type_hints = get_type_hints(cls.run, localns=local_ns) - next_node_ids: set[str] = set() - returns_end: bool = False - returns_base_node: bool = False try: return_hint = type_hints['return'] except KeyError: raise TypeError(f'Node {cls} is missing a return type hint on its `run` method') + next_node_ids: set[str] = set() + returns_end: bool = False + returns_base_node: bool = False + after_interrupt_nodes: list[str] = [] for return_type in _utils.get_union_args(return_hint): return_type_origin = get_origin(return_type) or return_type if return_type_origin is End: returns_end = True + elif issubclass(return_type_origin, Interrupt): + interrupt_args = get_args(return_type) + assert len(interrupt_args) == 1, f'Invalid Interrupt return type: {return_type}' + next_node_id = interrupt_args[0].get_id() + next_node_ids.add(next_node_id) + after_interrupt_nodes.append(next_node_id) elif return_type_origin is BaseNode: - # TODO: Should we disallow this? More generally, what do we do about sub-subclasses? + # TODO: Should we disallow this? returns_base_node = True elif issubclass(return_type_origin, BaseNode): next_node_ids.add(return_type.get_id()) @@ -69,19 +71,48 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRu cls.get_id(), next_node_ids, returns_end, + after_interrupt_nodes, returns_base_node, ) +@dataclass +class End(Generic[RunEndT]): + """Type to return from a node to signal the end of the graph.""" + + data: RunEndT + + +InterruptNextNodeT = TypeVar('InterruptNextNodeT', covariant=True, bound=BaseNode[Any, Any]) + + +@dataclass +class Interrupt(Generic[InterruptNextNodeT]): + """Type to return from a node to signal that the run should be interrupted.""" + + node: InterruptNextNodeT + + +RunInterrupt = Interrupt[BaseNode[StateT, Any]] + + @dataclass class NodeDef(Generic[StateT, NodeRunEndT]): """Definition of a node. - Used by [`Graph`][pydantic_ai_graph.graph.Graph] store information about a node. + Used by [`Graph`][pydantic_ai_graph.graph.Graph] to store information about a node, and when generating + mermaid graphs. """ node: type[BaseNode[StateT, NodeRunEndT]] + """The node definition itself.""" node_id: str + """ID of the node.""" next_node_ids: set[str] + """IDs of the nodes that can be called next.""" returns_end: bool + """The node definition returns an `End`, hence the node and end the run.""" + after_interrupt_nodes: list[str] + """Nodes that can be returned within an `Interrupt`.""" returns_base_node: bool + """The node definition returns a `BaseNode`, hence any node in the next can be called next.""" diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py index 66c527d882..41c81d03eb 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -4,17 +4,17 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Generic, Literal, Union +from typing import TYPE_CHECKING, Generic, Literal, Self, Union from typing_extensions import Never, TypeVar from . import _utils -__all__ = 'AbstractState', 'StateT', 'Step', 'EndEvent', 'StepOrEnd' +__all__ = 'AbstractState', 'StateT', 'NextNodeEvent', 'EndEvent', 'InterruptEvent', 'HistoryStep' if TYPE_CHECKING: from pydantic_ai_graph import BaseNode - from pydantic_ai_graph.nodes import End + from pydantic_ai_graph.nodes import End, RunInterrupt class AbstractState(ABC): @@ -25,6 +25,10 @@ def serialize(self) -> bytes | None: """Serialize the state object.""" raise NotImplementedError + def deep_copy(self) -> Self: + """Create a deep copy of the state object.""" + return copy.deepcopy(self) + RunEndT = TypeVar('RunEndT', default=None) NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) @@ -32,8 +36,8 @@ def serialize(self) -> bytes | None: @dataclass -class Step(Generic[StateT, RunEndT]): - """History item describing the execution of a step of a graph.""" +class NextNodeEvent(Generic[StateT, RunEndT]): + """History step describing the execution of a step of a graph.""" state: StateT node: BaseNode[StateT, RunEndT] @@ -44,15 +48,33 @@ class Step(Generic[StateT, RunEndT]): def __post_init__(self): # Copy the state to prevent it from being modified by other code - self.state = copy.deepcopy(self.state) + self.state = _deep_copy_state(self.state) def node_summary(self) -> str: return str(self.node) +@dataclass +class InterruptEvent(Generic[StateT]): + """History step describing the interruption of a graph run.""" + + state: StateT + result: RunInterrupt[StateT] + ts: datetime = field(default_factory=_utils.now_utc) + + kind: Literal['interrupt'] = 'interrupt' + + def __post_init__(self): + # Copy the state to prevent it from being modified by other code + self.state = _deep_copy_state(self.state) + + def node_summary(self) -> str: + return str(self.result) + + @dataclass class EndEvent(Generic[StateT, RunEndT]): - """History item describing the end of a graph run.""" + """History step describing the end of a graph run.""" state: StateT result: End[RunEndT] @@ -62,10 +84,17 @@ class EndEvent(Generic[StateT, RunEndT]): def __post_init__(self): # Copy the state to prevent it from being modified by other code - self.state = copy.deepcopy(self.state) + self.state = _deep_copy_state(self.state) def node_summary(self) -> str: return str(self.result) -StepOrEnd = Union[Step[StateT, RunEndT], EndEvent[StateT, RunEndT]] +def _deep_copy_state(state: StateT) -> StateT: + if state is None: + return None # pyright: ignore[reportReturnType] + else: + return state.deep_copy() + + +HistoryStep = Union[NextNodeEvent[StateT, RunEndT], InterruptEvent[StateT], EndEvent[StateT, RunEndT]] diff --git a/tests/test_graph.py b/tests/test_graph.py index 6187b764e6..cdef4f27d3 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -7,7 +7,7 @@ import pytest from inline_snapshot import snapshot -from pydantic_ai_graph import BaseNode, End, EndEvent, Graph, GraphContext, Step +from pydantic_ai_graph import BaseNode, End, EndEvent, Graph, GraphContext, NextNodeEvent from .conftest import IsFloat, IsNow @@ -42,22 +42,22 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq g = Graph[None, int](nodes=(Float2String, String2Length, Double)) result, history = await g.run(None, Float2String(3.14)) # len('3.14') * 2 == 8 - assert result == 8 + assert result == End(8) assert history == snapshot( [ - Step( + NextNodeEvent( state=None, node=Float2String(input_data=3.14), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - Step( + NextNodeEvent( state=None, node=String2Length(input_data='3.14'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - Step( + NextNodeEvent( state=None, node=Double(input_data=4), start_ts=IsNow(tz=timezone.utc), @@ -72,34 +72,34 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq ) result, history = await g.run(None, Float2String(3.14159)) # len('3.14159') == 7, 21 * 2 == 42 - assert result == 42 + assert result == End(42) assert history == snapshot( [ - Step( + NextNodeEvent( state=None, node=Float2String(input_data=3.14159), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - Step( + NextNodeEvent( state=None, node=String2Length(input_data='3.14159'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - Step( + NextNodeEvent( state=None, node=Double(input_data=7), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - Step( + NextNodeEvent( state=None, node=String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - Step( + NextNodeEvent( state=None, node=Double(input_data=21), start_ts=IsNow(tz=timezone.utc), From b63ca74f5bd8729c8b9c0dabe2558b98233e160d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 9 Jan 2025 18:20:39 +0000 Subject: [PATCH 23/57] remove interrupt, replace with "next()" --- .../email_extract_graph.py | 3 +- .../pydantic_ai_graph/__init__.py | 6 +- pydantic_ai_graph/pydantic_ai_graph/graph.py | 108 ++++++++---------- .../pydantic_ai_graph/mermaid.py | 10 +- pydantic_ai_graph/pydantic_ai_graph/nodes.py | 31 +---- pydantic_ai_graph/pydantic_ai_graph/state.py | 28 +---- 6 files changed, 60 insertions(+), 126 deletions(-) diff --git a/examples/pydantic_ai_examples/email_extract_graph.py b/examples/pydantic_ai_examples/email_extract_graph.py index 8cf30803e1..0c16e8bee0 100644 --- a/examples/pydantic_ai_examples/email_extract_graph.py +++ b/examples/pydantic_ai_examples/email_extract_graph.py @@ -147,7 +147,8 @@ async def run( async def main(): state = State(email_content=email) - result, history = await graph.run(state, ExtractEvent()) + history = [] + result = await graph.run(state, ExtractEvent()) debug(result, history) diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py index afdad1bdc2..3fbdb587bb 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/__init__.py +++ b/pydantic_ai_graph/pydantic_ai_graph/__init__.py @@ -1,13 +1,11 @@ -from .graph import Graph, GraphRun -from .nodes import BaseNode, End, GraphContext, Interrupt +from .graph import Graph +from .nodes import BaseNode, End, GraphContext from .state import AbstractState, EndEvent, HistoryStep, NextNodeEvent __all__ = ( 'Graph', - 'GraphRun', 'BaseNode', 'End', - 'Interrupt', 'GraphContext', 'AbstractState', 'EndEvent', diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index 22db3d6497..32e8a7caa8 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -2,20 +2,20 @@ import inspect from collections.abc import Sequence -from dataclasses import dataclass, field +from dataclasses import dataclass from pathlib import Path from time import perf_counter -from typing import TYPE_CHECKING, Generic +from typing import TYPE_CHECKING, Any, Generic import logfire_api from typing_extensions import Never, ParamSpec, TypeVar, Unpack, assert_never from . import _utils, mermaid from ._utils import get_parent_namespace -from .nodes import BaseNode, End, GraphContext, Interrupt, NodeDef, RunInterrupt -from .state import EndEvent, HistoryStep, InterruptEvent, NextNodeEvent, StateT +from .nodes import BaseNode, End, GraphContext, NodeDef +from .state import EndEvent, HistoryStep, NextNodeEvent, StateT -__all__ = 'Graph', 'GraphRun' +__all__ = ('Graph',) _logfire = logfire_api.Logfire(otel_scope='pydantic-ai-graph') @@ -74,16 +74,48 @@ def _validate_edges(self): b = '\n'.join(f' {be}' for be in bad_edges_list) raise ValueError(f'Nodes are referenced in the graph but not included in the graph:\n{b}') + async def next( + self, state: StateT, node: BaseNode[StateT, RunEndT], history: list[HistoryStep[StateT, RunEndT]] + ) -> BaseNode[StateT, Any] | End[RunEndT]: + node_id = node.get_id() + if node_id not in self.node_defs: + raise TypeError(f'Node "{node}" is not in the graph.') + + history_step: NextNodeEvent[StateT, RunEndT] | None = NextNodeEvent(state, node) + history.append(history_step) + + ctx = GraphContext(state) + with _logfire.span('run node {node_id}', node_id=node_id, node=node): + start = perf_counter() + next_node = await node.run(ctx) + history_step.duration = perf_counter() - start + return next_node + async def run( - self, state: StateT, node: BaseNode[StateT, RunEndT] - ) -> tuple[End[RunEndT] | RunInterrupt[StateT], list[HistoryStep[StateT, RunEndT]]]: - if not isinstance(node, self.nodes): - raise ValueError(f'Node "{node}" is not in the graph.') - run = GraphRun[StateT, RunEndT](state=state) - # TODO: Infer the graph name properly - result = await run.run(self.name or 'graph', node) - history = run.history - return result, history + self, + state: StateT, + node: BaseNode[StateT, RunEndT], + ) -> tuple[End[RunEndT], list[HistoryStep[StateT, RunEndT]]]: + history: list[HistoryStep[StateT, RunEndT]] = [] + + with _logfire.span( + '{graph_name} run {start=}', + graph_name=self.name or 'graph', + start=node, + ) as run_span: + while True: + next_node = await self.next(state, node, history=history) + if isinstance(next_node, End): + history.append(EndEvent(state, next_node)) + run_span.set_attribute('history', history) + return next_node, history + elif isinstance(next_node, BaseNode): + node = next_node + else: + if TYPE_CHECKING: + assert_never(next_node) + else: + raise TypeError(f'Invalid node type: {type(next_node)}. Expected `BaseNode` or `End`.') def mermaid_code( self, @@ -101,51 +133,3 @@ def mermaid_image(self, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes: def mermaid_save(self, path: Path | str, /, **kwargs: Unpack[mermaid.MermaidConfig]) -> None: mermaid.save_image(path, self, **kwargs) - - -@dataclass -class GraphRun(Generic[StateT, RunEndT]): - """Stateful run of a graph.""" - - state: StateT - history: list[HistoryStep[StateT, RunEndT]] = field(default_factory=list) - - async def run( - self, graph_name: str, start: BaseNode[StateT, RunEndT], infer_name: bool = True - ) -> End[RunEndT] | RunInterrupt[StateT]: - current_node = start - - with _logfire.span( - '{graph_name} run {start=}', - graph_name=graph_name, - start=start, - ) as run_span: - while True: - next_node = await self.step(current_node) - if isinstance(next_node, (Interrupt, End)): - if isinstance(next_node, End): - self.history.append(EndEvent(self.state, next_node)) - else: - self.history.append(InterruptEvent(self.state, next_node)) - run_span.set_attribute('history', self.history) - return next_node - elif isinstance(next_node, BaseNode): - current_node = next_node - else: - if TYPE_CHECKING: - assert_never(next_node) - else: - raise TypeError(f'Invalid node type: {type(next_node)}. Expected `BaseNode` or `End`.') - - async def step( - self, node: BaseNode[StateT, RunEndT] - ) -> BaseNode[StateT, RunEndT] | End[RunEndT] | RunInterrupt[StateT]: - history_step = NextNodeEvent(self.state, node) - self.history.append(history_step) - - ctx = GraphContext(self.state) - with _logfire.span('run node {node_id}', node_id=node.get_id()): - start = perf_counter() - next_node = await node.run(ctx) - history_step.duration = perf_counter() - start - return next_node diff --git a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py index 109423bcee..534e1b5227 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py +++ b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py @@ -2,7 +2,6 @@ import base64 from collections.abc import Iterable, Sequence -from itertools import chain from pathlib import Path from typing import TYPE_CHECKING, Annotated, Any, Literal @@ -43,19 +42,14 @@ def generate_code( raise LookupError(f'Start node "{node_id}" is not in the graph.') node_order = {node_id: index for index, node_id in enumerate(graph.node_defs)} - after_interrupt_nodes = set(chain(*(node_def.after_interrupt_nodes for node_def in graph.node_defs.values()))) lines = ['graph TD'] for node in graph.nodes: node_id = node.get_id() node_def = graph.node_defs[node_id] - # we use square brackets (square box) for nodes that can interrupt the flow, - # and round brackets (rounded box) for nodes that cannot interrupt the flow - if node_id in after_interrupt_nodes: - mermaid_name = f'[{node_id}]' - else: - mermaid_name = f'({node_id})' + # we use round brackets (rounded box) for nodes other than the start and end + mermaid_name = f'({node_id})' if node_id in start_node_ids: lines.append(f' START --> {node_id}{mermaid_name}') if node_def.returns_base_node: diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_ai_graph/pydantic_ai_graph/nodes.py index 168339d0a9..c5beb7b6d5 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/nodes.py +++ b/pydantic_ai_graph/pydantic_ai_graph/nodes.py @@ -3,14 +3,14 @@ from abc import abstractmethod from dataclasses import dataclass from functools import cache -from typing import Any, Generic, get_args, get_origin, get_type_hints +from typing import Any, Generic, get_origin, get_type_hints from typing_extensions import Never, TypeVar from . import _utils from .state import StateT -__all__ = 'GraphContext', 'BaseNode', 'End', 'Interrupt', 'RunInterrupt', 'NodeDef' +__all__ = 'GraphContext', 'BaseNode', 'End', 'NodeDef' RunEndT = TypeVar('RunEndT', default=None) NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) @@ -27,9 +27,7 @@ class BaseNode(Generic[StateT, NodeRunEndT]): """Base class for a node.""" @abstractmethod - async def run( - self, ctx: GraphContext[StateT] - ) -> BaseNode[StateT, Any] | End[NodeRunEndT] | RunInterrupt[StateT]: ... + async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: ... @classmethod @cache @@ -47,17 +45,10 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRu next_node_ids: set[str] = set() returns_end: bool = False returns_base_node: bool = False - after_interrupt_nodes: list[str] = [] for return_type in _utils.get_union_args(return_hint): return_type_origin = get_origin(return_type) or return_type if return_type_origin is End: returns_end = True - elif issubclass(return_type_origin, Interrupt): - interrupt_args = get_args(return_type) - assert len(interrupt_args) == 1, f'Invalid Interrupt return type: {return_type}' - next_node_id = interrupt_args[0].get_id() - next_node_ids.add(next_node_id) - after_interrupt_nodes.append(next_node_id) elif return_type_origin is BaseNode: # TODO: Should we disallow this? returns_base_node = True @@ -71,7 +62,6 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRu cls.get_id(), next_node_ids, returns_end, - after_interrupt_nodes, returns_base_node, ) @@ -83,19 +73,6 @@ class End(Generic[RunEndT]): data: RunEndT -InterruptNextNodeT = TypeVar('InterruptNextNodeT', covariant=True, bound=BaseNode[Any, Any]) - - -@dataclass -class Interrupt(Generic[InterruptNextNodeT]): - """Type to return from a node to signal that the run should be interrupted.""" - - node: InterruptNextNodeT - - -RunInterrupt = Interrupt[BaseNode[StateT, Any]] - - @dataclass class NodeDef(Generic[StateT, NodeRunEndT]): """Definition of a node. @@ -112,7 +89,5 @@ class NodeDef(Generic[StateT, NodeRunEndT]): """IDs of the nodes that can be called next.""" returns_end: bool """The node definition returns an `End`, hence the node and end the run.""" - after_interrupt_nodes: list[str] - """Nodes that can be returned within an `Interrupt`.""" returns_base_node: bool """The node definition returns a `BaseNode`, hence any node in the next can be called next.""" diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py index 41c81d03eb..ad8858c4a4 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -10,11 +10,11 @@ from . import _utils -__all__ = 'AbstractState', 'StateT', 'NextNodeEvent', 'EndEvent', 'InterruptEvent', 'HistoryStep' +__all__ = 'AbstractState', 'StateT', 'NextNodeEvent', 'EndEvent', 'HistoryStep' if TYPE_CHECKING: from pydantic_ai_graph import BaseNode - from pydantic_ai_graph.nodes import End, RunInterrupt + from pydantic_ai_graph.nodes import End class AbstractState(ABC): @@ -50,28 +50,10 @@ def __post_init__(self): # Copy the state to prevent it from being modified by other code self.state = _deep_copy_state(self.state) - def node_summary(self) -> str: + def summary(self) -> str: return str(self.node) -@dataclass -class InterruptEvent(Generic[StateT]): - """History step describing the interruption of a graph run.""" - - state: StateT - result: RunInterrupt[StateT] - ts: datetime = field(default_factory=_utils.now_utc) - - kind: Literal['interrupt'] = 'interrupt' - - def __post_init__(self): - # Copy the state to prevent it from being modified by other code - self.state = _deep_copy_state(self.state) - - def node_summary(self) -> str: - return str(self.result) - - @dataclass class EndEvent(Generic[StateT, RunEndT]): """History step describing the end of a graph run.""" @@ -86,7 +68,7 @@ def __post_init__(self): # Copy the state to prevent it from being modified by other code self.state = _deep_copy_state(self.state) - def node_summary(self) -> str: + def summary(self) -> str: return str(self.result) @@ -97,4 +79,4 @@ def _deep_copy_state(state: StateT) -> StateT: return state.deep_copy() -HistoryStep = Union[NextNodeEvent[StateT, RunEndT], InterruptEvent[StateT], EndEvent[StateT, RunEndT]] +HistoryStep = Union[NextNodeEvent[StateT, RunEndT], EndEvent[StateT, RunEndT]] From 745e3d585409e177971ab1353b88edf456f1125e Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 9 Jan 2025 18:29:23 +0000 Subject: [PATCH 24/57] address comments --- .../pydantic_ai_graph/__init__.py | 4 ++-- pydantic_ai_graph/pydantic_ai_graph/graph.py | 4 ++-- pydantic_ai_graph/pydantic_ai_graph/nodes.py | 2 +- pydantic_ai_graph/pydantic_ai_graph/state.py | 10 +++++----- tests/test_graph.py | 18 +++++++++--------- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py index 3fbdb587bb..51db45c53c 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/__init__.py +++ b/pydantic_ai_graph/pydantic_ai_graph/__init__.py @@ -1,6 +1,6 @@ from .graph import Graph from .nodes import BaseNode, End, GraphContext -from .state import AbstractState, EndEvent, HistoryStep, NextNodeEvent +from .state import AbstractState, EndEvent, HistoryStep, NodeEvent __all__ = ( 'Graph', @@ -10,5 +10,5 @@ 'AbstractState', 'EndEvent', 'HistoryStep', - 'NextNodeEvent', + 'NodeEvent', ) diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index 32e8a7caa8..7c008469da 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -13,7 +13,7 @@ from . import _utils, mermaid from ._utils import get_parent_namespace from .nodes import BaseNode, End, GraphContext, NodeDef -from .state import EndEvent, HistoryStep, NextNodeEvent, StateT +from .state import EndEvent, HistoryStep, NodeEvent, StateT __all__ = ('Graph',) @@ -81,7 +81,7 @@ async def next( if node_id not in self.node_defs: raise TypeError(f'Node "{node}" is not in the graph.') - history_step: NextNodeEvent[StateT, RunEndT] | None = NextNodeEvent(state, node) + history_step: NodeEvent[StateT, RunEndT] | None = NodeEvent(state, node) history.append(history_step) ctx = GraphContext(state) diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_ai_graph/pydantic_ai_graph/nodes.py index c5beb7b6d5..76b39aae64 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/nodes.py +++ b/pydantic_ai_graph/pydantic_ai_graph/nodes.py @@ -88,6 +88,6 @@ class NodeDef(Generic[StateT, NodeRunEndT]): next_node_ids: set[str] """IDs of the nodes that can be called next.""" returns_end: bool - """The node definition returns an `End`, hence the node and end the run.""" + """The node definition returns an `End`, hence the node can end the run.""" returns_base_node: bool """The node definition returns a `BaseNode`, hence any node in the next can be called next.""" diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py index ad8858c4a4..876038f9c4 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -10,7 +10,7 @@ from . import _utils -__all__ = 'AbstractState', 'StateT', 'NextNodeEvent', 'EndEvent', 'HistoryStep' +__all__ = 'AbstractState', 'StateT', 'NodeEvent', 'EndEvent', 'HistoryStep' if TYPE_CHECKING: from pydantic_ai_graph import BaseNode @@ -36,8 +36,8 @@ def deep_copy(self) -> Self: @dataclass -class NextNodeEvent(Generic[StateT, RunEndT]): - """History step describing the execution of a step of a graph.""" +class NodeEvent(Generic[StateT, RunEndT]): + """History step describing the execution of a node in a graph.""" state: StateT node: BaseNode[StateT, RunEndT] @@ -74,9 +74,9 @@ def summary(self) -> str: def _deep_copy_state(state: StateT) -> StateT: if state is None: - return None # pyright: ignore[reportReturnType] + return state else: return state.deep_copy() -HistoryStep = Union[NextNodeEvent[StateT, RunEndT], EndEvent[StateT, RunEndT]] +HistoryStep = Union[NodeEvent[StateT, RunEndT], EndEvent[StateT, RunEndT]] diff --git a/tests/test_graph.py b/tests/test_graph.py index cdef4f27d3..3d70bec92c 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -7,7 +7,7 @@ import pytest from inline_snapshot import snapshot -from pydantic_ai_graph import BaseNode, End, EndEvent, Graph, GraphContext, NextNodeEvent +from pydantic_ai_graph import BaseNode, End, EndEvent, Graph, GraphContext, NodeEvent from .conftest import IsFloat, IsNow @@ -45,19 +45,19 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq assert result == End(8) assert history == snapshot( [ - NextNodeEvent( + NodeEvent( state=None, node=Float2String(input_data=3.14), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NextNodeEvent( + NodeEvent( state=None, node=String2Length(input_data='3.14'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NextNodeEvent( + NodeEvent( state=None, node=Double(input_data=4), start_ts=IsNow(tz=timezone.utc), @@ -75,31 +75,31 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq assert result == End(42) assert history == snapshot( [ - NextNodeEvent( + NodeEvent( state=None, node=Float2String(input_data=3.14159), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NextNodeEvent( + NodeEvent( state=None, node=String2Length(input_data='3.14159'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NextNodeEvent( + NodeEvent( state=None, node=Double(input_data=7), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NextNodeEvent( + NodeEvent( state=None, node=String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NextNodeEvent( + NodeEvent( state=None, node=Double(input_data=21), start_ts=IsNow(tz=timezone.utc), From 1370f88ffffc541f717d71576b5c4f18157929ea Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 12:15:56 +0000 Subject: [PATCH 25/57] switch name to pydantic-graph --- .github/workflows/ci.yml | 2 +- .../email_extract_graph.py | 2 +- mkdocs.yml | 2 +- pydantic_ai_graph/README.md | 21 ---------- pydantic_ai_slim/pyproject.toml | 4 +- pydantic_graph/README.md | 16 ++++++++ .../pydantic_graph}/__init__.py | 0 .../pydantic_graph}/_utils.py | 0 .../pydantic_graph}/graph.py | 2 +- .../pydantic_graph}/mermaid.py | 0 .../pydantic_graph}/nodes.py | 2 +- .../pydantic_graph}/py.typed | 0 .../pydantic_graph}/state.py | 4 +- .../pyproject.toml | 6 +-- pyproject.toml | 10 ++--- tests/test_graph.py | 2 +- tests/typed_graph.py | 2 +- uprev.py | 4 +- uv.lock | 40 +++++++++---------- 19 files changed, 57 insertions(+), 62 deletions(-) delete mode 100644 pydantic_ai_graph/README.md create mode 100644 pydantic_graph/README.md rename {pydantic_ai_graph/pydantic_ai_graph => pydantic_graph/pydantic_graph}/__init__.py (100%) rename {pydantic_ai_graph/pydantic_ai_graph => pydantic_graph/pydantic_graph}/_utils.py (100%) rename {pydantic_ai_graph/pydantic_ai_graph => pydantic_graph/pydantic_graph}/graph.py (98%) rename {pydantic_ai_graph/pydantic_ai_graph => pydantic_graph/pydantic_graph}/mermaid.py (100%) rename {pydantic_ai_graph/pydantic_ai_graph => pydantic_graph/pydantic_graph}/nodes.py (96%) rename {pydantic_ai_graph/pydantic_ai_graph => pydantic_graph/pydantic_graph}/py.typed (100%) rename {pydantic_ai_graph/pydantic_ai_graph => pydantic_graph/pydantic_graph}/state.py (96%) rename {pydantic_ai_graph => pydantic_graph}/pyproject.toml (93%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 12439b8320..cc694f175c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -128,7 +128,7 @@ jobs: - run: mkdir coverage - # run tests with just `pydantic-ai-slim` and `pydantic-ai-graph` dependencies + # run tests with just `pydantic-ai-slim` and `pydantic-graph` dependencies - run: uv run --package pydantic-ai-slim --extra graph coverage run -m pytest env: COVERAGE_FILE: coverage/.coverage.${{ runner.os }}-py${{ matrix.python-version }}-slim diff --git a/examples/pydantic_ai_examples/email_extract_graph.py b/examples/pydantic_ai_examples/email_extract_graph.py index 0c16e8bee0..82c98a370e 100644 --- a/examples/pydantic_ai_examples/email_extract_graph.py +++ b/examples/pydantic_ai_examples/email_extract_graph.py @@ -7,7 +7,7 @@ import logfire from devtools import debug from pydantic import BaseModel -from pydantic_ai_graph import AbstractState, BaseNode, End, Graph, GraphContext +from pydantic_graph import AbstractState, BaseNode, End, Graph, GraphContext from pydantic_ai import Agent, RunContext diff --git a/mkdocs.yml b/mkdocs.yml index 60d27c96c2..915701c64a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -159,7 +159,7 @@ plugins: - mkdocstrings: handlers: python: - paths: [src/packages/pydantic_ai_slim/pydantic_ai] + paths: [pydantic_ai_slim/pydantic_ai] options: relative_crossrefs: true members_order: source diff --git a/pydantic_ai_graph/README.md b/pydantic_ai_graph/README.md deleted file mode 100644 index 3f3a5bd72a..0000000000 --- a/pydantic_ai_graph/README.md +++ /dev/null @@ -1,21 +0,0 @@ -# PydanticAI Graph - -[![CI](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain) -[![Coverage](https://coverage-badge.samuelcolvin.workers.dev/pydantic/pydantic-ai.svg)](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai) -[![PyPI](https://img.shields.io/pypi/v/pydantic-ai-graph.svg)](https://pypi.python.org/pypi/pydantic-ai-graph) -[![versions](https://img.shields.io/pypi/pyversions/pydantic-ai-graph.svg)](https://github.com/pydantic/pydantic-ai) -[![license](https://img.shields.io/github/license/pydantic/pydantic-ai-graph.svg?v)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE) - -Graph and state machine library. - -This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency -on `pydantic-ai` or related packages and can be considered as a pure graph library. - -As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. - -`pydantic-ai-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes. - -When designing your graph and state machine, you need to identify the data types for the overall graph input, the final graph output, the graph dependency object and graph state. Then for each specific node in the graph, you have to identify the specific data type each node is expected to receive as the input type from the prior node in the graph during transitions. - -Once the nodes in the graph are defined, you can use certain built-in methods on the Graph object to visualize the nodes -and state transitions on the graph as mermaid diagrams. diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 693d27aaec..912ad47479 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -42,7 +42,7 @@ dependencies = [ ] [project.optional-dependencies] -graph = ["pydantic-ai-graph==0.0.14"] +graph = ["pydantic-graph==0.0.14"] openai = ["openai>=1.54.3"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"] anthropic = ["anthropic>=0.40.0"] @@ -68,4 +68,4 @@ dev = [ packages = ["pydantic_ai"] [tool.uv.sources] -pydantic-ai-graph = { workspace = true } +pydantic-graph = { workspace = true } diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md new file mode 100644 index 0000000000..2789d18ac1 --- /dev/null +++ b/pydantic_graph/README.md @@ -0,0 +1,16 @@ +# Pydantic Graph + +[![CI](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain) +[![Coverage](https://coverage-badge.samuelcolvin.workers.dev/pydantic/pydantic-ai.svg)](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai) +[![PyPI](https://img.shields.io/pypi/v/pydantic-graph.svg)](https://pypi.python.org/pypi/pydantic-graph) +[![versions](https://img.shields.io/pypi/pyversions/pydantic-graph.svg)](https://github.com/pydantic/pydantic-ai) +[![license](https://img.shields.io/github/license/pydantic/pydantic-ai.svg?v)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE) + +Graph and finite state machine library. + +This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency +on `pydantic-ai` or related packages and does and can be considered as a pure graph library. + +As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. + +`pydantic-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes. diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py similarity index 100% rename from pydantic_ai_graph/pydantic_ai_graph/__init__.py rename to pydantic_graph/pydantic_graph/__init__.py diff --git a/pydantic_ai_graph/pydantic_ai_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py similarity index 100% rename from pydantic_ai_graph/pydantic_ai_graph/_utils.py rename to pydantic_graph/pydantic_graph/_utils.py diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py similarity index 98% rename from pydantic_ai_graph/pydantic_ai_graph/graph.py rename to pydantic_graph/pydantic_graph/graph.py index 7c008469da..10e1b35e55 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -17,7 +17,7 @@ __all__ = ('Graph',) -_logfire = logfire_api.Logfire(otel_scope='pydantic-ai-graph') +_logfire = logfire_api.Logfire(otel_scope='pydantic-graph') RunSignatureT = ParamSpec('RunSignatureT') RunEndT = TypeVar('RunEndT', default=None) diff --git a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py similarity index 100% rename from pydantic_ai_graph/pydantic_ai_graph/mermaid.py rename to pydantic_graph/pydantic_graph/mermaid.py diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py similarity index 96% rename from pydantic_ai_graph/pydantic_ai_graph/nodes.py rename to pydantic_graph/pydantic_graph/nodes.py index 76b39aae64..bbadd8b380 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -77,7 +77,7 @@ class End(Generic[RunEndT]): class NodeDef(Generic[StateT, NodeRunEndT]): """Definition of a node. - Used by [`Graph`][pydantic_ai_graph.graph.Graph] to store information about a node, and when generating + Used by [`Graph`][pydantic_graph.graph.Graph] to store information about a node, and when generating mermaid graphs. """ diff --git a/pydantic_ai_graph/pydantic_ai_graph/py.typed b/pydantic_graph/pydantic_graph/py.typed similarity index 100% rename from pydantic_ai_graph/pydantic_ai_graph/py.typed rename to pydantic_graph/pydantic_graph/py.typed diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_graph/pydantic_graph/state.py similarity index 96% rename from pydantic_ai_graph/pydantic_ai_graph/state.py rename to pydantic_graph/pydantic_graph/state.py index 876038f9c4..9b007836ee 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -13,8 +13,8 @@ __all__ = 'AbstractState', 'StateT', 'NodeEvent', 'EndEvent', 'HistoryStep' if TYPE_CHECKING: - from pydantic_ai_graph import BaseNode - from pydantic_ai_graph.nodes import End + from pydantic_graph import BaseNode + from pydantic_graph.nodes import End class AbstractState(ABC): diff --git a/pydantic_ai_graph/pyproject.toml b/pydantic_graph/pyproject.toml similarity index 93% rename from pydantic_ai_graph/pyproject.toml rename to pydantic_graph/pyproject.toml index cb047a9a8e..17aaa1fba7 100644 --- a/pydantic_ai_graph/pyproject.toml +++ b/pydantic_graph/pyproject.toml @@ -3,8 +3,8 @@ requires = ["hatchling"] build-backend = "hatchling.build" [project] -name = "pydantic-ai-graph" -version = "0.0.14" +name = "pydantic-graph" +version = "0.0.17" description = "Graph and state machine library" authors = [ { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, @@ -40,4 +40,4 @@ dependencies = [ ] [tool.hatch.build.targets.wheel] -packages = ["pydantic_ai_graph"] +packages = ["pydantic_graph"] diff --git a/pyproject.toml b/pyproject.toml index 354b4b4448..746b8ee7ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,11 +51,11 @@ logfire = ["logfire>=2.3"] [tool.uv.sources] pydantic-ai-slim = { workspace = true } -pydantic-ai-graph = { workspace = true } +pydantic-graph = { workspace = true } pydantic-ai-examples = { workspace = true } [tool.uv.workspace] -members = ["pydantic_ai_slim", "pydantic_ai_graph", "examples"] +members = ["pydantic_ai_slim", "pydantic_graph", "examples"] [dependency-groups] # dev dependencies are defined in `pydantic-ai-slim/pyproject.toml` to allow for minimal testing @@ -83,7 +83,7 @@ line-length = 120 target-version = "py39" include = [ "pydantic_ai_slim/**/*.py", - "pydantic_ai_graph/**/*.py", + "pydantic_graph/**/*.py", "examples/**/*.py", "tests/**/*.py", "docs/**/*.py", @@ -127,7 +127,7 @@ typeCheckingMode = "strict" reportMissingTypeStubs = false reportUnnecessaryIsInstance = false reportUnnecessaryTypeIgnoreComment = true -include = ["pydantic_ai_slim", "pydantic_ai_graph", "tests", "examples"] +include = ["pydantic_ai_slim", "pydantic_graph", "tests", "examples"] venvPath = ".venv" # see https://github.com/microsoft/pyright/issues/7771 - we don't want to error on decorated functions in tests # which are not otherwise used @@ -150,7 +150,7 @@ filterwarnings = [ # https://coverage.readthedocs.io/en/latest/config.html#run [tool.coverage.run] # required to avoid warnings about files created by create_module fixture -include = ["pydantic_ai_slim/**/*.py", "pydantic_ai_graph/**/*.py","tests/**/*.py"] +include = ["pydantic_ai_slim/**/*.py", "pydantic_graph/**/*.py","tests/**/*.py"] omit = ["tests/test_live.py", "tests/example_modules/*.py"] branch = true diff --git a/tests/test_graph.py b/tests/test_graph.py index 3d70bec92c..7c3eba97aa 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -7,7 +7,7 @@ import pytest from inline_snapshot import snapshot -from pydantic_ai_graph import BaseNode, End, EndEvent, Graph, GraphContext, NodeEvent +from pydantic_graph import BaseNode, End, EndEvent, Graph, GraphContext, NodeEvent from .conftest import IsFloat, IsNow diff --git a/tests/typed_graph.py b/tests/typed_graph.py index c131f19b4c..8c6ff9b0ea 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -4,7 +4,7 @@ from typing_extensions import assert_type -from pydantic_ai_graph import BaseNode, End, Graph, GraphContext +from pydantic_graph import BaseNode, End, Graph, GraphContext @dataclass diff --git a/uprev.py b/uprev.py index 3eeca56d01..9ded772252 100644 --- a/uprev.py +++ b/uprev.py @@ -5,7 +5,7 @@ * pyproject.toml * examples/pyproject.toml * pydantic_ai_slim/pyproject.toml -* pydantic_ai_graph/pyproject.toml +* pydantic_graph/pyproject.toml Usage: @@ -68,7 +68,7 @@ def replace_deps_version(text: str) -> tuple[str, int]: slim_pp_text = slim_pp.read_text() slim_pp_text, count_slim = replace_deps_version(slim_pp_text) -graph_pp = ROOT_DIR / 'pydantic_ai_graph' / 'pyproject.toml' +graph_pp = ROOT_DIR / 'pydantic_graph' / 'pyproject.toml' graph_pp_text = graph_pp.read_text() graph_pp_text, count_graph = replace_deps_version(graph_pp_text) diff --git a/uv.lock b/uv.lock index 2c36b984b3..a5b648133e 100644 --- a/uv.lock +++ b/uv.lock @@ -12,8 +12,8 @@ resolution-markers = [ members = [ "pydantic-ai", "pydantic-ai-examples", - "pydantic-ai-graph", "pydantic-ai-slim", + "pydantic-graph", ] [[package]] @@ -2096,23 +2096,6 @@ requires-dist = [ { name = "uvicorn", specifier = ">=0.32.0" }, ] -[[package]] -name = "pydantic-ai-graph" -version = "0.0.14" -source = { editable = "pydantic_ai_graph" } -dependencies = [ - { name = "httpx" }, - { name = "logfire-api" }, - { name = "pydantic" }, -] - -[package.metadata] -requires-dist = [ - { name = "httpx", specifier = ">=0.27.2" }, - { name = "logfire-api", specifier = ">=1.2.0" }, - { name = "pydantic", specifier = ">=2.10" }, -] - [[package]] name = "pydantic-ai-slim" version = "0.0.18" @@ -2130,7 +2113,7 @@ anthropic = [ { name = "anthropic" }, ] graph = [ - { name = "pydantic-ai-graph" }, + { name = "pydantic-graph" }, ] groq = [ { name = "groq" }, @@ -2177,7 +2160,7 @@ requires-dist = [ { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.2.5" }, { name = "openai", marker = "extra == 'openai'", specifier = ">=1.54.3" }, { name = "pydantic", specifier = ">=2.10" }, - { name = "pydantic-ai-graph", marker = "extra == 'graph'", editable = "pydantic_ai_graph" }, + { name = "pydantic-graph", marker = "extra == 'graph'", editable = "pydantic_graph" }, { name = "requests", marker = "extra == 'vertexai'", specifier = ">=2.32.3" }, ] @@ -2301,6 +2284,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a6/53/d78dc063216e62fc55f6b2eebb447f6a4b0a59f55c8406376f76bf959b08/pydub-0.25.1-py2.py3-none-any.whl", hash = "sha256:65617e33033874b59d87db603aa1ed450633288aefead953b30bded59cb599a6", size = 32327 }, ] +[[package]] +name = "pydantic-graph" +version = "0.0.17" +source = { editable = "pydantic_graph" } +dependencies = [ + { name = "httpx" }, + { name = "logfire-api" }, + { name = "pydantic" }, +] + +[package.metadata] +requires-dist = [ + { name = "httpx", specifier = ">=0.27.2" }, + { name = "logfire-api", specifier = ">=1.2.0" }, + { name = "pydantic", specifier = ">=2.10" }, +] + [[package]] name = "pygments" version = "2.18.0" From 24cdd35dc2019914c087c8ccb6525fd053ec8a60 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 15:34:08 +0000 Subject: [PATCH 26/57] allow labeling edges and notes for docstrings --- pydantic_graph/pydantic_graph/__init__.py | 3 +- pydantic_graph/pydantic_graph/_utils.py | 14 ++++++- pydantic_graph/pydantic_graph/graph.py | 2 +- pydantic_graph/pydantic_graph/mermaid.py | 44 ++++++++++++++++------ pydantic_graph/pydantic_graph/nodes.py | 45 +++++++++++++++++------ pydantic_graph/pydantic_graph/state.py | 4 +- 6 files changed, 84 insertions(+), 28 deletions(-) diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index 51db45c53c..9284c250f8 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -1,5 +1,5 @@ from .graph import Graph -from .nodes import BaseNode, End, GraphContext +from .nodes import BaseNode, Edge, End, GraphContext from .state import AbstractState, EndEvent, HistoryStep, NodeEvent __all__ = ( @@ -7,6 +7,7 @@ 'BaseNode', 'End', 'GraphContext', + 'Edge', 'AbstractState', 'EndEvent', 'HistoryStep', diff --git a/pydantic_graph/pydantic_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py index a451fc5533..26743fb49d 100644 --- a/pydantic_graph/pydantic_graph/_utils.py +++ b/pydantic_graph/pydantic_graph/_utils.py @@ -3,7 +3,7 @@ import sys import types from datetime import datetime, timezone -from typing import Any, Union, get_args, get_origin +from typing import Annotated, Any, Union, get_args, get_origin from typing_extensions import TypeAliasType @@ -21,6 +21,18 @@ def get_union_args(tp: Any) -> tuple[Any, ...]: return (tp,) +def strip_annotated(tp: Any) -> tuple[Any, list[Any]]: + """Strip `Annotated` from the type if present. + + Returns: `(tp argument, ())` if not annotated, otherwise `(stripped type, annotations)`. + """ + if get_origin(tp) is Annotated: + inner_tp, *args = get_args(tp) + return inner_tp, args + else: + return tp, [] + + # same as `pydantic_ai_slim/pydantic_ai/_result.py:origin_is_union` if sys.version_info < (3, 10): diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 10e1b35e55..82ac4d2576 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -62,7 +62,7 @@ def _validate_edges(self): bad_edges: dict[str, list[str]] = {} for node_id, node_def in self.node_defs.items(): - node_bad_edges = node_def.next_node_ids - known_node_ids + node_bad_edges = node_def.next_node_edges.keys() - known_node_ids for bad_edge in node_bad_edges: bad_edges.setdefault(bad_edge, []).append(f'"{node_id}"') diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index 534e1b5227..e0a66984e3 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -1,8 +1,10 @@ from __future__ import annotations as _annotations import base64 +import re from collections.abc import Iterable, Sequence from pathlib import Path +from textwrap import indent from typing import TYPE_CHECKING, Annotated, Any, Literal from annotated_types import Ge, Le @@ -25,6 +27,8 @@ def generate_code( start_node: Sequence[NodeIdent] | NodeIdent | None = None, highlighted_nodes: Sequence[NodeIdent] | NodeIdent | None = None, highlight_css: str = DEFAULT_HIGHLIGHT_CSS, + edge_labels: bool = True, + docstring_notes: bool = True, ) -> str: """Generate Mermaid code for a graph. @@ -33,6 +37,8 @@ def generate_code( start_node: Identifiers of nodes that start the graph. highlighted_nodes: Identifiers of nodes to highlight. highlight_css: CSS to use for highlighting nodes. + edge_labels: Whether to include edge labels in the diagram, defaults to true. + docstring_notes: Whether to include docstrings as notes in the diagram, defaults to true. Returns: The Mermaid code for the graph. """ @@ -41,25 +47,34 @@ def generate_code( if node_id not in graph.node_defs: raise LookupError(f'Start node "{node_id}" is not in the graph.') - node_order = {node_id: index for index, node_id in enumerate(graph.node_defs)} - - lines = ['graph TD'] + lines = ['stateDiagram-v2'] for node in graph.nodes: node_id = node.get_id() node_def = graph.node_defs[node_id] # we use round brackets (rounded box) for nodes other than the start and end - mermaid_name = f'({node_id})' if node_id in start_node_ids: - lines.append(f' START --> {node_id}{mermaid_name}') + lines.append(f' [*] --> {node_id}') if node_def.returns_base_node: for next_node_id in graph.nodes: - lines.append(f' {node_id}{mermaid_name} --> {next_node_id}') + lines.append(f' {node_id} --> {next_node_id}') else: - for _, next_node_id in sorted((node_order[node_id], node_id) for node_id in node_def.next_node_ids): - lines.append(f' {node_id}{mermaid_name} --> {next_node_id}') - if node_def.returns_end: - lines.append(f' {node_id}{mermaid_name} --> END') + for next_node_id, edge in node_def.next_node_edges.items(): + if edge_labels and (label := edge.label): + lines.append(f' {node_id} --> {next_node_id}: {label}') + else: + lines.append(f' {node_id} --> {next_node_id}') + if end_edge := node_def.end_edge: + if edge_labels and (label := end_edge.label): + lines.append(f' {node_id} --> [*]: {label}') + else: + lines.append(f' {node_id} --> [*]') + + if docstring_notes and node_def.doc_string: + lines.append(f' note right of {node_id}') + clean_docs = re.sub('\n\n+', '\n', node_def.doc_string) + lines.append(indent(clean_docs, ' ')) + lines.append(' end note') if highlighted_nodes: lines.append('') @@ -97,6 +112,10 @@ class MermaidConfig(TypedDict, total=False): """Identifiers of nodes to highlight.""" highlight_css: str """CSS to use for highlighting nodes.""" + edge_labels: bool + """Whether to include edge labels in the diagram.""" + docstring_notes: bool + """Whether to include docstrings as notes in the diagram, defaults to true.""" image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] """The image type to generate. If unspecified, the default behavior is `'jpeg'`.""" pdf_fit: bool @@ -148,6 +167,8 @@ def request_image( start_node=kwargs.get('start_node'), highlighted_nodes=kwargs.get('highlighted_nodes'), highlight_css=kwargs.get('highlight_css', DEFAULT_HIGHLIGHT_CSS), + edge_labels=kwargs.get('edge_labels', True), + docstring_notes=kwargs.get('docstring_notes', True), ) code_base64 = base64.b64encode(code.encode()).decode() @@ -180,7 +201,8 @@ def request_image( params['scale'] = str(scale) response = httpx.get(url, params=params) - response.raise_for_status() + if not response.is_success: + raise ValueError(f'{response.status_code} error generating image:\n{response.text}') return response.content diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index bbadd8b380..c984c6f76a 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -1,7 +1,8 @@ from __future__ import annotations as _annotations +import inspect from abc import abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, is_dataclass from functools import cache from typing import Any, Generic, get_origin, get_type_hints @@ -10,7 +11,7 @@ from . import _utils from .state import StateT -__all__ = 'GraphContext', 'BaseNode', 'End', 'NodeDef' +__all__ = 'GraphContext', 'BaseNode', 'End', 'Edge', 'NodeDef' RunEndT = TypeVar('RunEndT', default=None) NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) @@ -36,32 +37,43 @@ def get_id(cls) -> str: @classmethod def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRunEndT]: - type_hints = get_type_hints(cls.run, localns=local_ns) + type_hints = get_type_hints(cls.run, localns=local_ns, include_extras=True) try: return_hint = type_hints['return'] except KeyError: raise TypeError(f'Node {cls} is missing a return type hint on its `run` method') - next_node_ids: set[str] = set() - returns_end: bool = False + next_node_edges: dict[str, Edge] = {} + end_edge: Edge | None = None returns_base_node: bool = False for return_type in _utils.get_union_args(return_hint): + return_type, annotations = _utils.strip_annotated(return_type) + edge = next((a for a in annotations if isinstance(a, Edge)), Edge(None)) return_type_origin = get_origin(return_type) or return_type if return_type_origin is End: - returns_end = True + end_edge = edge elif return_type_origin is BaseNode: # TODO: Should we disallow this? returns_base_node = True elif issubclass(return_type_origin, BaseNode): - next_node_ids.add(return_type.get_id()) + next_node_edges[return_type.get_id()] = edge else: raise TypeError(f'Invalid return type: {return_type}') + docstring = cls.__doc__ + # dataclasses get an automatic docstring which is just their signature, we don't want that + if docstring and is_dataclass(cls) and docstring.startswith(f'{cls.__name__}('): + docstring = None + if docstring: + # remove indentation from docstring + docstring = inspect.cleandoc(docstring) + return NodeDef( cls, cls.get_id(), - next_node_ids, - returns_end, + docstring, + next_node_edges, + end_edge, returns_base_node, ) @@ -73,6 +85,13 @@ class End(Generic[RunEndT]): data: RunEndT +@dataclass +class Edge: + """Annotation to apply a label to an edge in a graph.""" + + label: str | None + + @dataclass class NodeDef(Generic[StateT, NodeRunEndT]): """Definition of a node. @@ -85,9 +104,11 @@ class NodeDef(Generic[StateT, NodeRunEndT]): """The node definition itself.""" node_id: str """ID of the node.""" - next_node_ids: set[str] + doc_string: str | None + """Docstring of the node.""" + next_node_edges: dict[str, Edge] """IDs of the nodes that can be called next.""" - returns_end: bool - """The node definition returns an `End`, hence the node can end the run.""" + end_edge: Edge | None + """If node definition returns an `End` this is an Edge, indicating the node can end the run.""" returns_base_node: bool """The node definition returns a `BaseNode`, hence any node in the next can be called next.""" diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index 9b007836ee..b8dbff7a72 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -4,9 +4,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Generic, Literal, Self, Union +from typing import TYPE_CHECKING, Generic, Literal, Union -from typing_extensions import Never, TypeVar +from typing_extensions import Never, Self, TypeVar from . import _utils From 6990c49dd7df406d25e204dd6536ee3adae831a2 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 18:11:47 +0000 Subject: [PATCH 27/57] allow notes to be disabled --- pydantic_graph/pydantic_graph/_utils.py | 2 +- pydantic_graph/pydantic_graph/mermaid.py | 42 +++++++++++++----------- pydantic_graph/pydantic_graph/nodes.py | 41 ++++++++++++++--------- 3 files changed, 49 insertions(+), 36 deletions(-) diff --git a/pydantic_graph/pydantic_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py index 26743fb49d..498c5bdde6 100644 --- a/pydantic_graph/pydantic_graph/_utils.py +++ b/pydantic_graph/pydantic_graph/_utils.py @@ -21,7 +21,7 @@ def get_union_args(tp: Any) -> tuple[Any, ...]: return (tp,) -def strip_annotated(tp: Any) -> tuple[Any, list[Any]]: +def unpack_annotated(tp: Any) -> tuple[Any, list[Any]]: """Strip `Annotated` from the type if present. Returns: `(tp argument, ())` if not annotated, otherwise `(stripped type, annotations)`. diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index e0a66984e3..e3e03888d7 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -16,6 +16,9 @@ from .graph import Graph +__all__ = 'NodeIdent', 'DEFAULT_HIGHLIGHT_CSS', 'generate_code', 'MermaidConfig', 'request_image', 'save_image' + + NodeIdent: TypeAlias = 'type[BaseNode[Any, Any]] | BaseNode[Any, Any] | str' DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32' @@ -28,21 +31,21 @@ def generate_code( highlighted_nodes: Sequence[NodeIdent] | NodeIdent | None = None, highlight_css: str = DEFAULT_HIGHLIGHT_CSS, edge_labels: bool = True, - docstring_notes: bool = True, + notes: bool = True, ) -> str: - """Generate Mermaid code for a graph. + """Generate [Mermaid state diagram](https://mermaid.js.org/syntax/stateDiagram.html) code for a graph. Args: graph: The graph to generate the image for. start_node: Identifiers of nodes that start the graph. highlighted_nodes: Identifiers of nodes to highlight. highlight_css: CSS to use for highlighting nodes. - edge_labels: Whether to include edge labels in the diagram, defaults to true. - docstring_notes: Whether to include docstrings as notes in the diagram, defaults to true. + edge_labels: Whether to include edge labels in the diagram. + notes: Whether to include notes in the diagram. Returns: The Mermaid code for the graph. """ - start_node_ids = set(node_ids(start_node or ())) + start_node_ids = set(_node_ids(start_node or ())) for node_id in start_node_ids: if node_id not in graph.node_defs: raise LookupError(f'Start node "{node_id}" is not in the graph.') @@ -60,26 +63,27 @@ def generate_code( lines.append(f' {node_id} --> {next_node_id}') else: for next_node_id, edge in node_def.next_node_edges.items(): - if edge_labels and (label := edge.label): - lines.append(f' {node_id} --> {next_node_id}: {label}') - else: - lines.append(f' {node_id} --> {next_node_id}') + line = f' {node_id} --> {next_node_id}' + if edge_labels and edge.label: + line += f': {edge.label}' + lines.append(line) if end_edge := node_def.end_edge: - if edge_labels and (label := end_edge.label): - lines.append(f' {node_id} --> [*]: {label}') - else: - lines.append(f' {node_id} --> [*]') + line = f' {node_id} --> [*]' + if edge_labels and end_edge.label: + line += f': {end_edge.label}' + lines.append(line) - if docstring_notes and node_def.doc_string: + if notes and node_def.note: lines.append(f' note right of {node_id}') - clean_docs = re.sub('\n\n+', '\n', node_def.doc_string) + # mermaid doesn't like multiple paragraphs in a note, and shows if so + clean_docs = re.sub('\n{2,}', '\n', node_def.note) lines.append(indent(clean_docs, ' ')) lines.append(' end note') if highlighted_nodes: lines.append('') lines.append(f'classDef highlighted {highlight_css}') - for node_id in node_ids(highlighted_nodes): + for node_id in _node_ids(highlighted_nodes): if node_id not in graph.node_defs: raise LookupError(f'Highlighted node "{node_id}" is not in the graph.') lines.append(f'class {node_id} highlighted') @@ -87,7 +91,7 @@ def generate_code( return '\n'.join(lines) -def node_ids(node_idents: Sequence[NodeIdent] | NodeIdent) -> Iterable[str]: +def _node_ids(node_idents: Sequence[NodeIdent] | NodeIdent) -> Iterable[str]: """Get the node IDs from a sequence of node identifiers.""" if isinstance(node_idents, str): node_iter = (node_idents,) @@ -114,7 +118,7 @@ class MermaidConfig(TypedDict, total=False): """CSS to use for highlighting nodes.""" edge_labels: bool """Whether to include edge labels in the diagram.""" - docstring_notes: bool + notes: bool """Whether to include docstrings as notes in the diagram, defaults to true.""" image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] """The image type to generate. If unspecified, the default behavior is `'jpeg'`.""" @@ -168,7 +172,7 @@ def request_image( highlighted_nodes=kwargs.get('highlighted_nodes'), highlight_css=kwargs.get('highlight_css', DEFAULT_HIGHLIGHT_CSS), edge_labels=kwargs.get('edge_labels', True), - docstring_notes=kwargs.get('docstring_notes', True), + notes=kwargs.get('notes', True), ) code_base64 = base64.b64encode(code.encode()).decode() diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index c984c6f76a..19b42cb7ae 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -1,10 +1,9 @@ from __future__ import annotations as _annotations -import inspect -from abc import abstractmethod +from abc import ABC, abstractmethod from dataclasses import dataclass, is_dataclass from functools import cache -from typing import Any, Generic, get_origin, get_type_hints +from typing import Any, ClassVar, Generic, get_origin, get_type_hints from typing_extensions import Never, TypeVar @@ -24,9 +23,12 @@ class GraphContext(Generic[StateT]): state: StateT -class BaseNode(Generic[StateT, NodeRunEndT]): +class BaseNode(ABC, Generic[StateT, NodeRunEndT]): """Base class for a node.""" + enable_docstring_notes: ClassVar[bool] = True + """Set to `False` to not generate mermaid diagram notes from the class's docstring.""" + @abstractmethod async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: ... @@ -35,6 +37,21 @@ async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[No def get_id(cls) -> str: return cls.__name__ + @classmethod + def get_note(cls) -> str | None: + if not cls.enable_docstring_notes: + return None + docstring = cls.__doc__ + # dataclasses get an automatic docstring which is just their signature, we don't want that + if docstring and is_dataclass(cls) and docstring.startswith(f'{cls.__name__}('): + docstring = None + if docstring: + # remove indentation from docstring + import inspect + + docstring = inspect.cleandoc(docstring) + return docstring + @classmethod def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRunEndT]: type_hints = get_type_hints(cls.run, localns=local_ns, include_extras=True) @@ -47,7 +64,7 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRu end_edge: Edge | None = None returns_base_node: bool = False for return_type in _utils.get_union_args(return_hint): - return_type, annotations = _utils.strip_annotated(return_type) + return_type, annotations = _utils.unpack_annotated(return_type) edge = next((a for a in annotations if isinstance(a, Edge)), Edge(None)) return_type_origin = get_origin(return_type) or return_type if return_type_origin is End: @@ -60,18 +77,10 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRu else: raise TypeError(f'Invalid return type: {return_type}') - docstring = cls.__doc__ - # dataclasses get an automatic docstring which is just their signature, we don't want that - if docstring and is_dataclass(cls) and docstring.startswith(f'{cls.__name__}('): - docstring = None - if docstring: - # remove indentation from docstring - docstring = inspect.cleandoc(docstring) - return NodeDef( cls, cls.get_id(), - docstring, + cls.get_note(), next_node_edges, end_edge, returns_base_node, @@ -104,8 +113,8 @@ class NodeDef(Generic[StateT, NodeRunEndT]): """The node definition itself.""" node_id: str """ID of the node.""" - doc_string: str | None - """Docstring of the node.""" + note: str | None + """Note about the node to render on mermaid charts.""" next_node_edges: dict[str, Edge] """IDs of the nodes that can be called next.""" end_edge: Edge | None From 4f69960044de1aaa9444c3331a0bc1093e57d110 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 20:02:03 +0000 Subject: [PATCH 28/57] adding graph tests --- pydantic_graph/pydantic_graph/__init__.py | 3 + pydantic_graph/pydantic_graph/exceptions.py | 20 +++ pydantic_graph/pydantic_graph/graph.py | 44 ++++-- pydantic_graph/pydantic_graph/mermaid.py | 2 +- tests/test_graph.py | 113 +++++++++++++- tests/test_graph_mermaid.py | 165 ++++++++++++++++++++ 6 files changed, 326 insertions(+), 21 deletions(-) create mode 100644 pydantic_graph/pydantic_graph/exceptions.py create mode 100644 tests/test_graph_mermaid.py diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index 9284c250f8..192b62f98c 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -1,3 +1,4 @@ +from .exceptions import GraphRuntimeError, GraphSetupError from .graph import Graph from .nodes import BaseNode, Edge, End, GraphContext from .state import AbstractState, EndEvent, HistoryStep, NodeEvent @@ -12,4 +13,6 @@ 'EndEvent', 'HistoryStep', 'NodeEvent', + 'GraphSetupError', + 'GraphRuntimeError', ) diff --git a/pydantic_graph/pydantic_graph/exceptions.py b/pydantic_graph/pydantic_graph/exceptions.py new file mode 100644 index 0000000000..5288402c36 --- /dev/null +++ b/pydantic_graph/pydantic_graph/exceptions.py @@ -0,0 +1,20 @@ +class GraphSetupError(TypeError): + """Error caused by an incorrectly configured graph.""" + + message: str + """Description of the mistake.""" + + def __init__(self, message: str): + self.message = message + super().__init__(message) + + +class GraphRuntimeError(RuntimeError): + """Error caused by an issue during graph execution.""" + + message: str + """The error message.""" + + def __init__(self, message: str): + self.message = message + super().__init__(message) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 82ac4d2576..2141333623 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -12,6 +12,7 @@ from . import _utils, mermaid from ._utils import get_parent_namespace +from .exceptions import GraphRuntimeError, GraphSetupError from .nodes import BaseNode, End, GraphContext, NodeDef from .state import EndEvent, HistoryStep, NodeEvent, StateT @@ -44,42 +45,42 @@ def __init__( _nodes_by_id: dict[str, type[BaseNode[StateT, RunEndT]]] = {} for node in nodes: node_id = node.get_id() - if (existing_node := _nodes_by_id.get(node_id)) and existing_node is not node: - raise ValueError(f'Node ID "{node_id}" is not unique — found in {existing_node} and {node}') + if existing_node := _nodes_by_id.get(node_id): + raise GraphSetupError(f'Node ID `{node_id}` is not unique — found in {existing_node} and {node}') else: _nodes_by_id[node_id] = node self.nodes = tuple(_nodes_by_id.values()) parent_namespace = get_parent_namespace(inspect.currentframe()) - self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = {} - for node in self.nodes: - self.node_defs[node.get_id()] = node.get_node_def(parent_namespace) + self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = { + node.get_id(): node.get_node_def(parent_namespace) for node in self.nodes + } self._validate_edges() def _validate_edges(self): - known_node_ids = set(self.node_defs.keys()) + known_node_ids = self.node_defs.keys() bad_edges: dict[str, list[str]] = {} for node_id, node_def in self.node_defs.items(): - node_bad_edges = node_def.next_node_edges.keys() - known_node_ids - for bad_edge in node_bad_edges: - bad_edges.setdefault(bad_edge, []).append(f'"{node_id}"') + for edge in node_def.next_node_edges.keys(): + if edge not in known_node_ids: + bad_edges.setdefault(edge, []).append(f'`{node_id}`') if bad_edges: - bad_edges_list = [f'"{k}" is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] + bad_edges_list = [f'`{k}` is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] if len(bad_edges_list) == 1: - raise ValueError(f'{bad_edges_list[0]} but not included in the graph.') + raise GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.') else: b = '\n'.join(f' {be}' for be in bad_edges_list) - raise ValueError(f'Nodes are referenced in the graph but not included in the graph:\n{b}') + raise GraphSetupError(f'Nodes are referenced in the graph but not included in the graph:\n{b}') async def next( self, state: StateT, node: BaseNode[StateT, RunEndT], history: list[HistoryStep[StateT, RunEndT]] ) -> BaseNode[StateT, Any] | End[RunEndT]: node_id = node.get_id() if node_id not in self.node_defs: - raise TypeError(f'Node "{node}" is not in the graph.') + raise GraphRuntimeError(f'Node `{node}` is not in the graph.') history_step: NodeEvent[StateT, RunEndT] | None = NodeEvent(state, node) history.append(history_step) @@ -95,7 +96,7 @@ async def run( self, state: StateT, node: BaseNode[StateT, RunEndT], - ) -> tuple[End[RunEndT], list[HistoryStep[StateT, RunEndT]]]: + ) -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: history: list[HistoryStep[StateT, RunEndT]] = [] with _logfire.span( @@ -108,14 +109,16 @@ async def run( if isinstance(next_node, End): history.append(EndEvent(state, next_node)) run_span.set_attribute('history', history) - return next_node, history + return next_node.data, history elif isinstance(next_node, BaseNode): node = next_node else: if TYPE_CHECKING: assert_never(next_node) else: - raise TypeError(f'Invalid node type: {type(next_node)}. Expected `BaseNode` or `End`.') + raise GraphRuntimeError( + f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.' + ) def mermaid_code( self, @@ -123,9 +126,16 @@ def mermaid_code( start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, + edge_labels: bool = True, + notes: bool = True, ) -> str: return mermaid.generate_code( - self, start_node=start_node, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css + self, + start_node=start_node, + highlighted_nodes=highlighted_nodes, + highlight_css=highlight_css, + edge_labels=edge_labels, + notes=notes, ) def mermaid_image(self, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes: diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index e3e03888d7..192d9f1e25 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -59,7 +59,7 @@ def generate_code( if node_id in start_node_ids: lines.append(f' [*] --> {node_id}') if node_def.returns_base_node: - for next_node_id in graph.nodes: + for next_node_id in graph.node_defs: lines.append(f' {node_id} --> {next_node_id}') else: for next_node_id, edge in node_def.next_node_edges.items(): diff --git a/tests/test_graph.py b/tests/test_graph.py index 7c3eba97aa..801bf1d876 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -2,12 +2,14 @@ from dataclasses import dataclass from datetime import timezone +from functools import cache from typing import Union import pytest +from dirty_equals import IsStr from inline_snapshot import snapshot -from pydantic_graph import BaseNode, End, EndEvent, Graph, GraphContext, NodeEvent +from pydantic_graph import BaseNode, End, EndEvent, Graph, GraphContext, GraphRuntimeError, GraphSetupError, NodeEvent from .conftest import IsFloat, IsNow @@ -42,7 +44,7 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq g = Graph[None, int](nodes=(Float2String, String2Length, Double)) result, history = await g.run(None, Float2String(3.14)) # len('3.14') * 2 == 8 - assert result == End(8) + assert result == 8 assert history == snapshot( [ NodeEvent( @@ -72,7 +74,7 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq ) result, history = await g.run(None, Float2String(3.14159)) # len('3.14159') == 7, 21 * 2 == 42 - assert result == End(42) + assert result == 42 assert history == snapshot( [ NodeEvent( @@ -112,3 +114,108 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq ), ] ) + + +def test_one_bad_node(): + class Float2String(BaseNode): + async def run(self, ctx: GraphContext) -> String2Length: + return String2Length() + + class String2Length(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> End[None]: + return End(None) + + with pytest.raises(GraphSetupError) as exc_info: + Graph(nodes=(Float2String,)) + + assert exc_info.value.message == snapshot( + '`String2Length` is referenced by `Float2String` but not included in the graph.' + ) + + +def test_two_bad_nodes(): + class Float2String(BaseNode): + input_data: float + + async def run(self, ctx: GraphContext) -> String2Length | Double: + raise NotImplementedError() + + class String2Length(BaseNode[None, None]): + input_data: str + + async def run(self, ctx: GraphContext) -> End[None]: + return End(None) + + class Double(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> End[None]: + return End(None) + + with pytest.raises(GraphSetupError) as exc_info: + Graph(nodes=(Float2String,)) + + assert exc_info.value.message == snapshot("""\ +Nodes are referenced in the graph but not included in the graph: + `String2Length` is referenced by `Float2String` + `Double` is referenced by `Float2String`\ +""") + + +def test_duplicate_id(): + class Foo(BaseNode): + async def run(self, ctx: GraphContext) -> Bar: + return Bar() + + class Bar(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> End[None]: + return End(None) + + @classmethod + @cache + def get_id(cls) -> str: + return 'Foo' + + with pytest.raises(GraphSetupError) as exc_info: + Graph(nodes=(Foo, Bar)) + + assert exc_info.value.message == snapshot(IsStr(regex='Node ID `Foo` is not unique — found in.+')) + + +async def test_run_node_not_in_graph(): + @dataclass + class Foo(BaseNode): + async def run(self, ctx: GraphContext) -> Bar: + return Bar() + + @dataclass + class Bar(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> End[None]: + return Spam() # type: ignore + + @dataclass + class Spam(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> End[None]: + return End(None) + + g = Graph(nodes=(Foo, Bar)) + with pytest.raises(GraphRuntimeError) as exc_info: + await g.run(None, Foo()) + + assert exc_info.value.message == snapshot('Node `test_run_node_not_in_graph..Spam()` is not in the graph.') + + +async def test_run_return_other(): + @dataclass + class Foo(BaseNode): + async def run(self, ctx: GraphContext) -> Bar: + return Bar() + + @dataclass + class Bar(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> End[None]: + return 42 # type: ignore + + g = Graph(nodes=(Foo, Bar)) + with pytest.raises(GraphRuntimeError) as exc_info: + await g.run(None, Foo()) + + assert exc_info.value.message == snapshot('Invalid node return type: `int`. Expected `BaseNode` or `End`.') diff --git a/tests/test_graph_mermaid.py b/tests/test_graph_mermaid.py new file mode 100644 index 0000000000..bb3931fcfa --- /dev/null +++ b/tests/test_graph_mermaid.py @@ -0,0 +1,165 @@ +from __future__ import annotations as _annotations + +from dataclasses import dataclass +from datetime import timezone +from typing import Annotated + +import pytest +from inline_snapshot import snapshot + +from pydantic_graph import BaseNode, Edge, End, EndEvent, Graph, GraphContext, NodeEvent + +from .conftest import IsFloat, IsNow + +pytestmark = pytest.mark.anyio + + +@dataclass +class Foo(BaseNode): + async def run(self, ctx: GraphContext) -> Bar: + return Bar() + + +@dataclass +class Bar(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> End[None]: + return End(None) + + +graph1 = Graph(nodes=(Foo, Bar)) + + +@dataclass +class Spam(BaseNode): + """This is the docstring for Spam.""" + + async def run(self, ctx: GraphContext) -> Annotated[Foo, Edge(label='spam to foo')]: + return Foo() + + +@dataclass +class Eggs(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> Annotated[End[None], Edge(label='eggs to end')]: + return End(None) + + +graph2 = Graph(nodes=(Spam, Foo, Bar, Eggs)) + + +async def test_run_graph(): + result, history = await graph1.run(None, Foo()) + assert result is None + assert history == snapshot( + [ + NodeEvent( + state=None, + node=Foo(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + ), + NodeEvent( + state=None, + node=Bar(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + ), + EndEvent(state=None, result=End(data=None), ts=IsNow(tz=timezone.utc)), + ] + ) + + +def test_mermaid_code_no_start(): + assert graph1.mermaid_code() == snapshot("""\ +stateDiagram-v2 + Foo --> Bar + Bar --> [*]\ +""") + + +def test_mermaid_code_start(): + assert graph1.mermaid_code(start_node=Foo) == snapshot("""\ +stateDiagram-v2 + [*] --> Foo + Foo --> Bar + Bar --> [*]\ +""") + + +def test_mermaid_code_start_wrong(): + with pytest.raises(LookupError): + graph1.mermaid_code(start_node=Spam) + + +def test_mermaid_highlight(): + code = graph1.mermaid_code(highlighted_nodes=Foo) + assert code == snapshot("""\ +stateDiagram-v2 + Foo --> Bar + Bar --> [*] + +classDef highlighted fill:#fdff32 +class Foo highlighted\ +""") + assert code == graph1.mermaid_code(highlighted_nodes='Foo') + + +def test_mermaid_highlight_multiple(): + code = graph1.mermaid_code(highlighted_nodes=(Foo, Bar)) + assert code == snapshot("""\ +stateDiagram-v2 + Foo --> Bar + Bar --> [*] + +classDef highlighted fill:#fdff32 +class Foo highlighted +class Bar highlighted\ +""") + assert code == graph1.mermaid_code(highlighted_nodes=('Foo', 'Bar')) + + +def test_mermaid_highlight_wrong(): + with pytest.raises(LookupError): + graph1.mermaid_code(highlighted_nodes=Spam) + + +def test_mermaid_code_with_edge_labels(): + assert graph2.mermaid_code() == snapshot("""\ +stateDiagram-v2 + Spam --> Foo: spam to foo + note right of Spam + This is the docstring for Spam. + end note + Foo --> Bar + Bar --> [*] + Eggs --> [*]: eggs to end\ +""") + + +def test_mermaid_code_without_edge_labels(): + assert graph2.mermaid_code(edge_labels=False, notes=False) == snapshot("""\ +stateDiagram-v2 + Spam --> Foo + Foo --> Bar + Bar --> [*] + Eggs --> [*]\ +""") + + +@dataclass +class AllNodes(BaseNode): + async def run(self, ctx: GraphContext) -> BaseNode: + return Foo() + + +graph3 = Graph(nodes=(AllNodes, Foo, Bar)) + + +def test_mermaid_code_all_nodes(): + assert graph3.mermaid_code() == snapshot("""\ +stateDiagram-v2 + AllNodes --> AllNodes + AllNodes --> Foo + AllNodes --> Bar + Foo --> Bar + Bar --> [*]\ +""") From 29f8a95918bac02624c77a01450a7977334610d0 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 20:50:57 +0000 Subject: [PATCH 29/57] more mermaid tests, fix 3.9 --- pydantic_graph/pydantic_graph/mermaid.py | 13 ++++-- tests/{test_graph.py => graph/test_main.py} | 4 +- .../test_mermaid.py} | 43 ++++++++++++++++++- 3 files changed, 52 insertions(+), 8 deletions(-) rename tests/{test_graph.py => graph/test_main.py} (98%) rename tests/{test_graph_mermaid.py => graph/test_mermaid.py} (68%) diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index 192d9f1e25..00d2ccb744 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -7,6 +7,7 @@ from textwrap import indent from typing import TYPE_CHECKING, Annotated, Any, Literal +import httpx from annotated_types import Ge, Le from typing_extensions import TypeAlias, TypedDict, Unpack @@ -149,6 +150,7 @@ class MermaidConfig(TypedDict, total=False): The scale must be a number between 1 and 3, and you can only set a scale if one or both of width and height are set. """ + httpx_client: httpx.Client def request_image( @@ -164,8 +166,6 @@ def request_image( Returns: The image data. """ - import httpx - code = generate_code( graph, start_node=kwargs.get('start_node'), @@ -204,9 +204,14 @@ def request_image( if scale := kwargs.get('scale'): params['scale'] = str(scale) - response = httpx.get(url, params=params) + httpx_client = kwargs.get('httpx_client') or httpx.Client() + response = httpx_client.get(url, params=params) if not response.is_success: - raise ValueError(f'{response.status_code} error generating image:\n{response.text}') + raise httpx.HTTPStatusError( + f'{response.status_code} error generating image:\n{response.text}', + request=response.request, + response=response, + ) return response.content diff --git a/tests/test_graph.py b/tests/graph/test_main.py similarity index 98% rename from tests/test_graph.py rename to tests/graph/test_main.py index 801bf1d876..717eaefb7b 100644 --- a/tests/test_graph.py +++ b/tests/graph/test_main.py @@ -11,7 +11,7 @@ from pydantic_graph import BaseNode, End, EndEvent, Graph, GraphContext, GraphRuntimeError, GraphSetupError, NodeEvent -from .conftest import IsFloat, IsNow +from ..conftest import IsFloat, IsNow pytestmark = pytest.mark.anyio @@ -137,7 +137,7 @@ def test_two_bad_nodes(): class Float2String(BaseNode): input_data: float - async def run(self, ctx: GraphContext) -> String2Length | Double: + async def run(self, ctx: GraphContext) -> Union[String2Length, Double]: # noqa: UP007 raise NotImplementedError() class String2Length(BaseNode[None, None]): diff --git a/tests/test_graph_mermaid.py b/tests/graph/test_mermaid.py similarity index 68% rename from tests/test_graph_mermaid.py rename to tests/graph/test_mermaid.py index bb3931fcfa..40a461c96a 100644 --- a/tests/test_graph_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -1,17 +1,42 @@ from __future__ import annotations as _annotations +from collections.abc import Iterator from dataclasses import dataclass from datetime import timezone -from typing import Annotated +from typing import Annotated, Callable +import httpx import pytest from inline_snapshot import snapshot +from typing_extensions import TypeAlias from pydantic_graph import BaseNode, Edge, End, EndEvent, Graph, GraphContext, NodeEvent -from .conftest import IsFloat, IsNow +from ..conftest import IsFloat, IsNow pytestmark = pytest.mark.anyio +HttpxWithHandler: TypeAlias = 'Callable[[Callable[[httpx.Request], httpx.Response] | httpx.Response], httpx.Client]' + + +@pytest.fixture +def httpx_with_handler() -> Iterator[HttpxWithHandler]: + client: httpx.Client | None = None + + def create_client(handler: Callable[[httpx.Request], httpx.Response] | httpx.Response) -> httpx.Client: + nonlocal client + assert client is None, 'client_with_handler can only be called once' + if isinstance(handler, httpx.Response): + transport = httpx.MockTransport(lambda _: handler) + else: + transport = httpx.MockTransport(handler) + client = httpx.Client(mounts={'all://': transport}) + return client + + try: + yield create_client + finally: + if client: # pragma: no cover + client.close() @dataclass @@ -163,3 +188,17 @@ def test_mermaid_code_all_nodes(): Foo --> Bar Bar --> [*]\ """) + + +def test_image(httpx_with_handler: HttpxWithHandler): + response = httpx.Response(200, content=b'fake image') + img = graph1.mermaid_image(start_node=Foo(), httpx_client=httpx_with_handler(response)) + assert img == b'fake image' + + +def test_image_bad(httpx_with_handler: HttpxWithHandler): + response = httpx.Response(404, content=b'not found') + with pytest.raises(httpx.HTTPStatusError, match='404 error generating image:\nnot found') as exc_info: + graph1.mermaid_image(start_node=Foo(), httpx_client=httpx_with_handler(response)) + assert exc_info.value.response.status_code == 404 + assert exc_info.value.response.content == b'not found' From 02c4dc087a49d6960914d5ac975a8fdc63baba14 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 20:51:28 +0000 Subject: [PATCH 30/57] rename node to start_node in graph.run() --- pydantic_graph/pydantic_graph/graph.py | 8 ++++---- tests/graph/__init__.py | 0 2 files changed, 4 insertions(+), 4 deletions(-) create mode 100644 tests/graph/__init__.py diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 2141333623..4e87142865 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -95,23 +95,23 @@ async def next( async def run( self, state: StateT, - node: BaseNode[StateT, RunEndT], + start_node: BaseNode[StateT, RunEndT], ) -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: history: list[HistoryStep[StateT, RunEndT]] = [] with _logfire.span( '{graph_name} run {start=}', graph_name=self.name or 'graph', - start=node, + start=start_node, ) as run_span: while True: - next_node = await self.next(state, node, history=history) + next_node = await self.next(state, start_node, history=history) if isinstance(next_node, End): history.append(EndEvent(state, next_node)) run_span.set_attribute('history', history) return next_node.data, history elif isinstance(next_node, BaseNode): - node = next_node + start_node = next_node else: if TYPE_CHECKING: assert_never(next_node) diff --git a/tests/graph/__init__.py b/tests/graph/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From fc7dfc6c65f817d85f9dfb284b3706436e4f187a Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 22:01:24 +0000 Subject: [PATCH 31/57] more tests for graphs --- pydantic_graph/pydantic_graph/_utils.py | 17 +-- pydantic_graph/pydantic_graph/graph.py | 17 +-- pydantic_graph/pydantic_graph/mermaid.py | 14 +-- pydantic_graph/pydantic_graph/nodes.py | 8 +- pyproject.toml | 1 + tests/graph/test_main.py | 53 ++++++-- tests/graph/test_mermaid.py | 151 ++++++++++++++++++++++- 7 files changed, 218 insertions(+), 43 deletions(-) diff --git a/pydantic_graph/pydantic_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py index 498c5bdde6..81123afde6 100644 --- a/pydantic_graph/pydantic_graph/_utils.py +++ b/pydantic_graph/pydantic_graph/_utils.py @@ -5,13 +5,13 @@ from datetime import datetime, timezone from typing import Annotated, Any, Union, get_args, get_origin -from typing_extensions import TypeAliasType +import typing_extensions def get_union_args(tp: Any) -> tuple[Any, ...]: """Extract the arguments of a Union type if `response_type` is a union, otherwise return the original type.""" # similar to `pydantic_ai_slim/pydantic_ai/_result.py:get_union_args` - if isinstance(tp, TypeAliasType): + if isinstance(tp, typing_extensions.TypeAliasType): tp = tp.__value__ origin = get_origin(tp) @@ -26,7 +26,8 @@ def unpack_annotated(tp: Any) -> tuple[Any, list[Any]]: Returns: `(tp argument, ())` if not annotated, otherwise `(stripped type, annotations)`. """ - if get_origin(tp) is Annotated: + origin = get_origin(tp) + if origin is Annotated or origin is typing_extensions.Annotated: inner_tp, *args = get_args(tp) return inner_tp, args else: @@ -54,16 +55,6 @@ def comma_and(items: list[str]) -> str: return ', '.join(items[:-1]) + ', and ' + items[-1] -_NoneType = type(None) - - -def type_arg_name(arg: Any) -> str: - if arg is _NoneType: - return 'None' - else: - return arg.__name__ - - def get_parent_namespace(frame: types.FrameType | None) -> dict[str, Any] | None: """Attempt to get the namespace where the graph was defined. diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 4e87142865..83f5c92845 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -10,9 +10,8 @@ import logfire_api from typing_extensions import Never, ParamSpec, TypeVar, Unpack, assert_never -from . import _utils, mermaid +from . import _utils, exceptions, mermaid from ._utils import get_parent_namespace -from .exceptions import GraphRuntimeError, GraphSetupError from .nodes import BaseNode, End, GraphContext, NodeDef from .state import EndEvent, HistoryStep, NodeEvent, StateT @@ -46,7 +45,9 @@ def __init__( for node in nodes: node_id = node.get_id() if existing_node := _nodes_by_id.get(node_id): - raise GraphSetupError(f'Node ID `{node_id}` is not unique — found in {existing_node} and {node}') + raise exceptions.GraphSetupError( + f'Node ID `{node_id}` is not unique — found in {existing_node} and {node}' + ) else: _nodes_by_id[node_id] = node self.nodes = tuple(_nodes_by_id.values()) @@ -70,17 +71,19 @@ def _validate_edges(self): if bad_edges: bad_edges_list = [f'`{k}` is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] if len(bad_edges_list) == 1: - raise GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.') + raise exceptions.GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.') else: b = '\n'.join(f' {be}' for be in bad_edges_list) - raise GraphSetupError(f'Nodes are referenced in the graph but not included in the graph:\n{b}') + raise exceptions.GraphSetupError( + f'Nodes are referenced in the graph but not included in the graph:\n{b}' + ) async def next( self, state: StateT, node: BaseNode[StateT, RunEndT], history: list[HistoryStep[StateT, RunEndT]] ) -> BaseNode[StateT, Any] | End[RunEndT]: node_id = node.get_id() if node_id not in self.node_defs: - raise GraphRuntimeError(f'Node `{node}` is not in the graph.') + raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') history_step: NodeEvent[StateT, RunEndT] | None = NodeEvent(state, node) history.append(history_step) @@ -116,7 +119,7 @@ async def run( if TYPE_CHECKING: assert_never(next_node) else: - raise GraphRuntimeError( + raise exceptions.GraphRuntimeError( f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.' ) diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index 00d2ccb744..b21b50509d 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -120,7 +120,7 @@ class MermaidConfig(TypedDict, total=False): edge_labels: bool """Whether to include edge labels in the diagram.""" notes: bool - """Whether to include docstrings as notes in the diagram, defaults to true.""" + """Whether to include notes on nodes in the diagram, defaults to true.""" image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] """The image type to generate. If unspecified, the default behavior is `'jpeg'`.""" pdf_fit: bool @@ -176,13 +176,13 @@ def request_image( ) code_base64 = base64.b64encode(code.encode()).decode() - params: dict[str, str | bool] = {} + params: dict[str, str | float] = {} if kwargs.get('image_type') == 'pdf': url = f'https://mermaid.ink/pdf/{code_base64}' if kwargs.get('pdf_fit'): - params['fit'] = True + params['fit'] = '' if kwargs.get('pdf_landscape'): - params['landscape'] = True + params['landscape'] = '' if pdf_paper := kwargs.get('pdf_paper'): params['paper'] = pdf_paper elif kwargs.get('image_type') == 'svg': @@ -198,11 +198,11 @@ def request_image( if theme := kwargs.get('theme'): params['theme'] = theme if width := kwargs.get('width'): - params['width'] = str(width) + params['width'] = width if height := kwargs.get('height'): - params['height'] = str(height) + params['height'] = height if scale := kwargs.get('scale'): - params['scale'] = str(scale) + params['scale'] = scale httpx_client = kwargs.get('httpx_client') or httpx.Client() response = httpx_client.get(url, params=params) diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index 19b42cb7ae..5fd5755881 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -7,7 +7,7 @@ from typing_extensions import Never, TypeVar -from . import _utils +from . import _utils, exceptions from .state import StateT __all__ = 'GraphContext', 'BaseNode', 'End', 'Edge', 'NodeDef' @@ -57,8 +57,8 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRu type_hints = get_type_hints(cls.run, localns=local_ns, include_extras=True) try: return_hint = type_hints['return'] - except KeyError: - raise TypeError(f'Node {cls} is missing a return type hint on its `run` method') + except KeyError as e: + raise exceptions.GraphSetupError(f'Node {cls} is missing a return type hint on its `run` method') from e next_node_edges: dict[str, Edge] = {} end_edge: Edge | None = None @@ -75,7 +75,7 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRu elif issubclass(return_type_origin, BaseNode): next_node_edges[return_type.get_id()] = edge else: - raise TypeError(f'Invalid return type: {return_type}') + raise exceptions.GraphSetupError(f'Invalid return type: {return_type}') return NodeDef( cls, diff --git a/pyproject.toml b/pyproject.toml index 746b8ee7ce..cfaf61880b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -167,6 +167,7 @@ exclude_lines = [ 'if typing.TYPE_CHECKING:', '@overload', '@typing.overload', + '@abstractmethod', '\(Protocol\):$', 'typing.assert_never', '$\s*assert_never\(', diff --git a/tests/graph/test_main.py b/tests/graph/test_main.py index 717eaefb7b..5b1d61164f 100644 --- a/tests/graph/test_main.py +++ b/tests/graph/test_main.py @@ -114,6 +114,16 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq ), ] ) + assert [e.summary() for e in history] == snapshot( + [ + 'test_graph..Float2String(input_data=3.14159)', + "test_graph..String2Length(input_data='3.14159')", + 'test_graph..Double(input_data=7)', + "test_graph..String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx')", + 'test_graph..Double(input_data=21)', + 'End(data=42)', + ] + ) def test_one_bad_node(): @@ -134,32 +144,61 @@ async def run(self, ctx: GraphContext) -> End[None]: def test_two_bad_nodes(): - class Float2String(BaseNode): + class Foo(BaseNode): input_data: float - async def run(self, ctx: GraphContext) -> Union[String2Length, Double]: # noqa: UP007 + async def run(self, ctx: GraphContext) -> Union[Bar, Spam]: # noqa: UP007 raise NotImplementedError() - class String2Length(BaseNode[None, None]): + class Bar(BaseNode[None, None]): input_data: str async def run(self, ctx: GraphContext) -> End[None]: return End(None) - class Double(BaseNode[None, None]): + class Spam(BaseNode[None, None]): async def run(self, ctx: GraphContext) -> End[None]: return End(None) with pytest.raises(GraphSetupError) as exc_info: - Graph(nodes=(Float2String,)) + Graph(nodes=(Foo,)) assert exc_info.value.message == snapshot("""\ Nodes are referenced in the graph but not included in the graph: - `String2Length` is referenced by `Float2String` - `Double` is referenced by `Float2String`\ + `Bar` is referenced by `Foo` + `Spam` is referenced by `Foo`\ """) +def test_three_bad_nodes_separate(): + class Foo(BaseNode): + input_data: float + + async def run(self, ctx: GraphContext) -> Eggs: + raise NotImplementedError() + + class Bar(BaseNode[None, None]): + input_data: str + + async def run(self, ctx: GraphContext) -> Eggs: + raise NotImplementedError() + + class Spam(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> Eggs: + raise NotImplementedError() + + class Eggs(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> End[None]: + return End(None) + + with pytest.raises(GraphSetupError) as exc_info: + Graph(nodes=(Foo, Bar, Spam)) + + assert exc_info.value.message == snapshot( + '`Eggs` is referenced by `Foo`, `Bar`, and `Spam` but not included in the graph.' + ) + + def test_duplicate_id(): class Foo(BaseNode): async def run(self, ctx: GraphContext) -> Bar: diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 40a461c96a..3546296c42 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -3,6 +3,7 @@ from collections.abc import Iterator from dataclasses import dataclass from datetime import timezone +from pathlib import Path from typing import Annotated, Callable import httpx @@ -10,7 +11,8 @@ from inline_snapshot import snapshot from typing_extensions import TypeAlias -from pydantic_graph import BaseNode, Edge, End, EndEvent, Graph, GraphContext, NodeEvent +from pydantic_graph import BaseNode, Edge, End, EndEvent, Graph, GraphContext, GraphSetupError, NodeEvent +from pydantic_graph.nodes import NodeDef from ..conftest import IsFloat, IsNow @@ -64,6 +66,10 @@ async def run(self, ctx: GraphContext) -> Annotated[Foo, Edge(label='spam to foo @dataclass class Eggs(BaseNode[None, None]): + """This is the docstring for Eggs.""" + + enable_docstring_notes = False + async def run(self, ctx: GraphContext) -> Annotated[End[None], Edge(label='eggs to end')]: return End(None) @@ -190,10 +196,42 @@ def test_mermaid_code_all_nodes(): """) -def test_image(httpx_with_handler: HttpxWithHandler): - response = httpx.Response(200, content=b'fake image') - img = graph1.mermaid_image(start_node=Foo(), httpx_client=httpx_with_handler(response)) - assert img == b'fake image' +def test_image_jpg(httpx_with_handler: HttpxWithHandler): + def get_jpg(request: httpx.Request) -> httpx.Response: + assert dict(request.url.params) == snapshot({}) + assert request.url.path.startswith('/img/') + return httpx.Response(200, content=b'fake jpg') + + img = graph1.mermaid_image(start_node=Foo(), httpx_client=httpx_with_handler(get_jpg)) + assert img == b'fake jpg' + + +def test_image_png(httpx_with_handler: HttpxWithHandler): + def get_png(request: httpx.Request) -> httpx.Response: + assert dict(request.url.params) == snapshot( + { + 'type': 'png', + 'bgColor': '123', + 'theme': 'forest', + 'width': '100', + 'height': '200', + 'scale': '3', + } + ) + assert request.url.path.startswith('/img/') + return httpx.Response(200, content=b'fake png') + + img = graph1.mermaid_image( + start_node=Foo(), + image_type='png', + background_color='123', + theme='forest', + width=100, + height=200, + scale=3, + httpx_client=httpx_with_handler(get_png), + ) + assert img == b'fake png' def test_image_bad(httpx_with_handler: HttpxWithHandler): @@ -202,3 +240,106 @@ def test_image_bad(httpx_with_handler: HttpxWithHandler): graph1.mermaid_image(start_node=Foo(), httpx_client=httpx_with_handler(response)) assert exc_info.value.response.status_code == 404 assert exc_info.value.response.content == b'not found' + + +def test_pdf(httpx_with_handler: HttpxWithHandler): + def get_pdf(request: httpx.Request) -> httpx.Response: + assert dict(request.url.params) == snapshot({}) + assert request.url.path.startswith('/pdf/') + return httpx.Response(200, content=b'fake pdf') + + pdf = graph1.mermaid_image(start_node=Foo(), image_type='pdf', httpx_client=httpx_with_handler(get_pdf)) + assert pdf == b'fake pdf' + + +def test_pdf_config(httpx_with_handler: HttpxWithHandler): + def get_pdf(request: httpx.Request) -> httpx.Response: + assert dict(request.url.params) == snapshot({'fit': '', 'landscape': '', 'paper': 'letter'}) + assert request.url.path.startswith('/pdf/') + return httpx.Response(200, content=b'fake pdf') + + pdf = graph1.mermaid_image( + start_node=Foo(), + image_type='pdf', + pdf_fit=True, + pdf_landscape=True, + pdf_paper='letter', + httpx_client=httpx_with_handler(get_pdf), + ) + assert pdf == b'fake pdf' + + +def test_svg(httpx_with_handler: HttpxWithHandler): + def get_svg(request: httpx.Request) -> httpx.Response: + assert dict(request.url.params) == snapshot({}) + assert request.url.path.startswith('/svg/') + return httpx.Response(200, content=b'fake svg') + + svg = graph1.mermaid_image(start_node=Foo(), image_type='svg', httpx_client=httpx_with_handler(get_svg)) + assert svg == b'fake svg' + + +def test_save_jpg(tmp_path: Path, httpx_with_handler: HttpxWithHandler): + def get_jpg(request: httpx.Request) -> httpx.Response: + assert dict(request.url.params) == snapshot({}) + assert request.url.path.startswith('/img/') + return httpx.Response(200, content=b'fake jpg') + + path = tmp_path / 'graph.jpg' + graph1.mermaid_save(path, start_node=Foo(), httpx_client=httpx_with_handler(get_jpg)) + assert path.read_bytes() == b'fake jpg' + + +def test_save_png(tmp_path: Path, httpx_with_handler: HttpxWithHandler): + def get_png(request: httpx.Request) -> httpx.Response: + assert dict(request.url.params) == snapshot({'type': 'png'}) + assert request.url.path.startswith('/img/') + return httpx.Response(200, content=b'fake png') + + path2 = tmp_path / 'graph.png' + graph1.mermaid_save(str(path2), start_node=Foo(), httpx_client=httpx_with_handler(get_png)) + assert path2.read_bytes() == b'fake png' + + +def test_save_pdf_known(tmp_path: Path, httpx_with_handler: HttpxWithHandler): + def get_pdf(request: httpx.Request) -> httpx.Response: + assert dict(request.url.params) == snapshot({}) + assert request.url.path.startswith('/pdf/') + return httpx.Response(200, content=b'fake pdf') + + path2 = tmp_path / 'graph' + graph1.mermaid_save(str(path2), start_node=Foo(), image_type='pdf', httpx_client=httpx_with_handler(get_pdf)) + assert path2.read_bytes() == b'fake pdf' + + +def test_get_node_def(): + assert Foo.get_node_def({}) == snapshot( + NodeDef( + node=Foo, + node_id='Foo', + note=None, + next_node_edges={'Bar': Edge(label=None)}, + end_edge=None, + returns_base_node=False, + ) + ) + + +def test_no_return_type(): + @dataclass + class NoReturnType(BaseNode): + async def run(self, ctx: GraphContext): # type: ignore + raise NotImplementedError() + + with pytest.raises(GraphSetupError, match=r".*\.NoReturnType'> is missing a return type hint on its `run` method"): + NoReturnType.get_node_def({}) + + +def test_wrong_return_type(): + @dataclass + class NoReturnType(BaseNode): + async def run(self, ctx: GraphContext) -> int: # type: ignore + raise NotImplementedError() + + with pytest.raises(GraphSetupError, match="Invalid return type: "): + NoReturnType.get_node_def({}) From db9543eb97c4d768f8dd2021cd1cf9edd3372611 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 22:08:10 +0000 Subject: [PATCH 32/57] coverage in tests --- tests/graph/test_main.py | 16 ++++++++-------- tests/graph/test_mermaid.py | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/graph/test_main.py b/tests/graph/test_main.py index 5b1d61164f..43438ff7bc 100644 --- a/tests/graph/test_main.py +++ b/tests/graph/test_main.py @@ -129,11 +129,11 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq def test_one_bad_node(): class Float2String(BaseNode): async def run(self, ctx: GraphContext) -> String2Length: - return String2Length() + raise NotImplementedError() class String2Length(BaseNode[None, None]): async def run(self, ctx: GraphContext) -> End[None]: - return End(None) + raise NotImplementedError() with pytest.raises(GraphSetupError) as exc_info: Graph(nodes=(Float2String,)) @@ -154,11 +154,11 @@ class Bar(BaseNode[None, None]): input_data: str async def run(self, ctx: GraphContext) -> End[None]: - return End(None) + raise NotImplementedError() class Spam(BaseNode[None, None]): async def run(self, ctx: GraphContext) -> End[None]: - return End(None) + raise NotImplementedError() with pytest.raises(GraphSetupError) as exc_info: Graph(nodes=(Foo,)) @@ -189,7 +189,7 @@ async def run(self, ctx: GraphContext) -> Eggs: class Eggs(BaseNode[None, None]): async def run(self, ctx: GraphContext) -> End[None]: - return End(None) + raise NotImplementedError() with pytest.raises(GraphSetupError) as exc_info: Graph(nodes=(Foo, Bar, Spam)) @@ -202,11 +202,11 @@ async def run(self, ctx: GraphContext) -> End[None]: def test_duplicate_id(): class Foo(BaseNode): async def run(self, ctx: GraphContext) -> Bar: - return Bar() + raise NotImplementedError() class Bar(BaseNode[None, None]): async def run(self, ctx: GraphContext) -> End[None]: - return End(None) + raise NotImplementedError() @classmethod @cache @@ -233,7 +233,7 @@ async def run(self, ctx: GraphContext) -> End[None]: @dataclass class Spam(BaseNode[None, None]): async def run(self, ctx: GraphContext) -> End[None]: - return End(None) + raise NotImplementedError() g = Graph(nodes=(Foo, Bar)) with pytest.raises(GraphRuntimeError) as exc_info: diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 3546296c42..422b7ecdee 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -61,7 +61,7 @@ class Spam(BaseNode): """This is the docstring for Spam.""" async def run(self, ctx: GraphContext) -> Annotated[Foo, Edge(label='spam to foo')]: - return Foo() + raise NotImplementedError() @dataclass @@ -71,7 +71,7 @@ class Eggs(BaseNode[None, None]): enable_docstring_notes = False async def run(self, ctx: GraphContext) -> Annotated[End[None], Edge(label='eggs to end')]: - return End(None) + raise NotImplementedError() graph2 = Graph(nodes=(Spam, Foo, Bar, Eggs)) @@ -179,7 +179,7 @@ def test_mermaid_code_without_edge_labels(): @dataclass class AllNodes(BaseNode): async def run(self, ctx: GraphContext) -> BaseNode: - return Foo() + raise NotImplementedError() graph3 = Graph(nodes=(AllNodes, Foo, Bar)) From e9d1d2bb23940c1e1d214de3b569effee0527d0b Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 23:43:46 +0000 Subject: [PATCH 33/57] cleanup graph properties --- pydantic_graph/pydantic_graph/graph.py | 27 +++++++++++------------- pydantic_graph/pydantic_graph/mermaid.py | 5 +---- tests/graph/test_main.py | 2 +- 3 files changed, 14 insertions(+), 20 deletions(-) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 83f5c92845..8ae518ac42 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -29,7 +29,6 @@ class Graph(Generic[StateT, RunEndT]): """Definition of a graph.""" name: str | None - nodes: tuple[type[BaseNode[StateT, RunEndT]], ...] node_defs: dict[str, NodeDef[StateT, RunEndT]] def __init__( @@ -41,24 +40,22 @@ def __init__( ): self.name = name - _nodes_by_id: dict[str, type[BaseNode[StateT, RunEndT]]] = {} - for node in nodes: - node_id = node.get_id() - if existing_node := _nodes_by_id.get(node_id): - raise exceptions.GraphSetupError( - f'Node ID `{node_id}` is not unique — found in {existing_node} and {node}' - ) - else: - _nodes_by_id[node_id] = node - self.nodes = tuple(_nodes_by_id.values()) - parent_namespace = get_parent_namespace(inspect.currentframe()) - self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = { - node.get_id(): node.get_node_def(parent_namespace) for node in self.nodes - } + self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = {} + for node in nodes: + self._register_node(node, parent_namespace) self._validate_edges() + def _register_node(self, node: type[BaseNode[StateT, RunEndT]], parent_namespace: dict[str, Any] | None) -> None: + node_id = node.get_id() + if existing_node := self.node_defs.get(node_id): + raise exceptions.GraphSetupError( + f'Node ID `{node_id}` is not unique — found on {existing_node.node} and {node}' + ) + else: + self.node_defs[node_id] = node.get_node_def(parent_namespace) + def _validate_edges(self): known_node_ids = self.node_defs.keys() bad_edges: dict[str, list[str]] = {} diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index b21b50509d..0b498e685f 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -52,10 +52,7 @@ def generate_code( raise LookupError(f'Start node "{node_id}" is not in the graph.') lines = ['stateDiagram-v2'] - for node in graph.nodes: - node_id = node.get_id() - node_def = graph.node_defs[node_id] - + for node_id, node_def in graph.node_defs.items(): # we use round brackets (rounded box) for nodes other than the start and end if node_id in start_node_ids: lines.append(f' [*] --> {node_id}') diff --git a/tests/graph/test_main.py b/tests/graph/test_main.py index 43438ff7bc..7e076baf47 100644 --- a/tests/graph/test_main.py +++ b/tests/graph/test_main.py @@ -216,7 +216,7 @@ def get_id(cls) -> str: with pytest.raises(GraphSetupError) as exc_info: Graph(nodes=(Foo, Bar)) - assert exc_info.value.message == snapshot(IsStr(regex='Node ID `Foo` is not unique — found in.+')) + assert exc_info.value.message == snapshot(IsStr(regex='Node ID `Foo` is not unique — found on Date: Sat, 11 Jan 2025 00:39:03 +0000 Subject: [PATCH 34/57] infer graph name --- pydantic_graph/pydantic_graph/graph.py | 59 ++++++++++++-- pydantic_graph/pydantic_graph/mermaid.py | 12 ++- tests/graph/test_main.py | 53 +++++++++++-- tests/graph/test_mermaid.py | 99 +++++++++++++++--------- 4 files changed, 176 insertions(+), 47 deletions(-) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 8ae518ac42..f7f45020ba 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -5,10 +5,11 @@ from dataclasses import dataclass from pathlib import Path from time import perf_counter +from types import FrameType from typing import TYPE_CHECKING, Any, Generic import logfire_api -from typing_extensions import Never, ParamSpec, TypeVar, Unpack, assert_never +from typing_extensions import Literal, Never, ParamSpec, TypeVar, Unpack, assert_never from . import _utils, exceptions, mermaid from ._utils import get_parent_namespace @@ -76,8 +77,15 @@ def _validate_edges(self): ) async def next( - self, state: StateT, node: BaseNode[StateT, RunEndT], history: list[HistoryStep[StateT, RunEndT]] + self, + state: StateT, + node: BaseNode[StateT, RunEndT], + history: list[HistoryStep[StateT, RunEndT]], + *, + infer_name: bool = True, ) -> BaseNode[StateT, Any] | End[RunEndT]: + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) node_id = node.get_id() if node_id not in self.node_defs: raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') @@ -96,8 +104,12 @@ async def run( self, state: StateT, start_node: BaseNode[StateT, RunEndT], + *, + infer_name: bool = True, ) -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: history: list[HistoryStep[StateT, RunEndT]] = [] + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) with _logfire.span( '{graph_name} run {start=}', @@ -105,7 +117,7 @@ async def run( start=start_node, ) as run_span: while True: - next_node = await self.next(state, start_node, history=history) + next_node = await self.next(state, start_node, history, infer_name=False) if isinstance(next_node, End): history.append(EndEvent(state, next_node)) run_span.set_attribute('history', history) @@ -127,19 +139,56 @@ def mermaid_code( highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, edge_labels: bool = True, + title: str | None | Literal[False] = None, notes: bool = True, + infer_name: bool = True, ) -> str: + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + if title is None and self.name: + title = self.name return mermaid.generate_code( self, start_node=start_node, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css, + title=title or None, edge_labels=edge_labels, notes=notes, ) - def mermaid_image(self, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes: + def mermaid_image(self, infer_name: bool = True, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes: + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + if 'title' not in kwargs and self.name: + kwargs['title'] = self.name return mermaid.request_image(self, **kwargs) - def mermaid_save(self, path: Path | str, /, **kwargs: Unpack[mermaid.MermaidConfig]) -> None: + def mermaid_save( + self, path: Path | str, /, *, infer_name: bool = True, **kwargs: Unpack[mermaid.MermaidConfig] + ) -> None: + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + if 'title' not in kwargs and self.name: + kwargs['title'] = self.name mermaid.save_image(path, self, **kwargs) + + def _infer_name(self, function_frame: FrameType | None) -> None: + """Infer the agent name from the call frame. + + Usage should be `self._infer_name(inspect.currentframe())`. + + Copied from `Agent`. + """ + assert self.name is None, 'Name already set' + if function_frame is not None and (parent_frame := function_frame.f_back): # pragma: no branch + for name, item in parent_frame.f_locals.items(): + if item is self: + self.name = name + return + if parent_frame.f_locals != parent_frame.f_globals: + # if we couldn't find the agent in locals and globals are a different dict, try globals + for name, item in parent_frame.f_globals.items(): + if item is self: + self.name = name + return diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index 0b498e685f..e89a00da49 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -24,13 +24,14 @@ DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32' -def generate_code( +def generate_code( # noqa: C901 graph: Graph[Any, Any], /, *, start_node: Sequence[NodeIdent] | NodeIdent | None = None, highlighted_nodes: Sequence[NodeIdent] | NodeIdent | None = None, highlight_css: str = DEFAULT_HIGHLIGHT_CSS, + title: str | None = None, edge_labels: bool = True, notes: bool = True, ) -> str: @@ -41,6 +42,7 @@ def generate_code( start_node: Identifiers of nodes that start the graph. highlighted_nodes: Identifiers of nodes to highlight. highlight_css: CSS to use for highlighting nodes. + title: The title of the diagram. edge_labels: Whether to include edge labels in the diagram. notes: Whether to include notes in the diagram. @@ -51,7 +53,10 @@ def generate_code( if node_id not in graph.node_defs: raise LookupError(f'Start node "{node_id}" is not in the graph.') - lines = ['stateDiagram-v2'] + lines: list[str] = [] + if title: + lines = ['---', f'title: {title}', '---'] + lines.append('stateDiagram-v2') for node_id, node_def in graph.node_defs.items(): # we use round brackets (rounded box) for nodes other than the start and end if node_id in start_node_ids: @@ -114,6 +119,8 @@ class MermaidConfig(TypedDict, total=False): """Identifiers of nodes to highlight.""" highlight_css: str """CSS to use for highlighting nodes.""" + title: str | None + """The title of the diagram.""" edge_labels: bool """Whether to include edge labels in the diagram.""" notes: bool @@ -168,6 +175,7 @@ def request_image( start_node=kwargs.get('start_node'), highlighted_nodes=kwargs.get('highlighted_nodes'), highlight_css=kwargs.get('highlight_css', DEFAULT_HIGHLIGHT_CSS), + title=kwargs.get('title'), edge_labels=kwargs.get('edge_labels', True), notes=kwargs.get('notes', True), ) diff --git a/tests/graph/test_main.py b/tests/graph/test_main.py index 7e076baf47..22463341d9 100644 --- a/tests/graph/test_main.py +++ b/tests/graph/test_main.py @@ -3,13 +3,23 @@ from dataclasses import dataclass from datetime import timezone from functools import cache -from typing import Union +from typing import Never, Union import pytest from dirty_equals import IsStr from inline_snapshot import snapshot -from pydantic_graph import BaseNode, End, EndEvent, Graph, GraphContext, GraphRuntimeError, GraphSetupError, NodeEvent +from pydantic_graph import ( + BaseNode, + End, + EndEvent, + Graph, + GraphContext, + GraphRuntimeError, + GraphSetupError, + HistoryStep, + NodeEvent, +) from ..conftest import IsFloat, IsNow @@ -41,10 +51,12 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq else: return End(self.input_data * 2) - g = Graph[None, int](nodes=(Float2String, String2Length, Double)) - result, history = await g.run(None, Float2String(3.14)) + my_graph = Graph[None, int](nodes=(Float2String, String2Length, Double)) + assert my_graph.name is None + result, history = await my_graph.run(None, Float2String(3.14)) # len('3.14') * 2 == 8 assert result == 8 + assert my_graph.name == 'my_graph' assert history == snapshot( [ NodeEvent( @@ -72,7 +84,7 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq ), ] ) - result, history = await g.run(None, Float2String(3.14159)) + result, history = await my_graph.run(None, Float2String(3.14159)) # len('3.14159') == 7, 21 * 2 == 42 assert result == 42 assert history == snapshot( @@ -258,3 +270,34 @@ async def run(self, ctx: GraphContext) -> End[None]: await g.run(None, Foo()) assert exc_info.value.message == snapshot('Invalid node return type: `int`. Expected `BaseNode` or `End`.') + + +async def test_next(): + @dataclass + class Foo(BaseNode): + async def run(self, ctx: GraphContext) -> Bar: + return Bar() + + @dataclass + class Bar(BaseNode): + async def run(self, ctx: GraphContext) -> Foo: + return Foo() + + g = Graph(nodes=(Foo, Bar)) + assert g.name is None + history: list[HistoryStep[None, Never]] = [] + n = await g.next(None, Foo(), history) + assert n == Bar() + assert g.name == 'g' + assert history == snapshot([NodeEvent(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat())]) + + assert isinstance(n, Bar) + n2 = await g.next(None, n, history) + assert n2 == Foo() + + assert history == snapshot( + [ + NodeEvent(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeEvent(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + ] + ) diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 422b7ecdee..9f790fc0ca 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import base64 from collections.abc import Iterator from dataclasses import dataclass from datetime import timezone @@ -9,7 +10,6 @@ import httpx import pytest from inline_snapshot import snapshot -from typing_extensions import TypeAlias from pydantic_graph import BaseNode, Edge, End, EndEvent, Graph, GraphContext, GraphSetupError, NodeEvent from pydantic_graph.nodes import NodeDef @@ -17,28 +17,6 @@ from ..conftest import IsFloat, IsNow pytestmark = pytest.mark.anyio -HttpxWithHandler: TypeAlias = 'Callable[[Callable[[httpx.Request], httpx.Response] | httpx.Response], httpx.Client]' - - -@pytest.fixture -def httpx_with_handler() -> Iterator[HttpxWithHandler]: - client: httpx.Client | None = None - - def create_client(handler: Callable[[httpx.Request], httpx.Response] | httpx.Response) -> httpx.Client: - nonlocal client - assert client is None, 'client_with_handler can only be called once' - if isinstance(handler, httpx.Response): - transport = httpx.MockTransport(lambda _: handler) - else: - transport = httpx.MockTransport(handler) - client = httpx.Client(mounts={'all://': transport}) - return client - - try: - yield create_client - finally: - if client: # pragma: no cover - client.close() @dataclass @@ -100,7 +78,7 @@ async def test_run_graph(): def test_mermaid_code_no_start(): - assert graph1.mermaid_code() == snapshot("""\ + assert graph1.mermaid_code(title=False) == snapshot("""\ stateDiagram-v2 Foo --> Bar Bar --> [*]\ @@ -109,6 +87,9 @@ def test_mermaid_code_no_start(): def test_mermaid_code_start(): assert graph1.mermaid_code(start_node=Foo) == snapshot("""\ +--- +title: graph1 +--- stateDiagram-v2 [*] --> Foo Foo --> Bar @@ -124,6 +105,9 @@ def test_mermaid_code_start_wrong(): def test_mermaid_highlight(): code = graph1.mermaid_code(highlighted_nodes=Foo) assert code == snapshot("""\ +--- +title: graph1 +--- stateDiagram-v2 Foo --> Bar Bar --> [*] @@ -137,6 +121,9 @@ class Foo highlighted\ def test_mermaid_highlight_multiple(): code = graph1.mermaid_code(highlighted_nodes=(Foo, Bar)) assert code == snapshot("""\ +--- +title: graph1 +--- stateDiagram-v2 Foo --> Bar Bar --> [*] @@ -155,6 +142,9 @@ def test_mermaid_highlight_wrong(): def test_mermaid_code_with_edge_labels(): assert graph2.mermaid_code() == snapshot("""\ +--- +title: graph2 +--- stateDiagram-v2 Spam --> Foo: spam to foo note right of Spam @@ -168,6 +158,9 @@ def test_mermaid_code_with_edge_labels(): def test_mermaid_code_without_edge_labels(): assert graph2.mermaid_code(edge_labels=False, notes=False) == snapshot("""\ +--- +title: graph2 +--- stateDiagram-v2 Spam --> Foo Foo --> Bar @@ -187,6 +180,9 @@ async def run(self, ctx: GraphContext) -> BaseNode: def test_mermaid_code_all_nodes(): assert graph3.mermaid_code() == snapshot("""\ +--- +title: graph3 +--- stateDiagram-v2 AllNodes --> AllNodes AllNodes --> Foo @@ -196,14 +192,37 @@ def test_mermaid_code_all_nodes(): """) +@pytest.fixture +def httpx_with_handler() -> Iterator[HttpxWithHandler]: + client: httpx.Client | None = None + + def create_client(handler: Callable[[httpx.Request], httpx.Response]) -> httpx.Client: + nonlocal client + assert client is None, 'client_with_handler can only be called once' + client = httpx.Client(mounts={'all://': httpx.MockTransport(handler)}) + return client + + try: + yield create_client + finally: + if client: # pragma: no cover + client.close() + + +HttpxWithHandler = Callable[[Callable[[httpx.Request], httpx.Response]], httpx.Client] + + def test_image_jpg(httpx_with_handler: HttpxWithHandler): def get_jpg(request: httpx.Request) -> httpx.Response: assert dict(request.url.params) == snapshot({}) assert request.url.path.startswith('/img/') - return httpx.Response(200, content=b'fake jpg') + mermaid = base64.b64decode(request.url.path[5:].encode()) + return httpx.Response(200, content=mermaid) + graph1.name = None img = graph1.mermaid_image(start_node=Foo(), httpx_client=httpx_with_handler(get_jpg)) - assert img == b'fake jpg' + assert graph1.name == 'graph1' + assert img == snapshot(b'---\ntitle: graph1\n---\nstateDiagram-v2\n [*] --> Foo\n Foo --> Bar\n Bar --> [*]') def test_image_png(httpx_with_handler: HttpxWithHandler): @@ -219,10 +238,12 @@ def get_png(request: httpx.Request) -> httpx.Response: } ) assert request.url.path.startswith('/img/') - return httpx.Response(200, content=b'fake png') + mermaid = base64.b64decode(request.url.path[5:].encode()) + return httpx.Response(200, content=mermaid) img = graph1.mermaid_image( start_node=Foo(), + title=None, image_type='png', background_color='123', theme='forest', @@ -231,13 +252,15 @@ def get_png(request: httpx.Request) -> httpx.Response: scale=3, httpx_client=httpx_with_handler(get_png), ) - assert img == b'fake png' + assert img == snapshot(b'stateDiagram-v2\n [*] --> Foo\n Foo --> Bar\n Bar --> [*]') def test_image_bad(httpx_with_handler: HttpxWithHandler): - response = httpx.Response(404, content=b'not found') + def get_404(request: httpx.Request) -> httpx.Response: + return httpx.Response(404, content=b'not found') + with pytest.raises(httpx.HTTPStatusError, match='404 error generating image:\nnot found') as exc_info: - graph1.mermaid_image(start_node=Foo(), httpx_client=httpx_with_handler(response)) + graph1.mermaid_image(start_node=Foo(), httpx_client=httpx_with_handler(get_404)) assert exc_info.value.response.status_code == 404 assert exc_info.value.response.content == b'not found' @@ -283,22 +306,28 @@ def test_save_jpg(tmp_path: Path, httpx_with_handler: HttpxWithHandler): def get_jpg(request: httpx.Request) -> httpx.Response: assert dict(request.url.params) == snapshot({}) assert request.url.path.startswith('/img/') - return httpx.Response(200, content=b'fake jpg') + mermaid = base64.b64decode(request.url.path[5:].encode()) + return httpx.Response(200, content=mermaid) path = tmp_path / 'graph.jpg' graph1.mermaid_save(path, start_node=Foo(), httpx_client=httpx_with_handler(get_jpg)) - assert path.read_bytes() == b'fake jpg' + assert path.read_bytes() == snapshot( + b'---\ntitle: graph1\n---\nstateDiagram-v2\n [*] --> Foo\n Foo --> Bar\n Bar --> [*]' + ) def test_save_png(tmp_path: Path, httpx_with_handler: HttpxWithHandler): def get_png(request: httpx.Request) -> httpx.Response: assert dict(request.url.params) == snapshot({'type': 'png'}) assert request.url.path.startswith('/img/') - return httpx.Response(200, content=b'fake png') + mermaid = base64.b64decode(request.url.path[5:].encode()) + return httpx.Response(200, content=mermaid) path2 = tmp_path / 'graph.png' - graph1.mermaid_save(str(path2), start_node=Foo(), httpx_client=httpx_with_handler(get_png)) - assert path2.read_bytes() == b'fake png' + graph1.name = None + graph1.mermaid_save(str(path2), title=None, start_node=Foo(), httpx_client=httpx_with_handler(get_png)) + assert graph1.name == 'graph1' + assert path2.read_bytes() == snapshot(b'stateDiagram-v2\n [*] --> Foo\n Foo --> Bar\n Bar --> [*]') def test_save_pdf_known(tmp_path: Path, httpx_with_handler: HttpxWithHandler): From 88c1d4604fd09a945e11cf30a6ec9d9b168596c6 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 11 Jan 2025 00:50:38 +0000 Subject: [PATCH 35/57] fix for 3.9 --- tests/graph/test_main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/graph/test_main.py b/tests/graph/test_main.py index 22463341d9..a26e9d7663 100644 --- a/tests/graph/test_main.py +++ b/tests/graph/test_main.py @@ -3,11 +3,12 @@ from dataclasses import dataclass from datetime import timezone from functools import cache -from typing import Never, Union +from typing import Union import pytest from dirty_equals import IsStr from inline_snapshot import snapshot +from typing_extensions import Never from pydantic_graph import ( BaseNode, From 0e3ecb37ad06e83956b8f25d7820f99c3d5d6d2d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 11 Jan 2025 12:55:44 +0000 Subject: [PATCH 36/57] adding API docs --- docs/api/pydantic_graph/graph.md | 3 + docs/api/pydantic_graph/mermaid.md | 3 + docs/api/pydantic_graph/nodes.md | 11 ++ docs/api/pydantic_graph/state.md | 3 + mkdocs.yml | 5 + pydantic_graph/pydantic_graph/__init__.py | 6 +- pydantic_graph/pydantic_graph/graph.py | 143 ++++++++++++++++------ pydantic_graph/pydantic_graph/mermaid.py | 113 +++++++++-------- pydantic_graph/pydantic_graph/nodes.py | 27 +++- pydantic_graph/pydantic_graph/state.py | 32 +++-- pydantic_graph/pyproject.toml | 3 +- tests/graph/test_main.py | 30 ++--- tests/graph/test_mermaid.py | 8 +- uv.lock | 49 ++++---- 14 files changed, 287 insertions(+), 149 deletions(-) create mode 100644 docs/api/pydantic_graph/graph.md create mode 100644 docs/api/pydantic_graph/mermaid.md create mode 100644 docs/api/pydantic_graph/nodes.md create mode 100644 docs/api/pydantic_graph/state.md diff --git a/docs/api/pydantic_graph/graph.md b/docs/api/pydantic_graph/graph.md new file mode 100644 index 0000000000..c8e3ea7944 --- /dev/null +++ b/docs/api/pydantic_graph/graph.md @@ -0,0 +1,3 @@ +# `pydantic_graph` + +::: pydantic_graph.graph diff --git a/docs/api/pydantic_graph/mermaid.md b/docs/api/pydantic_graph/mermaid.md new file mode 100644 index 0000000000..ccc7cec5ad --- /dev/null +++ b/docs/api/pydantic_graph/mermaid.md @@ -0,0 +1,3 @@ +# `pydantic_graph.mermaid` + +::: pydantic_graph.mermaid diff --git a/docs/api/pydantic_graph/nodes.md b/docs/api/pydantic_graph/nodes.md new file mode 100644 index 0000000000..7540920a02 --- /dev/null +++ b/docs/api/pydantic_graph/nodes.md @@ -0,0 +1,11 @@ +# `pydantic_graph.nodes` + +::: pydantic_graph.nodes + options: + members: + - GraphContext + - BaseNode + - End + - Edge + - RunEndT + - NodeRunEndT diff --git a/docs/api/pydantic_graph/state.md b/docs/api/pydantic_graph/state.md new file mode 100644 index 0000000000..480eea5a54 --- /dev/null +++ b/docs/api/pydantic_graph/state.md @@ -0,0 +1,3 @@ +# `pydantic_graph.state` + +::: pydantic_graph.state diff --git a/mkdocs.yml b/mkdocs.yml index 915701c64a..0b13b451c5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -55,6 +55,10 @@ nav: - api/models/ollama.md - api/models/test.md - api/models/function.md + - api/pydantic_graph/graph.md + - api/pydantic_graph/nodes.md + - api/pydantic_graph/state.md + - api/pydantic_graph/mermaid.md extra: # hide the "Made with Material for MkDocs" message @@ -150,6 +154,7 @@ markdown_extensions: watch: - pydantic_ai_slim + - pydantic_graph - examples plugins: diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index 192b62f98c..5102117738 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -1,7 +1,7 @@ from .exceptions import GraphRuntimeError, GraphSetupError from .graph import Graph from .nodes import BaseNode, Edge, End, GraphContext -from .state import AbstractState, EndEvent, HistoryStep, NodeEvent +from .state import AbstractState, EndStep, HistoryStep, NodeStep __all__ = ( 'Graph', @@ -10,9 +10,9 @@ 'GraphContext', 'Edge', 'AbstractState', - 'EndEvent', + 'EndStep', 'HistoryStep', - 'NodeEvent', + 'NodeStep', 'GraphSetupError', 'GraphRuntimeError', ) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index f7f45020ba..eed901fe95 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -14,7 +14,7 @@ from . import _utils, exceptions, mermaid from ._utils import get_parent_namespace from .nodes import BaseNode, End, GraphContext, NodeDef -from .state import EndEvent, HistoryStep, NodeEvent, StateT +from .state import EndStep, HistoryStep, NodeStep, StateT __all__ = ('Graph',) @@ -36,9 +36,16 @@ def __init__( self, *, nodes: Sequence[type[BaseNode[StateT, RunEndT]]], - state_type: type[StateT] | None = None, name: str | None = None, ): + """Create a graph from a sequence of nodes. + + Args: + nodes: The nodes which make up the graph, nodes need to be unique and all be generic in the same + state type. + name: Optional name for the graph, if not provided the name will be inferred from the calling frame + on the first call to a graph method. + """ self.name = name parent_namespace = get_parent_namespace(inspect.currentframe()) @@ -48,34 +55,6 @@ def __init__( self._validate_edges() - def _register_node(self, node: type[BaseNode[StateT, RunEndT]], parent_namespace: dict[str, Any] | None) -> None: - node_id = node.get_id() - if existing_node := self.node_defs.get(node_id): - raise exceptions.GraphSetupError( - f'Node ID `{node_id}` is not unique — found on {existing_node.node} and {node}' - ) - else: - self.node_defs[node_id] = node.get_node_def(parent_namespace) - - def _validate_edges(self): - known_node_ids = self.node_defs.keys() - bad_edges: dict[str, list[str]] = {} - - for node_id, node_def in self.node_defs.items(): - for edge in node_def.next_node_edges.keys(): - if edge not in known_node_ids: - bad_edges.setdefault(edge, []).append(f'`{node_id}`') - - if bad_edges: - bad_edges_list = [f'`{k}` is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] - if len(bad_edges_list) == 1: - raise exceptions.GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.') - else: - b = '\n'.join(f' {be}' for be in bad_edges_list) - raise exceptions.GraphSetupError( - f'Nodes are referenced in the graph but not included in the graph:\n{b}' - ) - async def next( self, state: StateT, @@ -84,13 +63,24 @@ async def next( *, infer_name: bool = True, ) -> BaseNode[StateT, Any] | End[RunEndT]: + """Run a node in the graph and return the next node to run. + + Args: + state: The current state of the graph. + node: The node to run. + history: The history of the graph run so far. NOTE: this will be mutated to add the new step. + infer_name: Whether to infer the graph name from the calling frame. + + Returns: + The next node to run or [`End`][pydantic_graph.nodes.End] if the graph has finished. + """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) node_id = node.get_id() if node_id not in self.node_defs: raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') - history_step: NodeEvent[StateT, RunEndT] | None = NodeEvent(state, node) + history_step: NodeStep[StateT, RunEndT] | None = NodeStep(state, node) history.append(history_step) ctx = GraphContext(state) @@ -107,6 +97,15 @@ async def run( *, infer_name: bool = True, ) -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: + """Run the graph from a starting node until it ends. + + Args: + state: The initial state of the graph. + start_node: the first node to run. + infer_name: Whether to infer the graph name from the calling frame. + + Returns: The result type from ending the run and the history of the run. + """ history: list[HistoryStep[StateT, RunEndT]] = [] if infer_name and self.name is None: self._infer_name(inspect.currentframe()) @@ -119,7 +118,7 @@ async def run( while True: next_node = await self.next(state, start_node, history, infer_name=False) if isinstance(next_node, End): - history.append(EndEvent(state, next_node)) + history.append(EndStep(state, next_node)) run_span.set_attribute('history', history) return next_node.data, history elif isinstance(next_node, BaseNode): @@ -136,13 +135,29 @@ def mermaid_code( self, *, start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, - highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, - highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, - edge_labels: bool = True, title: str | None | Literal[False] = None, + edge_labels: bool = True, notes: bool = True, + highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, + highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, infer_name: bool = True, ) -> str: + """Generate a diagram representing the graph as [mermaid](https://mermaid.js.org/) chart. + + This method calls [`pydantic_graph.mermaid.generate_code`][pydantic_graph.mermaid.generate_code]. + + Args: + start_node: The node or nodes to start the graph from. + title: The title of the diagram, use `False` to not include a title. + edge_labels: Whether to include edge labels. + notes: Whether to include notes on each node. + highlighted_nodes: Optional node or nodes to highlight. + highlight_css: The CSS to use for highlighting nodes. + infer_name: Whether to infer the graph name from the calling frame. + + Returns: + The mermaid code for the graph, which can then be rendered as a diagram. + """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) if title is None and self.name: @@ -158,6 +173,22 @@ def mermaid_code( ) def mermaid_image(self, infer_name: bool = True, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes: + """Generate a diagram representing the graph as an image. + + The format and diagram can be customized using `kwargs`, + see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig]. + + !!! note "Uses external service" + This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink` + is a free service not affiliated with Pydantic. + + Args: + infer_name: Whether to infer the graph name from the calling frame. + **kwargs: Additional arguments to pass to `mermaid.request_image`. + + Returns: + The image bytes. + """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) if 'title' not in kwargs and self.name: @@ -167,12 +198,54 @@ def mermaid_image(self, infer_name: bool = True, **kwargs: Unpack[mermaid.Mermai def mermaid_save( self, path: Path | str, /, *, infer_name: bool = True, **kwargs: Unpack[mermaid.MermaidConfig] ) -> None: + """Generate a diagram representing the graph and save it as an image. + + The format and diagram can be customized using `kwargs`, + see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig]. + + !!! note "Uses external service" + This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink` + is a free service not affiliated with Pydantic. + + Args: + path: The path to save the image to. + infer_name: Whether to infer the graph name from the calling frame. + **kwargs: Additional arguments to pass to `mermaid.save_image`. + """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) if 'title' not in kwargs and self.name: kwargs['title'] = self.name mermaid.save_image(path, self, **kwargs) + def _register_node(self, node: type[BaseNode[StateT, RunEndT]], parent_namespace: dict[str, Any] | None) -> None: + node_id = node.get_id() + if existing_node := self.node_defs.get(node_id): + raise exceptions.GraphSetupError( + f'Node ID `{node_id}` is not unique — found on {existing_node.node} and {node}' + ) + else: + self.node_defs[node_id] = node.get_node_def(parent_namespace) + + def _validate_edges(self): + known_node_ids = self.node_defs.keys() + bad_edges: dict[str, list[str]] = {} + + for node_id, node_def in self.node_defs.items(): + for edge in node_def.next_node_edges.keys(): + if edge not in known_node_ids: + bad_edges.setdefault(edge, []).append(f'`{node_id}`') + + if bad_edges: + bad_edges_list = [f'`{k}` is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] + if len(bad_edges_list) == 1: + raise exceptions.GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.') + else: + b = '\n'.join(f' {be}' for be in bad_edges_list) + raise exceptions.GraphSetupError( + f'Nodes are referenced in the graph but not included in the graph:\n{b}' + ) + def _infer_name(self, function_frame: FrameType | None) -> None: """Infer the agent name from the call frame. diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index e89a00da49..1e72493759 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -18,10 +18,8 @@ __all__ = 'NodeIdent', 'DEFAULT_HIGHLIGHT_CSS', 'generate_code', 'MermaidConfig', 'request_image', 'save_image' - - -NodeIdent: TypeAlias = 'type[BaseNode[Any, Any]] | BaseNode[Any, Any] | str' DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32' +"""The default CSS to use for highlighting nodes.""" def generate_code( # noqa: C901 @@ -46,7 +44,8 @@ def generate_code( # noqa: C901 edge_labels: Whether to include edge labels in the diagram. notes: Whether to include notes in the diagram. - Returns: The Mermaid code for the graph. + Returns: + The Mermaid code for the graph. """ start_node_ids = set(_node_ids(start_node or ())) for node_id in start_node_ids: @@ -110,53 +109,6 @@ def _node_ids(node_idents: Sequence[NodeIdent] | NodeIdent) -> Iterable[str]: yield node.get_id() -class MermaidConfig(TypedDict, total=False): - """Parameters to configure mermaid chart generation.""" - - start_node: Sequence[NodeIdent] | NodeIdent - """Identifiers of nodes that start the graph.""" - highlighted_nodes: Sequence[NodeIdent] | NodeIdent - """Identifiers of nodes to highlight.""" - highlight_css: str - """CSS to use for highlighting nodes.""" - title: str | None - """The title of the diagram.""" - edge_labels: bool - """Whether to include edge labels in the diagram.""" - notes: bool - """Whether to include notes on nodes in the diagram, defaults to true.""" - image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] - """The image type to generate. If unspecified, the default behavior is `'jpeg'`.""" - pdf_fit: bool - """When using image_type='pdf', whether to fit the diagram to the PDF page.""" - pdf_landscape: bool - """When using image_type='pdf', whether to use landscape orientation for the PDF. - - This has no effect if using `pdf_fit`. - """ - pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] - """When using image_type='pdf', the paper size of the PDF.""" - background_color: str - """The background color of the diagram. - - If None, the default transparent background is used. The color value is interpreted as a hexadecimal color - code by default (and should not have a leading '#'), but you can also use named colors by prefixing the - value with `'!'`. For example, valid choices include `background_color='!white'` or `background_color='FF0000'`. - """ - theme: Literal['default', 'neutral', 'dark', 'forest'] - """The theme of the diagram. Defaults to 'default'.""" - width: int - """The width of the diagram.""" - height: int - """The height of the diagram.""" - scale: Annotated[float, Ge(1), Le(3)] - """The scale of the diagram. - - The scale must be a number between 1 and 3, and you can only set a scale if one or both of width and height are set. - """ - httpx_client: httpx.Client - - def request_image( graph: Graph[Any, Any], /, @@ -244,3 +196,62 @@ def save_image( image_data = request_image(graph, **kwargs) path.write_bytes(image_data) + + +class MermaidConfig(TypedDict, total=False): + """Parameters to configure mermaid chart generation.""" + + start_node: Sequence[NodeIdent] | NodeIdent + """Identifiers of nodes that start the graph.""" + highlighted_nodes: Sequence[NodeIdent] | NodeIdent + """Identifiers of nodes to highlight.""" + highlight_css: str + """CSS to use for highlighting nodes.""" + title: str | None + """The title of the diagram.""" + edge_labels: bool + """Whether to include edge labels in the diagram.""" + notes: bool + """Whether to include notes on nodes in the diagram, defaults to true.""" + image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] + """The image type to generate. If unspecified, the default behavior is `'jpeg'`.""" + pdf_fit: bool + """When using image_type='pdf', whether to fit the diagram to the PDF page.""" + pdf_landscape: bool + """When using image_type='pdf', whether to use landscape orientation for the PDF. + + This has no effect if using `pdf_fit`. + """ + pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] + """When using image_type='pdf', the paper size of the PDF.""" + background_color: str + """The background color of the diagram. + + If None, the default transparent background is used. The color value is interpreted as a hexadecimal color + code by default (and should not have a leading '#'), but you can also use named colors by prefixing the + value with `'!'`. For example, valid choices include `background_color='!white'` or `background_color='FF0000'`. + """ + theme: Literal['default', 'neutral', 'dark', 'forest'] + """The theme of the diagram. Defaults to 'default'.""" + width: int + """The width of the diagram.""" + height: int + """The height of the diagram.""" + scale: Annotated[float, Ge(1), Le(3)] + """The scale of the diagram. + + The scale must be a number between 1 and 3, and you can only set a scale if one or both of width and height are set. + """ + httpx_client: httpx.Client + """An HTTPX client to use for requests, mostly for testing purposes.""" + + +NodeIdent: TypeAlias = 'type[BaseNode[Any, Any]] | BaseNode[Any, Any] | str' +"""A type alias for a node identifier. + +This can be: + +- A node instance (instance of a subclass of [`BaseNode`][pydantic_graph.nodes.BaseNode]). +- A node class (subclass of [`BaseNode`][pydantic_graph.nodes.BaseNode]). +- A string representing the node ID. +""" diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index 5fd5755881..1881b10b44 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -10,10 +10,12 @@ from . import _utils, exceptions from .state import StateT -__all__ = 'GraphContext', 'BaseNode', 'End', 'Edge', 'NodeDef' +__all__ = 'GraphContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'RunEndT', 'NodeRunEndT' RunEndT = TypeVar('RunEndT', default=None) +"""Type variable for the return type of a graph [`run`][pydantic_graph.graph.Graph.run].""" NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) +"""Type variable for the return type of a node [`run`][pydantic_graph.nodes.BaseNode.run].""" @dataclass @@ -21,6 +23,7 @@ class GraphContext(Generic[StateT]): """Context for a graph.""" state: StateT + """The state of the graph.""" class BaseNode(ABC, Generic[StateT, NodeRunEndT]): @@ -30,7 +33,23 @@ class BaseNode(ABC, Generic[StateT, NodeRunEndT]): """Set to `False` to not generate mermaid diagram notes from the class's docstring.""" @abstractmethod - async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: ... + async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: + """Run the node. + + This is an abstract method that must be implemented by subclasses. + + !!! note "Return types used at runtime" + The return type of this method are read by `pydantic_graph` at runtime and used to define which + nodes can be called next in the graph. This is displayed in [mermaid diagrams](mermaid.md) + and enforced when running the graph. + + Args: + ctx: The graph context. + + Returns: + The next node to run or [`End`][pydantic_graph.nodes.End] to signal the end of the graph. + """ + ... @classmethod @cache @@ -92,6 +111,7 @@ class End(Generic[RunEndT]): """Type to return from a node to signal the end of the graph.""" data: RunEndT + """Data to return from the graph.""" @dataclass @@ -99,12 +119,15 @@ class Edge: """Annotation to apply a label to an edge in a graph.""" label: str | None + """Label for the edge.""" @dataclass class NodeDef(Generic[StateT, NodeRunEndT]): """Definition of a node. + This is an internal representation of a node, it shouldn't be necessary to use it directly. + Used by [`Graph`][pydantic_graph.graph.Graph] to store information about a node, and when generating mermaid graphs. """ diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index b8dbff7a72..c01e23f1b4 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -6,15 +6,16 @@ from datetime import datetime from typing import TYPE_CHECKING, Generic, Literal, Union -from typing_extensions import Never, Self, TypeVar +from typing_extensions import Self, TypeVar from . import _utils -__all__ = 'AbstractState', 'StateT', 'NodeEvent', 'EndEvent', 'HistoryStep' +__all__ = 'AbstractState', 'StateT', 'NodeStep', 'EndStep', 'HistoryStep' if TYPE_CHECKING: - from pydantic_graph import BaseNode - from pydantic_graph.nodes import End + from .nodes import BaseNode, End, RunEndT +else: + RunEndT = TypeVar('RunEndT', default=None) class AbstractState(ABC): @@ -30,21 +31,24 @@ def deep_copy(self) -> Self: return copy.deepcopy(self) -RunEndT = TypeVar('RunEndT', default=None) -NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) StateT = TypeVar('StateT', bound=Union[None, AbstractState], default=None) +"""Type variable for the state in a graph.""" @dataclass -class NodeEvent(Generic[StateT, RunEndT]): +class NodeStep(Generic[StateT, RunEndT]): """History step describing the execution of a node in a graph.""" state: StateT + """The state of the graph after the node has run.""" node: BaseNode[StateT, RunEndT] + """The node that was run.""" start_ts: datetime = field(default_factory=_utils.now_utc) + """The timestamp when the node started running.""" duration: float | None = None - - kind: Literal['step'] = 'step' + """The duration of the node run in seconds.""" + kind: Literal['node'] = 'node' + """The kind of history step, can be used as a discriminator when deserializing history.""" def __post_init__(self): # Copy the state to prevent it from being modified by other code @@ -55,14 +59,17 @@ def summary(self) -> str: @dataclass -class EndEvent(Generic[StateT, RunEndT]): +class EndStep(Generic[StateT, RunEndT]): """History step describing the end of a graph run.""" state: StateT + """The state of the graph after the run.""" result: End[RunEndT] + """The result of the graph run.""" ts: datetime = field(default_factory=_utils.now_utc) - + """The timestamp when the graph run ended.""" kind: Literal['end'] = 'end' + """The kind of history step, can be used as a discriminator when deserializing history.""" def __post_init__(self): # Copy the state to prevent it from being modified by other code @@ -79,4 +86,5 @@ def _deep_copy_state(state: StateT) -> StateT: return state.deep_copy() -HistoryStep = Union[NodeEvent[StateT, RunEndT], EndEvent[StateT, RunEndT]] +HistoryStep = Union[NodeStep[StateT, RunEndT], EndStep[StateT, RunEndT]] +"""A step in the history of a graph run.""" diff --git a/pydantic_graph/pyproject.toml b/pydantic_graph/pyproject.toml index 17aaa1fba7..89cbad516c 100644 --- a/pydantic_graph/pyproject.toml +++ b/pydantic_graph/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "pydantic-graph" -version = "0.0.17" +version = "0.0.18" description = "Graph and state machine library" authors = [ { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, @@ -36,7 +36,6 @@ requires-python = ">=3.9" dependencies = [ "httpx>=0.27.2", "logfire-api>=1.2.0", - "pydantic>=2.10", ] [tool.hatch.build.targets.wheel] diff --git a/tests/graph/test_main.py b/tests/graph/test_main.py index a26e9d7663..f82e92e5e6 100644 --- a/tests/graph/test_main.py +++ b/tests/graph/test_main.py @@ -13,13 +13,13 @@ from pydantic_graph import ( BaseNode, End, - EndEvent, + EndStep, Graph, GraphContext, GraphRuntimeError, GraphSetupError, HistoryStep, - NodeEvent, + NodeStep, ) from ..conftest import IsFloat, IsNow @@ -60,25 +60,25 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq assert my_graph.name == 'my_graph' assert history == snapshot( [ - NodeEvent( + NodeStep( state=None, node=Float2String(input_data=3.14), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeEvent( + NodeStep( state=None, node=String2Length(input_data='3.14'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeEvent( + NodeStep( state=None, node=Double(input_data=4), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndEvent( + EndStep( state=None, result=End(data=8), ts=IsNow(tz=timezone.utc), @@ -90,37 +90,37 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq assert result == 42 assert history == snapshot( [ - NodeEvent( + NodeStep( state=None, node=Float2String(input_data=3.14159), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeEvent( + NodeStep( state=None, node=String2Length(input_data='3.14159'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeEvent( + NodeStep( state=None, node=Double(input_data=7), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeEvent( + NodeStep( state=None, node=String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeEvent( + NodeStep( state=None, node=Double(input_data=21), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndEvent( + EndStep( state=None, result=End(data=42), ts=IsNow(tz=timezone.utc), @@ -290,7 +290,7 @@ async def run(self, ctx: GraphContext) -> Foo: n = await g.next(None, Foo(), history) assert n == Bar() assert g.name == 'g' - assert history == snapshot([NodeEvent(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat())]) + assert history == snapshot([NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat())]) assert isinstance(n, Bar) n2 = await g.next(None, n, history) @@ -298,7 +298,7 @@ async def run(self, ctx: GraphContext) -> Foo: assert history == snapshot( [ - NodeEvent(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - NodeEvent(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeStep(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), ] ) diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 9f790fc0ca..1f40df5c8c 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -11,7 +11,7 @@ import pytest from inline_snapshot import snapshot -from pydantic_graph import BaseNode, Edge, End, EndEvent, Graph, GraphContext, GraphSetupError, NodeEvent +from pydantic_graph import BaseNode, Edge, End, EndStep, Graph, GraphContext, GraphSetupError, NodeStep from pydantic_graph.nodes import NodeDef from ..conftest import IsFloat, IsNow @@ -60,19 +60,19 @@ async def test_run_graph(): assert result is None assert history == snapshot( [ - NodeEvent( + NodeStep( state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeEvent( + NodeStep( state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndEvent(state=None, result=End(data=None), ts=IsNow(tz=timezone.utc)), + EndStep(state=None, result=End(data=None), ts=IsNow(tz=timezone.utc)), ] ) diff --git a/uv.lock b/uv.lock index a5b648133e..1e998e93c7 100644 --- a/uv.lock +++ b/uv.lock @@ -594,13 +594,13 @@ name = "diff-cover" version = "9.2.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9.17'", + "python_full_version < '3.10'", ] dependencies = [ - { name = "chardet", marker = "python_full_version < '3.9.17'" }, - { name = "jinja2", marker = "python_full_version < '3.9.17'" }, - { name = "pluggy", marker = "python_full_version < '3.9.17'" }, - { name = "pygments", marker = "python_full_version < '3.9.17'" }, + { name = "chardet", marker = "python_full_version < '3.10'" }, + { name = "jinja2", marker = "python_full_version < '3.10'" }, + { name = "pluggy", marker = "python_full_version < '3.10'" }, + { name = "pygments", marker = "python_full_version < '3.10'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/44/3a/e49ccba052a4dda264fbad4f467739ecc63498f7223bfc03d4bfac23ea95/diff_cover-9.2.0.tar.gz", hash = "sha256:85a0b353ebbb678f9e87ea303f75b545bd0baca38f563219bb72f2ae862bba36", size = 94857 } wheels = [ @@ -612,15 +612,16 @@ name = "diff-cover" version = "9.2.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.9.17' and python_full_version < '3.11'", - "python_full_version >= '3.11' and python_full_version < '3.13'", + "python_full_version == '3.10.*'", + "python_full_version == '3.12.*'", + "python_full_version == '3.11.*'", "python_full_version >= '3.13'", ] dependencies = [ - { name = "chardet", marker = "python_full_version >= '3.9.17'" }, - { name = "jinja2", marker = "python_full_version >= '3.9.17'" }, - { name = "pluggy", marker = "python_full_version >= '3.9.17'" }, - { name = "pygments", marker = "python_full_version >= '3.9.17'" }, + { name = "chardet", marker = "python_full_version >= '3.10'" }, + { name = "jinja2", marker = "python_full_version >= '3.10'" }, + { name = "pluggy", marker = "python_full_version >= '3.10'" }, + { name = "pygments", marker = "python_full_version >= '3.10'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/36/b0/f3ccf97926f6e5cc76d5ece42f3c685d75673d1886fcec62886b8b00c51a/diff_cover-9.2.1.tar.gz", hash = "sha256:5fa5b2d71ccf5d16cd222a71c2ca069d9bf5fa3d657f6fac9b4d9c23379323bf", size = 94964 } wheels = [ @@ -2137,8 +2138,8 @@ dev = [ { name = "anyio" }, { name = "coverage", extra = ["toml"] }, { name = "devtools" }, - { name = "diff-cover", version = "9.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9.17'" }, - { name = "diff-cover", version = "9.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9.17'" }, + { name = "diff-cover", version = "9.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "diff-cover", version = "9.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "dirty-equals" }, { name = "inline-snapshot" }, { name = "pytest" }, @@ -2275,30 +2276,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fd/ab/718d9a1c41bb8d3e0e04d15b68b8afc135f8fcf552705b62f226225065c7/pydantic_core-2.27.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:62ba45e21cf6571d7f716d903b5b7b6d2617e2d5d67c0923dc47b9d41369f840", size = 2002035 }, ] -[[package]] -name = "pydub" -version = "0.25.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fe/9a/e6bca0eed82db26562c73b5076539a4a08d3cffd19c3cc5913a3e61145fd/pydub-0.25.1.tar.gz", hash = "sha256:980a33ce9949cab2a569606b65674d748ecbca4f0796887fd6f46173a7b0d30f", size = 38326 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a6/53/d78dc063216e62fc55f6b2eebb447f6a4b0a59f55c8406376f76bf959b08/pydub-0.25.1-py2.py3-none-any.whl", hash = "sha256:65617e33033874b59d87db603aa1ed450633288aefead953b30bded59cb599a6", size = 32327 }, -] - [[package]] name = "pydantic-graph" -version = "0.0.17" +version = "0.0.18" source = { editable = "pydantic_graph" } dependencies = [ { name = "httpx" }, { name = "logfire-api" }, - { name = "pydantic" }, ] [package.metadata] requires-dist = [ { name = "httpx", specifier = ">=0.27.2" }, { name = "logfire-api", specifier = ">=1.2.0" }, - { name = "pydantic", specifier = ">=2.10" }, +] + +[[package]] +name = "pydub" +version = "0.25.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/9a/e6bca0eed82db26562c73b5076539a4a08d3cffd19c3cc5913a3e61145fd/pydub-0.25.1.tar.gz", hash = "sha256:980a33ce9949cab2a569606b65674d748ecbca4f0796887fd6f46173a7b0d30f", size = 38326 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/53/d78dc063216e62fc55f6b2eebb447f6a4b0a59f55c8406376f76bf959b08/pydub-0.25.1-py2.py3-none-any.whl", hash = "sha256:65617e33033874b59d87db603aa1ed450633288aefead953b30bded59cb599a6", size = 32327 }, ] [[package]] From 3284ff1876edfd15932a166d1ca5dc20805891ab Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 11 Jan 2025 14:49:22 +0000 Subject: [PATCH 37/57] fix state, more docs --- .../email_extract_graph.py | 156 ------------------ pydantic_graph/pydantic_graph/__init__.py | 3 +- pydantic_graph/pydantic_graph/graph.py | 92 ++++++++++- pydantic_graph/pydantic_graph/state.py | 51 +++--- tests/graph/{test_main.py => test_graph.py} | 0 tests/graph/test_state.py | 58 +++++++ tests/test_examples.py | 2 +- 7 files changed, 167 insertions(+), 195 deletions(-) delete mode 100644 examples/pydantic_ai_examples/email_extract_graph.py rename tests/graph/{test_main.py => test_graph.py} (100%) create mode 100644 tests/graph/test_state.py diff --git a/examples/pydantic_ai_examples/email_extract_graph.py b/examples/pydantic_ai_examples/email_extract_graph.py deleted file mode 100644 index 82c98a370e..0000000000 --- a/examples/pydantic_ai_examples/email_extract_graph.py +++ /dev/null @@ -1,156 +0,0 @@ -from __future__ import annotations as _annotations - -import asyncio -from dataclasses import dataclass -from datetime import datetime, timedelta - -import logfire -from devtools import debug -from pydantic import BaseModel -from pydantic_graph import AbstractState, BaseNode, End, Graph, GraphContext - -from pydantic_ai import Agent, RunContext - -logfire.configure(send_to_logfire='if-token-present') - - -class EventDetails(BaseModel): - title: str - location: str - start_ts: datetime - end_ts: datetime - - -class State(AbstractState, BaseModel): - email_content: str - skip_events: list[str] = [] - attempt: int = 0 - - def serialize(self) -> bytes | None: - return self.model_dump_json(exclude={'email_content'}).encode() - - -class RawEventDetails(BaseModel): - title: str - location: str - start_ts: str - duration: str - - -extract_agent = Agent('openai:gpt-4o', result_type=RawEventDetails, deps_type=list[str]) - - -@extract_agent.system_prompt -def extract_system_prompt(ctx: RunContext[list[str]]): - prompt = 'Extract event details from the email body.' - if ctx.deps: - skip_events = '\n'.join(ctx.deps) - prompt += f'\n\nDo not return the following events:\n{skip_events}' - return prompt - - -@dataclass -class ExtractEvent(BaseNode[State]): - async def run(self, ctx: GraphContext[State]) -> CleanEvent: - event = await extract_agent.run( - ctx.state.email_content, deps=ctx.state.skip_events - ) - return CleanEvent(event.data) - - -# agent used to extract the timestamp from the string in `CleanEvent` -timestamp_agent = Agent('openai:gpt-4o', result_type=datetime) - - -@timestamp_agent.system_prompt -def timestamp_system_prompt(): - return f'Extract the timestamp from the string, the current timestamp is: {datetime.now().isoformat()}' - - -# agent used to extract the duration from the string in `CleanEvent` -duration_agent = Agent( - 'openai:gpt-4o', - result_type=timedelta, - system_prompt='Extract the duration from the string as an ISO 8601 interval.', -) - - -@dataclass -class CleanEvent(BaseNode[State]): - input_data: RawEventDetails - - async def run(self, ctx: GraphContext[State]) -> InspectEvent: - start_ts, duration = await asyncio.gather( - timestamp_agent.run(self.input_data.start_ts), - duration_agent.run(self.input_data.duration), - ) - return InspectEvent( - EventDetails( - title=self.input_data.title, - location=self.input_data.location, - start_ts=start_ts.data, - end_ts=start_ts.data + duration.data, - ) - ) - - -@dataclass -class InspectEvent(BaseNode[State, EventDetails | None]): - input_data: EventDetails - - async def run( - self, ctx: GraphContext[State] - ) -> ExtractEvent | End[EventDetails | None]: - now = datetime.now() - if self.input_data.start_ts.tzinfo is not None: - now = now.astimezone(self.input_data.start_ts.tzinfo) - - if self.input_data.start_ts > now: - return End(self.input_data) - ctx.state.attempt += 1 - if ctx.state.attempt > 2: - return End(None) - else: - ctx.state.skip_events.append(self.input_data.title) - return ExtractEvent() - - -graph = Graph[State, EventDetails | None]( - nodes=( - ExtractEvent, - CleanEvent, - InspectEvent, - ) -) -print(graph.mermaid_code(start_node=ExtractEvent)) - -email = """ -Hi Samuel, - -I hope this message finds you well! I wanted to share a quick update on our recent and upcoming team events. - -Firstly, a big thank you to everyone who participated in last month's -Team Building Retreat held on November 15th 2024 for 1 day. -It was a fantastic opportunity to enhance our collaboration and communication skills while having fun. Your -feedback was incredibly positive, and we're already planning to make the next retreat even better! - -Looking ahead, I'm excited to invite you all to our Annual Year-End Gala on January 20th 2025. -This event will be held at the Grand City Ballroom starting at 6 PM until 8pm. It promises to be an evening full -of entertainment, good food, and great company, celebrating the achievements and hard work of our amazing team -over the past year. - -Please mark your calendars and RSVP by January 10th. I hope to see all of you there! - -Best regards, -""" - - -async def main(): - state = State(email_content=email) - history = [] - result = await graph.run(state, ExtractEvent()) - debug(result, history) - - -if __name__ == '__main__': - asyncio.run(main()) diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index 5102117738..135f7529a5 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -1,7 +1,7 @@ from .exceptions import GraphRuntimeError, GraphSetupError from .graph import Graph from .nodes import BaseNode, Edge, End, GraphContext -from .state import AbstractState, EndStep, HistoryStep, NodeStep +from .state import EndStep, HistoryStep, NodeStep __all__ = ( 'Graph', @@ -9,7 +9,6 @@ 'End', 'GraphContext', 'Edge', - 'AbstractState', 'EndStep', 'HistoryStep', 'NodeStep', diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index eed901fe95..db586f9525 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -6,7 +6,7 @@ from pathlib import Path from time import perf_counter from types import FrameType -from typing import TYPE_CHECKING, Any, Generic +from typing import TYPE_CHECKING, Any, Callable, Generic import logfire_api from typing_extensions import Literal, Never, ParamSpec, TypeVar, Unpack, assert_never @@ -14,7 +14,7 @@ from . import _utils, exceptions, mermaid from ._utils import get_parent_namespace from .nodes import BaseNode, End, GraphContext, NodeDef -from .state import EndStep, HistoryStep, NodeStep, StateT +from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state __all__ = ('Graph',) @@ -27,16 +27,74 @@ @dataclass(init=False) class Graph(Generic[StateT, RunEndT]): - """Definition of a graph.""" + """Definition of a graph. + + In `pydantic-graph`, a graph is a collection of nodes that can be run in sequence. The nodes define + their outgoing edges — e.g. which nodes may be run next, and thereby the structure of the graph. + + Here's a very simple example of a graph which increments a number by 1, but makes sure the number is never + 42 at the end. Note in this example we don't run the graph, but instead just generate a mermaid diagram + using [`mermaid_code`][pydantic_graph.graph.Graph.mermaid_code]: + + ```py {title="never_42.py"} + from __future__ import annotations + + from dataclasses import dataclass + from pydantic_graph import BaseNode, End, Graph, GraphContext + + + @dataclass + class MyState: + number: int + + + @dataclass + class Increment(BaseNode[MyState]): + async def run(self, ctx: GraphContext) -> Check42: + ctx.state.number += 1 + return Check42() + + + @dataclass + class Check42(BaseNode[MyState]): + async def run(self, ctx: GraphContext) -> Increment | End: + if ctx.state.number == 42: + return Increment() + else: + return End(None) + + + never_42_graph = Graph(nodes=(Increment, Check42)) + print(never_42_graph.mermaid_code(start_node=Increment)) + ``` + _(This example is complete, it can be run "as is")_ + + The rendered mermaid diagram will look like this: + + ```mermaid + --- + title: never_42_graph + --- + stateDiagram-v2 + [*] --> Increment + Increment --> Check42 + Check42 --> Increment + Check42 --> [*] + ``` + + See [`run`][pydantic_graph.graph.Graph.run] For an example of running graph. + """ name: str | None node_defs: dict[str, NodeDef[StateT, RunEndT]] + snapshot_state: Callable[[StateT], StateT] def __init__( self, *, nodes: Sequence[type[BaseNode[StateT, RunEndT]]], name: str | None = None, + snapshot_state: Callable[[StateT], StateT] = deep_copy_state, ): """Create a graph from a sequence of nodes. @@ -45,8 +103,12 @@ def __init__( state type. name: Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method. + snapshot_state: A function to snapshot the state of the graph, this is used in + [`NodeStep`][pydantic_graph.state.NodeStep] and [`EndStep`][pydantic_graph.state.EndStep] to record + the state before each step. """ self.name = name + self.snapshot_state = snapshot_state parent_namespace = get_parent_namespace(inspect.currentframe()) self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = {} @@ -80,7 +142,7 @@ async def next( if node_id not in self.node_defs: raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') - history_step: NodeStep[StateT, RunEndT] | None = NodeStep(state, node) + history_step: NodeStep[StateT, RunEndT] = NodeStep(state, node) history.append(history_step) ctx = GraphContext(state) @@ -101,10 +163,28 @@ async def run( Args: state: The initial state of the graph. - start_node: the first node to run. + start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, + you need to provide the starting node. infer_name: Whether to infer the graph name from the calling frame. Returns: The result type from ending the run and the history of the run. + + Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: + + ```py {title="run_never_42.py"} + from never_42 import MyState, Increment, never_42_graph + + async def main(): + state = MyState(1) + _, history = await never_42_graph.run(state, Increment()) + print(state) + print(history) + + state = MyState(41) + _, history = await never_42_graph.run(state, Increment()) + print(state) + print(history) + ``` """ history: list[HistoryStep[StateT, RunEndT]] = [] if infer_name and self.name is None: @@ -118,7 +198,7 @@ async def run( while True: next_node = await self.next(state, start_node, history, infer_name=False) if isinstance(next_node, End): - history.append(EndStep(state, next_node)) + history.append(EndStep(state, next_node, snapshot_state=self.snapshot_state)) run_span.set_attribute('history', history) return next_node.data, history elif isinstance(next_node, BaseNode): diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index c01e23f1b4..9cd197a637 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -1,16 +1,15 @@ from __future__ import annotations as _annotations import copy -from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import InitVar, dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Generic, Literal, Union +from typing import TYPE_CHECKING, Callable, Generic, Literal, Union -from typing_extensions import Self, TypeVar +from typing_extensions import TypeVar from . import _utils -__all__ = 'AbstractState', 'StateT', 'NodeStep', 'EndStep', 'HistoryStep' +__all__ = 'StateT', 'NodeStep', 'EndStep', 'HistoryStep', 'deep_copy_state' if TYPE_CHECKING: from .nodes import BaseNode, End, RunEndT @@ -18,21 +17,16 @@ RunEndT = TypeVar('RunEndT', default=None) -class AbstractState(ABC): - """Abstract class for a state object.""" - - @abstractmethod - def serialize(self) -> bytes | None: - """Serialize the state object.""" - raise NotImplementedError - - def deep_copy(self) -> Self: - """Create a deep copy of the state object.""" - return copy.deepcopy(self) +StateT = TypeVar('StateT', default=None) +"""Type variable for the state in a graph.""" -StateT = TypeVar('StateT', bound=Union[None, AbstractState], default=None) -"""Type variable for the state in a graph.""" +def deep_copy_state(state: StateT) -> StateT: + """Default method for snapshotting the state in a graph run, uses [`copy.deepcopy`][copy.deepcopy].""" + if state is None: + return state + else: + return copy.deepcopy(state) @dataclass @@ -40,7 +34,7 @@ class NodeStep(Generic[StateT, RunEndT]): """History step describing the execution of a node in a graph.""" state: StateT - """The state of the graph after the node has run.""" + """The state of the graph before the node is run.""" node: BaseNode[StateT, RunEndT] """The node that was run.""" start_ts: datetime = field(default_factory=_utils.now_utc) @@ -49,10 +43,12 @@ class NodeStep(Generic[StateT, RunEndT]): """The duration of the node run in seconds.""" kind: Literal['node'] = 'node' """The kind of history step, can be used as a discriminator when deserializing history.""" + snapshot_state: InitVar[Callable[[StateT], StateT]] = deep_copy_state + """Function to snapshot the state of the graph.""" - def __post_init__(self): + def __post_init__(self, snapshot_state: Callable[[StateT], StateT]): # Copy the state to prevent it from being modified by other code - self.state = _deep_copy_state(self.state) + self.state = snapshot_state(self.state) def summary(self) -> str: return str(self.node) @@ -70,21 +66,16 @@ class EndStep(Generic[StateT, RunEndT]): """The timestamp when the graph run ended.""" kind: Literal['end'] = 'end' """The kind of history step, can be used as a discriminator when deserializing history.""" + snapshot_state: InitVar[Callable[[StateT], StateT]] = deep_copy_state + """Function to snapshot the state of the graph.""" - def __post_init__(self): + def __post_init__(self, snapshot_state: Callable[[StateT], StateT]): # Copy the state to prevent it from being modified by other code - self.state = _deep_copy_state(self.state) + self.state = snapshot_state(self.state) def summary(self) -> str: return str(self.result) -def _deep_copy_state(state: StateT) -> StateT: - if state is None: - return state - else: - return state.deep_copy() - - HistoryStep = Union[NodeStep[StateT, RunEndT], EndStep[StateT, RunEndT]] """A step in the history of a graph run.""" diff --git a/tests/graph/test_main.py b/tests/graph/test_graph.py similarity index 100% rename from tests/graph/test_main.py rename to tests/graph/test_graph.py diff --git a/tests/graph/test_state.py b/tests/graph/test_state.py new file mode 100644 index 0000000000..b363e37674 --- /dev/null +++ b/tests/graph/test_state.py @@ -0,0 +1,58 @@ +from __future__ import annotations as _annotations + +from dataclasses import dataclass +from datetime import timezone + +import pytest +from inline_snapshot import snapshot + +from pydantic_graph import BaseNode, End, EndStep, Graph, GraphContext, NodeStep + +from ..conftest import IsFloat, IsNow + +pytestmark = pytest.mark.anyio + + +async def test_run_graph(): + @dataclass + class MyState: + x: int + y: str + + @dataclass + class Foo(BaseNode[MyState]): + async def run(self, ctx: GraphContext[MyState]) -> Bar: + ctx.state.x += 1 + return Bar() + + @dataclass + class Bar(BaseNode[MyState, str]): + async def run(self, ctx: GraphContext[MyState]) -> End[str]: + ctx.state.y += 'y' + return End(f'x={ctx.state.x} y={ctx.state.y}') + + graph = Graph(nodes=(Foo, Bar)) + s = MyState(1, '') + result, history = await graph.run(s, Foo()) + assert result == snapshot('x=2 y=y') + assert history == snapshot( + [ + NodeStep( + state=MyState(x=1, y=''), + node=Foo(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + ), + NodeStep( + state=MyState(x=2, y=''), + node=Bar(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + ), + EndStep( + state=MyState(x=2, y='y'), + result=End('x=2 y=y'), + ts=IsNow(tz=timezone.utc), + ), + ] + ) diff --git a/tests/test_examples.py b/tests/test_examples.py index 9f29699990..2035b9b887 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -48,7 +48,7 @@ def find_filter_examples() -> Iterable[CodeExample]: - for ex in find_examples('docs', 'pydantic_ai_slim'): + for ex in find_examples('docs', 'pydantic_ai_slim', 'pydantic_graph'): if ex.path.name != '_utils.py': yield ex From b4d6c1cdc7ee215ad3ea62234eb6bebde483d1b6 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 11 Jan 2025 15:13:29 +0000 Subject: [PATCH 38/57] fix graph api examples --- pydantic_graph/pydantic_graph/graph.py | 125 ++++++++++++++++++++----- tests/test_examples.py | 1 + 2 files changed, 103 insertions(+), 23 deletions(-) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index db586f9525..d2ff4a155f 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -33,28 +33,25 @@ class Graph(Generic[StateT, RunEndT]): their outgoing edges — e.g. which nodes may be run next, and thereby the structure of the graph. Here's a very simple example of a graph which increments a number by 1, but makes sure the number is never - 42 at the end. Note in this example we don't run the graph, but instead just generate a mermaid diagram - using [`mermaid_code`][pydantic_graph.graph.Graph.mermaid_code]: + 42 at the end. ```py {title="never_42.py"} from __future__ import annotations from dataclasses import dataclass - from pydantic_graph import BaseNode, End, Graph, GraphContext + from pydantic_graph import BaseNode, End, Graph, GraphContext @dataclass class MyState: number: int - @dataclass class Increment(BaseNode[MyState]): async def run(self, ctx: GraphContext) -> Check42: ctx.state.number += 1 return Check42() - @dataclass class Check42(BaseNode[MyState]): async def run(self, ctx: GraphContext) -> Increment | End: @@ -63,26 +60,13 @@ async def run(self, ctx: GraphContext) -> Increment | End: else: return End(None) - never_42_graph = Graph(nodes=(Increment, Check42)) - print(never_42_graph.mermaid_code(start_node=Increment)) ``` _(This example is complete, it can be run "as is")_ - The rendered mermaid diagram will look like this: - - ```mermaid - --- - title: never_42_graph - --- - stateDiagram-v2 - [*] --> Increment - Increment --> Check42 - Check42 --> Increment - Check42 --> [*] - ``` - - See [`run`][pydantic_graph.graph.Graph.run] For an example of running graph. + See [`run`][pydantic_graph.graph.Graph.run] For an example of running graph, and + [`mermaid_code`][pydantic_graph.graph.Graph.mermaid_code] for an example of generating a mermaid diagram + from the graph. """ name: str | None @@ -172,18 +156,82 @@ async def run( Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: ```py {title="run_never_42.py"} - from never_42 import MyState, Increment, never_42_graph + from never_42 import Increment, MyState, never_42_graph async def main(): state = MyState(1) _, history = await never_42_graph.run(state, Increment()) print(state) + #> MyState(number=2) print(history) + ''' + [ + NodeStep( + state=MyState(number=1), + node=Increment(), + start_ts=datetime.datetime(...), + duration=0.0..., + kind='node', + ), + NodeStep( + state=MyState(number=2), + node=Check42(), + start_ts=datetime.datetime(...), + duration=0.0..., + kind='node', + ), + EndStep( + state=MyState(number=2), + result=End(data=None), + ts=datetime.datetime(...), + kind='end', + ), + ] + ''' state = MyState(41) _, history = await never_42_graph.run(state, Increment()) print(state) + #> MyState(number=43) print(history) + ''' + [ + NodeStep( + state=MyState(number=41), + node=Increment(), + start_ts=datetime.datetime(...), + duration=0.0..., + kind='node', + ), + NodeStep( + state=MyState(number=42), + node=Check42(), + start_ts=datetime.datetime(...), + duration=0.0..., + kind='node', + ), + NodeStep( + state=MyState(number=42), + node=Increment(), + start_ts=datetime.datetime(...), + duration=0.0..., + kind='node', + ), + NodeStep( + state=MyState(number=43), + node=Check42(), + start_ts=datetime.datetime(...), + duration=0.0..., + kind='node', + ), + EndStep( + state=MyState(number=43), + result=End(data=None), + ts=datetime.datetime(...), + kind='end', + ), + ] + ''' ``` """ history: list[HistoryStep[StateT, RunEndT]] = [] @@ -227,7 +275,7 @@ def mermaid_code( This method calls [`pydantic_graph.mermaid.generate_code`][pydantic_graph.mermaid.generate_code]. Args: - start_node: The node or nodes to start the graph from. + start_node: The node or nodes which can start the graph. title: The title of the diagram, use `False` to not include a title. edge_labels: Whether to include edge labels. notes: Whether to include notes on each node. @@ -237,6 +285,37 @@ def mermaid_code( Returns: The mermaid code for the graph, which can then be rendered as a diagram. + + Here's an example of generating a diagram for the graph from [above][pydantic_graph.graph.Graph]: + + ```py {title="never_42.py"} + from never_42 import Increment, never_42_graph + + print(never_42_graph.mermaid_code(start_node=Increment)) + ''' + --- + title: never_42_graph + --- + stateDiagram-v2 + [*] --> Increment + Increment --> Check42 + Check42 --> Increment + Check42 --> [*] + ''' + ``` + + The rendered diagram will look like this: + + ```mermaid + --- + title: never_42_graph + --- + stateDiagram-v2 + [*] --> Increment + Increment --> Check42 + Check42 --> Increment + Check42 --> [*] + ``` """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) diff --git a/tests/test_examples.py b/tests/test_examples.py index 2035b9b887..e6897d727f 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -133,6 +133,7 @@ def test_docs_examples( def print_callback(s: str) -> str: s = re.sub(r'datetime\.datetime\(.+?\)', 'datetime.datetime(...)', s, flags=re.DOTALL) + s = re.sub(r'\d\.\d{4,}e-0\d', '0.0...', s) return re.sub(r'datetime.date\(', 'date(', s) From 22708d99780c979cd731d98cac7c7503e6f1bbf8 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 11 Jan 2025 16:36:52 +0000 Subject: [PATCH 39/57] starting graph documentation --- docs/examples/question-graph.md | 37 +++++ docs/graph.md | 45 ++++++ docs/install.md | 10 ++ docs/multi-agent-applications.md | 7 +- .../pydantic_ai_examples/question_graph.py | 149 ++++++++++++++++++ mkdocs.yml | 2 + pydantic_ai_slim/pyproject.toml | 5 +- pydantic_graph/README.md | 2 +- 8 files changed, 248 insertions(+), 9 deletions(-) create mode 100644 docs/examples/question-graph.md create mode 100644 docs/graph.md create mode 100644 examples/pydantic_ai_examples/question_graph.py diff --git a/docs/examples/question-graph.md b/docs/examples/question-graph.md new file mode 100644 index 0000000000..e208135286 --- /dev/null +++ b/docs/examples/question-graph.md @@ -0,0 +1,37 @@ +# Question Graph + +Example of a graph for asking and evaluating questions. + +Demonstrates: + +* [`pydantic_graph`](../graph.md) + +## Running the Example + +With [dependencies installed and environment variables set](./index.md#usage), run: + +```bash +python/uv-run -m pydantic_ai_examples.question_graph +``` + +## Example Code + +```python {title="question_graph.py"} +#! examples/pydantic_ai_examples/question_graph.py +``` + +The mermaid diagram generated in this example looks like this: + +```mermaid +--- +title: question_graph +--- +stateDiagram-v2 + [*] --> Ask + Ask --> Answer: ask the question + Answer --> Evaluate: answer the question + Evaluate --> Congratulate + Evaluate --> Castigate + Congratulate --> [*]: success + Castigate --> Ask: try again +``` diff --git a/docs/graph.md b/docs/graph.md new file mode 100644 index 0000000000..e1e01bc421 --- /dev/null +++ b/docs/graph.md @@ -0,0 +1,45 @@ +# Graphs + +!!! danger "Don't use a nail gun unless you need a nail gun" + If PydanticAI [agents](agents.md) are a hammer, and [multi-agent workflows](multi-agent-applications.md) are a sledgehammer, then graphs are a nail gun, with flames down the side: + + * sure, nail guns look cooler than hammers + * but nail guns take a lot more setup than hammers + * and nail guns don't make you a better builder, they make you a builder with a nail gun + * Lastly, (and at the risk of torturing this metaphor), if you're a fan of medieval tools like mallets and untyped Python, you probably won't like nail guns or PydanticAI approach to graphs. (But then again, if you're not a fan of type hints in Python, you've probably already bounced off PydanticAI to use one of the toy agent frameworks — good luck, and feel free to borrow my sledgehammer) + + In short, graphs are a powerful tool, but they're not the right tool for every job. Please consider other [multi-agent approaches](multi-agent-applications.md) before proceeding. Unless you're sure you need a graph, you probably don't. + +Graphs and associated finite state machines (FSMs) are a powerful abstraction to model, control and visualize complex workflows. + +Alongside PydanticAI, we've developed `pydantic-graph` — an async graph and state machine library for Python where nodes and edges are defined in pure Python using type hints. + +While this library is developed as part of the PydanticAI; it has no dependency on `pydantic-ai` and can be considered as a pure graph library. You may find it useful whether or not you're using PydanticAI or even building with GenAI. + +## Installation + +`pydantic-graph` a required dependency of `pydantic-ai`, and an optional dependency of `pydantic-ai-slim`, see [installation instructions](install.md) for more information. You can also install it directly: + +```bash +pip/uv-add pydantic-graph +``` + +## Basic Usage + +TODO + +## Typing + +TODO + +## Running Graphs + +TODO + +## State Machines + +TODO + +## Mermaid Diagrams + +TODO diff --git a/docs/install.md b/docs/install.md index 1cd336611e..c450d8d885 100644 --- a/docs/install.md +++ b/docs/install.md @@ -44,6 +44,16 @@ For example, if you're using just [`OpenAIModel`][pydantic_ai.models.openai.Open pip/uv-add 'pydantic-ai-slim[openai]' ``` +`pydantic-ai-slim` has the following optional groups: + +* `logfire` — installs [`logfire`](logfire.md) [PyPI ↗](https://pypi.org/project/logfire){:target="_blank"} +* `graph` - installs [`pydantic-graph`](graph.md) [PyPI ↗](https://pypi.org/project/pydantic-graph){:target="_blank"} +* `openai` — installs `openai` [PyPI ↗](https://pypi.org/project/openai){:target="_blank"} +* `vertexai` — installs `google-auth` [PyPI ↗](https://pypi.org/project/google-auth){:target="_blank"} and `requests` [PyPI ↗](https://pypi.org/project/requests){:target="_blank"} +* `anthropic` — installs `anthropic` [PyPI ↗](https://pypi.org/project/anthropic){:target="_blank"} +* `groq` — installs `groq` [PyPI ↗](https://pypi.org/project/groq){:target="_blank"} +* `mistral` — installs `mistralai` [PyPI ↗](https://pypi.org/project/mistralai){:target="_blank"} + See the [models](models.md) documentation for information on which optional dependencies are required for each model. You can also install dependencies for multiple models and use cases, for example: diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index 6556fdbf5f..e5ced781e0 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -5,7 +5,7 @@ There are roughly four levels of complexity when building applications with Pyda 1. Single agent workflows — what most of the `pydantic_ai` documentation covers 2. [Agent delegation](#agent-delegation) — agents using another agent via tools 3. [Programmatic agent hand-off](#programmatic-agent-hand-off) — one agent runs, then application code calls another agent -4. [Graph based control flow](#pydanticai-graphs) — for the most complex cases, a graph-based state machine can be used to control the execution of multiple agents +4. [Graph based control flow](graph.md) — for the most complex cases, a graph-based state machine can be used to control the execution of multiple agents Of course, you can combine multiple strategies in a single application. @@ -330,11 +330,6 @@ graph TB seat_preference_agent --> END ``` -## PydanticAI Graphs - -!!! example "Work in progress" - This is a work in progress and not yet documented, see [#528](https://github.com/pydantic/pydantic-ai/issues/528) and [#539](https://github.com/pydantic/pydantic-ai/issues/539) - ## Examples The following examples demonstrate how to use dependencies in PydanticAI: diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py new file mode 100644 index 0000000000..5f2c42bca9 --- /dev/null +++ b/examples/pydantic_ai_examples/question_graph.py @@ -0,0 +1,149 @@ +"""Example of a graph for asking and evaluating questions. + +Run with: + + uv run -m pydantic_ai_examples.question_graph +""" + +from __future__ import annotations as _annotations + +from dataclasses import dataclass, field +from typing import Annotated + +import logfire +from pydantic_graph import BaseNode, Edge, End, Graph, GraphContext, HistoryStep + +from pydantic_ai import Agent +from pydantic_ai.format_as_xml import format_as_xml +from pydantic_ai.messages import ModelMessage + +# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured +logfire.configure(send_to_logfire='if-token-present') + +ask_agent = Agent('openai:gpt-4o', result_type=str) + + +@dataclass +class QuestionState: + ask_agent_messages: list[ModelMessage] = field(default_factory=list) + evaluate_agent_messages: list[ModelMessage] = field(default_factory=list) + + +@dataclass +class Ask(BaseNode[QuestionState]): + """Generate a question to ask the user. + + Uses the GPT-4o model to generate the question. + """ + + async def run( + self, ctx: GraphContext[QuestionState] + ) -> Annotated[Answer, Edge(label='ask the question')]: + result = await ask_agent.run( + 'Ask a simple question with a single correct answer.', + message_history=ctx.state.ask_agent_messages, + ) + ctx.state.ask_agent_messages += result.all_messages() + return Answer(result.data) + + +@dataclass +class Answer(BaseNode[QuestionState]): + """Get the answer to the question from the user. + + This node must be completed outside the graph run. + """ + + question: str + answer: str | None = None + + async def run( + self, ctx: GraphContext[QuestionState] + ) -> Annotated[Evaluate, Edge(label='answer the question')]: + assert self.answer is not None + return Evaluate(self.question, self.answer) + + +@dataclass +class EvaluationResult: + correct: bool + comment: str + + +evaluate_agent = Agent( + 'openai:gpt-4o', + result_type=EvaluationResult, + system_prompt='Given a question and answer, evaluate if the answer is correct.', + result_tool_name='evaluation', +) + + +@dataclass +class Evaluate(BaseNode[QuestionState]): + question: str + answer: str + + async def run( + self, + ctx: GraphContext[QuestionState], + ) -> Congratulate | Castigate: + result = await evaluate_agent.run( + format_as_xml({'question': self.question, 'answer': self.answer}), + message_history=ctx.state.evaluate_agent_messages, + ) + ctx.state.evaluate_agent_messages += result.all_messages() + if result.data.correct: + return Congratulate(result.data.comment) + else: + return Castigate(result.data.comment) + + +@dataclass +class Congratulate(BaseNode[QuestionState, None]): + """Congratulate the user and end.""" + + comment: str + + async def run( + self, ctx: GraphContext[QuestionState] + ) -> Annotated[End, Edge(label='success')]: + print(f'Correct answer! {self.comment}') + return End(None) + + +@dataclass +class Castigate(BaseNode[QuestionState]): + """Castigate the user, then ask another question.""" + + comment: str + + async def run( + self, ctx: GraphContext[QuestionState] + ) -> Annotated[Ask, Edge(label='try again')]: + print(f'Comment: {self.comment}') + return Ask() + + +question_graph = Graph(nodes=(Ask, Answer, Evaluate, Congratulate, Castigate)) +print(question_graph.mermaid_code(start_node=Ask, notes=False)) + + +async def main(): + state = QuestionState() + node = Ask() + history: list[HistoryStep[QuestionState]] = [] + with logfire.span('run questions graph'): + while True: + node = await question_graph.next(state, node, history) + if isinstance(node, End): + print('\n'.join(e.summary() for e in history)) + break + elif isinstance(node, Answer): + node.answer = input(f'{node.question} ') + # otherwise just continue + + +if __name__ == '__main__': + import asyncio + + asyncio.run(main()) diff --git a/mkdocs.yml b/mkdocs.yml index 0b13b451c5..81faa18225 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -25,6 +25,7 @@ nav: - testing-evals.md - logfire.md - multi-agent-applications.md + - graph.md - Examples: - examples/index.md - examples/pydantic-model.md @@ -36,6 +37,7 @@ nav: - examples/stream-markdown.md - examples/stream-whales.md - examples/chat-app.md + - examples/question-graph.md - API Reference: - api/agent.md - api/tools.md diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 912ad47479..8cc686ce53 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -42,13 +42,14 @@ dependencies = [ ] [project.optional-dependencies] -graph = ["pydantic-graph==0.0.14"] +# WARNING if you add optional groups, please update docs/install.md +logfire = ["logfire>=2.3"] +graph = ["pydantic-graph==0.0.18"] openai = ["openai>=1.54.3"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"] anthropic = ["anthropic>=0.40.0"] groq = ["groq>=0.12.0"] mistral = ["mistralai>=1.2.5"] -logfire = ["logfire>=2.3"] [dependency-groups] dev = [ diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md index 2789d18ac1..069f58924d 100644 --- a/pydantic_graph/README.md +++ b/pydantic_graph/README.md @@ -9,7 +9,7 @@ Graph and finite state machine library. This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency -on `pydantic-ai` or related packages and does and can be considered as a pure graph library. +on `pydantic-ai` or related packages and can be considered as a pure graph library. As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. From d1af5619c68c589cde69974e5695dc82dec88ed7 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 11 Jan 2025 16:45:54 +0000 Subject: [PATCH 40/57] fix examples --- pydantic_graph/pydantic_graph/graph.py | 4 ++-- tests/test_examples.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index d2ff4a155f..789ba4a546 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -35,7 +35,7 @@ class Graph(Generic[StateT, RunEndT]): Here's a very simple example of a graph which increments a number by 1, but makes sure the number is never 42 at the end. - ```py {title="never_42.py"} + ```py {title="never_42.py" lint="not-imports"} from __future__ import annotations from dataclasses import dataclass @@ -155,7 +155,7 @@ async def run( Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: - ```py {title="run_never_42.py"} + ```py {title="run_never_42.py" lint="not-imports"} from never_42 import Increment, MyState, never_42_graph async def main(): diff --git a/tests/test_examples.py b/tests/test_examples.py index e6897d727f..ce43b69e34 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -95,7 +95,7 @@ def test_docs_examples( with (tmp_path / 'examples.json').open('w') as f: json.dump(examples, f) - ruff_ignore: list[str] = ['D'] + ruff_ignore: list[str] = ['D', 'Q001'] # `from bank_database import DatabaseConn` wrongly sorted in imports # waiting for https://github.com/pydantic/pytest-examples/issues/43 # and https://github.com/pydantic/pytest-examples/issues/46 From 9d5f45c3d6012a4f553491f65484c5f25ed7abf9 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 11 Jan 2025 21:49:47 +0000 Subject: [PATCH 41/57] more graph documentation --- docs/.hooks/main.py | 17 +++ docs/graph.md | 197 ++++++++++++++++++++++++- pydantic_graph/pydantic_graph/graph.py | 10 +- pydantic_graph/pydantic_graph/state.py | 6 +- tests/test_examples.py | 9 +- 5 files changed, 221 insertions(+), 18 deletions(-) diff --git a/docs/.hooks/main.py b/docs/.hooks/main.py index 6b339ad6a8..606e8cb616 100644 --- a/docs/.hooks/main.py +++ b/docs/.hooks/main.py @@ -19,11 +19,28 @@ def on_page_markdown(markdown: str, page: Page, config: Config, files: Files) -> return markdown +# path to the main mkdocs material bundle file, found during `on_env` +bundle_path: Path | None = None + + def on_env(env: Environment, config: Config, files: Files) -> Environment: + global bundle_path + for file in files: + if re.match('assets/javascripts/bundle.[a-z0-9]+.min.js', file.src_uri): + bundle_path = Path(file.dest_dir) / file.src_uri + env.globals['build_timestamp'] = str(int(time.time())) return env +def on_post_build(config: Config) -> None: + """Inject extra CSS into mermaid styles to avoid titles being the same color as the background in dark mode.""" + if bundle_path.exists(): + content = bundle_path.read_text() + content, _ = re.subn(r'}(\.statediagram)', '}.statediagramTitleText{fill:#888}\1', content, count=1) + bundle_path.write_text(content) + + def replace_uv_python_run(markdown: str) -> str: return re.sub(r'```bash\n(.*?)(python/uv[\- ]run|pip/uv[\- ]add|py-cli)(.+?)\n```', sub_run, markdown) diff --git a/docs/graph.md b/docs/graph.md index e1e01bc421..0765d5e265 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -6,29 +6,210 @@ * sure, nail guns look cooler than hammers * but nail guns take a lot more setup than hammers * and nail guns don't make you a better builder, they make you a builder with a nail gun - * Lastly, (and at the risk of torturing this metaphor), if you're a fan of medieval tools like mallets and untyped Python, you probably won't like nail guns or PydanticAI approach to graphs. (But then again, if you're not a fan of type hints in Python, you've probably already bounced off PydanticAI to use one of the toy agent frameworks — good luck, and feel free to borrow my sledgehammer) + * Lastly, (and at the risk of torturing this metaphor), if you're a fan of medieval tools like mallets and untyped Python, you probably won't like nail guns or our approach to graphs. (But then again, if you're not a fan of type hints in Python, you've probably already bounced off PydanticAI to use one of the toy agent frameworks — good luck, and feel free to borrow my sledgehammer when you realize you need it) - In short, graphs are a powerful tool, but they're not the right tool for every job. Please consider other [multi-agent approaches](multi-agent-applications.md) before proceeding. Unless you're sure you need a graph, you probably don't. + In short, graphs are a powerful tool, but they're not the right tool for every job. Please consider other [multi-agent approaches](multi-agent-applications.md) before proceeding. -Graphs and associated finite state machines (FSMs) are a powerful abstraction to model, control and visualize complex workflows. + Unless you're sure you need a graph, you probably don't. -Alongside PydanticAI, we've developed `pydantic-graph` — an async graph and state machine library for Python where nodes and edges are defined in pure Python using type hints. +Graphs and finite state machines (FSMs) are a powerful abstraction to model, control and visualize complex workflows. + +Alongside PydanticAI, we've developed `pydantic-graph` — an async graph and state machine library for Python where nodes and edges are defined using type hints. While this library is developed as part of the PydanticAI; it has no dependency on `pydantic-ai` and can be considered as a pure graph library. You may find it useful whether or not you're using PydanticAI or even building with GenAI. +`pydantic-graph` is designed for advanced users and makes heavy use of Python generics and types hints. It is not designed to b as beginner-friendly as PydanticAI. + ## Installation -`pydantic-graph` a required dependency of `pydantic-ai`, and an optional dependency of `pydantic-ai-slim`, see [installation instructions](install.md) for more information. You can also install it directly: +`pydantic-graph` is a required dependency of `pydantic-ai`, and an optional dependency of `pydantic-ai-slim`, see [installation instructions](install.md#slim-install) for more information. You can also install it directly: ```bash pip/uv-add pydantic-graph ``` +## Graph Types + +!!! note "Every Early beta" + Graph support was [introduced](https://github.com/pydantic/pydantic-ai/pull/528) in v0.0.19 and is in very earlier beta. The API is subject to change. The documentation is incomplete. The implementation is incomplete. + +Graphs are made up of a few key components: + +### Nodes + +Subclasses of [`BaseNode`][pydantic_graph.nodes.BaseNode] to define nodes. + +Nodes which are generally [`dataclass`es][dataclasses.dataclass] include: + +* any parameters required when calling the node +* the business logic to execute the node +* return annotations which are read by `pydantic-graph` to determine the outgoing edges of the node + +Nodes are generic in: + +* **state** which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.state.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter +* **graph return type** this only applies if the node returns [`End`][pydantic_graph.nodes.End], [`RunEndT`][pydantic_graph.nodes.RunEndT] has a default of [Never][typing.Never] so this generic parameter can be omitted if the node doesn't return `End`, but must be included if it does. + +### GraphContext + +[`GraphContext`][pydantic_graph.nodes.GraphContext] — The context for the graph run, similar to PydanticAI's [`RunContext`][pydantic_ai.tools.RunContext], this holds the state of the graph and is passed to nodes when they're run. + +`GraphContext` is generic in the state type of the graph it's used in, [`StateT`][pydantic_graph.state.StateT]. + +### End + +[`End`][pydantic_graph.nodes.End] — return value to indicates the graph run should end. + +`End` is generic in the graph return type of the graph it's used in, [`RunEndT`][pydantic_graph.nodes.RunEndT]. + +### Graph + +[`Graph`][pydantic_graph.graph.Graph] — The graph itself, made up of a set of nodes. + +`Graph` is generic in: + +* **state** the state type of the graph, [`StateT`][pydantic_graph.state.StateT] +* **graph return type** the return type of the graph run, [`RunEndT`][pydantic_graph.nodes.RunEndT] + ## Basic Usage -TODO +Here's an example of a graph which represents a vending machine where the user may insert coins and select a product to purchase. + +```python {title="vending_machine.py"} +from __future__ import annotations + +from dataclasses import dataclass + +from rich.prompt import Prompt + +from pydantic_graph import BaseNode, End, Graph, GraphContext + + +@dataclass +class MachineState: # (1)! + user_balance: float = 0.0 + product: str | None = None + + +@dataclass +class InsertCoin(BaseNode[MachineState]): # (3)! + async def run(self, ctx: GraphContext[MachineState]) -> CoinsInserted: # (16)! + return CoinsInserted(float(Prompt.ask('Insert coins'))) # (4)! + + +@dataclass +class CoinsInserted(BaseNode[MachineState]): + amount: float # (5)! + + async def run( + self, ctx: GraphContext[MachineState] + ) -> SelectProduct | Purchase: # (17)! + ctx.state.user_balance += self.amount # (6)! + if ctx.state.product is not None: # (7)! + return Purchase(ctx.state.product) + else: + return SelectProduct() + + +@dataclass +class SelectProduct(BaseNode[MachineState]): + async def run(self, ctx: GraphContext[MachineState]) -> Purchase: + return Purchase(Prompt.ask('Select product')) + + +PRODUCT_PRICES = { # (2)! + 'water': 1.25, + 'soda': 1.50, + 'crisps': 1.75, + 'chocolate': 2.00, +} + + +@dataclass +class Purchase(BaseNode[MachineState, None]): # (18)! + product: str + + async def run( + self, ctx: GraphContext[MachineState] + ) -> End | InsertCoin | SelectProduct: + if price := PRODUCT_PRICES.get(self.product): # (8)! + ctx.state.product = self.product # (9)! + if ctx.state.user_balance >= price: # (10)! + ctx.state.user_balance -= price + return End(None) + else: + diff = price - ctx.state.user_balance + print(f'Not enough money for {self.product}, need {diff:0.2f} more') + #> Not enough money for crisps, need 0.75 more + return InsertCoin() # (11)! + else: + print(f'No such product: {self.product}, try again') + return SelectProduct() # (12)! + + +vending_machine_graph = Graph( # (13)! + nodes=[InsertCoin, CoinsInserted, SelectProduct, Purchase] +) + + +async def main(): + state = MachineState() # (14)! + await vending_machine_graph.run(state, InsertCoin()) # (15)! + print(f'purchase successful item={state.product} change={state.user_balance:0.2f}') + #> purchase successful item=crisps change=0.25 +``` + +1. The state of the vending machine is defined as a dataclass with the user's balance and the product they've selected, if any. +2. A dictionary of products mapped to prices. +3. The `InsertCoin` node, [`BaseNode`][pydantic_graph.nodes.BaseNode] is parameterized with `MachineState` as that's the state used in this graph. +4. The `InsertCoin` node prompts the user to insert coins. Keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using `input` within node, see [below](#running-graphs) for how control flow can be managed when nodes require external input. +5. The `CoinsInserted` node, again this is a [`dataclass`][dataclasses.dataclass], in this case with one field `amount`, thus nodes calling `CoinsInserted` must provide an amount. +6. Update the user's balance with the amount inserted. +7. If the user has already selected a product, go to `Purchase`, otherwise go to `SelectProduct`. +8. In the `Purchase` node, look up the price of the product if the user entered a valid product. +9. If the user did enter a valid product, set the product in the state so we don't revisit `SelectProduct`. +10. If the balance is enough to purchase the product, adjust the balance to reflect the purchase and return [`End`][pydantic_graph.nodes.End] to end the graph. We're not using the run return type, so we call `End` with `None`. +11. If the balance is insufficient, to go `InsertCoin` to prompt the user to insert more coins. +12. If the product is invalid, go to `SelectProduct` to prompt the user to select a product again. +13. The graph is created by passing a list of nodes to [`Graph`][pydantic_graph.graph.Graph]. Order of nodes is not important, but will alter how [diagramss](#mermaid-diagrams) are displayed. +14. Initialize the state, this will be passed to the graph run and mutated as the graph runs. +15. Run the graph with the initial state, since the graph can be run from any node, we must pass the start node, in this case `InsertCoin`. [`Graph.run`][pydantic_graph.graph.Graph.run] returns a tuple of the return value (`None`) in this case, and the [history][pydantic_graph.state.HistoryStep] of the graph run. +16. The return type of the node's [`run`][pydantic_graph.nodes.BaseNode.run] method is important, it's used to determine the outgoing edges of the node, this in turn is used to render [mermaid diagrams](#mermaid-diagrams) and is enforced at runtime. +17. The return type of `CoinsInserted`s [`run`][pydantic_graph.nodes.BaseNode.run] method is a union, meaning multiple outgoing edges are possible. +18. Unlike other nodes `Purchase` can end the run, so the [`RunEndT`][pydantic_graph.nodes.RunEndT] generic parameter must be set, in this case it's `None` since the graph run return type is `None`. + +_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_ + +A [mermaid diagram](#mermaid-diagrams) for this graph can be generated with the following code: + +```py {title="vending_machine_diagram.py"} +from vending_machine import InsertCoin, vending_machine_graph + +vending_machine_graph.mermaid_code(start_node=InsertCoin) +``` + +_(This example is complete, it can be run "as is")_ + +The diagram generated by the above code is: + +```mermaid +--- +title: vending_machine_graph +--- +stateDiagram-v2 + [*] --> InsertCoin + InsertCoin --> CoinsInserted + CoinsInserted --> SelectProduct + CoinsInserted --> Purchase + SelectProduct --> Purchase + Purchase --> InsertCoin + Purchase --> SelectProduct + Purchase --> [*] +``` + +See [below](#mermaid-diagrams) for more information on generating diagrams. -## Typing +## GenAI Example TODO @@ -36,7 +217,7 @@ TODO TODO -## State Machines +## State Machine TODO diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 789ba4a546..469905da2e 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -9,21 +9,17 @@ from typing import TYPE_CHECKING, Any, Callable, Generic import logfire_api -from typing_extensions import Literal, Never, ParamSpec, TypeVar, Unpack, assert_never +from typing_extensions import Literal, Unpack, assert_never from . import _utils, exceptions, mermaid from ._utils import get_parent_namespace -from .nodes import BaseNode, End, GraphContext, NodeDef +from .nodes import BaseNode, End, GraphContext, NodeDef, RunEndT from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state __all__ = ('Graph',) _logfire = logfire_api.Logfire(otel_scope='pydantic-graph') -RunSignatureT = ParamSpec('RunSignatureT') -RunEndT = TypeVar('RunEndT', default=None) -NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) - @dataclass(init=False) class Graph(Generic[StateT, RunEndT]): @@ -270,7 +266,7 @@ def mermaid_code( highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, infer_name: bool = True, ) -> str: - """Generate a diagram representing the graph as [mermaid](https://mermaid.js.org/) chart. + """Generate a diagram representing the graph as [mermaid](https://mermaid.js.org/) diagram. This method calls [`pydantic_graph.mermaid.generate_code`][pydantic_graph.mermaid.generate_code]. diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index 9cd197a637..7d4a977292 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -78,4 +78,8 @@ def summary(self) -> str: HistoryStep = Union[NodeStep[StateT, RunEndT], EndStep[StateT, RunEndT]] -"""A step in the history of a graph run.""" +"""A step in the history of a graph run. + +[`Graph.run`][pydantic_graph.graph.Graph.run] returns a list of these steps describing the execution of the graph, +together with the run return value. +""" diff --git a/tests/test_examples.py b/tests/test_examples.py index ce43b69e34..93ba1b90f9 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -150,9 +150,14 @@ async def async_http_request(url: str, **kwargs: Any) -> httpx.Response: def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: if prompt == 'Where would you like to fly from and to?': return 'SFO to ANC' - else: - assert prompt == 'What seat would you like?', prompt + elif prompt == 'What seat would you like?': return 'window seat with leg room' + if prompt == 'Insert coins': + return '1' + elif prompt == 'Select product': + return 'crisps' + else: # pragma: no cover + raise ValueError(f'Unexpected prompt: {prompt}') text_responses: dict[str, str | ToolCallPart] = { From a3a0ddc56d28d06ada0ac6934dec1ba638bb5e24 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 11 Jan 2025 23:15:55 +0000 Subject: [PATCH 42/57] add GenAI example --- docs/graph.md | 139 ++++++++++++++++++++++++- pydantic_graph/pydantic_graph/graph.py | 73 ++++++------- tests/test_examples.py | 16 +++ 3 files changed, 189 insertions(+), 39 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index 0765d5e265..7875c0b52c 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -162,7 +162,7 @@ async def main(): 1. The state of the vending machine is defined as a dataclass with the user's balance and the product they've selected, if any. 2. A dictionary of products mapped to prices. 3. The `InsertCoin` node, [`BaseNode`][pydantic_graph.nodes.BaseNode] is parameterized with `MachineState` as that's the state used in this graph. -4. The `InsertCoin` node prompts the user to insert coins. Keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using `input` within node, see [below](#running-graphs) for how control flow can be managed when nodes require external input. +4. The `InsertCoin` node prompts the user to insert coins. Keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using `input` within node, see [below](#custom-control-flow) for how control flow can be managed when nodes require external input. 5. The `CoinsInserted` node, again this is a [`dataclass`][dataclasses.dataclass], in this case with one field `amount`, thus nodes calling `CoinsInserted` must provide an amount. 6. Update the user's balance with the amount inserted. 7. If the user has already selected a product, go to `Purchase`, otherwise go to `SelectProduct`. @@ -211,9 +211,142 @@ See [below](#mermaid-diagrams) for more information on generating diagrams. ## GenAI Example -TODO +So far we haven't shown an example of a Graph that actually uses PydanticAI or GenAI at all. + +In this example, one agent generates a welcome email to a user and the other agent provides feedback on the email. + +This graph has avery simple structure: + +```mermaid +--- +title: feedback_graph +--- +stateDiagram-v2 + [*] --> WriteEmail + WriteEmail --> Feedback + Feedback --> WriteEmail + Feedback --> [*] +``` + + +```python {title="genai_email_feedback.py"} +from __future__ import annotations as _annotations + +from dataclasses import dataclass, field + +from pydantic import BaseModel, EmailStr + +from pydantic_ai import Agent +from pydantic_ai.format_as_xml import format_as_xml +from pydantic_ai.messages import ModelMessage +from pydantic_graph import BaseNode, End, Graph, GraphContext + + +@dataclass +class User: + name: str + email: EmailStr + interests: list[str] + + +@dataclass +class Email: + subject: str + body: str + + +@dataclass +class State: + user: User + write_agent_messages: list[ModelMessage] = field(default_factory=list) + + +email_writer_agent = Agent( + 'google-vertex:gemini-1.5-pro', + result_type=Email, + system_prompt='Write a welcome email to our tech blog.', +) + + +@dataclass +class WriteEmail(BaseNode[State]): + email_feedback: str | None = None + + async def run(self, ctx: GraphContext[State]) -> Feedback: + if self.email_feedback: + prompt = ( + f'Rewrite the email for the user:\n' + f'{format_as_xml(ctx.state.user)}\n' + f'Feedback: {self.email_feedback}' + ) + else: + prompt = ( + f'Write a welcome email for the user:\n' + f'{format_as_xml(ctx.state.user)}' + ) + + result = await email_writer_agent.run( + prompt, + message_history=ctx.state.write_agent_messages, + ) + ctx.state.write_agent_messages += result.all_messages() + return Feedback(result.data) + + +class EmailRequiresWrite(BaseModel): + feedback: str + + +class EmailOk(BaseModel): + pass + + +feedback_agent = Agent[None, EmailRequiresWrite | EmailOk]( + 'openai:gpt-4o', + result_type=EmailRequiresWrite | EmailOk, # type: ignore + system_prompt=( + 'Review the email and provide feedback, email must reference the users specific interests.' + ), +) + + +@dataclass +class Feedback(BaseNode[State, Email]): + email: Email + + async def run( + self, + ctx: GraphContext[State], + ) -> WriteEmail | End[Email]: + prompt = format_as_xml({'user': ctx.state.user, 'email': self.email}) + result = await feedback_agent.run(prompt) + if isinstance(result.data, EmailRequiresWrite): + return WriteEmail(email_feedback=result.data.feedback) + else: + return End(self.email) + + +async def main(): + user = User( + name='John Doe', + email='john.joe@exmaple.com', + interests=['Haskel', 'Lisp', 'Fortran'], + ) + state = State(user) + feedback_graph = Graph(nodes=(WriteEmail, Feedback)) + email, _ = await feedback_graph.run(state, WriteEmail()) + print(email) + """ + Email( + subject='Welcome to our tech blog!', + body='Hello John, Welcome to our tech blog! ...', + ) + """ +``` + +_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_ -## Running Graphs +## Custom Control Flow TODO diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 469905da2e..4cfd74541b 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -97,41 +97,6 @@ def __init__( self._validate_edges() - async def next( - self, - state: StateT, - node: BaseNode[StateT, RunEndT], - history: list[HistoryStep[StateT, RunEndT]], - *, - infer_name: bool = True, - ) -> BaseNode[StateT, Any] | End[RunEndT]: - """Run a node in the graph and return the next node to run. - - Args: - state: The current state of the graph. - node: The node to run. - history: The history of the graph run so far. NOTE: this will be mutated to add the new step. - infer_name: Whether to infer the graph name from the calling frame. - - Returns: - The next node to run or [`End`][pydantic_graph.nodes.End] if the graph has finished. - """ - if infer_name and self.name is None: - self._infer_name(inspect.currentframe()) - node_id = node.get_id() - if node_id not in self.node_defs: - raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') - - history_step: NodeStep[StateT, RunEndT] = NodeStep(state, node) - history.append(history_step) - - ctx = GraphContext(state) - with _logfire.span('run node {node_id}', node_id=node_id, node=node): - start = perf_counter() - next_node = await node.run(ctx) - history_step.duration = perf_counter() - start - return next_node - async def run( self, state: StateT, @@ -147,7 +112,8 @@ async def run( you need to provide the starting node. infer_name: Whether to infer the graph name from the calling frame. - Returns: The result type from ending the run and the history of the run. + Returns: + The result type from ending the run and the history of the run. Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: @@ -255,6 +221,41 @@ async def main(): f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.' ) + async def next( + self, + state: StateT, + node: BaseNode[StateT, RunEndT], + history: list[HistoryStep[StateT, RunEndT]], + *, + infer_name: bool = True, + ) -> BaseNode[StateT, Any] | End[RunEndT]: + """Run a node in the graph and return the next node to run. + + Args: + state: The current state of the graph. + node: The node to run. + history: The history of the graph run so far. NOTE: this will be mutated to add the new step. + infer_name: Whether to infer the graph name from the calling frame. + + Returns: + The next node to run or [`End`][pydantic_graph.nodes.End] if the graph has finished. + """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + node_id = node.get_id() + if node_id not in self.node_defs: + raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') + + history_step: NodeStep[StateT, RunEndT] = NodeStep(state, node) + history.append(history_step) + + ctx = GraphContext(state) + with _logfire.span('run node {node_id}', node_id=node_id, node=node): + start = perf_counter() + next_node = await node.run(ctx) + history_step.duration = perf_counter() - start + return next_node + def mermaid_code( self, *, diff --git a/tests/test_examples.py b/tests/test_examples.py index 93ba1b90f9..4adf2dada5 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -257,6 +257,22 @@ async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelRes return ModelResponse(parts=[ToolCallPart(tool_name='get_jokes', args=ArgsDict({'count': 5}))]) elif re.fullmatch(r'sql prompt \d+', m.content): return ModelResponse.from_text(content='SELECT 1') + elif m.content.startswith('Write a welcome email for the user:'): + return ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args=ArgsDict( + { + 'subject': 'Welcome to our tech blog!', + 'body': 'Hello John, Welcome to our tech blog! ...', + } + ), + ) + ] + ) + elif m.content.startswith('\n '): + return ModelResponse(parts=[ToolCallPart(tool_name='final_result_EmailOk', args=ArgsDict({}))]) elif response := text_responses.get(m.content): if isinstance(response, str): return ModelResponse.from_text(content=response) From 3994899c75604dc0d0bbb1c1518c2ea82a4bbefa Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 12 Jan 2025 11:17:34 +0000 Subject: [PATCH 43/57] more graph docs --- docs/api/models/function.md | 2 +- docs/api/models/test.md | 2 +- docs/graph.md | 162 +++++++++++++++--- .../pydantic_ai_examples/question_graph.py | 3 +- pydantic_ai_slim/pydantic_ai/agent.py | 7 +- pydantic_ai_slim/pydantic_ai/format_as_xml.py | 3 +- pydantic_ai_slim/pydantic_ai/tools.py | 6 +- pydantic_graph/pydantic_graph/_utils.py | 3 +- pydantic_graph/pydantic_graph/graph.py | 34 +++- pydantic_graph/pydantic_graph/mermaid.py | 3 +- pydantic_graph/pydantic_graph/state.py | 16 +- tests/graph/test_graph.py | 14 +- tests/test_examples.py | 15 +- 13 files changed, 218 insertions(+), 52 deletions(-) diff --git a/docs/api/models/function.md b/docs/api/models/function.md index 6280ae47f1..7d16124fdb 100644 --- a/docs/api/models/function.md +++ b/docs/api/models/function.md @@ -9,7 +9,7 @@ Its primary use case is for more advanced unit testing than is possible with `Te Here's a minimal example: -```py {title="function_model_usage.py" call_name="test_my_agent" lint="not-imports"} +```py {title="function_model_usage.py" call_name="test_my_agent" noqa="I001"} from pydantic_ai import Agent from pydantic_ai.messages import ModelMessage, ModelResponse from pydantic_ai.models.function import FunctionModel, AgentInfo diff --git a/docs/api/models/test.md b/docs/api/models/test.md index 66bfac7932..e6c4411d29 100644 --- a/docs/api/models/test.md +++ b/docs/api/models/test.md @@ -4,7 +4,7 @@ Utility model for quickly testing apps built with PydanticAI. Here's a minimal example: -```py {title="test_model_usage.py" call_name="test_my_agent" lint="not-imports"} +```py {title="test_model_usage.py" call_name="test_my_agent" noqa="I001"} from pydantic_ai import Agent from pydantic_ai.models.test import TestModel diff --git a/docs/graph.md b/docs/graph.md index 7875c0b52c..829d660a82 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -18,7 +18,10 @@ Alongside PydanticAI, we've developed `pydantic-graph` — an async graph and st While this library is developed as part of the PydanticAI; it has no dependency on `pydantic-ai` and can be considered as a pure graph library. You may find it useful whether or not you're using PydanticAI or even building with GenAI. -`pydantic-graph` is designed for advanced users and makes heavy use of Python generics and types hints. It is not designed to b as beginner-friendly as PydanticAI. +`pydantic-graph` is designed for advanced users and makes heavy use of Python generics and types hints. It is not designed to be as beginner-friendly as PydanticAI. + +!!! note "Every Early beta" + Graph support was [introduced](https://github.com/pydantic/pydantic-ai/pull/528) in v0.0.19 and is in very earlier beta. The API is subject to change. The documentation is incomplete. The implementation is incomplete. ## Installation @@ -30,11 +33,20 @@ pip/uv-add pydantic-graph ## Graph Types -!!! note "Every Early beta" - Graph support was [introduced](https://github.com/pydantic/pydantic-ai/pull/528) in v0.0.19 and is in very earlier beta. The API is subject to change. The documentation is incomplete. The implementation is incomplete. - Graphs are made up of a few key components: +### GraphContext + +[`GraphContext`][pydantic_graph.nodes.GraphContext] — The context for the graph run, similar to PydanticAI's [`RunContext`][pydantic_ai.tools.RunContext], this holds the state of the graph and is passed to nodes when they're run. + +`GraphContext` is generic in the state type of the graph it's used in, [`StateT`][pydantic_graph.state.StateT]. + +### End + +[`End`][pydantic_graph.nodes.End] — return value to indicates the graph run should end. + +`End` is generic in the graph return type of the graph it's used in, [`RunEndT`][pydantic_graph.nodes.RunEndT]. + ### Nodes Subclasses of [`BaseNode`][pydantic_graph.nodes.BaseNode] to define nodes. @@ -50,17 +62,55 @@ Nodes are generic in: * **state** which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.state.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter * **graph return type** this only applies if the node returns [`End`][pydantic_graph.nodes.End], [`RunEndT`][pydantic_graph.nodes.RunEndT] has a default of [Never][typing.Never] so this generic parameter can be omitted if the node doesn't return `End`, but must be included if it does. -### GraphContext +Here's an example of a start or intermediate node in a graph — it can't end the run as it doesn't return [`End`][pydantic_graph.nodes.End]: -[`GraphContext`][pydantic_graph.nodes.GraphContext] — The context for the graph run, similar to PydanticAI's [`RunContext`][pydantic_ai.tools.RunContext], this holds the state of the graph and is passed to nodes when they're run. +```py {title="intermediate_node.py" noqa="F821" test="skip"} +from dataclasses import dataclass -`GraphContext` is generic in the state type of the graph it's used in, [`StateT`][pydantic_graph.state.StateT]. +from pydantic_graph import BaseNode, GraphContext -### End -[`End`][pydantic_graph.nodes.End] — return value to indicates the graph run should end. +@dataclass +class MyNode(BaseNode[MyState]): # (1)! + foo: int # (2)! -`End` is generic in the graph return type of the graph it's used in, [`RunEndT`][pydantic_graph.nodes.RunEndT]. + async def run( + self, + ctx: GraphContext[MyState], # (3)! + ) -> AnotherNode: # (4)! + ... + return AnotherNode() +``` + +1. State in this example is `MyState` (not shown), hence `BaseNode` is parameterized with `MyState`. This node can't end the run, so the `RunEndT` generic parameter is omitted and defaults to `Never`. +2. `MyNode` is a dataclass and has a single field `foo`, an `int`. +3. The `run` method takes a `GraphContext` parameter, again parameterized with state `MyState`. +4. The return type of the `run` method is `AnotherNode` (not shown), this is used to determine the outgoing edges of the node. + +We could extend `MyNode` to optionally end the run if `foo` is divisible by 5: + +```py {title="intermediate_or_end_node.py" hl_lines="7 13" noqa="F821" test="skip"} +from dataclasses import dataclass + +from pydantic_graph import BaseNode, End, GraphContext + + +@dataclass +class MyNode(BaseNode[MyState, int]): # (1)! + foo: int + + async def run( + self, + ctx: GraphContext[MyState], + ) -> AnotherNode | End[int]: # (2)! + if self.foo % 5 == 0: + return End(self.foo) + else: + return AnotherNode() +``` + +1. We parameterize the node with the return type (`int` in this case) as well as state. +2. The return type of the `run` method is now a union of `AnotherNode` and `End[int]`, this allows the node to end the run if `foo` is divisible by 5. ### Graph @@ -71,11 +121,76 @@ Nodes are generic in: * **state** the state type of the graph, [`StateT`][pydantic_graph.state.StateT] * **graph return type** the return type of the graph run, [`RunEndT`][pydantic_graph.nodes.RunEndT] -## Basic Usage +Here's an example of a simple graph: + +```py {title="graph_example.py" py="3.10"} +from __future__ import annotations + +from dataclasses import dataclass + +from pydantic_graph import BaseNode, End, Graph, GraphContext + + +@dataclass +class DivisibleBy5(BaseNode[None, int]): # (1)! + foo: int + + async def run( + self, + ctx: GraphContext, + ) -> Increment | End[int]: + if self.foo % 5 == 0: + return End(self.foo) + else: + return Increment(self.foo) + + +@dataclass +class Increment(BaseNode): # (2)! + foo: int + + async def run(self, ctx: GraphContext) -> DivisibleBy5: + return DivisibleBy5(self.foo + 1) + + +fives_graph = Graph(nodes=[DivisibleBy5, Increment]) +result, history = fives_graph.run_sync(None, DivisibleBy5(4)) +print(result) +#> 5 +# the full history is quite verbose (see below), so we'll just print the summary +print([item.data_snapshot() for item in history]) +#> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] +``` +_(This example is complete, it can be run "as is" with Python 3.10+)_ + +A [mermaid diagram](#mermaid-diagrams) for this graph can be generated with the following code: + +```py {title="graph_example_diagram.py" py="3.10"} +from graph_example import DivisibleBy5, fives_graph + +fives_graph.mermaid_code(start_node=DivisibleBy5) +``` + +```mermaid +--- +title: fives_graph +--- +stateDiagram-v2 + [*] --> DivisibleBy5 + DivisibleBy5 --> Increment + DivisibleBy5 --> [*] + Increment --> DivisibleBy5 +``` + +## Stateful Graphs + +TODO introduce state + +TODO link to issue about persistent state. Here's an example of a graph which represents a vending machine where the user may insert coins and select a product to purchase. -```python {title="vending_machine.py"} +```python {title="vending_machine.py" py="3.10"} from __future__ import annotations from dataclasses import dataclass @@ -162,7 +277,7 @@ async def main(): 1. The state of the vending machine is defined as a dataclass with the user's balance and the product they've selected, if any. 2. A dictionary of products mapped to prices. 3. The `InsertCoin` node, [`BaseNode`][pydantic_graph.nodes.BaseNode] is parameterized with `MachineState` as that's the state used in this graph. -4. The `InsertCoin` node prompts the user to insert coins. Keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using `input` within node, see [below](#custom-control-flow) for how control flow can be managed when nodes require external input. +4. The `InsertCoin` node prompts the user to insert coins. We keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using `input` within nodes, see [below](#custom-control-flow) for how control flow can be managed when nodes require external input. 5. The `CoinsInserted` node, again this is a [`dataclass`][dataclasses.dataclass], in this case with one field `amount`, thus nodes calling `CoinsInserted` must provide an amount. 6. Update the user's balance with the amount inserted. 7. If the user has already selected a product, go to `Purchase`, otherwise go to `SelectProduct`. @@ -171,18 +286,18 @@ async def main(): 10. If the balance is enough to purchase the product, adjust the balance to reflect the purchase and return [`End`][pydantic_graph.nodes.End] to end the graph. We're not using the run return type, so we call `End` with `None`. 11. If the balance is insufficient, to go `InsertCoin` to prompt the user to insert more coins. 12. If the product is invalid, go to `SelectProduct` to prompt the user to select a product again. -13. The graph is created by passing a list of nodes to [`Graph`][pydantic_graph.graph.Graph]. Order of nodes is not important, but will alter how [diagramss](#mermaid-diagrams) are displayed. +13. The graph is created by passing a list of nodes to [`Graph`][pydantic_graph.graph.Graph]. Order of nodes is not important, but will alter how [diagrams](#mermaid-diagrams) are displayed. 14. Initialize the state, this will be passed to the graph run and mutated as the graph runs. 15. Run the graph with the initial state, since the graph can be run from any node, we must pass the start node, in this case `InsertCoin`. [`Graph.run`][pydantic_graph.graph.Graph.run] returns a tuple of the return value (`None`) in this case, and the [history][pydantic_graph.state.HistoryStep] of the graph run. 16. The return type of the node's [`run`][pydantic_graph.nodes.BaseNode.run] method is important, it's used to determine the outgoing edges of the node, this in turn is used to render [mermaid diagrams](#mermaid-diagrams) and is enforced at runtime. 17. The return type of `CoinsInserted`s [`run`][pydantic_graph.nodes.BaseNode.run] method is a union, meaning multiple outgoing edges are possible. 18. Unlike other nodes `Purchase` can end the run, so the [`RunEndT`][pydantic_graph.nodes.RunEndT] generic parameter must be set, in this case it's `None` since the graph run return type is `None`. -_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_ +_(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ A [mermaid diagram](#mermaid-diagrams) for this graph can be generated with the following code: -```py {title="vending_machine_diagram.py"} +```py {title="vending_machine_diagram.py" py="3.10"} from vending_machine import InsertCoin, vending_machine_graph vending_machine_graph.mermaid_code(start_node=InsertCoin) @@ -215,7 +330,7 @@ So far we haven't shown an example of a Graph that actually uses PydanticAI or G In this example, one agent generates a welcome email to a user and the other agent provides feedback on the email. -This graph has avery simple structure: +This graph has a very simple structure: ```mermaid --- @@ -229,7 +344,7 @@ stateDiagram-v2 ``` -```python {title="genai_email_feedback.py"} +```python {title="genai_email_feedback.py" py="3.10"} from __future__ import annotations as _annotations from dataclasses import dataclass, field @@ -344,15 +459,18 @@ async def main(): """ ``` -_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_ +_(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ ## Custom Control Flow -TODO +In many real-world applications, Graphs cannot run uninterrupted from start to finish — they require external input or run over an extended period of time where a single process cannot execute the entire graph. -## State Machine +In these scenarios the [`next`][pydantic_graph.graph.Graph.next] method can be used to run the graph one node at a time. -TODO +In this example, an AI asks the user a question, the user provides an answer, the AI evaluates the answer and ends if the user got it right or asks another question if they got it wrong. + +```python {title="ai_q_and_a.py"} +``` ## Mermaid Diagrams diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index 5f2c42bca9..d8158778c0 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -11,6 +11,7 @@ from typing import Annotated import logfire +from devtools import debug from pydantic_graph import BaseNode, Edge, End, Graph, GraphContext, HistoryStep from pydantic_ai import Agent @@ -136,7 +137,7 @@ async def main(): while True: node = await question_graph.next(state, node, history) if isinstance(node, End): - print('\n'.join(e.summary() for e in history)) + debug([e.data_snapshot() for e in history]) break elif isinstance(node, Answer): node.answer = input(f'{node.question} ') diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 0a8a254ffd..f44d1a91df 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -383,10 +383,9 @@ def run_sync( agent = Agent('openai:gpt-4o') - async def main(): - result = await agent.run('What is the capital of France?') - print(result.data) - #> Paris + result = agent.run_sync('What is the capital of France?') + print(result.data) + #> Paris ``` Args: diff --git a/pydantic_ai_slim/pydantic_ai/format_as_xml.py b/pydantic_ai_slim/pydantic_ai/format_as_xml.py index 3c6f67fd42..ce9973691d 100644 --- a/pydantic_ai_slim/pydantic_ai/format_as_xml.py +++ b/pydantic_ai_slim/pydantic_ai/format_as_xml.py @@ -37,7 +37,8 @@ def format_as_xml( none_str: String to use for `None` values. indent: Indentation string to use for pretty printing. - Returns: XML representation of the object. + Returns: + XML representation of the object. Example: ```python {title="format_as_xml_example.py" lint="skip"} diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 5f07f66ef7..9a99bbba22 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -107,7 +107,7 @@ def replace_with( Example — here `only_if_42` is valid as a `ToolPrepareFunc`: -```python {lint="not-imports"} +```python {noqa="I001"} from typing import Union from pydantic_ai import RunContext, Tool @@ -176,7 +176,7 @@ def __init__( Example usage: - ```python {lint="not-imports"} + ```python {noqa="I001"} from pydantic_ai import Agent, RunContext, Tool async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: @@ -187,7 +187,7 @@ async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: or with a custom prepare method: - ```python {lint="not-imports"} + ```python {noqa="I001"} from typing import Union from pydantic_ai import Agent, RunContext, Tool diff --git a/pydantic_graph/pydantic_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py index 81123afde6..0211f4fd17 100644 --- a/pydantic_graph/pydantic_graph/_utils.py +++ b/pydantic_graph/pydantic_graph/_utils.py @@ -24,7 +24,8 @@ def get_union_args(tp: Any) -> tuple[Any, ...]: def unpack_annotated(tp: Any) -> tuple[Any, list[Any]]: """Strip `Annotated` from the type if present. - Returns: `(tp argument, ())` if not annotated, otherwise `(stripped type, annotations)`. + Returns: + `(tp argument, ())` if not annotated, otherwise `(stripped type, annotations)`. """ origin = get_origin(tp) if origin is Annotated or origin is typing_extensions.Annotated: diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 4cfd74541b..8085307f85 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import asyncio import inspect from collections.abc import Sequence from dataclasses import dataclass @@ -31,7 +32,7 @@ class Graph(Generic[StateT, RunEndT]): Here's a very simple example of a graph which increments a number by 1, but makes sure the number is never 42 at the end. - ```py {title="never_42.py" lint="not-imports"} + ```py {title="never_42.py" noqa="I001" py="3.10"} from __future__ import annotations from dataclasses import dataclass @@ -117,7 +118,7 @@ async def run( Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: - ```py {title="run_never_42.py" lint="not-imports"} + ```py {title="run_never_42.py" noqa="I001" py="3.10"} from never_42 import Increment, MyState, never_42_graph async def main(): @@ -196,10 +197,10 @@ async def main(): ''' ``` """ - history: list[HistoryStep[StateT, RunEndT]] = [] if infer_name and self.name is None: self._infer_name(inspect.currentframe()) + history: list[HistoryStep[StateT, RunEndT]] = [] with _logfire.span( '{graph_name} run {start=}', graph_name=self.name or 'graph', @@ -221,6 +222,31 @@ async def main(): f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.' ) + def run_sync( + self, + state: StateT, + start_node: BaseNode[StateT, RunEndT], + *, + infer_name: bool = True, + ) -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: + """Run the graph synchronously. + + This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`. + You therefore can't use this method inside async code or if there's an active event loop. + + Args: + state: The initial state of the graph. + start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, + you need to provide the starting node. + infer_name: Whether to infer the graph name from the calling frame. + + Returns: + The result type from ending the run and the history of the run. + """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + return asyncio.get_event_loop().run_until_complete(self.run(state, start_node, infer_name=False)) + async def next( self, state: StateT, @@ -285,7 +311,7 @@ def mermaid_code( Here's an example of generating a diagram for the graph from [above][pydantic_graph.graph.Graph]: - ```py {title="never_42.py"} + ```py {title="never_42.py" py="3.10"} from never_42 import Increment, never_42_graph print(never_42_graph.mermaid_code(start_node=Increment)) diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index 1e72493759..49e41ee267 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -120,7 +120,8 @@ def request_image( graph: The graph to generate the image for. **kwargs: Additional parameters to configure mermaid chart generation. - Returns: The image data. + Returns: + The image data. """ code = generate_code( graph, diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index 7d4a977292..174607362c 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -50,8 +50,12 @@ def __post_init__(self, snapshot_state: Callable[[StateT], StateT]): # Copy the state to prevent it from being modified by other code self.state = snapshot_state(self.state) - def summary(self) -> str: - return str(self.node) + def data_snapshot(self) -> BaseNode[StateT, RunEndT]: + """Returns a deep copy of [`self.node`][pydantic_graph.state.NodeStep.node]. + + Useful for summarizing history. + """ + return copy.deepcopy(self.node) @dataclass @@ -73,8 +77,12 @@ def __post_init__(self, snapshot_state: Callable[[StateT], StateT]): # Copy the state to prevent it from being modified by other code self.state = snapshot_state(self.state) - def summary(self) -> str: - return str(self.result) + def data_snapshot(self) -> End[RunEndT]: + """Returns a deep copy of [`self.result`][pydantic_graph.state.EndStep.result]. + + Useful for summarizing history. + """ + return copy.deepcopy(self.result) HistoryStep = Union[NodeStep[StateT, RunEndT], EndStep[StateT, RunEndT]] diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index f82e92e5e6..5549ffaa8b 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -127,14 +127,14 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq ), ] ) - assert [e.summary() for e in history] == snapshot( + assert [e.data_snapshot() for e in history] == snapshot( [ - 'test_graph..Float2String(input_data=3.14159)', - "test_graph..String2Length(input_data='3.14159')", - 'test_graph..Double(input_data=7)', - "test_graph..String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx')", - 'test_graph..Double(input_data=21)', - 'End(data=42)', + Float2String(input_data=3.14159), + String2Length(input_data='3.14159'), + Double(input_data=7), + String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx'), + Double(input_data=21), + End(data=42), ] ) diff --git a/tests/test_examples.py b/tests/test_examples.py index 4adf2dada5..5d113943aa 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -84,6 +84,14 @@ def test_docs_examples( opt_title = prefix_settings.get('title') opt_test = prefix_settings.get('test', '') opt_lint = prefix_settings.get('lint', '') + noqa = prefix_settings.get('noqa', '') + python_version = prefix_settings.get('py', None) + + if python_version: + python_version_info = tuple(int(v) for v in python_version.split('.')) + if sys.version_info < python_version_info: + pytest.skip(f'Python version {python_version} required') + cwd = Path.cwd() if opt_test.startswith('skip') and opt_lint.startswith('skip'): @@ -99,9 +107,12 @@ def test_docs_examples( # `from bank_database import DatabaseConn` wrongly sorted in imports # waiting for https://github.com/pydantic/pytest-examples/issues/43 # and https://github.com/pydantic/pytest-examples/issues/46 - if opt_lint == 'not-imports' or 'import DatabaseConn' in example.source: + if 'import DatabaseConn' in example.source: ruff_ignore.append('I001') + if noqa: + ruff_ignore.extend(noqa.upper().split()) + line_length = int(prefix_settings.get('line_length', '88')) eval_example.set_config(ruff_ignore=ruff_ignore, target_version='py39', line_length=line_length) @@ -116,7 +127,7 @@ def test_docs_examples( eval_example.lint(example) if opt_test.startswith('skip'): - pytest.skip(opt_test[4:].lstrip(' -') or 'running code skipped') + print(opt_test[4:].lstrip(' -') or 'running code skipped') else: if eval_example.update_examples: module_dict = eval_example.run_print_update(example, call=call_name) From ecc243435dc66e0274e0e3f9c06f6c7552a373bb Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 13 Jan 2025 14:32:04 +0000 Subject: [PATCH 44/57] extending graph docs --- docs/graph.md | 176 +++++++++++++++++- .../pydantic_ai_examples/question_graph.py | 120 ++++++++---- pydantic_graph/pydantic_graph/nodes.py | 3 + tests/test_examples.py | 15 ++ 4 files changed, 269 insertions(+), 45 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index 829d660a82..1e2a8f2ffb 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -153,14 +153,20 @@ class Increment(BaseNode): # (2)! return DivisibleBy5(self.foo + 1) -fives_graph = Graph(nodes=[DivisibleBy5, Increment]) -result, history = fives_graph.run_sync(None, DivisibleBy5(4)) +fives_graph = Graph(nodes=[DivisibleBy5, Increment]) # (3)! +result, history = fives_graph.run_sync(None, DivisibleBy5(4)) # (4)! print(result) #> 5 # the full history is quite verbose (see below), so we'll just print the summary print([item.data_snapshot() for item in history]) #> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] ``` + +1. The `DivisibleBy5` node is parameterized with `None` as this graph doesn't use state, and `int` as it can end the run. +2. The `Increment` node doesn't return `End`, so the `RunEndT` generic parameter is omitted, state can also be omitted as the graph doesn't use state. +3. The graph is created with a sequence of nodes. +4. The graph is run synchronously with [`run_sync`][pydantic_graph.graph.Graph.run_sync] the initial state `None` and the start node `DivisibleBy5(4)` are passed as arguments. + _(This example is complete, it can be run "as is" with Python 3.10+)_ A [mermaid diagram](#mermaid-diagrams) for this graph can be generated with the following code: @@ -277,7 +283,7 @@ async def main(): 1. The state of the vending machine is defined as a dataclass with the user's balance and the product they've selected, if any. 2. A dictionary of products mapped to prices. 3. The `InsertCoin` node, [`BaseNode`][pydantic_graph.nodes.BaseNode] is parameterized with `MachineState` as that's the state used in this graph. -4. The `InsertCoin` node prompts the user to insert coins. We keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using `input` within nodes, see [below](#custom-control-flow) for how control flow can be managed when nodes require external input. +4. The `InsertCoin` node prompts the user to insert coins. We keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] within nodes, see [below](#custom-control-flow) for how control flow can be managed when nodes require external input. 5. The `CoinsInserted` node, again this is a [`dataclass`][dataclasses.dataclass], in this case with one field `amount`, thus nodes calling `CoinsInserted` must provide an amount. 6. Update the user's balance with the amount inserted. 7. If the user has already selected a product, go to `Purchase`, otherwise go to `SelectProduct`. @@ -303,8 +309,6 @@ from vending_machine import InsertCoin, vending_machine_graph vending_machine_graph.mermaid_code(start_node=InsertCoin) ``` -_(This example is complete, it can be run "as is")_ - The diagram generated by the above code is: ```mermaid @@ -469,9 +473,169 @@ In these scenarios the [`next`][pydantic_graph.graph.Graph.next] method can be u In this example, an AI asks the user a question, the user provides an answer, the AI evaluates the answer and ends if the user got it right or asks another question if they got it wrong. -```python {title="ai_q_and_a.py"} +??? example "`ai_q_and_a_graph.py` — `question_graph` definition" + ```python {title="ai_q_and_a_graph.py" noqa="I001" py="3.10"} + from __future__ import annotations as _annotations + + from dataclasses import dataclass, field + + from pydantic_graph import BaseNode, End, Graph, GraphContext + + from pydantic_ai import Agent + from pydantic_ai.format_as_xml import format_as_xml + from pydantic_ai.messages import ModelMessage + + ask_agent = Agent('openai:gpt-4o', result_type=str) + + + @dataclass + class QuestionState: + question: str | None = None + ask_agent_messages: list[ModelMessage] = field(default_factory=list) + evaluate_agent_messages: list[ModelMessage] = field(default_factory=list) + + + @dataclass + class Ask(BaseNode[QuestionState]): + async def run(self, ctx: GraphContext[QuestionState]) -> Answer: + result = await ask_agent.run( + 'Ask a simple question with a single correct answer.', + message_history=ctx.state.ask_agent_messages, + ) + ctx.state.ask_agent_messages += result.all_messages() + ctx.state.question = result.data + return Answer(result.data) + + + @dataclass + class Answer(BaseNode[QuestionState]): + question: str + answer: str | None = None + + async def run(self, ctx: GraphContext[QuestionState]) -> Evaluate: + assert self.answer is not None + return Evaluate(self.answer) + + + @dataclass + class EvaluationResult: + correct: bool + comment: str + + + evaluate_agent = Agent( + 'openai:gpt-4o', + result_type=EvaluationResult, + system_prompt='Given a question and answer, evaluate if the answer is correct.', + ) + + + @dataclass + class Evaluate(BaseNode[QuestionState]): + answer: str + + async def run( + self, + ctx: GraphContext[QuestionState], + ) -> End[str] | Reprimand: + assert ctx.state.question is not None + result = await evaluate_agent.run( + format_as_xml({'question': ctx.state.question, 'answer': self.answer}), + message_history=ctx.state.evaluate_agent_messages, + ) + ctx.state.evaluate_agent_messages += result.all_messages() + if result.data.correct: + return End(result.data.comment) + else: + return Reprimand(result.data.comment) + + + @dataclass + class Reprimand(BaseNode[QuestionState]): + comment: str + + async def run(self, ctx: GraphContext[QuestionState]) -> Ask: + print(f'Comment: {self.comment}') + #> Comment: Vichy is no longer the capital of France. + ctx.state.question = None + return Ask() + + + question_graph = Graph(nodes=(Ask, Answer, Evaluate, Reprimand)) + ``` + + _(This example is complete, it can be run "as is" with Python 3.10+)_ + + +```python {title="ai_q_and_a_run.py" noqa="I001" py="3.10"} +from rich.prompt import Prompt + +from pydantic_graph import End, HistoryStep + +from ai_q_and_a_graph import Ask, question_graph, QuestionState, Answer + + +async def main(): + state = QuestionState() # (1)! + node = Ask() # (2)! + history: list[HistoryStep[QuestionState]] = [] # (3)! + while True: + node = await question_graph.next(state, node, history) # (4)! + if isinstance(node, Answer): + node.answer = Prompt.ask(node.question) # (5)! + elif isinstance(node, End): # (6)! + print(f'Correct answer! {node.data}') + #> Correct answer! Well done, 1 + 1 = 2 + print([e.data_snapshot() for e in history]) + """ + [ + Ask(), + Answer(question='What is the capital of France?', answer='Vichy'), + Evaluate(answer='Vichy'), + Reprimand(comment='Vichy is no longer the capital of France.'), + Ask(), + Answer(question='what is 1 + 1?', answer='2'), + Evaluate(answer='2'), + ] + """ + return + # otherwise just continue +``` + +1. Create the state object which will be mutated by [`next`][pydantic_graph.graph.Graph.next]. +2. The start node is `Ask` but will be updated by [`next`][pydantic_graph.graph.Graph.next] as the graph runs. +3. The history of the graph run is stored in a list of [`HistoryStep`][pydantic_graph.state.HistoryStep] objects, again [`next`][pydantic_graph.graph.Graph.next] will update this list in place. +4. [Run][pydantic_graph.graph.Graph.next] the graph one node at a time, updating the state, current node and history as the graph runs. +5. If the current node is an `Answer` node, prompt the user for an answer. +6. Since we're using [`next`][pydantic_graph.graph.Graph.next] we have to manually check for an [`End`][pydantic_graph.nodes.End] and exit the loop if we get one. + +_(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ + +A [mermaid diagram](#mermaid-diagrams) for this graph can be generated with the following code: + +```py {title="ai_q_and_a_diagram.py" py="3.10"} +from ai_q_and_a_graph import Ask, question_graph + +question_graph.mermaid_code(start_node=Ask) +``` + +```mermaid +--- +title: question_graph +--- +stateDiagram-v2 + [*] --> Ask + Ask --> Answer + Answer --> Evaluate + Evaluate --> Reprimand + Evaluate --> [*] + Reprimand --> Ask ``` +You maybe have noticed that although this examples transfers control flow out of the graph run, we're still using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] to get user input, with the process hanging while we wait for the user to enter a response. + +TODO + ## Mermaid Diagrams TODO diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index d8158778c0..1ad7649d75 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -8,9 +8,11 @@ from __future__ import annotations as _annotations from dataclasses import dataclass, field +from pathlib import Path from typing import Annotated import logfire +import pydantic from devtools import debug from pydantic_graph import BaseNode, Edge, End, Graph, GraphContext, HistoryStep @@ -26,43 +28,30 @@ @dataclass class QuestionState: + question: str | None = None ask_agent_messages: list[ModelMessage] = field(default_factory=list) evaluate_agent_messages: list[ModelMessage] = field(default_factory=list) @dataclass class Ask(BaseNode[QuestionState]): - """Generate a question to ask the user. - - Uses the GPT-4o model to generate the question. - """ - - async def run( - self, ctx: GraphContext[QuestionState] - ) -> Annotated[Answer, Edge(label='ask the question')]: + async def run(self, ctx: GraphContext[QuestionState]) -> Answer: result = await ask_agent.run( 'Ask a simple question with a single correct answer.', message_history=ctx.state.ask_agent_messages, ) ctx.state.ask_agent_messages += result.all_messages() - return Answer(result.data) + ctx.state.question = result.data + return Answer() @dataclass class Answer(BaseNode[QuestionState]): - """Get the answer to the question from the user. - - This node must be completed outside the graph run. - """ - - question: str answer: str | None = None - async def run( - self, ctx: GraphContext[QuestionState] - ) -> Annotated[Evaluate, Edge(label='answer the question')]: + async def run(self, ctx: GraphContext[QuestionState]) -> Evaluate: assert self.answer is not None - return Evaluate(self.question, self.answer) + return Evaluate(self.answer) @dataclass @@ -75,34 +64,31 @@ class EvaluationResult: 'openai:gpt-4o', result_type=EvaluationResult, system_prompt='Given a question and answer, evaluate if the answer is correct.', - result_tool_name='evaluation', ) @dataclass class Evaluate(BaseNode[QuestionState]): - question: str answer: str async def run( self, ctx: GraphContext[QuestionState], - ) -> Congratulate | Castigate: + ) -> Congratulate | Reprimand: + assert ctx.state.question is not None result = await evaluate_agent.run( - format_as_xml({'question': self.question, 'answer': self.answer}), + format_as_xml({'question': ctx.state.question, 'answer': self.answer}), message_history=ctx.state.evaluate_agent_messages, ) ctx.state.evaluate_agent_messages += result.all_messages() if result.data.correct: return Congratulate(result.data.comment) else: - return Castigate(result.data.comment) + return Reprimand(result.data.comment) @dataclass class Congratulate(BaseNode[QuestionState, None]): - """Congratulate the user and end.""" - comment: str async def run( @@ -113,26 +99,23 @@ async def run( @dataclass -class Castigate(BaseNode[QuestionState]): - """Castigate the user, then ask another question.""" - +class Reprimand(BaseNode[QuestionState]): comment: str - async def run( - self, ctx: GraphContext[QuestionState] - ) -> Annotated[Ask, Edge(label='try again')]: + async def run(self, ctx: GraphContext[QuestionState]) -> Ask: print(f'Comment: {self.comment}') + # > Comment: Vichy is no longer the capital of France. + ctx.state.question = None return Ask() -question_graph = Graph(nodes=(Ask, Answer, Evaluate, Congratulate, Castigate)) -print(question_graph.mermaid_code(start_node=Ask, notes=False)) +question_graph = Graph(nodes=(Ask, Answer, Evaluate, Congratulate, Reprimand)) -async def main(): +async def run_as_continuous(): state = QuestionState() node = Ask() - history: list[HistoryStep[QuestionState]] = [] + history: list[HistoryStep[QuestionState, None]] = [] with logfire.span('run questions graph'): while True: node = await question_graph.next(state, node, history) @@ -140,11 +123,70 @@ async def main(): debug([e.data_snapshot() for e in history]) break elif isinstance(node, Answer): - node.answer = input(f'{node.question} ') + assert state.question + node.answer = input(f'{state.question} ') # otherwise just continue +ta = pydantic.TypeAdapter( + list[Annotated[HistoryStep[QuestionState, None], pydantic.Discriminator('kind')]] +) + + +async def run_as_cli(answer: str | None): + history_file = Path('question_graph_history.json') + if history_file.exists(): + history = ta.validate_json(history_file.read_bytes()) + else: + history = [] + + if history: + last = history[-1] + assert last.kind != 'node', 'expected last step to be a node' + state = last.state + assert answer is not None, 'answer is required to continue from history' + node = Answer(answer) + else: + state = QuestionState() + node = Ask() + + with logfire.span('run questions graph'): + while True: + node = await question_graph.next(state, node, history) + if isinstance(node, End): + debug([e.data_snapshot() for e in history]) + print('Finished!') + break + elif isinstance(node, Answer): + break + # otherwise just continue + + history_file.write_bytes(ta.dump_json(history)) + + if __name__ == '__main__': import asyncio - - asyncio.run(main()) + import sys + + try: + sub_command = sys.argv[1] + assert sub_command in ('continuous', 'cli', 'mermaid') + except (IndexError, AssertionError): + print( + 'Usage:\n' + ' uv run -m pydantic_ai_examples.question_graph meriad\n' + 'or:\n' + ' uv run -m pydantic_ai_examples.question_graph continuous\n' + 'or:\n' + ' uv run -m pydantic_ai_examples.question_graph cli [answer]', + file=sys.stderr, + ) + sys.exit(1) + + if sub_command == 'mermaid': + print(question_graph.mermaid_code(start_node=Ask)) + elif sub_command == 'continuous': + asyncio.run(run_as_continuous()) + else: + a = sys.argv[2] if len(sys.argv) > 2 else None + asyncio.run(run_as_cli(a)) diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index 1881b10b44..62848fa599 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -54,10 +54,12 @@ async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[No @classmethod @cache def get_id(cls) -> str: + """Get the ID of the node.""" return cls.__name__ @classmethod def get_note(cls) -> str | None: + """Get a note about the node to render on mermaid charts.""" if not cls.enable_docstring_notes: return None docstring = cls.__doc__ @@ -73,6 +75,7 @@ def get_note(cls) -> str | None: @classmethod def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRunEndT]: + """Get the node definition.""" type_hints = get_type_hints(cls.run, localns=local_ns, include_extras=True) try: return_hint = type_hints['return'] diff --git a/tests/test_examples.py b/tests/test_examples.py index 5d113943aa..d8063e1d26 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -167,6 +167,10 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: return '1' elif prompt == 'Select product': return 'crisps' + elif prompt == 'What is the capital of France?': + return 'Vichy' + elif prompt == 'what is 1 + 1?': + return '2' else: # pragma: no cover raise ValueError(f'Unexpected prompt: {prompt}') @@ -256,6 +260,15 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: tool_name='final_result_SeatPreference', args=ArgsDict({'row': 1, 'seat': 'A'}), ), + 'Ask a simple question with a single correct answer.': 'What is the capital of France?', + '\n What is the capital of France?\n Vichy\n': ToolCallPart( + tool_name='final_result', + args=ArgsDict({'correct': False, 'comment': 'Vichy is no longer the capital of France.'}), + ), + '\n what is 1 + 1?\n 2\n': ToolCallPart( + tool_name='final_result', + args=ArgsDict({'correct': True, 'comment': 'Well done, 1 + 1 = 2'}), + ), } @@ -284,6 +297,8 @@ async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelRes ) elif m.content.startswith('\n '): return ModelResponse(parts=[ToolCallPart(tool_name='final_result_EmailOk', args=ArgsDict({}))]) + elif m.content == 'Ask a simple question with a single correct answer.' and len(messages) > 2: + return ModelResponse.from_text(content='what is 1 + 1?') elif response := text_responses.get(m.content): if isinstance(response, str): return ModelResponse.from_text(content=response) From 5717bd5deb5daf544713ba815ebe80c9f6b61daa Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 14 Jan 2025 00:36:13 +0000 Subject: [PATCH 45/57] fix history serialization --- .../pydantic_ai_examples/question_graph.py | 21 +++--- pydantic_graph/pydantic_graph/graph.py | 46 ++++++++++++- pydantic_graph/pydantic_graph/nodes.py | 9 ++- pydantic_graph/pydantic_graph/state.py | 66 +++++++++++++++---- pydantic_graph/pyproject.toml | 1 + uv.lock | 2 + 6 files changed, 117 insertions(+), 28 deletions(-) diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index 1ad7649d75..1681eba184 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -7,12 +7,12 @@ from __future__ import annotations as _annotations +import copy from dataclasses import dataclass, field from pathlib import Path from typing import Annotated import logfire -import pydantic from devtools import debug from pydantic_graph import BaseNode, Edge, End, Graph, GraphContext, HistoryStep @@ -109,7 +109,9 @@ async def run(self, ctx: GraphContext[QuestionState]) -> Ask: return Ask() -question_graph = Graph(nodes=(Ask, Answer, Evaluate, Congratulate, Reprimand)) +question_graph = Graph( + nodes=(Ask, Answer, Evaluate, Congratulate, Reprimand), state_type=QuestionState +) async def run_as_continuous(): @@ -128,27 +130,23 @@ async def run_as_continuous(): # otherwise just continue -ta = pydantic.TypeAdapter( - list[Annotated[HistoryStep[QuestionState, None], pydantic.Discriminator('kind')]] -) - - async def run_as_cli(answer: str | None): history_file = Path('question_graph_history.json') if history_file.exists(): - history = ta.validate_json(history_file.read_bytes()) + history = question_graph.load_history(history_file.read_bytes()) else: history = [] if history: last = history[-1] - assert last.kind != 'node', 'expected last step to be a node' + assert last.kind == 'node', 'expected last step to be a node' state = last.state assert answer is not None, 'answer is required to continue from history' node = Answer(answer) else: state = QuestionState() node = Ask() + debug(state, node) with logfire.span('run questions graph'): while True: @@ -158,10 +156,13 @@ async def run_as_cli(answer: str | None): print('Finished!') break elif isinstance(node, Answer): + print(state.question) + # hack - history should show the state at the end of the node + history[-1].state = copy.deepcopy(state) break # otherwise just continue - history_file.write_bytes(ta.dump_json(history)) + history_file.write_bytes(question_graph.dump_history(history, indent=2)) if __name__ == '__main__': diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 8085307f85..0d1c02e362 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -4,18 +4,20 @@ import inspect from collections.abc import Sequence from dataclasses import dataclass +from functools import cached_property from pathlib import Path from time import perf_counter from types import FrameType -from typing import TYPE_CHECKING, Any, Callable, Generic +from typing import TYPE_CHECKING, Annotated, Any, Callable, Generic import logfire_api +import pydantic from typing_extensions import Literal, Unpack, assert_never from . import _utils, exceptions, mermaid from ._utils import get_parent_namespace from .nodes import BaseNode, End, GraphContext, NodeDef, RunEndT -from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state +from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state, nodes_schema_var __all__ = ('Graph',) @@ -66,6 +68,7 @@ async def run(self, ctx: GraphContext) -> Increment | End: from the graph. """ + state_type: type[StateT] | None name: str | None node_defs: dict[str, NodeDef[StateT, RunEndT]] snapshot_state: Callable[[StateT], StateT] @@ -74,6 +77,7 @@ def __init__( self, *, nodes: Sequence[type[BaseNode[StateT, RunEndT]]], + state_type: type[StateT] | None = None, name: str | None = None, snapshot_state: Callable[[StateT], StateT] = deep_copy_state, ): @@ -82,12 +86,14 @@ def __init__( Args: nodes: The nodes which make up the graph, nodes need to be unique and all be generic in the same state type. + state_type: The type of the state for the graph. name: Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method. snapshot_state: A function to snapshot the state of the graph, this is used in [`NodeStep`][pydantic_graph.state.NodeStep] and [`EndStep`][pydantic_graph.state.EndStep] to record the state before each step. """ + self.state_type = state_type self.name = name self.snapshot_state = snapshot_state @@ -282,6 +288,42 @@ async def next( history_step.duration = perf_counter() - start return next_node + def dump_history(self, history: list[HistoryStep[StateT, RunEndT]], *, indent: int | None = None) -> bytes: + """Dump the history of a graph run as JSON. + + Args: + history: The history of the graph run. + indent: The number of spaces to indent the JSON. + + Returns: + The JSON representation of the history. + """ + return self.history_type_adapter.dump_json(history, indent=indent) + + def load_history(self, json_bytes: str | bytes | bytearray) -> list[HistoryStep[StateT, RunEndT]]: + """Load the history of a graph run from JSON. + + Args: + json_bytes: The JSON representation of the history. + + Returns: + The history of the graph run. + """ + return self.history_type_adapter.validate_json(json_bytes) + + @cached_property + def history_type_adapter(self) -> pydantic.TypeAdapter[list[HistoryStep[StateT, RunEndT]]]: + nodes = [node_def.node for node_def in self.node_defs.values()] + token = nodes_schema_var.set(nodes) + try: + # TODO get the return type RunEndT + ta = pydantic.TypeAdapter( + list[Annotated[HistoryStep[self.state_type or Any, Any], pydantic.Discriminator('kind')]], + ) + finally: + nodes_schema_var.reset(token) + return ta # type: ignore + def mermaid_code( self, *, diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index 62848fa599..fff06f55f9 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -3,12 +3,16 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, is_dataclass from functools import cache -from typing import Any, ClassVar, Generic, get_origin, get_type_hints +from typing import TYPE_CHECKING, Any, ClassVar, Generic, get_origin, get_type_hints from typing_extensions import Never, TypeVar from . import _utils, exceptions -from .state import StateT + +if TYPE_CHECKING: + from .state import StateT +else: + StateT = TypeVar('StateT', default=None) __all__ = 'GraphContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'RunEndT', 'NodeRunEndT' @@ -26,6 +30,7 @@ class GraphContext(Generic[StateT]): """The state of the graph.""" +@dataclass class BaseNode(ABC, Generic[StateT, NodeRunEndT]): """Base class for a node.""" diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index 174607362c..b828d1f9de 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -1,21 +1,21 @@ from __future__ import annotations as _annotations import copy -from dataclasses import InitVar, dataclass, field +from collections.abc import Sequence +from contextvars import ContextVar +from dataclasses import dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Callable, Generic, Literal, Union +from typing import Annotated, Any, Callable, Generic, Literal, Union +import pydantic +from pydantic_core import core_schema from typing_extensions import TypeVar from . import _utils +from .nodes import BaseNode, End, RunEndT __all__ = 'StateT', 'NodeStep', 'EndStep', 'HistoryStep', 'deep_copy_state' -if TYPE_CHECKING: - from .nodes import BaseNode, End, RunEndT -else: - RunEndT = TypeVar('RunEndT', default=None) - StateT = TypeVar('StateT', default=None) """Type variable for the state in a graph.""" @@ -35,7 +35,8 @@ class NodeStep(Generic[StateT, RunEndT]): state: StateT """The state of the graph before the node is run.""" - node: BaseNode[StateT, RunEndT] + # node: Annotated[BaseNode[StateT, RunEndT], pydantic.WrapSerializer(node_serializer), pydantic.PlainValidator(node_validator)] + node: Annotated[BaseNode[StateT, RunEndT], CustomNodeSchema()] """The node that was run.""" start_ts: datetime = field(default_factory=_utils.now_utc) """The timestamp when the node started running.""" @@ -43,12 +44,13 @@ class NodeStep(Generic[StateT, RunEndT]): """The duration of the node run in seconds.""" kind: Literal['node'] = 'node' """The kind of history step, can be used as a discriminator when deserializing history.""" - snapshot_state: InitVar[Callable[[StateT], StateT]] = deep_copy_state + # waiting for https://github.com/pydantic/pydantic/issues/11264, should in InitVar + snapshot_state: Annotated[Callable[[StateT], StateT], pydantic.Field(exclude=True)] = deep_copy_state """Function to snapshot the state of the graph.""" - def __post_init__(self, snapshot_state: Callable[[StateT], StateT]): + def __post_init__(self): # Copy the state to prevent it from being modified by other code - self.state = snapshot_state(self.state) + self.state = self.snapshot_state(self.state) def data_snapshot(self) -> BaseNode[StateT, RunEndT]: """Returns a deep copy of [`self.node`][pydantic_graph.state.NodeStep.node]. @@ -70,12 +72,13 @@ class EndStep(Generic[StateT, RunEndT]): """The timestamp when the graph run ended.""" kind: Literal['end'] = 'end' """The kind of history step, can be used as a discriminator when deserializing history.""" - snapshot_state: InitVar[Callable[[StateT], StateT]] = deep_copy_state + # waiting for https://github.com/pydantic/pydantic/issues/11264, should in InitVar + snapshot_state: Annotated[Callable[[StateT], StateT], pydantic.Field(exclude=True)] = deep_copy_state """Function to snapshot the state of the graph.""" - def __post_init__(self, snapshot_state: Callable[[StateT], StateT]): + def __post_init__(self): # Copy the state to prevent it from being modified by other code - self.state = snapshot_state(self.state) + self.state = self.snapshot_state(self.state) def data_snapshot(self) -> End[RunEndT]: """Returns a deep copy of [`self.result`][pydantic_graph.state.EndStep.result]. @@ -91,3 +94,38 @@ def data_snapshot(self) -> End[RunEndT]: [`Graph.run`][pydantic_graph.graph.Graph.run] returns a list of these steps describing the execution of the graph, together with the run return value. """ + + +nodes_schema_var: ContextVar[Sequence[type[BaseNode[Any, Any]]]] = ContextVar('nodes_var') + + +class CustomNodeSchema: + def __get_pydantic_core_schema__( + self, _source_type: Any, handler: pydantic.GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + try: + nodes = nodes_schema_var.get() + except LookupError as e: + raise RuntimeError( + 'Unable to build a Pydantic schema for `NodeStep` or `HistoryStep` without setting `nodes_schema_var`. ' + 'You probably want to use ' + ) from e + nodes_annotated = [Annotated[node, pydantic.Tag(node.get_id())] for node in nodes] + nodes_union = Annotated[Union[tuple(nodes_annotated)], pydantic.Discriminator(self._node_discriminator)] + + schema = handler(nodes_union) + schema['serialization'] = core_schema.wrap_serializer_function_ser_schema( + function=self._node_serializer, + return_schema=core_schema.dict_schema(core_schema.str_schema(), core_schema.any_schema()), + ) + return schema + + @staticmethod + def _node_discriminator(node_data: Any) -> str: + return node_data.get('node_id') + + @staticmethod + def _node_serializer(node: Any, handler: pydantic.SerializerFunctionWrapHandler) -> dict[str, Any]: + node_dict = handler(node) + node_dict['node_id'] = node.get_id() + return node_dict diff --git a/pydantic_graph/pyproject.toml b/pydantic_graph/pyproject.toml index 89cbad516c..743814f6b6 100644 --- a/pydantic_graph/pyproject.toml +++ b/pydantic_graph/pyproject.toml @@ -36,6 +36,7 @@ requires-python = ">=3.9" dependencies = [ "httpx>=0.27.2", "logfire-api>=1.2.0", + "pydantic>=2.10", ] [tool.hatch.build.targets.wheel] diff --git a/uv.lock b/uv.lock index 1e998e93c7..38e2fc1712 100644 --- a/uv.lock +++ b/uv.lock @@ -2283,12 +2283,14 @@ source = { editable = "pydantic_graph" } dependencies = [ { name = "httpx" }, { name = "logfire-api" }, + { name = "pydantic" }, ] [package.metadata] requires-dist = [ { name = "httpx", specifier = ">=0.27.2" }, { name = "logfire-api", specifier = ">=1.2.0" }, + { name = "pydantic", specifier = ">=2.10" }, ] [[package]] From 3cb79c872f556830522e0b65839846f16d6c7c64 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 14 Jan 2025 18:21:50 +0000 Subject: [PATCH 46/57] add history (de)serialization tests --- pydantic_graph/pydantic_graph/graph.py | 68 ++------------------------ pydantic_graph/pydantic_graph/nodes.py | 1 - pydantic_graph/pydantic_graph/state.py | 4 +- tests/graph/test_history.py | 52 ++++++++++++++++++++ 4 files changed, 59 insertions(+), 66 deletions(-) create mode 100644 tests/graph/test_history.py diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 0d1c02e362..2effdd8dfd 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -132,75 +132,15 @@ async def main(): _, history = await never_42_graph.run(state, Increment()) print(state) #> MyState(number=2) - print(history) - ''' - [ - NodeStep( - state=MyState(number=1), - node=Increment(), - start_ts=datetime.datetime(...), - duration=0.0..., - kind='node', - ), - NodeStep( - state=MyState(number=2), - node=Check42(), - start_ts=datetime.datetime(...), - duration=0.0..., - kind='node', - ), - EndStep( - state=MyState(number=2), - result=End(data=None), - ts=datetime.datetime(...), - kind='end', - ), - ] - ''' + print(len(history)) + #> 3 state = MyState(41) _, history = await never_42_graph.run(state, Increment()) print(state) #> MyState(number=43) - print(history) - ''' - [ - NodeStep( - state=MyState(number=41), - node=Increment(), - start_ts=datetime.datetime(...), - duration=0.0..., - kind='node', - ), - NodeStep( - state=MyState(number=42), - node=Check42(), - start_ts=datetime.datetime(...), - duration=0.0..., - kind='node', - ), - NodeStep( - state=MyState(number=42), - node=Increment(), - start_ts=datetime.datetime(...), - duration=0.0..., - kind='node', - ), - NodeStep( - state=MyState(number=43), - node=Check42(), - start_ts=datetime.datetime(...), - duration=0.0..., - kind='node', - ), - EndStep( - state=MyState(number=43), - result=End(data=None), - ts=datetime.datetime(...), - kind='end', - ), - ] - ''' + print(len(history)) + #> 5 ``` """ if infer_name and self.name is None: diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index fff06f55f9..ef826bb2d0 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -30,7 +30,6 @@ class GraphContext(Generic[StateT]): """The state of the graph.""" -@dataclass class BaseNode(ABC, Generic[StateT, NodeRunEndT]): """Base class for a node.""" diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index b828d1f9de..c145be54c4 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -45,7 +45,9 @@ class NodeStep(Generic[StateT, RunEndT]): kind: Literal['node'] = 'node' """The kind of history step, can be used as a discriminator when deserializing history.""" # waiting for https://github.com/pydantic/pydantic/issues/11264, should in InitVar - snapshot_state: Annotated[Callable[[StateT], StateT], pydantic.Field(exclude=True)] = deep_copy_state + snapshot_state: Annotated[Callable[[StateT], StateT], pydantic.Field(exclude=True, repr=False)] = field( + default=deep_copy_state, repr=False + ) """Function to snapshot the state of the graph.""" def __post_init__(self): diff --git a/tests/graph/test_history.py b/tests/graph/test_history.py new file mode 100644 index 0000000000..ef4a455d0e --- /dev/null +++ b/tests/graph/test_history.py @@ -0,0 +1,52 @@ +from __future__ import annotations as _annotations + +from dataclasses import dataclass +from datetime import timezone + +import pytest +from inline_snapshot import snapshot + +from pydantic_graph import BaseNode, End, EndStep, Graph, GraphContext, NodeStep + +from ..conftest import IsFloat, IsNow + +pytestmark = pytest.mark.anyio + + +@dataclass +class MyState: + x: int + y: str + + +@dataclass +class Foo(BaseNode[MyState]): + async def run(self, ctx: GraphContext[MyState]) -> Bar: + ctx.state.x += 1 + return Bar() + + +@dataclass +class Bar(BaseNode[MyState, str]): + async def run(self, ctx: GraphContext[MyState]) -> End[str]: + ctx.state.y += 'y' + return End(f'x={ctx.state.x} y={ctx.state.y}') + + +graph = Graph(nodes=(Foo, Bar), state_type=MyState) + + +async def test_dump_history(): + result, history = await graph.run(MyState(1, ''), Foo()) + assert result == snapshot('x=2 y=y') + assert history == snapshot( + [ + NodeStep(state=MyState(x=1, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeStep(state=MyState(x=2, y=''), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + EndStep(state=MyState(x=2, y='y'), result=End(data='x=2 y=y'), ts=IsNow(tz=timezone.utc)), + ] + ) + history_json = graph.dump_history(history) + assert history_json.startswith(b'[{"state":') + history_loaded = graph.load_history(history_json) + assert history == history_loaded From 08bb7dd56ffd36846f2fb4de8ee5d13e754000b3 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 14 Jan 2025 19:18:10 +0000 Subject: [PATCH 47/57] add mermaid diagram section to graph docs --- docs/graph.md | 83 ++++++++++++++++++++++++-- pydantic_graph/pydantic_graph/nodes.py | 17 ++++-- tests/graph/test_mermaid.py | 2 +- 3 files changed, 92 insertions(+), 10 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index 1e2a8f2ffb..79ea118c8d 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -556,7 +556,6 @@ In this example, an AI asks the user a question, the user provides an answer, th async def run(self, ctx: GraphContext[QuestionState]) -> Ask: print(f'Comment: {self.comment}') - #> Comment: Vichy is no longer the capital of France. ctx.state.question = None return Ask() @@ -632,10 +631,84 @@ stateDiagram-v2 Reprimand --> Ask ``` -You maybe have noticed that although this examples transfers control flow out of the graph run, we're still using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] to get user input, with the process hanging while we wait for the user to enter a response. - -TODO +You maybe have noticed that although this examples transfers control flow out of the graph run, we're still using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] to get user input, with the process hanging while we wait for the user to enter a response. For an example of genuine out-of-process control flow, see the [question graph example](examples/question-graph.md). ## Mermaid Diagrams -TODO +Pydantic Graph, can generate [mermaid](https://mermaid.js.org/) [`stateDiagram-v2`](https://mermaid.js.org/syntax/stateDiagram.html) diagrams for graphs, as shown above. + +These diagrams can be generated with: + +* [`Graph.mermaid_code`][pydantic_graph.graph.Graph.mermaid_code] to generate the mermaid code for a graph +* [`Graph.mermaid_image`][pydantic_graph.graph.Graph.mermaid_image] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) +* ['Graph.mermaid_save`][pydantic_graph.graph.Graph.mermaid_save] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) and save it to a file + +Beyond the diagrams shown above, you can also customise mermaid diagrams with the following options: + +* [`Edge`][pydantic_graph.nodes.Edge] allows you to apply a label to an edge +* [`BaseNode.docstring_notes`][pydantic_graph.nodes.BaseNode.docstring_notes] and [`BaseNode.get_note`][pydantic_graph.nodes.BaseNode.get_note] allows you to add notes to nodes +* The [`highlighted_nodes`][pydantic_graph.graph.Graph.mermaid_code] parameter allows you to highlight specific node(s) in the diagram + +Putting that together, we can edit the last [`ai_q_and_a_graph.py`](#custom-control-flow) example to: + +* add labels to some edges +* add a note to the `Ask` node +* highlight the `Answer` node +* save the diagram as a `PNG` image to file + +```python {title="ai_q_and_a_graph_extra.py" test="skip" lint="skip" hl_lines="2 4 10-11 14 26 31"} +... +from typing import Annotated + +from pydantic_graph import BaseNode, End, Graph, GraphContext, Edge + +... + +@dataclass +class Ask(BaseNode[QuestionState]): + """Generate question using GPT-4o.""" + docstring_notes = True + async def run( + self, ctx: GraphContext[QuestionState] + ) -> Annotated[Answer, Edge(label='Ask the question')]: + ... + +... + +@dataclass +class Evaluate(BaseNode[QuestionState]): + answer: str + + async def run( + self, + ctx: GraphContext[QuestionState], + ) -> Annotated[End[str], Edge(label='success')] | Reprimand: + ... + +... + +question_graph.mermaid_save('image.png', highlighted_nodes=[Answer]) +``` + +_(This example is not complete and cannot be run directly)_ + +Would generate and image that looks like this: + +```mermaid +--- +title: question_graph +--- +stateDiagram-v2 + Ask --> Answer: Ask the question + note right of Ask + Judge the answer. + Decide on next step. + end note + Answer --> Evaluate + Evaluate --> Reprimand + Evaluate --> [*]: success + Reprimand --> Ask + +classDef highlighted fill:#fdff32 +class Answer highlighted +``` diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index ef826bb2d0..11e793c82c 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -33,8 +33,13 @@ class GraphContext(Generic[StateT]): class BaseNode(ABC, Generic[StateT, NodeRunEndT]): """Base class for a node.""" - enable_docstring_notes: ClassVar[bool] = True - """Set to `False` to not generate mermaid diagram notes from the class's docstring.""" + docstring_notes: ClassVar[bool] = False + """Set to `True` to generate mermaid diagram notes from the class's docstring. + + While this can add valuable information to the diagram, it can make diagrams harder to view, hence + it is disabled by default. You can also customise notes overriding the + [`get_note`][pydantic_graph.nodes.BaseNode.get_note] method. + """ @abstractmethod async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: @@ -63,8 +68,12 @@ def get_id(cls) -> str: @classmethod def get_note(cls) -> str | None: - """Get a note about the node to render on mermaid charts.""" - if not cls.enable_docstring_notes: + """Get a note about the node to render on mermaid charts. + + By default, this returns a note only if [`docstring_notes`][pydantic_graph.nodes.BaseNode.docstring_notes] + is `True`. You can override this method to customise the node notes. + """ + if not cls.docstring_notes: return None docstring = cls.__doc__ # dataclasses get an automatic docstring which is just their signature, we don't want that diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 1f40df5c8c..fc9a99c123 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -46,7 +46,7 @@ async def run(self, ctx: GraphContext) -> Annotated[Foo, Edge(label='spam to foo class Eggs(BaseNode[None, None]): """This is the docstring for Eggs.""" - enable_docstring_notes = False + docstring_notes = False async def run(self, ctx: GraphContext) -> Annotated[End[None], Edge(label='eggs to end')]: raise NotImplementedError() From 466a7dff817ae195a782303c984fae7f18d4fec9 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 14 Jan 2025 19:27:37 +0000 Subject: [PATCH 48/57] fix tests --- .github/workflows/ci.yml | 2 +- tests/graph/test_mermaid.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cc694f175c..efbce8fc51 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,7 +10,7 @@ on: env: CI: true - COLUMNS: 120 + COLUMNS: 150 UV_PYTHON: 3.12 UV_FROZEN: '1' diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index fc9a99c123..2cf2b01423 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -38,6 +38,8 @@ async def run(self, ctx: GraphContext) -> End[None]: class Spam(BaseNode): """This is the docstring for Spam.""" + docstring_notes = True + async def run(self, ctx: GraphContext) -> Annotated[Foo, Edge(label='spam to foo')]: raise NotImplementedError() @@ -192,6 +194,11 @@ def test_mermaid_code_all_nodes(): """) +def test_docstring_notes_classvar(): + assert Spam.docstring_notes is True + assert repr(Spam()) == 'Spam()' + + @pytest.fixture def httpx_with_handler() -> Iterator[HttpxWithHandler]: client: httpx.Client | None = None From 8098d34462ff34a87d5deecf3a620e4251ee7e66 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 14 Jan 2025 19:31:52 +0000 Subject: [PATCH 49/57] add exceptions docs --- docs/api/pydantic_graph/exceptions.md | 3 +++ mkdocs.yml | 1 + 2 files changed, 4 insertions(+) create mode 100644 docs/api/pydantic_graph/exceptions.md diff --git a/docs/api/pydantic_graph/exceptions.md b/docs/api/pydantic_graph/exceptions.md new file mode 100644 index 0000000000..ab9f3782ee --- /dev/null +++ b/docs/api/pydantic_graph/exceptions.md @@ -0,0 +1,3 @@ +# `pydantic_graph.exceptions` + +::: pydantic_graph.exceptions diff --git a/mkdocs.yml b/mkdocs.yml index 81faa18225..1649d3dd48 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -61,6 +61,7 @@ nav: - api/pydantic_graph/nodes.md - api/pydantic_graph/state.md - api/pydantic_graph/mermaid.md + - api/pydantic_graph/exceptions.md extra: # hide the "Made with Material for MkDocs" message From a3f507a996cc020c0f04117371829f87fc18f7ab Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 14 Jan 2025 19:40:51 +0000 Subject: [PATCH 50/57] docs tweaks --- docs/graph.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index 79ea118c8d..0aa2156911 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -20,7 +20,7 @@ While this library is developed as part of the PydanticAI; it has no dependency `pydantic-graph` is designed for advanced users and makes heavy use of Python generics and types hints. It is not designed to be as beginner-friendly as PydanticAI. -!!! note "Every Early beta" +!!! note "Very Early beta" Graph support was [introduced](https://github.com/pydantic/pydantic-ai/pull/528) in v0.0.19 and is in very earlier beta. The API is subject to change. The documentation is incomplete. The implementation is incomplete. ## Installation @@ -89,7 +89,7 @@ class MyNode(BaseNode[MyState]): # (1)! We could extend `MyNode` to optionally end the run if `foo` is divisible by 5: -```py {title="intermediate_or_end_node.py" hl_lines="7 13" noqa="F821" test="skip"} +```py {title="intermediate_or_end_node.py" hl_lines="7 13 15" noqa="F821" test="skip"} from dataclasses import dataclass from pydantic_graph import BaseNode, End, GraphContext @@ -641,7 +641,7 @@ These diagrams can be generated with: * [`Graph.mermaid_code`][pydantic_graph.graph.Graph.mermaid_code] to generate the mermaid code for a graph * [`Graph.mermaid_image`][pydantic_graph.graph.Graph.mermaid_image] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) -* ['Graph.mermaid_save`][pydantic_graph.graph.Graph.mermaid_save] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) and save it to a file +* [`Graph.mermaid_save`][pydantic_graph.graph.Graph.mermaid_save] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) and save it to a file Beyond the diagrams shown above, you can also customise mermaid diagrams with the following options: From 4e9b5162e980e172cc8994e3ab11f6382e640290 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 15 Jan 2025 13:19:47 +0000 Subject: [PATCH 51/57] copy edits from @dmontagu --- docs/graph.md | 32 ++++++++-------- docs/multi-agent-applications.md | 4 ++ pydantic_graph/README.md | 53 +++++++++++++++++++++++--- pydantic_graph/pydantic_graph/nodes.py | 2 +- pydantic_graph/pydantic_graph/state.py | 3 +- pydantic_graph/pyproject.toml | 6 +++ 6 files changed, 76 insertions(+), 24 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index 0aa2156911..ba5e00f511 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -53,14 +53,14 @@ Subclasses of [`BaseNode`][pydantic_graph.nodes.BaseNode] to define nodes. Nodes which are generally [`dataclass`es][dataclasses.dataclass] include: -* any parameters required when calling the node -* the business logic to execute the node -* return annotations which are read by `pydantic-graph` to determine the outgoing edges of the node +* fields containing any parameters required/optional when calling the node +* the business logic to execute the node, in the [`run`][pydantic_graph.nodes.BaseNode.run] method +* return annotations of the [`run`][pydantic_graph.nodes.BaseNode.run] method, which are read by `pydantic-graph` to determine the outgoing edges of the node Nodes are generic in: -* **state** which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.state.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter -* **graph return type** this only applies if the node returns [`End`][pydantic_graph.nodes.End], [`RunEndT`][pydantic_graph.nodes.RunEndT] has a default of [Never][typing.Never] so this generic parameter can be omitted if the node doesn't return `End`, but must be included if it does. +* **state**, which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.state.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter +* **graph return type** — this only applies if the node returns [`End`][pydantic_graph.nodes.End]. [`RunEndT`][pydantic_graph.nodes.RunEndT] has a default of [Never][typing.Never] so this generic parameter can be omitted if the node doesn't return `End`, but must be included if it does. Here's an example of a start or intermediate node in a graph — it can't end the run as it doesn't return [`End`][pydantic_graph.nodes.End]: @@ -114,7 +114,7 @@ class MyNode(BaseNode[MyState, int]): # (1)! ### Graph -[`Graph`][pydantic_graph.graph.Graph] — The graph itself, made up of a set of nodes. +[`Graph`][pydantic_graph.graph.Graph] — this is the execution graph itself, made up of a set of [node classes](#nodes) (i.e., `BaseNode` subclasses). `Graph` is generic in: @@ -284,7 +284,7 @@ async def main(): 2. A dictionary of products mapped to prices. 3. The `InsertCoin` node, [`BaseNode`][pydantic_graph.nodes.BaseNode] is parameterized with `MachineState` as that's the state used in this graph. 4. The `InsertCoin` node prompts the user to insert coins. We keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] within nodes, see [below](#custom-control-flow) for how control flow can be managed when nodes require external input. -5. The `CoinsInserted` node, again this is a [`dataclass`][dataclasses.dataclass], in this case with one field `amount`, thus nodes calling `CoinsInserted` must provide an amount. +5. The `CoinsInserted` node; again this is a [`dataclass`][dataclasses.dataclass], in this case with one field `amount`, thus nodes calling `CoinsInserted` must provide an amount. 6. Update the user's balance with the amount inserted. 7. If the user has already selected a product, go to `Purchase`, otherwise go to `SelectProduct`. 8. In the `Purchase` node, look up the price of the product if the user entered a valid product. @@ -293,11 +293,11 @@ async def main(): 11. If the balance is insufficient, to go `InsertCoin` to prompt the user to insert more coins. 12. If the product is invalid, go to `SelectProduct` to prompt the user to select a product again. 13. The graph is created by passing a list of nodes to [`Graph`][pydantic_graph.graph.Graph]. Order of nodes is not important, but will alter how [diagrams](#mermaid-diagrams) are displayed. -14. Initialize the state, this will be passed to the graph run and mutated as the graph runs. -15. Run the graph with the initial state, since the graph can be run from any node, we must pass the start node, in this case `InsertCoin`. [`Graph.run`][pydantic_graph.graph.Graph.run] returns a tuple of the return value (`None`) in this case, and the [history][pydantic_graph.state.HistoryStep] of the graph run. -16. The return type of the node's [`run`][pydantic_graph.nodes.BaseNode.run] method is important, it's used to determine the outgoing edges of the node, this in turn is used to render [mermaid diagrams](#mermaid-diagrams) and is enforced at runtime. -17. The return type of `CoinsInserted`s [`run`][pydantic_graph.nodes.BaseNode.run] method is a union, meaning multiple outgoing edges are possible. -18. Unlike other nodes `Purchase` can end the run, so the [`RunEndT`][pydantic_graph.nodes.RunEndT] generic parameter must be set, in this case it's `None` since the graph run return type is `None`. +14. Initialize the state. This will be passed to the graph run and mutated as the graph runs. +15. Run the graph with the initial state. Since the graph can be run from any node, we must pass the start node — in this case, `InsertCoin`. [`Graph.run`][pydantic_graph.graph.Graph.run] returns a tuple of the return value (`None`) in this case, and the [history][pydantic_graph.state.HistoryStep] of the graph run. +16. The return type of the node's [`run`][pydantic_graph.nodes.BaseNode.run] method is important as it is used to determine the outgoing edges of the node. This information in turn is used to render [mermaid diagrams](#mermaid-diagrams) and is enforced at runtime to detect misbehavior as soon as possible. +17. The return type of `CoinsInserted`'s [`run`][pydantic_graph.nodes.BaseNode.run] method is a union, meaning multiple outgoing edges are possible. +18. Unlike other nodes, `Purchase` can end the run, so the [`RunEndT`][pydantic_graph.nodes.RunEndT] generic parameter must be set. In this case it's `None` since the graph run return type is `None`. _(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ @@ -467,7 +467,7 @@ _(This example is complete, it can be run "as is" with Python 3.10+ — you'll n ## Custom Control Flow -In many real-world applications, Graphs cannot run uninterrupted from start to finish — they require external input or run over an extended period of time where a single process cannot execute the entire graph. +In many real-world applications, Graphs cannot run uninterrupted from start to finish — they might require external input, or run over an extended period of time such that a single process cannot execute the entire graph run from start to finish without interruption. In these scenarios the [`next`][pydantic_graph.graph.Graph.next] method can be used to run the graph one node at a time. @@ -603,7 +603,7 @@ async def main(): 1. Create the state object which will be mutated by [`next`][pydantic_graph.graph.Graph.next]. 2. The start node is `Ask` but will be updated by [`next`][pydantic_graph.graph.Graph.next] as the graph runs. -3. The history of the graph run is stored in a list of [`HistoryStep`][pydantic_graph.state.HistoryStep] objects, again [`next`][pydantic_graph.graph.Graph.next] will update this list in place. +3. The history of the graph run is stored in a list of [`HistoryStep`][pydantic_graph.state.HistoryStep] objects. Again [`next`][pydantic_graph.graph.Graph.next] will update this list in place. 4. [Run][pydantic_graph.graph.Graph.next] the graph one node at a time, updating the state, current node and history as the graph runs. 5. If the current node is an `Answer` node, prompt the user for an answer. 6. Since we're using [`next`][pydantic_graph.graph.Graph.next] we have to manually check for an [`End`][pydantic_graph.nodes.End] and exit the loop if we get one. @@ -635,7 +635,7 @@ You maybe have noticed that although this examples transfers control flow out of ## Mermaid Diagrams -Pydantic Graph, can generate [mermaid](https://mermaid.js.org/) [`stateDiagram-v2`](https://mermaid.js.org/syntax/stateDiagram.html) diagrams for graphs, as shown above. +Pydantic Graph can generate [mermaid](https://mermaid.js.org/) [`stateDiagram-v2`](https://mermaid.js.org/syntax/stateDiagram.html) diagrams for graphs, as shown above. These diagrams can be generated with: @@ -643,7 +643,7 @@ These diagrams can be generated with: * [`Graph.mermaid_image`][pydantic_graph.graph.Graph.mermaid_image] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) * [`Graph.mermaid_save`][pydantic_graph.graph.Graph.mermaid_save] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) and save it to a file -Beyond the diagrams shown above, you can also customise mermaid diagrams with the following options: +Beyond the diagrams shown above, you can also customize mermaid diagrams with the following options: * [`Edge`][pydantic_graph.nodes.Edge] allows you to apply a label to an edge * [`BaseNode.docstring_notes`][pydantic_graph.nodes.BaseNode.docstring_notes] and [`BaseNode.get_note`][pydantic_graph.nodes.BaseNode.get_note] allows you to add notes to nodes diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index e5ced781e0..38e2bbd074 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -330,6 +330,10 @@ graph TB seat_preference_agent --> END ``` +## Pydantic Graphs + +See the [graph](graph.md) documentation on when and how to use graphs. + ## Examples The following examples demonstrate how to use dependencies in PydanticAI: diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md index 069f58924d..e01f9bf80d 100644 --- a/pydantic_graph/README.md +++ b/pydantic_graph/README.md @@ -3,14 +3,57 @@ [![CI](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain) [![Coverage](https://coverage-badge.samuelcolvin.workers.dev/pydantic/pydantic-ai.svg)](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai) [![PyPI](https://img.shields.io/pypi/v/pydantic-graph.svg)](https://pypi.python.org/pypi/pydantic-graph) -[![versions](https://img.shields.io/pypi/pyversions/pydantic-graph.svg)](https://github.com/pydantic/pydantic-ai) -[![license](https://img.shields.io/github/license/pydantic/pydantic-ai.svg?v)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE) +[![python versions](https://img.shields.io/pypi/pyversions/pydantic-graph.svg)](https://github.com/pydantic/pydantic-ai) +[![license](https://img.shields.io/github/license/pydantic/pydantic-ai.svg)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE) Graph and finite state machine library. -This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency -on `pydantic-ai` or related packages and can be considered as a pure graph library. +This library is developed as part of [PydanticAI](https://ai.pydantic.dev), however it has no dependency +on `pydantic-ai` or related packages and can be considered as a pure graph-based state machine library. You may find it useful whether or not you're using PydanticAI or even building with GenAI. As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. -`pydantic-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes. +`pydantic-graph` allows you to define graphs using standard Python syntax. In particular, edges are defined using the return type hint of nodes. + +Full documentation is available at [ai.pydantic.dev/graph](https://ai.pydantic.dev/graph). + +Here's a basic example: + +```python +from __future__ import annotations + +from dataclasses import dataclass + +from pydantic_graph import BaseNode, End, Graph, GraphContext + + +@dataclass +class DivisibleBy5(BaseNode[None, int]): + foo: int + + async def run( + self, + ctx: GraphContext, + ) -> Increment | End[int]: + if self.foo % 5 == 0: + return End(self.foo) + else: + return Increment(self.foo) + + +@dataclass +class Increment(BaseNode): + foo: int + + async def run(self, ctx: GraphContext) -> DivisibleBy5: + return DivisibleBy5(self.foo + 1) + + +fives_graph = Graph(nodes=[DivisibleBy5, Increment]) +result, history = fives_graph.run_sync(None, DivisibleBy5(4)) +print(result) +#> 5 +# the full history is quite verbose (see below), so we'll just print the summary +print([item.data_snapshot() for item in history]) +#> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] +``` diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index 11e793c82c..ec97b16e84 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -142,7 +142,7 @@ class Edge: class NodeDef(Generic[StateT, NodeRunEndT]): """Definition of a node. - This is an internal representation of a node, it shouldn't be necessary to use it directly. + This is a primarily internal representation of a node; in general, it shouldn't be necessary to use it directly. Used by [`Graph`][pydantic_graph.graph.Graph] to store information about a node, and when generating mermaid graphs. diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index c145be54c4..b1b596b855 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -35,7 +35,6 @@ class NodeStep(Generic[StateT, RunEndT]): state: StateT """The state of the graph before the node is run.""" - # node: Annotated[BaseNode[StateT, RunEndT], pydantic.WrapSerializer(node_serializer), pydantic.PlainValidator(node_validator)] node: Annotated[BaseNode[StateT, RunEndT], CustomNodeSchema()] """The node that was run.""" start_ts: datetime = field(default_factory=_utils.now_utc) @@ -44,7 +43,7 @@ class NodeStep(Generic[StateT, RunEndT]): """The duration of the node run in seconds.""" kind: Literal['node'] = 'node' """The kind of history step, can be used as a discriminator when deserializing history.""" - # waiting for https://github.com/pydantic/pydantic/issues/11264, should in InitVar + # waiting for https://github.com/pydantic/pydantic/issues/11264, should be an InitVar snapshot_state: Annotated[Callable[[StateT], StateT], pydantic.Field(exclude=True, repr=False)] = field( default=deep_copy_state, repr=False ) diff --git a/pydantic_graph/pyproject.toml b/pydantic_graph/pyproject.toml index 743814f6b6..948e48e462 100644 --- a/pydantic_graph/pyproject.toml +++ b/pydantic_graph/pyproject.toml @@ -39,5 +39,11 @@ dependencies = [ "pydantic>=2.10", ] +[project.urls] +Homepage = "https://ai.pydantic.dev/graph/tree/main/pydantic_graph" +Source = "https://github.com/pydantic/pydantic-ai" +Documentation = "https://ai.pydantic.dev/graph" +Changelog = "https://github.com/pydantic/pydantic-ai/releases" + [tool.hatch.build.targets.wheel] packages = ["pydantic_graph"] From a834eed97ea0a6cf1a836a44c19d517bb68ab93c Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 15 Jan 2025 14:01:41 +0000 Subject: [PATCH 52/57] fix pydantic-graph readme --- pydantic_graph/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md index e01f9bf80d..6ac17fb216 100644 --- a/pydantic_graph/README.md +++ b/pydantic_graph/README.md @@ -19,7 +19,7 @@ Full documentation is available at [ai.pydantic.dev/graph](https://ai.pydantic.d Here's a basic example: -```python +```python {noqa="I001" py="3.10"} from __future__ import annotations from dataclasses import dataclass From 447a259c65578eeb72b5720a2bad9054852da04d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 15 Jan 2025 14:46:27 +0000 Subject: [PATCH 53/57] adding deps to graphs --- docs/graph.md | 21 ++++---- .../pydantic_ai_examples/question_graph.py | 6 +-- pydantic_graph/README.md | 4 +- pydantic_graph/pydantic_graph/graph.py | 54 +++++++++++-------- pydantic_graph/pydantic_graph/mermaid.py | 8 +-- pydantic_graph/pydantic_graph/nodes.py | 18 ++++--- pydantic_graph/pydantic_graph/state.py | 8 +-- tests/graph/test_graph.py | 36 ++++++------- tests/graph/test_history.py | 4 +- tests/graph/test_mermaid.py | 6 +-- tests/graph/test_state.py | 4 +- tests/typed_graph.py | 54 ++++++++++++++++--- 12 files changed, 139 insertions(+), 84 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index ba5e00f511..1128b2561c 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -12,7 +12,7 @@ Unless you're sure you need a graph, you probably don't. -Graphs and finite state machines (FSMs) are a powerful abstraction to model, control and visualize complex workflows. +Graphs and finite state machines (FSMs) are a powerful abstraction to model, execute, control and visualize complex workflows. Alongside PydanticAI, we've developed `pydantic-graph` — an async graph and state machine library for Python where nodes and edges are defined using type hints. @@ -60,6 +60,7 @@ Nodes which are generally [`dataclass`es][dataclasses.dataclass] include: Nodes are generic in: * **state**, which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.state.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter +* **deps**, which must have the same type as the deps of the graph they're included in, [`DepsT`][pydantic_graph.state.DepsT] has a default of `None`, so if you're not using deps you can omit this generic parameter * **graph return type** — this only applies if the node returns [`End`][pydantic_graph.nodes.End]. [`RunEndT`][pydantic_graph.nodes.RunEndT] has a default of [Never][typing.Never] so this generic parameter can be omitted if the node doesn't return `End`, but must be included if it does. Here's an example of a start or intermediate node in a graph — it can't end the run as it doesn't return [`End`][pydantic_graph.nodes.End]: @@ -96,7 +97,7 @@ from pydantic_graph import BaseNode, End, GraphContext @dataclass -class MyNode(BaseNode[MyState, int]): # (1)! +class MyNode(BaseNode[MyState, None, int]): # (1)! foo: int async def run( @@ -109,7 +110,7 @@ class MyNode(BaseNode[MyState, int]): # (1)! return AnotherNode() ``` -1. We parameterize the node with the return type (`int` in this case) as well as state. +1. We parameterize the node with the return type (`int` in this case) as well as state. Because generic parameters are positional-only, we have to include `None` as the second parameter representing deps. 2. The return type of the `run` method is now a union of `AnotherNode` and `End[int]`, this allows the node to end the run if `foo` is divisible by 5. ### Graph @@ -132,7 +133,7 @@ from pydantic_graph import BaseNode, End, Graph, GraphContext @dataclass -class DivisibleBy5(BaseNode[None, int]): # (1)! +class DivisibleBy5(BaseNode[None, None, int]): # (1)! foo: int async def run( @@ -154,7 +155,7 @@ class Increment(BaseNode): # (2)! fives_graph = Graph(nodes=[DivisibleBy5, Increment]) # (3)! -result, history = fives_graph.run_sync(None, DivisibleBy5(4)) # (4)! +result, history = fives_graph.run_sync(DivisibleBy5(4)) # (4)! print(result) #> 5 # the full history is quite verbose (see below), so we'll just print the summary @@ -247,7 +248,7 @@ PRODUCT_PRICES = { # (2)! @dataclass -class Purchase(BaseNode[MachineState, None]): # (18)! +class Purchase(BaseNode[MachineState, None, None]): # (18)! product: str async def run( @@ -275,7 +276,7 @@ vending_machine_graph = Graph( # (13)! async def main(): state = MachineState() # (14)! - await vending_machine_graph.run(state, InsertCoin()) # (15)! + await vending_machine_graph.run(InsertCoin(), state=state) # (15)! print(f'purchase successful item={state.product} change={state.user_balance:0.2f}') #> purchase successful item=crisps change=0.25 ``` @@ -430,7 +431,7 @@ feedback_agent = Agent[None, EmailRequiresWrite | EmailOk]( @dataclass -class Feedback(BaseNode[State, Email]): +class Feedback(BaseNode[State, None, Email]): email: Email async def run( @@ -453,7 +454,7 @@ async def main(): ) state = State(user) feedback_graph = Graph(nodes=(WriteEmail, Feedback)) - email, _ = await feedback_graph.run(state, WriteEmail()) + email, _ = await feedback_graph.run(WriteEmail(), state=state) print(email) """ Email( @@ -579,7 +580,7 @@ async def main(): node = Ask() # (2)! history: list[HistoryStep[QuestionState]] = [] # (3)! while True: - node = await question_graph.next(state, node, history) # (4)! + node = await question_graph.next(node, history, state=state) # (4)! if isinstance(node, Answer): node.answer = Prompt.ask(node.question) # (5)! elif isinstance(node, End): # (6)! diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index 1681eba184..636d790605 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -88,7 +88,7 @@ async def run( @dataclass -class Congratulate(BaseNode[QuestionState, None]): +class Congratulate(BaseNode[QuestionState, None, None]): comment: str async def run( @@ -120,7 +120,7 @@ async def run_as_continuous(): history: list[HistoryStep[QuestionState, None]] = [] with logfire.span('run questions graph'): while True: - node = await question_graph.next(state, node, history) + node = await question_graph.next(node, history, state=state) if isinstance(node, End): debug([e.data_snapshot() for e in history]) break @@ -150,7 +150,7 @@ async def run_as_cli(answer: str | None): with logfire.span('run questions graph'): while True: - node = await question_graph.next(state, node, history) + node = await question_graph.next(node, history, state=state) if isinstance(node, End): debug([e.data_snapshot() for e in history]) print('Finished!') diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md index 6ac17fb216..0b2508f033 100644 --- a/pydantic_graph/README.md +++ b/pydantic_graph/README.md @@ -28,7 +28,7 @@ from pydantic_graph import BaseNode, End, Graph, GraphContext @dataclass -class DivisibleBy5(BaseNode[None, int]): +class DivisibleBy5(BaseNode[None, None, int]): foo: int async def run( @@ -50,7 +50,7 @@ class Increment(BaseNode): fives_graph = Graph(nodes=[DivisibleBy5, Increment]) -result, history = fives_graph.run_sync(None, DivisibleBy5(4)) +result, history = fives_graph.run_sync(DivisibleBy5(4)) print(result) #> 5 # the full history is quite verbose (see below), so we'll just print the summary diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 2effdd8dfd..d146ca507c 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -16,7 +16,7 @@ from . import _utils, exceptions, mermaid from ._utils import get_parent_namespace -from .nodes import BaseNode, End, GraphContext, NodeDef, RunEndT +from .nodes import BaseNode, DepsT, End, GraphContext, NodeDef, RunEndT from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state, nodes_schema_var __all__ = ('Graph',) @@ -25,7 +25,7 @@ @dataclass(init=False) -class Graph(Generic[StateT, RunEndT]): +class Graph(Generic[StateT, DepsT, RunEndT]): """Definition of a graph. In `pydantic-graph`, a graph is a collection of nodes that can be run in sequence. The nodes define @@ -52,7 +52,7 @@ async def run(self, ctx: GraphContext) -> Check42: return Check42() @dataclass - class Check42(BaseNode[MyState]): + class Check42(BaseNode[MyState, None, None]): async def run(self, ctx: GraphContext) -> Increment | End: if ctx.state.number == 42: return Increment() @@ -70,13 +70,13 @@ async def run(self, ctx: GraphContext) -> Increment | End: state_type: type[StateT] | None name: str | None - node_defs: dict[str, NodeDef[StateT, RunEndT]] + node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] snapshot_state: Callable[[StateT], StateT] def __init__( self, *, - nodes: Sequence[type[BaseNode[StateT, RunEndT]]], + nodes: Sequence[type[BaseNode[StateT, DepsT, RunEndT]]], state_type: type[StateT] | None = None, name: str | None = None, snapshot_state: Callable[[StateT], StateT] = deep_copy_state, @@ -98,7 +98,7 @@ def __init__( self.snapshot_state = snapshot_state parent_namespace = get_parent_namespace(inspect.currentframe()) - self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = {} + self.node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] = {} for node in nodes: self._register_node(node, parent_namespace) @@ -106,17 +106,19 @@ def __init__( async def run( self, - state: StateT, - start_node: BaseNode[StateT, RunEndT], + start_node: BaseNode[StateT, DepsT, RunEndT], *, + state: StateT = None, + deps: DepsT = None, infer_name: bool = True, ) -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: """Run the graph from a starting node until it ends. Args: - state: The initial state of the graph. start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. + state: The initial state of the graph. + deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. Returns: @@ -129,14 +131,14 @@ async def run( async def main(): state = MyState(1) - _, history = await never_42_graph.run(state, Increment()) + _, history = await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=2) print(len(history)) #> 3 state = MyState(41) - _, history = await never_42_graph.run(state, Increment()) + _, history = await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=43) print(len(history)) @@ -153,7 +155,7 @@ async def main(): start=start_node, ) as run_span: while True: - next_node = await self.next(state, start_node, history, infer_name=False) + next_node = await self.next(start_node, history, state=state, deps=deps, infer_name=False) if isinstance(next_node, End): history.append(EndStep(state, next_node, snapshot_state=self.snapshot_state)) run_span.set_attribute('history', history) @@ -170,9 +172,10 @@ async def main(): def run_sync( self, - state: StateT, - start_node: BaseNode[StateT, RunEndT], + start_node: BaseNode[StateT, DepsT, RunEndT], *, + state: StateT = None, + deps: DepsT = None, infer_name: bool = True, ) -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: """Run the graph synchronously. @@ -181,9 +184,10 @@ def run_sync( You therefore can't use this method inside async code or if there's an active event loop. Args: - state: The initial state of the graph. start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. + state: The initial state of the graph. + deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. Returns: @@ -191,22 +195,26 @@ def run_sync( """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - return asyncio.get_event_loop().run_until_complete(self.run(state, start_node, infer_name=False)) + return asyncio.get_event_loop().run_until_complete( + self.run(start_node, state=state, deps=deps, infer_name=False) + ) async def next( self, - state: StateT, - node: BaseNode[StateT, RunEndT], + node: BaseNode[StateT, DepsT, RunEndT], history: list[HistoryStep[StateT, RunEndT]], *, + state: StateT = None, + deps: DepsT = None, infer_name: bool = True, - ) -> BaseNode[StateT, Any] | End[RunEndT]: + ) -> BaseNode[StateT, DepsT, Any] | End[RunEndT]: """Run a node in the graph and return the next node to run. Args: - state: The current state of the graph. node: The node to run. history: The history of the graph run so far. NOTE: this will be mutated to add the new step. + state: The current state of the graph. + deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. Returns: @@ -221,7 +229,7 @@ async def next( history_step: NodeStep[StateT, RunEndT] = NodeStep(state, node) history.append(history_step) - ctx = GraphContext(state) + ctx = GraphContext(state, deps) with _logfire.span('run node {node_id}', node_id=node_id, node=node): start = perf_counter() next_node = await node.run(ctx) @@ -382,7 +390,9 @@ def mermaid_save( kwargs['title'] = self.name mermaid.save_image(path, self, **kwargs) - def _register_node(self, node: type[BaseNode[StateT, RunEndT]], parent_namespace: dict[str, Any] | None) -> None: + def _register_node( + self, node: type[BaseNode[StateT, DepsT, RunEndT]], parent_namespace: dict[str, Any] | None + ) -> None: node_id = node.get_id() if existing_node := self.node_defs.get(node_id): raise exceptions.GraphSetupError( diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index 49e41ee267..ffa1fc0190 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -23,7 +23,7 @@ def generate_code( # noqa: C901 - graph: Graph[Any, Any], + graph: Graph[Any, Any, Any], /, *, start_node: Sequence[NodeIdent] | NodeIdent | None = None, @@ -110,7 +110,7 @@ def _node_ids(node_idents: Sequence[NodeIdent] | NodeIdent) -> Iterable[str]: def request_image( - graph: Graph[Any, Any], + graph: Graph[Any, Any, Any], /, **kwargs: Unpack[MermaidConfig], ) -> bytes: @@ -175,7 +175,7 @@ def request_image( def save_image( path: Path | str, - graph: Graph[Any, Any], + graph: Graph[Any, Any, Any], /, **kwargs: Unpack[MermaidConfig], ) -> None: @@ -247,7 +247,7 @@ class MermaidConfig(TypedDict, total=False): """An HTTPX client to use for requests, mostly for testing purposes.""" -NodeIdent: TypeAlias = 'type[BaseNode[Any, Any]] | BaseNode[Any, Any] | str' +NodeIdent: TypeAlias = 'type[BaseNode[Any, Any, Any]] | BaseNode[Any, Any, Any] | str' """A type alias for a node identifier. This can be: diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index ec97b16e84..3591e1b5be 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -14,23 +14,27 @@ else: StateT = TypeVar('StateT', default=None) -__all__ = 'GraphContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'RunEndT', 'NodeRunEndT' +__all__ = 'GraphContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'RunEndT', 'NodeRunEndT', 'DepsT' RunEndT = TypeVar('RunEndT', default=None) """Type variable for the return type of a graph [`run`][pydantic_graph.graph.Graph.run].""" NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) """Type variable for the return type of a node [`run`][pydantic_graph.nodes.BaseNode.run].""" +DepsT = TypeVar('DepsT', default=None) +"""Type variable for the dependencies of a graph and node.""" @dataclass -class GraphContext(Generic[StateT]): +class GraphContext(Generic[StateT, DepsT]): """Context for a graph.""" state: StateT """The state of the graph.""" + deps: DepsT + """Dependencies for the graph.""" -class BaseNode(ABC, Generic[StateT, NodeRunEndT]): +class BaseNode(ABC, Generic[StateT, DepsT, NodeRunEndT]): """Base class for a node.""" docstring_notes: ClassVar[bool] = False @@ -42,7 +46,7 @@ class BaseNode(ABC, Generic[StateT, NodeRunEndT]): """ @abstractmethod - async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: + async def run(self, ctx: GraphContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[NodeRunEndT]: """Run the node. This is an abstract method that must be implemented by subclasses. @@ -87,7 +91,7 @@ def get_note(cls) -> str | None: return docstring @classmethod - def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRunEndT]: + def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, DepsT, NodeRunEndT]: """Get the node definition.""" type_hints = get_type_hints(cls.run, localns=local_ns, include_extras=True) try: @@ -139,7 +143,7 @@ class Edge: @dataclass -class NodeDef(Generic[StateT, NodeRunEndT]): +class NodeDef(Generic[StateT, DepsT, NodeRunEndT]): """Definition of a node. This is a primarily internal representation of a node; in general, it shouldn't be necessary to use it directly. @@ -148,7 +152,7 @@ class NodeDef(Generic[StateT, NodeRunEndT]): mermaid graphs. """ - node: type[BaseNode[StateT, NodeRunEndT]] + node: type[BaseNode[StateT, DepsT, NodeRunEndT]] """The node definition itself.""" node_id: str """ID of the node.""" diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index b1b596b855..0fe37abe52 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -14,7 +14,7 @@ from . import _utils from .nodes import BaseNode, End, RunEndT -__all__ = 'StateT', 'NodeStep', 'EndStep', 'HistoryStep', 'deep_copy_state' +__all__ = 'StateT', 'NodeStep', 'EndStep', 'HistoryStep', 'deep_copy_state', 'nodes_schema_var' StateT = TypeVar('StateT', default=None) @@ -35,7 +35,7 @@ class NodeStep(Generic[StateT, RunEndT]): state: StateT """The state of the graph before the node is run.""" - node: Annotated[BaseNode[StateT, RunEndT], CustomNodeSchema()] + node: Annotated[BaseNode[StateT, Any, RunEndT], CustomNodeSchema()] """The node that was run.""" start_ts: datetime = field(default_factory=_utils.now_utc) """The timestamp when the node started running.""" @@ -53,7 +53,7 @@ def __post_init__(self): # Copy the state to prevent it from being modified by other code self.state = self.snapshot_state(self.state) - def data_snapshot(self) -> BaseNode[StateT, RunEndT]: + def data_snapshot(self) -> BaseNode[StateT, Any, RunEndT]: """Returns a deep copy of [`self.node`][pydantic_graph.state.NodeStep.node]. Useful for summarizing history. @@ -97,7 +97,7 @@ def data_snapshot(self) -> End[RunEndT]: """ -nodes_schema_var: ContextVar[Sequence[type[BaseNode[Any, Any]]]] = ContextVar('nodes_var') +nodes_schema_var: ContextVar[Sequence[type[BaseNode[Any, Any, Any]]]] = ContextVar('nodes_var') class CustomNodeSchema: diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 5549ffaa8b..2f65538eec 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -43,7 +43,7 @@ async def run(self, ctx: GraphContext) -> Double: return Double(len(self.input_data)) @dataclass - class Double(BaseNode[None, int]): + class Double(BaseNode[None, None, int]): input_data: int async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noqa: UP007 @@ -52,9 +52,9 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq else: return End(self.input_data * 2) - my_graph = Graph[None, int](nodes=(Float2String, String2Length, Double)) + my_graph = Graph[None, None, int](nodes=(Float2String, String2Length, Double)) assert my_graph.name is None - result, history = await my_graph.run(None, Float2String(3.14)) + result, history = await my_graph.run(Float2String(3.14)) # len('3.14') * 2 == 8 assert result == 8 assert my_graph.name == 'my_graph' @@ -85,7 +85,7 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq ), ] ) - result, history = await my_graph.run(None, Float2String(3.14159)) + result, history = await my_graph.run(Float2String(3.14159)) # len('3.14159') == 7, 21 * 2 == 42 assert result == 42 assert history == snapshot( @@ -144,7 +144,7 @@ class Float2String(BaseNode): async def run(self, ctx: GraphContext) -> String2Length: raise NotImplementedError() - class String2Length(BaseNode[None, None]): + class String2Length(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() @@ -163,13 +163,13 @@ class Foo(BaseNode): async def run(self, ctx: GraphContext) -> Union[Bar, Spam]: # noqa: UP007 raise NotImplementedError() - class Bar(BaseNode[None, None]): + class Bar(BaseNode[None, None, None]): input_data: str async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() - class Spam(BaseNode[None, None]): + class Spam(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() @@ -190,17 +190,17 @@ class Foo(BaseNode): async def run(self, ctx: GraphContext) -> Eggs: raise NotImplementedError() - class Bar(BaseNode[None, None]): + class Bar(BaseNode[None, None, None]): input_data: str async def run(self, ctx: GraphContext) -> Eggs: raise NotImplementedError() - class Spam(BaseNode[None, None]): + class Spam(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> Eggs: raise NotImplementedError() - class Eggs(BaseNode[None, None]): + class Eggs(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() @@ -217,7 +217,7 @@ class Foo(BaseNode): async def run(self, ctx: GraphContext) -> Bar: raise NotImplementedError() - class Bar(BaseNode[None, None]): + class Bar(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() @@ -239,18 +239,18 @@ async def run(self, ctx: GraphContext) -> Bar: return Bar() @dataclass - class Bar(BaseNode[None, None]): + class Bar(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: return Spam() # type: ignore @dataclass - class Spam(BaseNode[None, None]): + class Spam(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() g = Graph(nodes=(Foo, Bar)) with pytest.raises(GraphRuntimeError) as exc_info: - await g.run(None, Foo()) + await g.run(Foo()) assert exc_info.value.message == snapshot('Node `test_run_node_not_in_graph..Spam()` is not in the graph.') @@ -262,13 +262,13 @@ async def run(self, ctx: GraphContext) -> Bar: return Bar() @dataclass - class Bar(BaseNode[None, None]): + class Bar(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: return 42 # type: ignore g = Graph(nodes=(Foo, Bar)) with pytest.raises(GraphRuntimeError) as exc_info: - await g.run(None, Foo()) + await g.run(Foo()) assert exc_info.value.message == snapshot('Invalid node return type: `int`. Expected `BaseNode` or `End`.') @@ -287,13 +287,13 @@ async def run(self, ctx: GraphContext) -> Foo: g = Graph(nodes=(Foo, Bar)) assert g.name is None history: list[HistoryStep[None, Never]] = [] - n = await g.next(None, Foo(), history) + n = await g.next(Foo(), history) assert n == Bar() assert g.name == 'g' assert history == snapshot([NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat())]) assert isinstance(n, Bar) - n2 = await g.next(None, n, history) + n2 = await g.next(n, history) assert n2 == Foo() assert history == snapshot( diff --git a/tests/graph/test_history.py b/tests/graph/test_history.py index ef4a455d0e..56ca7c8ab2 100644 --- a/tests/graph/test_history.py +++ b/tests/graph/test_history.py @@ -27,7 +27,7 @@ async def run(self, ctx: GraphContext[MyState]) -> Bar: @dataclass -class Bar(BaseNode[MyState, str]): +class Bar(BaseNode[MyState, None, str]): async def run(self, ctx: GraphContext[MyState]) -> End[str]: ctx.state.y += 'y' return End(f'x={ctx.state.x} y={ctx.state.y}') @@ -37,7 +37,7 @@ async def run(self, ctx: GraphContext[MyState]) -> End[str]: async def test_dump_history(): - result, history = await graph.run(MyState(1, ''), Foo()) + result, history = await graph.run(Foo(), state=MyState(1, '')) assert result == snapshot('x=2 y=y') assert history == snapshot( [ diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 2cf2b01423..35676e694f 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -26,7 +26,7 @@ async def run(self, ctx: GraphContext) -> Bar: @dataclass -class Bar(BaseNode[None, None]): +class Bar(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: return End(None) @@ -45,7 +45,7 @@ async def run(self, ctx: GraphContext) -> Annotated[Foo, Edge(label='spam to foo @dataclass -class Eggs(BaseNode[None, None]): +class Eggs(BaseNode[None, None, None]): """This is the docstring for Eggs.""" docstring_notes = False @@ -58,7 +58,7 @@ async def run(self, ctx: GraphContext) -> Annotated[End[None], Edge(label='eggs async def test_run_graph(): - result, history = await graph1.run(None, Foo()) + result, history = await graph1.run(Foo()) assert result is None assert history == snapshot( [ diff --git a/tests/graph/test_state.py b/tests/graph/test_state.py index b363e37674..d68f7452fb 100644 --- a/tests/graph/test_state.py +++ b/tests/graph/test_state.py @@ -26,14 +26,14 @@ async def run(self, ctx: GraphContext[MyState]) -> Bar: return Bar() @dataclass - class Bar(BaseNode[MyState, str]): + class Bar(BaseNode[MyState, None, str]): async def run(self, ctx: GraphContext[MyState]) -> End[str]: ctx.state.y += 'y' return End(f'x={ctx.state.x} y={ctx.state.y}') graph = Graph(nodes=(Foo, Bar)) s = MyState(1, '') - result, history = await graph.run(s, Foo()) + result, history = await graph.run(Foo(), state=s) assert result == snapshot('x=2 y=y') assert history == snapshot( [ diff --git a/tests/typed_graph.py b/tests/typed_graph.py index 8c6ff9b0ea..347782c182 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -4,7 +4,7 @@ from typing_extensions import assert_type -from pydantic_graph import BaseNode, End, Graph, GraphContext +from pydantic_graph import BaseNode, End, Graph, GraphContext, HistoryStep @dataclass @@ -29,7 +29,7 @@ class X: @dataclass -class Double(BaseNode[None, X]): +class Double(BaseNode[None, None, X]): input_data: int async def run(self, ctx: GraphContext) -> String2Length | End[X]: @@ -39,7 +39,7 @@ async def run(self, ctx: GraphContext) -> String2Length | End[X]: return End(X(self.input_data * 2)) -def use_double(node: BaseNode[None, X]) -> None: +def use_double(node: BaseNode[None, None, X]) -> None: """Shoe that `Double` is valid as a `BaseNode[None, int, X]`.""" print(node) @@ -47,18 +47,18 @@ def use_double(node: BaseNode[None, X]) -> None: use_double(Double(1)) -g1 = Graph[None, X]( +g1 = Graph[None, None, X]( nodes=( Float2String, String2Length, Double, ) ) -assert_type(g1, Graph[None, X]) +assert_type(g1, Graph[None, None, X]) g2 = Graph(nodes=(Double,)) -assert_type(g2, Graph[None, X]) +assert_type(g2, Graph[None, None, X]) g3 = Graph( nodes=( @@ -68,7 +68,47 @@ def use_double(node: BaseNode[None, X]) -> None: ) ) # because String2Length came before Double, the output type is Any -assert_type(g3, Graph[None, X]) +assert_type(g3, Graph[None, None, X]) Graph[None, bytes](nodes=(Float2String, String2Length, Double)) # type: ignore[arg-type] Graph[None, str](nodes=[Double]) # type: ignore[list-item] + + +@dataclass +class MyState: + x: int + + +@dataclass +class MyDeps: + y: str + + +@dataclass +class A(BaseNode[MyState, MyDeps]): + async def run(self, ctx: GraphContext[MyState, MyDeps]) -> B: + assert ctx.state.x == 1 + assert ctx.deps.y == 'y' + return B() + + +@dataclass +class B(BaseNode[MyState, MyDeps, int]): + async def run(self, ctx: GraphContext[MyState, MyDeps]) -> End[int]: + return End(42) + + +g4 = Graph[MyState, MyDeps, int](nodes=(A, B)) +assert_type(g4, Graph[MyState, MyDeps, int]) + +g5 = Graph(nodes=(A, B)) +assert_type(g5, Graph[MyState, MyDeps, int]) + + +def run_g5(): + g5.run_sync(A()) # pyright: ignore[reportArgumentType] + g5.run_sync(A(), state=MyState(x=1)) # pyright: ignore[reportArgumentType] + g5.run_sync(A(), deps=MyDeps(y='y')) # pyright: ignore[reportArgumentType] + ans, history = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y')) + assert_type(ans, int) + assert_type(history, list[HistoryStep[MyState, int]]) From 3a1cddd94b05e28fcbe93d56c869f1f52054908d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 15 Jan 2025 14:49:24 +0000 Subject: [PATCH 54/57] fix build --- docs/api/pydantic_graph/nodes.md | 1 + docs/graph.md | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/api/pydantic_graph/nodes.md b/docs/api/pydantic_graph/nodes.md index 7540920a02..e58ddf7012 100644 --- a/docs/api/pydantic_graph/nodes.md +++ b/docs/api/pydantic_graph/nodes.md @@ -7,5 +7,6 @@ - BaseNode - End - Edge + - DepsT - RunEndT - NodeRunEndT diff --git a/docs/graph.md b/docs/graph.md index 1128b2561c..b751203838 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -60,7 +60,7 @@ Nodes which are generally [`dataclass`es][dataclasses.dataclass] include: Nodes are generic in: * **state**, which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.state.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter -* **deps**, which must have the same type as the deps of the graph they're included in, [`DepsT`][pydantic_graph.state.DepsT] has a default of `None`, so if you're not using deps you can omit this generic parameter +* **deps**, which must have the same type as the deps of the graph they're included in, [`DepsT`][pydantic_graph.nodes.DepsT] has a default of `None`, so if you're not using deps you can omit this generic parameter * **graph return type** — this only applies if the node returns [`End`][pydantic_graph.nodes.End]. [`RunEndT`][pydantic_graph.nodes.RunEndT] has a default of [Never][typing.Never] so this generic parameter can be omitted if the node doesn't return `End`, but must be included if it does. Here's an example of a start or intermediate node in a graph — it can't end the run as it doesn't return [`End`][pydantic_graph.nodes.End]: From d78a9d1c8bf1803a064218906bddaeb37b7ec36e Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 15 Jan 2025 17:29:37 +0000 Subject: [PATCH 55/57] fix type hint --- docs/graph.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index e109df92d6..e747e0b6bc 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -133,7 +133,7 @@ from pydantic_graph import BaseNode, End, Graph, GraphContext @dataclass -class DivisibleBy5(BaseNode[None, int]): # (1)! +class DivisibleBy5(BaseNode[None, None, int]): # (1)! foo: int async def run( @@ -163,7 +163,7 @@ print([item.data_snapshot() for item in history]) #> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] ``` -1. The `DivisibleBy5` node is parameterized with `None` as this graph doesn't use state, and `int` as it can end the run. +1. The `DivisibleBy5` node is parameterized with `None` for the state param and `None` for the deps param as this graph doesn't use state or deps, and `int` as it can end the run. 2. The `Increment` node doesn't return `End`, so the `RunEndT` generic parameter is omitted, state can also be omitted as the graph doesn't use state. 3. The graph is created with a sequence of nodes. 4. The graph is run synchronously with [`run_sync`][pydantic_graph.graph.Graph.run_sync] the initial state `None` and the start node `DivisibleBy5(4)` are passed as arguments. From d653c0a6cf92839f1f2f072922d73296281ba0a9 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 15 Jan 2025 19:14:38 +0000 Subject: [PATCH 56/57] add deps example and tests --- docs/graph.md | 76 +++++++++++++++++++++++++++++++++++++++ tests/graph/test_graph.py | 31 ++++++++++++++++ 2 files changed, 107 insertions(+) diff --git a/docs/graph.md b/docs/graph.md index e747e0b6bc..e11ad76b10 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -634,6 +634,82 @@ stateDiagram-v2 You maybe have noticed that although this examples transfers control flow out of the graph run, we're still using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] to get user input, with the process hanging while we wait for the user to enter a response. For an example of genuine out-of-process control flow, see the [question graph example](examples/question-graph.md). +## Dependency Injection + +As with PydanticAI, `pydantic-graph` supports dependency injection via a generic parameter on [`Graph`][pydantic_graph.graph.Graph] and [`BaseNode`][pydantic_graph.nodes.BaseNode], and the [`GraphContext.deps`][pydantic_graph.nodes.GraphContext.deps] fields. + +As an example of dependency injection, let's modify the `DivisibleBy5` example [above](#graph) to use a [`ProcessPoolExecutor`][concurrent.futures.ProcessPoolExecutor] to run the compute load in a separate process (this is a contrived example, `ProcessPoolExecutor` wouldn't actually improve performance in this example): + +```py {title="deps_example.py" py="3.10" test="skip" hl_lines="4 10-12 35-37 48-49"} +from __future__ import annotations + +import asyncio +from concurrent.futures import ProcessPoolExecutor +from dataclasses import dataclass + +from pydantic_graph import BaseNode, End, Graph, GraphContext + + +@dataclass +class GraphDeps: + executor: ProcessPoolExecutor + + +@dataclass +class DivisibleBy5(BaseNode[None, None, int]): + foo: int + + async def run( + self, + ctx: GraphContext, + ) -> Increment | End[int]: + if self.foo % 5 == 0: + return End(self.foo) + else: + return Increment(self.foo) + + +@dataclass +class Increment(BaseNode): + foo: int + + async def run(self, ctx: GraphContext) -> DivisibleBy5: + loop = asyncio.get_running_loop() + compute_result = await loop.run_in_executor( + ctx.deps.executor, + self.compute, + ) + return DivisibleBy5(compute_result) + + def compute(self) -> int: + return self.foo + 1 + + +fives_graph = Graph(nodes=[DivisibleBy5, Increment]) + + +async def main(): + with ProcessPoolExecutor() as executor: + deps = GraphDeps(executor) + result, history = await fives_graph.run(DivisibleBy5(3), deps=deps) + print(result) + #> 5 + # the full history is quite verbose (see below), so we'll just print the summary + print([item.data_snapshot() for item in history]) + """ + [ + DivisibleBy5(foo=3), + Increment(foo=3), + DivisibleBy5(foo=4), + Increment(foo=4), + DivisibleBy5(foo=5), + End(data=5), + ] + """ +``` + +_(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ + ## Mermaid Diagrams Pydantic Graph can generate [mermaid](https://mermaid.js.org/) [`stateDiagram-v2`](https://mermaid.js.org/syntax/stateDiagram.html) diagrams for graphs, as shown above. diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index d14571bed4..2ae8bc5a9d 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -299,3 +299,34 @@ async def run(self, ctx: GraphContext) -> Foo: NodeStep(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), ] ) + + +async def test_deps(): + @dataclass + class Deps: + a: int + b: int + + @dataclass + class Foo(BaseNode[None, Deps]): + async def run(self, ctx: GraphContext[None, Deps]) -> Bar: + assert isinstance(ctx.deps, Deps) + return Bar() + + @dataclass + class Bar(BaseNode[None, Deps, int]): + async def run(self, ctx: GraphContext[None, Deps]) -> End[int]: + assert isinstance(ctx.deps, Deps) + return End(123) + + g = Graph(nodes=(Foo, Bar)) + result, history = await g.run(Foo(), deps=Deps(1, 2)) + + assert result == 123 + assert history == snapshot( + [ + NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeStep(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + EndStep(result=End(data=123), ts=IsNow(tz=timezone.utc)), + ] + ) From e0ab64b2eaaaae64aa57d8f0056ffe870f77da95 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 15 Jan 2025 19:23:32 +0000 Subject: [PATCH 57/57] cleanup --- docs/graph.md | 5 +++-- pydantic_graph/pydantic_graph/graph.py | 7 +++---- tests/typed_graph.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index e11ad76b10..bbd546ddc3 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -59,8 +59,8 @@ Nodes which are generally [`dataclass`es][dataclasses.dataclass] include: Nodes are generic in: -* **state**, which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.state.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter -* **deps**, which must have the same type as the deps of the graph they're included in, [`DepsT`][pydantic_graph.nodes.DepsT] has a default of `None`, so if you're not using deps you can omit this generic parameter +* **state**, which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.state.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter, see [stateful graphs](#stateful-graphs) for more information +* **deps**, which must have the same type as the deps of the graph they're included in, [`DepsT`][pydantic_graph.nodes.DepsT] has a default of `None`, so if you're not using deps you can omit this generic parameter, see [dependency injection](#dependency-injection) for more information * **graph return type** — this only applies if the node returns [`End`][pydantic_graph.nodes.End]. [`RunEndT`][pydantic_graph.nodes.RunEndT] has a default of [Never][typing.Never] so this generic parameter can be omitted if the node doesn't return `End`, but must be included if it does. Here's an example of a start or intermediate node in a graph — it can't end the run as it doesn't return [`End`][pydantic_graph.nodes.End]: @@ -120,6 +120,7 @@ class MyNode(BaseNode[MyState, None, int]): # (1)! `Graph` is generic in: * **state** the state type of the graph, [`StateT`][pydantic_graph.state.StateT] +* **deps** the deps type of the graph, [`DepsT`][pydantic_graph.nodes.DepsT] * **graph return type** the return type of the graph run, [`RunEndT`][pydantic_graph.nodes.RunEndT] Here's an example of a simple graph: diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index ce22591a67..0aed1f3706 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -51,12 +51,12 @@ async def run(self, ctx: GraphContext) -> Check42: return Check42() @dataclass - class Check42(BaseNode[MyState, None, None]): - async def run(self, ctx: GraphContext) -> Increment | End: + class Check42(BaseNode[MyState, None, int]): + async def run(self, ctx: GraphContext) -> Increment | End[int]: if ctx.state.number == 42: return Increment() else: - return End(None) + return End(ctx.state.number) never_42_graph = Graph(nodes=(Increment, Check42)) ``` @@ -70,7 +70,6 @@ async def run(self, ctx: GraphContext) -> Increment | End: name: str | None node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] snapshot_state: Callable[[StateT], StateT] - snapshot_state: Callable[[StateT], StateT] _state_type: type[StateT] | _utils.Unset = field(repr=False) _run_end_type: type[RunEndT] | _utils.Unset = field(repr=False) diff --git a/tests/typed_graph.py b/tests/typed_graph.py index 347782c182..14bd171669 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -105,7 +105,7 @@ async def run(self, ctx: GraphContext[MyState, MyDeps]) -> End[int]: assert_type(g5, Graph[MyState, MyDeps, int]) -def run_g5(): +def run_g5() -> None: g5.run_sync(A()) # pyright: ignore[reportArgumentType] g5.run_sync(A(), state=MyState(x=1)) # pyright: ignore[reportArgumentType] g5.run_sync(A(), deps=MyDeps(y='y')) # pyright: ignore[reportArgumentType]