From 68f5fe2f63e225afb5ec6582607d699f1c2cab7c Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 22 Dec 2024 16:38:31 +0000 Subject: [PATCH 01/57] start pydantic-ai-graph --- pydantic_ai_graph/README.md | 16 +++ .../pydantic_ai_graph/__init__.py | 0 pydantic_ai_graph/pydantic_ai_graph/_utils.py | 37 +++++ pydantic_ai_graph/pydantic_ai_graph/graph.py | 128 ++++++++++++++++++ pydantic_ai_graph/pydantic_ai_graph/nodes.py | 116 ++++++++++++++++ pydantic_ai_graph/pydantic_ai_graph/state.py | 17 +++ pydantic_ai_graph/pyproject.toml | 43 ++++++ pyproject.toml | 7 +- uv.lock | 18 +++ 9 files changed, 379 insertions(+), 3 deletions(-) create mode 100644 pydantic_ai_graph/README.md create mode 100644 pydantic_ai_graph/pydantic_ai_graph/__init__.py create mode 100644 pydantic_ai_graph/pydantic_ai_graph/_utils.py create mode 100644 pydantic_ai_graph/pydantic_ai_graph/graph.py create mode 100644 pydantic_ai_graph/pydantic_ai_graph/nodes.py create mode 100644 pydantic_ai_graph/pydantic_ai_graph/state.py create mode 100644 pydantic_ai_graph/pyproject.toml diff --git a/pydantic_ai_graph/README.md b/pydantic_ai_graph/README.md new file mode 100644 index 0000000000..2271f779cb --- /dev/null +++ b/pydantic_ai_graph/README.md @@ -0,0 +1,16 @@ +# PydanticAI Graph + +[![CI](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain) +[![Coverage](https://coverage-badge.samuelcolvin.workers.dev/pydantic/pydantic-ai.svg)](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai) +[![PyPI](https://img.shields.io/pypi/v/pydantic-ai-graph.svg)](https://pypi.python.org/pypi/pydantic-ai-graph) +[![versions](https://img.shields.io/pypi/pyversions/pydantic-ai-graph.svg)](https://github.com/pydantic/pydantic-ai) +[![license](https://img.shields.io/github/license/pydantic/pydantic-ai-graph.svg?v)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE) + +Graph and State Machine library. + +This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency +on `pydantic-ai` or related packages and can be considered as a pure graph library. + +As with PydanticAI, this library priorities type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. + +`pydantic-ai-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes. diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pydantic_ai_graph/pydantic_ai_graph/_utils.py b/pydantic_ai_graph/pydantic_ai_graph/_utils.py new file mode 100644 index 0000000000..b58e142f6e --- /dev/null +++ b/pydantic_ai_graph/pydantic_ai_graph/_utils.py @@ -0,0 +1,37 @@ +import sys +import types +from typing import Any, TypeAliasType, Union, get_args, get_origin + + +def get_union_args(tp: Any) -> tuple[Any, ...]: + """Extract the arguments of a Union type if `response_type` is a union, otherwise return the original type.""" + # similar to `pydantic_ai_slim/pydantic_ai/_result.py:get_union_args` + if isinstance(tp, TypeAliasType): + tp = tp.__value__ + + origin = get_origin(tp) + if origin_is_union(origin): + return get_args(tp) + else: + return (tp,) + + +# same as `pydantic_ai_slim/pydantic_ai/_result.py:origin_is_union` +if sys.version_info < (3, 10): + + def origin_is_union(tp: type[Any] | None) -> bool: + return tp is Union + +else: + + def origin_is_union(tp: type[Any] | None) -> bool: + return tp is Union or tp is types.UnionType + + +def comma_and(items: list[str]) -> str: + """Join with a comma and 'and' for the last item.""" + if len(items) == 1: + return items[0] + else: + # oxford comma ¯\_(ツ)_/¯ + return ', '.join(items[:-1]) + ', and ' + items[-1] diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py new file mode 100644 index 0000000000..0180cfa325 --- /dev/null +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -0,0 +1,128 @@ +from __future__ import annotations as _annotations + +import base64 +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Generic + +from typing_extensions import TypeVar, assert_never + +from . import _utils +from .nodes import ( + BaseNode, + DepsT, + End, + GraphContext, + GraphOutputT, + NodeDef, +) +from .state import StateT + +__all__ = ('Graph',) +GraphInputT = TypeVar('GraphInputT', default=Any) + + +# noinspection PyTypeHints +@dataclass(init=False) +class Graph(Generic[GraphInputT, GraphOutputT, DepsT, StateT]): + """Definition of a graph.""" + + first_node: NodeDef[Any, Any, DepsT, StateT] + nodes: dict[str, NodeDef[Any, Any, DepsT, StateT]] + state_type: type[StateT] | None + + # noinspection PyUnusedLocal + def __init__( + self, + first_node: type[BaseNode[GraphInputT, Any, DepsT, StateT]], + *other_nodes: type[BaseNode[Any, GraphOutputT, DepsT, StateT]], + deps_type: type[DepsT] | None = None, + state_type: type[StateT] | None = None, + ): + self.first_node = first_node.get_node_def() + self.nodes = nodes = {self.first_node.node_id: self.first_node} + for node in other_nodes: + node_def = node.get_node_def() + nodes[node_def.node_id] = node_def + + self._check() + self.state_type = state_type + + async def run( + self, input_data: GraphInputT, deps: DepsT = None, state: StateT = None + ) -> tuple[GraphOutputT, StateT]: + current_node_def = self.first_node + current_node = current_node_def.node(input_data) + ctx = GraphContext(deps, state) + while True: + # noinspection PyUnresolvedReferences + next_node = await current_node.run(ctx) + if isinstance(next_node, End): + if current_node_def.can_end: + return next_node.data, ctx.state + else: + raise ValueError(f'Node {current_node_def.node_id} cannot end the graph') + elif isinstance(next_node, BaseNode): + next_node_id = next_node.get_id() + try: + next_node_def = self.nodes[next_node_id] + except KeyError as e: + raise ValueError( + f'Node {current_node_def.node_id} cannot go to {next_node_id} which is not in the Graph' + ) from e + + if not current_node_def.dest_any and next_node_id not in current_node_def.next_node_ids: + raise ValueError( + f'Node {current_node_def.node_id} cannot go to {next_node_id} which is not in its ' + f'list of allowed next nodes' + ) + + current_node_def = next_node_def + current_node = next_node + else: + if TYPE_CHECKING: + assert_never(next_node) + else: + raise TypeError(f'Invalid node type: {type(next_node)} expected BaseNode or End') + + def mermaid_code(self) -> str: + lines = ['graph TD'] + # order of destination nodes should match their order in `self.nodes` + node_order = {nid: index for index, nid in enumerate(self.nodes.keys())} + for node_id, node in self.nodes.items(): + if node.dest_any: + for next_node_id in self.nodes: + lines.append(f' {node_id} --> {next_node_id}') + for _, next_node_id in sorted((node_order[nid], nid) for nid in node.next_node_ids): + lines.append(f' {node_id} --> {next_node_id}') + if node.can_end: + lines.append(f' {node_id} --> END') + return '\n'.join(lines) + + def mermaid_image(self, mermaid_ink_params: dict[str, str | int] | None = None) -> bytes: + import httpx + + code_base64 = base64.b64encode(self.mermaid_code().encode()).decode() + + response = httpx.get(f'https://mermaid.ink/img/{code_base64}', params=mermaid_ink_params) + response.raise_for_status() + return response.content + + def mermaid_save(self, path: Path, mermaid_ink_params: dict[str, str | int] | None = None) -> None: + image_data = self.mermaid_image(mermaid_ink_params) + path.write_bytes(image_data) + + def _check(self): + bad_edges: dict[str, list[str]] = {} + for node in self.nodes.values(): + node_bad_edges = node.next_node_ids - self.nodes.keys() + for bad_edge in node_bad_edges: + bad_edges.setdefault(bad_edge, []).append(f'"{node.node_id}"') + + if bad_edges: + bad_edges_list = [f'"{k}" is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] + if len(bad_edges_list) == 1: + raise ValueError(f'{bad_edges_list[0]} but not included in the graph.') + else: + b = '\n'.join(f' {be}' for be in bad_edges_list) + raise ValueError(f'Nodes are referenced in the graph but not included in the graph:\n{b}') diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_ai_graph/pydantic_ai_graph/nodes.py new file mode 100644 index 0000000000..a790e01959 --- /dev/null +++ b/pydantic_ai_graph/pydantic_ai_graph/nodes.py @@ -0,0 +1,116 @@ +from __future__ import annotations as _annotations + +from abc import ABC, ABCMeta, abstractmethod +from dataclasses import dataclass +from functools import cache +from typing import Any, ClassVar, Generic, get_args, get_origin, get_type_hints + +from typing_extensions import TypeVar + +from . import _utils +from .state import StateT + +__all__ = ( + 'NodeInputT', + 'GraphOutputT', + 'DepsT', + 'GraphContext', + 'End', + 'BaseNode', + 'NodeDef', +) + +NodeInputT = TypeVar('NodeInputT', default=Any) +GraphOutputT = TypeVar('GraphOutputT', default=Any) +DepsT = TypeVar('DepsT', default=None) + + +# noinspection PyTypeHints +@dataclass +class GraphContext(Generic[DepsT, StateT]): + """Context for a graph.""" + + deps: DepsT + state: StateT + + +# noinspection PyTypeHints +class End(ABC, Generic[NodeInputT]): + """Type to return from a node to signal the end of the graph.""" + + __slots__ = ('data',) + + def __init__(self, input_data: NodeInputT) -> None: + self.data = input_data + + +class _BaseNodeMeta(ABCMeta): + def __repr__(cls): + base: Any = cls.__orig_bases__[0] # type: ignore + args = get_args(base) + if len(args) == 4 and args[3] is None: + if args[2] is None: + args = args[:2] + else: + args = args[:3] + args = ', '.join(a.__name__ for a in args) + return f'{cls.__name__}({base.__name__}[{args}])' + + +# noinspection PyTypeHints +class BaseNode(Generic[NodeInputT, GraphOutputT, DepsT, StateT], metaclass=_BaseNodeMeta): + """Base class for a node.""" + + node_id: ClassVar[str | None] = None + __slots__ = ('input_data',) + + def __init__(self, input_data: NodeInputT) -> None: + self.input_data = input_data + + @abstractmethod + async def run(self, ctx: GraphContext[DepsT, StateT]) -> BaseNode[Any, Any, DepsT, StateT] | End[GraphOutputT]: ... + + @classmethod + @cache + def get_id(cls) -> str: + return cls.node_id or cls.__qualname__ + + @classmethod + def get_node_def(cls) -> NodeDef[Any, Any, DepsT, StateT]: + type_hints = get_type_hints(cls.run) + next_node_ids: set[str] = set() + can_end: bool = False + dest_any: bool = False + for return_type in _utils.get_union_args(type_hints['return']): + return_type_origin = get_origin(return_type) or return_type + if return_type_origin is BaseNode: + dest_any = True + elif issubclass(return_type_origin, BaseNode): + next_node_ids.add(return_type.get_id()) + elif return_type_origin is End: + can_end = True + else: + raise TypeError(f'Invalid return type: {return_type}') + + return NodeDef( + cls, + cls.get_id(), + next_node_ids, + can_end, + dest_any, + ) + + +# noinspection PyTypeHints +@dataclass +class NodeDef(ABC, Generic[NodeInputT, GraphOutputT, DepsT, StateT]): + """Definition of a node. + + Used by [`Graph`][pydantic_ai_graph.graph.Graph] store information about a node. + """ + + node: type[BaseNode[NodeInputT, GraphOutputT, DepsT, StateT]] + node_id: str + next_node_ids: set[str] + can_end: bool + dest_any: bool diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py new file mode 100644 index 0000000000..f9fac1199f --- /dev/null +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -0,0 +1,17 @@ +from abc import ABC + +from typing_extensions import TypeVar + +__all__ = 'AbstractState', 'StateT' + + +class AbstractState(ABC): + """Abstract class for a state object.""" + + def __init__(self): + pass + + # TODO serializing and deserialize state + + +StateT = TypeVar('StateT', None, AbstractState, default=None) diff --git a/pydantic_ai_graph/pyproject.toml b/pydantic_ai_graph/pyproject.toml new file mode 100644 index 0000000000..84cce90ea7 --- /dev/null +++ b/pydantic_ai_graph/pyproject.toml @@ -0,0 +1,43 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "pydantic-ai-graph" +version = "0.0.1" +description = "Graph and State Machine library" +authors = [ + { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, +] +license = "MIT" +readme = "README.md" +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Intended Audience :: Developers", + "Intended Audience :: Information Technology", + "Intended Audience :: System Administrators", + "License :: OSI Approved :: MIT License", + "Operating System :: Unix", + "Operating System :: POSIX :: Linux", + "Environment :: Console", + "Environment :: MacOS X", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Internet", +] +requires-python = ">=3.9" +dependencies = [ + "httpx>=0.27.2", + "logfire-api>=1.2.0", + "pydantic>=2.10", +] + +[tool.hatch.build.targets.wheel] +packages = ["pydantic_ai_graph"] diff --git a/pyproject.toml b/pyproject.toml index 8dbfdc36a5..82728545c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ pydantic-ai-slim = { workspace = true } pydantic-ai-examples = { workspace = true } [tool.uv.workspace] -members = ["pydantic_ai_slim", "examples"] +members = ["pydantic_ai_slim", "pydantic_ai_graph", "examples"] [dependency-groups] # dev dependencies are defined in `pydantic-ai-slim/pyproject.toml` to allow for minimal testing @@ -85,6 +85,7 @@ line-length = 120 target-version = "py39" include = [ "pydantic_ai_slim/**/*.py", + "pydantic_ai_graph/**/*.py", "examples/**/*.py", "tests/**/*.py", "docs/**/*.py", @@ -128,7 +129,7 @@ typeCheckingMode = "strict" reportMissingTypeStubs = false reportUnnecessaryIsInstance = false reportUnnecessaryTypeIgnoreComment = true -include = ["pydantic_ai_slim", "tests", "examples"] +include = ["pydantic_ai_slim", "pydantic_ai_graph", "tests", "examples"] venvPath = ".venv" # see https://github.com/microsoft/pyright/issues/7771 - we don't want to error on decorated functions in tests # which are not otherwise used @@ -147,7 +148,7 @@ filterwarnings = [ # https://coverage.readthedocs.io/en/latest/config.html#run [tool.coverage.run] # required to avoid warnings about files created by create_module fixture -include = ["pydantic_ai_slim/**/*.py", "tests/**/*.py"] +include = ["pydantic_ai_slim/**/*.py", "pydantic_ai_graph/**/*.py","tests/**/*.py"] omit = ["tests/test_live.py", "tests/example_modules/*.py"] branch = true diff --git a/uv.lock b/uv.lock index cdd9976bad..2234f25d33 100644 --- a/uv.lock +++ b/uv.lock @@ -12,6 +12,7 @@ resolution-markers = [ members = [ "pydantic-ai", "pydantic-ai-examples", + "pydantic-ai-graph", "pydantic-ai-slim", ] @@ -2539,6 +2540,23 @@ requires-dist = [ { name = "uvicorn", specifier = ">=0.32.0" }, ] +[[package]] +name = "pydantic-ai-graph" +version = "0.0.1" +source = { editable = "pydantic_ai_graph" } +dependencies = [ + { name = "httpx" }, + { name = "logfire-api" }, + { name = "pydantic" }, +] + +[package.metadata] +requires-dist = [ + { name = "httpx", specifier = ">=0.27.2" }, + { name = "logfire-api", specifier = ">=1.2.0" }, + { name = "pydantic", specifier = ">=2.10" }, +] + [[package]] name = "pydantic-ai-slim" version = "0.0.18" From 615c5e2af4ece957772922f86155b126aea7ef74 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 22 Dec 2024 16:44:56 +0000 Subject: [PATCH 02/57] lower case state machine --- pydantic_ai_graph/README.md | 2 +- pydantic_ai_graph/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_graph/README.md b/pydantic_ai_graph/README.md index 2271f779cb..ec0b531a25 100644 --- a/pydantic_ai_graph/README.md +++ b/pydantic_ai_graph/README.md @@ -6,7 +6,7 @@ [![versions](https://img.shields.io/pypi/pyversions/pydantic-ai-graph.svg)](https://github.com/pydantic/pydantic-ai) [![license](https://img.shields.io/github/license/pydantic/pydantic-ai-graph.svg?v)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE) -Graph and State Machine library. +Graph and state machine library. This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency on `pydantic-ai` or related packages and can be considered as a pure graph library. diff --git a/pydantic_ai_graph/pyproject.toml b/pydantic_ai_graph/pyproject.toml index 84cce90ea7..2231df5cd1 100644 --- a/pydantic_ai_graph/pyproject.toml +++ b/pydantic_ai_graph/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "hatchling.build" [project] name = "pydantic-ai-graph" version = "0.0.1" -description = "Graph and State Machine library" +description = "Graph and state machine library" authors = [ { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, ] From 9ffe8f03481fdfc7513f465be921ae0d5e75dc9b Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 22 Dec 2024 19:46:50 +0000 Subject: [PATCH 03/57] starting tests --- Makefile | 2 +- .../pydantic_ai_graph/__init__.py | 4 ++ pydantic_ai_graph/pydantic_ai_graph/graph.py | 21 ++++++- pydantic_ai_graph/pydantic_ai_graph/nodes.py | 4 +- pydantic_ai_graph/pydantic_ai_graph/py.typed | 0 pyproject.toml | 4 ++ tests/test_graph.py | 34 ++++++++++ tests/typed_graph.py | 63 +++++++++++++++++++ 8 files changed, 127 insertions(+), 5 deletions(-) create mode 100644 pydantic_ai_graph/pydantic_ai_graph/py.typed create mode 100644 tests/test_graph.py create mode 100644 tests/typed_graph.py diff --git a/Makefile b/Makefile index 8fea000af8..7b12e2a9f1 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ typecheck-pyright: .PHONY: typecheck-mypy typecheck-mypy: - uv run mypy --strict tests/typed_agent.py + uv run mypy .PHONY: typecheck typecheck: typecheck-pyright ## Run static type checking diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py index e69de29bb2..6345bd7090 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/__init__.py +++ b/pydantic_ai_graph/pydantic_ai_graph/__init__.py @@ -0,0 +1,4 @@ +from .graph import Graph +from .nodes import BaseNode, End, GraphContext + +__all__ = 'BaseNode', 'End', 'GraphContext', 'Graph' diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index 0180cfa325..e21bfc7b54 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -1,6 +1,8 @@ from __future__ import annotations as _annotations import base64 +import inspect +import types from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Generic @@ -39,10 +41,11 @@ def __init__( deps_type: type[DepsT] | None = None, state_type: type[StateT] | None = None, ): - self.first_node = first_node.get_node_def() + parent_namespace = get_parent_namespace(inspect.currentframe()) + self.first_node = first_node.get_node_def(parent_namespace) self.nodes = nodes = {self.first_node.node_id: self.first_node} for node in other_nodes: - node_def = node.get_node_def() + node_def = node.get_node_def(parent_namespace) nodes[node_def.node_id] = node_def self._check() @@ -126,3 +129,17 @@ def _check(self): else: b = '\n'.join(f' {be}' for be in bad_edges_list) raise ValueError(f'Nodes are referenced in the graph but not included in the graph:\n{b}') + + +def get_parent_namespace(frame: types.FrameType | None) -> dict[str, Any] | None: + """Attempt to get the namespace where the graph was defined. + + If the graph is defined with generics `Graph[a, b]` then another frame is inserted, and we have to skip that + to get the correct namespace. + """ + if frame is not None: + if back := frame.f_back: + if back.f_code.co_filename.endswith('/typing.py'): + return get_parent_namespace(back) + else: + return back.f_locals diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_ai_graph/pydantic_ai_graph/nodes.py index a790e01959..c58102dfaf 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/nodes.py +++ b/pydantic_ai_graph/pydantic_ai_graph/nodes.py @@ -76,8 +76,8 @@ def get_id(cls) -> str: return cls.node_id or cls.__qualname__ @classmethod - def get_node_def(cls) -> NodeDef[Any, Any, DepsT, StateT]: - type_hints = get_type_hints(cls.run) + def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[Any, Any, DepsT, StateT]: + type_hints = get_type_hints(cls.run, localns=local_ns) next_node_ids: set[str] = set() can_end: bool = False dest_any: bool = False diff --git a/pydantic_ai_graph/pydantic_ai_graph/py.typed b/pydantic_ai_graph/pydantic_ai_graph/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pyproject.toml b/pyproject.toml index 82728545c9..fc64bd8263 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,6 +138,10 @@ executionEnvironments = [ ] exclude = ["examples/pydantic_ai_examples/weather_agent_gradio.py"] +[tool.mypy] +files = "tests/typed_agent.py,tests/typed_graph.py" +strict = true + [tool.pytest.ini_options] testpaths = "tests" xfail_strict = true diff --git a/tests/test_graph.py b/tests/test_graph.py new file mode 100644 index 0000000000..3b8ede23eb --- /dev/null +++ b/tests/test_graph.py @@ -0,0 +1,34 @@ +from __future__ import annotations as _annotations + +import pytest + +from pydantic_ai_graph import BaseNode, End, Graph, GraphContext + +pytestmark = pytest.mark.anyio + + +async def test_graph(): + class Float2String(BaseNode[float]): + async def run(self, ctx: GraphContext) -> String2Length: + return String2Length(str(self.input_data)) + + class String2Length(BaseNode[str]): + async def run(self, ctx: GraphContext) -> Double: + return Double(len(self.input_data)) + + class Double(BaseNode[int, int]): + async def run(self, ctx: GraphContext) -> String2Length | End[int]: + if self.input_data == 7: + return String2Length('x' * 21) + else: + return End(self.input_data * 2) + + g = Graph[float, int]( + Float2String, + String2Length, + Double, + ) + # len('3.14') * 2 == 8 + assert await g.run(3.14) == (8, None) + # len('3.14159') == 7, 21 * 2 == 42 + assert await g.run(3.14159) == (42, None) diff --git a/tests/typed_graph.py b/tests/typed_graph.py new file mode 100644 index 0000000000..e6974e3070 --- /dev/null +++ b/tests/typed_graph.py @@ -0,0 +1,63 @@ +from __future__ import annotations as _annotations + +from typing import assert_type + +from pydantic_ai_graph import BaseNode, End, Graph, GraphContext + + +class Float2String(BaseNode[float]): + async def run(self, ctx: GraphContext) -> String2Length: + return String2Length(str(self.input_data)) + + +class String2Length(BaseNode[str]): + async def run(self, ctx: GraphContext) -> Double: + return Double(len(self.input_data)) + + +class Double(BaseNode[int, int]): + async def run(self, ctx: GraphContext) -> String2Length | End[int]: + if self.input_data == 7: + return String2Length('x' * 21) + else: + return End(self.input_data * 2) + + +def use_double(node: BaseNode[int, int]) -> None: + """Shoe that `Double` is valid as a `BaseNode[int, int]`.""" + + +use_double(Double(1)) + + +g1 = Graph[float, int]( + Float2String, + String2Length, + Double, +) +assert_type(g1, Graph[float, int]) + + +g2 = Graph(Float2String, Double) +assert_type(g2, Graph[float, int]) + +g3 = Graph( + Float2String, + Double, + String2Length, +) +MYPY = False +if MYPY: + # with mypy the presence of `String2Length` makes the output type Any + assert_type(g3, Graph[float]) # pyright: ignore[reportAssertTypeFailure] +else: + # pyright works correct and uses `Double` to infer the output type + assert_type(g3, Graph[float, int]) + +g4 = Graph( + Float2String, + String2Length, + Double, +) +# because String2Length came before Double, the output type is Any +assert_type(g4, Graph[float]) From dd60d15491dd54a73957ac413f10d2192a2757e7 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 22 Dec 2024 20:37:37 +0000 Subject: [PATCH 04/57] add history and logfire --- .../pydantic_ai_graph/__init__.py | 3 +- pydantic_ai_graph/pydantic_ai_graph/graph.py | 102 +++++++++++------- pydantic_ai_graph/pydantic_ai_graph/nodes.py | 2 +- pydantic_ai_graph/pydantic_ai_graph/state.py | 39 +++++-- pydantic_ai_graph/pyproject.toml | 2 +- tests/conftest.py | 5 +- tests/test_graph.py | 69 +++++++++++- uprev.py | 14 ++- uv.lock | 2 +- 9 files changed, 184 insertions(+), 54 deletions(-) diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py index 6345bd7090..aaef92ef7a 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/__init__.py +++ b/pydantic_ai_graph/pydantic_ai_graph/__init__.py @@ -1,4 +1,5 @@ from .graph import Graph from .nodes import BaseNode, End, GraphContext +from .state import Snapshot -__all__ = 'BaseNode', 'End', 'GraphContext', 'Graph' +__all__ = 'BaseNode', 'End', 'GraphContext', 'Graph', 'Snapshot' diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index e21bfc7b54..acccccbbbe 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -4,9 +4,12 @@ import inspect import types from dataclasses import dataclass +from datetime import datetime, timezone from pathlib import Path +from time import perf_counter from typing import TYPE_CHECKING, Any, Generic +import logfire_api from typing_extensions import TypeVar, assert_never from . import _utils @@ -18,9 +21,11 @@ GraphOutputT, NodeDef, ) -from .state import StateT +from .state import Snapshot, StateT __all__ = ('Graph',) + +_logfire = logfire_api.Logfire(otel_scope='pydantic-ai-graph') GraphInputT = TypeVar('GraphInputT', default=Any) @@ -31,15 +36,13 @@ class Graph(Generic[GraphInputT, GraphOutputT, DepsT, StateT]): first_node: NodeDef[Any, Any, DepsT, StateT] nodes: dict[str, NodeDef[Any, Any, DepsT, StateT]] - state_type: type[StateT] | None + name: str | None - # noinspection PyUnusedLocal def __init__( self, first_node: type[BaseNode[GraphInputT, Any, DepsT, StateT]], *other_nodes: type[BaseNode[Any, GraphOutputT, DepsT, StateT]], - deps_type: type[DepsT] | None = None, - state_type: type[StateT] | None = None, + name: str | None = None, ): parent_namespace = get_parent_namespace(inspect.currentframe()) self.first_node = first_node.get_node_def(parent_namespace) @@ -49,44 +52,71 @@ def __init__( nodes[node_def.node_id] = node_def self._check() - self.state_type = state_type + self.name = name async def run( - self, input_data: GraphInputT, deps: DepsT = None, state: StateT = None - ) -> tuple[GraphOutputT, StateT]: + self, + input_data: GraphInputT, + deps: DepsT = None, + state: StateT = None, + history: list[Snapshot] | None = None, + ) -> tuple[GraphOutputT, list[Snapshot]]: current_node_def = self.first_node current_node = current_node_def.node(input_data) ctx = GraphContext(deps, state) - while True: - # noinspection PyUnresolvedReferences - next_node = await current_node.run(ctx) - if isinstance(next_node, End): - if current_node_def.can_end: - return next_node.data, ctx.state - else: - raise ValueError(f'Node {current_node_def.node_id} cannot end the graph') - elif isinstance(next_node, BaseNode): - next_node_id = next_node.get_id() - try: - next_node_def = self.nodes[next_node_id] - except KeyError as e: - raise ValueError( - f'Node {current_node_def.node_id} cannot go to {next_node_id} which is not in the Graph' - ) from e - - if not current_node_def.dest_any and next_node_id not in current_node_def.next_node_ids: - raise ValueError( - f'Node {current_node_def.node_id} cannot go to {next_node_id} which is not in its ' - f'list of allowed next nodes' + if history: + run_history = history[:] + else: + run_history = [] + + with _logfire.span( + '{graph_name} run {input=}', + graph_name=self.name or 'graph', + input=input_data, + graph=self, + ) as run_span: + while True: + with _logfire.span('run node {node_id}', node_id=current_node_def.node_id): + start_ts = datetime.now(tz=timezone.utc) + start = perf_counter() + # noinspection PyUnresolvedReferences + next_node = await current_node.run(ctx) + duration = perf_counter() - start + + if isinstance(next_node, End): + if current_node_def.can_end: + run_history.append( + Snapshot.from_state(current_node_def.node_id, None, start_ts, duration, ctx.state) + ) + run_span.set_attribute('history', run_history) + return next_node.data, run_history + else: + raise ValueError(f'Node {current_node_def.node_id} cannot end the graph') + elif isinstance(next_node, BaseNode): + next_node_id = next_node.get_id() + run_history.append( + Snapshot.from_state(current_node_def.node_id, next_node_id, start_ts, duration, ctx.state) ) - - current_node_def = next_node_def - current_node = next_node - else: - if TYPE_CHECKING: - assert_never(next_node) + try: + next_node_def = self.nodes[next_node_id] + except KeyError as e: + raise ValueError( + f'Node {current_node_def.node_id} cannot go to {next_node_id} which is not in the Graph' + ) from e + + if not current_node_def.dest_any and next_node_id not in current_node_def.next_node_ids: + raise ValueError( + f'Node {current_node_def.node_id} cannot go to {next_node_id} which is not in its ' + f'list of allowed next nodes' + ) + + current_node_def = next_node_def + current_node = next_node else: - raise TypeError(f'Invalid node type: {type(next_node)} expected BaseNode or End') + if TYPE_CHECKING: + assert_never(next_node) + else: + raise TypeError(f'Invalid node type: {type(next_node)} expected BaseNode or End') def mermaid_code(self) -> str: lines = ['graph TD'] diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_ai_graph/pydantic_ai_graph/nodes.py index c58102dfaf..376d9db1e0 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/nodes.py +++ b/pydantic_ai_graph/pydantic_ai_graph/nodes.py @@ -73,7 +73,7 @@ async def run(self, ctx: GraphContext[DepsT, StateT]) -> BaseNode[Any, Any, Deps @classmethod @cache def get_id(cls) -> str: - return cls.node_id or cls.__qualname__ + return cls.node_id or cls.__name__ @classmethod def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[Any, Any, DepsT, StateT]: diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py index f9fac1199f..133085ec8b 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -1,17 +1,44 @@ -from abc import ABC +from __future__ import annotations as _annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime from typing_extensions import TypeVar -__all__ = 'AbstractState', 'StateT' +__all__ = 'AbstractState', 'StateT', 'Snapshot' class AbstractState(ABC): """Abstract class for a state object.""" - def __init__(self): - pass - - # TODO serializing and deserialize state + @abstractmethod + def serialize(self) -> bytes | None: + """Serialize the state object.""" + raise NotImplementedError StateT = TypeVar('StateT', None, AbstractState, default=None) + + +@dataclass +class Snapshot: + """Snapshot of a graph.""" + + last_node_id: str + next_node_id: str | None + start_ts: datetime + duration: float + state: bytes | None = None + + @classmethod + def from_state( + cls, last_node_id: str, next_node_id: str | None, start_ts: datetime, duration: float, state: StateT + ) -> Snapshot: + return cls( + last_node_id=last_node_id, + next_node_id=next_node_id, + start_ts=start_ts, + duration=duration, + state=state.serialize() if state is not None else None, + ) diff --git a/pydantic_ai_graph/pyproject.toml b/pydantic_ai_graph/pyproject.toml index 2231df5cd1..cb047a9a8e 100644 --- a/pydantic_ai_graph/pyproject.toml +++ b/pydantic_ai_graph/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "pydantic-ai-graph" -version = "0.0.1" +version = "0.0.14" description = "Graph and state machine library" authors = [ { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, diff --git a/tests/conftest.py b/tests/conftest.py index 4a219ce00d..df874f8be1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,7 @@ import pydantic_ai.models -__all__ = 'IsNow', 'TestEnv', 'ClientWithHandler', 'try_import' +__all__ = 'IsNow', 'IsFloat', 'TestEnv', 'ClientWithHandler', 'try_import' pydantic_ai.models.ALLOW_MODEL_REQUESTS = False @@ -28,8 +28,9 @@ if TYPE_CHECKING: def IsNow(*args: Any, **kwargs: Any) -> datetime: ... + def IsFloat(*args: Any, **kwargs: Any) -> float: ... else: - from dirty_equals import IsNow + from dirty_equals import IsFloat, IsNow try: from logfire.testing import CaptureLogfire diff --git a/tests/test_graph.py b/tests/test_graph.py index 3b8ede23eb..af27f03cde 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,8 +1,13 @@ from __future__ import annotations as _annotations +from datetime import timezone + import pytest +from inline_snapshot import snapshot + +from pydantic_ai_graph import BaseNode, End, Graph, GraphContext, Snapshot -from pydantic_ai_graph import BaseNode, End, Graph, GraphContext +from .conftest import IsFloat, IsNow pytestmark = pytest.mark.anyio @@ -28,7 +33,65 @@ async def run(self, ctx: GraphContext) -> String2Length | End[int]: String2Length, Double, ) + result, history = await g.run(3.14) # len('3.14') * 2 == 8 - assert await g.run(3.14) == (8, None) + assert result == 8 + assert history == snapshot( + [ + Snapshot( + last_node_id='Float2String', + next_node_id='String2Length', + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(gt=0, lt=1e-5), + ), + Snapshot( + last_node_id='String2Length', + next_node_id='Double', + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(gt=0, lt=1e-5), + ), + Snapshot( + last_node_id='Double', + next_node_id=None, + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(gt=0, lt=1e-5), + ), + ] + ) + result, history = await g.run(3.14159) # len('3.14159') == 7, 21 * 2 == 42 - assert await g.run(3.14159) == (42, None) + assert result == 42 + assert history == snapshot( + [ + Snapshot( + last_node_id='Float2String', + next_node_id='String2Length', + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(gt=0, lt=1e-5), + ), + Snapshot( + last_node_id='String2Length', + next_node_id='Double', + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(gt=0, lt=1e-5), + ), + Snapshot( + last_node_id='Double', + next_node_id='String2Length', + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(gt=0, lt=1e-5), + ), + Snapshot( + last_node_id='String2Length', + next_node_id='Double', + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(gt=0, lt=1e-5), + ), + Snapshot( + last_node_id='Double', + next_node_id=None, + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(gt=0, lt=1e-5), + ), + ] + ) diff --git a/uprev.py b/uprev.py index 97562b5d3f..3eeca56d01 100644 --- a/uprev.py +++ b/uprev.py @@ -3,8 +3,9 @@ Because we have multiple packages which depend on one-another, we have to update the version number in: * pyproject.toml -* pydantic_ai_examples/pyproject.toml +* examples/pyproject.toml * pydantic_ai_slim/pyproject.toml +* pydantic_ai_graph/pyproject.toml Usage: @@ -67,10 +68,15 @@ def replace_deps_version(text: str) -> tuple[str, int]: slim_pp_text = slim_pp.read_text() slim_pp_text, count_slim = replace_deps_version(slim_pp_text) -if count_root == 2 and count_ex == 2 and count_slim == 1: +graph_pp = ROOT_DIR / 'pydantic_ai_graph' / 'pyproject.toml' +graph_pp_text = graph_pp.read_text() +graph_pp_text, count_graph = replace_deps_version(graph_pp_text) + +if count_root == 2 and count_ex == 2 and count_slim == 1 and count_graph == 1: root_pp.write_text(root_pp_text) examples_pp.write_text(examples_pp_text) slim_pp.write_text(slim_pp_text) + graph_pp.write_text(graph_pp_text) print('running `make sync`...') subprocess.run(['make', 'sync'], check=True) print(f'running `git switch -c uprev-{version}`...') @@ -79,7 +85,8 @@ def replace_deps_version(text: str) -> tuple[str, int]: f'SUCCESS: replaced version in\n' f' {root_pp}\n' f' {examples_pp}\n' - f' {slim_pp}' + f' {slim_pp}\n' + f' {graph_pp}' ) else: print( @@ -87,6 +94,7 @@ def replace_deps_version(text: str) -> tuple[str, int]: f' {count_root} version references in {root_pp} (expected 2)\n' f' {count_ex} version references in {examples_pp} (expected 2)\n' f' {count_slim} version references in {slim_pp} (expected 1)', + f' {count_graph} version references in {graph_pp} (expected 1)', file=sys.stderr, ) sys.exit(1) diff --git a/uv.lock b/uv.lock index 2234f25d33..6e43242d62 100644 --- a/uv.lock +++ b/uv.lock @@ -2542,7 +2542,7 @@ requires-dist = [ [[package]] name = "pydantic-ai-graph" -version = "0.0.1" +version = "0.0.14" source = { editable = "pydantic_ai_graph" } dependencies = [ { name = "httpx" }, From 73db28277fc09899db9f17ca60e3a1b1598a1c53 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 22 Dec 2024 23:41:50 +0000 Subject: [PATCH 05/57] add example, alter types --- .../email_extract_graph.py | 145 ++++++++++++++++++ .../pydantic_ai_graph/__init__.py | 4 +- pydantic_ai_graph/pydantic_ai_graph/_utils.py | 10 ++ pydantic_ai_graph/pydantic_ai_graph/graph.py | 29 ++-- pydantic_ai_graph/pydantic_ai_graph/nodes.py | 25 ++- pydantic_ai_graph/pydantic_ai_graph/state.py | 3 +- tests/test_graph.py | 8 +- tests/typed_graph.py | 47 +++--- 8 files changed, 208 insertions(+), 63 deletions(-) create mode 100644 examples/pydantic_ai_examples/email_extract_graph.py diff --git a/examples/pydantic_ai_examples/email_extract_graph.py b/examples/pydantic_ai_examples/email_extract_graph.py new file mode 100644 index 0000000000..c91613b4ec --- /dev/null +++ b/examples/pydantic_ai_examples/email_extract_graph.py @@ -0,0 +1,145 @@ +from __future__ import annotations as _annotations + +import asyncio +from datetime import datetime, timedelta + +import logfire +from devtools import debug +from pydantic import BaseModel +from pydantic_ai_graph import AbstractState, BaseNode, End, Graph, GraphContext + +from pydantic_ai import Agent, RunContext + +logfire.configure(send_to_logfire='if-token-present') + + +class EventDetails(BaseModel): + title: str + location: str + start_ts: datetime + end_ts: datetime + + +class State(AbstractState, BaseModel): + email_content: str + skip_events: list[str] = [] + attempt: int = 0 + + def serialize(self) -> bytes | None: + return self.model_dump_json(exclude={'email_content'}).encode() + + +class RawEventDetails(BaseModel): + title: str + location: str + start_ts: str + duration: str + + +extract_agent = Agent('openai:gpt-4o', result_type=RawEventDetails, deps_type=list[str]) + + +@extract_agent.system_prompt +def extract_system_prompt(ctx: RunContext[list[str]]): + prompt = 'Extract event details from the email body.' + if ctx.deps: + skip_events = '\n'.join(ctx.deps) + prompt += f'\n\nDo not return the following events:\n{skip_events}' + return prompt + + +class ExtractEvent(BaseNode[State, None]): + async def run(self, ctx: GraphContext[State]) -> CleanEvent: + event = await extract_agent.run( + ctx.state.email_content, deps=ctx.state.skip_events + ) + return CleanEvent(event.data) + + +# agent used to extract the timestamp from the string in `CleanEvent` +timestamp_agent = Agent('openai:gpt-4o', result_type=datetime) + + +@timestamp_agent.system_prompt +def timestamp_system_prompt(): + return f'Extract the timestamp from the string, the current timestamp is: {datetime.now().isoformat()}' + + +# agent used to extract the duration from the string in `CleanEvent` +duration_agent = Agent( + 'openai:gpt-4o', + result_type=timedelta, + system_prompt='Extract the duration from the string as an ISO 8601 interval.', +) + + +class CleanEvent(BaseNode[State, RawEventDetails]): + async def run(self, ctx: GraphContext[State]) -> InspectEvent: + start_ts, duration = await asyncio.gather( + timestamp_agent.run(self.input_data.start_ts), + duration_agent.run(self.input_data.duration), + ) + return InspectEvent( + EventDetails( + title=self.input_data.title, + location=self.input_data.location, + start_ts=start_ts.data, + end_ts=start_ts.data + duration.data, + ) + ) + + +class InspectEvent(BaseNode[State, EventDetails, EventDetails | None]): + async def run( + self, ctx: GraphContext[State] + ) -> ExtractEvent | End[EventDetails | None]: + now = datetime.now() + if self.input_data.start_ts.tzinfo is not None: + now = now.astimezone(self.input_data.start_ts.tzinfo) + + if self.input_data.start_ts > now: + return End(self.input_data) + ctx.state.attempt += 1 + if ctx.state.attempt > 2: + return End(None) + else: + ctx.state.skip_events.append(self.input_data.title) + return ExtractEvent(None) + + +graph = Graph[State, None, EventDetails | None]( + ExtractEvent, + CleanEvent, + InspectEvent, +) +print(graph.mermaid_code()) + +email = """ +Hi Samuel, + +I hope this message finds you well! I wanted to share a quick update on our recent and upcoming team events. + +Firstly, a big thank you to everyone who participated in last month's +Team Building Retreat held on November 15th 2024 for 1 day. +It was a fantastic opportunity to enhance our collaboration and communication skills while having fun. Your +feedback was incredibly positive, and we're already planning to make the next retreat even better! + +Looking ahead, I'm excited to invite you all to our Annual Year-End Gala on January 20th 2025. +This event will be held at the Grand City Ballroom starting at 6 PM until 8pm. It promises to be an evening full +of entertainment, good food, and great company, celebrating the achievements and hard work of our amazing team +over the past year. + +Please mark your calendars and RSVP by January 10th. I hope to see all of you there! + +Best regards, +""" + + +async def main(): + state = State(email_content=email) + result, history = await graph.run(None, state) + debug(result, history) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py index aaef92ef7a..6a6a4ef9f4 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/__init__.py +++ b/pydantic_ai_graph/pydantic_ai_graph/__init__.py @@ -1,5 +1,5 @@ from .graph import Graph from .nodes import BaseNode, End, GraphContext -from .state import Snapshot +from .state import AbstractState, Snapshot -__all__ = 'BaseNode', 'End', 'GraphContext', 'Graph', 'Snapshot' +__all__ = 'BaseNode', 'End', 'GraphContext', 'Graph', 'Snapshot', 'AbstractState' diff --git a/pydantic_ai_graph/pydantic_ai_graph/_utils.py b/pydantic_ai_graph/pydantic_ai_graph/_utils.py index b58e142f6e..d7474494a5 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/_utils.py +++ b/pydantic_ai_graph/pydantic_ai_graph/_utils.py @@ -35,3 +35,13 @@ def comma_and(items: list[str]) -> str: else: # oxford comma ¯\_(ツ)_/¯ return ', '.join(items[:-1]) + ', and ' + items[-1] + + +_NoneType = type(None) + + +def type_arg_name(arg: Any) -> str: + if arg is _NoneType: + return 'None' + else: + return arg.__name__ diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index acccccbbbe..5774543a91 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -13,14 +13,7 @@ from typing_extensions import TypeVar, assert_never from . import _utils -from .nodes import ( - BaseNode, - DepsT, - End, - GraphContext, - GraphOutputT, - NodeDef, -) +from .nodes import BaseNode, End, GraphContext, GraphOutputT, NodeDef from .state import Snapshot, StateT __all__ = ('Graph',) @@ -31,18 +24,19 @@ # noinspection PyTypeHints @dataclass(init=False) -class Graph(Generic[GraphInputT, GraphOutputT, DepsT, StateT]): +class Graph(Generic[StateT, GraphInputT, GraphOutputT]): """Definition of a graph.""" - first_node: NodeDef[Any, Any, DepsT, StateT] - nodes: dict[str, NodeDef[Any, Any, DepsT, StateT]] + first_node: NodeDef[StateT, Any, Any] + nodes: dict[str, NodeDef[StateT, Any, Any]] name: str | None def __init__( self, - first_node: type[BaseNode[GraphInputT, Any, DepsT, StateT]], - *other_nodes: type[BaseNode[Any, GraphOutputT, DepsT, StateT]], + first_node: type[BaseNode[StateT, GraphInputT, GraphOutputT]], + *other_nodes: type[BaseNode[StateT, Any, GraphOutputT]], name: str | None = None, + state_type: type[StateT] | None = None, ): parent_namespace = get_parent_namespace(inspect.currentframe()) self.first_node = first_node.get_node_def(parent_namespace) @@ -57,13 +51,12 @@ def __init__( async def run( self, input_data: GraphInputT, - deps: DepsT = None, state: StateT = None, history: list[Snapshot] | None = None, ) -> tuple[GraphOutputT, list[Snapshot]]: current_node_def = self.first_node current_node = current_node_def.node(input_data) - ctx = GraphContext(deps, state) + ctx = GraphContext(state) if history: run_history = history[:] else: @@ -123,6 +116,8 @@ def mermaid_code(self) -> str: # order of destination nodes should match their order in `self.nodes` node_order = {nid: index for index, nid in enumerate(self.nodes.keys())} for node_id, node in self.nodes.items(): + if node_id == self.first_node.node_id: + lines.append(f' START --> {node_id}') if node.dest_any: for next_node_id in self.nodes: lines.append(f' {node_id} --> {next_node_id}') @@ -141,9 +136,9 @@ def mermaid_image(self, mermaid_ink_params: dict[str, str | int] | None = None) response.raise_for_status() return response.content - def mermaid_save(self, path: Path, mermaid_ink_params: dict[str, str | int] | None = None) -> None: + def mermaid_save(self, path: Path | str, mermaid_ink_params: dict[str, str | int] | None = None) -> None: image_data = self.mermaid_image(mermaid_ink_params) - path.write_bytes(image_data) + Path(path).write_bytes(image_data) def _check(self): bad_edges: dict[str, list[str]] = {} diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_ai_graph/pydantic_ai_graph/nodes.py index 376d9db1e0..9ed72dfc23 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/nodes.py +++ b/pydantic_ai_graph/pydantic_ai_graph/nodes.py @@ -13,7 +13,6 @@ __all__ = ( 'NodeInputT', 'GraphOutputT', - 'DepsT', 'GraphContext', 'End', 'BaseNode', @@ -22,15 +21,13 @@ NodeInputT = TypeVar('NodeInputT', default=Any) GraphOutputT = TypeVar('GraphOutputT', default=Any) -DepsT = TypeVar('DepsT', default=None) # noinspection PyTypeHints @dataclass -class GraphContext(Generic[DepsT, StateT]): +class GraphContext(Generic[StateT]): """Context for a graph.""" - deps: DepsT state: StateT @@ -48,17 +45,17 @@ class _BaseNodeMeta(ABCMeta): def __repr__(cls): base: Any = cls.__orig_bases__[0] # type: ignore args = get_args(base) - if len(args) == 4 and args[3] is None: - if args[2] is None: - args = args[:2] + if len(args) == 3 and args[2] is Any: + if args[1] is Any: + args = args[:1] else: - args = args[:3] - args = ', '.join(a.__name__ for a in args) + args = args[:2] + args = ', '.join(_utils.type_arg_name(a) for a in args) return f'{cls.__name__}({base.__name__}[{args}])' # noinspection PyTypeHints -class BaseNode(Generic[NodeInputT, GraphOutputT, DepsT, StateT], metaclass=_BaseNodeMeta): +class BaseNode(Generic[StateT, NodeInputT, GraphOutputT], metaclass=_BaseNodeMeta): """Base class for a node.""" node_id: ClassVar[str | None] = None @@ -68,7 +65,7 @@ def __init__(self, input_data: NodeInputT) -> None: self.input_data = input_data @abstractmethod - async def run(self, ctx: GraphContext[DepsT, StateT]) -> BaseNode[Any, Any, DepsT, StateT] | End[GraphOutputT]: ... + async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any, Any] | End[GraphOutputT]: ... @classmethod @cache @@ -76,7 +73,7 @@ def get_id(cls) -> str: return cls.node_id or cls.__name__ @classmethod - def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[Any, Any, DepsT, StateT]: + def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, Any, Any]: type_hints = get_type_hints(cls.run, localns=local_ns) next_node_ids: set[str] = set() can_end: bool = False @@ -103,13 +100,13 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[Any, Any, Deps # noinspection PyTypeHints @dataclass -class NodeDef(ABC, Generic[NodeInputT, GraphOutputT, DepsT, StateT]): +class NodeDef(ABC, Generic[StateT, NodeInputT, GraphOutputT]): """Definition of a node. Used by [`Graph`][pydantic_ai_graph.graph.Graph] store information about a node. """ - node: type[BaseNode[NodeInputT, GraphOutputT, DepsT, StateT]] + node: type[BaseNode[StateT, NodeInputT, GraphOutputT]] node_id: str next_node_ids: set[str] can_end: bool diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py index 133085ec8b..cf64a32155 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime +from typing import Union from typing_extensions import TypeVar @@ -18,7 +19,7 @@ def serialize(self) -> bytes | None: raise NotImplementedError -StateT = TypeVar('StateT', None, AbstractState, default=None) +StateT = TypeVar('StateT', bound=Union[None, AbstractState], default=None) @dataclass diff --git a/tests/test_graph.py b/tests/test_graph.py index af27f03cde..2e579f6410 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -13,22 +13,22 @@ async def test_graph(): - class Float2String(BaseNode[float]): + class Float2String(BaseNode[None, float]): async def run(self, ctx: GraphContext) -> String2Length: return String2Length(str(self.input_data)) - class String2Length(BaseNode[str]): + class String2Length(BaseNode[None, str]): async def run(self, ctx: GraphContext) -> Double: return Double(len(self.input_data)) - class Double(BaseNode[int, int]): + class Double(BaseNode[None, int, int]): async def run(self, ctx: GraphContext) -> String2Length | End[int]: if self.input_data == 7: return String2Length('x' * 21) else: return End(self.input_data * 2) - g = Graph[float, int]( + g = Graph[None, float, int]( Float2String, String2Length, Double, diff --git a/tests/typed_graph.py b/tests/typed_graph.py index e6974e3070..806347f8fb 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -1,63 +1,60 @@ from __future__ import annotations as _annotations +from dataclasses import dataclass from typing import assert_type from pydantic_ai_graph import BaseNode, End, Graph, GraphContext -class Float2String(BaseNode[float]): +class Float2String(BaseNode[None, float]): async def run(self, ctx: GraphContext) -> String2Length: return String2Length(str(self.input_data)) -class String2Length(BaseNode[str]): +class String2Length(BaseNode[None, str]): async def run(self, ctx: GraphContext) -> Double: return Double(len(self.input_data)) -class Double(BaseNode[int, int]): - async def run(self, ctx: GraphContext) -> String2Length | End[int]: +@dataclass +class X: + v: int + + +class Double(BaseNode[None, int, X]): + async def run(self, ctx: GraphContext) -> String2Length | End[X]: if self.input_data == 7: return String2Length('x' * 21) else: - return End(self.input_data * 2) + return End(X(self.input_data * 2)) -def use_double(node: BaseNode[int, int]) -> None: - """Shoe that `Double` is valid as a `BaseNode[int, int]`.""" +def use_double(node: BaseNode[None, int, X]) -> None: + """Shoe that `Double` is valid as a `BaseNode[None, int, X]`.""" + print(node) use_double(Double(1)) -g1 = Graph[float, int]( +g1 = Graph[None, float, X]( Float2String, String2Length, Double, ) -assert_type(g1, Graph[float, int]) +assert_type(g1, Graph[None, float, X]) -g2 = Graph(Float2String, Double) -assert_type(g2, Graph[float, int]) +g2 = Graph(Double) +assert_type(g2, Graph[None, int, X]) g3 = Graph( - Float2String, - Double, - String2Length, -) -MYPY = False -if MYPY: - # with mypy the presence of `String2Length` makes the output type Any - assert_type(g3, Graph[float]) # pyright: ignore[reportAssertTypeFailure] -else: - # pyright works correct and uses `Double` to infer the output type - assert_type(g3, Graph[float, int]) - -g4 = Graph( Float2String, String2Length, Double, ) # because String2Length came before Double, the output type is Any -assert_type(g4, Graph[float]) +assert_type(g3, Graph[None, float]) + +Graph[None, float, bytes](Float2String, String2Length, Double) # type: ignore[arg-type] +Graph[None, int, str](Double) # type: ignore[arg-type] From 03fcaf29120fe0502cc7142be94c193924bb13c4 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 23 Dec 2024 00:13:37 +0000 Subject: [PATCH 06/57] fix dependencies --- .github/workflows/ci.yml | 4 ++-- pydantic_ai_graph/pydantic_ai_graph/graph.py | 2 +- pydantic_ai_graph/pydantic_ai_graph/state.py | 10 ++++++++-- pydantic_ai_slim/pyproject.toml | 4 ++++ pyproject.toml | 3 ++- tests/test_graph.py | 4 ++-- uv.lock | 8 ++++++-- 7 files changed, 25 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1bb3094f7a..8451a9a402 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -129,8 +129,8 @@ jobs: - run: mkdir coverage - # run tests with just `pydantic-ai-slim` dependencies - - run: uv run --package pydantic-ai-slim coverage run -m pytest + # run tests with just `pydantic-ai-slim` and `pydantic-ai-graph` dependencies + - run: uv run --package pydantic-ai-slim --package pydantic-ai-graph coverage run -m pytest env: COVERAGE_FILE: coverage/.coverage.${{ runner.os }}-py${{ matrix.python-version }}-slim diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index 5774543a91..e0a9d10f8e 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -79,7 +79,7 @@ async def run( if isinstance(next_node, End): if current_node_def.can_end: run_history.append( - Snapshot.from_state(current_node_def.node_id, None, start_ts, duration, ctx.state) + Snapshot.from_state(current_node_def.node_id, 'END', start_ts, duration, ctx.state) ) run_span.set_attribute('history', run_history) return next_node.data, run_history diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py index cf64a32155..00789887ed 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -27,14 +27,14 @@ class Snapshot: """Snapshot of a graph.""" last_node_id: str - next_node_id: str | None + next_node_id: str start_ts: datetime duration: float state: bytes | None = None @classmethod def from_state( - cls, last_node_id: str, next_node_id: str | None, start_ts: datetime, duration: float, state: StateT + cls, last_node_id: str, next_node_id: str, start_ts: datetime, duration: float, state: StateT ) -> Snapshot: return cls( last_node_id=last_node_id, @@ -43,3 +43,9 @@ def from_state( duration=duration, state=state.serialize() if state is not None else None, ) + + def summary(self) -> str: + s = f'{self.last_node_id} -> {self.next_node_id}' + if self.duration > 1e-5: + s += f' ({self.duration:.6f}s)' + return s diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index abc4c0ba99..693d27aaec 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ ] [project.optional-dependencies] +graph = ["pydantic-ai-graph==0.0.14"] openai = ["openai>=1.54.3"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"] anthropic = ["anthropic>=0.40.0"] @@ -65,3 +66,6 @@ dev = [ [tool.hatch.build.targets.wheel] packages = ["pydantic_ai"] + +[tool.uv.sources] +pydantic-ai-graph = { workspace = true } diff --git a/pyproject.toml b/pyproject.toml index fc64bd8263..e1edd3ca85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ classifiers = [ ] requires-python = ">=3.9" -dependencies = ["pydantic-ai-slim[openai,vertexai,groq,anthropic,mistral]==0.0.18"] +dependencies = ["pydantic-ai-slim[graph,openai,vertexai,groq,anthropic,mistral]==0.0.18"] [project.urls] Homepage = "https://ai.pydantic.dev" @@ -51,6 +51,7 @@ logfire = ["logfire>=2.3"] [tool.uv.sources] pydantic-ai-slim = { workspace = true } +pydantic-ai-graph = { workspace = true } pydantic-ai-examples = { workspace = true } [tool.uv.workspace] diff --git a/tests/test_graph.py b/tests/test_graph.py index 2e579f6410..dafdff5dce 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -52,7 +52,7 @@ async def run(self, ctx: GraphContext) -> String2Length | End[int]: ), Snapshot( last_node_id='Double', - next_node_id=None, + next_node_id='END', start_ts=IsNow(tz=timezone.utc), duration=IsFloat(gt=0, lt=1e-5), ), @@ -89,7 +89,7 @@ async def run(self, ctx: GraphContext) -> String2Length | End[int]: ), Snapshot( last_node_id='Double', - next_node_id=None, + next_node_id='END', start_ts=IsNow(tz=timezone.utc), duration=IsFloat(gt=0, lt=1e-5), ), diff --git a/uv.lock b/uv.lock index 6e43242d62..f7627c796b 100644 --- a/uv.lock +++ b/uv.lock @@ -2458,7 +2458,7 @@ name = "pydantic-ai" version = "0.0.18" source = { editable = "." } dependencies = [ - { name = "pydantic-ai-slim", extra = ["anthropic", "groq", "mistral", "openai", "vertexai"] }, + { name = "pydantic-ai-slim", extra = ["anthropic", "graph", "groq", "mistral", "openai", "vertexai"] }, ] [package.optional-dependencies] @@ -2490,7 +2490,7 @@ lint = [ requires-dist = [ { name = "logfire", marker = "extra == 'logfire'", specifier = ">=2.3" }, { name = "pydantic-ai-examples", marker = "extra == 'examples'", editable = "examples" }, - { name = "pydantic-ai-slim", extras = ["openai", "vertexai", "groq", "anthropic", "mistral"], editable = "pydantic_ai_slim" }, + { name = "pydantic-ai-slim", extras = ["graph", "openai", "vertexai", "groq", "anthropic", "mistral"], editable = "pydantic_ai_slim" }, ] [package.metadata.requires-dev] @@ -2573,6 +2573,9 @@ dependencies = [ anthropic = [ { name = "anthropic" }, ] +graph = [ + { name = "pydantic-ai-graph" }, +] groq = [ { name = "groq" }, ] @@ -2618,6 +2621,7 @@ requires-dist = [ { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.2.5" }, { name = "openai", marker = "extra == 'openai'", specifier = ">=1.54.3" }, { name = "pydantic", specifier = ">=2.10" }, + { name = "pydantic-ai-graph", marker = "extra == 'graph'", editable = "pydantic_ai_graph" }, { name = "requests", marker = "extra == 'vertexai'", specifier = ">=2.32.3" }, ] From 1bae9bf58439e2107fae5cc5b9233200a1c40ac9 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 23 Dec 2024 00:17:37 +0000 Subject: [PATCH 07/57] fix ci deps --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8451a9a402..9cab6307bd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -130,7 +130,7 @@ jobs: - run: mkdir coverage # run tests with just `pydantic-ai-slim` and `pydantic-ai-graph` dependencies - - run: uv run --package pydantic-ai-slim --package pydantic-ai-graph coverage run -m pytest + - run: uv run --package pydantic-ai-slim --extra graph coverage run -m pytest env: COVERAGE_FILE: coverage/.coverage.${{ runner.os }}-py${{ matrix.python-version }}-slim From 61f3d2b7710ee0e943526be85b9e8e11a1101e0d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 23 Dec 2024 00:23:19 +0000 Subject: [PATCH 08/57] fix tests for other versions --- pydantic_ai_graph/pydantic_ai_graph/_utils.py | 6 +++++- tests/test_graph.py | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_graph/pydantic_ai_graph/_utils.py b/pydantic_ai_graph/pydantic_ai_graph/_utils.py index d7474494a5..1136108809 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/_utils.py +++ b/pydantic_ai_graph/pydantic_ai_graph/_utils.py @@ -1,6 +1,10 @@ +from __future__ import annotations as _annotations + import sys import types -from typing import Any, TypeAliasType, Union, get_args, get_origin +from typing import Any, Union, get_args, get_origin + +from typing_extensions import TypeAliasType def get_union_args(tp: Any) -> tuple[Any, ...]: diff --git a/tests/test_graph.py b/tests/test_graph.py index dafdff5dce..5045d81027 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations from datetime import timezone +from typing import Union import pytest from inline_snapshot import snapshot @@ -22,7 +23,7 @@ async def run(self, ctx: GraphContext) -> Double: return Double(len(self.input_data)) class Double(BaseNode[None, int, int]): - async def run(self, ctx: GraphContext) -> String2Length | End[int]: + async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noqa: UP007 if self.input_data == 7: return String2Length('x' * 21) else: From dc769c909a17a19ec1ea80a2fa8b4c4dce9c5103 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 23 Dec 2024 00:27:05 +0000 Subject: [PATCH 09/57] change node test times --- tests/test_graph.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_graph.py b/tests/test_graph.py index 5045d81027..81b54fa070 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -43,19 +43,19 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq last_node_id='Float2String', next_node_id='String2Length', start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-5), + duration=IsFloat(gt=0, lt=1e-3), ), Snapshot( last_node_id='String2Length', next_node_id='Double', start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-5), + duration=IsFloat(gt=0, lt=1e-3), ), Snapshot( last_node_id='Double', next_node_id='END', start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-5), + duration=IsFloat(gt=0, lt=1e-3), ), ] ) @@ -68,31 +68,31 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq last_node_id='Float2String', next_node_id='String2Length', start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-5), + duration=IsFloat(gt=0, lt=1e-3), ), Snapshot( last_node_id='String2Length', next_node_id='Double', start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-5), + duration=IsFloat(gt=0, lt=1e-3), ), Snapshot( last_node_id='Double', next_node_id='String2Length', start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-5), + duration=IsFloat(gt=0, lt=1e-3), ), Snapshot( last_node_id='String2Length', next_node_id='Double', start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-5), + duration=IsFloat(gt=0, lt=1e-3), ), Snapshot( last_node_id='Double', next_node_id='END', start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-5), + duration=IsFloat(gt=0, lt=1e-3), ), ] ) From b03e7bc64107016e28456e270c95ea1c00819499 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Thu, 2 Jan 2025 11:17:46 -0700 Subject: [PATCH 10/57] pydantic-ai-graph - simplify public generics (#539) --- .../email_extract_graph.py | 31 +- .../pydantic_ai_graph/__init__.py | 17 +- pydantic_ai_graph/pydantic_ai_graph/_utils.py | 14 + pydantic_ai_graph/pydantic_ai_graph/graph.py | 426 +++++++++++++----- pydantic_ai_graph/pydantic_ai_graph/nodes.py | 90 ++-- pydantic_ai_graph/pydantic_ai_graph/state.py | 74 +-- tests/test_graph.py | 103 +++-- tests/typed_agent.py | 4 +- tests/typed_graph.py | 50 +- 9 files changed, 527 insertions(+), 282 deletions(-) diff --git a/examples/pydantic_ai_examples/email_extract_graph.py b/examples/pydantic_ai_examples/email_extract_graph.py index c91613b4ec..8465f5bcd9 100644 --- a/examples/pydantic_ai_examples/email_extract_graph.py +++ b/examples/pydantic_ai_examples/email_extract_graph.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import asyncio +from dataclasses import dataclass from datetime import datetime, timedelta import logfire @@ -48,7 +49,8 @@ def extract_system_prompt(ctx: RunContext[list[str]]): return prompt -class ExtractEvent(BaseNode[State, None]): +@dataclass +class ExtractEvent(BaseNode[State]): async def run(self, ctx: GraphContext[State]) -> CleanEvent: event = await extract_agent.run( ctx.state.email_content, deps=ctx.state.skip_events @@ -73,7 +75,10 @@ def timestamp_system_prompt(): ) -class CleanEvent(BaseNode[State, RawEventDetails]): +@dataclass +class CleanEvent(BaseNode[State]): + input_data: RawEventDetails + async def run(self, ctx: GraphContext[State]) -> InspectEvent: start_ts, duration = await asyncio.gather( timestamp_agent.run(self.input_data.start_ts), @@ -89,7 +94,10 @@ async def run(self, ctx: GraphContext[State]) -> InspectEvent: ) -class InspectEvent(BaseNode[State, EventDetails, EventDetails | None]): +@dataclass +class InspectEvent(BaseNode[State, EventDetails | None]): + input_data: EventDetails + async def run( self, ctx: GraphContext[State] ) -> ExtractEvent | End[EventDetails | None]: @@ -104,15 +112,18 @@ async def run( return End(None) else: ctx.state.skip_events.append(self.input_data.title) - return ExtractEvent(None) + return ExtractEvent() -graph = Graph[State, None, EventDetails | None]( - ExtractEvent, - CleanEvent, - InspectEvent, +graph = Graph[State, EventDetails | None]( + nodes=( + ExtractEvent, + CleanEvent, + InspectEvent, + ) ) -print(graph.mermaid_code()) +graph_runner = graph.get_runner(ExtractEvent) +print(graph_runner.mermaid_code()) email = """ Hi Samuel, @@ -137,7 +148,7 @@ async def run( async def main(): state = State(email_content=email) - result, history = await graph.run(None, state) + result, history = await graph_runner.run(state, None) debug(result, history) diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py index 6a6a4ef9f4..b7d79a6009 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/__init__.py +++ b/pydantic_ai_graph/pydantic_ai_graph/__init__.py @@ -1,5 +1,16 @@ -from .graph import Graph +from .graph import Graph, GraphRun, GraphRunner from .nodes import BaseNode, End, GraphContext -from .state import AbstractState, Snapshot +from .state import AbstractState, EndEvent, Step, StepOrEnd -__all__ = 'BaseNode', 'End', 'GraphContext', 'Graph', 'Snapshot', 'AbstractState' +__all__ = ( + 'Graph', + 'GraphRunner', + 'GraphRun', + 'BaseNode', + 'End', + 'GraphContext', + 'AbstractState', + 'EndEvent', + 'StepOrEnd', + 'Step', +) diff --git a/pydantic_ai_graph/pydantic_ai_graph/_utils.py b/pydantic_ai_graph/pydantic_ai_graph/_utils.py index 1136108809..d1b4671800 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/_utils.py +++ b/pydantic_ai_graph/pydantic_ai_graph/_utils.py @@ -49,3 +49,17 @@ def type_arg_name(arg: Any) -> str: return 'None' else: return arg.__name__ + + +def get_parent_namespace(frame: types.FrameType | None) -> dict[str, Any] | None: + """Attempt to get the namespace where the graph was defined. + + If the graph is defined with generics `Graph[a, b]` then another frame is inserted, and we have to skip that + to get the correct namespace. + """ + if frame is not None: + if back := frame.f_back: + if back.f_code.co_filename.endswith('/typing.py'): + return get_parent_namespace(back) + else: + return back.f_locals diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index e0a9d10f8e..a6e09e887a 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -2,169 +2,359 @@ import base64 import inspect -import types -from dataclasses import dataclass -from datetime import datetime, timezone +from dataclasses import dataclass, field from pathlib import Path from time import perf_counter -from typing import TYPE_CHECKING, Any, Generic +from typing import TYPE_CHECKING, Annotated, Any, Generic import logfire_api -from typing_extensions import TypeVar, assert_never +from annotated_types import Ge, Le +from typing_extensions import Literal, Never, ParamSpec, Protocol, TypeVar, assert_never from . import _utils -from .nodes import BaseNode, End, GraphContext, GraphOutputT, NodeDef -from .state import Snapshot, StateT +from ._utils import get_parent_namespace +from .nodes import BaseNode, End, GraphContext, NodeDef +from .state import EndEvent, StateT, Step, StepOrEnd -__all__ = ('Graph',) +__all__ = ('Graph', 'GraphRun', 'GraphRunner') _logfire = logfire_api.Logfire(otel_scope='pydantic-ai-graph') -GraphInputT = TypeVar('GraphInputT', default=Any) + +RunSignatureT = ParamSpec('RunSignatureT') +RunEndT = TypeVar('RunEndT', default=None) +NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) + + +class StartNodeProtocol(Protocol[RunSignatureT, StateT, NodeRunEndT]): + def get_id(self) -> str: ... + def __call__(self, *args: RunSignatureT.args, **kwargs: RunSignatureT.kwargs) -> BaseNode[StateT, NodeRunEndT]: ... -# noinspection PyTypeHints @dataclass(init=False) -class Graph(Generic[StateT, GraphInputT, GraphOutputT]): +class Graph(Generic[StateT, RunEndT]): """Definition of a graph.""" - first_node: NodeDef[StateT, Any, Any] - nodes: dict[str, NodeDef[StateT, Any, Any]] name: str | None + nodes: tuple[type[BaseNode[StateT, RunEndT]], ...] + node_defs: dict[str, NodeDef[StateT, RunEndT]] def __init__( self, - first_node: type[BaseNode[StateT, GraphInputT, GraphOutputT]], - *other_nodes: type[BaseNode[StateT, Any, GraphOutputT]], name: str | None = None, + nodes: tuple[type[BaseNode[StateT, RunEndT]], ...] = (), state_type: type[StateT] | None = None, ): + self.name = name + + _nodes_by_id: dict[str, type[BaseNode[StateT, RunEndT]]] = {} + for node in nodes: + node_id = node.get_id() + if (existing_node := _nodes_by_id.get(node_id)) and existing_node is not node: + raise ValueError(f'Node ID "{node_id}" is not unique — found in {existing_node} and {node}') + else: + _nodes_by_id[node_id] = node + self.nodes = tuple(_nodes_by_id.values()) + parent_namespace = get_parent_namespace(inspect.currentframe()) - self.first_node = first_node.get_node_def(parent_namespace) - self.nodes = nodes = {self.first_node.node_id: self.first_node} - for node in other_nodes: - node_def = node.get_node_def(parent_namespace) - nodes[node_def.node_id] = node_def + self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = {} + for node in self.nodes: + self.node_defs[node.get_id()] = node.get_node_def(parent_namespace) - self._check() - self.name = name + self._validate_edges() + + def _validate_edges(self): + known_node_ids = set(self.node_defs.keys()) + bad_edges: dict[str, list[str]] = {} + + for node_id, node_def in self.node_defs.items(): + node_bad_edges = node_def.next_node_ids - known_node_ids + for bad_edge in node_bad_edges: + bad_edges.setdefault(bad_edge, []).append(f'"{node_id}"') + + if bad_edges: + bad_edges_list = [f'"{k}" is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] + if len(bad_edges_list) == 1: + raise ValueError(f'{bad_edges_list[0]} but not included in the graph.') + else: + b = '\n'.join(f' {be}' for be in bad_edges_list) + raise ValueError(f'Nodes are referenced in the graph but not included in the graph:\n{b}') async def run( + self, state: StateT, node: BaseNode[StateT, RunEndT] + ) -> tuple[RunEndT, list[StepOrEnd[StateT, RunEndT]]]: + if not isinstance(node, self.nodes): + raise ValueError(f'Node "{node}" is not in the graph.') + run = GraphRun[StateT, RunEndT](state=state) + # TODO: Infer the graph name properly + result = await run.run(self.name or 'graph', node) + history = run.history + return result, history + + def get_runner( self, - input_data: GraphInputT, - state: StateT = None, - history: list[Snapshot] | None = None, - ) -> tuple[GraphOutputT, list[Snapshot]]: - current_node_def = self.first_node - current_node = current_node_def.node(input_data) - ctx = GraphContext(state) - if history: - run_history = history[:] - else: - run_history = [] + first_node: StartNodeProtocol[RunSignatureT, StateT, RunEndT], + ) -> GraphRunner[RunSignatureT, StateT, RunEndT]: + return GraphRunner( + graph=self, + first_node=first_node, + ) + + +@dataclass +class GraphRunner(Generic[RunSignatureT, StateT, RunEndT]): + """Runner for a graph. + + This is a separate class from Graph so that you can get a type-safe runner from a graph definition + without needing to manually annotate the paramspec of the start node. + """ + + graph: Graph[StateT, RunEndT] + first_node: StartNodeProtocol[RunSignatureT, StateT, RunEndT] + + def __post_init__(self): + if self.first_node not in self.graph.nodes: + raise ValueError(f'Start node "{self.first_node}" is not in the graph.') + + async def run( + self, state: StateT, /, *args: RunSignatureT.args, **kwargs: RunSignatureT.kwargs + ) -> tuple[RunEndT, list[StepOrEnd[StateT, RunEndT]]]: + run = GraphRun[StateT, RunEndT](state=state) + # TODO: Infer the graph name properly + result = await run.run(self.graph.name or 'graph', self.first_node(*args, **kwargs)) + history = run.history + return result, history + + def mermaid_code(self) -> str: + return mermaid_code(self.graph, self.first_node) + + def mermaid_image( + self, + image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] | str | None = None, + pdf_fit: bool = False, + pdf_landscape: bool = False, + pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] + | str + | None = None, + bg_color: str | None = None, + theme: Literal['default', 'neutral', 'dark', 'forest'] | str | None = None, + width: int | None = None, + height: int | None = None, + scale: Annotated[float, Ge(1), Le(3)] | None = None, + ) -> bytes: + return mermaid_image( + self.graph, + self.first_node, + image_type=image_type, + pdf_fit=pdf_fit, + pdf_landscape=pdf_landscape, + pdf_paper=pdf_paper, + bg_color=bg_color, + theme=theme, + width=width, + height=height, + scale=scale, + ) + + def mermaid_save( + self, + path: Path | str, + image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] | str | None = None, + pdf_fit: bool = False, + pdf_landscape: bool = False, + pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] + | str + | None = None, + bg_color: str | None = None, + theme: Literal['default', 'neutral', 'dark', 'forest'] | str | None = None, + width: int | None = None, + height: int | None = None, + scale: Annotated[float, Ge(1), Le(3)] | None = None, + ) -> None: + mermaid_save( + path, + self.graph, + self.first_node, + image_type=image_type, + pdf_fit=pdf_fit, + pdf_landscape=pdf_landscape, + pdf_paper=pdf_paper, + bg_color=bg_color, + theme=theme, + width=width, + height=height, + scale=scale, + ) + + +@dataclass +class GraphRun(Generic[StateT, RunEndT]): + """Stateful run of a graph.""" + + state: StateT + history: list[StepOrEnd[StateT, RunEndT]] = field(default_factory=list) + + async def run(self, graph_name: str, start: BaseNode[StateT, RunEndT], infer_name: bool = True) -> RunEndT: + current_node = start with _logfire.span( - '{graph_name} run {input=}', - graph_name=self.name or 'graph', - input=input_data, - graph=self, + '{graph_name} run {start=}', + graph_name=graph_name, + start=start, ) as run_span: while True: - with _logfire.span('run node {node_id}', node_id=current_node_def.node_id): - start_ts = datetime.now(tz=timezone.utc) - start = perf_counter() - # noinspection PyUnresolvedReferences - next_node = await current_node.run(ctx) - duration = perf_counter() - start - + next_node = await self.step(current_node) if isinstance(next_node, End): - if current_node_def.can_end: - run_history.append( - Snapshot.from_state(current_node_def.node_id, 'END', start_ts, duration, ctx.state) - ) - run_span.set_attribute('history', run_history) - return next_node.data, run_history - else: - raise ValueError(f'Node {current_node_def.node_id} cannot end the graph') + self.history.append(EndEvent(self.state, next_node)) + run_span.set_attribute('history', self.history) + return next_node.data elif isinstance(next_node, BaseNode): - next_node_id = next_node.get_id() - run_history.append( - Snapshot.from_state(current_node_def.node_id, next_node_id, start_ts, duration, ctx.state) - ) - try: - next_node_def = self.nodes[next_node_id] - except KeyError as e: - raise ValueError( - f'Node {current_node_def.node_id} cannot go to {next_node_id} which is not in the Graph' - ) from e - - if not current_node_def.dest_any and next_node_id not in current_node_def.next_node_ids: - raise ValueError( - f'Node {current_node_def.node_id} cannot go to {next_node_id} which is not in its ' - f'list of allowed next nodes' - ) - - current_node_def = next_node_def current_node = next_node else: if TYPE_CHECKING: assert_never(next_node) else: - raise TypeError(f'Invalid node type: {type(next_node)} expected BaseNode or End') + raise TypeError(f'Invalid node type: {type(next_node)}. Expected `BaseNode` or `End`.') - def mermaid_code(self) -> str: - lines = ['graph TD'] - # order of destination nodes should match their order in `self.nodes` - node_order = {nid: index for index, nid in enumerate(self.nodes.keys())} - for node_id, node in self.nodes.items(): - if node_id == self.first_node.node_id: - lines.append(f' START --> {node_id}') - if node.dest_any: - for next_node_id in self.nodes: - lines.append(f' {node_id} --> {next_node_id}') - for _, next_node_id in sorted((node_order[nid], nid) for nid in node.next_node_ids): - lines.append(f' {node_id} --> {next_node_id}') - if node.can_end: - lines.append(f' {node_id} --> END') - return '\n'.join(lines) + async def step(self, node: BaseNode[StateT, RunEndT]) -> BaseNode[StateT, RunEndT] | End[RunEndT]: + history_step = Step(self.state, node) + self.history.append(history_step) - def mermaid_image(self, mermaid_ink_params: dict[str, str | int] | None = None) -> bytes: - import httpx + ctx = GraphContext(self.state) + with _logfire.span('run node {node_id}', node_id=node.get_id()): + start = perf_counter() + next_node = await node.run(ctx) + history_step.duration = perf_counter() - start + return next_node - code_base64 = base64.b64encode(self.mermaid_code().encode()).decode() - response = httpx.get(f'https://mermaid.ink/img/{code_base64}', params=mermaid_ink_params) - response.raise_for_status() - return response.content +def mermaid_code( + graph: Graph[Any, Any], start: StartNodeProtocol[..., Any, Any] | tuple[StartNodeProtocol[..., Any, Any], ...] = () +) -> str: + if not isinstance(start, tuple): + start = (start,) - def mermaid_save(self, path: Path | str, mermaid_ink_params: dict[str, str | int] | None = None) -> None: - image_data = self.mermaid_image(mermaid_ink_params) - Path(path).write_bytes(image_data) + for node in start: + if node not in graph.nodes: + raise ValueError(f'Start node "{node}" is not in the graph.') - def _check(self): - bad_edges: dict[str, list[str]] = {} - for node in self.nodes.values(): - node_bad_edges = node.next_node_ids - self.nodes.keys() - for bad_edge in node_bad_edges: - bad_edges.setdefault(bad_edge, []).append(f'"{node.node_id}"') + node_order = {node_id: index for index, node_id in enumerate(graph.node_defs)} - if bad_edges: - bad_edges_list = [f'"{k}" is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] - if len(bad_edges_list) == 1: - raise ValueError(f'{bad_edges_list[0]} but not included in the graph.') - else: - b = '\n'.join(f' {be}' for be in bad_edges_list) - raise ValueError(f'Nodes are referenced in the graph but not included in the graph:\n{b}') + lines = ['graph TD'] + for node in graph.nodes: + node_id = node.get_id() + node_def = graph.node_defs[node_id] + if node in start: + lines.append(f' START --> {node_id}') + if node_def.returns_base_node: + for next_node_id in graph.nodes: + lines.append(f' {node_id} --> {next_node_id}') + else: + for _, next_node_id in sorted((node_order[node_id], node_id) for node_id in node_def.next_node_ids): + lines.append(f' {node_id} --> {next_node_id}') + if node_def.returns_end: + lines.append(f' {node_id} --> END') + + return '\n'.join(lines) -def get_parent_namespace(frame: types.FrameType | None) -> dict[str, Any] | None: - """Attempt to get the namespace where the graph was defined. +def mermaid_image( + graph: Graph[Any, Any], + start: StartNodeProtocol[..., Any, Any] | tuple[StartNodeProtocol[..., Any, Any], ...] = (), + image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] | str | None = None, + pdf_fit: bool = False, + pdf_landscape: bool = False, + pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] + | str + | None = None, + bg_color: str | None = None, + theme: Literal['default', 'neutral', 'dark', 'forest'] | str | None = None, + width: int | None = None, + height: int | None = None, + scale: Annotated[float, Ge(1), Le(3)] | None = None, +) -> bytes: + """Generate an image of a Mermaid diagram using mermaid.ink. - If the graph is defined with generics `Graph[a, b]` then another frame is inserted, and we have to skip that - to get the correct namespace. + Args: + graph: The graph to generate the image for. + start: The start node(s) of the graph. + image_type: The image type to generate. If unspecified, the default behavior is `'jpeg'`. + pdf_fit: When using image_type='pdf', whether to fit the diagram to the PDF page. + pdf_landscape: When using image_type='pdf', whether to use landscape orientation for the PDF. + This has no effect if using `pdf_fit`. + pdf_paper: When using image_type='pdf', the paper size of the PDF. + bg_color: The background color of the diagram. If None, the default transparent background is used. + The color value is interpreted as a hexadecimal color code by default (and should not have a leading '#'), + but you can also use named colors by prefixing the value with '!'. + For example, valid choices include `bg_color='!white'` or `bg_color='FF0000'`. + theme: The theme of the diagram. Defaults to 'default'. + width: The width of the diagram. + height: The height of the diagram. + scale: The scale of the diagram. The scale must be a number between 1 and 3, and you can only set + a scale if one or both of width and height are set. """ - if frame is not None: - if back := frame.f_back: - if back.f_code.co_filename.endswith('/typing.py'): - return get_parent_namespace(back) - else: - return back.f_locals + import httpx + + code_base64 = base64.b64encode(mermaid_code(graph, start).encode()).decode() + + params: dict[str, str] = {} + if image_type == 'pdf': + url = f'https://mermaid.ink/pdf/{code_base64}' + if pdf_fit: + params['fit'] = '' + if pdf_landscape: + params['landscape'] = '' + if pdf_paper: + params['paper'] = pdf_paper + else: + url = f'https://mermaid.ink/img/{code_base64}' + + if image_type: + params['type'] = image_type + + if bg_color: + params['bgColor'] = bg_color + if theme: + params['theme'] = theme + if width: + params['width'] = str(width) + if height: + params['height'] = str(height) + if scale: + params['scale'] = str(scale) + + response = httpx.get(url, params=params) + response.raise_for_status() + return response.content + + +def mermaid_save( + path: Path | str, + graph: Graph[Any, Any], + start: StartNodeProtocol[..., Any, Any] | tuple[StartNodeProtocol[..., Any, Any], ...] = (), + image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] | str | None = None, + pdf_fit: bool = False, + pdf_landscape: bool = False, + pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] + | str + | None = None, + bg_color: str | None = None, + theme: Literal['default', 'neutral', 'dark', 'forest'] | str | None = None, + width: int | None = None, + height: int | None = None, + scale: Annotated[float, Ge(1), Le(3)] | None = None, +) -> None: + # TODO: do something with the path file extension, e.g. error if it's incompatible, or use it to specify a param + image_data = mermaid_image( + graph, + start, + image_type=image_type, + pdf_fit=pdf_fit, + pdf_landscape=pdf_landscape, + pdf_paper=pdf_paper, + bg_color=bg_color, + theme=theme, + width=width, + height=height, + scale=scale, + ) + Path(path).write_bytes(image_data) diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_ai_graph/pydantic_ai_graph/nodes.py index 9ed72dfc23..1a1b8bcd5d 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/nodes.py +++ b/pydantic_ai_graph/pydantic_ai_graph/nodes.py @@ -1,29 +1,21 @@ from __future__ import annotations as _annotations -from abc import ABC, ABCMeta, abstractmethod +from abc import abstractmethod from dataclasses import dataclass from functools import cache -from typing import Any, ClassVar, Generic, get_args, get_origin, get_type_hints +from typing import Any, Generic, get_origin, get_type_hints -from typing_extensions import TypeVar +from typing_extensions import Never, TypeVar from . import _utils from .state import StateT -__all__ = ( - 'NodeInputT', - 'GraphOutputT', - 'GraphContext', - 'End', - 'BaseNode', - 'NodeDef', -) +__all__ = ('GraphContext', 'End', 'BaseNode', 'NodeDef') -NodeInputT = TypeVar('NodeInputT', default=Any) -GraphOutputT = TypeVar('GraphOutputT', default=Any) +RunEndT = TypeVar('RunEndT', default=None) +NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) -# noinspection PyTypeHints @dataclass class GraphContext(Generic[StateT]): """Context for a graph.""" @@ -31,61 +23,44 @@ class GraphContext(Generic[StateT]): state: StateT -# noinspection PyTypeHints -class End(ABC, Generic[NodeInputT]): +@dataclass +class End(Generic[RunEndT]): """Type to return from a node to signal the end of the graph.""" - __slots__ = ('data',) - - def __init__(self, input_data: NodeInputT) -> None: - self.data = input_data + data: RunEndT -class _BaseNodeMeta(ABCMeta): - def __repr__(cls): - base: Any = cls.__orig_bases__[0] # type: ignore - args = get_args(base) - if len(args) == 3 and args[2] is Any: - if args[1] is Any: - args = args[:1] - else: - args = args[:2] - args = ', '.join(_utils.type_arg_name(a) for a in args) - return f'{cls.__name__}({base.__name__}[{args}])' - - -# noinspection PyTypeHints -class BaseNode(Generic[StateT, NodeInputT, GraphOutputT], metaclass=_BaseNodeMeta): +class BaseNode(Generic[StateT, NodeRunEndT]): """Base class for a node.""" - node_id: ClassVar[str | None] = None - __slots__ = ('input_data',) - - def __init__(self, input_data: NodeInputT) -> None: - self.input_data = input_data - @abstractmethod - async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any, Any] | End[GraphOutputT]: ... + async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: ... @classmethod @cache def get_id(cls) -> str: - return cls.node_id or cls.__name__ + return cls.__name__ @classmethod - def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, Any, Any]: + def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRunEndT]: type_hints = get_type_hints(cls.run, localns=local_ns) next_node_ids: set[str] = set() - can_end: bool = False - dest_any: bool = False - for return_type in _utils.get_union_args(type_hints['return']): + returns_end: bool = False + returns_base_node: bool = False + try: + return_hint = type_hints['return'] + except KeyError: + raise TypeError(f'Node {cls} is missing a return type hint on its `run` method') + + for return_type in _utils.get_union_args(return_hint): return_type_origin = get_origin(return_type) or return_type - if return_type_origin is BaseNode: - dest_any = True + if return_type_origin is End: + returns_end = True + elif return_type_origin is BaseNode: + # TODO: Should we disallow this? More generally, what do we do about sub-subclasses? + returns_base_node = True elif issubclass(return_type_origin, BaseNode): next_node_ids.add(return_type.get_id()) - elif return_type_origin is End: - can_end = True else: raise TypeError(f'Invalid return type: {return_type}') @@ -93,21 +68,20 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, Any, A cls, cls.get_id(), next_node_ids, - can_end, - dest_any, + returns_end, + returns_base_node, ) -# noinspection PyTypeHints @dataclass -class NodeDef(ABC, Generic[StateT, NodeInputT, GraphOutputT]): +class NodeDef(Generic[StateT, NodeRunEndT]): """Definition of a node. Used by [`Graph`][pydantic_ai_graph.graph.Graph] store information about a node. """ - node: type[BaseNode[StateT, NodeInputT, GraphOutputT]] + node: type[BaseNode[StateT, NodeRunEndT]] node_id: str next_node_ids: set[str] - can_end: bool - dest_any: bool + returns_end: bool + returns_base_node: bool diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py index 00789887ed..45b01aedfd 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -1,13 +1,18 @@ from __future__ import annotations as _annotations +import copy from abc import ABC, abstractmethod -from dataclasses import dataclass -from datetime import datetime -from typing import Union +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Generic, Literal, Union -from typing_extensions import TypeVar +from typing_extensions import Never, TypeVar -__all__ = 'AbstractState', 'StateT', 'Snapshot' +__all__ = 'AbstractState', 'StateT', 'Step', 'EndEvent', 'StepOrEnd' + +if TYPE_CHECKING: + from pydantic_ai_graph import BaseNode + from pydantic_ai_graph.nodes import End class AbstractState(ABC): @@ -19,33 +24,40 @@ def serialize(self) -> bytes | None: raise NotImplementedError +RunEndT = TypeVar('RunEndT', default=None) +NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) StateT = TypeVar('StateT', bound=Union[None, AbstractState], default=None) @dataclass -class Snapshot: - """Snapshot of a graph.""" - - last_node_id: str - next_node_id: str - start_ts: datetime - duration: float - state: bytes | None = None - - @classmethod - def from_state( - cls, last_node_id: str, next_node_id: str, start_ts: datetime, duration: float, state: StateT - ) -> Snapshot: - return cls( - last_node_id=last_node_id, - next_node_id=next_node_id, - start_ts=start_ts, - duration=duration, - state=state.serialize() if state is not None else None, - ) - - def summary(self) -> str: - s = f'{self.last_node_id} -> {self.next_node_id}' - if self.duration > 1e-5: - s += f' ({self.duration:.6f}s)' - return s +class Step(Generic[StateT, RunEndT]): + """History item describing the execution of a step of a graph.""" + + state: StateT + node: BaseNode[StateT, RunEndT] + start_ts: datetime = field(default_factory=lambda: datetime.now(tz=timezone.utc)) + duration: float | None = None + + kind: Literal['start_step'] = 'start_step' + + def __post_init__(self): + # Copy the state to prevent it from being modified by other code + self.state = copy.deepcopy(self.state) + + +@dataclass +class EndEvent(Generic[StateT, RunEndT]): + """History item describing the end of a graph run.""" + + state: StateT + result: End[RunEndT] + ts: datetime = field(default_factory=lambda: datetime.now(tz=timezone.utc)) + + kind: Literal['end'] = 'end' + + def __post_init__(self): + # Copy the state to prevent it from being modified by other code + self.state = copy.deepcopy(self.state) + + +StepOrEnd = Union[Step[StateT, RunEndT], EndEvent[StateT, RunEndT]] diff --git a/tests/test_graph.py b/tests/test_graph.py index 81b54fa070..4a06603a7e 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,12 +1,13 @@ from __future__ import annotations as _annotations +from dataclasses import dataclass from datetime import timezone from typing import Union import pytest from inline_snapshot import snapshot -from pydantic_ai_graph import BaseNode, End, Graph, GraphContext, Snapshot +from pydantic_ai_graph import BaseNode, End, EndEvent, Graph, GraphContext, Step from .conftest import IsFloat, IsNow @@ -14,85 +15,101 @@ async def test_graph(): - class Float2String(BaseNode[None, float]): + @dataclass + class Float2String(BaseNode): + input_data: float + async def run(self, ctx: GraphContext) -> String2Length: return String2Length(str(self.input_data)) - class String2Length(BaseNode[None, str]): + @dataclass + class String2Length(BaseNode): + input_data: str + async def run(self, ctx: GraphContext) -> Double: return Double(len(self.input_data)) - class Double(BaseNode[None, int, int]): + @dataclass + class Double(BaseNode[None, int]): + input_data: int + async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noqa: UP007 if self.input_data == 7: return String2Length('x' * 21) else: return End(self.input_data * 2) - g = Graph[None, float, int]( - Float2String, - String2Length, - Double, - ) - result, history = await g.run(3.14) + g = Graph[None, int](nodes=(Float2String, String2Length, Double)) + runner = g.get_runner(Float2String) + result, history = await runner.run(None, 3.14) # len('3.14') * 2 == 8 assert result == 8 assert history == snapshot( [ - Snapshot( - last_node_id='Float2String', - next_node_id='String2Length', + Step( + state=None, + node=Float2String(input_data=3.14), start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-3), + duration=IsFloat(), ), - Snapshot( - last_node_id='String2Length', - next_node_id='Double', + Step( + state=None, + node=String2Length(input_data='3.14'), start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-3), + duration=IsFloat(), ), - Snapshot( - last_node_id='Double', - next_node_id='END', + Step( + state=None, + node=Double(input_data=4), start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-3), + duration=IsFloat(), + ), + EndEvent( + state=None, + result=End(data=8), + ts=IsNow(tz=timezone.utc), ), ] ) - result, history = await g.run(3.14159) + result, history = await runner.run(None, 3.14159) # len('3.14159') == 7, 21 * 2 == 42 assert result == 42 assert history == snapshot( [ - Snapshot( - last_node_id='Float2String', - next_node_id='String2Length', + Step( + state=None, + node=Float2String(input_data=3.14159), start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-3), + duration=IsFloat(), ), - Snapshot( - last_node_id='String2Length', - next_node_id='Double', + Step( + state=None, + node=String2Length(input_data='3.14159'), start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-3), + duration=IsFloat(), ), - Snapshot( - last_node_id='Double', - next_node_id='String2Length', + Step( + state=None, + node=Double(input_data=7), start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-3), + duration=IsFloat(), ), - Snapshot( - last_node_id='String2Length', - next_node_id='Double', + Step( + state=None, + node=String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx'), start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-3), + duration=IsFloat(), ), - Snapshot( - last_node_id='Double', - next_node_id='END', + Step( + state=None, + node=Double(input_data=21), start_ts=IsNow(tz=timezone.utc), - duration=IsFloat(gt=0, lt=1e-3), + duration=IsFloat(), + ), + EndEvent( + state=None, + result=End(data=42), + ts=IsNow(tz=timezone.utc), ), ] ) diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 040485af92..fdf9f1a255 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -3,7 +3,9 @@ from collections.abc import Awaitable, Iterator from contextlib import contextmanager from dataclasses import dataclass -from typing import Callable, TypeAlias, Union, assert_type +from typing import Callable, TypeAlias, Union + +from typing_extensions import assert_type from pydantic_ai import Agent, ModelRetry, RunContext, Tool from pydantic_ai.result import RunResult diff --git a/tests/typed_graph.py b/tests/typed_graph.py index 806347f8fb..cad0e7df16 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -1,17 +1,24 @@ from __future__ import annotations as _annotations from dataclasses import dataclass -from typing import assert_type + +from typing_extensions import assert_type from pydantic_ai_graph import BaseNode, End, Graph, GraphContext -class Float2String(BaseNode[None, float]): +@dataclass +class Float2String(BaseNode): + input_data: float + async def run(self, ctx: GraphContext) -> String2Length: return String2Length(str(self.input_data)) -class String2Length(BaseNode[None, str]): +@dataclass +class String2Length(BaseNode): + input_data: str + async def run(self, ctx: GraphContext) -> Double: return Double(len(self.input_data)) @@ -21,7 +28,10 @@ class X: v: int -class Double(BaseNode[None, int, X]): +@dataclass +class Double(BaseNode[None, X]): + input_data: int + async def run(self, ctx: GraphContext) -> String2Length | End[X]: if self.input_data == 7: return String2Length('x' * 21) @@ -29,7 +39,7 @@ async def run(self, ctx: GraphContext) -> String2Length | End[X]: return End(X(self.input_data * 2)) -def use_double(node: BaseNode[None, int, X]) -> None: +def use_double(node: BaseNode[None, X]) -> None: """Shoe that `Double` is valid as a `BaseNode[None, int, X]`.""" print(node) @@ -37,24 +47,28 @@ def use_double(node: BaseNode[None, int, X]) -> None: use_double(Double(1)) -g1 = Graph[None, float, X]( - Float2String, - String2Length, - Double, +g1 = Graph[None, X]( + nodes=( + Float2String, + String2Length, + Double, + ) ) -assert_type(g1, Graph[None, float, X]) +assert_type(g1, Graph[None, X]) -g2 = Graph(Double) -assert_type(g2, Graph[None, int, X]) +g2 = Graph(nodes=(Double,)) +assert_type(g2, Graph[None, X]) g3 = Graph( - Float2String, - String2Length, - Double, + nodes=( + Float2String, + String2Length, + Double, + ) ) # because String2Length came before Double, the output type is Any -assert_type(g3, Graph[None, float]) +assert_type(g3, Graph[None, X]) -Graph[None, float, bytes](Float2String, String2Length, Double) # type: ignore[arg-type] -Graph[None, int, str](Double) # type: ignore[arg-type] +Graph[None, bytes](Float2String, String2Length, Double) # type: ignore[arg-type] +Graph[None, str](Double) # type: ignore[arg-type] From d0bdb87f85e0b99bda61d1a19462da73afffca40 Mon Sep 17 00:00:00 2001 From: Israel Ekpo <44282278+izzyacademy@users.noreply.github.com> Date: Fri, 3 Jan 2025 07:03:07 -0500 Subject: [PATCH 11/57] Typo in Graph Documentation (#596) --- pydantic_ai_graph/README.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_graph/README.md b/pydantic_ai_graph/README.md index ec0b531a25..90e7f787ff 100644 --- a/pydantic_ai_graph/README.md +++ b/pydantic_ai_graph/README.md @@ -11,6 +11,11 @@ Graph and state machine library. This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency on `pydantic-ai` or related packages and can be considered as a pure graph library. -As with PydanticAI, this library priorities type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. +As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. -`pydantic-ai-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes. +`pydantic-ai-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes. + +When designing your graph and state machine, you need to identify the data types for the overall graph input, the final graph output, the graph dependency object and graph state. Then for each specific node in the graph, you have to identify the specific data type each node is expected to receive as the input type from the prior node in the graph during transitions. + +Once the nodes in the graph are defined, you can use certain built-in methods on the Graph object to visualize the nodes +and state transitions on the graph as mermaid diagrams. From 16ccd030a5057f61fa8cdde039b9b0f837671ec3 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 7 Jan 2025 14:20:31 +0000 Subject: [PATCH 12/57] fix linting --- pydantic_ai_graph/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_graph/README.md b/pydantic_ai_graph/README.md index 90e7f787ff..3f3a5bd72a 100644 --- a/pydantic_ai_graph/README.md +++ b/pydantic_ai_graph/README.md @@ -13,9 +13,9 @@ on `pydantic-ai` or related packages and can be considered as a pure graph libra As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. -`pydantic-ai-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes. +`pydantic-ai-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes. When designing your graph and state machine, you need to identify the data types for the overall graph input, the final graph output, the graph dependency object and graph state. Then for each specific node in the graph, you have to identify the specific data type each node is expected to receive as the input type from the prior node in the graph during transitions. -Once the nodes in the graph are defined, you can use certain built-in methods on the Graph object to visualize the nodes +Once the nodes in the graph are defined, you can use certain built-in methods on the Graph object to visualize the nodes and state transitions on the graph as mermaid diagrams. From bda6dfbbce5b74dd3e76dca33f261a911de521fa Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 7 Jan 2025 17:48:07 +0000 Subject: [PATCH 13/57] separate mermaid logic --- pydantic_ai_graph/pydantic_ai_graph/graph.py | 176 ++--------------- .../pydantic_ai_graph/mermaid.py | 186 ++++++++++++++++++ 2 files changed, 207 insertions(+), 155 deletions(-) create mode 100644 pydantic_ai_graph/pydantic_ai_graph/mermaid.py diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index a6e09e887a..ae3c926b49 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -1,22 +1,22 @@ from __future__ import annotations as _annotations -import base64 import inspect +from collections.abc import Sequence from dataclasses import dataclass, field from pathlib import Path from time import perf_counter -from typing import TYPE_CHECKING, Annotated, Any, Generic +from typing import TYPE_CHECKING, Annotated, Generic import logfire_api from annotated_types import Ge, Le -from typing_extensions import Literal, Never, ParamSpec, Protocol, TypeVar, assert_never +from typing_extensions import Never, ParamSpec, Protocol, TypeVar, assert_never -from . import _utils +from . import _utils, mermaid from ._utils import get_parent_namespace from .nodes import BaseNode, End, GraphContext, NodeDef from .state import EndEvent, StateT, Step, StepOrEnd -__all__ = ('Graph', 'GraphRun', 'GraphRunner') +__all__ = 'Graph', 'GraphRun', 'GraphRunner' _logfire = logfire_api.Logfire(otel_scope='pydantic-ai-graph') @@ -27,6 +27,7 @@ class StartNodeProtocol(Protocol[RunSignatureT, StateT, NodeRunEndT]): def get_id(self) -> str: ... + def __call__(self, *args: RunSignatureT.args, **kwargs: RunSignatureT.kwargs) -> BaseNode[StateT, NodeRunEndT]: ... @@ -40,9 +41,10 @@ class Graph(Generic[StateT, RunEndT]): def __init__( self, - name: str | None = None, - nodes: tuple[type[BaseNode[StateT, RunEndT]], ...] = (), + *, + nodes: Sequence[type[BaseNode[StateT, RunEndT]]], state_type: type[StateT] | None = None, + name: str | None = None, ): self.name = name @@ -94,6 +96,7 @@ def get_runner( self, first_node: StartNodeProtocol[RunSignatureT, StateT, RunEndT], ) -> GraphRunner[RunSignatureT, StateT, RunEndT]: + # noinspection PyTypeChecker return GraphRunner( graph=self, first_node=first_node, @@ -125,25 +128,23 @@ async def run( return result, history def mermaid_code(self) -> str: - return mermaid_code(self.graph, self.first_node) + return mermaid.generate_code(self.graph, self.first_node.get_id()) def mermaid_image( self, - image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] | str | None = None, + image_type: mermaid.ImageType | None = None, pdf_fit: bool = False, pdf_landscape: bool = False, - pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] - | str - | None = None, + pdf_paper: mermaid.PdfPaper | None = None, bg_color: str | None = None, - theme: Literal['default', 'neutral', 'dark', 'forest'] | str | None = None, + theme: mermaid.Theme | None = None, width: int | None = None, height: int | None = None, scale: Annotated[float, Ge(1), Le(3)] | None = None, ) -> bytes: - return mermaid_image( + return mermaid.request_image( self.graph, - self.first_node, + start_node_ids=self.first_node.get_id(), image_type=image_type, pdf_fit=pdf_fit, pdf_landscape=pdf_landscape, @@ -158,22 +159,20 @@ def mermaid_image( def mermaid_save( self, path: Path | str, - image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] | str | None = None, + image_type: mermaid.ImageType | None = None, pdf_fit: bool = False, pdf_landscape: bool = False, - pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] - | str - | None = None, + pdf_paper: mermaid.PdfPaper | None = None, bg_color: str | None = None, - theme: Literal['default', 'neutral', 'dark', 'forest'] | str | None = None, + theme: mermaid.Theme | None = None, width: int | None = None, height: int | None = None, scale: Annotated[float, Ge(1), Le(3)] | None = None, ) -> None: - mermaid_save( + mermaid.save_image( path, self.graph, - self.first_node, + self.first_node.get_id(), image_type=image_type, pdf_fit=pdf_fit, pdf_landscape=pdf_landscape, @@ -225,136 +224,3 @@ async def step(self, node: BaseNode[StateT, RunEndT]) -> BaseNode[StateT, RunEnd next_node = await node.run(ctx) history_step.duration = perf_counter() - start return next_node - - -def mermaid_code( - graph: Graph[Any, Any], start: StartNodeProtocol[..., Any, Any] | tuple[StartNodeProtocol[..., Any, Any], ...] = () -) -> str: - if not isinstance(start, tuple): - start = (start,) - - for node in start: - if node not in graph.nodes: - raise ValueError(f'Start node "{node}" is not in the graph.') - - node_order = {node_id: index for index, node_id in enumerate(graph.node_defs)} - - lines = ['graph TD'] - for node in graph.nodes: - node_id = node.get_id() - node_def = graph.node_defs[node_id] - if node in start: - lines.append(f' START --> {node_id}') - if node_def.returns_base_node: - for next_node_id in graph.nodes: - lines.append(f' {node_id} --> {next_node_id}') - else: - for _, next_node_id in sorted((node_order[node_id], node_id) for node_id in node_def.next_node_ids): - lines.append(f' {node_id} --> {next_node_id}') - if node_def.returns_end: - lines.append(f' {node_id} --> END') - - return '\n'.join(lines) - - -def mermaid_image( - graph: Graph[Any, Any], - start: StartNodeProtocol[..., Any, Any] | tuple[StartNodeProtocol[..., Any, Any], ...] = (), - image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] | str | None = None, - pdf_fit: bool = False, - pdf_landscape: bool = False, - pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] - | str - | None = None, - bg_color: str | None = None, - theme: Literal['default', 'neutral', 'dark', 'forest'] | str | None = None, - width: int | None = None, - height: int | None = None, - scale: Annotated[float, Ge(1), Le(3)] | None = None, -) -> bytes: - """Generate an image of a Mermaid diagram using mermaid.ink. - - Args: - graph: The graph to generate the image for. - start: The start node(s) of the graph. - image_type: The image type to generate. If unspecified, the default behavior is `'jpeg'`. - pdf_fit: When using image_type='pdf', whether to fit the diagram to the PDF page. - pdf_landscape: When using image_type='pdf', whether to use landscape orientation for the PDF. - This has no effect if using `pdf_fit`. - pdf_paper: When using image_type='pdf', the paper size of the PDF. - bg_color: The background color of the diagram. If None, the default transparent background is used. - The color value is interpreted as a hexadecimal color code by default (and should not have a leading '#'), - but you can also use named colors by prefixing the value with '!'. - For example, valid choices include `bg_color='!white'` or `bg_color='FF0000'`. - theme: The theme of the diagram. Defaults to 'default'. - width: The width of the diagram. - height: The height of the diagram. - scale: The scale of the diagram. The scale must be a number between 1 and 3, and you can only set - a scale if one or both of width and height are set. - """ - import httpx - - code_base64 = base64.b64encode(mermaid_code(graph, start).encode()).decode() - - params: dict[str, str] = {} - if image_type == 'pdf': - url = f'https://mermaid.ink/pdf/{code_base64}' - if pdf_fit: - params['fit'] = '' - if pdf_landscape: - params['landscape'] = '' - if pdf_paper: - params['paper'] = pdf_paper - else: - url = f'https://mermaid.ink/img/{code_base64}' - - if image_type: - params['type'] = image_type - - if bg_color: - params['bgColor'] = bg_color - if theme: - params['theme'] = theme - if width: - params['width'] = str(width) - if height: - params['height'] = str(height) - if scale: - params['scale'] = str(scale) - - response = httpx.get(url, params=params) - response.raise_for_status() - return response.content - - -def mermaid_save( - path: Path | str, - graph: Graph[Any, Any], - start: StartNodeProtocol[..., Any, Any] | tuple[StartNodeProtocol[..., Any, Any], ...] = (), - image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] | str | None = None, - pdf_fit: bool = False, - pdf_landscape: bool = False, - pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] - | str - | None = None, - bg_color: str | None = None, - theme: Literal['default', 'neutral', 'dark', 'forest'] | str | None = None, - width: int | None = None, - height: int | None = None, - scale: Annotated[float, Ge(1), Le(3)] | None = None, -) -> None: - # TODO: do something with the path file extension, e.g. error if it's incompatible, or use it to specify a param - image_data = mermaid_image( - graph, - start, - image_type=image_type, - pdf_fit=pdf_fit, - pdf_landscape=pdf_landscape, - pdf_paper=pdf_paper, - bg_color=bg_color, - theme=theme, - width=width, - height=height, - scale=scale, - ) - Path(path).write_bytes(image_data) diff --git a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py new file mode 100644 index 0000000000..431d26d642 --- /dev/null +++ b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py @@ -0,0 +1,186 @@ +from __future__ import annotations as _annotations + +import base64 +from collections.abc import Sequence +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Any, Literal, cast + +from annotated_types import Ge, Le + +if TYPE_CHECKING: + from .graph import Graph + + +def generate_code(graph: Graph[Any, Any], start_node_ids: Sequence[str] | str) -> str: + """Generate Mermaid code for a graph. + + Args: + graph: The graph to generate the image for. + start_node_ids: IDs of start nodes of the graph. + + Returns: The Mermaid code for the graph. + """ + if isinstance(start_node_ids, str): + start_node_ids = (start_node_ids,) + + for node_id in start_node_ids: + if node_id not in graph.node_defs: + raise LookupError(f'Start node "{node_id}" is not in the graph.') + + node_order = {node_id: index for index, node_id in enumerate(graph.node_defs)} + + lines = ['graph TD'] + for node in graph.nodes: + node_id = node.get_id() + node_def = graph.node_defs[node_id] + if node_id in start_node_ids: + lines.append(f' START --> {node_id}') + if node_def.returns_base_node: + for next_node_id in graph.nodes: + lines.append(f' {node_id} --> {next_node_id}') + else: + for _, next_node_id in sorted((node_order[node_id], node_id) for node_id in node_def.next_node_ids): + lines.append(f' {node_id} --> {next_node_id}') + if node_def.returns_end: + lines.append(f' {node_id} --> END') + + return '\n'.join(lines) + + +ImageType = Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] +PdfPaper = Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] +Theme = Literal['default', 'neutral', 'dark', 'forest'] + + +def request_image( + graph: Graph[Any, Any], + start_node_ids: Sequence[str] | str, + *, + image_type: ImageType | str | None = None, + pdf_fit: bool = False, + pdf_landscape: bool = False, + pdf_paper: PdfPaper | None = None, + bg_color: str | None = None, + theme: Theme | None = None, + width: int | None = None, + height: int | None = None, + scale: Annotated[float, Ge(1), Le(3)] | None = None, +) -> bytes: + """Generate an image of a Mermaid diagram using [mermaid.ink](https://mermaid.ink). + + Args: + graph: The graph to generate the image for. + start_node_ids: IDs of start nodes of the graph. + image_type: The image type to generate. If unspecified, the default behavior is `'jpeg'`. + pdf_fit: When using image_type='pdf', whether to fit the diagram to the PDF page. + pdf_landscape: When using image_type='pdf', whether to use landscape orientation for the PDF. + This has no effect if using `pdf_fit`. + pdf_paper: When using image_type='pdf', the paper size of the PDF. + bg_color: The background color of the diagram. If None, the default transparent background is used. + The color value is interpreted as a hexadecimal color code by default (and should not have a leading '#'), + but you can also use named colors by prefixing the value with `'!'`. + For example, valid choices include `bg_color='!white'` or `bg_color='FF0000'`. + theme: The theme of the diagram. Defaults to 'default'. + width: The width of the diagram. + height: The height of the diagram. + scale: The scale of the diagram. The scale must be a number between 1 and 3, and you can only set + a scale if one or both of width and height are set. + + Returns: The image data. + """ + import httpx + + code_base64 = base64.b64encode(generate_code(graph, start_node_ids).encode()).decode() + + params: dict[str, str] = {} + if image_type == 'pdf': + url = f'https://mermaid.ink/pdf/{code_base64}' + if pdf_fit: + params['fit'] = '' + if pdf_landscape: + params['landscape'] = '' + if pdf_paper: + params['paper'] = pdf_paper + elif image_type == 'svg': + url = f'https://mermaid.ink/svg/{code_base64}' + else: + url = f'https://mermaid.ink/img/{code_base64}' + + if image_type: + params['type'] = image_type + + if bg_color: + params['bgColor'] = bg_color + if theme: + params['theme'] = theme + if width: + params['width'] = str(width) + if height: + params['height'] = str(height) + if scale: + params['scale'] = str(scale) + + response = httpx.get(url, params=params) + response.raise_for_status() + return response.content + + +def save_image( + path: Path | str, + graph: Graph[Any, Any], + start_node_ids: Sequence[str] | str, + *, + image_type: ImageType | None = None, + pdf_fit: bool = False, + pdf_landscape: bool = False, + pdf_paper: PdfPaper | None = None, + bg_color: str | None = None, + theme: Theme | None = None, + width: int | None = None, + height: int | None = None, + scale: Annotated[float, Ge(1), Le(3)] | None = None, +) -> None: + """Generate an image of a Mermaid diagram using [mermaid.ink](https://mermaid.ink) and save it to a local file. + + Args: + path: The path to save the image to. + graph: The graph to generate the image for. + start_node_ids: IDs of start nodes of the graph. + image_type: The image type to generate. If unspecified, the default behavior is `'jpeg'`. + pdf_fit: When using image_type='pdf', whether to fit the diagram to the PDF page. + pdf_landscape: When using image_type='pdf', whether to use landscape orientation for the PDF. + This has no effect if using `pdf_fit`. + pdf_paper: When using image_type='pdf', the paper size of the PDF. + bg_color: The background color of the diagram. If None, the default transparent background is used. + The color value is interpreted as a hexadecimal color code by default (and should not have a leading '#'), + but you can also use named colors by prefixing the value with `'!'`. + For example, valid choices include `bg_color='!white'` or `bg_color='FF0000'`. + theme: The theme of the diagram. Defaults to 'default'. + width: The width of the diagram. + height: The height of the diagram. + scale: The scale of the diagram. The scale must be a number between 1 and 3, and you can only set + a scale if one or both of width and height are set. + """ + if isinstance(path, str): + path = Path(path) + + if image_type is None: + ext = path.suffix.lower() + # no need to check for .jpeg/.jpg, as it is the default + if ext in {'.png', '.webp', '.svg', '.pdf'}: + image_type = cast(ImageType, ext[1:]) + + image_data = request_image( + graph, + start_node_ids, + image_type=image_type, + pdf_fit=pdf_fit, + pdf_landscape=pdf_landscape, + pdf_paper=pdf_paper, + bg_color=bg_color, + theme=theme, + width=width, + height=height, + scale=scale, + ) + Path(path).write_bytes(image_data) From cce71e166bb34ffdf979083f396e8b36b2210d69 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 7 Jan 2025 21:09:21 +0000 Subject: [PATCH 14/57] fix graph type checking --- tests/typed_graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/typed_graph.py b/tests/typed_graph.py index cad0e7df16..c131f19b4c 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -70,5 +70,5 @@ def use_double(node: BaseNode[None, X]) -> None: # because String2Length came before Double, the output type is Any assert_type(g3, Graph[None, X]) -Graph[None, bytes](Float2String, String2Length, Double) # type: ignore[arg-type] -Graph[None, str](Double) # type: ignore[arg-type] +Graph[None, bytes](nodes=(Float2String, String2Length, Double)) # type: ignore[arg-type] +Graph[None, str](nodes=[Double]) # type: ignore[list-item] From 2d9f9f396f5f56c05bf54f6eea8f00a192193f3f Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 7 Jan 2025 21:16:35 +0000 Subject: [PATCH 15/57] bump From 7e98bf7be3871236b8965b662b7059f79ab87e96 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 7 Jan 2025 22:40:28 +0000 Subject: [PATCH 16/57] adding node highlighting to mermaid, testing locally --- pydantic_ai_graph/pydantic_ai_graph/_utils.py | 5 ++ pydantic_ai_graph/pydantic_ai_graph/graph.py | 24 ++++++++-- .../pydantic_ai_graph/mermaid.py | 46 +++++++++++++++---- pydantic_ai_graph/pydantic_ai_graph/state.py | 16 +++++-- 4 files changed, 75 insertions(+), 16 deletions(-) diff --git a/pydantic_ai_graph/pydantic_ai_graph/_utils.py b/pydantic_ai_graph/pydantic_ai_graph/_utils.py index d1b4671800..a451fc5533 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/_utils.py +++ b/pydantic_ai_graph/pydantic_ai_graph/_utils.py @@ -2,6 +2,7 @@ import sys import types +from datetime import datetime, timezone from typing import Any, Union, get_args, get_origin from typing_extensions import TypeAliasType @@ -63,3 +64,7 @@ def get_parent_namespace(frame: types.FrameType | None) -> dict[str, Any] | None return get_parent_namespace(back) else: return back.f_locals + + +def now_utc() -> datetime: + return datetime.now(tz=timezone.utc) diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index ae3c926b49..98437c4941 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -127,11 +127,20 @@ async def run( history = run.history return result, history - def mermaid_code(self) -> str: - return mermaid.generate_code(self.graph, self.first_node.get_id()) + def mermaid_code( + self, + highlighted_nodes: Sequence[mermaid.NodeIdent] | None = None, + highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, + ) -> str: + return mermaid.generate_code( + self.graph, {self.first_node.get_id()}, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css + ) def mermaid_image( self, + *, + highlighted_nodes: Sequence[mermaid.NodeIdent] | None = None, + highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, image_type: mermaid.ImageType | None = None, pdf_fit: bool = False, pdf_landscape: bool = False, @@ -144,7 +153,9 @@ def mermaid_image( ) -> bytes: return mermaid.request_image( self.graph, - start_node_ids=self.first_node.get_id(), + {self.first_node.get_id()}, + highlighted_nodes=highlighted_nodes, + highlight_css=highlight_css, image_type=image_type, pdf_fit=pdf_fit, pdf_landscape=pdf_landscape, @@ -159,6 +170,9 @@ def mermaid_image( def mermaid_save( self, path: Path | str, + *, + highlighted_nodes: Sequence[mermaid.NodeIdent] | None = None, + highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, image_type: mermaid.ImageType | None = None, pdf_fit: bool = False, pdf_landscape: bool = False, @@ -172,7 +186,9 @@ def mermaid_save( mermaid.save_image( path, self.graph, - self.first_node.get_id(), + {self.first_node.get_id()}, + highlighted_nodes=highlighted_nodes, + highlight_css=highlight_css, image_type=image_type, pdf_fit=pdf_fit, pdf_landscape=pdf_landscape, diff --git a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py index 431d26d642..f7d0e42611 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py +++ b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py @@ -9,20 +9,30 @@ if TYPE_CHECKING: from .graph import Graph + from .nodes import BaseNode -def generate_code(graph: Graph[Any, Any], start_node_ids: Sequence[str] | str) -> str: +NodeIdent = type[BaseNode[Any, Any]] | str +DEFAULT_HIGHLIGHT_CSS = 'fill:#f9f' + + +def generate_code( + graph: Graph[Any, Any], + start_node_ids: set[str], + *, + highlighted_nodes: Sequence[NodeIdent] | None = None, + highlight_css: str = DEFAULT_HIGHLIGHT_CSS, +) -> str: """Generate Mermaid code for a graph. Args: graph: The graph to generate the image for. - start_node_ids: IDs of start nodes of the graph. + start_node_ids: Identifiers of nodes that start the graph. + highlighted_nodes: Identifiers of nodes to highlight. + highlight_css: CSS to use for highlighting nodes. Returns: The Mermaid code for the graph. """ - if isinstance(start_node_ids, str): - start_node_ids = (start_node_ids,) - for node_id in start_node_ids: if node_id not in graph.node_defs: raise LookupError(f'Start node "{node_id}" is not in the graph.') @@ -44,6 +54,15 @@ def generate_code(graph: Graph[Any, Any], start_node_ids: Sequence[str] | str) - if node_def.returns_end: lines.append(f' {node_id} --> END') + if highlighted_nodes: + lines.append('') + lines.append(f'classDef highlighted {highlight_css}') + for node in highlighted_nodes: + node_id = node if isinstance(node, str) else node.get_id() + if node_id not in graph.node_defs: + raise LookupError(f'Highlighted node "{node_id}" is not in the graph.') + lines.append(f'class {node_id} highlighted') + return '\n'.join(lines) @@ -54,8 +73,10 @@ def generate_code(graph: Graph[Any, Any], start_node_ids: Sequence[str] | str) - def request_image( graph: Graph[Any, Any], - start_node_ids: Sequence[str] | str, + start_node_ids: set[str], *, + highlighted_nodes: Sequence[NodeIdent] | None = None, + highlight_css: str = 'fill:#f9f,stroke:#333,stroke-width:4px', image_type: ImageType | str | None = None, pdf_fit: bool = False, pdf_landscape: bool = False, @@ -71,6 +92,8 @@ def request_image( Args: graph: The graph to generate the image for. start_node_ids: IDs of start nodes of the graph. + highlighted_nodes: Identifiers of nodes to highlight. + highlight_css: CSS to use for highlighting nodes. image_type: The image type to generate. If unspecified, the default behavior is `'jpeg'`. pdf_fit: When using image_type='pdf', whether to fit the diagram to the PDF page. pdf_landscape: When using image_type='pdf', whether to use landscape orientation for the PDF. @@ -90,7 +113,8 @@ def request_image( """ import httpx - code_base64 = base64.b64encode(generate_code(graph, start_node_ids).encode()).decode() + code = generate_code(graph, start_node_ids, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css) + code_base64 = base64.b64encode(code.encode()).decode() params: dict[str, str] = {} if image_type == 'pdf': @@ -128,8 +152,10 @@ def request_image( def save_image( path: Path | str, graph: Graph[Any, Any], - start_node_ids: Sequence[str] | str, + start_node_ids: set[str], *, + highlighted_nodes: Sequence[NodeIdent] | None = None, + highlight_css: str = 'fill:#f9f,stroke:#333,stroke-width:4px', image_type: ImageType | None = None, pdf_fit: bool = False, pdf_landscape: bool = False, @@ -146,6 +172,8 @@ def save_image( path: The path to save the image to. graph: The graph to generate the image for. start_node_ids: IDs of start nodes of the graph. + highlighted_nodes: Identifiers of nodes to highlight. + highlight_css: CSS to use for highlighting nodes. image_type: The image type to generate. If unspecified, the default behavior is `'jpeg'`. pdf_fit: When using image_type='pdf', whether to fit the diagram to the PDF page. pdf_landscape: When using image_type='pdf', whether to use landscape orientation for the PDF. @@ -173,6 +201,8 @@ def save_image( image_data = request_image( graph, start_node_ids, + highlighted_nodes=highlighted_nodes, + highlight_css=highlight_css, image_type=image_type, pdf_fit=pdf_fit, pdf_landscape=pdf_landscape, diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py index 45b01aedfd..66c527d882 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -3,11 +3,13 @@ import copy from abc import ABC, abstractmethod from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import datetime from typing import TYPE_CHECKING, Generic, Literal, Union from typing_extensions import Never, TypeVar +from . import _utils + __all__ = 'AbstractState', 'StateT', 'Step', 'EndEvent', 'StepOrEnd' if TYPE_CHECKING: @@ -35,15 +37,18 @@ class Step(Generic[StateT, RunEndT]): state: StateT node: BaseNode[StateT, RunEndT] - start_ts: datetime = field(default_factory=lambda: datetime.now(tz=timezone.utc)) + start_ts: datetime = field(default_factory=_utils.now_utc) duration: float | None = None - kind: Literal['start_step'] = 'start_step' + kind: Literal['step'] = 'step' def __post_init__(self): # Copy the state to prevent it from being modified by other code self.state = copy.deepcopy(self.state) + def node_summary(self) -> str: + return str(self.node) + @dataclass class EndEvent(Generic[StateT, RunEndT]): @@ -51,7 +56,7 @@ class EndEvent(Generic[StateT, RunEndT]): state: StateT result: End[RunEndT] - ts: datetime = field(default_factory=lambda: datetime.now(tz=timezone.utc)) + ts: datetime = field(default_factory=_utils.now_utc) kind: Literal['end'] = 'end' @@ -59,5 +64,8 @@ def __post_init__(self): # Copy the state to prevent it from being modified by other code self.state = copy.deepcopy(self.state) + def node_summary(self) -> str: + return str(self.result) + StepOrEnd = Union[Step[StateT, RunEndT], EndEvent[StateT, RunEndT]] From 743fa5a9d13d0287ab7705da03770f4ab7a7fedb Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 7 Jan 2025 23:21:22 +0000 Subject: [PATCH 17/57] bump From f6aa9297e0dfeff93b15ccc591afaa529d978c6e Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 7 Jan 2025 23:23:33 +0000 Subject: [PATCH 18/57] fix type checking imports --- pydantic_ai_graph/pydantic_ai_graph/mermaid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py index f7d0e42611..cf83f60de1 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py +++ b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py @@ -7,9 +7,10 @@ from annotated_types import Ge, Le +from .nodes import BaseNode + if TYPE_CHECKING: from .graph import Graph - from .nodes import BaseNode NodeIdent = type[BaseNode[Any, Any]] | str From 246755df4c0155a87af1b2e4b19bd5246e5bc164 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 7 Jan 2025 23:27:13 +0000 Subject: [PATCH 19/57] fix for python 3.9 --- pydantic_ai_graph/pydantic_ai_graph/mermaid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py index cf83f60de1..11fe9c6a59 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py +++ b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Literal, cast from annotated_types import Ge, Le +from typing_extensions import TypeAlias from .nodes import BaseNode @@ -13,7 +14,7 @@ from .graph import Graph -NodeIdent = type[BaseNode[Any, Any]] | str +NodeIdent: TypeAlias = 'type[BaseNode[Any, Any]] | str' DEFAULT_HIGHLIGHT_CSS = 'fill:#f9f' From d985db409fb606910f2b4d9469c98362b42de915 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 8 Jan 2025 18:07:47 +0000 Subject: [PATCH 20/57] simplify mermaid config --- pydantic_ai_graph/pydantic_ai_graph/graph.py | 76 +------ .../pydantic_ai_graph/mermaid.py | 209 +++++++++--------- 2 files changed, 109 insertions(+), 176 deletions(-) diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index 98437c4941..ce39168d2c 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -5,11 +5,10 @@ from dataclasses import dataclass, field from pathlib import Path from time import perf_counter -from typing import TYPE_CHECKING, Annotated, Generic +from typing import TYPE_CHECKING, Generic import logfire_api -from annotated_types import Ge, Le -from typing_extensions import Never, ParamSpec, Protocol, TypeVar, assert_never +from typing_extensions import Never, ParamSpec, Protocol, TypeVar, Unpack, assert_never from . import _utils, mermaid from ._utils import get_parent_namespace @@ -129,76 +128,19 @@ async def run( def mermaid_code( self, - highlighted_nodes: Sequence[mermaid.NodeIdent] | None = None, + *, + highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, ) -> str: return mermaid.generate_code( - self.graph, {self.first_node.get_id()}, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css + self.graph, self.first_node.get_id(), highlighted_nodes=highlighted_nodes, highlight_css=highlight_css ) - def mermaid_image( - self, - *, - highlighted_nodes: Sequence[mermaid.NodeIdent] | None = None, - highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, - image_type: mermaid.ImageType | None = None, - pdf_fit: bool = False, - pdf_landscape: bool = False, - pdf_paper: mermaid.PdfPaper | None = None, - bg_color: str | None = None, - theme: mermaid.Theme | None = None, - width: int | None = None, - height: int | None = None, - scale: Annotated[float, Ge(1), Le(3)] | None = None, - ) -> bytes: - return mermaid.request_image( - self.graph, - {self.first_node.get_id()}, - highlighted_nodes=highlighted_nodes, - highlight_css=highlight_css, - image_type=image_type, - pdf_fit=pdf_fit, - pdf_landscape=pdf_landscape, - pdf_paper=pdf_paper, - bg_color=bg_color, - theme=theme, - width=width, - height=height, - scale=scale, - ) + def mermaid_image(self, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes: + return mermaid.request_image(self.graph, self.first_node.get_id(), **kwargs) - def mermaid_save( - self, - path: Path | str, - *, - highlighted_nodes: Sequence[mermaid.NodeIdent] | None = None, - highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, - image_type: mermaid.ImageType | None = None, - pdf_fit: bool = False, - pdf_landscape: bool = False, - pdf_paper: mermaid.PdfPaper | None = None, - bg_color: str | None = None, - theme: mermaid.Theme | None = None, - width: int | None = None, - height: int | None = None, - scale: Annotated[float, Ge(1), Le(3)] | None = None, - ) -> None: - mermaid.save_image( - path, - self.graph, - {self.first_node.get_id()}, - highlighted_nodes=highlighted_nodes, - highlight_css=highlight_css, - image_type=image_type, - pdf_fit=pdf_fit, - pdf_landscape=pdf_landscape, - pdf_paper=pdf_paper, - bg_color=bg_color, - theme=theme, - width=width, - height=height, - scale=scale, - ) + def mermaid_save(self, path: Path | str, /, **kwargs: Unpack[mermaid.MermaidConfig]) -> None: + mermaid.save_image(path, self.graph, self.first_node.get_id(), **kwargs) @dataclass diff --git a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py index 11fe9c6a59..00d2d3a9f2 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py +++ b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py @@ -1,12 +1,12 @@ from __future__ import annotations as _annotations import base64 -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any, Literal, cast +from typing import TYPE_CHECKING, Annotated, Any, Literal from annotated_types import Ge, Le -from typing_extensions import TypeAlias +from typing_extensions import TypeAlias, TypedDict, Unpack from .nodes import BaseNode @@ -14,27 +14,29 @@ from .graph import Graph -NodeIdent: TypeAlias = 'type[BaseNode[Any, Any]] | str' -DEFAULT_HIGHLIGHT_CSS = 'fill:#f9f' +NodeIdent: TypeAlias = 'type[BaseNode[Any, Any]] | BaseNode[Any, Any] | str' +DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32' def generate_code( graph: Graph[Any, Any], - start_node_ids: set[str], + start_nodes: Sequence[NodeIdent] | NodeIdent, + /, *, - highlighted_nodes: Sequence[NodeIdent] | None = None, + highlighted_nodes: Sequence[NodeIdent] | NodeIdent | None = None, highlight_css: str = DEFAULT_HIGHLIGHT_CSS, ) -> str: """Generate Mermaid code for a graph. Args: graph: The graph to generate the image for. - start_node_ids: Identifiers of nodes that start the graph. + start_nodes: Identifiers of nodes that start the graph. highlighted_nodes: Identifiers of nodes to highlight. highlight_css: CSS to use for highlighting nodes. Returns: The Mermaid code for the graph. """ + start_node_ids = set(node_ids(start_nodes)) for node_id in start_node_ids: if node_id not in graph.node_defs: raise LookupError(f'Start node "{node_id}" is not in the graph.') @@ -59,8 +61,7 @@ def generate_code( if highlighted_nodes: lines.append('') lines.append(f'classDef highlighted {highlight_css}') - for node in highlighted_nodes: - node_id = node if isinstance(node, str) else node.get_id() + for node_id in node_ids(highlighted_nodes): if node_id not in graph.node_defs: raise LookupError(f'Highlighted node "{node_id}" is not in the graph.') lines.append(f'class {node_id} highlighted') @@ -68,82 +69,111 @@ def generate_code( return '\n'.join(lines) -ImageType = Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] -PdfPaper = Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] -Theme = Literal['default', 'neutral', 'dark', 'forest'] +def node_ids(node_idents: Sequence[NodeIdent] | NodeIdent) -> Iterable[str]: + """Get the node IDs from a sequence of node identifiers.""" + if isinstance(node_idents, str): + node_iter = (node_idents,) + elif isinstance(node_idents, Sequence): + node_iter = node_idents + else: + node_iter = (node_idents,) + + for node in node_iter: + if isinstance(node, str): + yield node + else: + yield node.get_id() + + +class MermaidConfig(TypedDict, total=False): + """Parameters to configure mermaid chart generation.""" + + highlighted_nodes: Sequence[NodeIdent] | NodeIdent + """Identifiers of nodes to highlight.""" + highlight_css: str + """CSS to use for highlighting nodes.""" + image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] + """The image type to generate. If unspecified, the default behavior is `'jpeg'`.""" + pdf_fit: bool + """When using image_type='pdf', whether to fit the diagram to the PDF page.""" + pdf_landscape: bool + """When using image_type='pdf', whether to use landscape orientation for the PDF. + + This has no effect if using `pdf_fit`. + """ + pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] + """When using image_type='pdf', the paper size of the PDF.""" + background_color: str + """The background color of the diagram. + + If None, the default transparent background is used. The color value is interpreted as a hexadecimal color + code by default (and should not have a leading '#'), but you can also use named colors by prefixing the + value with `'!'`. For example, valid choices include `background_color='!white'` or `background_color='FF0000'`. + """ + theme: Literal['default', 'neutral', 'dark', 'forest'] + """The theme of the diagram. Defaults to 'default'.""" + width: int + """The width of the diagram.""" + height: int + """The height of the diagram.""" + scale: Annotated[float, Ge(1), Le(3)] + """The scale of the diagram. + + The scale must be a number between 1 and 3, and you can only set a scale if one or both of width and height are set. + """ def request_image( graph: Graph[Any, Any], - start_node_ids: set[str], - *, - highlighted_nodes: Sequence[NodeIdent] | None = None, - highlight_css: str = 'fill:#f9f,stroke:#333,stroke-width:4px', - image_type: ImageType | str | None = None, - pdf_fit: bool = False, - pdf_landscape: bool = False, - pdf_paper: PdfPaper | None = None, - bg_color: str | None = None, - theme: Theme | None = None, - width: int | None = None, - height: int | None = None, - scale: Annotated[float, Ge(1), Le(3)] | None = None, + start_nodes: Sequence[NodeIdent] | NodeIdent, + /, + **kwargs: Unpack[MermaidConfig], ) -> bytes: """Generate an image of a Mermaid diagram using [mermaid.ink](https://mermaid.ink). Args: graph: The graph to generate the image for. - start_node_ids: IDs of start nodes of the graph. - highlighted_nodes: Identifiers of nodes to highlight. - highlight_css: CSS to use for highlighting nodes. - image_type: The image type to generate. If unspecified, the default behavior is `'jpeg'`. - pdf_fit: When using image_type='pdf', whether to fit the diagram to the PDF page. - pdf_landscape: When using image_type='pdf', whether to use landscape orientation for the PDF. - This has no effect if using `pdf_fit`. - pdf_paper: When using image_type='pdf', the paper size of the PDF. - bg_color: The background color of the diagram. If None, the default transparent background is used. - The color value is interpreted as a hexadecimal color code by default (and should not have a leading '#'), - but you can also use named colors by prefixing the value with `'!'`. - For example, valid choices include `bg_color='!white'` or `bg_color='FF0000'`. - theme: The theme of the diagram. Defaults to 'default'. - width: The width of the diagram. - height: The height of the diagram. - scale: The scale of the diagram. The scale must be a number between 1 and 3, and you can only set - a scale if one or both of width and height are set. + start_nodes: Identifiers of nodes that start the graph. + **kwargs: Additional parameters to configure mermaid chart generation. Returns: The image data. """ import httpx - code = generate_code(graph, start_node_ids, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css) + code = generate_code( + graph, + start_nodes, + highlighted_nodes=kwargs.get('highlighted_nodes'), + highlight_css=kwargs.get('highlight_css', DEFAULT_HIGHLIGHT_CSS), + ) code_base64 = base64.b64encode(code.encode()).decode() - params: dict[str, str] = {} - if image_type == 'pdf': + params: dict[str, str | bool] = {} + if kwargs.get('image_type') == 'pdf': url = f'https://mermaid.ink/pdf/{code_base64}' - if pdf_fit: - params['fit'] = '' - if pdf_landscape: - params['landscape'] = '' - if pdf_paper: + if kwargs.get('pdf_fit'): + params['fit'] = True + if kwargs.get('pdf_landscape'): + params['landscape'] = True + if pdf_paper := kwargs.get('pdf_paper'): params['paper'] = pdf_paper - elif image_type == 'svg': + elif kwargs.get('image_type') == 'svg': url = f'https://mermaid.ink/svg/{code_base64}' else: url = f'https://mermaid.ink/img/{code_base64}' - if image_type: + if image_type := kwargs.get('image_type'): params['type'] = image_type - if bg_color: - params['bgColor'] = bg_color - if theme: + if background_color := kwargs.get('background_color'): + params['bgColor'] = background_color + if theme := kwargs.get('theme'): params['theme'] = theme - if width: + if width := kwargs.get('width'): params['width'] = str(width) - if height: + if height := kwargs.get('height'): params['height'] = str(height) - if scale: + if scale := kwargs.get('scale'): params['scale'] = str(scale) response = httpx.get(url, params=params) @@ -154,65 +184,26 @@ def request_image( def save_image( path: Path | str, graph: Graph[Any, Any], - start_node_ids: set[str], - *, - highlighted_nodes: Sequence[NodeIdent] | None = None, - highlight_css: str = 'fill:#f9f,stroke:#333,stroke-width:4px', - image_type: ImageType | None = None, - pdf_fit: bool = False, - pdf_landscape: bool = False, - pdf_paper: PdfPaper | None = None, - bg_color: str | None = None, - theme: Theme | None = None, - width: int | None = None, - height: int | None = None, - scale: Annotated[float, Ge(1), Le(3)] | None = None, + start_nodes: Sequence[NodeIdent] | NodeIdent, + /, + **kwargs: Unpack[MermaidConfig], ) -> None: """Generate an image of a Mermaid diagram using [mermaid.ink](https://mermaid.ink) and save it to a local file. Args: path: The path to save the image to. graph: The graph to generate the image for. - start_node_ids: IDs of start nodes of the graph. - highlighted_nodes: Identifiers of nodes to highlight. - highlight_css: CSS to use for highlighting nodes. - image_type: The image type to generate. If unspecified, the default behavior is `'jpeg'`. - pdf_fit: When using image_type='pdf', whether to fit the diagram to the PDF page. - pdf_landscape: When using image_type='pdf', whether to use landscape orientation for the PDF. - This has no effect if using `pdf_fit`. - pdf_paper: When using image_type='pdf', the paper size of the PDF. - bg_color: The background color of the diagram. If None, the default transparent background is used. - The color value is interpreted as a hexadecimal color code by default (and should not have a leading '#'), - but you can also use named colors by prefixing the value with `'!'`. - For example, valid choices include `bg_color='!white'` or `bg_color='FF0000'`. - theme: The theme of the diagram. Defaults to 'default'. - width: The width of the diagram. - height: The height of the diagram. - scale: The scale of the diagram. The scale must be a number between 1 and 3, and you can only set - a scale if one or both of width and height are set. + start_nodes: Identifiers of nodes that start the graph. + **kwargs: Additional parameters to configure mermaid chart generation. """ if isinstance(path, str): path = Path(path) - if image_type is None: - ext = path.suffix.lower() + if 'image_type' not in kwargs: + ext = path.suffix.lower()[1:] # no need to check for .jpeg/.jpg, as it is the default - if ext in {'.png', '.webp', '.svg', '.pdf'}: - image_type = cast(ImageType, ext[1:]) + if ext in ('png', 'webp', 'svg', 'pdf'): + kwargs['image_type'] = ext - image_data = request_image( - graph, - start_node_ids, - highlighted_nodes=highlighted_nodes, - highlight_css=highlight_css, - image_type=image_type, - pdf_fit=pdf_fit, - pdf_landscape=pdf_landscape, - pdf_paper=pdf_paper, - bg_color=bg_color, - theme=theme, - width=width, - height=height, - scale=scale, - ) - Path(path).write_bytes(image_data) + image_data = request_image(graph, start_nodes, **kwargs) + path.write_bytes(image_data) From c325789101312aa302d9df594458b8792f9f9de8 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 8 Jan 2025 18:14:19 +0000 Subject: [PATCH 21/57] remove GraphRunner --- .../email_extract_graph.py | 5 +- .../pydantic_ai_graph/__init__.py | 3 +- pydantic_ai_graph/pydantic_ai_graph/graph.py | 63 +++++-------------- tests/test_graph.py | 5 +- 4 files changed, 20 insertions(+), 56 deletions(-) diff --git a/examples/pydantic_ai_examples/email_extract_graph.py b/examples/pydantic_ai_examples/email_extract_graph.py index 8465f5bcd9..a726b8afff 100644 --- a/examples/pydantic_ai_examples/email_extract_graph.py +++ b/examples/pydantic_ai_examples/email_extract_graph.py @@ -122,8 +122,7 @@ async def run( InspectEvent, ) ) -graph_runner = graph.get_runner(ExtractEvent) -print(graph_runner.mermaid_code()) +print(graph.mermaid_code(ExtractEvent)) email = """ Hi Samuel, @@ -148,7 +147,7 @@ async def run( async def main(): state = State(email_content=email) - result, history = await graph_runner.run(state, None) + result, history = await graph.run(state, ExtractEvent()) debug(result, history) diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py index b7d79a6009..f6b7054845 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/__init__.py +++ b/pydantic_ai_graph/pydantic_ai_graph/__init__.py @@ -1,10 +1,9 @@ -from .graph import Graph, GraphRun, GraphRunner +from .graph import Graph, GraphRun from .nodes import BaseNode, End, GraphContext from .state import AbstractState, EndEvent, Step, StepOrEnd __all__ = ( 'Graph', - 'GraphRunner', 'GraphRun', 'BaseNode', 'End', diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index ce39168d2c..d1541f29e9 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -8,14 +8,14 @@ from typing import TYPE_CHECKING, Generic import logfire_api -from typing_extensions import Never, ParamSpec, Protocol, TypeVar, Unpack, assert_never +from typing_extensions import Never, ParamSpec, TypeVar, Unpack, assert_never from . import _utils, mermaid from ._utils import get_parent_namespace from .nodes import BaseNode, End, GraphContext, NodeDef from .state import EndEvent, StateT, Step, StepOrEnd -__all__ = 'Graph', 'GraphRun', 'GraphRunner' +__all__ = 'Graph', 'GraphRun' _logfire = logfire_api.Logfire(otel_scope='pydantic-ai-graph') @@ -24,12 +24,6 @@ NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) -class StartNodeProtocol(Protocol[RunSignatureT, StateT, NodeRunEndT]): - def get_id(self) -> str: ... - - def __call__(self, *args: RunSignatureT.args, **kwargs: RunSignatureT.kwargs) -> BaseNode[StateT, NodeRunEndT]: ... - - @dataclass(init=False) class Graph(Generic[StateT, RunEndT]): """Definition of a graph.""" @@ -91,56 +85,29 @@ async def run( history = run.history return result, history - def get_runner( - self, - first_node: StartNodeProtocol[RunSignatureT, StateT, RunEndT], - ) -> GraphRunner[RunSignatureT, StateT, RunEndT]: - # noinspection PyTypeChecker - return GraphRunner( - graph=self, - first_node=first_node, - ) - - -@dataclass -class GraphRunner(Generic[RunSignatureT, StateT, RunEndT]): - """Runner for a graph. - - This is a separate class from Graph so that you can get a type-safe runner from a graph definition - without needing to manually annotate the paramspec of the start node. - """ - - graph: Graph[StateT, RunEndT] - first_node: StartNodeProtocol[RunSignatureT, StateT, RunEndT] - - def __post_init__(self): - if self.first_node not in self.graph.nodes: - raise ValueError(f'Start node "{self.first_node}" is not in the graph.') - - async def run( - self, state: StateT, /, *args: RunSignatureT.args, **kwargs: RunSignatureT.kwargs - ) -> tuple[RunEndT, list[StepOrEnd[StateT, RunEndT]]]: - run = GraphRun[StateT, RunEndT](state=state) - # TODO: Infer the graph name properly - result = await run.run(self.graph.name or 'graph', self.first_node(*args, **kwargs)) - history = run.history - return result, history - def mermaid_code( self, + start_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent, *, highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, ) -> str: return mermaid.generate_code( - self.graph, self.first_node.get_id(), highlighted_nodes=highlighted_nodes, highlight_css=highlight_css + self, start_nodes, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css ) - def mermaid_image(self, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes: - return mermaid.request_image(self.graph, self.first_node.get_id(), **kwargs) + def mermaid_image( + self, start_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent, **kwargs: Unpack[mermaid.MermaidConfig] + ) -> bytes: + return mermaid.request_image(self, start_nodes, **kwargs) - def mermaid_save(self, path: Path | str, /, **kwargs: Unpack[mermaid.MermaidConfig]) -> None: - mermaid.save_image(path, self.graph, self.first_node.get_id(), **kwargs) + def mermaid_save( + self, + start_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent, + path: Path | str, + **kwargs: Unpack[mermaid.MermaidConfig], + ) -> None: + mermaid.save_image(path, self, start_nodes, **kwargs) @dataclass diff --git a/tests/test_graph.py b/tests/test_graph.py index 4a06603a7e..6187b764e6 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -40,8 +40,7 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq return End(self.input_data * 2) g = Graph[None, int](nodes=(Float2String, String2Length, Double)) - runner = g.get_runner(Float2String) - result, history = await runner.run(None, 3.14) + result, history = await g.run(None, Float2String(3.14)) # len('3.14') * 2 == 8 assert result == 8 assert history == snapshot( @@ -71,7 +70,7 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq ), ] ) - result, history = await runner.run(None, 3.14159) + result, history = await g.run(None, Float2String(3.14159)) # len('3.14159') == 7, 21 * 2 == 42 assert result == 42 assert history == snapshot( From c41d59a56058856a361b2786f47f5d27bfa6a97a Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 9 Jan 2025 00:19:38 +0000 Subject: [PATCH 22/57] add Interrupt --- .../email_extract_graph.py | 2 +- .../pydantic_ai_graph/__init__.py | 9 +-- pydantic_ai_graph/pydantic_ai_graph/graph.py | 46 +++++++------- .../pydantic_ai_graph/mermaid.py | 33 ++++++---- pydantic_ai_graph/pydantic_ai_graph/nodes.py | 61 ++++++++++++++----- pydantic_ai_graph/pydantic_ai_graph/state.py | 47 +++++++++++--- tests/test_graph.py | 22 +++---- 7 files changed, 144 insertions(+), 76 deletions(-) diff --git a/examples/pydantic_ai_examples/email_extract_graph.py b/examples/pydantic_ai_examples/email_extract_graph.py index a726b8afff..8cf30803e1 100644 --- a/examples/pydantic_ai_examples/email_extract_graph.py +++ b/examples/pydantic_ai_examples/email_extract_graph.py @@ -122,7 +122,7 @@ async def run( InspectEvent, ) ) -print(graph.mermaid_code(ExtractEvent)) +print(graph.mermaid_code(start_node=ExtractEvent)) email = """ Hi Samuel, diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py index f6b7054845..afdad1bdc2 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/__init__.py +++ b/pydantic_ai_graph/pydantic_ai_graph/__init__.py @@ -1,15 +1,16 @@ from .graph import Graph, GraphRun -from .nodes import BaseNode, End, GraphContext -from .state import AbstractState, EndEvent, Step, StepOrEnd +from .nodes import BaseNode, End, GraphContext, Interrupt +from .state import AbstractState, EndEvent, HistoryStep, NextNodeEvent __all__ = ( 'Graph', 'GraphRun', 'BaseNode', 'End', + 'Interrupt', 'GraphContext', 'AbstractState', 'EndEvent', - 'StepOrEnd', - 'Step', + 'HistoryStep', + 'NextNodeEvent', ) diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index d1541f29e9..22db3d6497 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -12,8 +12,8 @@ from . import _utils, mermaid from ._utils import get_parent_namespace -from .nodes import BaseNode, End, GraphContext, NodeDef -from .state import EndEvent, StateT, Step, StepOrEnd +from .nodes import BaseNode, End, GraphContext, Interrupt, NodeDef, RunInterrupt +from .state import EndEvent, HistoryStep, InterruptEvent, NextNodeEvent, StateT __all__ = 'Graph', 'GraphRun' @@ -76,7 +76,7 @@ def _validate_edges(self): async def run( self, state: StateT, node: BaseNode[StateT, RunEndT] - ) -> tuple[RunEndT, list[StepOrEnd[StateT, RunEndT]]]: + ) -> tuple[End[RunEndT] | RunInterrupt[StateT], list[HistoryStep[StateT, RunEndT]]]: if not isinstance(node, self.nodes): raise ValueError(f'Node "{node}" is not in the graph.') run = GraphRun[StateT, RunEndT](state=state) @@ -87,27 +87,20 @@ async def run( def mermaid_code( self, - start_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent, *, + start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, ) -> str: return mermaid.generate_code( - self, start_nodes, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css + self, start_node=start_node, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css ) - def mermaid_image( - self, start_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent, **kwargs: Unpack[mermaid.MermaidConfig] - ) -> bytes: - return mermaid.request_image(self, start_nodes, **kwargs) + def mermaid_image(self, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes: + return mermaid.request_image(self, **kwargs) - def mermaid_save( - self, - start_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent, - path: Path | str, - **kwargs: Unpack[mermaid.MermaidConfig], - ) -> None: - mermaid.save_image(path, self, start_nodes, **kwargs) + def mermaid_save(self, path: Path | str, /, **kwargs: Unpack[mermaid.MermaidConfig]) -> None: + mermaid.save_image(path, self, **kwargs) @dataclass @@ -115,9 +108,11 @@ class GraphRun(Generic[StateT, RunEndT]): """Stateful run of a graph.""" state: StateT - history: list[StepOrEnd[StateT, RunEndT]] = field(default_factory=list) + history: list[HistoryStep[StateT, RunEndT]] = field(default_factory=list) - async def run(self, graph_name: str, start: BaseNode[StateT, RunEndT], infer_name: bool = True) -> RunEndT: + async def run( + self, graph_name: str, start: BaseNode[StateT, RunEndT], infer_name: bool = True + ) -> End[RunEndT] | RunInterrupt[StateT]: current_node = start with _logfire.span( @@ -127,10 +122,13 @@ async def run(self, graph_name: str, start: BaseNode[StateT, RunEndT], infer_nam ) as run_span: while True: next_node = await self.step(current_node) - if isinstance(next_node, End): - self.history.append(EndEvent(self.state, next_node)) + if isinstance(next_node, (Interrupt, End)): + if isinstance(next_node, End): + self.history.append(EndEvent(self.state, next_node)) + else: + self.history.append(InterruptEvent(self.state, next_node)) run_span.set_attribute('history', self.history) - return next_node.data + return next_node elif isinstance(next_node, BaseNode): current_node = next_node else: @@ -139,8 +137,10 @@ async def run(self, graph_name: str, start: BaseNode[StateT, RunEndT], infer_nam else: raise TypeError(f'Invalid node type: {type(next_node)}. Expected `BaseNode` or `End`.') - async def step(self, node: BaseNode[StateT, RunEndT]) -> BaseNode[StateT, RunEndT] | End[RunEndT]: - history_step = Step(self.state, node) + async def step( + self, node: BaseNode[StateT, RunEndT] + ) -> BaseNode[StateT, RunEndT] | End[RunEndT] | RunInterrupt[StateT]: + history_step = NextNodeEvent(self.state, node) self.history.append(history_step) ctx = GraphContext(self.state) diff --git a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py index 00d2d3a9f2..109423bcee 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py +++ b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py @@ -2,6 +2,7 @@ import base64 from collections.abc import Iterable, Sequence +from itertools import chain from pathlib import Path from typing import TYPE_CHECKING, Annotated, Any, Literal @@ -20,9 +21,9 @@ def generate_code( graph: Graph[Any, Any], - start_nodes: Sequence[NodeIdent] | NodeIdent, /, *, + start_node: Sequence[NodeIdent] | NodeIdent | None = None, highlighted_nodes: Sequence[NodeIdent] | NodeIdent | None = None, highlight_css: str = DEFAULT_HIGHLIGHT_CSS, ) -> str: @@ -30,33 +31,41 @@ def generate_code( Args: graph: The graph to generate the image for. - start_nodes: Identifiers of nodes that start the graph. + start_node: Identifiers of nodes that start the graph. highlighted_nodes: Identifiers of nodes to highlight. highlight_css: CSS to use for highlighting nodes. Returns: The Mermaid code for the graph. """ - start_node_ids = set(node_ids(start_nodes)) + start_node_ids = set(node_ids(start_node or ())) for node_id in start_node_ids: if node_id not in graph.node_defs: raise LookupError(f'Start node "{node_id}" is not in the graph.') node_order = {node_id: index for index, node_id in enumerate(graph.node_defs)} + after_interrupt_nodes = set(chain(*(node_def.after_interrupt_nodes for node_def in graph.node_defs.values()))) lines = ['graph TD'] for node in graph.nodes: node_id = node.get_id() node_def = graph.node_defs[node_id] + + # we use square brackets (square box) for nodes that can interrupt the flow, + # and round brackets (rounded box) for nodes that cannot interrupt the flow + if node_id in after_interrupt_nodes: + mermaid_name = f'[{node_id}]' + else: + mermaid_name = f'({node_id})' if node_id in start_node_ids: - lines.append(f' START --> {node_id}') + lines.append(f' START --> {node_id}{mermaid_name}') if node_def.returns_base_node: for next_node_id in graph.nodes: - lines.append(f' {node_id} --> {next_node_id}') + lines.append(f' {node_id}{mermaid_name} --> {next_node_id}') else: for _, next_node_id in sorted((node_order[node_id], node_id) for node_id in node_def.next_node_ids): - lines.append(f' {node_id} --> {next_node_id}') + lines.append(f' {node_id}{mermaid_name} --> {next_node_id}') if node_def.returns_end: - lines.append(f' {node_id} --> END') + lines.append(f' {node_id}{mermaid_name} --> END') if highlighted_nodes: lines.append('') @@ -88,6 +97,8 @@ def node_ids(node_idents: Sequence[NodeIdent] | NodeIdent) -> Iterable[str]: class MermaidConfig(TypedDict, total=False): """Parameters to configure mermaid chart generation.""" + start_node: Sequence[NodeIdent] | NodeIdent + """Identifiers of nodes that start the graph.""" highlighted_nodes: Sequence[NodeIdent] | NodeIdent """Identifiers of nodes to highlight.""" highlight_css: str @@ -125,7 +136,6 @@ class MermaidConfig(TypedDict, total=False): def request_image( graph: Graph[Any, Any], - start_nodes: Sequence[NodeIdent] | NodeIdent, /, **kwargs: Unpack[MermaidConfig], ) -> bytes: @@ -133,7 +143,6 @@ def request_image( Args: graph: The graph to generate the image for. - start_nodes: Identifiers of nodes that start the graph. **kwargs: Additional parameters to configure mermaid chart generation. Returns: The image data. @@ -142,7 +151,7 @@ def request_image( code = generate_code( graph, - start_nodes, + start_node=kwargs.get('start_node'), highlighted_nodes=kwargs.get('highlighted_nodes'), highlight_css=kwargs.get('highlight_css', DEFAULT_HIGHLIGHT_CSS), ) @@ -184,7 +193,6 @@ def request_image( def save_image( path: Path | str, graph: Graph[Any, Any], - start_nodes: Sequence[NodeIdent] | NodeIdent, /, **kwargs: Unpack[MermaidConfig], ) -> None: @@ -193,7 +201,6 @@ def save_image( Args: path: The path to save the image to. graph: The graph to generate the image for. - start_nodes: Identifiers of nodes that start the graph. **kwargs: Additional parameters to configure mermaid chart generation. """ if isinstance(path, str): @@ -205,5 +212,5 @@ def save_image( if ext in ('png', 'webp', 'svg', 'pdf'): kwargs['image_type'] = ext - image_data = request_image(graph, start_nodes, **kwargs) + image_data = request_image(graph, **kwargs) path.write_bytes(image_data) diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_ai_graph/pydantic_ai_graph/nodes.py index 1a1b8bcd5d..168339d0a9 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/nodes.py +++ b/pydantic_ai_graph/pydantic_ai_graph/nodes.py @@ -3,14 +3,14 @@ from abc import abstractmethod from dataclasses import dataclass from functools import cache -from typing import Any, Generic, get_origin, get_type_hints +from typing import Any, Generic, get_args, get_origin, get_type_hints from typing_extensions import Never, TypeVar from . import _utils from .state import StateT -__all__ = ('GraphContext', 'End', 'BaseNode', 'NodeDef') +__all__ = 'GraphContext', 'BaseNode', 'End', 'Interrupt', 'RunInterrupt', 'NodeDef' RunEndT = TypeVar('RunEndT', default=None) NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) @@ -23,18 +23,13 @@ class GraphContext(Generic[StateT]): state: StateT -@dataclass -class End(Generic[RunEndT]): - """Type to return from a node to signal the end of the graph.""" - - data: RunEndT - - class BaseNode(Generic[StateT, NodeRunEndT]): """Base class for a node.""" @abstractmethod - async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: ... + async def run( + self, ctx: GraphContext[StateT] + ) -> BaseNode[StateT, Any] | End[NodeRunEndT] | RunInterrupt[StateT]: ... @classmethod @cache @@ -44,20 +39,27 @@ def get_id(cls) -> str: @classmethod def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRunEndT]: type_hints = get_type_hints(cls.run, localns=local_ns) - next_node_ids: set[str] = set() - returns_end: bool = False - returns_base_node: bool = False try: return_hint = type_hints['return'] except KeyError: raise TypeError(f'Node {cls} is missing a return type hint on its `run` method') + next_node_ids: set[str] = set() + returns_end: bool = False + returns_base_node: bool = False + after_interrupt_nodes: list[str] = [] for return_type in _utils.get_union_args(return_hint): return_type_origin = get_origin(return_type) or return_type if return_type_origin is End: returns_end = True + elif issubclass(return_type_origin, Interrupt): + interrupt_args = get_args(return_type) + assert len(interrupt_args) == 1, f'Invalid Interrupt return type: {return_type}' + next_node_id = interrupt_args[0].get_id() + next_node_ids.add(next_node_id) + after_interrupt_nodes.append(next_node_id) elif return_type_origin is BaseNode: - # TODO: Should we disallow this? More generally, what do we do about sub-subclasses? + # TODO: Should we disallow this? returns_base_node = True elif issubclass(return_type_origin, BaseNode): next_node_ids.add(return_type.get_id()) @@ -69,19 +71,48 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRu cls.get_id(), next_node_ids, returns_end, + after_interrupt_nodes, returns_base_node, ) +@dataclass +class End(Generic[RunEndT]): + """Type to return from a node to signal the end of the graph.""" + + data: RunEndT + + +InterruptNextNodeT = TypeVar('InterruptNextNodeT', covariant=True, bound=BaseNode[Any, Any]) + + +@dataclass +class Interrupt(Generic[InterruptNextNodeT]): + """Type to return from a node to signal that the run should be interrupted.""" + + node: InterruptNextNodeT + + +RunInterrupt = Interrupt[BaseNode[StateT, Any]] + + @dataclass class NodeDef(Generic[StateT, NodeRunEndT]): """Definition of a node. - Used by [`Graph`][pydantic_ai_graph.graph.Graph] store information about a node. + Used by [`Graph`][pydantic_ai_graph.graph.Graph] to store information about a node, and when generating + mermaid graphs. """ node: type[BaseNode[StateT, NodeRunEndT]] + """The node definition itself.""" node_id: str + """ID of the node.""" next_node_ids: set[str] + """IDs of the nodes that can be called next.""" returns_end: bool + """The node definition returns an `End`, hence the node and end the run.""" + after_interrupt_nodes: list[str] + """Nodes that can be returned within an `Interrupt`.""" returns_base_node: bool + """The node definition returns a `BaseNode`, hence any node in the next can be called next.""" diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py index 66c527d882..41c81d03eb 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -4,17 +4,17 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Generic, Literal, Union +from typing import TYPE_CHECKING, Generic, Literal, Self, Union from typing_extensions import Never, TypeVar from . import _utils -__all__ = 'AbstractState', 'StateT', 'Step', 'EndEvent', 'StepOrEnd' +__all__ = 'AbstractState', 'StateT', 'NextNodeEvent', 'EndEvent', 'InterruptEvent', 'HistoryStep' if TYPE_CHECKING: from pydantic_ai_graph import BaseNode - from pydantic_ai_graph.nodes import End + from pydantic_ai_graph.nodes import End, RunInterrupt class AbstractState(ABC): @@ -25,6 +25,10 @@ def serialize(self) -> bytes | None: """Serialize the state object.""" raise NotImplementedError + def deep_copy(self) -> Self: + """Create a deep copy of the state object.""" + return copy.deepcopy(self) + RunEndT = TypeVar('RunEndT', default=None) NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) @@ -32,8 +36,8 @@ def serialize(self) -> bytes | None: @dataclass -class Step(Generic[StateT, RunEndT]): - """History item describing the execution of a step of a graph.""" +class NextNodeEvent(Generic[StateT, RunEndT]): + """History step describing the execution of a step of a graph.""" state: StateT node: BaseNode[StateT, RunEndT] @@ -44,15 +48,33 @@ class Step(Generic[StateT, RunEndT]): def __post_init__(self): # Copy the state to prevent it from being modified by other code - self.state = copy.deepcopy(self.state) + self.state = _deep_copy_state(self.state) def node_summary(self) -> str: return str(self.node) +@dataclass +class InterruptEvent(Generic[StateT]): + """History step describing the interruption of a graph run.""" + + state: StateT + result: RunInterrupt[StateT] + ts: datetime = field(default_factory=_utils.now_utc) + + kind: Literal['interrupt'] = 'interrupt' + + def __post_init__(self): + # Copy the state to prevent it from being modified by other code + self.state = _deep_copy_state(self.state) + + def node_summary(self) -> str: + return str(self.result) + + @dataclass class EndEvent(Generic[StateT, RunEndT]): - """History item describing the end of a graph run.""" + """History step describing the end of a graph run.""" state: StateT result: End[RunEndT] @@ -62,10 +84,17 @@ class EndEvent(Generic[StateT, RunEndT]): def __post_init__(self): # Copy the state to prevent it from being modified by other code - self.state = copy.deepcopy(self.state) + self.state = _deep_copy_state(self.state) def node_summary(self) -> str: return str(self.result) -StepOrEnd = Union[Step[StateT, RunEndT], EndEvent[StateT, RunEndT]] +def _deep_copy_state(state: StateT) -> StateT: + if state is None: + return None # pyright: ignore[reportReturnType] + else: + return state.deep_copy() + + +HistoryStep = Union[NextNodeEvent[StateT, RunEndT], InterruptEvent[StateT], EndEvent[StateT, RunEndT]] diff --git a/tests/test_graph.py b/tests/test_graph.py index 6187b764e6..cdef4f27d3 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -7,7 +7,7 @@ import pytest from inline_snapshot import snapshot -from pydantic_ai_graph import BaseNode, End, EndEvent, Graph, GraphContext, Step +from pydantic_ai_graph import BaseNode, End, EndEvent, Graph, GraphContext, NextNodeEvent from .conftest import IsFloat, IsNow @@ -42,22 +42,22 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq g = Graph[None, int](nodes=(Float2String, String2Length, Double)) result, history = await g.run(None, Float2String(3.14)) # len('3.14') * 2 == 8 - assert result == 8 + assert result == End(8) assert history == snapshot( [ - Step( + NextNodeEvent( state=None, node=Float2String(input_data=3.14), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - Step( + NextNodeEvent( state=None, node=String2Length(input_data='3.14'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - Step( + NextNodeEvent( state=None, node=Double(input_data=4), start_ts=IsNow(tz=timezone.utc), @@ -72,34 +72,34 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq ) result, history = await g.run(None, Float2String(3.14159)) # len('3.14159') == 7, 21 * 2 == 42 - assert result == 42 + assert result == End(42) assert history == snapshot( [ - Step( + NextNodeEvent( state=None, node=Float2String(input_data=3.14159), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - Step( + NextNodeEvent( state=None, node=String2Length(input_data='3.14159'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - Step( + NextNodeEvent( state=None, node=Double(input_data=7), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - Step( + NextNodeEvent( state=None, node=String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - Step( + NextNodeEvent( state=None, node=Double(input_data=21), start_ts=IsNow(tz=timezone.utc), From 0b19632d230690f373ac2d451f6ea6cc2d65ac0f Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 9 Jan 2025 18:20:39 +0000 Subject: [PATCH 23/57] remove interrupt, replace with "next()" --- .../email_extract_graph.py | 3 +- .../pydantic_ai_graph/__init__.py | 6 +- pydantic_ai_graph/pydantic_ai_graph/graph.py | 108 ++++++++---------- .../pydantic_ai_graph/mermaid.py | 10 +- pydantic_ai_graph/pydantic_ai_graph/nodes.py | 31 +---- pydantic_ai_graph/pydantic_ai_graph/state.py | 28 +---- 6 files changed, 60 insertions(+), 126 deletions(-) diff --git a/examples/pydantic_ai_examples/email_extract_graph.py b/examples/pydantic_ai_examples/email_extract_graph.py index 8cf30803e1..0c16e8bee0 100644 --- a/examples/pydantic_ai_examples/email_extract_graph.py +++ b/examples/pydantic_ai_examples/email_extract_graph.py @@ -147,7 +147,8 @@ async def run( async def main(): state = State(email_content=email) - result, history = await graph.run(state, ExtractEvent()) + history = [] + result = await graph.run(state, ExtractEvent()) debug(result, history) diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py index afdad1bdc2..3fbdb587bb 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/__init__.py +++ b/pydantic_ai_graph/pydantic_ai_graph/__init__.py @@ -1,13 +1,11 @@ -from .graph import Graph, GraphRun -from .nodes import BaseNode, End, GraphContext, Interrupt +from .graph import Graph +from .nodes import BaseNode, End, GraphContext from .state import AbstractState, EndEvent, HistoryStep, NextNodeEvent __all__ = ( 'Graph', - 'GraphRun', 'BaseNode', 'End', - 'Interrupt', 'GraphContext', 'AbstractState', 'EndEvent', diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index 22db3d6497..32e8a7caa8 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -2,20 +2,20 @@ import inspect from collections.abc import Sequence -from dataclasses import dataclass, field +from dataclasses import dataclass from pathlib import Path from time import perf_counter -from typing import TYPE_CHECKING, Generic +from typing import TYPE_CHECKING, Any, Generic import logfire_api from typing_extensions import Never, ParamSpec, TypeVar, Unpack, assert_never from . import _utils, mermaid from ._utils import get_parent_namespace -from .nodes import BaseNode, End, GraphContext, Interrupt, NodeDef, RunInterrupt -from .state import EndEvent, HistoryStep, InterruptEvent, NextNodeEvent, StateT +from .nodes import BaseNode, End, GraphContext, NodeDef +from .state import EndEvent, HistoryStep, NextNodeEvent, StateT -__all__ = 'Graph', 'GraphRun' +__all__ = ('Graph',) _logfire = logfire_api.Logfire(otel_scope='pydantic-ai-graph') @@ -74,16 +74,48 @@ def _validate_edges(self): b = '\n'.join(f' {be}' for be in bad_edges_list) raise ValueError(f'Nodes are referenced in the graph but not included in the graph:\n{b}') + async def next( + self, state: StateT, node: BaseNode[StateT, RunEndT], history: list[HistoryStep[StateT, RunEndT]] + ) -> BaseNode[StateT, Any] | End[RunEndT]: + node_id = node.get_id() + if node_id not in self.node_defs: + raise TypeError(f'Node "{node}" is not in the graph.') + + history_step: NextNodeEvent[StateT, RunEndT] | None = NextNodeEvent(state, node) + history.append(history_step) + + ctx = GraphContext(state) + with _logfire.span('run node {node_id}', node_id=node_id, node=node): + start = perf_counter() + next_node = await node.run(ctx) + history_step.duration = perf_counter() - start + return next_node + async def run( - self, state: StateT, node: BaseNode[StateT, RunEndT] - ) -> tuple[End[RunEndT] | RunInterrupt[StateT], list[HistoryStep[StateT, RunEndT]]]: - if not isinstance(node, self.nodes): - raise ValueError(f'Node "{node}" is not in the graph.') - run = GraphRun[StateT, RunEndT](state=state) - # TODO: Infer the graph name properly - result = await run.run(self.name or 'graph', node) - history = run.history - return result, history + self, + state: StateT, + node: BaseNode[StateT, RunEndT], + ) -> tuple[End[RunEndT], list[HistoryStep[StateT, RunEndT]]]: + history: list[HistoryStep[StateT, RunEndT]] = [] + + with _logfire.span( + '{graph_name} run {start=}', + graph_name=self.name or 'graph', + start=node, + ) as run_span: + while True: + next_node = await self.next(state, node, history=history) + if isinstance(next_node, End): + history.append(EndEvent(state, next_node)) + run_span.set_attribute('history', history) + return next_node, history + elif isinstance(next_node, BaseNode): + node = next_node + else: + if TYPE_CHECKING: + assert_never(next_node) + else: + raise TypeError(f'Invalid node type: {type(next_node)}. Expected `BaseNode` or `End`.') def mermaid_code( self, @@ -101,51 +133,3 @@ def mermaid_image(self, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes: def mermaid_save(self, path: Path | str, /, **kwargs: Unpack[mermaid.MermaidConfig]) -> None: mermaid.save_image(path, self, **kwargs) - - -@dataclass -class GraphRun(Generic[StateT, RunEndT]): - """Stateful run of a graph.""" - - state: StateT - history: list[HistoryStep[StateT, RunEndT]] = field(default_factory=list) - - async def run( - self, graph_name: str, start: BaseNode[StateT, RunEndT], infer_name: bool = True - ) -> End[RunEndT] | RunInterrupt[StateT]: - current_node = start - - with _logfire.span( - '{graph_name} run {start=}', - graph_name=graph_name, - start=start, - ) as run_span: - while True: - next_node = await self.step(current_node) - if isinstance(next_node, (Interrupt, End)): - if isinstance(next_node, End): - self.history.append(EndEvent(self.state, next_node)) - else: - self.history.append(InterruptEvent(self.state, next_node)) - run_span.set_attribute('history', self.history) - return next_node - elif isinstance(next_node, BaseNode): - current_node = next_node - else: - if TYPE_CHECKING: - assert_never(next_node) - else: - raise TypeError(f'Invalid node type: {type(next_node)}. Expected `BaseNode` or `End`.') - - async def step( - self, node: BaseNode[StateT, RunEndT] - ) -> BaseNode[StateT, RunEndT] | End[RunEndT] | RunInterrupt[StateT]: - history_step = NextNodeEvent(self.state, node) - self.history.append(history_step) - - ctx = GraphContext(self.state) - with _logfire.span('run node {node_id}', node_id=node.get_id()): - start = perf_counter() - next_node = await node.run(ctx) - history_step.duration = perf_counter() - start - return next_node diff --git a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py index 109423bcee..534e1b5227 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py +++ b/pydantic_ai_graph/pydantic_ai_graph/mermaid.py @@ -2,7 +2,6 @@ import base64 from collections.abc import Iterable, Sequence -from itertools import chain from pathlib import Path from typing import TYPE_CHECKING, Annotated, Any, Literal @@ -43,19 +42,14 @@ def generate_code( raise LookupError(f'Start node "{node_id}" is not in the graph.') node_order = {node_id: index for index, node_id in enumerate(graph.node_defs)} - after_interrupt_nodes = set(chain(*(node_def.after_interrupt_nodes for node_def in graph.node_defs.values()))) lines = ['graph TD'] for node in graph.nodes: node_id = node.get_id() node_def = graph.node_defs[node_id] - # we use square brackets (square box) for nodes that can interrupt the flow, - # and round brackets (rounded box) for nodes that cannot interrupt the flow - if node_id in after_interrupt_nodes: - mermaid_name = f'[{node_id}]' - else: - mermaid_name = f'({node_id})' + # we use round brackets (rounded box) for nodes other than the start and end + mermaid_name = f'({node_id})' if node_id in start_node_ids: lines.append(f' START --> {node_id}{mermaid_name}') if node_def.returns_base_node: diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_ai_graph/pydantic_ai_graph/nodes.py index 168339d0a9..c5beb7b6d5 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/nodes.py +++ b/pydantic_ai_graph/pydantic_ai_graph/nodes.py @@ -3,14 +3,14 @@ from abc import abstractmethod from dataclasses import dataclass from functools import cache -from typing import Any, Generic, get_args, get_origin, get_type_hints +from typing import Any, Generic, get_origin, get_type_hints from typing_extensions import Never, TypeVar from . import _utils from .state import StateT -__all__ = 'GraphContext', 'BaseNode', 'End', 'Interrupt', 'RunInterrupt', 'NodeDef' +__all__ = 'GraphContext', 'BaseNode', 'End', 'NodeDef' RunEndT = TypeVar('RunEndT', default=None) NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) @@ -27,9 +27,7 @@ class BaseNode(Generic[StateT, NodeRunEndT]): """Base class for a node.""" @abstractmethod - async def run( - self, ctx: GraphContext[StateT] - ) -> BaseNode[StateT, Any] | End[NodeRunEndT] | RunInterrupt[StateT]: ... + async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: ... @classmethod @cache @@ -47,17 +45,10 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRu next_node_ids: set[str] = set() returns_end: bool = False returns_base_node: bool = False - after_interrupt_nodes: list[str] = [] for return_type in _utils.get_union_args(return_hint): return_type_origin = get_origin(return_type) or return_type if return_type_origin is End: returns_end = True - elif issubclass(return_type_origin, Interrupt): - interrupt_args = get_args(return_type) - assert len(interrupt_args) == 1, f'Invalid Interrupt return type: {return_type}' - next_node_id = interrupt_args[0].get_id() - next_node_ids.add(next_node_id) - after_interrupt_nodes.append(next_node_id) elif return_type_origin is BaseNode: # TODO: Should we disallow this? returns_base_node = True @@ -71,7 +62,6 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRu cls.get_id(), next_node_ids, returns_end, - after_interrupt_nodes, returns_base_node, ) @@ -83,19 +73,6 @@ class End(Generic[RunEndT]): data: RunEndT -InterruptNextNodeT = TypeVar('InterruptNextNodeT', covariant=True, bound=BaseNode[Any, Any]) - - -@dataclass -class Interrupt(Generic[InterruptNextNodeT]): - """Type to return from a node to signal that the run should be interrupted.""" - - node: InterruptNextNodeT - - -RunInterrupt = Interrupt[BaseNode[StateT, Any]] - - @dataclass class NodeDef(Generic[StateT, NodeRunEndT]): """Definition of a node. @@ -112,7 +89,5 @@ class NodeDef(Generic[StateT, NodeRunEndT]): """IDs of the nodes that can be called next.""" returns_end: bool """The node definition returns an `End`, hence the node and end the run.""" - after_interrupt_nodes: list[str] - """Nodes that can be returned within an `Interrupt`.""" returns_base_node: bool """The node definition returns a `BaseNode`, hence any node in the next can be called next.""" diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py index 41c81d03eb..ad8858c4a4 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -10,11 +10,11 @@ from . import _utils -__all__ = 'AbstractState', 'StateT', 'NextNodeEvent', 'EndEvent', 'InterruptEvent', 'HistoryStep' +__all__ = 'AbstractState', 'StateT', 'NextNodeEvent', 'EndEvent', 'HistoryStep' if TYPE_CHECKING: from pydantic_ai_graph import BaseNode - from pydantic_ai_graph.nodes import End, RunInterrupt + from pydantic_ai_graph.nodes import End class AbstractState(ABC): @@ -50,28 +50,10 @@ def __post_init__(self): # Copy the state to prevent it from being modified by other code self.state = _deep_copy_state(self.state) - def node_summary(self) -> str: + def summary(self) -> str: return str(self.node) -@dataclass -class InterruptEvent(Generic[StateT]): - """History step describing the interruption of a graph run.""" - - state: StateT - result: RunInterrupt[StateT] - ts: datetime = field(default_factory=_utils.now_utc) - - kind: Literal['interrupt'] = 'interrupt' - - def __post_init__(self): - # Copy the state to prevent it from being modified by other code - self.state = _deep_copy_state(self.state) - - def node_summary(self) -> str: - return str(self.result) - - @dataclass class EndEvent(Generic[StateT, RunEndT]): """History step describing the end of a graph run.""" @@ -86,7 +68,7 @@ def __post_init__(self): # Copy the state to prevent it from being modified by other code self.state = _deep_copy_state(self.state) - def node_summary(self) -> str: + def summary(self) -> str: return str(self.result) @@ -97,4 +79,4 @@ def _deep_copy_state(state: StateT) -> StateT: return state.deep_copy() -HistoryStep = Union[NextNodeEvent[StateT, RunEndT], InterruptEvent[StateT], EndEvent[StateT, RunEndT]] +HistoryStep = Union[NextNodeEvent[StateT, RunEndT], EndEvent[StateT, RunEndT]] From c1b80356102a690babe0cac2f4938ebd25517dc9 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 9 Jan 2025 18:29:23 +0000 Subject: [PATCH 24/57] address comments --- .../pydantic_ai_graph/__init__.py | 4 ++-- pydantic_ai_graph/pydantic_ai_graph/graph.py | 4 ++-- pydantic_ai_graph/pydantic_ai_graph/nodes.py | 2 +- pydantic_ai_graph/pydantic_ai_graph/state.py | 10 +++++----- tests/test_graph.py | 18 +++++++++--------- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_ai_graph/pydantic_ai_graph/__init__.py index 3fbdb587bb..51db45c53c 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/__init__.py +++ b/pydantic_ai_graph/pydantic_ai_graph/__init__.py @@ -1,6 +1,6 @@ from .graph import Graph from .nodes import BaseNode, End, GraphContext -from .state import AbstractState, EndEvent, HistoryStep, NextNodeEvent +from .state import AbstractState, EndEvent, HistoryStep, NodeEvent __all__ = ( 'Graph', @@ -10,5 +10,5 @@ 'AbstractState', 'EndEvent', 'HistoryStep', - 'NextNodeEvent', + 'NodeEvent', ) diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_ai_graph/pydantic_ai_graph/graph.py index 32e8a7caa8..7c008469da 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_ai_graph/pydantic_ai_graph/graph.py @@ -13,7 +13,7 @@ from . import _utils, mermaid from ._utils import get_parent_namespace from .nodes import BaseNode, End, GraphContext, NodeDef -from .state import EndEvent, HistoryStep, NextNodeEvent, StateT +from .state import EndEvent, HistoryStep, NodeEvent, StateT __all__ = ('Graph',) @@ -81,7 +81,7 @@ async def next( if node_id not in self.node_defs: raise TypeError(f'Node "{node}" is not in the graph.') - history_step: NextNodeEvent[StateT, RunEndT] | None = NextNodeEvent(state, node) + history_step: NodeEvent[StateT, RunEndT] | None = NodeEvent(state, node) history.append(history_step) ctx = GraphContext(state) diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_ai_graph/pydantic_ai_graph/nodes.py index c5beb7b6d5..76b39aae64 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/nodes.py +++ b/pydantic_ai_graph/pydantic_ai_graph/nodes.py @@ -88,6 +88,6 @@ class NodeDef(Generic[StateT, NodeRunEndT]): next_node_ids: set[str] """IDs of the nodes that can be called next.""" returns_end: bool - """The node definition returns an `End`, hence the node and end the run.""" + """The node definition returns an `End`, hence the node can end the run.""" returns_base_node: bool """The node definition returns a `BaseNode`, hence any node in the next can be called next.""" diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_ai_graph/pydantic_ai_graph/state.py index ad8858c4a4..876038f9c4 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_ai_graph/pydantic_ai_graph/state.py @@ -10,7 +10,7 @@ from . import _utils -__all__ = 'AbstractState', 'StateT', 'NextNodeEvent', 'EndEvent', 'HistoryStep' +__all__ = 'AbstractState', 'StateT', 'NodeEvent', 'EndEvent', 'HistoryStep' if TYPE_CHECKING: from pydantic_ai_graph import BaseNode @@ -36,8 +36,8 @@ def deep_copy(self) -> Self: @dataclass -class NextNodeEvent(Generic[StateT, RunEndT]): - """History step describing the execution of a step of a graph.""" +class NodeEvent(Generic[StateT, RunEndT]): + """History step describing the execution of a node in a graph.""" state: StateT node: BaseNode[StateT, RunEndT] @@ -74,9 +74,9 @@ def summary(self) -> str: def _deep_copy_state(state: StateT) -> StateT: if state is None: - return None # pyright: ignore[reportReturnType] + return state else: return state.deep_copy() -HistoryStep = Union[NextNodeEvent[StateT, RunEndT], EndEvent[StateT, RunEndT]] +HistoryStep = Union[NodeEvent[StateT, RunEndT], EndEvent[StateT, RunEndT]] diff --git a/tests/test_graph.py b/tests/test_graph.py index cdef4f27d3..3d70bec92c 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -7,7 +7,7 @@ import pytest from inline_snapshot import snapshot -from pydantic_ai_graph import BaseNode, End, EndEvent, Graph, GraphContext, NextNodeEvent +from pydantic_ai_graph import BaseNode, End, EndEvent, Graph, GraphContext, NodeEvent from .conftest import IsFloat, IsNow @@ -45,19 +45,19 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq assert result == End(8) assert history == snapshot( [ - NextNodeEvent( + NodeEvent( state=None, node=Float2String(input_data=3.14), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NextNodeEvent( + NodeEvent( state=None, node=String2Length(input_data='3.14'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NextNodeEvent( + NodeEvent( state=None, node=Double(input_data=4), start_ts=IsNow(tz=timezone.utc), @@ -75,31 +75,31 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq assert result == End(42) assert history == snapshot( [ - NextNodeEvent( + NodeEvent( state=None, node=Float2String(input_data=3.14159), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NextNodeEvent( + NodeEvent( state=None, node=String2Length(input_data='3.14159'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NextNodeEvent( + NodeEvent( state=None, node=Double(input_data=7), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NextNodeEvent( + NodeEvent( state=None, node=String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NextNodeEvent( + NodeEvent( state=None, node=Double(input_data=21), start_ts=IsNow(tz=timezone.utc), From 7e24d9d06fe485fd3947bf4b2a66656490ccaaf4 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 12:15:56 +0000 Subject: [PATCH 25/57] switch name to pydantic-graph --- .github/workflows/ci.yml | 2 +- .../email_extract_graph.py | 2 +- mkdocs.yml | 2 +- pydantic_ai_graph/README.md | 21 ---------- pydantic_ai_slim/pyproject.toml | 4 +- pydantic_graph/README.md | 16 ++++++++ .../pydantic_graph}/__init__.py | 0 .../pydantic_graph}/_utils.py | 0 .../pydantic_graph}/graph.py | 2 +- .../pydantic_graph}/mermaid.py | 0 .../pydantic_graph}/nodes.py | 2 +- .../pydantic_graph}/py.typed | 0 .../pydantic_graph}/state.py | 4 +- .../pyproject.toml | 6 +-- pyproject.toml | 10 ++--- tests/test_graph.py | 2 +- tests/typed_graph.py | 2 +- uprev.py | 4 +- uv.lock | 40 +++++++++---------- 19 files changed, 57 insertions(+), 62 deletions(-) delete mode 100644 pydantic_ai_graph/README.md create mode 100644 pydantic_graph/README.md rename {pydantic_ai_graph/pydantic_ai_graph => pydantic_graph/pydantic_graph}/__init__.py (100%) rename {pydantic_ai_graph/pydantic_ai_graph => pydantic_graph/pydantic_graph}/_utils.py (100%) rename {pydantic_ai_graph/pydantic_ai_graph => pydantic_graph/pydantic_graph}/graph.py (98%) rename {pydantic_ai_graph/pydantic_ai_graph => pydantic_graph/pydantic_graph}/mermaid.py (100%) rename {pydantic_ai_graph/pydantic_ai_graph => pydantic_graph/pydantic_graph}/nodes.py (96%) rename {pydantic_ai_graph/pydantic_ai_graph => pydantic_graph/pydantic_graph}/py.typed (100%) rename {pydantic_ai_graph/pydantic_ai_graph => pydantic_graph/pydantic_graph}/state.py (96%) rename {pydantic_ai_graph => pydantic_graph}/pyproject.toml (93%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9cab6307bd..a3c7cf99a5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -129,7 +129,7 @@ jobs: - run: mkdir coverage - # run tests with just `pydantic-ai-slim` and `pydantic-ai-graph` dependencies + # run tests with just `pydantic-ai-slim` and `pydantic-graph` dependencies - run: uv run --package pydantic-ai-slim --extra graph coverage run -m pytest env: COVERAGE_FILE: coverage/.coverage.${{ runner.os }}-py${{ matrix.python-version }}-slim diff --git a/examples/pydantic_ai_examples/email_extract_graph.py b/examples/pydantic_ai_examples/email_extract_graph.py index 0c16e8bee0..82c98a370e 100644 --- a/examples/pydantic_ai_examples/email_extract_graph.py +++ b/examples/pydantic_ai_examples/email_extract_graph.py @@ -7,7 +7,7 @@ import logfire from devtools import debug from pydantic import BaseModel -from pydantic_ai_graph import AbstractState, BaseNode, End, Graph, GraphContext +from pydantic_graph import AbstractState, BaseNode, End, Graph, GraphContext from pydantic_ai import Agent, RunContext diff --git a/mkdocs.yml b/mkdocs.yml index d163fded70..781b1494e3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -159,7 +159,7 @@ plugins: - mkdocstrings: handlers: python: - paths: [src/packages/pydantic_ai_slim/pydantic_ai] + paths: [pydantic_ai_slim/pydantic_ai] options: relative_crossrefs: true members_order: source diff --git a/pydantic_ai_graph/README.md b/pydantic_ai_graph/README.md deleted file mode 100644 index 3f3a5bd72a..0000000000 --- a/pydantic_ai_graph/README.md +++ /dev/null @@ -1,21 +0,0 @@ -# PydanticAI Graph - -[![CI](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain) -[![Coverage](https://coverage-badge.samuelcolvin.workers.dev/pydantic/pydantic-ai.svg)](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai) -[![PyPI](https://img.shields.io/pypi/v/pydantic-ai-graph.svg)](https://pypi.python.org/pypi/pydantic-ai-graph) -[![versions](https://img.shields.io/pypi/pyversions/pydantic-ai-graph.svg)](https://github.com/pydantic/pydantic-ai) -[![license](https://img.shields.io/github/license/pydantic/pydantic-ai-graph.svg?v)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE) - -Graph and state machine library. - -This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency -on `pydantic-ai` or related packages and can be considered as a pure graph library. - -As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. - -`pydantic-ai-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes. - -When designing your graph and state machine, you need to identify the data types for the overall graph input, the final graph output, the graph dependency object and graph state. Then for each specific node in the graph, you have to identify the specific data type each node is expected to receive as the input type from the prior node in the graph during transitions. - -Once the nodes in the graph are defined, you can use certain built-in methods on the Graph object to visualize the nodes -and state transitions on the graph as mermaid diagrams. diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 693d27aaec..912ad47479 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -42,7 +42,7 @@ dependencies = [ ] [project.optional-dependencies] -graph = ["pydantic-ai-graph==0.0.14"] +graph = ["pydantic-graph==0.0.14"] openai = ["openai>=1.54.3"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"] anthropic = ["anthropic>=0.40.0"] @@ -68,4 +68,4 @@ dev = [ packages = ["pydantic_ai"] [tool.uv.sources] -pydantic-ai-graph = { workspace = true } +pydantic-graph = { workspace = true } diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md new file mode 100644 index 0000000000..2789d18ac1 --- /dev/null +++ b/pydantic_graph/README.md @@ -0,0 +1,16 @@ +# Pydantic Graph + +[![CI](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain) +[![Coverage](https://coverage-badge.samuelcolvin.workers.dev/pydantic/pydantic-ai.svg)](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai) +[![PyPI](https://img.shields.io/pypi/v/pydantic-graph.svg)](https://pypi.python.org/pypi/pydantic-graph) +[![versions](https://img.shields.io/pypi/pyversions/pydantic-graph.svg)](https://github.com/pydantic/pydantic-ai) +[![license](https://img.shields.io/github/license/pydantic/pydantic-ai.svg?v)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE) + +Graph and finite state machine library. + +This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency +on `pydantic-ai` or related packages and does and can be considered as a pure graph library. + +As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. + +`pydantic-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes. diff --git a/pydantic_ai_graph/pydantic_ai_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py similarity index 100% rename from pydantic_ai_graph/pydantic_ai_graph/__init__.py rename to pydantic_graph/pydantic_graph/__init__.py diff --git a/pydantic_ai_graph/pydantic_ai_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py similarity index 100% rename from pydantic_ai_graph/pydantic_ai_graph/_utils.py rename to pydantic_graph/pydantic_graph/_utils.py diff --git a/pydantic_ai_graph/pydantic_ai_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py similarity index 98% rename from pydantic_ai_graph/pydantic_ai_graph/graph.py rename to pydantic_graph/pydantic_graph/graph.py index 7c008469da..10e1b35e55 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -17,7 +17,7 @@ __all__ = ('Graph',) -_logfire = logfire_api.Logfire(otel_scope='pydantic-ai-graph') +_logfire = logfire_api.Logfire(otel_scope='pydantic-graph') RunSignatureT = ParamSpec('RunSignatureT') RunEndT = TypeVar('RunEndT', default=None) diff --git a/pydantic_ai_graph/pydantic_ai_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py similarity index 100% rename from pydantic_ai_graph/pydantic_ai_graph/mermaid.py rename to pydantic_graph/pydantic_graph/mermaid.py diff --git a/pydantic_ai_graph/pydantic_ai_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py similarity index 96% rename from pydantic_ai_graph/pydantic_ai_graph/nodes.py rename to pydantic_graph/pydantic_graph/nodes.py index 76b39aae64..bbadd8b380 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -77,7 +77,7 @@ class End(Generic[RunEndT]): class NodeDef(Generic[StateT, NodeRunEndT]): """Definition of a node. - Used by [`Graph`][pydantic_ai_graph.graph.Graph] to store information about a node, and when generating + Used by [`Graph`][pydantic_graph.graph.Graph] to store information about a node, and when generating mermaid graphs. """ diff --git a/pydantic_ai_graph/pydantic_ai_graph/py.typed b/pydantic_graph/pydantic_graph/py.typed similarity index 100% rename from pydantic_ai_graph/pydantic_ai_graph/py.typed rename to pydantic_graph/pydantic_graph/py.typed diff --git a/pydantic_ai_graph/pydantic_ai_graph/state.py b/pydantic_graph/pydantic_graph/state.py similarity index 96% rename from pydantic_ai_graph/pydantic_ai_graph/state.py rename to pydantic_graph/pydantic_graph/state.py index 876038f9c4..9b007836ee 100644 --- a/pydantic_ai_graph/pydantic_ai_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -13,8 +13,8 @@ __all__ = 'AbstractState', 'StateT', 'NodeEvent', 'EndEvent', 'HistoryStep' if TYPE_CHECKING: - from pydantic_ai_graph import BaseNode - from pydantic_ai_graph.nodes import End + from pydantic_graph import BaseNode + from pydantic_graph.nodes import End class AbstractState(ABC): diff --git a/pydantic_ai_graph/pyproject.toml b/pydantic_graph/pyproject.toml similarity index 93% rename from pydantic_ai_graph/pyproject.toml rename to pydantic_graph/pyproject.toml index cb047a9a8e..17aaa1fba7 100644 --- a/pydantic_ai_graph/pyproject.toml +++ b/pydantic_graph/pyproject.toml @@ -3,8 +3,8 @@ requires = ["hatchling"] build-backend = "hatchling.build" [project] -name = "pydantic-ai-graph" -version = "0.0.14" +name = "pydantic-graph" +version = "0.0.17" description = "Graph and state machine library" authors = [ { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, @@ -40,4 +40,4 @@ dependencies = [ ] [tool.hatch.build.targets.wheel] -packages = ["pydantic_ai_graph"] +packages = ["pydantic_graph"] diff --git a/pyproject.toml b/pyproject.toml index e1edd3ca85..f26156bbed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,11 +51,11 @@ logfire = ["logfire>=2.3"] [tool.uv.sources] pydantic-ai-slim = { workspace = true } -pydantic-ai-graph = { workspace = true } +pydantic-graph = { workspace = true } pydantic-ai-examples = { workspace = true } [tool.uv.workspace] -members = ["pydantic_ai_slim", "pydantic_ai_graph", "examples"] +members = ["pydantic_ai_slim", "pydantic_graph", "examples"] [dependency-groups] # dev dependencies are defined in `pydantic-ai-slim/pyproject.toml` to allow for minimal testing @@ -86,7 +86,7 @@ line-length = 120 target-version = "py39" include = [ "pydantic_ai_slim/**/*.py", - "pydantic_ai_graph/**/*.py", + "pydantic_graph/**/*.py", "examples/**/*.py", "tests/**/*.py", "docs/**/*.py", @@ -130,7 +130,7 @@ typeCheckingMode = "strict" reportMissingTypeStubs = false reportUnnecessaryIsInstance = false reportUnnecessaryTypeIgnoreComment = true -include = ["pydantic_ai_slim", "pydantic_ai_graph", "tests", "examples"] +include = ["pydantic_ai_slim", "pydantic_graph", "tests", "examples"] venvPath = ".venv" # see https://github.com/microsoft/pyright/issues/7771 - we don't want to error on decorated functions in tests # which are not otherwise used @@ -153,7 +153,7 @@ filterwarnings = [ # https://coverage.readthedocs.io/en/latest/config.html#run [tool.coverage.run] # required to avoid warnings about files created by create_module fixture -include = ["pydantic_ai_slim/**/*.py", "pydantic_ai_graph/**/*.py","tests/**/*.py"] +include = ["pydantic_ai_slim/**/*.py", "pydantic_graph/**/*.py","tests/**/*.py"] omit = ["tests/test_live.py", "tests/example_modules/*.py"] branch = true diff --git a/tests/test_graph.py b/tests/test_graph.py index 3d70bec92c..7c3eba97aa 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -7,7 +7,7 @@ import pytest from inline_snapshot import snapshot -from pydantic_ai_graph import BaseNode, End, EndEvent, Graph, GraphContext, NodeEvent +from pydantic_graph import BaseNode, End, EndEvent, Graph, GraphContext, NodeEvent from .conftest import IsFloat, IsNow diff --git a/tests/typed_graph.py b/tests/typed_graph.py index c131f19b4c..8c6ff9b0ea 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -4,7 +4,7 @@ from typing_extensions import assert_type -from pydantic_ai_graph import BaseNode, End, Graph, GraphContext +from pydantic_graph import BaseNode, End, Graph, GraphContext @dataclass diff --git a/uprev.py b/uprev.py index 3eeca56d01..9ded772252 100644 --- a/uprev.py +++ b/uprev.py @@ -5,7 +5,7 @@ * pyproject.toml * examples/pyproject.toml * pydantic_ai_slim/pyproject.toml -* pydantic_ai_graph/pyproject.toml +* pydantic_graph/pyproject.toml Usage: @@ -68,7 +68,7 @@ def replace_deps_version(text: str) -> tuple[str, int]: slim_pp_text = slim_pp.read_text() slim_pp_text, count_slim = replace_deps_version(slim_pp_text) -graph_pp = ROOT_DIR / 'pydantic_ai_graph' / 'pyproject.toml' +graph_pp = ROOT_DIR / 'pydantic_graph' / 'pyproject.toml' graph_pp_text = graph_pp.read_text() graph_pp_text, count_graph = replace_deps_version(graph_pp_text) diff --git a/uv.lock b/uv.lock index f7627c796b..498d850d4c 100644 --- a/uv.lock +++ b/uv.lock @@ -12,8 +12,8 @@ resolution-markers = [ members = [ "pydantic-ai", "pydantic-ai-examples", - "pydantic-ai-graph", "pydantic-ai-slim", + "pydantic-graph", ] [[package]] @@ -2540,23 +2540,6 @@ requires-dist = [ { name = "uvicorn", specifier = ">=0.32.0" }, ] -[[package]] -name = "pydantic-ai-graph" -version = "0.0.14" -source = { editable = "pydantic_ai_graph" } -dependencies = [ - { name = "httpx" }, - { name = "logfire-api" }, - { name = "pydantic" }, -] - -[package.metadata] -requires-dist = [ - { name = "httpx", specifier = ">=0.27.2" }, - { name = "logfire-api", specifier = ">=1.2.0" }, - { name = "pydantic", specifier = ">=2.10" }, -] - [[package]] name = "pydantic-ai-slim" version = "0.0.18" @@ -2574,7 +2557,7 @@ anthropic = [ { name = "anthropic" }, ] graph = [ - { name = "pydantic-ai-graph" }, + { name = "pydantic-graph" }, ] groq = [ { name = "groq" }, @@ -2621,7 +2604,7 @@ requires-dist = [ { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.2.5" }, { name = "openai", marker = "extra == 'openai'", specifier = ">=1.54.3" }, { name = "pydantic", specifier = ">=2.10" }, - { name = "pydantic-ai-graph", marker = "extra == 'graph'", editable = "pydantic_ai_graph" }, + { name = "pydantic-graph", marker = "extra == 'graph'", editable = "pydantic_graph" }, { name = "requests", marker = "extra == 'vertexai'", specifier = ">=2.32.3" }, ] @@ -2745,6 +2728,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a6/53/d78dc063216e62fc55f6b2eebb447f6a4b0a59f55c8406376f76bf959b08/pydub-0.25.1-py2.py3-none-any.whl", hash = "sha256:65617e33033874b59d87db603aa1ed450633288aefead953b30bded59cb599a6", size = 32327 }, ] +[[package]] +name = "pydantic-graph" +version = "0.0.17" +source = { editable = "pydantic_graph" } +dependencies = [ + { name = "httpx" }, + { name = "logfire-api" }, + { name = "pydantic" }, +] + +[package.metadata] +requires-dist = [ + { name = "httpx", specifier = ">=0.27.2" }, + { name = "logfire-api", specifier = ">=1.2.0" }, + { name = "pydantic", specifier = ">=2.10" }, +] + [[package]] name = "pygments" version = "2.18.0" From b74d0e495ac9e2522ea5ee8060d72994dfadcfb5 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 15:34:08 +0000 Subject: [PATCH 26/57] allow labeling edges and notes for docstrings --- pydantic_graph/pydantic_graph/__init__.py | 3 +- pydantic_graph/pydantic_graph/_utils.py | 14 ++++++- pydantic_graph/pydantic_graph/graph.py | 2 +- pydantic_graph/pydantic_graph/mermaid.py | 44 ++++++++++++++++------ pydantic_graph/pydantic_graph/nodes.py | 45 +++++++++++++++++------ pydantic_graph/pydantic_graph/state.py | 4 +- 6 files changed, 84 insertions(+), 28 deletions(-) diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index 51db45c53c..9284c250f8 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -1,5 +1,5 @@ from .graph import Graph -from .nodes import BaseNode, End, GraphContext +from .nodes import BaseNode, Edge, End, GraphContext from .state import AbstractState, EndEvent, HistoryStep, NodeEvent __all__ = ( @@ -7,6 +7,7 @@ 'BaseNode', 'End', 'GraphContext', + 'Edge', 'AbstractState', 'EndEvent', 'HistoryStep', diff --git a/pydantic_graph/pydantic_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py index a451fc5533..26743fb49d 100644 --- a/pydantic_graph/pydantic_graph/_utils.py +++ b/pydantic_graph/pydantic_graph/_utils.py @@ -3,7 +3,7 @@ import sys import types from datetime import datetime, timezone -from typing import Any, Union, get_args, get_origin +from typing import Annotated, Any, Union, get_args, get_origin from typing_extensions import TypeAliasType @@ -21,6 +21,18 @@ def get_union_args(tp: Any) -> tuple[Any, ...]: return (tp,) +def strip_annotated(tp: Any) -> tuple[Any, list[Any]]: + """Strip `Annotated` from the type if present. + + Returns: `(tp argument, ())` if not annotated, otherwise `(stripped type, annotations)`. + """ + if get_origin(tp) is Annotated: + inner_tp, *args = get_args(tp) + return inner_tp, args + else: + return tp, [] + + # same as `pydantic_ai_slim/pydantic_ai/_result.py:origin_is_union` if sys.version_info < (3, 10): diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 10e1b35e55..82ac4d2576 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -62,7 +62,7 @@ def _validate_edges(self): bad_edges: dict[str, list[str]] = {} for node_id, node_def in self.node_defs.items(): - node_bad_edges = node_def.next_node_ids - known_node_ids + node_bad_edges = node_def.next_node_edges.keys() - known_node_ids for bad_edge in node_bad_edges: bad_edges.setdefault(bad_edge, []).append(f'"{node_id}"') diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index 534e1b5227..e0a66984e3 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -1,8 +1,10 @@ from __future__ import annotations as _annotations import base64 +import re from collections.abc import Iterable, Sequence from pathlib import Path +from textwrap import indent from typing import TYPE_CHECKING, Annotated, Any, Literal from annotated_types import Ge, Le @@ -25,6 +27,8 @@ def generate_code( start_node: Sequence[NodeIdent] | NodeIdent | None = None, highlighted_nodes: Sequence[NodeIdent] | NodeIdent | None = None, highlight_css: str = DEFAULT_HIGHLIGHT_CSS, + edge_labels: bool = True, + docstring_notes: bool = True, ) -> str: """Generate Mermaid code for a graph. @@ -33,6 +37,8 @@ def generate_code( start_node: Identifiers of nodes that start the graph. highlighted_nodes: Identifiers of nodes to highlight. highlight_css: CSS to use for highlighting nodes. + edge_labels: Whether to include edge labels in the diagram, defaults to true. + docstring_notes: Whether to include docstrings as notes in the diagram, defaults to true. Returns: The Mermaid code for the graph. """ @@ -41,25 +47,34 @@ def generate_code( if node_id not in graph.node_defs: raise LookupError(f'Start node "{node_id}" is not in the graph.') - node_order = {node_id: index for index, node_id in enumerate(graph.node_defs)} - - lines = ['graph TD'] + lines = ['stateDiagram-v2'] for node in graph.nodes: node_id = node.get_id() node_def = graph.node_defs[node_id] # we use round brackets (rounded box) for nodes other than the start and end - mermaid_name = f'({node_id})' if node_id in start_node_ids: - lines.append(f' START --> {node_id}{mermaid_name}') + lines.append(f' [*] --> {node_id}') if node_def.returns_base_node: for next_node_id in graph.nodes: - lines.append(f' {node_id}{mermaid_name} --> {next_node_id}') + lines.append(f' {node_id} --> {next_node_id}') else: - for _, next_node_id in sorted((node_order[node_id], node_id) for node_id in node_def.next_node_ids): - lines.append(f' {node_id}{mermaid_name} --> {next_node_id}') - if node_def.returns_end: - lines.append(f' {node_id}{mermaid_name} --> END') + for next_node_id, edge in node_def.next_node_edges.items(): + if edge_labels and (label := edge.label): + lines.append(f' {node_id} --> {next_node_id}: {label}') + else: + lines.append(f' {node_id} --> {next_node_id}') + if end_edge := node_def.end_edge: + if edge_labels and (label := end_edge.label): + lines.append(f' {node_id} --> [*]: {label}') + else: + lines.append(f' {node_id} --> [*]') + + if docstring_notes and node_def.doc_string: + lines.append(f' note right of {node_id}') + clean_docs = re.sub('\n\n+', '\n', node_def.doc_string) + lines.append(indent(clean_docs, ' ')) + lines.append(' end note') if highlighted_nodes: lines.append('') @@ -97,6 +112,10 @@ class MermaidConfig(TypedDict, total=False): """Identifiers of nodes to highlight.""" highlight_css: str """CSS to use for highlighting nodes.""" + edge_labels: bool + """Whether to include edge labels in the diagram.""" + docstring_notes: bool + """Whether to include docstrings as notes in the diagram, defaults to true.""" image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] """The image type to generate. If unspecified, the default behavior is `'jpeg'`.""" pdf_fit: bool @@ -148,6 +167,8 @@ def request_image( start_node=kwargs.get('start_node'), highlighted_nodes=kwargs.get('highlighted_nodes'), highlight_css=kwargs.get('highlight_css', DEFAULT_HIGHLIGHT_CSS), + edge_labels=kwargs.get('edge_labels', True), + docstring_notes=kwargs.get('docstring_notes', True), ) code_base64 = base64.b64encode(code.encode()).decode() @@ -180,7 +201,8 @@ def request_image( params['scale'] = str(scale) response = httpx.get(url, params=params) - response.raise_for_status() + if not response.is_success: + raise ValueError(f'{response.status_code} error generating image:\n{response.text}') return response.content diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index bbadd8b380..c984c6f76a 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -1,7 +1,8 @@ from __future__ import annotations as _annotations +import inspect from abc import abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, is_dataclass from functools import cache from typing import Any, Generic, get_origin, get_type_hints @@ -10,7 +11,7 @@ from . import _utils from .state import StateT -__all__ = 'GraphContext', 'BaseNode', 'End', 'NodeDef' +__all__ = 'GraphContext', 'BaseNode', 'End', 'Edge', 'NodeDef' RunEndT = TypeVar('RunEndT', default=None) NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) @@ -36,32 +37,43 @@ def get_id(cls) -> str: @classmethod def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRunEndT]: - type_hints = get_type_hints(cls.run, localns=local_ns) + type_hints = get_type_hints(cls.run, localns=local_ns, include_extras=True) try: return_hint = type_hints['return'] except KeyError: raise TypeError(f'Node {cls} is missing a return type hint on its `run` method') - next_node_ids: set[str] = set() - returns_end: bool = False + next_node_edges: dict[str, Edge] = {} + end_edge: Edge | None = None returns_base_node: bool = False for return_type in _utils.get_union_args(return_hint): + return_type, annotations = _utils.strip_annotated(return_type) + edge = next((a for a in annotations if isinstance(a, Edge)), Edge(None)) return_type_origin = get_origin(return_type) or return_type if return_type_origin is End: - returns_end = True + end_edge = edge elif return_type_origin is BaseNode: # TODO: Should we disallow this? returns_base_node = True elif issubclass(return_type_origin, BaseNode): - next_node_ids.add(return_type.get_id()) + next_node_edges[return_type.get_id()] = edge else: raise TypeError(f'Invalid return type: {return_type}') + docstring = cls.__doc__ + # dataclasses get an automatic docstring which is just their signature, we don't want that + if docstring and is_dataclass(cls) and docstring.startswith(f'{cls.__name__}('): + docstring = None + if docstring: + # remove indentation from docstring + docstring = inspect.cleandoc(docstring) + return NodeDef( cls, cls.get_id(), - next_node_ids, - returns_end, + docstring, + next_node_edges, + end_edge, returns_base_node, ) @@ -73,6 +85,13 @@ class End(Generic[RunEndT]): data: RunEndT +@dataclass +class Edge: + """Annotation to apply a label to an edge in a graph.""" + + label: str | None + + @dataclass class NodeDef(Generic[StateT, NodeRunEndT]): """Definition of a node. @@ -85,9 +104,11 @@ class NodeDef(Generic[StateT, NodeRunEndT]): """The node definition itself.""" node_id: str """ID of the node.""" - next_node_ids: set[str] + doc_string: str | None + """Docstring of the node.""" + next_node_edges: dict[str, Edge] """IDs of the nodes that can be called next.""" - returns_end: bool - """The node definition returns an `End`, hence the node can end the run.""" + end_edge: Edge | None + """If node definition returns an `End` this is an Edge, indicating the node can end the run.""" returns_base_node: bool """The node definition returns a `BaseNode`, hence any node in the next can be called next.""" diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index 9b007836ee..b8dbff7a72 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -4,9 +4,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Generic, Literal, Self, Union +from typing import TYPE_CHECKING, Generic, Literal, Union -from typing_extensions import Never, TypeVar +from typing_extensions import Never, Self, TypeVar from . import _utils From 7f34a0d1a29b293403a4a517f0069de081edbd01 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 18:11:47 +0000 Subject: [PATCH 27/57] allow notes to be disabled --- pydantic_graph/pydantic_graph/_utils.py | 2 +- pydantic_graph/pydantic_graph/mermaid.py | 42 +++++++++++++----------- pydantic_graph/pydantic_graph/nodes.py | 41 ++++++++++++++--------- 3 files changed, 49 insertions(+), 36 deletions(-) diff --git a/pydantic_graph/pydantic_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py index 26743fb49d..498c5bdde6 100644 --- a/pydantic_graph/pydantic_graph/_utils.py +++ b/pydantic_graph/pydantic_graph/_utils.py @@ -21,7 +21,7 @@ def get_union_args(tp: Any) -> tuple[Any, ...]: return (tp,) -def strip_annotated(tp: Any) -> tuple[Any, list[Any]]: +def unpack_annotated(tp: Any) -> tuple[Any, list[Any]]: """Strip `Annotated` from the type if present. Returns: `(tp argument, ())` if not annotated, otherwise `(stripped type, annotations)`. diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index e0a66984e3..e3e03888d7 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -16,6 +16,9 @@ from .graph import Graph +__all__ = 'NodeIdent', 'DEFAULT_HIGHLIGHT_CSS', 'generate_code', 'MermaidConfig', 'request_image', 'save_image' + + NodeIdent: TypeAlias = 'type[BaseNode[Any, Any]] | BaseNode[Any, Any] | str' DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32' @@ -28,21 +31,21 @@ def generate_code( highlighted_nodes: Sequence[NodeIdent] | NodeIdent | None = None, highlight_css: str = DEFAULT_HIGHLIGHT_CSS, edge_labels: bool = True, - docstring_notes: bool = True, + notes: bool = True, ) -> str: - """Generate Mermaid code for a graph. + """Generate [Mermaid state diagram](https://mermaid.js.org/syntax/stateDiagram.html) code for a graph. Args: graph: The graph to generate the image for. start_node: Identifiers of nodes that start the graph. highlighted_nodes: Identifiers of nodes to highlight. highlight_css: CSS to use for highlighting nodes. - edge_labels: Whether to include edge labels in the diagram, defaults to true. - docstring_notes: Whether to include docstrings as notes in the diagram, defaults to true. + edge_labels: Whether to include edge labels in the diagram. + notes: Whether to include notes in the diagram. Returns: The Mermaid code for the graph. """ - start_node_ids = set(node_ids(start_node or ())) + start_node_ids = set(_node_ids(start_node or ())) for node_id in start_node_ids: if node_id not in graph.node_defs: raise LookupError(f'Start node "{node_id}" is not in the graph.') @@ -60,26 +63,27 @@ def generate_code( lines.append(f' {node_id} --> {next_node_id}') else: for next_node_id, edge in node_def.next_node_edges.items(): - if edge_labels and (label := edge.label): - lines.append(f' {node_id} --> {next_node_id}: {label}') - else: - lines.append(f' {node_id} --> {next_node_id}') + line = f' {node_id} --> {next_node_id}' + if edge_labels and edge.label: + line += f': {edge.label}' + lines.append(line) if end_edge := node_def.end_edge: - if edge_labels and (label := end_edge.label): - lines.append(f' {node_id} --> [*]: {label}') - else: - lines.append(f' {node_id} --> [*]') + line = f' {node_id} --> [*]' + if edge_labels and end_edge.label: + line += f': {end_edge.label}' + lines.append(line) - if docstring_notes and node_def.doc_string: + if notes and node_def.note: lines.append(f' note right of {node_id}') - clean_docs = re.sub('\n\n+', '\n', node_def.doc_string) + # mermaid doesn't like multiple paragraphs in a note, and shows if so + clean_docs = re.sub('\n{2,}', '\n', node_def.note) lines.append(indent(clean_docs, ' ')) lines.append(' end note') if highlighted_nodes: lines.append('') lines.append(f'classDef highlighted {highlight_css}') - for node_id in node_ids(highlighted_nodes): + for node_id in _node_ids(highlighted_nodes): if node_id not in graph.node_defs: raise LookupError(f'Highlighted node "{node_id}" is not in the graph.') lines.append(f'class {node_id} highlighted') @@ -87,7 +91,7 @@ def generate_code( return '\n'.join(lines) -def node_ids(node_idents: Sequence[NodeIdent] | NodeIdent) -> Iterable[str]: +def _node_ids(node_idents: Sequence[NodeIdent] | NodeIdent) -> Iterable[str]: """Get the node IDs from a sequence of node identifiers.""" if isinstance(node_idents, str): node_iter = (node_idents,) @@ -114,7 +118,7 @@ class MermaidConfig(TypedDict, total=False): """CSS to use for highlighting nodes.""" edge_labels: bool """Whether to include edge labels in the diagram.""" - docstring_notes: bool + notes: bool """Whether to include docstrings as notes in the diagram, defaults to true.""" image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] """The image type to generate. If unspecified, the default behavior is `'jpeg'`.""" @@ -168,7 +172,7 @@ def request_image( highlighted_nodes=kwargs.get('highlighted_nodes'), highlight_css=kwargs.get('highlight_css', DEFAULT_HIGHLIGHT_CSS), edge_labels=kwargs.get('edge_labels', True), - docstring_notes=kwargs.get('docstring_notes', True), + notes=kwargs.get('notes', True), ) code_base64 = base64.b64encode(code.encode()).decode() diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index c984c6f76a..19b42cb7ae 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -1,10 +1,9 @@ from __future__ import annotations as _annotations -import inspect -from abc import abstractmethod +from abc import ABC, abstractmethod from dataclasses import dataclass, is_dataclass from functools import cache -from typing import Any, Generic, get_origin, get_type_hints +from typing import Any, ClassVar, Generic, get_origin, get_type_hints from typing_extensions import Never, TypeVar @@ -24,9 +23,12 @@ class GraphContext(Generic[StateT]): state: StateT -class BaseNode(Generic[StateT, NodeRunEndT]): +class BaseNode(ABC, Generic[StateT, NodeRunEndT]): """Base class for a node.""" + enable_docstring_notes: ClassVar[bool] = True + """Set to `False` to not generate mermaid diagram notes from the class's docstring.""" + @abstractmethod async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: ... @@ -35,6 +37,21 @@ async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[No def get_id(cls) -> str: return cls.__name__ + @classmethod + def get_note(cls) -> str | None: + if not cls.enable_docstring_notes: + return None + docstring = cls.__doc__ + # dataclasses get an automatic docstring which is just their signature, we don't want that + if docstring and is_dataclass(cls) and docstring.startswith(f'{cls.__name__}('): + docstring = None + if docstring: + # remove indentation from docstring + import inspect + + docstring = inspect.cleandoc(docstring) + return docstring + @classmethod def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRunEndT]: type_hints = get_type_hints(cls.run, localns=local_ns, include_extras=True) @@ -47,7 +64,7 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRu end_edge: Edge | None = None returns_base_node: bool = False for return_type in _utils.get_union_args(return_hint): - return_type, annotations = _utils.strip_annotated(return_type) + return_type, annotations = _utils.unpack_annotated(return_type) edge = next((a for a in annotations if isinstance(a, Edge)), Edge(None)) return_type_origin = get_origin(return_type) or return_type if return_type_origin is End: @@ -60,18 +77,10 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRu else: raise TypeError(f'Invalid return type: {return_type}') - docstring = cls.__doc__ - # dataclasses get an automatic docstring which is just their signature, we don't want that - if docstring and is_dataclass(cls) and docstring.startswith(f'{cls.__name__}('): - docstring = None - if docstring: - # remove indentation from docstring - docstring = inspect.cleandoc(docstring) - return NodeDef( cls, cls.get_id(), - docstring, + cls.get_note(), next_node_edges, end_edge, returns_base_node, @@ -104,8 +113,8 @@ class NodeDef(Generic[StateT, NodeRunEndT]): """The node definition itself.""" node_id: str """ID of the node.""" - doc_string: str | None - """Docstring of the node.""" + note: str | None + """Note about the node to render on mermaid charts.""" next_node_edges: dict[str, Edge] """IDs of the nodes that can be called next.""" end_edge: Edge | None From 15573e937972b3fee676b8a8f8478c589474f3a1 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 20:02:03 +0000 Subject: [PATCH 28/57] adding graph tests --- pydantic_graph/pydantic_graph/__init__.py | 3 + pydantic_graph/pydantic_graph/exceptions.py | 20 +++ pydantic_graph/pydantic_graph/graph.py | 44 ++++-- pydantic_graph/pydantic_graph/mermaid.py | 2 +- tests/test_graph.py | 113 +++++++++++++- tests/test_graph_mermaid.py | 165 ++++++++++++++++++++ 6 files changed, 326 insertions(+), 21 deletions(-) create mode 100644 pydantic_graph/pydantic_graph/exceptions.py create mode 100644 tests/test_graph_mermaid.py diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index 9284c250f8..192b62f98c 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -1,3 +1,4 @@ +from .exceptions import GraphRuntimeError, GraphSetupError from .graph import Graph from .nodes import BaseNode, Edge, End, GraphContext from .state import AbstractState, EndEvent, HistoryStep, NodeEvent @@ -12,4 +13,6 @@ 'EndEvent', 'HistoryStep', 'NodeEvent', + 'GraphSetupError', + 'GraphRuntimeError', ) diff --git a/pydantic_graph/pydantic_graph/exceptions.py b/pydantic_graph/pydantic_graph/exceptions.py new file mode 100644 index 0000000000..5288402c36 --- /dev/null +++ b/pydantic_graph/pydantic_graph/exceptions.py @@ -0,0 +1,20 @@ +class GraphSetupError(TypeError): + """Error caused by an incorrectly configured graph.""" + + message: str + """Description of the mistake.""" + + def __init__(self, message: str): + self.message = message + super().__init__(message) + + +class GraphRuntimeError(RuntimeError): + """Error caused by an issue during graph execution.""" + + message: str + """The error message.""" + + def __init__(self, message: str): + self.message = message + super().__init__(message) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 82ac4d2576..2141333623 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -12,6 +12,7 @@ from . import _utils, mermaid from ._utils import get_parent_namespace +from .exceptions import GraphRuntimeError, GraphSetupError from .nodes import BaseNode, End, GraphContext, NodeDef from .state import EndEvent, HistoryStep, NodeEvent, StateT @@ -44,42 +45,42 @@ def __init__( _nodes_by_id: dict[str, type[BaseNode[StateT, RunEndT]]] = {} for node in nodes: node_id = node.get_id() - if (existing_node := _nodes_by_id.get(node_id)) and existing_node is not node: - raise ValueError(f'Node ID "{node_id}" is not unique — found in {existing_node} and {node}') + if existing_node := _nodes_by_id.get(node_id): + raise GraphSetupError(f'Node ID `{node_id}` is not unique — found in {existing_node} and {node}') else: _nodes_by_id[node_id] = node self.nodes = tuple(_nodes_by_id.values()) parent_namespace = get_parent_namespace(inspect.currentframe()) - self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = {} - for node in self.nodes: - self.node_defs[node.get_id()] = node.get_node_def(parent_namespace) + self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = { + node.get_id(): node.get_node_def(parent_namespace) for node in self.nodes + } self._validate_edges() def _validate_edges(self): - known_node_ids = set(self.node_defs.keys()) + known_node_ids = self.node_defs.keys() bad_edges: dict[str, list[str]] = {} for node_id, node_def in self.node_defs.items(): - node_bad_edges = node_def.next_node_edges.keys() - known_node_ids - for bad_edge in node_bad_edges: - bad_edges.setdefault(bad_edge, []).append(f'"{node_id}"') + for edge in node_def.next_node_edges.keys(): + if edge not in known_node_ids: + bad_edges.setdefault(edge, []).append(f'`{node_id}`') if bad_edges: - bad_edges_list = [f'"{k}" is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] + bad_edges_list = [f'`{k}` is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] if len(bad_edges_list) == 1: - raise ValueError(f'{bad_edges_list[0]} but not included in the graph.') + raise GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.') else: b = '\n'.join(f' {be}' for be in bad_edges_list) - raise ValueError(f'Nodes are referenced in the graph but not included in the graph:\n{b}') + raise GraphSetupError(f'Nodes are referenced in the graph but not included in the graph:\n{b}') async def next( self, state: StateT, node: BaseNode[StateT, RunEndT], history: list[HistoryStep[StateT, RunEndT]] ) -> BaseNode[StateT, Any] | End[RunEndT]: node_id = node.get_id() if node_id not in self.node_defs: - raise TypeError(f'Node "{node}" is not in the graph.') + raise GraphRuntimeError(f'Node `{node}` is not in the graph.') history_step: NodeEvent[StateT, RunEndT] | None = NodeEvent(state, node) history.append(history_step) @@ -95,7 +96,7 @@ async def run( self, state: StateT, node: BaseNode[StateT, RunEndT], - ) -> tuple[End[RunEndT], list[HistoryStep[StateT, RunEndT]]]: + ) -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: history: list[HistoryStep[StateT, RunEndT]] = [] with _logfire.span( @@ -108,14 +109,16 @@ async def run( if isinstance(next_node, End): history.append(EndEvent(state, next_node)) run_span.set_attribute('history', history) - return next_node, history + return next_node.data, history elif isinstance(next_node, BaseNode): node = next_node else: if TYPE_CHECKING: assert_never(next_node) else: - raise TypeError(f'Invalid node type: {type(next_node)}. Expected `BaseNode` or `End`.') + raise GraphRuntimeError( + f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.' + ) def mermaid_code( self, @@ -123,9 +126,16 @@ def mermaid_code( start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, + edge_labels: bool = True, + notes: bool = True, ) -> str: return mermaid.generate_code( - self, start_node=start_node, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css + self, + start_node=start_node, + highlighted_nodes=highlighted_nodes, + highlight_css=highlight_css, + edge_labels=edge_labels, + notes=notes, ) def mermaid_image(self, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes: diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index e3e03888d7..192d9f1e25 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -59,7 +59,7 @@ def generate_code( if node_id in start_node_ids: lines.append(f' [*] --> {node_id}') if node_def.returns_base_node: - for next_node_id in graph.nodes: + for next_node_id in graph.node_defs: lines.append(f' {node_id} --> {next_node_id}') else: for next_node_id, edge in node_def.next_node_edges.items(): diff --git a/tests/test_graph.py b/tests/test_graph.py index 7c3eba97aa..801bf1d876 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -2,12 +2,14 @@ from dataclasses import dataclass from datetime import timezone +from functools import cache from typing import Union import pytest +from dirty_equals import IsStr from inline_snapshot import snapshot -from pydantic_graph import BaseNode, End, EndEvent, Graph, GraphContext, NodeEvent +from pydantic_graph import BaseNode, End, EndEvent, Graph, GraphContext, GraphRuntimeError, GraphSetupError, NodeEvent from .conftest import IsFloat, IsNow @@ -42,7 +44,7 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq g = Graph[None, int](nodes=(Float2String, String2Length, Double)) result, history = await g.run(None, Float2String(3.14)) # len('3.14') * 2 == 8 - assert result == End(8) + assert result == 8 assert history == snapshot( [ NodeEvent( @@ -72,7 +74,7 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq ) result, history = await g.run(None, Float2String(3.14159)) # len('3.14159') == 7, 21 * 2 == 42 - assert result == End(42) + assert result == 42 assert history == snapshot( [ NodeEvent( @@ -112,3 +114,108 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq ), ] ) + + +def test_one_bad_node(): + class Float2String(BaseNode): + async def run(self, ctx: GraphContext) -> String2Length: + return String2Length() + + class String2Length(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> End[None]: + return End(None) + + with pytest.raises(GraphSetupError) as exc_info: + Graph(nodes=(Float2String,)) + + assert exc_info.value.message == snapshot( + '`String2Length` is referenced by `Float2String` but not included in the graph.' + ) + + +def test_two_bad_nodes(): + class Float2String(BaseNode): + input_data: float + + async def run(self, ctx: GraphContext) -> String2Length | Double: + raise NotImplementedError() + + class String2Length(BaseNode[None, None]): + input_data: str + + async def run(self, ctx: GraphContext) -> End[None]: + return End(None) + + class Double(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> End[None]: + return End(None) + + with pytest.raises(GraphSetupError) as exc_info: + Graph(nodes=(Float2String,)) + + assert exc_info.value.message == snapshot("""\ +Nodes are referenced in the graph but not included in the graph: + `String2Length` is referenced by `Float2String` + `Double` is referenced by `Float2String`\ +""") + + +def test_duplicate_id(): + class Foo(BaseNode): + async def run(self, ctx: GraphContext) -> Bar: + return Bar() + + class Bar(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> End[None]: + return End(None) + + @classmethod + @cache + def get_id(cls) -> str: + return 'Foo' + + with pytest.raises(GraphSetupError) as exc_info: + Graph(nodes=(Foo, Bar)) + + assert exc_info.value.message == snapshot(IsStr(regex='Node ID `Foo` is not unique — found in.+')) + + +async def test_run_node_not_in_graph(): + @dataclass + class Foo(BaseNode): + async def run(self, ctx: GraphContext) -> Bar: + return Bar() + + @dataclass + class Bar(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> End[None]: + return Spam() # type: ignore + + @dataclass + class Spam(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> End[None]: + return End(None) + + g = Graph(nodes=(Foo, Bar)) + with pytest.raises(GraphRuntimeError) as exc_info: + await g.run(None, Foo()) + + assert exc_info.value.message == snapshot('Node `test_run_node_not_in_graph..Spam()` is not in the graph.') + + +async def test_run_return_other(): + @dataclass + class Foo(BaseNode): + async def run(self, ctx: GraphContext) -> Bar: + return Bar() + + @dataclass + class Bar(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> End[None]: + return 42 # type: ignore + + g = Graph(nodes=(Foo, Bar)) + with pytest.raises(GraphRuntimeError) as exc_info: + await g.run(None, Foo()) + + assert exc_info.value.message == snapshot('Invalid node return type: `int`. Expected `BaseNode` or `End`.') diff --git a/tests/test_graph_mermaid.py b/tests/test_graph_mermaid.py new file mode 100644 index 0000000000..bb3931fcfa --- /dev/null +++ b/tests/test_graph_mermaid.py @@ -0,0 +1,165 @@ +from __future__ import annotations as _annotations + +from dataclasses import dataclass +from datetime import timezone +from typing import Annotated + +import pytest +from inline_snapshot import snapshot + +from pydantic_graph import BaseNode, Edge, End, EndEvent, Graph, GraphContext, NodeEvent + +from .conftest import IsFloat, IsNow + +pytestmark = pytest.mark.anyio + + +@dataclass +class Foo(BaseNode): + async def run(self, ctx: GraphContext) -> Bar: + return Bar() + + +@dataclass +class Bar(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> End[None]: + return End(None) + + +graph1 = Graph(nodes=(Foo, Bar)) + + +@dataclass +class Spam(BaseNode): + """This is the docstring for Spam.""" + + async def run(self, ctx: GraphContext) -> Annotated[Foo, Edge(label='spam to foo')]: + return Foo() + + +@dataclass +class Eggs(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> Annotated[End[None], Edge(label='eggs to end')]: + return End(None) + + +graph2 = Graph(nodes=(Spam, Foo, Bar, Eggs)) + + +async def test_run_graph(): + result, history = await graph1.run(None, Foo()) + assert result is None + assert history == snapshot( + [ + NodeEvent( + state=None, + node=Foo(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + ), + NodeEvent( + state=None, + node=Bar(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + ), + EndEvent(state=None, result=End(data=None), ts=IsNow(tz=timezone.utc)), + ] + ) + + +def test_mermaid_code_no_start(): + assert graph1.mermaid_code() == snapshot("""\ +stateDiagram-v2 + Foo --> Bar + Bar --> [*]\ +""") + + +def test_mermaid_code_start(): + assert graph1.mermaid_code(start_node=Foo) == snapshot("""\ +stateDiagram-v2 + [*] --> Foo + Foo --> Bar + Bar --> [*]\ +""") + + +def test_mermaid_code_start_wrong(): + with pytest.raises(LookupError): + graph1.mermaid_code(start_node=Spam) + + +def test_mermaid_highlight(): + code = graph1.mermaid_code(highlighted_nodes=Foo) + assert code == snapshot("""\ +stateDiagram-v2 + Foo --> Bar + Bar --> [*] + +classDef highlighted fill:#fdff32 +class Foo highlighted\ +""") + assert code == graph1.mermaid_code(highlighted_nodes='Foo') + + +def test_mermaid_highlight_multiple(): + code = graph1.mermaid_code(highlighted_nodes=(Foo, Bar)) + assert code == snapshot("""\ +stateDiagram-v2 + Foo --> Bar + Bar --> [*] + +classDef highlighted fill:#fdff32 +class Foo highlighted +class Bar highlighted\ +""") + assert code == graph1.mermaid_code(highlighted_nodes=('Foo', 'Bar')) + + +def test_mermaid_highlight_wrong(): + with pytest.raises(LookupError): + graph1.mermaid_code(highlighted_nodes=Spam) + + +def test_mermaid_code_with_edge_labels(): + assert graph2.mermaid_code() == snapshot("""\ +stateDiagram-v2 + Spam --> Foo: spam to foo + note right of Spam + This is the docstring for Spam. + end note + Foo --> Bar + Bar --> [*] + Eggs --> [*]: eggs to end\ +""") + + +def test_mermaid_code_without_edge_labels(): + assert graph2.mermaid_code(edge_labels=False, notes=False) == snapshot("""\ +stateDiagram-v2 + Spam --> Foo + Foo --> Bar + Bar --> [*] + Eggs --> [*]\ +""") + + +@dataclass +class AllNodes(BaseNode): + async def run(self, ctx: GraphContext) -> BaseNode: + return Foo() + + +graph3 = Graph(nodes=(AllNodes, Foo, Bar)) + + +def test_mermaid_code_all_nodes(): + assert graph3.mermaid_code() == snapshot("""\ +stateDiagram-v2 + AllNodes --> AllNodes + AllNodes --> Foo + AllNodes --> Bar + Foo --> Bar + Bar --> [*]\ +""") From de6b9e75357b0ccd6e966273d58a981fdd8c451b Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 20:50:57 +0000 Subject: [PATCH 29/57] more mermaid tests, fix 3.9 --- pydantic_graph/pydantic_graph/mermaid.py | 13 ++++-- tests/{test_graph.py => graph/test_main.py} | 4 +- .../test_mermaid.py} | 43 ++++++++++++++++++- 3 files changed, 52 insertions(+), 8 deletions(-) rename tests/{test_graph.py => graph/test_main.py} (98%) rename tests/{test_graph_mermaid.py => graph/test_mermaid.py} (68%) diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index 192d9f1e25..00d2ccb744 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -7,6 +7,7 @@ from textwrap import indent from typing import TYPE_CHECKING, Annotated, Any, Literal +import httpx from annotated_types import Ge, Le from typing_extensions import TypeAlias, TypedDict, Unpack @@ -149,6 +150,7 @@ class MermaidConfig(TypedDict, total=False): The scale must be a number between 1 and 3, and you can only set a scale if one or both of width and height are set. """ + httpx_client: httpx.Client def request_image( @@ -164,8 +166,6 @@ def request_image( Returns: The image data. """ - import httpx - code = generate_code( graph, start_node=kwargs.get('start_node'), @@ -204,9 +204,14 @@ def request_image( if scale := kwargs.get('scale'): params['scale'] = str(scale) - response = httpx.get(url, params=params) + httpx_client = kwargs.get('httpx_client') or httpx.Client() + response = httpx_client.get(url, params=params) if not response.is_success: - raise ValueError(f'{response.status_code} error generating image:\n{response.text}') + raise httpx.HTTPStatusError( + f'{response.status_code} error generating image:\n{response.text}', + request=response.request, + response=response, + ) return response.content diff --git a/tests/test_graph.py b/tests/graph/test_main.py similarity index 98% rename from tests/test_graph.py rename to tests/graph/test_main.py index 801bf1d876..717eaefb7b 100644 --- a/tests/test_graph.py +++ b/tests/graph/test_main.py @@ -11,7 +11,7 @@ from pydantic_graph import BaseNode, End, EndEvent, Graph, GraphContext, GraphRuntimeError, GraphSetupError, NodeEvent -from .conftest import IsFloat, IsNow +from ..conftest import IsFloat, IsNow pytestmark = pytest.mark.anyio @@ -137,7 +137,7 @@ def test_two_bad_nodes(): class Float2String(BaseNode): input_data: float - async def run(self, ctx: GraphContext) -> String2Length | Double: + async def run(self, ctx: GraphContext) -> Union[String2Length, Double]: # noqa: UP007 raise NotImplementedError() class String2Length(BaseNode[None, None]): diff --git a/tests/test_graph_mermaid.py b/tests/graph/test_mermaid.py similarity index 68% rename from tests/test_graph_mermaid.py rename to tests/graph/test_mermaid.py index bb3931fcfa..40a461c96a 100644 --- a/tests/test_graph_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -1,17 +1,42 @@ from __future__ import annotations as _annotations +from collections.abc import Iterator from dataclasses import dataclass from datetime import timezone -from typing import Annotated +from typing import Annotated, Callable +import httpx import pytest from inline_snapshot import snapshot +from typing_extensions import TypeAlias from pydantic_graph import BaseNode, Edge, End, EndEvent, Graph, GraphContext, NodeEvent -from .conftest import IsFloat, IsNow +from ..conftest import IsFloat, IsNow pytestmark = pytest.mark.anyio +HttpxWithHandler: TypeAlias = 'Callable[[Callable[[httpx.Request], httpx.Response] | httpx.Response], httpx.Client]' + + +@pytest.fixture +def httpx_with_handler() -> Iterator[HttpxWithHandler]: + client: httpx.Client | None = None + + def create_client(handler: Callable[[httpx.Request], httpx.Response] | httpx.Response) -> httpx.Client: + nonlocal client + assert client is None, 'client_with_handler can only be called once' + if isinstance(handler, httpx.Response): + transport = httpx.MockTransport(lambda _: handler) + else: + transport = httpx.MockTransport(handler) + client = httpx.Client(mounts={'all://': transport}) + return client + + try: + yield create_client + finally: + if client: # pragma: no cover + client.close() @dataclass @@ -163,3 +188,17 @@ def test_mermaid_code_all_nodes(): Foo --> Bar Bar --> [*]\ """) + + +def test_image(httpx_with_handler: HttpxWithHandler): + response = httpx.Response(200, content=b'fake image') + img = graph1.mermaid_image(start_node=Foo(), httpx_client=httpx_with_handler(response)) + assert img == b'fake image' + + +def test_image_bad(httpx_with_handler: HttpxWithHandler): + response = httpx.Response(404, content=b'not found') + with pytest.raises(httpx.HTTPStatusError, match='404 error generating image:\nnot found') as exc_info: + graph1.mermaid_image(start_node=Foo(), httpx_client=httpx_with_handler(response)) + assert exc_info.value.response.status_code == 404 + assert exc_info.value.response.content == b'not found' From 08e87aad868a155620832b8e67f27942b4330974 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 20:51:28 +0000 Subject: [PATCH 30/57] rename node to start_node in graph.run() --- pydantic_graph/pydantic_graph/graph.py | 8 ++++---- tests/graph/__init__.py | 0 2 files changed, 4 insertions(+), 4 deletions(-) create mode 100644 tests/graph/__init__.py diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 2141333623..4e87142865 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -95,23 +95,23 @@ async def next( async def run( self, state: StateT, - node: BaseNode[StateT, RunEndT], + start_node: BaseNode[StateT, RunEndT], ) -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: history: list[HistoryStep[StateT, RunEndT]] = [] with _logfire.span( '{graph_name} run {start=}', graph_name=self.name or 'graph', - start=node, + start=start_node, ) as run_span: while True: - next_node = await self.next(state, node, history=history) + next_node = await self.next(state, start_node, history=history) if isinstance(next_node, End): history.append(EndEvent(state, next_node)) run_span.set_attribute('history', history) return next_node.data, history elif isinstance(next_node, BaseNode): - node = next_node + start_node = next_node else: if TYPE_CHECKING: assert_never(next_node) diff --git a/tests/graph/__init__.py b/tests/graph/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From 25d79aa192ec973388fb2677c5dad4ede4004e93 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 22:01:24 +0000 Subject: [PATCH 31/57] more tests for graphs --- pydantic_graph/pydantic_graph/_utils.py | 17 +-- pydantic_graph/pydantic_graph/graph.py | 17 +-- pydantic_graph/pydantic_graph/mermaid.py | 14 +-- pydantic_graph/pydantic_graph/nodes.py | 8 +- pyproject.toml | 1 + tests/graph/test_main.py | 53 ++++++-- tests/graph/test_mermaid.py | 151 ++++++++++++++++++++++- 7 files changed, 218 insertions(+), 43 deletions(-) diff --git a/pydantic_graph/pydantic_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py index 498c5bdde6..81123afde6 100644 --- a/pydantic_graph/pydantic_graph/_utils.py +++ b/pydantic_graph/pydantic_graph/_utils.py @@ -5,13 +5,13 @@ from datetime import datetime, timezone from typing import Annotated, Any, Union, get_args, get_origin -from typing_extensions import TypeAliasType +import typing_extensions def get_union_args(tp: Any) -> tuple[Any, ...]: """Extract the arguments of a Union type if `response_type` is a union, otherwise return the original type.""" # similar to `pydantic_ai_slim/pydantic_ai/_result.py:get_union_args` - if isinstance(tp, TypeAliasType): + if isinstance(tp, typing_extensions.TypeAliasType): tp = tp.__value__ origin = get_origin(tp) @@ -26,7 +26,8 @@ def unpack_annotated(tp: Any) -> tuple[Any, list[Any]]: Returns: `(tp argument, ())` if not annotated, otherwise `(stripped type, annotations)`. """ - if get_origin(tp) is Annotated: + origin = get_origin(tp) + if origin is Annotated or origin is typing_extensions.Annotated: inner_tp, *args = get_args(tp) return inner_tp, args else: @@ -54,16 +55,6 @@ def comma_and(items: list[str]) -> str: return ', '.join(items[:-1]) + ', and ' + items[-1] -_NoneType = type(None) - - -def type_arg_name(arg: Any) -> str: - if arg is _NoneType: - return 'None' - else: - return arg.__name__ - - def get_parent_namespace(frame: types.FrameType | None) -> dict[str, Any] | None: """Attempt to get the namespace where the graph was defined. diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 4e87142865..83f5c92845 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -10,9 +10,8 @@ import logfire_api from typing_extensions import Never, ParamSpec, TypeVar, Unpack, assert_never -from . import _utils, mermaid +from . import _utils, exceptions, mermaid from ._utils import get_parent_namespace -from .exceptions import GraphRuntimeError, GraphSetupError from .nodes import BaseNode, End, GraphContext, NodeDef from .state import EndEvent, HistoryStep, NodeEvent, StateT @@ -46,7 +45,9 @@ def __init__( for node in nodes: node_id = node.get_id() if existing_node := _nodes_by_id.get(node_id): - raise GraphSetupError(f'Node ID `{node_id}` is not unique — found in {existing_node} and {node}') + raise exceptions.GraphSetupError( + f'Node ID `{node_id}` is not unique — found in {existing_node} and {node}' + ) else: _nodes_by_id[node_id] = node self.nodes = tuple(_nodes_by_id.values()) @@ -70,17 +71,19 @@ def _validate_edges(self): if bad_edges: bad_edges_list = [f'`{k}` is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] if len(bad_edges_list) == 1: - raise GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.') + raise exceptions.GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.') else: b = '\n'.join(f' {be}' for be in bad_edges_list) - raise GraphSetupError(f'Nodes are referenced in the graph but not included in the graph:\n{b}') + raise exceptions.GraphSetupError( + f'Nodes are referenced in the graph but not included in the graph:\n{b}' + ) async def next( self, state: StateT, node: BaseNode[StateT, RunEndT], history: list[HistoryStep[StateT, RunEndT]] ) -> BaseNode[StateT, Any] | End[RunEndT]: node_id = node.get_id() if node_id not in self.node_defs: - raise GraphRuntimeError(f'Node `{node}` is not in the graph.') + raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') history_step: NodeEvent[StateT, RunEndT] | None = NodeEvent(state, node) history.append(history_step) @@ -116,7 +119,7 @@ async def run( if TYPE_CHECKING: assert_never(next_node) else: - raise GraphRuntimeError( + raise exceptions.GraphRuntimeError( f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.' ) diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index 00d2ccb744..b21b50509d 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -120,7 +120,7 @@ class MermaidConfig(TypedDict, total=False): edge_labels: bool """Whether to include edge labels in the diagram.""" notes: bool - """Whether to include docstrings as notes in the diagram, defaults to true.""" + """Whether to include notes on nodes in the diagram, defaults to true.""" image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] """The image type to generate. If unspecified, the default behavior is `'jpeg'`.""" pdf_fit: bool @@ -176,13 +176,13 @@ def request_image( ) code_base64 = base64.b64encode(code.encode()).decode() - params: dict[str, str | bool] = {} + params: dict[str, str | float] = {} if kwargs.get('image_type') == 'pdf': url = f'https://mermaid.ink/pdf/{code_base64}' if kwargs.get('pdf_fit'): - params['fit'] = True + params['fit'] = '' if kwargs.get('pdf_landscape'): - params['landscape'] = True + params['landscape'] = '' if pdf_paper := kwargs.get('pdf_paper'): params['paper'] = pdf_paper elif kwargs.get('image_type') == 'svg': @@ -198,11 +198,11 @@ def request_image( if theme := kwargs.get('theme'): params['theme'] = theme if width := kwargs.get('width'): - params['width'] = str(width) + params['width'] = width if height := kwargs.get('height'): - params['height'] = str(height) + params['height'] = height if scale := kwargs.get('scale'): - params['scale'] = str(scale) + params['scale'] = scale httpx_client = kwargs.get('httpx_client') or httpx.Client() response = httpx_client.get(url, params=params) diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index 19b42cb7ae..5fd5755881 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -7,7 +7,7 @@ from typing_extensions import Never, TypeVar -from . import _utils +from . import _utils, exceptions from .state import StateT __all__ = 'GraphContext', 'BaseNode', 'End', 'Edge', 'NodeDef' @@ -57,8 +57,8 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRu type_hints = get_type_hints(cls.run, localns=local_ns, include_extras=True) try: return_hint = type_hints['return'] - except KeyError: - raise TypeError(f'Node {cls} is missing a return type hint on its `run` method') + except KeyError as e: + raise exceptions.GraphSetupError(f'Node {cls} is missing a return type hint on its `run` method') from e next_node_edges: dict[str, Edge] = {} end_edge: Edge | None = None @@ -75,7 +75,7 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRu elif issubclass(return_type_origin, BaseNode): next_node_edges[return_type.get_id()] = edge else: - raise TypeError(f'Invalid return type: {return_type}') + raise exceptions.GraphSetupError(f'Invalid return type: {return_type}') return NodeDef( cls, diff --git a/pyproject.toml b/pyproject.toml index f26156bbed..24ddbfbe38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,6 +170,7 @@ exclude_lines = [ 'if typing.TYPE_CHECKING:', '@overload', '@typing.overload', + '@abstractmethod', '\(Protocol\):$', 'typing.assert_never', '$\s*assert_never\(', diff --git a/tests/graph/test_main.py b/tests/graph/test_main.py index 717eaefb7b..5b1d61164f 100644 --- a/tests/graph/test_main.py +++ b/tests/graph/test_main.py @@ -114,6 +114,16 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq ), ] ) + assert [e.summary() for e in history] == snapshot( + [ + 'test_graph..Float2String(input_data=3.14159)', + "test_graph..String2Length(input_data='3.14159')", + 'test_graph..Double(input_data=7)', + "test_graph..String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx')", + 'test_graph..Double(input_data=21)', + 'End(data=42)', + ] + ) def test_one_bad_node(): @@ -134,32 +144,61 @@ async def run(self, ctx: GraphContext) -> End[None]: def test_two_bad_nodes(): - class Float2String(BaseNode): + class Foo(BaseNode): input_data: float - async def run(self, ctx: GraphContext) -> Union[String2Length, Double]: # noqa: UP007 + async def run(self, ctx: GraphContext) -> Union[Bar, Spam]: # noqa: UP007 raise NotImplementedError() - class String2Length(BaseNode[None, None]): + class Bar(BaseNode[None, None]): input_data: str async def run(self, ctx: GraphContext) -> End[None]: return End(None) - class Double(BaseNode[None, None]): + class Spam(BaseNode[None, None]): async def run(self, ctx: GraphContext) -> End[None]: return End(None) with pytest.raises(GraphSetupError) as exc_info: - Graph(nodes=(Float2String,)) + Graph(nodes=(Foo,)) assert exc_info.value.message == snapshot("""\ Nodes are referenced in the graph but not included in the graph: - `String2Length` is referenced by `Float2String` - `Double` is referenced by `Float2String`\ + `Bar` is referenced by `Foo` + `Spam` is referenced by `Foo`\ """) +def test_three_bad_nodes_separate(): + class Foo(BaseNode): + input_data: float + + async def run(self, ctx: GraphContext) -> Eggs: + raise NotImplementedError() + + class Bar(BaseNode[None, None]): + input_data: str + + async def run(self, ctx: GraphContext) -> Eggs: + raise NotImplementedError() + + class Spam(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> Eggs: + raise NotImplementedError() + + class Eggs(BaseNode[None, None]): + async def run(self, ctx: GraphContext) -> End[None]: + return End(None) + + with pytest.raises(GraphSetupError) as exc_info: + Graph(nodes=(Foo, Bar, Spam)) + + assert exc_info.value.message == snapshot( + '`Eggs` is referenced by `Foo`, `Bar`, and `Spam` but not included in the graph.' + ) + + def test_duplicate_id(): class Foo(BaseNode): async def run(self, ctx: GraphContext) -> Bar: diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 40a461c96a..3546296c42 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -3,6 +3,7 @@ from collections.abc import Iterator from dataclasses import dataclass from datetime import timezone +from pathlib import Path from typing import Annotated, Callable import httpx @@ -10,7 +11,8 @@ from inline_snapshot import snapshot from typing_extensions import TypeAlias -from pydantic_graph import BaseNode, Edge, End, EndEvent, Graph, GraphContext, NodeEvent +from pydantic_graph import BaseNode, Edge, End, EndEvent, Graph, GraphContext, GraphSetupError, NodeEvent +from pydantic_graph.nodes import NodeDef from ..conftest import IsFloat, IsNow @@ -64,6 +66,10 @@ async def run(self, ctx: GraphContext) -> Annotated[Foo, Edge(label='spam to foo @dataclass class Eggs(BaseNode[None, None]): + """This is the docstring for Eggs.""" + + enable_docstring_notes = False + async def run(self, ctx: GraphContext) -> Annotated[End[None], Edge(label='eggs to end')]: return End(None) @@ -190,10 +196,42 @@ def test_mermaid_code_all_nodes(): """) -def test_image(httpx_with_handler: HttpxWithHandler): - response = httpx.Response(200, content=b'fake image') - img = graph1.mermaid_image(start_node=Foo(), httpx_client=httpx_with_handler(response)) - assert img == b'fake image' +def test_image_jpg(httpx_with_handler: HttpxWithHandler): + def get_jpg(request: httpx.Request) -> httpx.Response: + assert dict(request.url.params) == snapshot({}) + assert request.url.path.startswith('/img/') + return httpx.Response(200, content=b'fake jpg') + + img = graph1.mermaid_image(start_node=Foo(), httpx_client=httpx_with_handler(get_jpg)) + assert img == b'fake jpg' + + +def test_image_png(httpx_with_handler: HttpxWithHandler): + def get_png(request: httpx.Request) -> httpx.Response: + assert dict(request.url.params) == snapshot( + { + 'type': 'png', + 'bgColor': '123', + 'theme': 'forest', + 'width': '100', + 'height': '200', + 'scale': '3', + } + ) + assert request.url.path.startswith('/img/') + return httpx.Response(200, content=b'fake png') + + img = graph1.mermaid_image( + start_node=Foo(), + image_type='png', + background_color='123', + theme='forest', + width=100, + height=200, + scale=3, + httpx_client=httpx_with_handler(get_png), + ) + assert img == b'fake png' def test_image_bad(httpx_with_handler: HttpxWithHandler): @@ -202,3 +240,106 @@ def test_image_bad(httpx_with_handler: HttpxWithHandler): graph1.mermaid_image(start_node=Foo(), httpx_client=httpx_with_handler(response)) assert exc_info.value.response.status_code == 404 assert exc_info.value.response.content == b'not found' + + +def test_pdf(httpx_with_handler: HttpxWithHandler): + def get_pdf(request: httpx.Request) -> httpx.Response: + assert dict(request.url.params) == snapshot({}) + assert request.url.path.startswith('/pdf/') + return httpx.Response(200, content=b'fake pdf') + + pdf = graph1.mermaid_image(start_node=Foo(), image_type='pdf', httpx_client=httpx_with_handler(get_pdf)) + assert pdf == b'fake pdf' + + +def test_pdf_config(httpx_with_handler: HttpxWithHandler): + def get_pdf(request: httpx.Request) -> httpx.Response: + assert dict(request.url.params) == snapshot({'fit': '', 'landscape': '', 'paper': 'letter'}) + assert request.url.path.startswith('/pdf/') + return httpx.Response(200, content=b'fake pdf') + + pdf = graph1.mermaid_image( + start_node=Foo(), + image_type='pdf', + pdf_fit=True, + pdf_landscape=True, + pdf_paper='letter', + httpx_client=httpx_with_handler(get_pdf), + ) + assert pdf == b'fake pdf' + + +def test_svg(httpx_with_handler: HttpxWithHandler): + def get_svg(request: httpx.Request) -> httpx.Response: + assert dict(request.url.params) == snapshot({}) + assert request.url.path.startswith('/svg/') + return httpx.Response(200, content=b'fake svg') + + svg = graph1.mermaid_image(start_node=Foo(), image_type='svg', httpx_client=httpx_with_handler(get_svg)) + assert svg == b'fake svg' + + +def test_save_jpg(tmp_path: Path, httpx_with_handler: HttpxWithHandler): + def get_jpg(request: httpx.Request) -> httpx.Response: + assert dict(request.url.params) == snapshot({}) + assert request.url.path.startswith('/img/') + return httpx.Response(200, content=b'fake jpg') + + path = tmp_path / 'graph.jpg' + graph1.mermaid_save(path, start_node=Foo(), httpx_client=httpx_with_handler(get_jpg)) + assert path.read_bytes() == b'fake jpg' + + +def test_save_png(tmp_path: Path, httpx_with_handler: HttpxWithHandler): + def get_png(request: httpx.Request) -> httpx.Response: + assert dict(request.url.params) == snapshot({'type': 'png'}) + assert request.url.path.startswith('/img/') + return httpx.Response(200, content=b'fake png') + + path2 = tmp_path / 'graph.png' + graph1.mermaid_save(str(path2), start_node=Foo(), httpx_client=httpx_with_handler(get_png)) + assert path2.read_bytes() == b'fake png' + + +def test_save_pdf_known(tmp_path: Path, httpx_with_handler: HttpxWithHandler): + def get_pdf(request: httpx.Request) -> httpx.Response: + assert dict(request.url.params) == snapshot({}) + assert request.url.path.startswith('/pdf/') + return httpx.Response(200, content=b'fake pdf') + + path2 = tmp_path / 'graph' + graph1.mermaid_save(str(path2), start_node=Foo(), image_type='pdf', httpx_client=httpx_with_handler(get_pdf)) + assert path2.read_bytes() == b'fake pdf' + + +def test_get_node_def(): + assert Foo.get_node_def({}) == snapshot( + NodeDef( + node=Foo, + node_id='Foo', + note=None, + next_node_edges={'Bar': Edge(label=None)}, + end_edge=None, + returns_base_node=False, + ) + ) + + +def test_no_return_type(): + @dataclass + class NoReturnType(BaseNode): + async def run(self, ctx: GraphContext): # type: ignore + raise NotImplementedError() + + with pytest.raises(GraphSetupError, match=r".*\.NoReturnType'> is missing a return type hint on its `run` method"): + NoReturnType.get_node_def({}) + + +def test_wrong_return_type(): + @dataclass + class NoReturnType(BaseNode): + async def run(self, ctx: GraphContext) -> int: # type: ignore + raise NotImplementedError() + + with pytest.raises(GraphSetupError, match="Invalid return type: "): + NoReturnType.get_node_def({}) From 6e629067dd5e90fbec9c8b1888a594fe2dbc12c5 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 22:08:10 +0000 Subject: [PATCH 32/57] coverage in tests --- tests/graph/test_main.py | 16 ++++++++-------- tests/graph/test_mermaid.py | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/graph/test_main.py b/tests/graph/test_main.py index 5b1d61164f..43438ff7bc 100644 --- a/tests/graph/test_main.py +++ b/tests/graph/test_main.py @@ -129,11 +129,11 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq def test_one_bad_node(): class Float2String(BaseNode): async def run(self, ctx: GraphContext) -> String2Length: - return String2Length() + raise NotImplementedError() class String2Length(BaseNode[None, None]): async def run(self, ctx: GraphContext) -> End[None]: - return End(None) + raise NotImplementedError() with pytest.raises(GraphSetupError) as exc_info: Graph(nodes=(Float2String,)) @@ -154,11 +154,11 @@ class Bar(BaseNode[None, None]): input_data: str async def run(self, ctx: GraphContext) -> End[None]: - return End(None) + raise NotImplementedError() class Spam(BaseNode[None, None]): async def run(self, ctx: GraphContext) -> End[None]: - return End(None) + raise NotImplementedError() with pytest.raises(GraphSetupError) as exc_info: Graph(nodes=(Foo,)) @@ -189,7 +189,7 @@ async def run(self, ctx: GraphContext) -> Eggs: class Eggs(BaseNode[None, None]): async def run(self, ctx: GraphContext) -> End[None]: - return End(None) + raise NotImplementedError() with pytest.raises(GraphSetupError) as exc_info: Graph(nodes=(Foo, Bar, Spam)) @@ -202,11 +202,11 @@ async def run(self, ctx: GraphContext) -> End[None]: def test_duplicate_id(): class Foo(BaseNode): async def run(self, ctx: GraphContext) -> Bar: - return Bar() + raise NotImplementedError() class Bar(BaseNode[None, None]): async def run(self, ctx: GraphContext) -> End[None]: - return End(None) + raise NotImplementedError() @classmethod @cache @@ -233,7 +233,7 @@ async def run(self, ctx: GraphContext) -> End[None]: @dataclass class Spam(BaseNode[None, None]): async def run(self, ctx: GraphContext) -> End[None]: - return End(None) + raise NotImplementedError() g = Graph(nodes=(Foo, Bar)) with pytest.raises(GraphRuntimeError) as exc_info: diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 3546296c42..422b7ecdee 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -61,7 +61,7 @@ class Spam(BaseNode): """This is the docstring for Spam.""" async def run(self, ctx: GraphContext) -> Annotated[Foo, Edge(label='spam to foo')]: - return Foo() + raise NotImplementedError() @dataclass @@ -71,7 +71,7 @@ class Eggs(BaseNode[None, None]): enable_docstring_notes = False async def run(self, ctx: GraphContext) -> Annotated[End[None], Edge(label='eggs to end')]: - return End(None) + raise NotImplementedError() graph2 = Graph(nodes=(Spam, Foo, Bar, Eggs)) @@ -179,7 +179,7 @@ def test_mermaid_code_without_edge_labels(): @dataclass class AllNodes(BaseNode): async def run(self, ctx: GraphContext) -> BaseNode: - return Foo() + raise NotImplementedError() graph3 = Graph(nodes=(AllNodes, Foo, Bar)) From c9ebc49e28a193abf73f4e02ff004e358000644f Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 10 Jan 2025 23:43:46 +0000 Subject: [PATCH 33/57] cleanup graph properties --- pydantic_graph/pydantic_graph/graph.py | 27 +++++++++++------------- pydantic_graph/pydantic_graph/mermaid.py | 5 +---- tests/graph/test_main.py | 2 +- 3 files changed, 14 insertions(+), 20 deletions(-) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 83f5c92845..8ae518ac42 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -29,7 +29,6 @@ class Graph(Generic[StateT, RunEndT]): """Definition of a graph.""" name: str | None - nodes: tuple[type[BaseNode[StateT, RunEndT]], ...] node_defs: dict[str, NodeDef[StateT, RunEndT]] def __init__( @@ -41,24 +40,22 @@ def __init__( ): self.name = name - _nodes_by_id: dict[str, type[BaseNode[StateT, RunEndT]]] = {} - for node in nodes: - node_id = node.get_id() - if existing_node := _nodes_by_id.get(node_id): - raise exceptions.GraphSetupError( - f'Node ID `{node_id}` is not unique — found in {existing_node} and {node}' - ) - else: - _nodes_by_id[node_id] = node - self.nodes = tuple(_nodes_by_id.values()) - parent_namespace = get_parent_namespace(inspect.currentframe()) - self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = { - node.get_id(): node.get_node_def(parent_namespace) for node in self.nodes - } + self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = {} + for node in nodes: + self._register_node(node, parent_namespace) self._validate_edges() + def _register_node(self, node: type[BaseNode[StateT, RunEndT]], parent_namespace: dict[str, Any] | None) -> None: + node_id = node.get_id() + if existing_node := self.node_defs.get(node_id): + raise exceptions.GraphSetupError( + f'Node ID `{node_id}` is not unique — found on {existing_node.node} and {node}' + ) + else: + self.node_defs[node_id] = node.get_node_def(parent_namespace) + def _validate_edges(self): known_node_ids = self.node_defs.keys() bad_edges: dict[str, list[str]] = {} diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index b21b50509d..0b498e685f 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -52,10 +52,7 @@ def generate_code( raise LookupError(f'Start node "{node_id}" is not in the graph.') lines = ['stateDiagram-v2'] - for node in graph.nodes: - node_id = node.get_id() - node_def = graph.node_defs[node_id] - + for node_id, node_def in graph.node_defs.items(): # we use round brackets (rounded box) for nodes other than the start and end if node_id in start_node_ids: lines.append(f' [*] --> {node_id}') diff --git a/tests/graph/test_main.py b/tests/graph/test_main.py index 43438ff7bc..7e076baf47 100644 --- a/tests/graph/test_main.py +++ b/tests/graph/test_main.py @@ -216,7 +216,7 @@ def get_id(cls) -> str: with pytest.raises(GraphSetupError) as exc_info: Graph(nodes=(Foo, Bar)) - assert exc_info.value.message == snapshot(IsStr(regex='Node ID `Foo` is not unique — found in.+')) + assert exc_info.value.message == snapshot(IsStr(regex='Node ID `Foo` is not unique — found on Date: Sat, 11 Jan 2025 00:39:03 +0000 Subject: [PATCH 34/57] infer graph name --- pydantic_graph/pydantic_graph/graph.py | 59 ++++++++++++-- pydantic_graph/pydantic_graph/mermaid.py | 12 ++- tests/graph/test_main.py | 53 +++++++++++-- tests/graph/test_mermaid.py | 99 +++++++++++++++--------- 4 files changed, 176 insertions(+), 47 deletions(-) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 8ae518ac42..f7f45020ba 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -5,10 +5,11 @@ from dataclasses import dataclass from pathlib import Path from time import perf_counter +from types import FrameType from typing import TYPE_CHECKING, Any, Generic import logfire_api -from typing_extensions import Never, ParamSpec, TypeVar, Unpack, assert_never +from typing_extensions import Literal, Never, ParamSpec, TypeVar, Unpack, assert_never from . import _utils, exceptions, mermaid from ._utils import get_parent_namespace @@ -76,8 +77,15 @@ def _validate_edges(self): ) async def next( - self, state: StateT, node: BaseNode[StateT, RunEndT], history: list[HistoryStep[StateT, RunEndT]] + self, + state: StateT, + node: BaseNode[StateT, RunEndT], + history: list[HistoryStep[StateT, RunEndT]], + *, + infer_name: bool = True, ) -> BaseNode[StateT, Any] | End[RunEndT]: + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) node_id = node.get_id() if node_id not in self.node_defs: raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') @@ -96,8 +104,12 @@ async def run( self, state: StateT, start_node: BaseNode[StateT, RunEndT], + *, + infer_name: bool = True, ) -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: history: list[HistoryStep[StateT, RunEndT]] = [] + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) with _logfire.span( '{graph_name} run {start=}', @@ -105,7 +117,7 @@ async def run( start=start_node, ) as run_span: while True: - next_node = await self.next(state, start_node, history=history) + next_node = await self.next(state, start_node, history, infer_name=False) if isinstance(next_node, End): history.append(EndEvent(state, next_node)) run_span.set_attribute('history', history) @@ -127,19 +139,56 @@ def mermaid_code( highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, edge_labels: bool = True, + title: str | None | Literal[False] = None, notes: bool = True, + infer_name: bool = True, ) -> str: + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + if title is None and self.name: + title = self.name return mermaid.generate_code( self, start_node=start_node, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css, + title=title or None, edge_labels=edge_labels, notes=notes, ) - def mermaid_image(self, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes: + def mermaid_image(self, infer_name: bool = True, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes: + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + if 'title' not in kwargs and self.name: + kwargs['title'] = self.name return mermaid.request_image(self, **kwargs) - def mermaid_save(self, path: Path | str, /, **kwargs: Unpack[mermaid.MermaidConfig]) -> None: + def mermaid_save( + self, path: Path | str, /, *, infer_name: bool = True, **kwargs: Unpack[mermaid.MermaidConfig] + ) -> None: + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + if 'title' not in kwargs and self.name: + kwargs['title'] = self.name mermaid.save_image(path, self, **kwargs) + + def _infer_name(self, function_frame: FrameType | None) -> None: + """Infer the agent name from the call frame. + + Usage should be `self._infer_name(inspect.currentframe())`. + + Copied from `Agent`. + """ + assert self.name is None, 'Name already set' + if function_frame is not None and (parent_frame := function_frame.f_back): # pragma: no branch + for name, item in parent_frame.f_locals.items(): + if item is self: + self.name = name + return + if parent_frame.f_locals != parent_frame.f_globals: + # if we couldn't find the agent in locals and globals are a different dict, try globals + for name, item in parent_frame.f_globals.items(): + if item is self: + self.name = name + return diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index 0b498e685f..e89a00da49 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -24,13 +24,14 @@ DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32' -def generate_code( +def generate_code( # noqa: C901 graph: Graph[Any, Any], /, *, start_node: Sequence[NodeIdent] | NodeIdent | None = None, highlighted_nodes: Sequence[NodeIdent] | NodeIdent | None = None, highlight_css: str = DEFAULT_HIGHLIGHT_CSS, + title: str | None = None, edge_labels: bool = True, notes: bool = True, ) -> str: @@ -41,6 +42,7 @@ def generate_code( start_node: Identifiers of nodes that start the graph. highlighted_nodes: Identifiers of nodes to highlight. highlight_css: CSS to use for highlighting nodes. + title: The title of the diagram. edge_labels: Whether to include edge labels in the diagram. notes: Whether to include notes in the diagram. @@ -51,7 +53,10 @@ def generate_code( if node_id not in graph.node_defs: raise LookupError(f'Start node "{node_id}" is not in the graph.') - lines = ['stateDiagram-v2'] + lines: list[str] = [] + if title: + lines = ['---', f'title: {title}', '---'] + lines.append('stateDiagram-v2') for node_id, node_def in graph.node_defs.items(): # we use round brackets (rounded box) for nodes other than the start and end if node_id in start_node_ids: @@ -114,6 +119,8 @@ class MermaidConfig(TypedDict, total=False): """Identifiers of nodes to highlight.""" highlight_css: str """CSS to use for highlighting nodes.""" + title: str | None + """The title of the diagram.""" edge_labels: bool """Whether to include edge labels in the diagram.""" notes: bool @@ -168,6 +175,7 @@ def request_image( start_node=kwargs.get('start_node'), highlighted_nodes=kwargs.get('highlighted_nodes'), highlight_css=kwargs.get('highlight_css', DEFAULT_HIGHLIGHT_CSS), + title=kwargs.get('title'), edge_labels=kwargs.get('edge_labels', True), notes=kwargs.get('notes', True), ) diff --git a/tests/graph/test_main.py b/tests/graph/test_main.py index 7e076baf47..22463341d9 100644 --- a/tests/graph/test_main.py +++ b/tests/graph/test_main.py @@ -3,13 +3,23 @@ from dataclasses import dataclass from datetime import timezone from functools import cache -from typing import Union +from typing import Never, Union import pytest from dirty_equals import IsStr from inline_snapshot import snapshot -from pydantic_graph import BaseNode, End, EndEvent, Graph, GraphContext, GraphRuntimeError, GraphSetupError, NodeEvent +from pydantic_graph import ( + BaseNode, + End, + EndEvent, + Graph, + GraphContext, + GraphRuntimeError, + GraphSetupError, + HistoryStep, + NodeEvent, +) from ..conftest import IsFloat, IsNow @@ -41,10 +51,12 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq else: return End(self.input_data * 2) - g = Graph[None, int](nodes=(Float2String, String2Length, Double)) - result, history = await g.run(None, Float2String(3.14)) + my_graph = Graph[None, int](nodes=(Float2String, String2Length, Double)) + assert my_graph.name is None + result, history = await my_graph.run(None, Float2String(3.14)) # len('3.14') * 2 == 8 assert result == 8 + assert my_graph.name == 'my_graph' assert history == snapshot( [ NodeEvent( @@ -72,7 +84,7 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq ), ] ) - result, history = await g.run(None, Float2String(3.14159)) + result, history = await my_graph.run(None, Float2String(3.14159)) # len('3.14159') == 7, 21 * 2 == 42 assert result == 42 assert history == snapshot( @@ -258,3 +270,34 @@ async def run(self, ctx: GraphContext) -> End[None]: await g.run(None, Foo()) assert exc_info.value.message == snapshot('Invalid node return type: `int`. Expected `BaseNode` or `End`.') + + +async def test_next(): + @dataclass + class Foo(BaseNode): + async def run(self, ctx: GraphContext) -> Bar: + return Bar() + + @dataclass + class Bar(BaseNode): + async def run(self, ctx: GraphContext) -> Foo: + return Foo() + + g = Graph(nodes=(Foo, Bar)) + assert g.name is None + history: list[HistoryStep[None, Never]] = [] + n = await g.next(None, Foo(), history) + assert n == Bar() + assert g.name == 'g' + assert history == snapshot([NodeEvent(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat())]) + + assert isinstance(n, Bar) + n2 = await g.next(None, n, history) + assert n2 == Foo() + + assert history == snapshot( + [ + NodeEvent(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeEvent(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + ] + ) diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 422b7ecdee..9f790fc0ca 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import base64 from collections.abc import Iterator from dataclasses import dataclass from datetime import timezone @@ -9,7 +10,6 @@ import httpx import pytest from inline_snapshot import snapshot -from typing_extensions import TypeAlias from pydantic_graph import BaseNode, Edge, End, EndEvent, Graph, GraphContext, GraphSetupError, NodeEvent from pydantic_graph.nodes import NodeDef @@ -17,28 +17,6 @@ from ..conftest import IsFloat, IsNow pytestmark = pytest.mark.anyio -HttpxWithHandler: TypeAlias = 'Callable[[Callable[[httpx.Request], httpx.Response] | httpx.Response], httpx.Client]' - - -@pytest.fixture -def httpx_with_handler() -> Iterator[HttpxWithHandler]: - client: httpx.Client | None = None - - def create_client(handler: Callable[[httpx.Request], httpx.Response] | httpx.Response) -> httpx.Client: - nonlocal client - assert client is None, 'client_with_handler can only be called once' - if isinstance(handler, httpx.Response): - transport = httpx.MockTransport(lambda _: handler) - else: - transport = httpx.MockTransport(handler) - client = httpx.Client(mounts={'all://': transport}) - return client - - try: - yield create_client - finally: - if client: # pragma: no cover - client.close() @dataclass @@ -100,7 +78,7 @@ async def test_run_graph(): def test_mermaid_code_no_start(): - assert graph1.mermaid_code() == snapshot("""\ + assert graph1.mermaid_code(title=False) == snapshot("""\ stateDiagram-v2 Foo --> Bar Bar --> [*]\ @@ -109,6 +87,9 @@ def test_mermaid_code_no_start(): def test_mermaid_code_start(): assert graph1.mermaid_code(start_node=Foo) == snapshot("""\ +--- +title: graph1 +--- stateDiagram-v2 [*] --> Foo Foo --> Bar @@ -124,6 +105,9 @@ def test_mermaid_code_start_wrong(): def test_mermaid_highlight(): code = graph1.mermaid_code(highlighted_nodes=Foo) assert code == snapshot("""\ +--- +title: graph1 +--- stateDiagram-v2 Foo --> Bar Bar --> [*] @@ -137,6 +121,9 @@ class Foo highlighted\ def test_mermaid_highlight_multiple(): code = graph1.mermaid_code(highlighted_nodes=(Foo, Bar)) assert code == snapshot("""\ +--- +title: graph1 +--- stateDiagram-v2 Foo --> Bar Bar --> [*] @@ -155,6 +142,9 @@ def test_mermaid_highlight_wrong(): def test_mermaid_code_with_edge_labels(): assert graph2.mermaid_code() == snapshot("""\ +--- +title: graph2 +--- stateDiagram-v2 Spam --> Foo: spam to foo note right of Spam @@ -168,6 +158,9 @@ def test_mermaid_code_with_edge_labels(): def test_mermaid_code_without_edge_labels(): assert graph2.mermaid_code(edge_labels=False, notes=False) == snapshot("""\ +--- +title: graph2 +--- stateDiagram-v2 Spam --> Foo Foo --> Bar @@ -187,6 +180,9 @@ async def run(self, ctx: GraphContext) -> BaseNode: def test_mermaid_code_all_nodes(): assert graph3.mermaid_code() == snapshot("""\ +--- +title: graph3 +--- stateDiagram-v2 AllNodes --> AllNodes AllNodes --> Foo @@ -196,14 +192,37 @@ def test_mermaid_code_all_nodes(): """) +@pytest.fixture +def httpx_with_handler() -> Iterator[HttpxWithHandler]: + client: httpx.Client | None = None + + def create_client(handler: Callable[[httpx.Request], httpx.Response]) -> httpx.Client: + nonlocal client + assert client is None, 'client_with_handler can only be called once' + client = httpx.Client(mounts={'all://': httpx.MockTransport(handler)}) + return client + + try: + yield create_client + finally: + if client: # pragma: no cover + client.close() + + +HttpxWithHandler = Callable[[Callable[[httpx.Request], httpx.Response]], httpx.Client] + + def test_image_jpg(httpx_with_handler: HttpxWithHandler): def get_jpg(request: httpx.Request) -> httpx.Response: assert dict(request.url.params) == snapshot({}) assert request.url.path.startswith('/img/') - return httpx.Response(200, content=b'fake jpg') + mermaid = base64.b64decode(request.url.path[5:].encode()) + return httpx.Response(200, content=mermaid) + graph1.name = None img = graph1.mermaid_image(start_node=Foo(), httpx_client=httpx_with_handler(get_jpg)) - assert img == b'fake jpg' + assert graph1.name == 'graph1' + assert img == snapshot(b'---\ntitle: graph1\n---\nstateDiagram-v2\n [*] --> Foo\n Foo --> Bar\n Bar --> [*]') def test_image_png(httpx_with_handler: HttpxWithHandler): @@ -219,10 +238,12 @@ def get_png(request: httpx.Request) -> httpx.Response: } ) assert request.url.path.startswith('/img/') - return httpx.Response(200, content=b'fake png') + mermaid = base64.b64decode(request.url.path[5:].encode()) + return httpx.Response(200, content=mermaid) img = graph1.mermaid_image( start_node=Foo(), + title=None, image_type='png', background_color='123', theme='forest', @@ -231,13 +252,15 @@ def get_png(request: httpx.Request) -> httpx.Response: scale=3, httpx_client=httpx_with_handler(get_png), ) - assert img == b'fake png' + assert img == snapshot(b'stateDiagram-v2\n [*] --> Foo\n Foo --> Bar\n Bar --> [*]') def test_image_bad(httpx_with_handler: HttpxWithHandler): - response = httpx.Response(404, content=b'not found') + def get_404(request: httpx.Request) -> httpx.Response: + return httpx.Response(404, content=b'not found') + with pytest.raises(httpx.HTTPStatusError, match='404 error generating image:\nnot found') as exc_info: - graph1.mermaid_image(start_node=Foo(), httpx_client=httpx_with_handler(response)) + graph1.mermaid_image(start_node=Foo(), httpx_client=httpx_with_handler(get_404)) assert exc_info.value.response.status_code == 404 assert exc_info.value.response.content == b'not found' @@ -283,22 +306,28 @@ def test_save_jpg(tmp_path: Path, httpx_with_handler: HttpxWithHandler): def get_jpg(request: httpx.Request) -> httpx.Response: assert dict(request.url.params) == snapshot({}) assert request.url.path.startswith('/img/') - return httpx.Response(200, content=b'fake jpg') + mermaid = base64.b64decode(request.url.path[5:].encode()) + return httpx.Response(200, content=mermaid) path = tmp_path / 'graph.jpg' graph1.mermaid_save(path, start_node=Foo(), httpx_client=httpx_with_handler(get_jpg)) - assert path.read_bytes() == b'fake jpg' + assert path.read_bytes() == snapshot( + b'---\ntitle: graph1\n---\nstateDiagram-v2\n [*] --> Foo\n Foo --> Bar\n Bar --> [*]' + ) def test_save_png(tmp_path: Path, httpx_with_handler: HttpxWithHandler): def get_png(request: httpx.Request) -> httpx.Response: assert dict(request.url.params) == snapshot({'type': 'png'}) assert request.url.path.startswith('/img/') - return httpx.Response(200, content=b'fake png') + mermaid = base64.b64decode(request.url.path[5:].encode()) + return httpx.Response(200, content=mermaid) path2 = tmp_path / 'graph.png' - graph1.mermaid_save(str(path2), start_node=Foo(), httpx_client=httpx_with_handler(get_png)) - assert path2.read_bytes() == b'fake png' + graph1.name = None + graph1.mermaid_save(str(path2), title=None, start_node=Foo(), httpx_client=httpx_with_handler(get_png)) + assert graph1.name == 'graph1' + assert path2.read_bytes() == snapshot(b'stateDiagram-v2\n [*] --> Foo\n Foo --> Bar\n Bar --> [*]') def test_save_pdf_known(tmp_path: Path, httpx_with_handler: HttpxWithHandler): From 3b22850b83daa8d8ea1eaa8071b4df83b7134f21 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 11 Jan 2025 00:50:38 +0000 Subject: [PATCH 35/57] fix for 3.9 --- tests/graph/test_main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/graph/test_main.py b/tests/graph/test_main.py index 22463341d9..a26e9d7663 100644 --- a/tests/graph/test_main.py +++ b/tests/graph/test_main.py @@ -3,11 +3,12 @@ from dataclasses import dataclass from datetime import timezone from functools import cache -from typing import Never, Union +from typing import Union import pytest from dirty_equals import IsStr from inline_snapshot import snapshot +from typing_extensions import Never from pydantic_graph import ( BaseNode, From 452a62f3ff760afc7ce788e607d6aa6aa9012d9a Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 11 Jan 2025 12:55:44 +0000 Subject: [PATCH 36/57] adding API docs --- docs/api/pydantic_graph/graph.md | 3 + docs/api/pydantic_graph/mermaid.md | 3 + docs/api/pydantic_graph/nodes.md | 11 ++ docs/api/pydantic_graph/state.md | 3 + mkdocs.yml | 5 + pydantic_graph/pydantic_graph/__init__.py | 6 +- pydantic_graph/pydantic_graph/graph.py | 143 ++++++++++++++++------ pydantic_graph/pydantic_graph/mermaid.py | 113 +++++++++-------- pydantic_graph/pydantic_graph/nodes.py | 27 +++- pydantic_graph/pydantic_graph/state.py | 32 +++-- pydantic_graph/pyproject.toml | 3 +- tests/graph/test_main.py | 30 ++--- tests/graph/test_mermaid.py | 8 +- uv.lock | 22 ++-- 14 files changed, 273 insertions(+), 136 deletions(-) create mode 100644 docs/api/pydantic_graph/graph.md create mode 100644 docs/api/pydantic_graph/mermaid.md create mode 100644 docs/api/pydantic_graph/nodes.md create mode 100644 docs/api/pydantic_graph/state.md diff --git a/docs/api/pydantic_graph/graph.md b/docs/api/pydantic_graph/graph.md new file mode 100644 index 0000000000..c8e3ea7944 --- /dev/null +++ b/docs/api/pydantic_graph/graph.md @@ -0,0 +1,3 @@ +# `pydantic_graph` + +::: pydantic_graph.graph diff --git a/docs/api/pydantic_graph/mermaid.md b/docs/api/pydantic_graph/mermaid.md new file mode 100644 index 0000000000..ccc7cec5ad --- /dev/null +++ b/docs/api/pydantic_graph/mermaid.md @@ -0,0 +1,3 @@ +# `pydantic_graph.mermaid` + +::: pydantic_graph.mermaid diff --git a/docs/api/pydantic_graph/nodes.md b/docs/api/pydantic_graph/nodes.md new file mode 100644 index 0000000000..7540920a02 --- /dev/null +++ b/docs/api/pydantic_graph/nodes.md @@ -0,0 +1,11 @@ +# `pydantic_graph.nodes` + +::: pydantic_graph.nodes + options: + members: + - GraphContext + - BaseNode + - End + - Edge + - RunEndT + - NodeRunEndT diff --git a/docs/api/pydantic_graph/state.md b/docs/api/pydantic_graph/state.md new file mode 100644 index 0000000000..480eea5a54 --- /dev/null +++ b/docs/api/pydantic_graph/state.md @@ -0,0 +1,3 @@ +# `pydantic_graph.state` + +::: pydantic_graph.state diff --git a/mkdocs.yml b/mkdocs.yml index 781b1494e3..92fe21f938 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -55,6 +55,10 @@ nav: - api/models/ollama.md - api/models/test.md - api/models/function.md + - api/pydantic_graph/graph.md + - api/pydantic_graph/nodes.md + - api/pydantic_graph/state.md + - api/pydantic_graph/mermaid.md extra: # hide the "Made with Material for MkDocs" message @@ -150,6 +154,7 @@ markdown_extensions: watch: - pydantic_ai_slim + - pydantic_graph - examples plugins: diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index 192b62f98c..5102117738 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -1,7 +1,7 @@ from .exceptions import GraphRuntimeError, GraphSetupError from .graph import Graph from .nodes import BaseNode, Edge, End, GraphContext -from .state import AbstractState, EndEvent, HistoryStep, NodeEvent +from .state import AbstractState, EndStep, HistoryStep, NodeStep __all__ = ( 'Graph', @@ -10,9 +10,9 @@ 'GraphContext', 'Edge', 'AbstractState', - 'EndEvent', + 'EndStep', 'HistoryStep', - 'NodeEvent', + 'NodeStep', 'GraphSetupError', 'GraphRuntimeError', ) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index f7f45020ba..eed901fe95 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -14,7 +14,7 @@ from . import _utils, exceptions, mermaid from ._utils import get_parent_namespace from .nodes import BaseNode, End, GraphContext, NodeDef -from .state import EndEvent, HistoryStep, NodeEvent, StateT +from .state import EndStep, HistoryStep, NodeStep, StateT __all__ = ('Graph',) @@ -36,9 +36,16 @@ def __init__( self, *, nodes: Sequence[type[BaseNode[StateT, RunEndT]]], - state_type: type[StateT] | None = None, name: str | None = None, ): + """Create a graph from a sequence of nodes. + + Args: + nodes: The nodes which make up the graph, nodes need to be unique and all be generic in the same + state type. + name: Optional name for the graph, if not provided the name will be inferred from the calling frame + on the first call to a graph method. + """ self.name = name parent_namespace = get_parent_namespace(inspect.currentframe()) @@ -48,34 +55,6 @@ def __init__( self._validate_edges() - def _register_node(self, node: type[BaseNode[StateT, RunEndT]], parent_namespace: dict[str, Any] | None) -> None: - node_id = node.get_id() - if existing_node := self.node_defs.get(node_id): - raise exceptions.GraphSetupError( - f'Node ID `{node_id}` is not unique — found on {existing_node.node} and {node}' - ) - else: - self.node_defs[node_id] = node.get_node_def(parent_namespace) - - def _validate_edges(self): - known_node_ids = self.node_defs.keys() - bad_edges: dict[str, list[str]] = {} - - for node_id, node_def in self.node_defs.items(): - for edge in node_def.next_node_edges.keys(): - if edge not in known_node_ids: - bad_edges.setdefault(edge, []).append(f'`{node_id}`') - - if bad_edges: - bad_edges_list = [f'`{k}` is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] - if len(bad_edges_list) == 1: - raise exceptions.GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.') - else: - b = '\n'.join(f' {be}' for be in bad_edges_list) - raise exceptions.GraphSetupError( - f'Nodes are referenced in the graph but not included in the graph:\n{b}' - ) - async def next( self, state: StateT, @@ -84,13 +63,24 @@ async def next( *, infer_name: bool = True, ) -> BaseNode[StateT, Any] | End[RunEndT]: + """Run a node in the graph and return the next node to run. + + Args: + state: The current state of the graph. + node: The node to run. + history: The history of the graph run so far. NOTE: this will be mutated to add the new step. + infer_name: Whether to infer the graph name from the calling frame. + + Returns: + The next node to run or [`End`][pydantic_graph.nodes.End] if the graph has finished. + """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) node_id = node.get_id() if node_id not in self.node_defs: raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') - history_step: NodeEvent[StateT, RunEndT] | None = NodeEvent(state, node) + history_step: NodeStep[StateT, RunEndT] | None = NodeStep(state, node) history.append(history_step) ctx = GraphContext(state) @@ -107,6 +97,15 @@ async def run( *, infer_name: bool = True, ) -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: + """Run the graph from a starting node until it ends. + + Args: + state: The initial state of the graph. + start_node: the first node to run. + infer_name: Whether to infer the graph name from the calling frame. + + Returns: The result type from ending the run and the history of the run. + """ history: list[HistoryStep[StateT, RunEndT]] = [] if infer_name and self.name is None: self._infer_name(inspect.currentframe()) @@ -119,7 +118,7 @@ async def run( while True: next_node = await self.next(state, start_node, history, infer_name=False) if isinstance(next_node, End): - history.append(EndEvent(state, next_node)) + history.append(EndStep(state, next_node)) run_span.set_attribute('history', history) return next_node.data, history elif isinstance(next_node, BaseNode): @@ -136,13 +135,29 @@ def mermaid_code( self, *, start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, - highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, - highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, - edge_labels: bool = True, title: str | None | Literal[False] = None, + edge_labels: bool = True, notes: bool = True, + highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, + highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, infer_name: bool = True, ) -> str: + """Generate a diagram representing the graph as [mermaid](https://mermaid.js.org/) chart. + + This method calls [`pydantic_graph.mermaid.generate_code`][pydantic_graph.mermaid.generate_code]. + + Args: + start_node: The node or nodes to start the graph from. + title: The title of the diagram, use `False` to not include a title. + edge_labels: Whether to include edge labels. + notes: Whether to include notes on each node. + highlighted_nodes: Optional node or nodes to highlight. + highlight_css: The CSS to use for highlighting nodes. + infer_name: Whether to infer the graph name from the calling frame. + + Returns: + The mermaid code for the graph, which can then be rendered as a diagram. + """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) if title is None and self.name: @@ -158,6 +173,22 @@ def mermaid_code( ) def mermaid_image(self, infer_name: bool = True, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes: + """Generate a diagram representing the graph as an image. + + The format and diagram can be customized using `kwargs`, + see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig]. + + !!! note "Uses external service" + This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink` + is a free service not affiliated with Pydantic. + + Args: + infer_name: Whether to infer the graph name from the calling frame. + **kwargs: Additional arguments to pass to `mermaid.request_image`. + + Returns: + The image bytes. + """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) if 'title' not in kwargs and self.name: @@ -167,12 +198,54 @@ def mermaid_image(self, infer_name: bool = True, **kwargs: Unpack[mermaid.Mermai def mermaid_save( self, path: Path | str, /, *, infer_name: bool = True, **kwargs: Unpack[mermaid.MermaidConfig] ) -> None: + """Generate a diagram representing the graph and save it as an image. + + The format and diagram can be customized using `kwargs`, + see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig]. + + !!! note "Uses external service" + This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink` + is a free service not affiliated with Pydantic. + + Args: + path: The path to save the image to. + infer_name: Whether to infer the graph name from the calling frame. + **kwargs: Additional arguments to pass to `mermaid.save_image`. + """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) if 'title' not in kwargs and self.name: kwargs['title'] = self.name mermaid.save_image(path, self, **kwargs) + def _register_node(self, node: type[BaseNode[StateT, RunEndT]], parent_namespace: dict[str, Any] | None) -> None: + node_id = node.get_id() + if existing_node := self.node_defs.get(node_id): + raise exceptions.GraphSetupError( + f'Node ID `{node_id}` is not unique — found on {existing_node.node} and {node}' + ) + else: + self.node_defs[node_id] = node.get_node_def(parent_namespace) + + def _validate_edges(self): + known_node_ids = self.node_defs.keys() + bad_edges: dict[str, list[str]] = {} + + for node_id, node_def in self.node_defs.items(): + for edge in node_def.next_node_edges.keys(): + if edge not in known_node_ids: + bad_edges.setdefault(edge, []).append(f'`{node_id}`') + + if bad_edges: + bad_edges_list = [f'`{k}` is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] + if len(bad_edges_list) == 1: + raise exceptions.GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.') + else: + b = '\n'.join(f' {be}' for be in bad_edges_list) + raise exceptions.GraphSetupError( + f'Nodes are referenced in the graph but not included in the graph:\n{b}' + ) + def _infer_name(self, function_frame: FrameType | None) -> None: """Infer the agent name from the call frame. diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index e89a00da49..1e72493759 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -18,10 +18,8 @@ __all__ = 'NodeIdent', 'DEFAULT_HIGHLIGHT_CSS', 'generate_code', 'MermaidConfig', 'request_image', 'save_image' - - -NodeIdent: TypeAlias = 'type[BaseNode[Any, Any]] | BaseNode[Any, Any] | str' DEFAULT_HIGHLIGHT_CSS = 'fill:#fdff32' +"""The default CSS to use for highlighting nodes.""" def generate_code( # noqa: C901 @@ -46,7 +44,8 @@ def generate_code( # noqa: C901 edge_labels: Whether to include edge labels in the diagram. notes: Whether to include notes in the diagram. - Returns: The Mermaid code for the graph. + Returns: + The Mermaid code for the graph. """ start_node_ids = set(_node_ids(start_node or ())) for node_id in start_node_ids: @@ -110,53 +109,6 @@ def _node_ids(node_idents: Sequence[NodeIdent] | NodeIdent) -> Iterable[str]: yield node.get_id() -class MermaidConfig(TypedDict, total=False): - """Parameters to configure mermaid chart generation.""" - - start_node: Sequence[NodeIdent] | NodeIdent - """Identifiers of nodes that start the graph.""" - highlighted_nodes: Sequence[NodeIdent] | NodeIdent - """Identifiers of nodes to highlight.""" - highlight_css: str - """CSS to use for highlighting nodes.""" - title: str | None - """The title of the diagram.""" - edge_labels: bool - """Whether to include edge labels in the diagram.""" - notes: bool - """Whether to include notes on nodes in the diagram, defaults to true.""" - image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] - """The image type to generate. If unspecified, the default behavior is `'jpeg'`.""" - pdf_fit: bool - """When using image_type='pdf', whether to fit the diagram to the PDF page.""" - pdf_landscape: bool - """When using image_type='pdf', whether to use landscape orientation for the PDF. - - This has no effect if using `pdf_fit`. - """ - pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] - """When using image_type='pdf', the paper size of the PDF.""" - background_color: str - """The background color of the diagram. - - If None, the default transparent background is used. The color value is interpreted as a hexadecimal color - code by default (and should not have a leading '#'), but you can also use named colors by prefixing the - value with `'!'`. For example, valid choices include `background_color='!white'` or `background_color='FF0000'`. - """ - theme: Literal['default', 'neutral', 'dark', 'forest'] - """The theme of the diagram. Defaults to 'default'.""" - width: int - """The width of the diagram.""" - height: int - """The height of the diagram.""" - scale: Annotated[float, Ge(1), Le(3)] - """The scale of the diagram. - - The scale must be a number between 1 and 3, and you can only set a scale if one or both of width and height are set. - """ - httpx_client: httpx.Client - - def request_image( graph: Graph[Any, Any], /, @@ -244,3 +196,62 @@ def save_image( image_data = request_image(graph, **kwargs) path.write_bytes(image_data) + + +class MermaidConfig(TypedDict, total=False): + """Parameters to configure mermaid chart generation.""" + + start_node: Sequence[NodeIdent] | NodeIdent + """Identifiers of nodes that start the graph.""" + highlighted_nodes: Sequence[NodeIdent] | NodeIdent + """Identifiers of nodes to highlight.""" + highlight_css: str + """CSS to use for highlighting nodes.""" + title: str | None + """The title of the diagram.""" + edge_labels: bool + """Whether to include edge labels in the diagram.""" + notes: bool + """Whether to include notes on nodes in the diagram, defaults to true.""" + image_type: Literal['jpeg', 'png', 'webp', 'svg', 'pdf'] + """The image type to generate. If unspecified, the default behavior is `'jpeg'`.""" + pdf_fit: bool + """When using image_type='pdf', whether to fit the diagram to the PDF page.""" + pdf_landscape: bool + """When using image_type='pdf', whether to use landscape orientation for the PDF. + + This has no effect if using `pdf_fit`. + """ + pdf_paper: Literal['letter', 'legal', 'tabloid', 'ledger', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6'] + """When using image_type='pdf', the paper size of the PDF.""" + background_color: str + """The background color of the diagram. + + If None, the default transparent background is used. The color value is interpreted as a hexadecimal color + code by default (and should not have a leading '#'), but you can also use named colors by prefixing the + value with `'!'`. For example, valid choices include `background_color='!white'` or `background_color='FF0000'`. + """ + theme: Literal['default', 'neutral', 'dark', 'forest'] + """The theme of the diagram. Defaults to 'default'.""" + width: int + """The width of the diagram.""" + height: int + """The height of the diagram.""" + scale: Annotated[float, Ge(1), Le(3)] + """The scale of the diagram. + + The scale must be a number between 1 and 3, and you can only set a scale if one or both of width and height are set. + """ + httpx_client: httpx.Client + """An HTTPX client to use for requests, mostly for testing purposes.""" + + +NodeIdent: TypeAlias = 'type[BaseNode[Any, Any]] | BaseNode[Any, Any] | str' +"""A type alias for a node identifier. + +This can be: + +- A node instance (instance of a subclass of [`BaseNode`][pydantic_graph.nodes.BaseNode]). +- A node class (subclass of [`BaseNode`][pydantic_graph.nodes.BaseNode]). +- A string representing the node ID. +""" diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index 5fd5755881..1881b10b44 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -10,10 +10,12 @@ from . import _utils, exceptions from .state import StateT -__all__ = 'GraphContext', 'BaseNode', 'End', 'Edge', 'NodeDef' +__all__ = 'GraphContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'RunEndT', 'NodeRunEndT' RunEndT = TypeVar('RunEndT', default=None) +"""Type variable for the return type of a graph [`run`][pydantic_graph.graph.Graph.run].""" NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) +"""Type variable for the return type of a node [`run`][pydantic_graph.nodes.BaseNode.run].""" @dataclass @@ -21,6 +23,7 @@ class GraphContext(Generic[StateT]): """Context for a graph.""" state: StateT + """The state of the graph.""" class BaseNode(ABC, Generic[StateT, NodeRunEndT]): @@ -30,7 +33,23 @@ class BaseNode(ABC, Generic[StateT, NodeRunEndT]): """Set to `False` to not generate mermaid diagram notes from the class's docstring.""" @abstractmethod - async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: ... + async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: + """Run the node. + + This is an abstract method that must be implemented by subclasses. + + !!! note "Return types used at runtime" + The return type of this method are read by `pydantic_graph` at runtime and used to define which + nodes can be called next in the graph. This is displayed in [mermaid diagrams](mermaid.md) + and enforced when running the graph. + + Args: + ctx: The graph context. + + Returns: + The next node to run or [`End`][pydantic_graph.nodes.End] to signal the end of the graph. + """ + ... @classmethod @cache @@ -92,6 +111,7 @@ class End(Generic[RunEndT]): """Type to return from a node to signal the end of the graph.""" data: RunEndT + """Data to return from the graph.""" @dataclass @@ -99,12 +119,15 @@ class Edge: """Annotation to apply a label to an edge in a graph.""" label: str | None + """Label for the edge.""" @dataclass class NodeDef(Generic[StateT, NodeRunEndT]): """Definition of a node. + This is an internal representation of a node, it shouldn't be necessary to use it directly. + Used by [`Graph`][pydantic_graph.graph.Graph] to store information about a node, and when generating mermaid graphs. """ diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index b8dbff7a72..c01e23f1b4 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -6,15 +6,16 @@ from datetime import datetime from typing import TYPE_CHECKING, Generic, Literal, Union -from typing_extensions import Never, Self, TypeVar +from typing_extensions import Self, TypeVar from . import _utils -__all__ = 'AbstractState', 'StateT', 'NodeEvent', 'EndEvent', 'HistoryStep' +__all__ = 'AbstractState', 'StateT', 'NodeStep', 'EndStep', 'HistoryStep' if TYPE_CHECKING: - from pydantic_graph import BaseNode - from pydantic_graph.nodes import End + from .nodes import BaseNode, End, RunEndT +else: + RunEndT = TypeVar('RunEndT', default=None) class AbstractState(ABC): @@ -30,21 +31,24 @@ def deep_copy(self) -> Self: return copy.deepcopy(self) -RunEndT = TypeVar('RunEndT', default=None) -NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) StateT = TypeVar('StateT', bound=Union[None, AbstractState], default=None) +"""Type variable for the state in a graph.""" @dataclass -class NodeEvent(Generic[StateT, RunEndT]): +class NodeStep(Generic[StateT, RunEndT]): """History step describing the execution of a node in a graph.""" state: StateT + """The state of the graph after the node has run.""" node: BaseNode[StateT, RunEndT] + """The node that was run.""" start_ts: datetime = field(default_factory=_utils.now_utc) + """The timestamp when the node started running.""" duration: float | None = None - - kind: Literal['step'] = 'step' + """The duration of the node run in seconds.""" + kind: Literal['node'] = 'node' + """The kind of history step, can be used as a discriminator when deserializing history.""" def __post_init__(self): # Copy the state to prevent it from being modified by other code @@ -55,14 +59,17 @@ def summary(self) -> str: @dataclass -class EndEvent(Generic[StateT, RunEndT]): +class EndStep(Generic[StateT, RunEndT]): """History step describing the end of a graph run.""" state: StateT + """The state of the graph after the run.""" result: End[RunEndT] + """The result of the graph run.""" ts: datetime = field(default_factory=_utils.now_utc) - + """The timestamp when the graph run ended.""" kind: Literal['end'] = 'end' + """The kind of history step, can be used as a discriminator when deserializing history.""" def __post_init__(self): # Copy the state to prevent it from being modified by other code @@ -79,4 +86,5 @@ def _deep_copy_state(state: StateT) -> StateT: return state.deep_copy() -HistoryStep = Union[NodeEvent[StateT, RunEndT], EndEvent[StateT, RunEndT]] +HistoryStep = Union[NodeStep[StateT, RunEndT], EndStep[StateT, RunEndT]] +"""A step in the history of a graph run.""" diff --git a/pydantic_graph/pyproject.toml b/pydantic_graph/pyproject.toml index 17aaa1fba7..89cbad516c 100644 --- a/pydantic_graph/pyproject.toml +++ b/pydantic_graph/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "pydantic-graph" -version = "0.0.17" +version = "0.0.18" description = "Graph and state machine library" authors = [ { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, @@ -36,7 +36,6 @@ requires-python = ">=3.9" dependencies = [ "httpx>=0.27.2", "logfire-api>=1.2.0", - "pydantic>=2.10", ] [tool.hatch.build.targets.wheel] diff --git a/tests/graph/test_main.py b/tests/graph/test_main.py index a26e9d7663..f82e92e5e6 100644 --- a/tests/graph/test_main.py +++ b/tests/graph/test_main.py @@ -13,13 +13,13 @@ from pydantic_graph import ( BaseNode, End, - EndEvent, + EndStep, Graph, GraphContext, GraphRuntimeError, GraphSetupError, HistoryStep, - NodeEvent, + NodeStep, ) from ..conftest import IsFloat, IsNow @@ -60,25 +60,25 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq assert my_graph.name == 'my_graph' assert history == snapshot( [ - NodeEvent( + NodeStep( state=None, node=Float2String(input_data=3.14), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeEvent( + NodeStep( state=None, node=String2Length(input_data='3.14'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeEvent( + NodeStep( state=None, node=Double(input_data=4), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndEvent( + EndStep( state=None, result=End(data=8), ts=IsNow(tz=timezone.utc), @@ -90,37 +90,37 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq assert result == 42 assert history == snapshot( [ - NodeEvent( + NodeStep( state=None, node=Float2String(input_data=3.14159), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeEvent( + NodeStep( state=None, node=String2Length(input_data='3.14159'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeEvent( + NodeStep( state=None, node=Double(input_data=7), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeEvent( + NodeStep( state=None, node=String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx'), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeEvent( + NodeStep( state=None, node=Double(input_data=21), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndEvent( + EndStep( state=None, result=End(data=42), ts=IsNow(tz=timezone.utc), @@ -290,7 +290,7 @@ async def run(self, ctx: GraphContext) -> Foo: n = await g.next(None, Foo(), history) assert n == Bar() assert g.name == 'g' - assert history == snapshot([NodeEvent(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat())]) + assert history == snapshot([NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat())]) assert isinstance(n, Bar) n2 = await g.next(None, n, history) @@ -298,7 +298,7 @@ async def run(self, ctx: GraphContext) -> Foo: assert history == snapshot( [ - NodeEvent(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - NodeEvent(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeStep(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), ] ) diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 9f790fc0ca..1f40df5c8c 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -11,7 +11,7 @@ import pytest from inline_snapshot import snapshot -from pydantic_graph import BaseNode, Edge, End, EndEvent, Graph, GraphContext, GraphSetupError, NodeEvent +from pydantic_graph import BaseNode, Edge, End, EndStep, Graph, GraphContext, GraphSetupError, NodeStep from pydantic_graph.nodes import NodeDef from ..conftest import IsFloat, IsNow @@ -60,19 +60,19 @@ async def test_run_graph(): assert result is None assert history == snapshot( [ - NodeEvent( + NodeStep( state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - NodeEvent( + NodeStep( state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndEvent(state=None, result=End(data=None), ts=IsNow(tz=timezone.utc)), + EndStep(state=None, result=End(data=None), ts=IsNow(tz=timezone.utc)), ] ) diff --git a/uv.lock b/uv.lock index 498d850d4c..89e2fb1456 100644 --- a/uv.lock +++ b/uv.lock @@ -2719,30 +2719,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fd/ab/718d9a1c41bb8d3e0e04d15b68b8afc135f8fcf552705b62f226225065c7/pydantic_core-2.27.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:62ba45e21cf6571d7f716d903b5b7b6d2617e2d5d67c0923dc47b9d41369f840", size = 2002035 }, ] -[[package]] -name = "pydub" -version = "0.25.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fe/9a/e6bca0eed82db26562c73b5076539a4a08d3cffd19c3cc5913a3e61145fd/pydub-0.25.1.tar.gz", hash = "sha256:980a33ce9949cab2a569606b65674d748ecbca4f0796887fd6f46173a7b0d30f", size = 38326 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a6/53/d78dc063216e62fc55f6b2eebb447f6a4b0a59f55c8406376f76bf959b08/pydub-0.25.1-py2.py3-none-any.whl", hash = "sha256:65617e33033874b59d87db603aa1ed450633288aefead953b30bded59cb599a6", size = 32327 }, -] - [[package]] name = "pydantic-graph" -version = "0.0.17" +version = "0.0.18" source = { editable = "pydantic_graph" } dependencies = [ { name = "httpx" }, { name = "logfire-api" }, - { name = "pydantic" }, ] [package.metadata] requires-dist = [ { name = "httpx", specifier = ">=0.27.2" }, { name = "logfire-api", specifier = ">=1.2.0" }, - { name = "pydantic", specifier = ">=2.10" }, +] + +[[package]] +name = "pydub" +version = "0.25.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/9a/e6bca0eed82db26562c73b5076539a4a08d3cffd19c3cc5913a3e61145fd/pydub-0.25.1.tar.gz", hash = "sha256:980a33ce9949cab2a569606b65674d748ecbca4f0796887fd6f46173a7b0d30f", size = 38326 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/53/d78dc063216e62fc55f6b2eebb447f6a4b0a59f55c8406376f76bf959b08/pydub-0.25.1-py2.py3-none-any.whl", hash = "sha256:65617e33033874b59d87db603aa1ed450633288aefead953b30bded59cb599a6", size = 32327 }, ] [[package]] From 707129ffde7227b73da904034909db0d09c8aafb Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 11 Jan 2025 14:49:22 +0000 Subject: [PATCH 37/57] fix state, more docs --- .../email_extract_graph.py | 156 ------------------ pydantic_graph/pydantic_graph/__init__.py | 3 +- pydantic_graph/pydantic_graph/graph.py | 92 ++++++++++- pydantic_graph/pydantic_graph/state.py | 51 +++--- tests/graph/{test_main.py => test_graph.py} | 0 tests/graph/test_state.py | 58 +++++++ tests/test_examples.py | 2 +- 7 files changed, 167 insertions(+), 195 deletions(-) delete mode 100644 examples/pydantic_ai_examples/email_extract_graph.py rename tests/graph/{test_main.py => test_graph.py} (100%) create mode 100644 tests/graph/test_state.py diff --git a/examples/pydantic_ai_examples/email_extract_graph.py b/examples/pydantic_ai_examples/email_extract_graph.py deleted file mode 100644 index 82c98a370e..0000000000 --- a/examples/pydantic_ai_examples/email_extract_graph.py +++ /dev/null @@ -1,156 +0,0 @@ -from __future__ import annotations as _annotations - -import asyncio -from dataclasses import dataclass -from datetime import datetime, timedelta - -import logfire -from devtools import debug -from pydantic import BaseModel -from pydantic_graph import AbstractState, BaseNode, End, Graph, GraphContext - -from pydantic_ai import Agent, RunContext - -logfire.configure(send_to_logfire='if-token-present') - - -class EventDetails(BaseModel): - title: str - location: str - start_ts: datetime - end_ts: datetime - - -class State(AbstractState, BaseModel): - email_content: str - skip_events: list[str] = [] - attempt: int = 0 - - def serialize(self) -> bytes | None: - return self.model_dump_json(exclude={'email_content'}).encode() - - -class RawEventDetails(BaseModel): - title: str - location: str - start_ts: str - duration: str - - -extract_agent = Agent('openai:gpt-4o', result_type=RawEventDetails, deps_type=list[str]) - - -@extract_agent.system_prompt -def extract_system_prompt(ctx: RunContext[list[str]]): - prompt = 'Extract event details from the email body.' - if ctx.deps: - skip_events = '\n'.join(ctx.deps) - prompt += f'\n\nDo not return the following events:\n{skip_events}' - return prompt - - -@dataclass -class ExtractEvent(BaseNode[State]): - async def run(self, ctx: GraphContext[State]) -> CleanEvent: - event = await extract_agent.run( - ctx.state.email_content, deps=ctx.state.skip_events - ) - return CleanEvent(event.data) - - -# agent used to extract the timestamp from the string in `CleanEvent` -timestamp_agent = Agent('openai:gpt-4o', result_type=datetime) - - -@timestamp_agent.system_prompt -def timestamp_system_prompt(): - return f'Extract the timestamp from the string, the current timestamp is: {datetime.now().isoformat()}' - - -# agent used to extract the duration from the string in `CleanEvent` -duration_agent = Agent( - 'openai:gpt-4o', - result_type=timedelta, - system_prompt='Extract the duration from the string as an ISO 8601 interval.', -) - - -@dataclass -class CleanEvent(BaseNode[State]): - input_data: RawEventDetails - - async def run(self, ctx: GraphContext[State]) -> InspectEvent: - start_ts, duration = await asyncio.gather( - timestamp_agent.run(self.input_data.start_ts), - duration_agent.run(self.input_data.duration), - ) - return InspectEvent( - EventDetails( - title=self.input_data.title, - location=self.input_data.location, - start_ts=start_ts.data, - end_ts=start_ts.data + duration.data, - ) - ) - - -@dataclass -class InspectEvent(BaseNode[State, EventDetails | None]): - input_data: EventDetails - - async def run( - self, ctx: GraphContext[State] - ) -> ExtractEvent | End[EventDetails | None]: - now = datetime.now() - if self.input_data.start_ts.tzinfo is not None: - now = now.astimezone(self.input_data.start_ts.tzinfo) - - if self.input_data.start_ts > now: - return End(self.input_data) - ctx.state.attempt += 1 - if ctx.state.attempt > 2: - return End(None) - else: - ctx.state.skip_events.append(self.input_data.title) - return ExtractEvent() - - -graph = Graph[State, EventDetails | None]( - nodes=( - ExtractEvent, - CleanEvent, - InspectEvent, - ) -) -print(graph.mermaid_code(start_node=ExtractEvent)) - -email = """ -Hi Samuel, - -I hope this message finds you well! I wanted to share a quick update on our recent and upcoming team events. - -Firstly, a big thank you to everyone who participated in last month's -Team Building Retreat held on November 15th 2024 for 1 day. -It was a fantastic opportunity to enhance our collaboration and communication skills while having fun. Your -feedback was incredibly positive, and we're already planning to make the next retreat even better! - -Looking ahead, I'm excited to invite you all to our Annual Year-End Gala on January 20th 2025. -This event will be held at the Grand City Ballroom starting at 6 PM until 8pm. It promises to be an evening full -of entertainment, good food, and great company, celebrating the achievements and hard work of our amazing team -over the past year. - -Please mark your calendars and RSVP by January 10th. I hope to see all of you there! - -Best regards, -""" - - -async def main(): - state = State(email_content=email) - history = [] - result = await graph.run(state, ExtractEvent()) - debug(result, history) - - -if __name__ == '__main__': - asyncio.run(main()) diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index 5102117738..135f7529a5 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -1,7 +1,7 @@ from .exceptions import GraphRuntimeError, GraphSetupError from .graph import Graph from .nodes import BaseNode, Edge, End, GraphContext -from .state import AbstractState, EndStep, HistoryStep, NodeStep +from .state import EndStep, HistoryStep, NodeStep __all__ = ( 'Graph', @@ -9,7 +9,6 @@ 'End', 'GraphContext', 'Edge', - 'AbstractState', 'EndStep', 'HistoryStep', 'NodeStep', diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index eed901fe95..db586f9525 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -6,7 +6,7 @@ from pathlib import Path from time import perf_counter from types import FrameType -from typing import TYPE_CHECKING, Any, Generic +from typing import TYPE_CHECKING, Any, Callable, Generic import logfire_api from typing_extensions import Literal, Never, ParamSpec, TypeVar, Unpack, assert_never @@ -14,7 +14,7 @@ from . import _utils, exceptions, mermaid from ._utils import get_parent_namespace from .nodes import BaseNode, End, GraphContext, NodeDef -from .state import EndStep, HistoryStep, NodeStep, StateT +from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state __all__ = ('Graph',) @@ -27,16 +27,74 @@ @dataclass(init=False) class Graph(Generic[StateT, RunEndT]): - """Definition of a graph.""" + """Definition of a graph. + + In `pydantic-graph`, a graph is a collection of nodes that can be run in sequence. The nodes define + their outgoing edges — e.g. which nodes may be run next, and thereby the structure of the graph. + + Here's a very simple example of a graph which increments a number by 1, but makes sure the number is never + 42 at the end. Note in this example we don't run the graph, but instead just generate a mermaid diagram + using [`mermaid_code`][pydantic_graph.graph.Graph.mermaid_code]: + + ```py {title="never_42.py"} + from __future__ import annotations + + from dataclasses import dataclass + from pydantic_graph import BaseNode, End, Graph, GraphContext + + + @dataclass + class MyState: + number: int + + + @dataclass + class Increment(BaseNode[MyState]): + async def run(self, ctx: GraphContext) -> Check42: + ctx.state.number += 1 + return Check42() + + + @dataclass + class Check42(BaseNode[MyState]): + async def run(self, ctx: GraphContext) -> Increment | End: + if ctx.state.number == 42: + return Increment() + else: + return End(None) + + + never_42_graph = Graph(nodes=(Increment, Check42)) + print(never_42_graph.mermaid_code(start_node=Increment)) + ``` + _(This example is complete, it can be run "as is")_ + + The rendered mermaid diagram will look like this: + + ```mermaid + --- + title: never_42_graph + --- + stateDiagram-v2 + [*] --> Increment + Increment --> Check42 + Check42 --> Increment + Check42 --> [*] + ``` + + See [`run`][pydantic_graph.graph.Graph.run] For an example of running graph. + """ name: str | None node_defs: dict[str, NodeDef[StateT, RunEndT]] + snapshot_state: Callable[[StateT], StateT] def __init__( self, *, nodes: Sequence[type[BaseNode[StateT, RunEndT]]], name: str | None = None, + snapshot_state: Callable[[StateT], StateT] = deep_copy_state, ): """Create a graph from a sequence of nodes. @@ -45,8 +103,12 @@ def __init__( state type. name: Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method. + snapshot_state: A function to snapshot the state of the graph, this is used in + [`NodeStep`][pydantic_graph.state.NodeStep] and [`EndStep`][pydantic_graph.state.EndStep] to record + the state before each step. """ self.name = name + self.snapshot_state = snapshot_state parent_namespace = get_parent_namespace(inspect.currentframe()) self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = {} @@ -80,7 +142,7 @@ async def next( if node_id not in self.node_defs: raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') - history_step: NodeStep[StateT, RunEndT] | None = NodeStep(state, node) + history_step: NodeStep[StateT, RunEndT] = NodeStep(state, node) history.append(history_step) ctx = GraphContext(state) @@ -101,10 +163,28 @@ async def run( Args: state: The initial state of the graph. - start_node: the first node to run. + start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, + you need to provide the starting node. infer_name: Whether to infer the graph name from the calling frame. Returns: The result type from ending the run and the history of the run. + + Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: + + ```py {title="run_never_42.py"} + from never_42 import MyState, Increment, never_42_graph + + async def main(): + state = MyState(1) + _, history = await never_42_graph.run(state, Increment()) + print(state) + print(history) + + state = MyState(41) + _, history = await never_42_graph.run(state, Increment()) + print(state) + print(history) + ``` """ history: list[HistoryStep[StateT, RunEndT]] = [] if infer_name and self.name is None: @@ -118,7 +198,7 @@ async def run( while True: next_node = await self.next(state, start_node, history, infer_name=False) if isinstance(next_node, End): - history.append(EndStep(state, next_node)) + history.append(EndStep(state, next_node, snapshot_state=self.snapshot_state)) run_span.set_attribute('history', history) return next_node.data, history elif isinstance(next_node, BaseNode): diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index c01e23f1b4..9cd197a637 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -1,16 +1,15 @@ from __future__ import annotations as _annotations import copy -from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import InitVar, dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Generic, Literal, Union +from typing import TYPE_CHECKING, Callable, Generic, Literal, Union -from typing_extensions import Self, TypeVar +from typing_extensions import TypeVar from . import _utils -__all__ = 'AbstractState', 'StateT', 'NodeStep', 'EndStep', 'HistoryStep' +__all__ = 'StateT', 'NodeStep', 'EndStep', 'HistoryStep', 'deep_copy_state' if TYPE_CHECKING: from .nodes import BaseNode, End, RunEndT @@ -18,21 +17,16 @@ RunEndT = TypeVar('RunEndT', default=None) -class AbstractState(ABC): - """Abstract class for a state object.""" - - @abstractmethod - def serialize(self) -> bytes | None: - """Serialize the state object.""" - raise NotImplementedError - - def deep_copy(self) -> Self: - """Create a deep copy of the state object.""" - return copy.deepcopy(self) +StateT = TypeVar('StateT', default=None) +"""Type variable for the state in a graph.""" -StateT = TypeVar('StateT', bound=Union[None, AbstractState], default=None) -"""Type variable for the state in a graph.""" +def deep_copy_state(state: StateT) -> StateT: + """Default method for snapshotting the state in a graph run, uses [`copy.deepcopy`][copy.deepcopy].""" + if state is None: + return state + else: + return copy.deepcopy(state) @dataclass @@ -40,7 +34,7 @@ class NodeStep(Generic[StateT, RunEndT]): """History step describing the execution of a node in a graph.""" state: StateT - """The state of the graph after the node has run.""" + """The state of the graph before the node is run.""" node: BaseNode[StateT, RunEndT] """The node that was run.""" start_ts: datetime = field(default_factory=_utils.now_utc) @@ -49,10 +43,12 @@ class NodeStep(Generic[StateT, RunEndT]): """The duration of the node run in seconds.""" kind: Literal['node'] = 'node' """The kind of history step, can be used as a discriminator when deserializing history.""" + snapshot_state: InitVar[Callable[[StateT], StateT]] = deep_copy_state + """Function to snapshot the state of the graph.""" - def __post_init__(self): + def __post_init__(self, snapshot_state: Callable[[StateT], StateT]): # Copy the state to prevent it from being modified by other code - self.state = _deep_copy_state(self.state) + self.state = snapshot_state(self.state) def summary(self) -> str: return str(self.node) @@ -70,21 +66,16 @@ class EndStep(Generic[StateT, RunEndT]): """The timestamp when the graph run ended.""" kind: Literal['end'] = 'end' """The kind of history step, can be used as a discriminator when deserializing history.""" + snapshot_state: InitVar[Callable[[StateT], StateT]] = deep_copy_state + """Function to snapshot the state of the graph.""" - def __post_init__(self): + def __post_init__(self, snapshot_state: Callable[[StateT], StateT]): # Copy the state to prevent it from being modified by other code - self.state = _deep_copy_state(self.state) + self.state = snapshot_state(self.state) def summary(self) -> str: return str(self.result) -def _deep_copy_state(state: StateT) -> StateT: - if state is None: - return state - else: - return state.deep_copy() - - HistoryStep = Union[NodeStep[StateT, RunEndT], EndStep[StateT, RunEndT]] """A step in the history of a graph run.""" diff --git a/tests/graph/test_main.py b/tests/graph/test_graph.py similarity index 100% rename from tests/graph/test_main.py rename to tests/graph/test_graph.py diff --git a/tests/graph/test_state.py b/tests/graph/test_state.py new file mode 100644 index 0000000000..b363e37674 --- /dev/null +++ b/tests/graph/test_state.py @@ -0,0 +1,58 @@ +from __future__ import annotations as _annotations + +from dataclasses import dataclass +from datetime import timezone + +import pytest +from inline_snapshot import snapshot + +from pydantic_graph import BaseNode, End, EndStep, Graph, GraphContext, NodeStep + +from ..conftest import IsFloat, IsNow + +pytestmark = pytest.mark.anyio + + +async def test_run_graph(): + @dataclass + class MyState: + x: int + y: str + + @dataclass + class Foo(BaseNode[MyState]): + async def run(self, ctx: GraphContext[MyState]) -> Bar: + ctx.state.x += 1 + return Bar() + + @dataclass + class Bar(BaseNode[MyState, str]): + async def run(self, ctx: GraphContext[MyState]) -> End[str]: + ctx.state.y += 'y' + return End(f'x={ctx.state.x} y={ctx.state.y}') + + graph = Graph(nodes=(Foo, Bar)) + s = MyState(1, '') + result, history = await graph.run(s, Foo()) + assert result == snapshot('x=2 y=y') + assert history == snapshot( + [ + NodeStep( + state=MyState(x=1, y=''), + node=Foo(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + ), + NodeStep( + state=MyState(x=2, y=''), + node=Bar(), + start_ts=IsNow(tz=timezone.utc), + duration=IsFloat(), + ), + EndStep( + state=MyState(x=2, y='y'), + result=End('x=2 y=y'), + ts=IsNow(tz=timezone.utc), + ), + ] + ) diff --git a/tests/test_examples.py b/tests/test_examples.py index eb12c18cd0..ebe448148e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -48,7 +48,7 @@ def find_filter_examples() -> Iterable[CodeExample]: - for ex in find_examples('docs', 'pydantic_ai_slim'): + for ex in find_examples('docs', 'pydantic_ai_slim', 'pydantic_graph'): if ex.path.name != '_utils.py': yield ex From 80d4713b54ee8ae7d9f921a80eb5e83bcda18d78 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 11 Jan 2025 15:13:29 +0000 Subject: [PATCH 38/57] fix graph api examples --- pydantic_graph/pydantic_graph/graph.py | 125 ++++++++++++++++++++----- tests/test_examples.py | 1 + 2 files changed, 103 insertions(+), 23 deletions(-) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index db586f9525..d2ff4a155f 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -33,28 +33,25 @@ class Graph(Generic[StateT, RunEndT]): their outgoing edges — e.g. which nodes may be run next, and thereby the structure of the graph. Here's a very simple example of a graph which increments a number by 1, but makes sure the number is never - 42 at the end. Note in this example we don't run the graph, but instead just generate a mermaid diagram - using [`mermaid_code`][pydantic_graph.graph.Graph.mermaid_code]: + 42 at the end. ```py {title="never_42.py"} from __future__ import annotations from dataclasses import dataclass - from pydantic_graph import BaseNode, End, Graph, GraphContext + from pydantic_graph import BaseNode, End, Graph, GraphContext @dataclass class MyState: number: int - @dataclass class Increment(BaseNode[MyState]): async def run(self, ctx: GraphContext) -> Check42: ctx.state.number += 1 return Check42() - @dataclass class Check42(BaseNode[MyState]): async def run(self, ctx: GraphContext) -> Increment | End: @@ -63,26 +60,13 @@ async def run(self, ctx: GraphContext) -> Increment | End: else: return End(None) - never_42_graph = Graph(nodes=(Increment, Check42)) - print(never_42_graph.mermaid_code(start_node=Increment)) ``` _(This example is complete, it can be run "as is")_ - The rendered mermaid diagram will look like this: - - ```mermaid - --- - title: never_42_graph - --- - stateDiagram-v2 - [*] --> Increment - Increment --> Check42 - Check42 --> Increment - Check42 --> [*] - ``` - - See [`run`][pydantic_graph.graph.Graph.run] For an example of running graph. + See [`run`][pydantic_graph.graph.Graph.run] For an example of running graph, and + [`mermaid_code`][pydantic_graph.graph.Graph.mermaid_code] for an example of generating a mermaid diagram + from the graph. """ name: str | None @@ -172,18 +156,82 @@ async def run( Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: ```py {title="run_never_42.py"} - from never_42 import MyState, Increment, never_42_graph + from never_42 import Increment, MyState, never_42_graph async def main(): state = MyState(1) _, history = await never_42_graph.run(state, Increment()) print(state) + #> MyState(number=2) print(history) + ''' + [ + NodeStep( + state=MyState(number=1), + node=Increment(), + start_ts=datetime.datetime(...), + duration=0.0..., + kind='node', + ), + NodeStep( + state=MyState(number=2), + node=Check42(), + start_ts=datetime.datetime(...), + duration=0.0..., + kind='node', + ), + EndStep( + state=MyState(number=2), + result=End(data=None), + ts=datetime.datetime(...), + kind='end', + ), + ] + ''' state = MyState(41) _, history = await never_42_graph.run(state, Increment()) print(state) + #> MyState(number=43) print(history) + ''' + [ + NodeStep( + state=MyState(number=41), + node=Increment(), + start_ts=datetime.datetime(...), + duration=0.0..., + kind='node', + ), + NodeStep( + state=MyState(number=42), + node=Check42(), + start_ts=datetime.datetime(...), + duration=0.0..., + kind='node', + ), + NodeStep( + state=MyState(number=42), + node=Increment(), + start_ts=datetime.datetime(...), + duration=0.0..., + kind='node', + ), + NodeStep( + state=MyState(number=43), + node=Check42(), + start_ts=datetime.datetime(...), + duration=0.0..., + kind='node', + ), + EndStep( + state=MyState(number=43), + result=End(data=None), + ts=datetime.datetime(...), + kind='end', + ), + ] + ''' ``` """ history: list[HistoryStep[StateT, RunEndT]] = [] @@ -227,7 +275,7 @@ def mermaid_code( This method calls [`pydantic_graph.mermaid.generate_code`][pydantic_graph.mermaid.generate_code]. Args: - start_node: The node or nodes to start the graph from. + start_node: The node or nodes which can start the graph. title: The title of the diagram, use `False` to not include a title. edge_labels: Whether to include edge labels. notes: Whether to include notes on each node. @@ -237,6 +285,37 @@ def mermaid_code( Returns: The mermaid code for the graph, which can then be rendered as a diagram. + + Here's an example of generating a diagram for the graph from [above][pydantic_graph.graph.Graph]: + + ```py {title="never_42.py"} + from never_42 import Increment, never_42_graph + + print(never_42_graph.mermaid_code(start_node=Increment)) + ''' + --- + title: never_42_graph + --- + stateDiagram-v2 + [*] --> Increment + Increment --> Check42 + Check42 --> Increment + Check42 --> [*] + ''' + ``` + + The rendered diagram will look like this: + + ```mermaid + --- + title: never_42_graph + --- + stateDiagram-v2 + [*] --> Increment + Increment --> Check42 + Check42 --> Increment + Check42 --> [*] + ``` """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) diff --git a/tests/test_examples.py b/tests/test_examples.py index ebe448148e..f4ae2d35b5 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -133,6 +133,7 @@ def test_docs_examples( def print_callback(s: str) -> str: s = re.sub(r'datetime\.datetime\(.+?\)', 'datetime.datetime(...)', s, flags=re.DOTALL) + s = re.sub(r'\d\.\d{4,}e-0\d', '0.0...', s) return re.sub(r'datetime.date\(', 'date(', s) From be1563df3fdf75d58c9df4c96f1cb169d9ad8408 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 11 Jan 2025 16:36:52 +0000 Subject: [PATCH 39/57] starting graph documentation --- docs/examples/question-graph.md | 37 +++++ docs/graph.md | 45 ++++++ docs/install.md | 10 ++ docs/multi-agent-applications.md | 7 +- .../pydantic_ai_examples/question_graph.py | 149 ++++++++++++++++++ mkdocs.yml | 2 + pydantic_ai_slim/pyproject.toml | 5 +- pydantic_graph/README.md | 2 +- 8 files changed, 248 insertions(+), 9 deletions(-) create mode 100644 docs/examples/question-graph.md create mode 100644 docs/graph.md create mode 100644 examples/pydantic_ai_examples/question_graph.py diff --git a/docs/examples/question-graph.md b/docs/examples/question-graph.md new file mode 100644 index 0000000000..e208135286 --- /dev/null +++ b/docs/examples/question-graph.md @@ -0,0 +1,37 @@ +# Question Graph + +Example of a graph for asking and evaluating questions. + +Demonstrates: + +* [`pydantic_graph`](../graph.md) + +## Running the Example + +With [dependencies installed and environment variables set](./index.md#usage), run: + +```bash +python/uv-run -m pydantic_ai_examples.question_graph +``` + +## Example Code + +```python {title="question_graph.py"} +#! examples/pydantic_ai_examples/question_graph.py +``` + +The mermaid diagram generated in this example looks like this: + +```mermaid +--- +title: question_graph +--- +stateDiagram-v2 + [*] --> Ask + Ask --> Answer: ask the question + Answer --> Evaluate: answer the question + Evaluate --> Congratulate + Evaluate --> Castigate + Congratulate --> [*]: success + Castigate --> Ask: try again +``` diff --git a/docs/graph.md b/docs/graph.md new file mode 100644 index 0000000000..e1e01bc421 --- /dev/null +++ b/docs/graph.md @@ -0,0 +1,45 @@ +# Graphs + +!!! danger "Don't use a nail gun unless you need a nail gun" + If PydanticAI [agents](agents.md) are a hammer, and [multi-agent workflows](multi-agent-applications.md) are a sledgehammer, then graphs are a nail gun, with flames down the side: + + * sure, nail guns look cooler than hammers + * but nail guns take a lot more setup than hammers + * and nail guns don't make you a better builder, they make you a builder with a nail gun + * Lastly, (and at the risk of torturing this metaphor), if you're a fan of medieval tools like mallets and untyped Python, you probably won't like nail guns or PydanticAI approach to graphs. (But then again, if you're not a fan of type hints in Python, you've probably already bounced off PydanticAI to use one of the toy agent frameworks — good luck, and feel free to borrow my sledgehammer) + + In short, graphs are a powerful tool, but they're not the right tool for every job. Please consider other [multi-agent approaches](multi-agent-applications.md) before proceeding. Unless you're sure you need a graph, you probably don't. + +Graphs and associated finite state machines (FSMs) are a powerful abstraction to model, control and visualize complex workflows. + +Alongside PydanticAI, we've developed `pydantic-graph` — an async graph and state machine library for Python where nodes and edges are defined in pure Python using type hints. + +While this library is developed as part of the PydanticAI; it has no dependency on `pydantic-ai` and can be considered as a pure graph library. You may find it useful whether or not you're using PydanticAI or even building with GenAI. + +## Installation + +`pydantic-graph` a required dependency of `pydantic-ai`, and an optional dependency of `pydantic-ai-slim`, see [installation instructions](install.md) for more information. You can also install it directly: + +```bash +pip/uv-add pydantic-graph +``` + +## Basic Usage + +TODO + +## Typing + +TODO + +## Running Graphs + +TODO + +## State Machines + +TODO + +## Mermaid Diagrams + +TODO diff --git a/docs/install.md b/docs/install.md index 1cd336611e..c450d8d885 100644 --- a/docs/install.md +++ b/docs/install.md @@ -44,6 +44,16 @@ For example, if you're using just [`OpenAIModel`][pydantic_ai.models.openai.Open pip/uv-add 'pydantic-ai-slim[openai]' ``` +`pydantic-ai-slim` has the following optional groups: + +* `logfire` — installs [`logfire`](logfire.md) [PyPI ↗](https://pypi.org/project/logfire){:target="_blank"} +* `graph` - installs [`pydantic-graph`](graph.md) [PyPI ↗](https://pypi.org/project/pydantic-graph){:target="_blank"} +* `openai` — installs `openai` [PyPI ↗](https://pypi.org/project/openai){:target="_blank"} +* `vertexai` — installs `google-auth` [PyPI ↗](https://pypi.org/project/google-auth){:target="_blank"} and `requests` [PyPI ↗](https://pypi.org/project/requests){:target="_blank"} +* `anthropic` — installs `anthropic` [PyPI ↗](https://pypi.org/project/anthropic){:target="_blank"} +* `groq` — installs `groq` [PyPI ↗](https://pypi.org/project/groq){:target="_blank"} +* `mistral` — installs `mistralai` [PyPI ↗](https://pypi.org/project/mistralai){:target="_blank"} + See the [models](models.md) documentation for information on which optional dependencies are required for each model. You can also install dependencies for multiple models and use cases, for example: diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index f2fb508150..1506883141 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -5,7 +5,7 @@ There are roughly four levels of complexity when building applications with Pyda 1. Single agent workflows — what most of the `pydantic_ai` documentation covers 2. [Agent delegation](#agent-delegation) — agents using another agent via tools 3. [Programmatic agent hand-off](#programmatic-agent-hand-off) — one agent runs, then application code calls another agent -4. [Graph based control flow](#pydanticai-graphs) — for the most complex cases, a graph-based state machine can be used to control the execution of multiple agents +4. [Graph based control flow](graph.md) — for the most complex cases, a graph-based state machine can be used to control the execution of multiple agents Of course, you can combine multiple strategies in a single application. @@ -330,11 +330,6 @@ graph TB seat_preference_agent --> END ``` -## PydanticAI Graphs - -!!! example "Work in progress" - This is a work in progress and not yet documented, see [#528](https://github.com/pydantic/pydantic-ai/issues/528) and [#539](https://github.com/pydantic/pydantic-ai/issues/539) - ## Examples The following examples demonstrate how to use dependencies in PydanticAI: diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py new file mode 100644 index 0000000000..5f2c42bca9 --- /dev/null +++ b/examples/pydantic_ai_examples/question_graph.py @@ -0,0 +1,149 @@ +"""Example of a graph for asking and evaluating questions. + +Run with: + + uv run -m pydantic_ai_examples.question_graph +""" + +from __future__ import annotations as _annotations + +from dataclasses import dataclass, field +from typing import Annotated + +import logfire +from pydantic_graph import BaseNode, Edge, End, Graph, GraphContext, HistoryStep + +from pydantic_ai import Agent +from pydantic_ai.format_as_xml import format_as_xml +from pydantic_ai.messages import ModelMessage + +# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured +logfire.configure(send_to_logfire='if-token-present') + +ask_agent = Agent('openai:gpt-4o', result_type=str) + + +@dataclass +class QuestionState: + ask_agent_messages: list[ModelMessage] = field(default_factory=list) + evaluate_agent_messages: list[ModelMessage] = field(default_factory=list) + + +@dataclass +class Ask(BaseNode[QuestionState]): + """Generate a question to ask the user. + + Uses the GPT-4o model to generate the question. + """ + + async def run( + self, ctx: GraphContext[QuestionState] + ) -> Annotated[Answer, Edge(label='ask the question')]: + result = await ask_agent.run( + 'Ask a simple question with a single correct answer.', + message_history=ctx.state.ask_agent_messages, + ) + ctx.state.ask_agent_messages += result.all_messages() + return Answer(result.data) + + +@dataclass +class Answer(BaseNode[QuestionState]): + """Get the answer to the question from the user. + + This node must be completed outside the graph run. + """ + + question: str + answer: str | None = None + + async def run( + self, ctx: GraphContext[QuestionState] + ) -> Annotated[Evaluate, Edge(label='answer the question')]: + assert self.answer is not None + return Evaluate(self.question, self.answer) + + +@dataclass +class EvaluationResult: + correct: bool + comment: str + + +evaluate_agent = Agent( + 'openai:gpt-4o', + result_type=EvaluationResult, + system_prompt='Given a question and answer, evaluate if the answer is correct.', + result_tool_name='evaluation', +) + + +@dataclass +class Evaluate(BaseNode[QuestionState]): + question: str + answer: str + + async def run( + self, + ctx: GraphContext[QuestionState], + ) -> Congratulate | Castigate: + result = await evaluate_agent.run( + format_as_xml({'question': self.question, 'answer': self.answer}), + message_history=ctx.state.evaluate_agent_messages, + ) + ctx.state.evaluate_agent_messages += result.all_messages() + if result.data.correct: + return Congratulate(result.data.comment) + else: + return Castigate(result.data.comment) + + +@dataclass +class Congratulate(BaseNode[QuestionState, None]): + """Congratulate the user and end.""" + + comment: str + + async def run( + self, ctx: GraphContext[QuestionState] + ) -> Annotated[End, Edge(label='success')]: + print(f'Correct answer! {self.comment}') + return End(None) + + +@dataclass +class Castigate(BaseNode[QuestionState]): + """Castigate the user, then ask another question.""" + + comment: str + + async def run( + self, ctx: GraphContext[QuestionState] + ) -> Annotated[Ask, Edge(label='try again')]: + print(f'Comment: {self.comment}') + return Ask() + + +question_graph = Graph(nodes=(Ask, Answer, Evaluate, Congratulate, Castigate)) +print(question_graph.mermaid_code(start_node=Ask, notes=False)) + + +async def main(): + state = QuestionState() + node = Ask() + history: list[HistoryStep[QuestionState]] = [] + with logfire.span('run questions graph'): + while True: + node = await question_graph.next(state, node, history) + if isinstance(node, End): + print('\n'.join(e.summary() for e in history)) + break + elif isinstance(node, Answer): + node.answer = input(f'{node.question} ') + # otherwise just continue + + +if __name__ == '__main__': + import asyncio + + asyncio.run(main()) diff --git a/mkdocs.yml b/mkdocs.yml index 92fe21f938..3ff7b55503 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -25,6 +25,7 @@ nav: - testing-evals.md - logfire.md - multi-agent-applications.md + - graph.md - Examples: - examples/index.md - examples/pydantic-model.md @@ -36,6 +37,7 @@ nav: - examples/stream-markdown.md - examples/stream-whales.md - examples/chat-app.md + - examples/question-graph.md - API Reference: - api/agent.md - api/tools.md diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 912ad47479..8cc686ce53 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -42,13 +42,14 @@ dependencies = [ ] [project.optional-dependencies] -graph = ["pydantic-graph==0.0.14"] +# WARNING if you add optional groups, please update docs/install.md +logfire = ["logfire>=2.3"] +graph = ["pydantic-graph==0.0.18"] openai = ["openai>=1.54.3"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"] anthropic = ["anthropic>=0.40.0"] groq = ["groq>=0.12.0"] mistral = ["mistralai>=1.2.5"] -logfire = ["logfire>=2.3"] [dependency-groups] dev = [ diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md index 2789d18ac1..069f58924d 100644 --- a/pydantic_graph/README.md +++ b/pydantic_graph/README.md @@ -9,7 +9,7 @@ Graph and finite state machine library. This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency -on `pydantic-ai` or related packages and does and can be considered as a pure graph library. +on `pydantic-ai` or related packages and can be considered as a pure graph library. As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. From 6fdd1e91a28c4ca27056682ac37339b1bbc472dd Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 11 Jan 2025 16:45:54 +0000 Subject: [PATCH 40/57] fix examples --- pydantic_graph/pydantic_graph/graph.py | 4 ++-- tests/test_examples.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index d2ff4a155f..789ba4a546 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -35,7 +35,7 @@ class Graph(Generic[StateT, RunEndT]): Here's a very simple example of a graph which increments a number by 1, but makes sure the number is never 42 at the end. - ```py {title="never_42.py"} + ```py {title="never_42.py" lint="not-imports"} from __future__ import annotations from dataclasses import dataclass @@ -155,7 +155,7 @@ async def run( Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: - ```py {title="run_never_42.py"} + ```py {title="run_never_42.py" lint="not-imports"} from never_42 import Increment, MyState, never_42_graph async def main(): diff --git a/tests/test_examples.py b/tests/test_examples.py index f4ae2d35b5..f52248a52a 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -95,7 +95,7 @@ def test_docs_examples( with (tmp_path / 'examples.json').open('w') as f: json.dump(examples, f) - ruff_ignore: list[str] = ['D'] + ruff_ignore: list[str] = ['D', 'Q001'] # `from bank_database import DatabaseConn` wrongly sorted in imports # waiting for https://github.com/pydantic/pytest-examples/issues/43 # and https://github.com/pydantic/pytest-examples/issues/46 From a882a6c1a1167a2053b0250304008011fc47dd53 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 11 Jan 2025 21:49:47 +0000 Subject: [PATCH 41/57] more graph documentation --- docs/.hooks/main.py | 17 +++ docs/graph.md | 197 ++++++++++++++++++++++++- pydantic_graph/pydantic_graph/graph.py | 10 +- pydantic_graph/pydantic_graph/state.py | 6 +- tests/test_examples.py | 9 +- 5 files changed, 221 insertions(+), 18 deletions(-) diff --git a/docs/.hooks/main.py b/docs/.hooks/main.py index 6b339ad6a8..606e8cb616 100644 --- a/docs/.hooks/main.py +++ b/docs/.hooks/main.py @@ -19,11 +19,28 @@ def on_page_markdown(markdown: str, page: Page, config: Config, files: Files) -> return markdown +# path to the main mkdocs material bundle file, found during `on_env` +bundle_path: Path | None = None + + def on_env(env: Environment, config: Config, files: Files) -> Environment: + global bundle_path + for file in files: + if re.match('assets/javascripts/bundle.[a-z0-9]+.min.js', file.src_uri): + bundle_path = Path(file.dest_dir) / file.src_uri + env.globals['build_timestamp'] = str(int(time.time())) return env +def on_post_build(config: Config) -> None: + """Inject extra CSS into mermaid styles to avoid titles being the same color as the background in dark mode.""" + if bundle_path.exists(): + content = bundle_path.read_text() + content, _ = re.subn(r'}(\.statediagram)', '}.statediagramTitleText{fill:#888}\1', content, count=1) + bundle_path.write_text(content) + + def replace_uv_python_run(markdown: str) -> str: return re.sub(r'```bash\n(.*?)(python/uv[\- ]run|pip/uv[\- ]add|py-cli)(.+?)\n```', sub_run, markdown) diff --git a/docs/graph.md b/docs/graph.md index e1e01bc421..0765d5e265 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -6,29 +6,210 @@ * sure, nail guns look cooler than hammers * but nail guns take a lot more setup than hammers * and nail guns don't make you a better builder, they make you a builder with a nail gun - * Lastly, (and at the risk of torturing this metaphor), if you're a fan of medieval tools like mallets and untyped Python, you probably won't like nail guns or PydanticAI approach to graphs. (But then again, if you're not a fan of type hints in Python, you've probably already bounced off PydanticAI to use one of the toy agent frameworks — good luck, and feel free to borrow my sledgehammer) + * Lastly, (and at the risk of torturing this metaphor), if you're a fan of medieval tools like mallets and untyped Python, you probably won't like nail guns or our approach to graphs. (But then again, if you're not a fan of type hints in Python, you've probably already bounced off PydanticAI to use one of the toy agent frameworks — good luck, and feel free to borrow my sledgehammer when you realize you need it) - In short, graphs are a powerful tool, but they're not the right tool for every job. Please consider other [multi-agent approaches](multi-agent-applications.md) before proceeding. Unless you're sure you need a graph, you probably don't. + In short, graphs are a powerful tool, but they're not the right tool for every job. Please consider other [multi-agent approaches](multi-agent-applications.md) before proceeding. -Graphs and associated finite state machines (FSMs) are a powerful abstraction to model, control and visualize complex workflows. + Unless you're sure you need a graph, you probably don't. -Alongside PydanticAI, we've developed `pydantic-graph` — an async graph and state machine library for Python where nodes and edges are defined in pure Python using type hints. +Graphs and finite state machines (FSMs) are a powerful abstraction to model, control and visualize complex workflows. + +Alongside PydanticAI, we've developed `pydantic-graph` — an async graph and state machine library for Python where nodes and edges are defined using type hints. While this library is developed as part of the PydanticAI; it has no dependency on `pydantic-ai` and can be considered as a pure graph library. You may find it useful whether or not you're using PydanticAI or even building with GenAI. +`pydantic-graph` is designed for advanced users and makes heavy use of Python generics and types hints. It is not designed to b as beginner-friendly as PydanticAI. + ## Installation -`pydantic-graph` a required dependency of `pydantic-ai`, and an optional dependency of `pydantic-ai-slim`, see [installation instructions](install.md) for more information. You can also install it directly: +`pydantic-graph` is a required dependency of `pydantic-ai`, and an optional dependency of `pydantic-ai-slim`, see [installation instructions](install.md#slim-install) for more information. You can also install it directly: ```bash pip/uv-add pydantic-graph ``` +## Graph Types + +!!! note "Every Early beta" + Graph support was [introduced](https://github.com/pydantic/pydantic-ai/pull/528) in v0.0.19 and is in very earlier beta. The API is subject to change. The documentation is incomplete. The implementation is incomplete. + +Graphs are made up of a few key components: + +### Nodes + +Subclasses of [`BaseNode`][pydantic_graph.nodes.BaseNode] to define nodes. + +Nodes which are generally [`dataclass`es][dataclasses.dataclass] include: + +* any parameters required when calling the node +* the business logic to execute the node +* return annotations which are read by `pydantic-graph` to determine the outgoing edges of the node + +Nodes are generic in: + +* **state** which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.state.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter +* **graph return type** this only applies if the node returns [`End`][pydantic_graph.nodes.End], [`RunEndT`][pydantic_graph.nodes.RunEndT] has a default of [Never][typing.Never] so this generic parameter can be omitted if the node doesn't return `End`, but must be included if it does. + +### GraphContext + +[`GraphContext`][pydantic_graph.nodes.GraphContext] — The context for the graph run, similar to PydanticAI's [`RunContext`][pydantic_ai.tools.RunContext], this holds the state of the graph and is passed to nodes when they're run. + +`GraphContext` is generic in the state type of the graph it's used in, [`StateT`][pydantic_graph.state.StateT]. + +### End + +[`End`][pydantic_graph.nodes.End] — return value to indicates the graph run should end. + +`End` is generic in the graph return type of the graph it's used in, [`RunEndT`][pydantic_graph.nodes.RunEndT]. + +### Graph + +[`Graph`][pydantic_graph.graph.Graph] — The graph itself, made up of a set of nodes. + +`Graph` is generic in: + +* **state** the state type of the graph, [`StateT`][pydantic_graph.state.StateT] +* **graph return type** the return type of the graph run, [`RunEndT`][pydantic_graph.nodes.RunEndT] + ## Basic Usage -TODO +Here's an example of a graph which represents a vending machine where the user may insert coins and select a product to purchase. + +```python {title="vending_machine.py"} +from __future__ import annotations + +from dataclasses import dataclass + +from rich.prompt import Prompt + +from pydantic_graph import BaseNode, End, Graph, GraphContext + + +@dataclass +class MachineState: # (1)! + user_balance: float = 0.0 + product: str | None = None + + +@dataclass +class InsertCoin(BaseNode[MachineState]): # (3)! + async def run(self, ctx: GraphContext[MachineState]) -> CoinsInserted: # (16)! + return CoinsInserted(float(Prompt.ask('Insert coins'))) # (4)! + + +@dataclass +class CoinsInserted(BaseNode[MachineState]): + amount: float # (5)! + + async def run( + self, ctx: GraphContext[MachineState] + ) -> SelectProduct | Purchase: # (17)! + ctx.state.user_balance += self.amount # (6)! + if ctx.state.product is not None: # (7)! + return Purchase(ctx.state.product) + else: + return SelectProduct() + + +@dataclass +class SelectProduct(BaseNode[MachineState]): + async def run(self, ctx: GraphContext[MachineState]) -> Purchase: + return Purchase(Prompt.ask('Select product')) + + +PRODUCT_PRICES = { # (2)! + 'water': 1.25, + 'soda': 1.50, + 'crisps': 1.75, + 'chocolate': 2.00, +} + + +@dataclass +class Purchase(BaseNode[MachineState, None]): # (18)! + product: str + + async def run( + self, ctx: GraphContext[MachineState] + ) -> End | InsertCoin | SelectProduct: + if price := PRODUCT_PRICES.get(self.product): # (8)! + ctx.state.product = self.product # (9)! + if ctx.state.user_balance >= price: # (10)! + ctx.state.user_balance -= price + return End(None) + else: + diff = price - ctx.state.user_balance + print(f'Not enough money for {self.product}, need {diff:0.2f} more') + #> Not enough money for crisps, need 0.75 more + return InsertCoin() # (11)! + else: + print(f'No such product: {self.product}, try again') + return SelectProduct() # (12)! + + +vending_machine_graph = Graph( # (13)! + nodes=[InsertCoin, CoinsInserted, SelectProduct, Purchase] +) + + +async def main(): + state = MachineState() # (14)! + await vending_machine_graph.run(state, InsertCoin()) # (15)! + print(f'purchase successful item={state.product} change={state.user_balance:0.2f}') + #> purchase successful item=crisps change=0.25 +``` + +1. The state of the vending machine is defined as a dataclass with the user's balance and the product they've selected, if any. +2. A dictionary of products mapped to prices. +3. The `InsertCoin` node, [`BaseNode`][pydantic_graph.nodes.BaseNode] is parameterized with `MachineState` as that's the state used in this graph. +4. The `InsertCoin` node prompts the user to insert coins. Keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using `input` within node, see [below](#running-graphs) for how control flow can be managed when nodes require external input. +5. The `CoinsInserted` node, again this is a [`dataclass`][dataclasses.dataclass], in this case with one field `amount`, thus nodes calling `CoinsInserted` must provide an amount. +6. Update the user's balance with the amount inserted. +7. If the user has already selected a product, go to `Purchase`, otherwise go to `SelectProduct`. +8. In the `Purchase` node, look up the price of the product if the user entered a valid product. +9. If the user did enter a valid product, set the product in the state so we don't revisit `SelectProduct`. +10. If the balance is enough to purchase the product, adjust the balance to reflect the purchase and return [`End`][pydantic_graph.nodes.End] to end the graph. We're not using the run return type, so we call `End` with `None`. +11. If the balance is insufficient, to go `InsertCoin` to prompt the user to insert more coins. +12. If the product is invalid, go to `SelectProduct` to prompt the user to select a product again. +13. The graph is created by passing a list of nodes to [`Graph`][pydantic_graph.graph.Graph]. Order of nodes is not important, but will alter how [diagramss](#mermaid-diagrams) are displayed. +14. Initialize the state, this will be passed to the graph run and mutated as the graph runs. +15. Run the graph with the initial state, since the graph can be run from any node, we must pass the start node, in this case `InsertCoin`. [`Graph.run`][pydantic_graph.graph.Graph.run] returns a tuple of the return value (`None`) in this case, and the [history][pydantic_graph.state.HistoryStep] of the graph run. +16. The return type of the node's [`run`][pydantic_graph.nodes.BaseNode.run] method is important, it's used to determine the outgoing edges of the node, this in turn is used to render [mermaid diagrams](#mermaid-diagrams) and is enforced at runtime. +17. The return type of `CoinsInserted`s [`run`][pydantic_graph.nodes.BaseNode.run] method is a union, meaning multiple outgoing edges are possible. +18. Unlike other nodes `Purchase` can end the run, so the [`RunEndT`][pydantic_graph.nodes.RunEndT] generic parameter must be set, in this case it's `None` since the graph run return type is `None`. + +_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_ + +A [mermaid diagram](#mermaid-diagrams) for this graph can be generated with the following code: + +```py {title="vending_machine_diagram.py"} +from vending_machine import InsertCoin, vending_machine_graph + +vending_machine_graph.mermaid_code(start_node=InsertCoin) +``` + +_(This example is complete, it can be run "as is")_ + +The diagram generated by the above code is: + +```mermaid +--- +title: vending_machine_graph +--- +stateDiagram-v2 + [*] --> InsertCoin + InsertCoin --> CoinsInserted + CoinsInserted --> SelectProduct + CoinsInserted --> Purchase + SelectProduct --> Purchase + Purchase --> InsertCoin + Purchase --> SelectProduct + Purchase --> [*] +``` + +See [below](#mermaid-diagrams) for more information on generating diagrams. -## Typing +## GenAI Example TODO @@ -36,7 +217,7 @@ TODO TODO -## State Machines +## State Machine TODO diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 789ba4a546..469905da2e 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -9,21 +9,17 @@ from typing import TYPE_CHECKING, Any, Callable, Generic import logfire_api -from typing_extensions import Literal, Never, ParamSpec, TypeVar, Unpack, assert_never +from typing_extensions import Literal, Unpack, assert_never from . import _utils, exceptions, mermaid from ._utils import get_parent_namespace -from .nodes import BaseNode, End, GraphContext, NodeDef +from .nodes import BaseNode, End, GraphContext, NodeDef, RunEndT from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state __all__ = ('Graph',) _logfire = logfire_api.Logfire(otel_scope='pydantic-graph') -RunSignatureT = ParamSpec('RunSignatureT') -RunEndT = TypeVar('RunEndT', default=None) -NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) - @dataclass(init=False) class Graph(Generic[StateT, RunEndT]): @@ -270,7 +266,7 @@ def mermaid_code( highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, infer_name: bool = True, ) -> str: - """Generate a diagram representing the graph as [mermaid](https://mermaid.js.org/) chart. + """Generate a diagram representing the graph as [mermaid](https://mermaid.js.org/) diagram. This method calls [`pydantic_graph.mermaid.generate_code`][pydantic_graph.mermaid.generate_code]. diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index 9cd197a637..7d4a977292 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -78,4 +78,8 @@ def summary(self) -> str: HistoryStep = Union[NodeStep[StateT, RunEndT], EndStep[StateT, RunEndT]] -"""A step in the history of a graph run.""" +"""A step in the history of a graph run. + +[`Graph.run`][pydantic_graph.graph.Graph.run] returns a list of these steps describing the execution of the graph, +together with the run return value. +""" diff --git a/tests/test_examples.py b/tests/test_examples.py index f52248a52a..975258a04a 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -150,9 +150,14 @@ async def async_http_request(url: str, **kwargs: Any) -> httpx.Response: def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: if prompt == 'Where would you like to fly from and to?': return 'SFO to ANC' - else: - assert prompt == 'What seat would you like?', prompt + elif prompt == 'What seat would you like?': return 'window seat with leg room' + if prompt == 'Insert coins': + return '1' + elif prompt == 'Select product': + return 'crisps' + else: # pragma: no cover + raise ValueError(f'Unexpected prompt: {prompt}') text_responses: dict[str, str | ToolCallPart] = { From f2cd72a4dbac7b7797353fc0027bcaa1ee6a610c Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 11 Jan 2025 23:15:55 +0000 Subject: [PATCH 42/57] add GenAI example --- docs/graph.md | 139 ++++++++++++++++++++++++- pydantic_graph/pydantic_graph/graph.py | 73 ++++++------- tests/test_examples.py | 16 +++ 3 files changed, 189 insertions(+), 39 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index 0765d5e265..7875c0b52c 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -162,7 +162,7 @@ async def main(): 1. The state of the vending machine is defined as a dataclass with the user's balance and the product they've selected, if any. 2. A dictionary of products mapped to prices. 3. The `InsertCoin` node, [`BaseNode`][pydantic_graph.nodes.BaseNode] is parameterized with `MachineState` as that's the state used in this graph. -4. The `InsertCoin` node prompts the user to insert coins. Keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using `input` within node, see [below](#running-graphs) for how control flow can be managed when nodes require external input. +4. The `InsertCoin` node prompts the user to insert coins. Keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using `input` within node, see [below](#custom-control-flow) for how control flow can be managed when nodes require external input. 5. The `CoinsInserted` node, again this is a [`dataclass`][dataclasses.dataclass], in this case with one field `amount`, thus nodes calling `CoinsInserted` must provide an amount. 6. Update the user's balance with the amount inserted. 7. If the user has already selected a product, go to `Purchase`, otherwise go to `SelectProduct`. @@ -211,9 +211,142 @@ See [below](#mermaid-diagrams) for more information on generating diagrams. ## GenAI Example -TODO +So far we haven't shown an example of a Graph that actually uses PydanticAI or GenAI at all. + +In this example, one agent generates a welcome email to a user and the other agent provides feedback on the email. + +This graph has avery simple structure: + +```mermaid +--- +title: feedback_graph +--- +stateDiagram-v2 + [*] --> WriteEmail + WriteEmail --> Feedback + Feedback --> WriteEmail + Feedback --> [*] +``` + + +```python {title="genai_email_feedback.py"} +from __future__ import annotations as _annotations + +from dataclasses import dataclass, field + +from pydantic import BaseModel, EmailStr + +from pydantic_ai import Agent +from pydantic_ai.format_as_xml import format_as_xml +from pydantic_ai.messages import ModelMessage +from pydantic_graph import BaseNode, End, Graph, GraphContext + + +@dataclass +class User: + name: str + email: EmailStr + interests: list[str] + + +@dataclass +class Email: + subject: str + body: str + + +@dataclass +class State: + user: User + write_agent_messages: list[ModelMessage] = field(default_factory=list) + + +email_writer_agent = Agent( + 'google-vertex:gemini-1.5-pro', + result_type=Email, + system_prompt='Write a welcome email to our tech blog.', +) + + +@dataclass +class WriteEmail(BaseNode[State]): + email_feedback: str | None = None + + async def run(self, ctx: GraphContext[State]) -> Feedback: + if self.email_feedback: + prompt = ( + f'Rewrite the email for the user:\n' + f'{format_as_xml(ctx.state.user)}\n' + f'Feedback: {self.email_feedback}' + ) + else: + prompt = ( + f'Write a welcome email for the user:\n' + f'{format_as_xml(ctx.state.user)}' + ) + + result = await email_writer_agent.run( + prompt, + message_history=ctx.state.write_agent_messages, + ) + ctx.state.write_agent_messages += result.all_messages() + return Feedback(result.data) + + +class EmailRequiresWrite(BaseModel): + feedback: str + + +class EmailOk(BaseModel): + pass + + +feedback_agent = Agent[None, EmailRequiresWrite | EmailOk]( + 'openai:gpt-4o', + result_type=EmailRequiresWrite | EmailOk, # type: ignore + system_prompt=( + 'Review the email and provide feedback, email must reference the users specific interests.' + ), +) + + +@dataclass +class Feedback(BaseNode[State, Email]): + email: Email + + async def run( + self, + ctx: GraphContext[State], + ) -> WriteEmail | End[Email]: + prompt = format_as_xml({'user': ctx.state.user, 'email': self.email}) + result = await feedback_agent.run(prompt) + if isinstance(result.data, EmailRequiresWrite): + return WriteEmail(email_feedback=result.data.feedback) + else: + return End(self.email) + + +async def main(): + user = User( + name='John Doe', + email='john.joe@exmaple.com', + interests=['Haskel', 'Lisp', 'Fortran'], + ) + state = State(user) + feedback_graph = Graph(nodes=(WriteEmail, Feedback)) + email, _ = await feedback_graph.run(state, WriteEmail()) + print(email) + """ + Email( + subject='Welcome to our tech blog!', + body='Hello John, Welcome to our tech blog! ...', + ) + """ +``` + +_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_ -## Running Graphs +## Custom Control Flow TODO diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 469905da2e..4cfd74541b 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -97,41 +97,6 @@ def __init__( self._validate_edges() - async def next( - self, - state: StateT, - node: BaseNode[StateT, RunEndT], - history: list[HistoryStep[StateT, RunEndT]], - *, - infer_name: bool = True, - ) -> BaseNode[StateT, Any] | End[RunEndT]: - """Run a node in the graph and return the next node to run. - - Args: - state: The current state of the graph. - node: The node to run. - history: The history of the graph run so far. NOTE: this will be mutated to add the new step. - infer_name: Whether to infer the graph name from the calling frame. - - Returns: - The next node to run or [`End`][pydantic_graph.nodes.End] if the graph has finished. - """ - if infer_name and self.name is None: - self._infer_name(inspect.currentframe()) - node_id = node.get_id() - if node_id not in self.node_defs: - raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') - - history_step: NodeStep[StateT, RunEndT] = NodeStep(state, node) - history.append(history_step) - - ctx = GraphContext(state) - with _logfire.span('run node {node_id}', node_id=node_id, node=node): - start = perf_counter() - next_node = await node.run(ctx) - history_step.duration = perf_counter() - start - return next_node - async def run( self, state: StateT, @@ -147,7 +112,8 @@ async def run( you need to provide the starting node. infer_name: Whether to infer the graph name from the calling frame. - Returns: The result type from ending the run and the history of the run. + Returns: + The result type from ending the run and the history of the run. Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: @@ -255,6 +221,41 @@ async def main(): f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.' ) + async def next( + self, + state: StateT, + node: BaseNode[StateT, RunEndT], + history: list[HistoryStep[StateT, RunEndT]], + *, + infer_name: bool = True, + ) -> BaseNode[StateT, Any] | End[RunEndT]: + """Run a node in the graph and return the next node to run. + + Args: + state: The current state of the graph. + node: The node to run. + history: The history of the graph run so far. NOTE: this will be mutated to add the new step. + infer_name: Whether to infer the graph name from the calling frame. + + Returns: + The next node to run or [`End`][pydantic_graph.nodes.End] if the graph has finished. + """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + node_id = node.get_id() + if node_id not in self.node_defs: + raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') + + history_step: NodeStep[StateT, RunEndT] = NodeStep(state, node) + history.append(history_step) + + ctx = GraphContext(state) + with _logfire.span('run node {node_id}', node_id=node_id, node=node): + start = perf_counter() + next_node = await node.run(ctx) + history_step.duration = perf_counter() - start + return next_node + def mermaid_code( self, *, diff --git a/tests/test_examples.py b/tests/test_examples.py index 975258a04a..2f0d1db441 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -257,6 +257,22 @@ async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelRes return ModelResponse(parts=[ToolCallPart(tool_name='get_jokes', args=ArgsDict({'count': 5}))]) elif re.fullmatch(r'sql prompt \d+', m.content): return ModelResponse.from_text(content='SELECT 1') + elif m.content.startswith('Write a welcome email for the user:'): + return ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args=ArgsDict( + { + 'subject': 'Welcome to our tech blog!', + 'body': 'Hello John, Welcome to our tech blog! ...', + } + ), + ) + ] + ) + elif m.content.startswith('\n '): + return ModelResponse(parts=[ToolCallPart(tool_name='final_result_EmailOk', args=ArgsDict({}))]) elif response := text_responses.get(m.content): if isinstance(response, str): return ModelResponse.from_text(content=response) From 111f2d017b6ad3e1af291f68427f9b162c44ca20 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 12 Jan 2025 11:17:34 +0000 Subject: [PATCH 43/57] more graph docs --- docs/api/models/function.md | 2 +- docs/api/models/test.md | 2 +- docs/graph.md | 162 +++++++++++++++--- .../pydantic_ai_examples/question_graph.py | 3 +- pydantic_ai_slim/pydantic_ai/format_as_xml.py | 3 +- pydantic_ai_slim/pydantic_ai/tools.py | 6 +- pydantic_graph/pydantic_graph/_utils.py | 3 +- pydantic_graph/pydantic_graph/graph.py | 34 +++- pydantic_graph/pydantic_graph/mermaid.py | 3 +- pydantic_graph/pydantic_graph/state.py | 16 +- tests/graph/test_graph.py | 14 +- tests/test_examples.py | 15 +- 12 files changed, 215 insertions(+), 48 deletions(-) diff --git a/docs/api/models/function.md b/docs/api/models/function.md index 6280ae47f1..7d16124fdb 100644 --- a/docs/api/models/function.md +++ b/docs/api/models/function.md @@ -9,7 +9,7 @@ Its primary use case is for more advanced unit testing than is possible with `Te Here's a minimal example: -```py {title="function_model_usage.py" call_name="test_my_agent" lint="not-imports"} +```py {title="function_model_usage.py" call_name="test_my_agent" noqa="I001"} from pydantic_ai import Agent from pydantic_ai.messages import ModelMessage, ModelResponse from pydantic_ai.models.function import FunctionModel, AgentInfo diff --git a/docs/api/models/test.md b/docs/api/models/test.md index 66bfac7932..e6c4411d29 100644 --- a/docs/api/models/test.md +++ b/docs/api/models/test.md @@ -4,7 +4,7 @@ Utility model for quickly testing apps built with PydanticAI. Here's a minimal example: -```py {title="test_model_usage.py" call_name="test_my_agent" lint="not-imports"} +```py {title="test_model_usage.py" call_name="test_my_agent" noqa="I001"} from pydantic_ai import Agent from pydantic_ai.models.test import TestModel diff --git a/docs/graph.md b/docs/graph.md index 7875c0b52c..829d660a82 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -18,7 +18,10 @@ Alongside PydanticAI, we've developed `pydantic-graph` — an async graph and st While this library is developed as part of the PydanticAI; it has no dependency on `pydantic-ai` and can be considered as a pure graph library. You may find it useful whether or not you're using PydanticAI or even building with GenAI. -`pydantic-graph` is designed for advanced users and makes heavy use of Python generics and types hints. It is not designed to b as beginner-friendly as PydanticAI. +`pydantic-graph` is designed for advanced users and makes heavy use of Python generics and types hints. It is not designed to be as beginner-friendly as PydanticAI. + +!!! note "Every Early beta" + Graph support was [introduced](https://github.com/pydantic/pydantic-ai/pull/528) in v0.0.19 and is in very earlier beta. The API is subject to change. The documentation is incomplete. The implementation is incomplete. ## Installation @@ -30,11 +33,20 @@ pip/uv-add pydantic-graph ## Graph Types -!!! note "Every Early beta" - Graph support was [introduced](https://github.com/pydantic/pydantic-ai/pull/528) in v0.0.19 and is in very earlier beta. The API is subject to change. The documentation is incomplete. The implementation is incomplete. - Graphs are made up of a few key components: +### GraphContext + +[`GraphContext`][pydantic_graph.nodes.GraphContext] — The context for the graph run, similar to PydanticAI's [`RunContext`][pydantic_ai.tools.RunContext], this holds the state of the graph and is passed to nodes when they're run. + +`GraphContext` is generic in the state type of the graph it's used in, [`StateT`][pydantic_graph.state.StateT]. + +### End + +[`End`][pydantic_graph.nodes.End] — return value to indicates the graph run should end. + +`End` is generic in the graph return type of the graph it's used in, [`RunEndT`][pydantic_graph.nodes.RunEndT]. + ### Nodes Subclasses of [`BaseNode`][pydantic_graph.nodes.BaseNode] to define nodes. @@ -50,17 +62,55 @@ Nodes are generic in: * **state** which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.state.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter * **graph return type** this only applies if the node returns [`End`][pydantic_graph.nodes.End], [`RunEndT`][pydantic_graph.nodes.RunEndT] has a default of [Never][typing.Never] so this generic parameter can be omitted if the node doesn't return `End`, but must be included if it does. -### GraphContext +Here's an example of a start or intermediate node in a graph — it can't end the run as it doesn't return [`End`][pydantic_graph.nodes.End]: -[`GraphContext`][pydantic_graph.nodes.GraphContext] — The context for the graph run, similar to PydanticAI's [`RunContext`][pydantic_ai.tools.RunContext], this holds the state of the graph and is passed to nodes when they're run. +```py {title="intermediate_node.py" noqa="F821" test="skip"} +from dataclasses import dataclass -`GraphContext` is generic in the state type of the graph it's used in, [`StateT`][pydantic_graph.state.StateT]. +from pydantic_graph import BaseNode, GraphContext -### End -[`End`][pydantic_graph.nodes.End] — return value to indicates the graph run should end. +@dataclass +class MyNode(BaseNode[MyState]): # (1)! + foo: int # (2)! -`End` is generic in the graph return type of the graph it's used in, [`RunEndT`][pydantic_graph.nodes.RunEndT]. + async def run( + self, + ctx: GraphContext[MyState], # (3)! + ) -> AnotherNode: # (4)! + ... + return AnotherNode() +``` + +1. State in this example is `MyState` (not shown), hence `BaseNode` is parameterized with `MyState`. This node can't end the run, so the `RunEndT` generic parameter is omitted and defaults to `Never`. +2. `MyNode` is a dataclass and has a single field `foo`, an `int`. +3. The `run` method takes a `GraphContext` parameter, again parameterized with state `MyState`. +4. The return type of the `run` method is `AnotherNode` (not shown), this is used to determine the outgoing edges of the node. + +We could extend `MyNode` to optionally end the run if `foo` is divisible by 5: + +```py {title="intermediate_or_end_node.py" hl_lines="7 13" noqa="F821" test="skip"} +from dataclasses import dataclass + +from pydantic_graph import BaseNode, End, GraphContext + + +@dataclass +class MyNode(BaseNode[MyState, int]): # (1)! + foo: int + + async def run( + self, + ctx: GraphContext[MyState], + ) -> AnotherNode | End[int]: # (2)! + if self.foo % 5 == 0: + return End(self.foo) + else: + return AnotherNode() +``` + +1. We parameterize the node with the return type (`int` in this case) as well as state. +2. The return type of the `run` method is now a union of `AnotherNode` and `End[int]`, this allows the node to end the run if `foo` is divisible by 5. ### Graph @@ -71,11 +121,76 @@ Nodes are generic in: * **state** the state type of the graph, [`StateT`][pydantic_graph.state.StateT] * **graph return type** the return type of the graph run, [`RunEndT`][pydantic_graph.nodes.RunEndT] -## Basic Usage +Here's an example of a simple graph: + +```py {title="graph_example.py" py="3.10"} +from __future__ import annotations + +from dataclasses import dataclass + +from pydantic_graph import BaseNode, End, Graph, GraphContext + + +@dataclass +class DivisibleBy5(BaseNode[None, int]): # (1)! + foo: int + + async def run( + self, + ctx: GraphContext, + ) -> Increment | End[int]: + if self.foo % 5 == 0: + return End(self.foo) + else: + return Increment(self.foo) + + +@dataclass +class Increment(BaseNode): # (2)! + foo: int + + async def run(self, ctx: GraphContext) -> DivisibleBy5: + return DivisibleBy5(self.foo + 1) + + +fives_graph = Graph(nodes=[DivisibleBy5, Increment]) +result, history = fives_graph.run_sync(None, DivisibleBy5(4)) +print(result) +#> 5 +# the full history is quite verbose (see below), so we'll just print the summary +print([item.data_snapshot() for item in history]) +#> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] +``` +_(This example is complete, it can be run "as is" with Python 3.10+)_ + +A [mermaid diagram](#mermaid-diagrams) for this graph can be generated with the following code: + +```py {title="graph_example_diagram.py" py="3.10"} +from graph_example import DivisibleBy5, fives_graph + +fives_graph.mermaid_code(start_node=DivisibleBy5) +``` + +```mermaid +--- +title: fives_graph +--- +stateDiagram-v2 + [*] --> DivisibleBy5 + DivisibleBy5 --> Increment + DivisibleBy5 --> [*] + Increment --> DivisibleBy5 +``` + +## Stateful Graphs + +TODO introduce state + +TODO link to issue about persistent state. Here's an example of a graph which represents a vending machine where the user may insert coins and select a product to purchase. -```python {title="vending_machine.py"} +```python {title="vending_machine.py" py="3.10"} from __future__ import annotations from dataclasses import dataclass @@ -162,7 +277,7 @@ async def main(): 1. The state of the vending machine is defined as a dataclass with the user's balance and the product they've selected, if any. 2. A dictionary of products mapped to prices. 3. The `InsertCoin` node, [`BaseNode`][pydantic_graph.nodes.BaseNode] is parameterized with `MachineState` as that's the state used in this graph. -4. The `InsertCoin` node prompts the user to insert coins. Keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using `input` within node, see [below](#custom-control-flow) for how control flow can be managed when nodes require external input. +4. The `InsertCoin` node prompts the user to insert coins. We keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using `input` within nodes, see [below](#custom-control-flow) for how control flow can be managed when nodes require external input. 5. The `CoinsInserted` node, again this is a [`dataclass`][dataclasses.dataclass], in this case with one field `amount`, thus nodes calling `CoinsInserted` must provide an amount. 6. Update the user's balance with the amount inserted. 7. If the user has already selected a product, go to `Purchase`, otherwise go to `SelectProduct`. @@ -171,18 +286,18 @@ async def main(): 10. If the balance is enough to purchase the product, adjust the balance to reflect the purchase and return [`End`][pydantic_graph.nodes.End] to end the graph. We're not using the run return type, so we call `End` with `None`. 11. If the balance is insufficient, to go `InsertCoin` to prompt the user to insert more coins. 12. If the product is invalid, go to `SelectProduct` to prompt the user to select a product again. -13. The graph is created by passing a list of nodes to [`Graph`][pydantic_graph.graph.Graph]. Order of nodes is not important, but will alter how [diagramss](#mermaid-diagrams) are displayed. +13. The graph is created by passing a list of nodes to [`Graph`][pydantic_graph.graph.Graph]. Order of nodes is not important, but will alter how [diagrams](#mermaid-diagrams) are displayed. 14. Initialize the state, this will be passed to the graph run and mutated as the graph runs. 15. Run the graph with the initial state, since the graph can be run from any node, we must pass the start node, in this case `InsertCoin`. [`Graph.run`][pydantic_graph.graph.Graph.run] returns a tuple of the return value (`None`) in this case, and the [history][pydantic_graph.state.HistoryStep] of the graph run. 16. The return type of the node's [`run`][pydantic_graph.nodes.BaseNode.run] method is important, it's used to determine the outgoing edges of the node, this in turn is used to render [mermaid diagrams](#mermaid-diagrams) and is enforced at runtime. 17. The return type of `CoinsInserted`s [`run`][pydantic_graph.nodes.BaseNode.run] method is a union, meaning multiple outgoing edges are possible. 18. Unlike other nodes `Purchase` can end the run, so the [`RunEndT`][pydantic_graph.nodes.RunEndT] generic parameter must be set, in this case it's `None` since the graph run return type is `None`. -_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_ +_(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ A [mermaid diagram](#mermaid-diagrams) for this graph can be generated with the following code: -```py {title="vending_machine_diagram.py"} +```py {title="vending_machine_diagram.py" py="3.10"} from vending_machine import InsertCoin, vending_machine_graph vending_machine_graph.mermaid_code(start_node=InsertCoin) @@ -215,7 +330,7 @@ So far we haven't shown an example of a Graph that actually uses PydanticAI or G In this example, one agent generates a welcome email to a user and the other agent provides feedback on the email. -This graph has avery simple structure: +This graph has a very simple structure: ```mermaid --- @@ -229,7 +344,7 @@ stateDiagram-v2 ``` -```python {title="genai_email_feedback.py"} +```python {title="genai_email_feedback.py" py="3.10"} from __future__ import annotations as _annotations from dataclasses import dataclass, field @@ -344,15 +459,18 @@ async def main(): """ ``` -_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_ +_(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ ## Custom Control Flow -TODO +In many real-world applications, Graphs cannot run uninterrupted from start to finish — they require external input or run over an extended period of time where a single process cannot execute the entire graph. -## State Machine +In these scenarios the [`next`][pydantic_graph.graph.Graph.next] method can be used to run the graph one node at a time. -TODO +In this example, an AI asks the user a question, the user provides an answer, the AI evaluates the answer and ends if the user got it right or asks another question if they got it wrong. + +```python {title="ai_q_and_a.py"} +``` ## Mermaid Diagrams diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index 5f2c42bca9..d8158778c0 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -11,6 +11,7 @@ from typing import Annotated import logfire +from devtools import debug from pydantic_graph import BaseNode, Edge, End, Graph, GraphContext, HistoryStep from pydantic_ai import Agent @@ -136,7 +137,7 @@ async def main(): while True: node = await question_graph.next(state, node, history) if isinstance(node, End): - print('\n'.join(e.summary() for e in history)) + debug([e.data_snapshot() for e in history]) break elif isinstance(node, Answer): node.answer = input(f'{node.question} ') diff --git a/pydantic_ai_slim/pydantic_ai/format_as_xml.py b/pydantic_ai_slim/pydantic_ai/format_as_xml.py index 3c6f67fd42..ce9973691d 100644 --- a/pydantic_ai_slim/pydantic_ai/format_as_xml.py +++ b/pydantic_ai_slim/pydantic_ai/format_as_xml.py @@ -37,7 +37,8 @@ def format_as_xml( none_str: String to use for `None` values. indent: Indentation string to use for pretty printing. - Returns: XML representation of the object. + Returns: + XML representation of the object. Example: ```python {title="format_as_xml_example.py" lint="skip"} diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 5f07f66ef7..9a99bbba22 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -107,7 +107,7 @@ def replace_with( Example — here `only_if_42` is valid as a `ToolPrepareFunc`: -```python {lint="not-imports"} +```python {noqa="I001"} from typing import Union from pydantic_ai import RunContext, Tool @@ -176,7 +176,7 @@ def __init__( Example usage: - ```python {lint="not-imports"} + ```python {noqa="I001"} from pydantic_ai import Agent, RunContext, Tool async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: @@ -187,7 +187,7 @@ async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: or with a custom prepare method: - ```python {lint="not-imports"} + ```python {noqa="I001"} from typing import Union from pydantic_ai import Agent, RunContext, Tool diff --git a/pydantic_graph/pydantic_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py index 81123afde6..0211f4fd17 100644 --- a/pydantic_graph/pydantic_graph/_utils.py +++ b/pydantic_graph/pydantic_graph/_utils.py @@ -24,7 +24,8 @@ def get_union_args(tp: Any) -> tuple[Any, ...]: def unpack_annotated(tp: Any) -> tuple[Any, list[Any]]: """Strip `Annotated` from the type if present. - Returns: `(tp argument, ())` if not annotated, otherwise `(stripped type, annotations)`. + Returns: + `(tp argument, ())` if not annotated, otherwise `(stripped type, annotations)`. """ origin = get_origin(tp) if origin is Annotated or origin is typing_extensions.Annotated: diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 4cfd74541b..8085307f85 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import asyncio import inspect from collections.abc import Sequence from dataclasses import dataclass @@ -31,7 +32,7 @@ class Graph(Generic[StateT, RunEndT]): Here's a very simple example of a graph which increments a number by 1, but makes sure the number is never 42 at the end. - ```py {title="never_42.py" lint="not-imports"} + ```py {title="never_42.py" noqa="I001" py="3.10"} from __future__ import annotations from dataclasses import dataclass @@ -117,7 +118,7 @@ async def run( Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: - ```py {title="run_never_42.py" lint="not-imports"} + ```py {title="run_never_42.py" noqa="I001" py="3.10"} from never_42 import Increment, MyState, never_42_graph async def main(): @@ -196,10 +197,10 @@ async def main(): ''' ``` """ - history: list[HistoryStep[StateT, RunEndT]] = [] if infer_name and self.name is None: self._infer_name(inspect.currentframe()) + history: list[HistoryStep[StateT, RunEndT]] = [] with _logfire.span( '{graph_name} run {start=}', graph_name=self.name or 'graph', @@ -221,6 +222,31 @@ async def main(): f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.' ) + def run_sync( + self, + state: StateT, + start_node: BaseNode[StateT, RunEndT], + *, + infer_name: bool = True, + ) -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: + """Run the graph synchronously. + + This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`. + You therefore can't use this method inside async code or if there's an active event loop. + + Args: + state: The initial state of the graph. + start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, + you need to provide the starting node. + infer_name: Whether to infer the graph name from the calling frame. + + Returns: + The result type from ending the run and the history of the run. + """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + return asyncio.get_event_loop().run_until_complete(self.run(state, start_node, infer_name=False)) + async def next( self, state: StateT, @@ -285,7 +311,7 @@ def mermaid_code( Here's an example of generating a diagram for the graph from [above][pydantic_graph.graph.Graph]: - ```py {title="never_42.py"} + ```py {title="never_42.py" py="3.10"} from never_42 import Increment, never_42_graph print(never_42_graph.mermaid_code(start_node=Increment)) diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index 1e72493759..49e41ee267 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -120,7 +120,8 @@ def request_image( graph: The graph to generate the image for. **kwargs: Additional parameters to configure mermaid chart generation. - Returns: The image data. + Returns: + The image data. """ code = generate_code( graph, diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index 7d4a977292..174607362c 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -50,8 +50,12 @@ def __post_init__(self, snapshot_state: Callable[[StateT], StateT]): # Copy the state to prevent it from being modified by other code self.state = snapshot_state(self.state) - def summary(self) -> str: - return str(self.node) + def data_snapshot(self) -> BaseNode[StateT, RunEndT]: + """Returns a deep copy of [`self.node`][pydantic_graph.state.NodeStep.node]. + + Useful for summarizing history. + """ + return copy.deepcopy(self.node) @dataclass @@ -73,8 +77,12 @@ def __post_init__(self, snapshot_state: Callable[[StateT], StateT]): # Copy the state to prevent it from being modified by other code self.state = snapshot_state(self.state) - def summary(self) -> str: - return str(self.result) + def data_snapshot(self) -> End[RunEndT]: + """Returns a deep copy of [`self.result`][pydantic_graph.state.EndStep.result]. + + Useful for summarizing history. + """ + return copy.deepcopy(self.result) HistoryStep = Union[NodeStep[StateT, RunEndT], EndStep[StateT, RunEndT]] diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index f82e92e5e6..5549ffaa8b 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -127,14 +127,14 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq ), ] ) - assert [e.summary() for e in history] == snapshot( + assert [e.data_snapshot() for e in history] == snapshot( [ - 'test_graph..Float2String(input_data=3.14159)', - "test_graph..String2Length(input_data='3.14159')", - 'test_graph..Double(input_data=7)', - "test_graph..String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx')", - 'test_graph..Double(input_data=21)', - 'End(data=42)', + Float2String(input_data=3.14159), + String2Length(input_data='3.14159'), + Double(input_data=7), + String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx'), + Double(input_data=21), + End(data=42), ] ) diff --git a/tests/test_examples.py b/tests/test_examples.py index 2f0d1db441..26ec8ae9a0 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -84,6 +84,14 @@ def test_docs_examples( opt_title = prefix_settings.get('title') opt_test = prefix_settings.get('test', '') opt_lint = prefix_settings.get('lint', '') + noqa = prefix_settings.get('noqa', '') + python_version = prefix_settings.get('py', None) + + if python_version: + python_version_info = tuple(int(v) for v in python_version.split('.')) + if sys.version_info < python_version_info: + pytest.skip(f'Python version {python_version} required') + cwd = Path.cwd() if opt_test.startswith('skip') and opt_lint.startswith('skip'): @@ -99,9 +107,12 @@ def test_docs_examples( # `from bank_database import DatabaseConn` wrongly sorted in imports # waiting for https://github.com/pydantic/pytest-examples/issues/43 # and https://github.com/pydantic/pytest-examples/issues/46 - if opt_lint == 'not-imports' or 'import DatabaseConn' in example.source: + if 'import DatabaseConn' in example.source: ruff_ignore.append('I001') + if noqa: + ruff_ignore.extend(noqa.upper().split()) + line_length = int(prefix_settings.get('line_length', '88')) eval_example.set_config(ruff_ignore=ruff_ignore, target_version='py39', line_length=line_length) @@ -116,7 +127,7 @@ def test_docs_examples( eval_example.lint(example) if opt_test.startswith('skip'): - pytest.skip(opt_test[4:].lstrip(' -') or 'running code skipped') + print(opt_test[4:].lstrip(' -') or 'running code skipped') else: if eval_example.update_examples: # pragma: no cover module_dict = eval_example.run_print_update(example, call=call_name) From 5d0a834e34663db47fd8db89d9bbe1c45ca7aeba Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 13 Jan 2025 14:32:04 +0000 Subject: [PATCH 44/57] extending graph docs --- docs/graph.md | 176 +++++++++++++++++- .../pydantic_ai_examples/question_graph.py | 120 ++++++++---- pydantic_graph/pydantic_graph/nodes.py | 3 + tests/test_examples.py | 15 ++ 4 files changed, 269 insertions(+), 45 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index 829d660a82..1e2a8f2ffb 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -153,14 +153,20 @@ class Increment(BaseNode): # (2)! return DivisibleBy5(self.foo + 1) -fives_graph = Graph(nodes=[DivisibleBy5, Increment]) -result, history = fives_graph.run_sync(None, DivisibleBy5(4)) +fives_graph = Graph(nodes=[DivisibleBy5, Increment]) # (3)! +result, history = fives_graph.run_sync(None, DivisibleBy5(4)) # (4)! print(result) #> 5 # the full history is quite verbose (see below), so we'll just print the summary print([item.data_snapshot() for item in history]) #> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] ``` + +1. The `DivisibleBy5` node is parameterized with `None` as this graph doesn't use state, and `int` as it can end the run. +2. The `Increment` node doesn't return `End`, so the `RunEndT` generic parameter is omitted, state can also be omitted as the graph doesn't use state. +3. The graph is created with a sequence of nodes. +4. The graph is run synchronously with [`run_sync`][pydantic_graph.graph.Graph.run_sync] the initial state `None` and the start node `DivisibleBy5(4)` are passed as arguments. + _(This example is complete, it can be run "as is" with Python 3.10+)_ A [mermaid diagram](#mermaid-diagrams) for this graph can be generated with the following code: @@ -277,7 +283,7 @@ async def main(): 1. The state of the vending machine is defined as a dataclass with the user's balance and the product they've selected, if any. 2. A dictionary of products mapped to prices. 3. The `InsertCoin` node, [`BaseNode`][pydantic_graph.nodes.BaseNode] is parameterized with `MachineState` as that's the state used in this graph. -4. The `InsertCoin` node prompts the user to insert coins. We keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using `input` within nodes, see [below](#custom-control-flow) for how control flow can be managed when nodes require external input. +4. The `InsertCoin` node prompts the user to insert coins. We keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] within nodes, see [below](#custom-control-flow) for how control flow can be managed when nodes require external input. 5. The `CoinsInserted` node, again this is a [`dataclass`][dataclasses.dataclass], in this case with one field `amount`, thus nodes calling `CoinsInserted` must provide an amount. 6. Update the user's balance with the amount inserted. 7. If the user has already selected a product, go to `Purchase`, otherwise go to `SelectProduct`. @@ -303,8 +309,6 @@ from vending_machine import InsertCoin, vending_machine_graph vending_machine_graph.mermaid_code(start_node=InsertCoin) ``` -_(This example is complete, it can be run "as is")_ - The diagram generated by the above code is: ```mermaid @@ -469,9 +473,169 @@ In these scenarios the [`next`][pydantic_graph.graph.Graph.next] method can be u In this example, an AI asks the user a question, the user provides an answer, the AI evaluates the answer and ends if the user got it right or asks another question if they got it wrong. -```python {title="ai_q_and_a.py"} +??? example "`ai_q_and_a_graph.py` — `question_graph` definition" + ```python {title="ai_q_and_a_graph.py" noqa="I001" py="3.10"} + from __future__ import annotations as _annotations + + from dataclasses import dataclass, field + + from pydantic_graph import BaseNode, End, Graph, GraphContext + + from pydantic_ai import Agent + from pydantic_ai.format_as_xml import format_as_xml + from pydantic_ai.messages import ModelMessage + + ask_agent = Agent('openai:gpt-4o', result_type=str) + + + @dataclass + class QuestionState: + question: str | None = None + ask_agent_messages: list[ModelMessage] = field(default_factory=list) + evaluate_agent_messages: list[ModelMessage] = field(default_factory=list) + + + @dataclass + class Ask(BaseNode[QuestionState]): + async def run(self, ctx: GraphContext[QuestionState]) -> Answer: + result = await ask_agent.run( + 'Ask a simple question with a single correct answer.', + message_history=ctx.state.ask_agent_messages, + ) + ctx.state.ask_agent_messages += result.all_messages() + ctx.state.question = result.data + return Answer(result.data) + + + @dataclass + class Answer(BaseNode[QuestionState]): + question: str + answer: str | None = None + + async def run(self, ctx: GraphContext[QuestionState]) -> Evaluate: + assert self.answer is not None + return Evaluate(self.answer) + + + @dataclass + class EvaluationResult: + correct: bool + comment: str + + + evaluate_agent = Agent( + 'openai:gpt-4o', + result_type=EvaluationResult, + system_prompt='Given a question and answer, evaluate if the answer is correct.', + ) + + + @dataclass + class Evaluate(BaseNode[QuestionState]): + answer: str + + async def run( + self, + ctx: GraphContext[QuestionState], + ) -> End[str] | Reprimand: + assert ctx.state.question is not None + result = await evaluate_agent.run( + format_as_xml({'question': ctx.state.question, 'answer': self.answer}), + message_history=ctx.state.evaluate_agent_messages, + ) + ctx.state.evaluate_agent_messages += result.all_messages() + if result.data.correct: + return End(result.data.comment) + else: + return Reprimand(result.data.comment) + + + @dataclass + class Reprimand(BaseNode[QuestionState]): + comment: str + + async def run(self, ctx: GraphContext[QuestionState]) -> Ask: + print(f'Comment: {self.comment}') + #> Comment: Vichy is no longer the capital of France. + ctx.state.question = None + return Ask() + + + question_graph = Graph(nodes=(Ask, Answer, Evaluate, Reprimand)) + ``` + + _(This example is complete, it can be run "as is" with Python 3.10+)_ + + +```python {title="ai_q_and_a_run.py" noqa="I001" py="3.10"} +from rich.prompt import Prompt + +from pydantic_graph import End, HistoryStep + +from ai_q_and_a_graph import Ask, question_graph, QuestionState, Answer + + +async def main(): + state = QuestionState() # (1)! + node = Ask() # (2)! + history: list[HistoryStep[QuestionState]] = [] # (3)! + while True: + node = await question_graph.next(state, node, history) # (4)! + if isinstance(node, Answer): + node.answer = Prompt.ask(node.question) # (5)! + elif isinstance(node, End): # (6)! + print(f'Correct answer! {node.data}') + #> Correct answer! Well done, 1 + 1 = 2 + print([e.data_snapshot() for e in history]) + """ + [ + Ask(), + Answer(question='What is the capital of France?', answer='Vichy'), + Evaluate(answer='Vichy'), + Reprimand(comment='Vichy is no longer the capital of France.'), + Ask(), + Answer(question='what is 1 + 1?', answer='2'), + Evaluate(answer='2'), + ] + """ + return + # otherwise just continue +``` + +1. Create the state object which will be mutated by [`next`][pydantic_graph.graph.Graph.next]. +2. The start node is `Ask` but will be updated by [`next`][pydantic_graph.graph.Graph.next] as the graph runs. +3. The history of the graph run is stored in a list of [`HistoryStep`][pydantic_graph.state.HistoryStep] objects, again [`next`][pydantic_graph.graph.Graph.next] will update this list in place. +4. [Run][pydantic_graph.graph.Graph.next] the graph one node at a time, updating the state, current node and history as the graph runs. +5. If the current node is an `Answer` node, prompt the user for an answer. +6. Since we're using [`next`][pydantic_graph.graph.Graph.next] we have to manually check for an [`End`][pydantic_graph.nodes.End] and exit the loop if we get one. + +_(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ + +A [mermaid diagram](#mermaid-diagrams) for this graph can be generated with the following code: + +```py {title="ai_q_and_a_diagram.py" py="3.10"} +from ai_q_and_a_graph import Ask, question_graph + +question_graph.mermaid_code(start_node=Ask) +``` + +```mermaid +--- +title: question_graph +--- +stateDiagram-v2 + [*] --> Ask + Ask --> Answer + Answer --> Evaluate + Evaluate --> Reprimand + Evaluate --> [*] + Reprimand --> Ask ``` +You maybe have noticed that although this examples transfers control flow out of the graph run, we're still using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] to get user input, with the process hanging while we wait for the user to enter a response. + +TODO + ## Mermaid Diagrams TODO diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index d8158778c0..1ad7649d75 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -8,9 +8,11 @@ from __future__ import annotations as _annotations from dataclasses import dataclass, field +from pathlib import Path from typing import Annotated import logfire +import pydantic from devtools import debug from pydantic_graph import BaseNode, Edge, End, Graph, GraphContext, HistoryStep @@ -26,43 +28,30 @@ @dataclass class QuestionState: + question: str | None = None ask_agent_messages: list[ModelMessage] = field(default_factory=list) evaluate_agent_messages: list[ModelMessage] = field(default_factory=list) @dataclass class Ask(BaseNode[QuestionState]): - """Generate a question to ask the user. - - Uses the GPT-4o model to generate the question. - """ - - async def run( - self, ctx: GraphContext[QuestionState] - ) -> Annotated[Answer, Edge(label='ask the question')]: + async def run(self, ctx: GraphContext[QuestionState]) -> Answer: result = await ask_agent.run( 'Ask a simple question with a single correct answer.', message_history=ctx.state.ask_agent_messages, ) ctx.state.ask_agent_messages += result.all_messages() - return Answer(result.data) + ctx.state.question = result.data + return Answer() @dataclass class Answer(BaseNode[QuestionState]): - """Get the answer to the question from the user. - - This node must be completed outside the graph run. - """ - - question: str answer: str | None = None - async def run( - self, ctx: GraphContext[QuestionState] - ) -> Annotated[Evaluate, Edge(label='answer the question')]: + async def run(self, ctx: GraphContext[QuestionState]) -> Evaluate: assert self.answer is not None - return Evaluate(self.question, self.answer) + return Evaluate(self.answer) @dataclass @@ -75,34 +64,31 @@ class EvaluationResult: 'openai:gpt-4o', result_type=EvaluationResult, system_prompt='Given a question and answer, evaluate if the answer is correct.', - result_tool_name='evaluation', ) @dataclass class Evaluate(BaseNode[QuestionState]): - question: str answer: str async def run( self, ctx: GraphContext[QuestionState], - ) -> Congratulate | Castigate: + ) -> Congratulate | Reprimand: + assert ctx.state.question is not None result = await evaluate_agent.run( - format_as_xml({'question': self.question, 'answer': self.answer}), + format_as_xml({'question': ctx.state.question, 'answer': self.answer}), message_history=ctx.state.evaluate_agent_messages, ) ctx.state.evaluate_agent_messages += result.all_messages() if result.data.correct: return Congratulate(result.data.comment) else: - return Castigate(result.data.comment) + return Reprimand(result.data.comment) @dataclass class Congratulate(BaseNode[QuestionState, None]): - """Congratulate the user and end.""" - comment: str async def run( @@ -113,26 +99,23 @@ async def run( @dataclass -class Castigate(BaseNode[QuestionState]): - """Castigate the user, then ask another question.""" - +class Reprimand(BaseNode[QuestionState]): comment: str - async def run( - self, ctx: GraphContext[QuestionState] - ) -> Annotated[Ask, Edge(label='try again')]: + async def run(self, ctx: GraphContext[QuestionState]) -> Ask: print(f'Comment: {self.comment}') + # > Comment: Vichy is no longer the capital of France. + ctx.state.question = None return Ask() -question_graph = Graph(nodes=(Ask, Answer, Evaluate, Congratulate, Castigate)) -print(question_graph.mermaid_code(start_node=Ask, notes=False)) +question_graph = Graph(nodes=(Ask, Answer, Evaluate, Congratulate, Reprimand)) -async def main(): +async def run_as_continuous(): state = QuestionState() node = Ask() - history: list[HistoryStep[QuestionState]] = [] + history: list[HistoryStep[QuestionState, None]] = [] with logfire.span('run questions graph'): while True: node = await question_graph.next(state, node, history) @@ -140,11 +123,70 @@ async def main(): debug([e.data_snapshot() for e in history]) break elif isinstance(node, Answer): - node.answer = input(f'{node.question} ') + assert state.question + node.answer = input(f'{state.question} ') # otherwise just continue +ta = pydantic.TypeAdapter( + list[Annotated[HistoryStep[QuestionState, None], pydantic.Discriminator('kind')]] +) + + +async def run_as_cli(answer: str | None): + history_file = Path('question_graph_history.json') + if history_file.exists(): + history = ta.validate_json(history_file.read_bytes()) + else: + history = [] + + if history: + last = history[-1] + assert last.kind != 'node', 'expected last step to be a node' + state = last.state + assert answer is not None, 'answer is required to continue from history' + node = Answer(answer) + else: + state = QuestionState() + node = Ask() + + with logfire.span('run questions graph'): + while True: + node = await question_graph.next(state, node, history) + if isinstance(node, End): + debug([e.data_snapshot() for e in history]) + print('Finished!') + break + elif isinstance(node, Answer): + break + # otherwise just continue + + history_file.write_bytes(ta.dump_json(history)) + + if __name__ == '__main__': import asyncio - - asyncio.run(main()) + import sys + + try: + sub_command = sys.argv[1] + assert sub_command in ('continuous', 'cli', 'mermaid') + except (IndexError, AssertionError): + print( + 'Usage:\n' + ' uv run -m pydantic_ai_examples.question_graph meriad\n' + 'or:\n' + ' uv run -m pydantic_ai_examples.question_graph continuous\n' + 'or:\n' + ' uv run -m pydantic_ai_examples.question_graph cli [answer]', + file=sys.stderr, + ) + sys.exit(1) + + if sub_command == 'mermaid': + print(question_graph.mermaid_code(start_node=Ask)) + elif sub_command == 'continuous': + asyncio.run(run_as_continuous()) + else: + a = sys.argv[2] if len(sys.argv) > 2 else None + asyncio.run(run_as_cli(a)) diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index 1881b10b44..62848fa599 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -54,10 +54,12 @@ async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[No @classmethod @cache def get_id(cls) -> str: + """Get the ID of the node.""" return cls.__name__ @classmethod def get_note(cls) -> str | None: + """Get a note about the node to render on mermaid charts.""" if not cls.enable_docstring_notes: return None docstring = cls.__doc__ @@ -73,6 +75,7 @@ def get_note(cls) -> str | None: @classmethod def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRunEndT]: + """Get the node definition.""" type_hints = get_type_hints(cls.run, localns=local_ns, include_extras=True) try: return_hint = type_hints['return'] diff --git a/tests/test_examples.py b/tests/test_examples.py index 26ec8ae9a0..8de53ba0e4 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -167,6 +167,10 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: return '1' elif prompt == 'Select product': return 'crisps' + elif prompt == 'What is the capital of France?': + return 'Vichy' + elif prompt == 'what is 1 + 1?': + return '2' else: # pragma: no cover raise ValueError(f'Unexpected prompt: {prompt}') @@ -256,6 +260,15 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: tool_name='final_result_SeatPreference', args=ArgsDict({'row': 1, 'seat': 'A'}), ), + 'Ask a simple question with a single correct answer.': 'What is the capital of France?', + '\n What is the capital of France?\n Vichy\n': ToolCallPart( + tool_name='final_result', + args=ArgsDict({'correct': False, 'comment': 'Vichy is no longer the capital of France.'}), + ), + '\n what is 1 + 1?\n 2\n': ToolCallPart( + tool_name='final_result', + args=ArgsDict({'correct': True, 'comment': 'Well done, 1 + 1 = 2'}), + ), } @@ -284,6 +297,8 @@ async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelRes ) elif m.content.startswith('\n '): return ModelResponse(parts=[ToolCallPart(tool_name='final_result_EmailOk', args=ArgsDict({}))]) + elif m.content == 'Ask a simple question with a single correct answer.' and len(messages) > 2: + return ModelResponse.from_text(content='what is 1 + 1?') elif response := text_responses.get(m.content): if isinstance(response, str): return ModelResponse.from_text(content=response) From 9a7ab8ef2ac8e1bd9fd944bc01a8af4051e50a80 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 14 Jan 2025 00:36:13 +0000 Subject: [PATCH 45/57] fix history serialization --- .../pydantic_ai_examples/question_graph.py | 21 +++--- pydantic_graph/pydantic_graph/graph.py | 46 ++++++++++++- pydantic_graph/pydantic_graph/nodes.py | 9 ++- pydantic_graph/pydantic_graph/state.py | 66 +++++++++++++++---- pydantic_graph/pyproject.toml | 1 + uv.lock | 2 + 6 files changed, 117 insertions(+), 28 deletions(-) diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index 1ad7649d75..1681eba184 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -7,12 +7,12 @@ from __future__ import annotations as _annotations +import copy from dataclasses import dataclass, field from pathlib import Path from typing import Annotated import logfire -import pydantic from devtools import debug from pydantic_graph import BaseNode, Edge, End, Graph, GraphContext, HistoryStep @@ -109,7 +109,9 @@ async def run(self, ctx: GraphContext[QuestionState]) -> Ask: return Ask() -question_graph = Graph(nodes=(Ask, Answer, Evaluate, Congratulate, Reprimand)) +question_graph = Graph( + nodes=(Ask, Answer, Evaluate, Congratulate, Reprimand), state_type=QuestionState +) async def run_as_continuous(): @@ -128,27 +130,23 @@ async def run_as_continuous(): # otherwise just continue -ta = pydantic.TypeAdapter( - list[Annotated[HistoryStep[QuestionState, None], pydantic.Discriminator('kind')]] -) - - async def run_as_cli(answer: str | None): history_file = Path('question_graph_history.json') if history_file.exists(): - history = ta.validate_json(history_file.read_bytes()) + history = question_graph.load_history(history_file.read_bytes()) else: history = [] if history: last = history[-1] - assert last.kind != 'node', 'expected last step to be a node' + assert last.kind == 'node', 'expected last step to be a node' state = last.state assert answer is not None, 'answer is required to continue from history' node = Answer(answer) else: state = QuestionState() node = Ask() + debug(state, node) with logfire.span('run questions graph'): while True: @@ -158,10 +156,13 @@ async def run_as_cli(answer: str | None): print('Finished!') break elif isinstance(node, Answer): + print(state.question) + # hack - history should show the state at the end of the node + history[-1].state = copy.deepcopy(state) break # otherwise just continue - history_file.write_bytes(ta.dump_json(history)) + history_file.write_bytes(question_graph.dump_history(history, indent=2)) if __name__ == '__main__': diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 8085307f85..0d1c02e362 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -4,18 +4,20 @@ import inspect from collections.abc import Sequence from dataclasses import dataclass +from functools import cached_property from pathlib import Path from time import perf_counter from types import FrameType -from typing import TYPE_CHECKING, Any, Callable, Generic +from typing import TYPE_CHECKING, Annotated, Any, Callable, Generic import logfire_api +import pydantic from typing_extensions import Literal, Unpack, assert_never from . import _utils, exceptions, mermaid from ._utils import get_parent_namespace from .nodes import BaseNode, End, GraphContext, NodeDef, RunEndT -from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state +from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state, nodes_schema_var __all__ = ('Graph',) @@ -66,6 +68,7 @@ async def run(self, ctx: GraphContext) -> Increment | End: from the graph. """ + state_type: type[StateT] | None name: str | None node_defs: dict[str, NodeDef[StateT, RunEndT]] snapshot_state: Callable[[StateT], StateT] @@ -74,6 +77,7 @@ def __init__( self, *, nodes: Sequence[type[BaseNode[StateT, RunEndT]]], + state_type: type[StateT] | None = None, name: str | None = None, snapshot_state: Callable[[StateT], StateT] = deep_copy_state, ): @@ -82,12 +86,14 @@ def __init__( Args: nodes: The nodes which make up the graph, nodes need to be unique and all be generic in the same state type. + state_type: The type of the state for the graph. name: Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method. snapshot_state: A function to snapshot the state of the graph, this is used in [`NodeStep`][pydantic_graph.state.NodeStep] and [`EndStep`][pydantic_graph.state.EndStep] to record the state before each step. """ + self.state_type = state_type self.name = name self.snapshot_state = snapshot_state @@ -282,6 +288,42 @@ async def next( history_step.duration = perf_counter() - start return next_node + def dump_history(self, history: list[HistoryStep[StateT, RunEndT]], *, indent: int | None = None) -> bytes: + """Dump the history of a graph run as JSON. + + Args: + history: The history of the graph run. + indent: The number of spaces to indent the JSON. + + Returns: + The JSON representation of the history. + """ + return self.history_type_adapter.dump_json(history, indent=indent) + + def load_history(self, json_bytes: str | bytes | bytearray) -> list[HistoryStep[StateT, RunEndT]]: + """Load the history of a graph run from JSON. + + Args: + json_bytes: The JSON representation of the history. + + Returns: + The history of the graph run. + """ + return self.history_type_adapter.validate_json(json_bytes) + + @cached_property + def history_type_adapter(self) -> pydantic.TypeAdapter[list[HistoryStep[StateT, RunEndT]]]: + nodes = [node_def.node for node_def in self.node_defs.values()] + token = nodes_schema_var.set(nodes) + try: + # TODO get the return type RunEndT + ta = pydantic.TypeAdapter( + list[Annotated[HistoryStep[self.state_type or Any, Any], pydantic.Discriminator('kind')]], + ) + finally: + nodes_schema_var.reset(token) + return ta # type: ignore + def mermaid_code( self, *, diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index 62848fa599..fff06f55f9 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -3,12 +3,16 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, is_dataclass from functools import cache -from typing import Any, ClassVar, Generic, get_origin, get_type_hints +from typing import TYPE_CHECKING, Any, ClassVar, Generic, get_origin, get_type_hints from typing_extensions import Never, TypeVar from . import _utils, exceptions -from .state import StateT + +if TYPE_CHECKING: + from .state import StateT +else: + StateT = TypeVar('StateT', default=None) __all__ = 'GraphContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'RunEndT', 'NodeRunEndT' @@ -26,6 +30,7 @@ class GraphContext(Generic[StateT]): """The state of the graph.""" +@dataclass class BaseNode(ABC, Generic[StateT, NodeRunEndT]): """Base class for a node.""" diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index 174607362c..b828d1f9de 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -1,21 +1,21 @@ from __future__ import annotations as _annotations import copy -from dataclasses import InitVar, dataclass, field +from collections.abc import Sequence +from contextvars import ContextVar +from dataclasses import dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Callable, Generic, Literal, Union +from typing import Annotated, Any, Callable, Generic, Literal, Union +import pydantic +from pydantic_core import core_schema from typing_extensions import TypeVar from . import _utils +from .nodes import BaseNode, End, RunEndT __all__ = 'StateT', 'NodeStep', 'EndStep', 'HistoryStep', 'deep_copy_state' -if TYPE_CHECKING: - from .nodes import BaseNode, End, RunEndT -else: - RunEndT = TypeVar('RunEndT', default=None) - StateT = TypeVar('StateT', default=None) """Type variable for the state in a graph.""" @@ -35,7 +35,8 @@ class NodeStep(Generic[StateT, RunEndT]): state: StateT """The state of the graph before the node is run.""" - node: BaseNode[StateT, RunEndT] + # node: Annotated[BaseNode[StateT, RunEndT], pydantic.WrapSerializer(node_serializer), pydantic.PlainValidator(node_validator)] + node: Annotated[BaseNode[StateT, RunEndT], CustomNodeSchema()] """The node that was run.""" start_ts: datetime = field(default_factory=_utils.now_utc) """The timestamp when the node started running.""" @@ -43,12 +44,13 @@ class NodeStep(Generic[StateT, RunEndT]): """The duration of the node run in seconds.""" kind: Literal['node'] = 'node' """The kind of history step, can be used as a discriminator when deserializing history.""" - snapshot_state: InitVar[Callable[[StateT], StateT]] = deep_copy_state + # waiting for https://github.com/pydantic/pydantic/issues/11264, should in InitVar + snapshot_state: Annotated[Callable[[StateT], StateT], pydantic.Field(exclude=True)] = deep_copy_state """Function to snapshot the state of the graph.""" - def __post_init__(self, snapshot_state: Callable[[StateT], StateT]): + def __post_init__(self): # Copy the state to prevent it from being modified by other code - self.state = snapshot_state(self.state) + self.state = self.snapshot_state(self.state) def data_snapshot(self) -> BaseNode[StateT, RunEndT]: """Returns a deep copy of [`self.node`][pydantic_graph.state.NodeStep.node]. @@ -70,12 +72,13 @@ class EndStep(Generic[StateT, RunEndT]): """The timestamp when the graph run ended.""" kind: Literal['end'] = 'end' """The kind of history step, can be used as a discriminator when deserializing history.""" - snapshot_state: InitVar[Callable[[StateT], StateT]] = deep_copy_state + # waiting for https://github.com/pydantic/pydantic/issues/11264, should in InitVar + snapshot_state: Annotated[Callable[[StateT], StateT], pydantic.Field(exclude=True)] = deep_copy_state """Function to snapshot the state of the graph.""" - def __post_init__(self, snapshot_state: Callable[[StateT], StateT]): + def __post_init__(self): # Copy the state to prevent it from being modified by other code - self.state = snapshot_state(self.state) + self.state = self.snapshot_state(self.state) def data_snapshot(self) -> End[RunEndT]: """Returns a deep copy of [`self.result`][pydantic_graph.state.EndStep.result]. @@ -91,3 +94,38 @@ def data_snapshot(self) -> End[RunEndT]: [`Graph.run`][pydantic_graph.graph.Graph.run] returns a list of these steps describing the execution of the graph, together with the run return value. """ + + +nodes_schema_var: ContextVar[Sequence[type[BaseNode[Any, Any]]]] = ContextVar('nodes_var') + + +class CustomNodeSchema: + def __get_pydantic_core_schema__( + self, _source_type: Any, handler: pydantic.GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + try: + nodes = nodes_schema_var.get() + except LookupError as e: + raise RuntimeError( + 'Unable to build a Pydantic schema for `NodeStep` or `HistoryStep` without setting `nodes_schema_var`. ' + 'You probably want to use ' + ) from e + nodes_annotated = [Annotated[node, pydantic.Tag(node.get_id())] for node in nodes] + nodes_union = Annotated[Union[tuple(nodes_annotated)], pydantic.Discriminator(self._node_discriminator)] + + schema = handler(nodes_union) + schema['serialization'] = core_schema.wrap_serializer_function_ser_schema( + function=self._node_serializer, + return_schema=core_schema.dict_schema(core_schema.str_schema(), core_schema.any_schema()), + ) + return schema + + @staticmethod + def _node_discriminator(node_data: Any) -> str: + return node_data.get('node_id') + + @staticmethod + def _node_serializer(node: Any, handler: pydantic.SerializerFunctionWrapHandler) -> dict[str, Any]: + node_dict = handler(node) + node_dict['node_id'] = node.get_id() + return node_dict diff --git a/pydantic_graph/pyproject.toml b/pydantic_graph/pyproject.toml index 89cbad516c..743814f6b6 100644 --- a/pydantic_graph/pyproject.toml +++ b/pydantic_graph/pyproject.toml @@ -36,6 +36,7 @@ requires-python = ">=3.9" dependencies = [ "httpx>=0.27.2", "logfire-api>=1.2.0", + "pydantic>=2.10", ] [tool.hatch.build.targets.wheel] diff --git a/uv.lock b/uv.lock index 89e2fb1456..8d42a60d0b 100644 --- a/uv.lock +++ b/uv.lock @@ -2726,12 +2726,14 @@ source = { editable = "pydantic_graph" } dependencies = [ { name = "httpx" }, { name = "logfire-api" }, + { name = "pydantic" }, ] [package.metadata] requires-dist = [ { name = "httpx", specifier = ">=0.27.2" }, { name = "logfire-api", specifier = ">=1.2.0" }, + { name = "pydantic", specifier = ">=2.10" }, ] [[package]] From 4cd9142dedf0c2e5fe83fb9bd8d30774c2c24ddf Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 14 Jan 2025 18:21:50 +0000 Subject: [PATCH 46/57] add history (de)serialization tests --- pydantic_graph/pydantic_graph/graph.py | 68 ++------------------------ pydantic_graph/pydantic_graph/nodes.py | 1 - pydantic_graph/pydantic_graph/state.py | 4 +- tests/graph/test_history.py | 52 ++++++++++++++++++++ 4 files changed, 59 insertions(+), 66 deletions(-) create mode 100644 tests/graph/test_history.py diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 0d1c02e362..2effdd8dfd 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -132,75 +132,15 @@ async def main(): _, history = await never_42_graph.run(state, Increment()) print(state) #> MyState(number=2) - print(history) - ''' - [ - NodeStep( - state=MyState(number=1), - node=Increment(), - start_ts=datetime.datetime(...), - duration=0.0..., - kind='node', - ), - NodeStep( - state=MyState(number=2), - node=Check42(), - start_ts=datetime.datetime(...), - duration=0.0..., - kind='node', - ), - EndStep( - state=MyState(number=2), - result=End(data=None), - ts=datetime.datetime(...), - kind='end', - ), - ] - ''' + print(len(history)) + #> 3 state = MyState(41) _, history = await never_42_graph.run(state, Increment()) print(state) #> MyState(number=43) - print(history) - ''' - [ - NodeStep( - state=MyState(number=41), - node=Increment(), - start_ts=datetime.datetime(...), - duration=0.0..., - kind='node', - ), - NodeStep( - state=MyState(number=42), - node=Check42(), - start_ts=datetime.datetime(...), - duration=0.0..., - kind='node', - ), - NodeStep( - state=MyState(number=42), - node=Increment(), - start_ts=datetime.datetime(...), - duration=0.0..., - kind='node', - ), - NodeStep( - state=MyState(number=43), - node=Check42(), - start_ts=datetime.datetime(...), - duration=0.0..., - kind='node', - ), - EndStep( - state=MyState(number=43), - result=End(data=None), - ts=datetime.datetime(...), - kind='end', - ), - ] - ''' + print(len(history)) + #> 5 ``` """ if infer_name and self.name is None: diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index fff06f55f9..ef826bb2d0 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -30,7 +30,6 @@ class GraphContext(Generic[StateT]): """The state of the graph.""" -@dataclass class BaseNode(ABC, Generic[StateT, NodeRunEndT]): """Base class for a node.""" diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index b828d1f9de..c145be54c4 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -45,7 +45,9 @@ class NodeStep(Generic[StateT, RunEndT]): kind: Literal['node'] = 'node' """The kind of history step, can be used as a discriminator when deserializing history.""" # waiting for https://github.com/pydantic/pydantic/issues/11264, should in InitVar - snapshot_state: Annotated[Callable[[StateT], StateT], pydantic.Field(exclude=True)] = deep_copy_state + snapshot_state: Annotated[Callable[[StateT], StateT], pydantic.Field(exclude=True, repr=False)] = field( + default=deep_copy_state, repr=False + ) """Function to snapshot the state of the graph.""" def __post_init__(self): diff --git a/tests/graph/test_history.py b/tests/graph/test_history.py new file mode 100644 index 0000000000..ef4a455d0e --- /dev/null +++ b/tests/graph/test_history.py @@ -0,0 +1,52 @@ +from __future__ import annotations as _annotations + +from dataclasses import dataclass +from datetime import timezone + +import pytest +from inline_snapshot import snapshot + +from pydantic_graph import BaseNode, End, EndStep, Graph, GraphContext, NodeStep + +from ..conftest import IsFloat, IsNow + +pytestmark = pytest.mark.anyio + + +@dataclass +class MyState: + x: int + y: str + + +@dataclass +class Foo(BaseNode[MyState]): + async def run(self, ctx: GraphContext[MyState]) -> Bar: + ctx.state.x += 1 + return Bar() + + +@dataclass +class Bar(BaseNode[MyState, str]): + async def run(self, ctx: GraphContext[MyState]) -> End[str]: + ctx.state.y += 'y' + return End(f'x={ctx.state.x} y={ctx.state.y}') + + +graph = Graph(nodes=(Foo, Bar), state_type=MyState) + + +async def test_dump_history(): + result, history = await graph.run(MyState(1, ''), Foo()) + assert result == snapshot('x=2 y=y') + assert history == snapshot( + [ + NodeStep(state=MyState(x=1, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeStep(state=MyState(x=2, y=''), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + EndStep(state=MyState(x=2, y='y'), result=End(data='x=2 y=y'), ts=IsNow(tz=timezone.utc)), + ] + ) + history_json = graph.dump_history(history) + assert history_json.startswith(b'[{"state":') + history_loaded = graph.load_history(history_json) + assert history == history_loaded From 598fdd6321335e9c5ce747421f8a84e90d006a75 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 14 Jan 2025 19:18:10 +0000 Subject: [PATCH 47/57] add mermaid diagram section to graph docs --- docs/graph.md | 83 ++++++++++++++++++++++++-- pydantic_graph/pydantic_graph/nodes.py | 17 ++++-- tests/graph/test_mermaid.py | 2 +- 3 files changed, 92 insertions(+), 10 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index 1e2a8f2ffb..79ea118c8d 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -556,7 +556,6 @@ In this example, an AI asks the user a question, the user provides an answer, th async def run(self, ctx: GraphContext[QuestionState]) -> Ask: print(f'Comment: {self.comment}') - #> Comment: Vichy is no longer the capital of France. ctx.state.question = None return Ask() @@ -632,10 +631,84 @@ stateDiagram-v2 Reprimand --> Ask ``` -You maybe have noticed that although this examples transfers control flow out of the graph run, we're still using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] to get user input, with the process hanging while we wait for the user to enter a response. - -TODO +You maybe have noticed that although this examples transfers control flow out of the graph run, we're still using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] to get user input, with the process hanging while we wait for the user to enter a response. For an example of genuine out-of-process control flow, see the [question graph example](examples/question-graph.md). ## Mermaid Diagrams -TODO +Pydantic Graph, can generate [mermaid](https://mermaid.js.org/) [`stateDiagram-v2`](https://mermaid.js.org/syntax/stateDiagram.html) diagrams for graphs, as shown above. + +These diagrams can be generated with: + +* [`Graph.mermaid_code`][pydantic_graph.graph.Graph.mermaid_code] to generate the mermaid code for a graph +* [`Graph.mermaid_image`][pydantic_graph.graph.Graph.mermaid_image] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) +* ['Graph.mermaid_save`][pydantic_graph.graph.Graph.mermaid_save] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) and save it to a file + +Beyond the diagrams shown above, you can also customise mermaid diagrams with the following options: + +* [`Edge`][pydantic_graph.nodes.Edge] allows you to apply a label to an edge +* [`BaseNode.docstring_notes`][pydantic_graph.nodes.BaseNode.docstring_notes] and [`BaseNode.get_note`][pydantic_graph.nodes.BaseNode.get_note] allows you to add notes to nodes +* The [`highlighted_nodes`][pydantic_graph.graph.Graph.mermaid_code] parameter allows you to highlight specific node(s) in the diagram + +Putting that together, we can edit the last [`ai_q_and_a_graph.py`](#custom-control-flow) example to: + +* add labels to some edges +* add a note to the `Ask` node +* highlight the `Answer` node +* save the diagram as a `PNG` image to file + +```python {title="ai_q_and_a_graph_extra.py" test="skip" lint="skip" hl_lines="2 4 10-11 14 26 31"} +... +from typing import Annotated + +from pydantic_graph import BaseNode, End, Graph, GraphContext, Edge + +... + +@dataclass +class Ask(BaseNode[QuestionState]): + """Generate question using GPT-4o.""" + docstring_notes = True + async def run( + self, ctx: GraphContext[QuestionState] + ) -> Annotated[Answer, Edge(label='Ask the question')]: + ... + +... + +@dataclass +class Evaluate(BaseNode[QuestionState]): + answer: str + + async def run( + self, + ctx: GraphContext[QuestionState], + ) -> Annotated[End[str], Edge(label='success')] | Reprimand: + ... + +... + +question_graph.mermaid_save('image.png', highlighted_nodes=[Answer]) +``` + +_(This example is not complete and cannot be run directly)_ + +Would generate and image that looks like this: + +```mermaid +--- +title: question_graph +--- +stateDiagram-v2 + Ask --> Answer: Ask the question + note right of Ask + Judge the answer. + Decide on next step. + end note + Answer --> Evaluate + Evaluate --> Reprimand + Evaluate --> [*]: success + Reprimand --> Ask + +classDef highlighted fill:#fdff32 +class Answer highlighted +``` diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index ef826bb2d0..11e793c82c 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -33,8 +33,13 @@ class GraphContext(Generic[StateT]): class BaseNode(ABC, Generic[StateT, NodeRunEndT]): """Base class for a node.""" - enable_docstring_notes: ClassVar[bool] = True - """Set to `False` to not generate mermaid diagram notes from the class's docstring.""" + docstring_notes: ClassVar[bool] = False + """Set to `True` to generate mermaid diagram notes from the class's docstring. + + While this can add valuable information to the diagram, it can make diagrams harder to view, hence + it is disabled by default. You can also customise notes overriding the + [`get_note`][pydantic_graph.nodes.BaseNode.get_note] method. + """ @abstractmethod async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: @@ -63,8 +68,12 @@ def get_id(cls) -> str: @classmethod def get_note(cls) -> str | None: - """Get a note about the node to render on mermaid charts.""" - if not cls.enable_docstring_notes: + """Get a note about the node to render on mermaid charts. + + By default, this returns a note only if [`docstring_notes`][pydantic_graph.nodes.BaseNode.docstring_notes] + is `True`. You can override this method to customise the node notes. + """ + if not cls.docstring_notes: return None docstring = cls.__doc__ # dataclasses get an automatic docstring which is just their signature, we don't want that diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 1f40df5c8c..fc9a99c123 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -46,7 +46,7 @@ async def run(self, ctx: GraphContext) -> Annotated[Foo, Edge(label='spam to foo class Eggs(BaseNode[None, None]): """This is the docstring for Eggs.""" - enable_docstring_notes = False + docstring_notes = False async def run(self, ctx: GraphContext) -> Annotated[End[None], Edge(label='eggs to end')]: raise NotImplementedError() From bf8b824015caed834812cac5eba5080e1b33b3fd Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 14 Jan 2025 19:27:37 +0000 Subject: [PATCH 48/57] fix tests --- .github/workflows/ci.yml | 2 +- tests/graph/test_mermaid.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a3c7cf99a5..56003faa93 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,7 +10,7 @@ on: env: CI: true - COLUMNS: 120 + COLUMNS: 150 UV_PYTHON: 3.12 UV_FROZEN: '1' diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index fc9a99c123..2cf2b01423 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -38,6 +38,8 @@ async def run(self, ctx: GraphContext) -> End[None]: class Spam(BaseNode): """This is the docstring for Spam.""" + docstring_notes = True + async def run(self, ctx: GraphContext) -> Annotated[Foo, Edge(label='spam to foo')]: raise NotImplementedError() @@ -192,6 +194,11 @@ def test_mermaid_code_all_nodes(): """) +def test_docstring_notes_classvar(): + assert Spam.docstring_notes is True + assert repr(Spam()) == 'Spam()' + + @pytest.fixture def httpx_with_handler() -> Iterator[HttpxWithHandler]: client: httpx.Client | None = None From aec8fab041d207889abe5d2bf6bc9344e6528f6b Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 14 Jan 2025 19:31:52 +0000 Subject: [PATCH 49/57] add exceptions docs --- docs/api/pydantic_graph/exceptions.md | 3 +++ mkdocs.yml | 1 + 2 files changed, 4 insertions(+) create mode 100644 docs/api/pydantic_graph/exceptions.md diff --git a/docs/api/pydantic_graph/exceptions.md b/docs/api/pydantic_graph/exceptions.md new file mode 100644 index 0000000000..ab9f3782ee --- /dev/null +++ b/docs/api/pydantic_graph/exceptions.md @@ -0,0 +1,3 @@ +# `pydantic_graph.exceptions` + +::: pydantic_graph.exceptions diff --git a/mkdocs.yml b/mkdocs.yml index 3ff7b55503..c363e37ef8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -61,6 +61,7 @@ nav: - api/pydantic_graph/nodes.md - api/pydantic_graph/state.md - api/pydantic_graph/mermaid.md + - api/pydantic_graph/exceptions.md extra: # hide the "Made with Material for MkDocs" message From 7d4f31d7f6bd3bea487d577aa5d974210598ef4b Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 14 Jan 2025 19:40:51 +0000 Subject: [PATCH 50/57] docs tweaks --- docs/graph.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index 79ea118c8d..0aa2156911 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -20,7 +20,7 @@ While this library is developed as part of the PydanticAI; it has no dependency `pydantic-graph` is designed for advanced users and makes heavy use of Python generics and types hints. It is not designed to be as beginner-friendly as PydanticAI. -!!! note "Every Early beta" +!!! note "Very Early beta" Graph support was [introduced](https://github.com/pydantic/pydantic-ai/pull/528) in v0.0.19 and is in very earlier beta. The API is subject to change. The documentation is incomplete. The implementation is incomplete. ## Installation @@ -89,7 +89,7 @@ class MyNode(BaseNode[MyState]): # (1)! We could extend `MyNode` to optionally end the run if `foo` is divisible by 5: -```py {title="intermediate_or_end_node.py" hl_lines="7 13" noqa="F821" test="skip"} +```py {title="intermediate_or_end_node.py" hl_lines="7 13 15" noqa="F821" test="skip"} from dataclasses import dataclass from pydantic_graph import BaseNode, End, GraphContext @@ -641,7 +641,7 @@ These diagrams can be generated with: * [`Graph.mermaid_code`][pydantic_graph.graph.Graph.mermaid_code] to generate the mermaid code for a graph * [`Graph.mermaid_image`][pydantic_graph.graph.Graph.mermaid_image] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) -* ['Graph.mermaid_save`][pydantic_graph.graph.Graph.mermaid_save] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) and save it to a file +* [`Graph.mermaid_save`][pydantic_graph.graph.Graph.mermaid_save] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) and save it to a file Beyond the diagrams shown above, you can also customise mermaid diagrams with the following options: From 79cb2b3db9dde8af935b7f113f8efb02ec39fe90 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 15 Jan 2025 13:19:47 +0000 Subject: [PATCH 51/57] copy edits from @dmontagu --- docs/graph.md | 32 ++++++++-------- docs/multi-agent-applications.md | 4 ++ pydantic_graph/README.md | 53 +++++++++++++++++++++++--- pydantic_graph/pydantic_graph/nodes.py | 2 +- pydantic_graph/pydantic_graph/state.py | 3 +- pydantic_graph/pyproject.toml | 6 +++ 6 files changed, 76 insertions(+), 24 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index 0aa2156911..ba5e00f511 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -53,14 +53,14 @@ Subclasses of [`BaseNode`][pydantic_graph.nodes.BaseNode] to define nodes. Nodes which are generally [`dataclass`es][dataclasses.dataclass] include: -* any parameters required when calling the node -* the business logic to execute the node -* return annotations which are read by `pydantic-graph` to determine the outgoing edges of the node +* fields containing any parameters required/optional when calling the node +* the business logic to execute the node, in the [`run`][pydantic_graph.nodes.BaseNode.run] method +* return annotations of the [`run`][pydantic_graph.nodes.BaseNode.run] method, which are read by `pydantic-graph` to determine the outgoing edges of the node Nodes are generic in: -* **state** which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.state.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter -* **graph return type** this only applies if the node returns [`End`][pydantic_graph.nodes.End], [`RunEndT`][pydantic_graph.nodes.RunEndT] has a default of [Never][typing.Never] so this generic parameter can be omitted if the node doesn't return `End`, but must be included if it does. +* **state**, which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.state.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter +* **graph return type** — this only applies if the node returns [`End`][pydantic_graph.nodes.End]. [`RunEndT`][pydantic_graph.nodes.RunEndT] has a default of [Never][typing.Never] so this generic parameter can be omitted if the node doesn't return `End`, but must be included if it does. Here's an example of a start or intermediate node in a graph — it can't end the run as it doesn't return [`End`][pydantic_graph.nodes.End]: @@ -114,7 +114,7 @@ class MyNode(BaseNode[MyState, int]): # (1)! ### Graph -[`Graph`][pydantic_graph.graph.Graph] — The graph itself, made up of a set of nodes. +[`Graph`][pydantic_graph.graph.Graph] — this is the execution graph itself, made up of a set of [node classes](#nodes) (i.e., `BaseNode` subclasses). `Graph` is generic in: @@ -284,7 +284,7 @@ async def main(): 2. A dictionary of products mapped to prices. 3. The `InsertCoin` node, [`BaseNode`][pydantic_graph.nodes.BaseNode] is parameterized with `MachineState` as that's the state used in this graph. 4. The `InsertCoin` node prompts the user to insert coins. We keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] within nodes, see [below](#custom-control-flow) for how control flow can be managed when nodes require external input. -5. The `CoinsInserted` node, again this is a [`dataclass`][dataclasses.dataclass], in this case with one field `amount`, thus nodes calling `CoinsInserted` must provide an amount. +5. The `CoinsInserted` node; again this is a [`dataclass`][dataclasses.dataclass], in this case with one field `amount`, thus nodes calling `CoinsInserted` must provide an amount. 6. Update the user's balance with the amount inserted. 7. If the user has already selected a product, go to `Purchase`, otherwise go to `SelectProduct`. 8. In the `Purchase` node, look up the price of the product if the user entered a valid product. @@ -293,11 +293,11 @@ async def main(): 11. If the balance is insufficient, to go `InsertCoin` to prompt the user to insert more coins. 12. If the product is invalid, go to `SelectProduct` to prompt the user to select a product again. 13. The graph is created by passing a list of nodes to [`Graph`][pydantic_graph.graph.Graph]. Order of nodes is not important, but will alter how [diagrams](#mermaid-diagrams) are displayed. -14. Initialize the state, this will be passed to the graph run and mutated as the graph runs. -15. Run the graph with the initial state, since the graph can be run from any node, we must pass the start node, in this case `InsertCoin`. [`Graph.run`][pydantic_graph.graph.Graph.run] returns a tuple of the return value (`None`) in this case, and the [history][pydantic_graph.state.HistoryStep] of the graph run. -16. The return type of the node's [`run`][pydantic_graph.nodes.BaseNode.run] method is important, it's used to determine the outgoing edges of the node, this in turn is used to render [mermaid diagrams](#mermaid-diagrams) and is enforced at runtime. -17. The return type of `CoinsInserted`s [`run`][pydantic_graph.nodes.BaseNode.run] method is a union, meaning multiple outgoing edges are possible. -18. Unlike other nodes `Purchase` can end the run, so the [`RunEndT`][pydantic_graph.nodes.RunEndT] generic parameter must be set, in this case it's `None` since the graph run return type is `None`. +14. Initialize the state. This will be passed to the graph run and mutated as the graph runs. +15. Run the graph with the initial state. Since the graph can be run from any node, we must pass the start node — in this case, `InsertCoin`. [`Graph.run`][pydantic_graph.graph.Graph.run] returns a tuple of the return value (`None`) in this case, and the [history][pydantic_graph.state.HistoryStep] of the graph run. +16. The return type of the node's [`run`][pydantic_graph.nodes.BaseNode.run] method is important as it is used to determine the outgoing edges of the node. This information in turn is used to render [mermaid diagrams](#mermaid-diagrams) and is enforced at runtime to detect misbehavior as soon as possible. +17. The return type of `CoinsInserted`'s [`run`][pydantic_graph.nodes.BaseNode.run] method is a union, meaning multiple outgoing edges are possible. +18. Unlike other nodes, `Purchase` can end the run, so the [`RunEndT`][pydantic_graph.nodes.RunEndT] generic parameter must be set. In this case it's `None` since the graph run return type is `None`. _(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ @@ -467,7 +467,7 @@ _(This example is complete, it can be run "as is" with Python 3.10+ — you'll n ## Custom Control Flow -In many real-world applications, Graphs cannot run uninterrupted from start to finish — they require external input or run over an extended period of time where a single process cannot execute the entire graph. +In many real-world applications, Graphs cannot run uninterrupted from start to finish — they might require external input, or run over an extended period of time such that a single process cannot execute the entire graph run from start to finish without interruption. In these scenarios the [`next`][pydantic_graph.graph.Graph.next] method can be used to run the graph one node at a time. @@ -603,7 +603,7 @@ async def main(): 1. Create the state object which will be mutated by [`next`][pydantic_graph.graph.Graph.next]. 2. The start node is `Ask` but will be updated by [`next`][pydantic_graph.graph.Graph.next] as the graph runs. -3. The history of the graph run is stored in a list of [`HistoryStep`][pydantic_graph.state.HistoryStep] objects, again [`next`][pydantic_graph.graph.Graph.next] will update this list in place. +3. The history of the graph run is stored in a list of [`HistoryStep`][pydantic_graph.state.HistoryStep] objects. Again [`next`][pydantic_graph.graph.Graph.next] will update this list in place. 4. [Run][pydantic_graph.graph.Graph.next] the graph one node at a time, updating the state, current node and history as the graph runs. 5. If the current node is an `Answer` node, prompt the user for an answer. 6. Since we're using [`next`][pydantic_graph.graph.Graph.next] we have to manually check for an [`End`][pydantic_graph.nodes.End] and exit the loop if we get one. @@ -635,7 +635,7 @@ You maybe have noticed that although this examples transfers control flow out of ## Mermaid Diagrams -Pydantic Graph, can generate [mermaid](https://mermaid.js.org/) [`stateDiagram-v2`](https://mermaid.js.org/syntax/stateDiagram.html) diagrams for graphs, as shown above. +Pydantic Graph can generate [mermaid](https://mermaid.js.org/) [`stateDiagram-v2`](https://mermaid.js.org/syntax/stateDiagram.html) diagrams for graphs, as shown above. These diagrams can be generated with: @@ -643,7 +643,7 @@ These diagrams can be generated with: * [`Graph.mermaid_image`][pydantic_graph.graph.Graph.mermaid_image] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) * [`Graph.mermaid_save`][pydantic_graph.graph.Graph.mermaid_save] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) and save it to a file -Beyond the diagrams shown above, you can also customise mermaid diagrams with the following options: +Beyond the diagrams shown above, you can also customize mermaid diagrams with the following options: * [`Edge`][pydantic_graph.nodes.Edge] allows you to apply a label to an edge * [`BaseNode.docstring_notes`][pydantic_graph.nodes.BaseNode.docstring_notes] and [`BaseNode.get_note`][pydantic_graph.nodes.BaseNode.get_note] allows you to add notes to nodes diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index 1506883141..eecf43c608 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -330,6 +330,10 @@ graph TB seat_preference_agent --> END ``` +## Pydantic Graphs + +See the [graph](graph.md) documentation on when and how to use graphs. + ## Examples The following examples demonstrate how to use dependencies in PydanticAI: diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md index 069f58924d..e01f9bf80d 100644 --- a/pydantic_graph/README.md +++ b/pydantic_graph/README.md @@ -3,14 +3,57 @@ [![CI](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain) [![Coverage](https://coverage-badge.samuelcolvin.workers.dev/pydantic/pydantic-ai.svg)](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai) [![PyPI](https://img.shields.io/pypi/v/pydantic-graph.svg)](https://pypi.python.org/pypi/pydantic-graph) -[![versions](https://img.shields.io/pypi/pyversions/pydantic-graph.svg)](https://github.com/pydantic/pydantic-ai) -[![license](https://img.shields.io/github/license/pydantic/pydantic-ai.svg?v)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE) +[![python versions](https://img.shields.io/pypi/pyversions/pydantic-graph.svg)](https://github.com/pydantic/pydantic-ai) +[![license](https://img.shields.io/github/license/pydantic/pydantic-ai.svg)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE) Graph and finite state machine library. -This library is developed as part of the [PydanticAI](https://ai.pydantic.dev), however it has no dependency -on `pydantic-ai` or related packages and can be considered as a pure graph library. +This library is developed as part of [PydanticAI](https://ai.pydantic.dev), however it has no dependency +on `pydantic-ai` or related packages and can be considered as a pure graph-based state machine library. You may find it useful whether or not you're using PydanticAI or even building with GenAI. As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. -`pydantic-graph` allows you to define graphs using simple Python syntax. In particular, edges are defined using the return type hint of nodes. +`pydantic-graph` allows you to define graphs using standard Python syntax. In particular, edges are defined using the return type hint of nodes. + +Full documentation is available at [ai.pydantic.dev/graph](https://ai.pydantic.dev/graph). + +Here's a basic example: + +```python +from __future__ import annotations + +from dataclasses import dataclass + +from pydantic_graph import BaseNode, End, Graph, GraphContext + + +@dataclass +class DivisibleBy5(BaseNode[None, int]): + foo: int + + async def run( + self, + ctx: GraphContext, + ) -> Increment | End[int]: + if self.foo % 5 == 0: + return End(self.foo) + else: + return Increment(self.foo) + + +@dataclass +class Increment(BaseNode): + foo: int + + async def run(self, ctx: GraphContext) -> DivisibleBy5: + return DivisibleBy5(self.foo + 1) + + +fives_graph = Graph(nodes=[DivisibleBy5, Increment]) +result, history = fives_graph.run_sync(None, DivisibleBy5(4)) +print(result) +#> 5 +# the full history is quite verbose (see below), so we'll just print the summary +print([item.data_snapshot() for item in history]) +#> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] +``` diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index 11e793c82c..ec97b16e84 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -142,7 +142,7 @@ class Edge: class NodeDef(Generic[StateT, NodeRunEndT]): """Definition of a node. - This is an internal representation of a node, it shouldn't be necessary to use it directly. + This is a primarily internal representation of a node; in general, it shouldn't be necessary to use it directly. Used by [`Graph`][pydantic_graph.graph.Graph] to store information about a node, and when generating mermaid graphs. diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index c145be54c4..b1b596b855 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -35,7 +35,6 @@ class NodeStep(Generic[StateT, RunEndT]): state: StateT """The state of the graph before the node is run.""" - # node: Annotated[BaseNode[StateT, RunEndT], pydantic.WrapSerializer(node_serializer), pydantic.PlainValidator(node_validator)] node: Annotated[BaseNode[StateT, RunEndT], CustomNodeSchema()] """The node that was run.""" start_ts: datetime = field(default_factory=_utils.now_utc) @@ -44,7 +43,7 @@ class NodeStep(Generic[StateT, RunEndT]): """The duration of the node run in seconds.""" kind: Literal['node'] = 'node' """The kind of history step, can be used as a discriminator when deserializing history.""" - # waiting for https://github.com/pydantic/pydantic/issues/11264, should in InitVar + # waiting for https://github.com/pydantic/pydantic/issues/11264, should be an InitVar snapshot_state: Annotated[Callable[[StateT], StateT], pydantic.Field(exclude=True, repr=False)] = field( default=deep_copy_state, repr=False ) diff --git a/pydantic_graph/pyproject.toml b/pydantic_graph/pyproject.toml index 743814f6b6..948e48e462 100644 --- a/pydantic_graph/pyproject.toml +++ b/pydantic_graph/pyproject.toml @@ -39,5 +39,11 @@ dependencies = [ "pydantic>=2.10", ] +[project.urls] +Homepage = "https://ai.pydantic.dev/graph/tree/main/pydantic_graph" +Source = "https://github.com/pydantic/pydantic-ai" +Documentation = "https://ai.pydantic.dev/graph" +Changelog = "https://github.com/pydantic/pydantic-ai/releases" + [tool.hatch.build.targets.wheel] packages = ["pydantic_graph"] From 38a787e78364f2d75aa5ff8fac768e553af80702 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 15 Jan 2025 14:01:41 +0000 Subject: [PATCH 52/57] fix pydantic-graph readme --- pydantic_graph/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md index e01f9bf80d..6ac17fb216 100644 --- a/pydantic_graph/README.md +++ b/pydantic_graph/README.md @@ -19,7 +19,7 @@ Full documentation is available at [ai.pydantic.dev/graph](https://ai.pydantic.d Here's a basic example: -```python +```python {noqa="I001" py="3.10"} from __future__ import annotations from dataclasses import dataclass From 6db58d6bffbc6987314925e5bd7f8c716eb0a848 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 15 Jan 2025 15:32:02 +0000 Subject: [PATCH 53/57] snapshot state after node execution, not before --- .gitignore | 1 + examples/pydantic_ai_examples/question_graph.py | 12 +++++------- pydantic_graph/pydantic_graph/graph.py | 12 +++++++----- pydantic_graph/pydantic_graph/state.py | 17 ++++------------- tests/graph/test_graph.py | 12 ++---------- tests/graph/test_history.py | 6 +++--- tests/graph/test_mermaid.py | 2 +- tests/graph/test_state.py | 15 ++++++--------- 8 files changed, 29 insertions(+), 48 deletions(-) diff --git a/.gitignore b/.gitignore index 73c84030ba..5813be689e 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ env*/ examples/pydantic_ai_examples/.chat_app_messages.sqlite .cache/ .vscode/ +/question_graph_history.json diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index 1681eba184..f73dbf1684 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -7,7 +7,6 @@ from __future__ import annotations as _annotations -import copy from dataclasses import dataclass, field from pathlib import Path from typing import Annotated @@ -132,10 +131,11 @@ async def run_as_continuous(): async def run_as_cli(answer: str | None): history_file = Path('question_graph_history.json') - if history_file.exists(): - history = question_graph.load_history(history_file.read_bytes()) - else: - history = [] + history = ( + question_graph.load_history(history_file.read_bytes()) + if history_file.exists() + else [] + ) if history: last = history[-1] @@ -157,8 +157,6 @@ async def run_as_cli(answer: str | None): break elif isinstance(node, Answer): print(state.question) - # hack - history should show the state at the end of the node - history[-1].state = copy.deepcopy(state) break # otherwise just continue diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 2effdd8dfd..a4c31271a1 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -155,7 +155,7 @@ async def main(): while True: next_node = await self.next(state, start_node, history, infer_name=False) if isinstance(next_node, End): - history.append(EndStep(state, next_node, snapshot_state=self.snapshot_state)) + history.append(EndStep(result=next_node)) run_span.set_attribute('history', history) return next_node.data, history elif isinstance(next_node, BaseNode): @@ -218,14 +218,16 @@ async def next( if node_id not in self.node_defs: raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') - history_step: NodeStep[StateT, RunEndT] = NodeStep(state, node) - history.append(history_step) - ctx = GraphContext(state) with _logfire.span('run node {node_id}', node_id=node_id, node=node): + start_ts = _utils.now_utc() start = perf_counter() next_node = await node.run(ctx) - history_step.duration = perf_counter() - start + duration = perf_counter() - start + + history.append( + NodeStep(state=state, node=node, start_ts=start_ts, duration=duration, snapshot_state=self.snapshot_state) + ) return next_node def dump_history(self, history: list[HistoryStep[StateT, RunEndT]], *, indent: int | None = None) -> bytes: diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index b1b596b855..52cf8a7f7b 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -34,7 +34,7 @@ class NodeStep(Generic[StateT, RunEndT]): """History step describing the execution of a node in a graph.""" state: StateT - """The state of the graph before the node is run.""" + """The state of the graph after the node has been run.""" node: Annotated[BaseNode[StateT, RunEndT], CustomNodeSchema()] """The node that was run.""" start_ts: datetime = field(default_factory=_utils.now_utc) @@ -43,7 +43,7 @@ class NodeStep(Generic[StateT, RunEndT]): """The duration of the node run in seconds.""" kind: Literal['node'] = 'node' """The kind of history step, can be used as a discriminator when deserializing history.""" - # waiting for https://github.com/pydantic/pydantic/issues/11264, should be an InitVar + # TODO waiting for https://github.com/pydantic/pydantic/issues/11264, should be an InitVar snapshot_state: Annotated[Callable[[StateT], StateT], pydantic.Field(exclude=True, repr=False)] = field( default=deep_copy_state, repr=False ) @@ -62,24 +62,15 @@ def data_snapshot(self) -> BaseNode[StateT, RunEndT]: @dataclass -class EndStep(Generic[StateT, RunEndT]): +class EndStep(Generic[RunEndT]): """History step describing the end of a graph run.""" - state: StateT - """The state of the graph after the run.""" result: End[RunEndT] """The result of the graph run.""" ts: datetime = field(default_factory=_utils.now_utc) """The timestamp when the graph run ended.""" kind: Literal['end'] = 'end' """The kind of history step, can be used as a discriminator when deserializing history.""" - # waiting for https://github.com/pydantic/pydantic/issues/11264, should in InitVar - snapshot_state: Annotated[Callable[[StateT], StateT], pydantic.Field(exclude=True)] = deep_copy_state - """Function to snapshot the state of the graph.""" - - def __post_init__(self): - # Copy the state to prevent it from being modified by other code - self.state = self.snapshot_state(self.state) def data_snapshot(self) -> End[RunEndT]: """Returns a deep copy of [`self.result`][pydantic_graph.state.EndStep.result]. @@ -89,7 +80,7 @@ def data_snapshot(self) -> End[RunEndT]: return copy.deepcopy(self.result) -HistoryStep = Union[NodeStep[StateT, RunEndT], EndStep[StateT, RunEndT]] +HistoryStep = Union[NodeStep[StateT, RunEndT], EndStep[RunEndT]] """A step in the history of a graph run. [`Graph.run`][pydantic_graph.graph.Graph.run] returns a list of these steps describing the execution of the graph, diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 5549ffaa8b..0f226857da 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -78,11 +78,7 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndStep( - state=None, - result=End(data=8), - ts=IsNow(tz=timezone.utc), - ), + EndStep(result=End(data=8), ts=IsNow(tz=timezone.utc)), ] ) result, history = await my_graph.run(None, Float2String(3.14159)) @@ -120,11 +116,7 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndStep( - state=None, - result=End(data=42), - ts=IsNow(tz=timezone.utc), - ), + EndStep(result=End(data=42), ts=IsNow(tz=timezone.utc)), ] ) assert [e.data_snapshot() for e in history] == snapshot( diff --git a/tests/graph/test_history.py b/tests/graph/test_history.py index ef4a455d0e..6dff468a3e 100644 --- a/tests/graph/test_history.py +++ b/tests/graph/test_history.py @@ -41,9 +41,9 @@ async def test_dump_history(): assert result == snapshot('x=2 y=y') assert history == snapshot( [ - NodeStep(state=MyState(x=1, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - NodeStep(state=MyState(x=2, y=''), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - EndStep(state=MyState(x=2, y='y'), result=End(data='x=2 y=y'), ts=IsNow(tz=timezone.utc)), + NodeStep(state=MyState(x=2, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeStep(state=MyState(x=2, y='y'), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + EndStep(result=End('x=2 y=y'), ts=IsNow(tz=timezone.utc)), ] ) history_json = graph.dump_history(history) diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 2cf2b01423..1d33f2c611 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -74,7 +74,7 @@ async def test_run_graph(): start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndStep(state=None, result=End(data=None), ts=IsNow(tz=timezone.utc)), + EndStep(result=End(data=None), ts=IsNow(tz=timezone.utc)), ] ) diff --git a/tests/graph/test_state.py b/tests/graph/test_state.py index b363e37674..a19df012d7 100644 --- a/tests/graph/test_state.py +++ b/tests/graph/test_state.py @@ -32,27 +32,24 @@ async def run(self, ctx: GraphContext[MyState]) -> End[str]: return End(f'x={ctx.state.x} y={ctx.state.y}') graph = Graph(nodes=(Foo, Bar)) - s = MyState(1, '') - result, history = await graph.run(s, Foo()) + state = MyState(1, '') + result, history = await graph.run(state, Foo()) assert result == snapshot('x=2 y=y') assert history == snapshot( [ NodeStep( - state=MyState(x=1, y=''), + state=MyState(x=2, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), NodeStep( - state=MyState(x=2, y=''), + state=MyState(x=2, y='y'), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat(), ), - EndStep( - state=MyState(x=2, y='y'), - result=End('x=2 y=y'), - ts=IsNow(tz=timezone.utc), - ), + EndStep(result=End('x=2 y=y'), ts=IsNow(tz=timezone.utc)), ] ) + assert state == MyState(x=2, y='y') From 326e5f50dc234b75be3bca142ed1ddeaa2f64e34 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 15 Jan 2025 16:51:21 +0000 Subject: [PATCH 54/57] improve history (de)serialization --- pydantic_graph/pydantic_graph/_utils.py | 29 +++++- pydantic_graph/pydantic_graph/graph.py | 73 +++++++++++---- pydantic_graph/pydantic_graph/state.py | 9 +- tests/graph/test_graph.py | 7 +- tests/graph/test_history.py | 118 +++++++++++++++++++++--- tests/graph/test_state.py | 3 + 6 files changed, 203 insertions(+), 36 deletions(-) diff --git a/pydantic_graph/pydantic_graph/_utils.py b/pydantic_graph/pydantic_graph/_utils.py index 0211f4fd17..6753138c60 100644 --- a/pydantic_graph/pydantic_graph/_utils.py +++ b/pydantic_graph/pydantic_graph/_utils.py @@ -3,7 +3,7 @@ import sys import types from datetime import datetime, timezone -from typing import Annotated, Any, Union, get_args, get_origin +from typing import Annotated, Any, TypeVar, Union, get_args, get_origin import typing_extensions @@ -35,6 +35,16 @@ def unpack_annotated(tp: Any) -> tuple[Any, list[Any]]: return tp, [] +def is_never(tp: Any) -> bool: + """Check if a type is `Never`.""" + if tp is typing_extensions.Never: + return True + elif typing_never := getattr(typing_extensions, 'Never', None): + return tp is typing_never + else: + return False + + # same as `pydantic_ai_slim/pydantic_ai/_result.py:origin_is_union` if sys.version_info < (3, 10): @@ -72,3 +82,20 @@ def get_parent_namespace(frame: types.FrameType | None) -> dict[str, Any] | None def now_utc() -> datetime: return datetime.now(tz=timezone.utc) + + +class Unset: + """A singleton to represent an unset value. + + Copied from pydantic_ai/_utils.py. + """ + + pass + + +UNSET = Unset() +T = TypeVar('T') + + +def is_set(t_or_unset: T | Unset) -> typing_extensions.TypeGuard[T]: + return t_or_unset is not UNSET diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index a4c31271a1..3cc1adf730 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -2,20 +2,19 @@ import asyncio import inspect +import types from collections.abc import Sequence -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import cached_property from pathlib import Path from time import perf_counter -from types import FrameType from typing import TYPE_CHECKING, Annotated, Any, Callable, Generic import logfire_api import pydantic -from typing_extensions import Literal, Unpack, assert_never +import typing_extensions from . import _utils, exceptions, mermaid -from ._utils import get_parent_namespace from .nodes import BaseNode, End, GraphContext, NodeDef, RunEndT from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state, nodes_schema_var @@ -68,17 +67,19 @@ async def run(self, ctx: GraphContext) -> Increment | End: from the graph. """ - state_type: type[StateT] | None name: str | None node_defs: dict[str, NodeDef[StateT, RunEndT]] snapshot_state: Callable[[StateT], StateT] + _state_type: type[StateT] | _utils.Unset = field(repr=False) + _run_end_type: type[RunEndT] | _utils.Unset = field(repr=False) def __init__( self, *, nodes: Sequence[type[BaseNode[StateT, RunEndT]]], - state_type: type[StateT] | None = None, name: str | None = None, + state_type: type[StateT] | _utils.Unset = _utils.UNSET, + run_end_type: type[RunEndT] | _utils.Unset = _utils.UNSET, snapshot_state: Callable[[StateT], StateT] = deep_copy_state, ): """Create a graph from a sequence of nodes. @@ -86,18 +87,20 @@ def __init__( Args: nodes: The nodes which make up the graph, nodes need to be unique and all be generic in the same state type. - state_type: The type of the state for the graph. name: Optional name for the graph, if not provided the name will be inferred from the calling frame on the first call to a graph method. + state_type: The type of the state for the graph, this can generally be inferred from `nodes`. + run_end_type: The type of the result of running the graph, this can generally be inferred from `nodes`. snapshot_state: A function to snapshot the state of the graph, this is used in [`NodeStep`][pydantic_graph.state.NodeStep] and [`EndStep`][pydantic_graph.state.EndStep] to record the state before each step. """ - self.state_type = state_type self.name = name + self._state_type = state_type + self._run_end_type = run_end_type self.snapshot_state = snapshot_state - parent_namespace = get_parent_namespace(inspect.currentframe()) + parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = {} for node in nodes: self._register_node(node, parent_namespace) @@ -162,7 +165,7 @@ async def main(): start_node = next_node else: if TYPE_CHECKING: - assert_never(next_node) + typing_extensions.assert_never(next_node) else: raise exceptions.GraphRuntimeError( f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.' @@ -256,21 +259,20 @@ def load_history(self, json_bytes: str | bytes | bytearray) -> list[HistoryStep[ @cached_property def history_type_adapter(self) -> pydantic.TypeAdapter[list[HistoryStep[StateT, RunEndT]]]: nodes = [node_def.node for node_def in self.node_defs.values()] + state_t = self._get_state_type() + end_t = self._get_run_end_type() token = nodes_schema_var.set(nodes) try: - # TODO get the return type RunEndT - ta = pydantic.TypeAdapter( - list[Annotated[HistoryStep[self.state_type or Any, Any], pydantic.Discriminator('kind')]], - ) + ta = pydantic.TypeAdapter(list[Annotated[HistoryStep[state_t, end_t], pydantic.Discriminator('kind')]]) finally: nodes_schema_var.reset(token) - return ta # type: ignore + return ta def mermaid_code( self, *, start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, - title: str | None | Literal[False] = None, + title: str | None | typing_extensions.Literal[False] = None, edge_labels: bool = True, notes: bool = True, highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, @@ -338,7 +340,9 @@ def mermaid_code( notes=notes, ) - def mermaid_image(self, infer_name: bool = True, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes: + def mermaid_image( + self, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig] + ) -> bytes: """Generate a diagram representing the graph as an image. The format and diagram can be customized using `kwargs`, @@ -362,7 +366,7 @@ def mermaid_image(self, infer_name: bool = True, **kwargs: Unpack[mermaid.Mermai return mermaid.request_image(self, **kwargs) def mermaid_save( - self, path: Path | str, /, *, infer_name: bool = True, **kwargs: Unpack[mermaid.MermaidConfig] + self, path: Path | str, /, *, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig] ) -> None: """Generate a diagram representing the graph and save it as an image. @@ -384,6 +388,37 @@ def mermaid_save( kwargs['title'] = self.name mermaid.save_image(path, self, **kwargs) + def _get_state_type(self) -> type[StateT]: + if _utils.is_set(self._state_type): + return self._state_type + + for node_def in self.node_defs.values(): + for base in types.get_original_bases(node_def.node): + if typing_extensions.get_origin(base) is BaseNode: + args = typing_extensions.get_args(base) + if args: + return args[0] + # break the inner (bases) loop + break + # state defaults to None, so use that if we can't infer it + return type(None) # pyright: ignore[reportReturnType] + + def _get_run_end_type(self) -> type[RunEndT]: + if _utils.is_set(self._run_end_type): + return self._run_end_type + + for node_def in self.node_defs.values(): + for base in types.get_original_bases(node_def.node): + if typing_extensions.get_origin(base) is BaseNode: + args = typing_extensions.get_args(base) + if len(args) == 2: + t = args[1] + if not _utils.is_never(t): + return t + # break the inner (bases) loop + break + raise exceptions.GraphSetupError('Could not infer run end type from nodes, please set `run_end_type`.') + def _register_node(self, node: type[BaseNode[StateT, RunEndT]], parent_namespace: dict[str, Any] | None) -> None: node_id = node.get_id() if existing_node := self.node_defs.get(node_id): @@ -412,7 +447,7 @@ def _validate_edges(self): f'Nodes are referenced in the graph but not included in the graph:\n{b}' ) - def _infer_name(self, function_frame: FrameType | None) -> None: + def _infer_name(self, function_frame: types.FrameType | None) -> None: """Infer the agent name from the call frame. Usage should be `self._infer_name(inspect.currentframe())`. diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index 52cf8a7f7b..8db69fb0df 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -102,10 +102,13 @@ def __get_pydantic_core_schema__( 'Unable to build a Pydantic schema for `NodeStep` or `HistoryStep` without setting `nodes_schema_var`. ' 'You probably want to use ' ) from e - nodes_annotated = [Annotated[node, pydantic.Tag(node.get_id())] for node in nodes] - nodes_union = Annotated[Union[tuple(nodes_annotated)], pydantic.Discriminator(self._node_discriminator)] + if len(nodes) == 1: + nodes_type = nodes[0] + else: + nodes_annotated = [Annotated[node, pydantic.Tag(node.get_id())] for node in nodes] + nodes_type = Annotated[Union[tuple(nodes_annotated)], pydantic.Discriminator(self._node_discriminator)] - schema = handler(nodes_union) + schema = handler(nodes_type) schema['serialization'] = core_schema.wrap_serializer_function_ser_schema( function=self._node_serializer, return_schema=core_schema.dict_schema(core_schema.str_schema(), core_schema.any_schema()), diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 0f226857da..7be7577e2e 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -1,3 +1,4 @@ +# pyright: reportPrivateUsage=false from __future__ import annotations as _annotations from dataclasses import dataclass @@ -52,7 +53,9 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq else: return End(self.input_data * 2) - my_graph = Graph[None, int](nodes=(Float2String, String2Length, Double)) + my_graph = Graph(nodes=(Float2String, String2Length, Double)) + assert my_graph._get_state_type() is type(None) + assert my_graph._get_run_end_type() is int assert my_graph.name is None result, history = await my_graph.run(None, Float2String(3.14)) # len('3.14') * 2 == 8 @@ -259,6 +262,8 @@ async def run(self, ctx: GraphContext) -> End[None]: return 42 # type: ignore g = Graph(nodes=(Foo, Bar)) + assert g._get_state_type() is type(None) + assert g._get_run_end_type() is type(None) with pytest.raises(GraphRuntimeError) as exc_info: await g.run(None, Foo()) diff --git a/tests/graph/test_history.py b/tests/graph/test_history.py index 6dff468a3e..89b5b73a4d 100644 --- a/tests/graph/test_history.py +++ b/tests/graph/test_history.py @@ -1,12 +1,15 @@ +# pyright: reportPrivateUsage=false from __future__ import annotations as _annotations +import json from dataclasses import dataclass -from datetime import timezone +from datetime import datetime, timezone import pytest +from dirty_equals import IsStr from inline_snapshot import snapshot -from pydantic_graph import BaseNode, End, EndStep, Graph, GraphContext, NodeStep +from pydantic_graph import BaseNode, End, EndStep, Graph, GraphContext, GraphSetupError, NodeStep from ..conftest import IsFloat, IsNow @@ -27,26 +30,117 @@ async def run(self, ctx: GraphContext[MyState]) -> Bar: @dataclass -class Bar(BaseNode[MyState, str]): - async def run(self, ctx: GraphContext[MyState]) -> End[str]: +class Bar(BaseNode[MyState, int]): + async def run(self, ctx: GraphContext[MyState]) -> End[int]: ctx.state.y += 'y' - return End(f'x={ctx.state.x} y={ctx.state.y}') + return End(ctx.state.x * 2) -graph = Graph(nodes=(Foo, Bar), state_type=MyState) - - -async def test_dump_history(): +@pytest.mark.parametrize( + 'graph', + [ + Graph(nodes=(Foo, Bar), state_type=MyState, run_end_type=int), + Graph(nodes=(Foo, Bar), state_type=MyState), + Graph(nodes=(Foo, Bar), run_end_type=int), + Graph(nodes=(Foo, Bar)), + ], +) +async def test_dump_load_history(graph: Graph[MyState, int]): result, history = await graph.run(MyState(1, ''), Foo()) - assert result == snapshot('x=2 y=y') + assert result == snapshot(4) assert history == snapshot( [ NodeStep(state=MyState(x=2, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), NodeStep(state=MyState(x=2, y='y'), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), - EndStep(result=End('x=2 y=y'), ts=IsNow(tz=timezone.utc)), + EndStep(result=End(4), ts=IsNow(tz=timezone.utc)), ] ) history_json = graph.dump_history(history) - assert history_json.startswith(b'[{"state":') + assert json.loads(history_json) == snapshot( + [ + { + 'state': {'x': 2, 'y': ''}, + 'node': {'node_id': 'Foo'}, + 'start_ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), + 'duration': IsFloat(), + 'kind': 'node', + }, + { + 'state': {'x': 2, 'y': 'y'}, + 'node': {'node_id': 'Bar'}, + 'start_ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), + 'duration': IsFloat(), + 'kind': 'node', + }, + {'result': {'data': 4}, 'ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), 'kind': 'end'}, + ] + ) history_loaded = graph.load_history(history_json) assert history == history_loaded + + custom_history = [ + { + 'state': {'x': 2, 'y': ''}, + 'node': {'node_id': 'Foo'}, + 'start_ts': '2025-01-01T00:00:00Z', + 'duration': 123, + 'kind': 'node', + }, + {'result': {'data': '42'}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'}, + ] + history_loaded = graph.load_history(json.dumps(custom_history)) + assert history_loaded == snapshot( + [ + NodeStep( + state=MyState(x=2, y=''), + node=Foo(), + start_ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), + duration=123.0, + ), + EndStep(result=End(data=42), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), + ] + ) + + +def test_one_node(): + @dataclass + class MyNode(BaseNode[None, int]): + async def run(self, ctx: GraphContext) -> End[int]: + return End(123) + + g = Graph(nodes=[MyNode]) + + custom_history = [ + {'result': {'data': '123'}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'}, + ] + history_loaded = g.load_history(json.dumps(custom_history)) + assert history_loaded == snapshot( + [ + EndStep(result=End(data=123), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), + ] + ) + + +def test_no_generic_arg(): + @dataclass + class NoGenericArgsNode(BaseNode): + async def run(self, ctx: GraphContext) -> NoGenericArgsNode: + return NoGenericArgsNode() + + g = Graph(nodes=[NoGenericArgsNode]) + assert g._get_state_type() is type(None) + with pytest.raises(GraphSetupError, match='Could not infer run end type from nodes, please set `run_end_type`.'): + g._get_run_end_type() + + g = Graph(nodes=[NoGenericArgsNode], run_end_type=None) # pyright: ignore[reportArgumentType] + assert g._get_run_end_type() is None + + custom_history = [ + {'result': {'data': None}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'}, + ] + history_loaded = g.load_history(json.dumps(custom_history)) + assert history_loaded == snapshot( + [ + EndStep(result=End(data=None), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), + ] + ) diff --git a/tests/graph/test_state.py b/tests/graph/test_state.py index a19df012d7..6d9899dc78 100644 --- a/tests/graph/test_state.py +++ b/tests/graph/test_state.py @@ -1,3 +1,4 @@ +# pyright: reportPrivateUsage=false from __future__ import annotations as _annotations from dataclasses import dataclass @@ -32,6 +33,8 @@ async def run(self, ctx: GraphContext[MyState]) -> End[str]: return End(f'x={ctx.state.x} y={ctx.state.y}') graph = Graph(nodes=(Foo, Bar)) + assert graph._get_state_type() is MyState + assert graph._get_run_end_type() is str state = MyState(1, '') result, history = await graph.run(state, Foo()) assert result == snapshot('x=2 y=y') From 5b8433bd6ecaaa9261b74a288d2c036c8eb895f3 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 15 Jan 2025 16:55:58 +0000 Subject: [PATCH 55/57] fix for older python --- pydantic_graph/pydantic_graph/graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 3cc1adf730..5f25330cee 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -393,7 +393,7 @@ def _get_state_type(self) -> type[StateT]: return self._state_type for node_def in self.node_defs.values(): - for base in types.get_original_bases(node_def.node): + for base in typing_extensions.get_original_bases(node_def.node): if typing_extensions.get_origin(base) is BaseNode: args = typing_extensions.get_args(base) if args: @@ -408,7 +408,7 @@ def _get_run_end_type(self) -> type[RunEndT]: return self._run_end_type for node_def in self.node_defs.values(): - for base in types.get_original_bases(node_def.node): + for base in typing_extensions.get_original_bases(node_def.node): if typing_extensions.get_origin(base) is BaseNode: args = typing_extensions.get_args(base) if len(args) == 2: From b4e44bf3b4a0dbd18ab0cb51c39ae84d86ccf81a Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 15 Jan 2025 19:27:55 +0000 Subject: [PATCH 56/57] Graph deps (#693) Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com> Co-authored-by: Israel Ekpo <44282278+izzyacademy@users.noreply.github.com> --- docs/api/pydantic_graph/nodes.md | 1 + docs/graph.md | 102 +++++++++++++++--- .../pydantic_ai_examples/question_graph.py | 6 +- pydantic_graph/README.md | 4 +- pydantic_graph/pydantic_graph/graph.py | 62 ++++++----- pydantic_graph/pydantic_graph/mermaid.py | 8 +- pydantic_graph/pydantic_graph/nodes.py | 18 ++-- pydantic_graph/pydantic_graph/state.py | 8 +- tests/graph/test_graph.py | 69 ++++++++---- tests/graph/test_history.py | 8 +- tests/graph/test_mermaid.py | 6 +- tests/graph/test_state.py | 4 +- tests/typed_graph.py | 54 ++++++++-- 13 files changed, 257 insertions(+), 93 deletions(-) diff --git a/docs/api/pydantic_graph/nodes.md b/docs/api/pydantic_graph/nodes.md index 7540920a02..e58ddf7012 100644 --- a/docs/api/pydantic_graph/nodes.md +++ b/docs/api/pydantic_graph/nodes.md @@ -7,5 +7,6 @@ - BaseNode - End - Edge + - DepsT - RunEndT - NodeRunEndT diff --git a/docs/graph.md b/docs/graph.md index ba5e00f511..bbd546ddc3 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -12,7 +12,7 @@ Unless you're sure you need a graph, you probably don't. -Graphs and finite state machines (FSMs) are a powerful abstraction to model, control and visualize complex workflows. +Graphs and finite state machines (FSMs) are a powerful abstraction to model, execute, control and visualize complex workflows. Alongside PydanticAI, we've developed `pydantic-graph` — an async graph and state machine library for Python where nodes and edges are defined using type hints. @@ -59,7 +59,8 @@ Nodes which are generally [`dataclass`es][dataclasses.dataclass] include: Nodes are generic in: -* **state**, which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.state.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter +* **state**, which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.state.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter, see [stateful graphs](#stateful-graphs) for more information +* **deps**, which must have the same type as the deps of the graph they're included in, [`DepsT`][pydantic_graph.nodes.DepsT] has a default of `None`, so if you're not using deps you can omit this generic parameter, see [dependency injection](#dependency-injection) for more information * **graph return type** — this only applies if the node returns [`End`][pydantic_graph.nodes.End]. [`RunEndT`][pydantic_graph.nodes.RunEndT] has a default of [Never][typing.Never] so this generic parameter can be omitted if the node doesn't return `End`, but must be included if it does. Here's an example of a start or intermediate node in a graph — it can't end the run as it doesn't return [`End`][pydantic_graph.nodes.End]: @@ -96,7 +97,7 @@ from pydantic_graph import BaseNode, End, GraphContext @dataclass -class MyNode(BaseNode[MyState, int]): # (1)! +class MyNode(BaseNode[MyState, None, int]): # (1)! foo: int async def run( @@ -109,7 +110,7 @@ class MyNode(BaseNode[MyState, int]): # (1)! return AnotherNode() ``` -1. We parameterize the node with the return type (`int` in this case) as well as state. +1. We parameterize the node with the return type (`int` in this case) as well as state. Because generic parameters are positional-only, we have to include `None` as the second parameter representing deps. 2. The return type of the `run` method is now a union of `AnotherNode` and `End[int]`, this allows the node to end the run if `foo` is divisible by 5. ### Graph @@ -119,6 +120,7 @@ class MyNode(BaseNode[MyState, int]): # (1)! `Graph` is generic in: * **state** the state type of the graph, [`StateT`][pydantic_graph.state.StateT] +* **deps** the deps type of the graph, [`DepsT`][pydantic_graph.nodes.DepsT] * **graph return type** the return type of the graph run, [`RunEndT`][pydantic_graph.nodes.RunEndT] Here's an example of a simple graph: @@ -132,7 +134,7 @@ from pydantic_graph import BaseNode, End, Graph, GraphContext @dataclass -class DivisibleBy5(BaseNode[None, int]): # (1)! +class DivisibleBy5(BaseNode[None, None, int]): # (1)! foo: int async def run( @@ -154,7 +156,7 @@ class Increment(BaseNode): # (2)! fives_graph = Graph(nodes=[DivisibleBy5, Increment]) # (3)! -result, history = fives_graph.run_sync(None, DivisibleBy5(4)) # (4)! +result, history = fives_graph.run_sync(DivisibleBy5(4)) # (4)! print(result) #> 5 # the full history is quite verbose (see below), so we'll just print the summary @@ -162,7 +164,7 @@ print([item.data_snapshot() for item in history]) #> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] ``` -1. The `DivisibleBy5` node is parameterized with `None` as this graph doesn't use state, and `int` as it can end the run. +1. The `DivisibleBy5` node is parameterized with `None` for the state param and `None` for the deps param as this graph doesn't use state or deps, and `int` as it can end the run. 2. The `Increment` node doesn't return `End`, so the `RunEndT` generic parameter is omitted, state can also be omitted as the graph doesn't use state. 3. The graph is created with a sequence of nodes. 4. The graph is run synchronously with [`run_sync`][pydantic_graph.graph.Graph.run_sync] the initial state `None` and the start node `DivisibleBy5(4)` are passed as arguments. @@ -247,7 +249,7 @@ PRODUCT_PRICES = { # (2)! @dataclass -class Purchase(BaseNode[MachineState, None]): # (18)! +class Purchase(BaseNode[MachineState, None, None]): # (18)! product: str async def run( @@ -275,7 +277,7 @@ vending_machine_graph = Graph( # (13)! async def main(): state = MachineState() # (14)! - await vending_machine_graph.run(state, InsertCoin()) # (15)! + await vending_machine_graph.run(InsertCoin(), state=state) # (15)! print(f'purchase successful item={state.product} change={state.user_balance:0.2f}') #> purchase successful item=crisps change=0.25 ``` @@ -430,7 +432,7 @@ feedback_agent = Agent[None, EmailRequiresWrite | EmailOk]( @dataclass -class Feedback(BaseNode[State, Email]): +class Feedback(BaseNode[State, None, Email]): email: Email async def run( @@ -453,7 +455,7 @@ async def main(): ) state = State(user) feedback_graph = Graph(nodes=(WriteEmail, Feedback)) - email, _ = await feedback_graph.run(state, WriteEmail()) + email, _ = await feedback_graph.run(WriteEmail(), state=state) print(email) """ Email( @@ -579,7 +581,7 @@ async def main(): node = Ask() # (2)! history: list[HistoryStep[QuestionState]] = [] # (3)! while True: - node = await question_graph.next(state, node, history) # (4)! + node = await question_graph.next(node, history, state=state) # (4)! if isinstance(node, Answer): node.answer = Prompt.ask(node.question) # (5)! elif isinstance(node, End): # (6)! @@ -633,6 +635,82 @@ stateDiagram-v2 You maybe have noticed that although this examples transfers control flow out of the graph run, we're still using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] to get user input, with the process hanging while we wait for the user to enter a response. For an example of genuine out-of-process control flow, see the [question graph example](examples/question-graph.md). +## Dependency Injection + +As with PydanticAI, `pydantic-graph` supports dependency injection via a generic parameter on [`Graph`][pydantic_graph.graph.Graph] and [`BaseNode`][pydantic_graph.nodes.BaseNode], and the [`GraphContext.deps`][pydantic_graph.nodes.GraphContext.deps] fields. + +As an example of dependency injection, let's modify the `DivisibleBy5` example [above](#graph) to use a [`ProcessPoolExecutor`][concurrent.futures.ProcessPoolExecutor] to run the compute load in a separate process (this is a contrived example, `ProcessPoolExecutor` wouldn't actually improve performance in this example): + +```py {title="deps_example.py" py="3.10" test="skip" hl_lines="4 10-12 35-37 48-49"} +from __future__ import annotations + +import asyncio +from concurrent.futures import ProcessPoolExecutor +from dataclasses import dataclass + +from pydantic_graph import BaseNode, End, Graph, GraphContext + + +@dataclass +class GraphDeps: + executor: ProcessPoolExecutor + + +@dataclass +class DivisibleBy5(BaseNode[None, None, int]): + foo: int + + async def run( + self, + ctx: GraphContext, + ) -> Increment | End[int]: + if self.foo % 5 == 0: + return End(self.foo) + else: + return Increment(self.foo) + + +@dataclass +class Increment(BaseNode): + foo: int + + async def run(self, ctx: GraphContext) -> DivisibleBy5: + loop = asyncio.get_running_loop() + compute_result = await loop.run_in_executor( + ctx.deps.executor, + self.compute, + ) + return DivisibleBy5(compute_result) + + def compute(self) -> int: + return self.foo + 1 + + +fives_graph = Graph(nodes=[DivisibleBy5, Increment]) + + +async def main(): + with ProcessPoolExecutor() as executor: + deps = GraphDeps(executor) + result, history = await fives_graph.run(DivisibleBy5(3), deps=deps) + print(result) + #> 5 + # the full history is quite verbose (see below), so we'll just print the summary + print([item.data_snapshot() for item in history]) + """ + [ + DivisibleBy5(foo=3), + Increment(foo=3), + DivisibleBy5(foo=4), + Increment(foo=4), + DivisibleBy5(foo=5), + End(data=5), + ] + """ +``` + +_(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ + ## Mermaid Diagrams Pydantic Graph can generate [mermaid](https://mermaid.js.org/) [`stateDiagram-v2`](https://mermaid.js.org/syntax/stateDiagram.html) diagrams for graphs, as shown above. diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index f73dbf1684..6ef092d3ba 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -87,7 +87,7 @@ async def run( @dataclass -class Congratulate(BaseNode[QuestionState, None]): +class Congratulate(BaseNode[QuestionState, None, None]): comment: str async def run( @@ -119,7 +119,7 @@ async def run_as_continuous(): history: list[HistoryStep[QuestionState, None]] = [] with logfire.span('run questions graph'): while True: - node = await question_graph.next(state, node, history) + node = await question_graph.next(node, history, state=state) if isinstance(node, End): debug([e.data_snapshot() for e in history]) break @@ -150,7 +150,7 @@ async def run_as_cli(answer: str | None): with logfire.span('run questions graph'): while True: - node = await question_graph.next(state, node, history) + node = await question_graph.next(node, history, state=state) if isinstance(node, End): debug([e.data_snapshot() for e in history]) print('Finished!') diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md index 6ac17fb216..0b2508f033 100644 --- a/pydantic_graph/README.md +++ b/pydantic_graph/README.md @@ -28,7 +28,7 @@ from pydantic_graph import BaseNode, End, Graph, GraphContext @dataclass -class DivisibleBy5(BaseNode[None, int]): +class DivisibleBy5(BaseNode[None, None, int]): foo: int async def run( @@ -50,7 +50,7 @@ class Increment(BaseNode): fives_graph = Graph(nodes=[DivisibleBy5, Increment]) -result, history = fives_graph.run_sync(None, DivisibleBy5(4)) +result, history = fives_graph.run_sync(DivisibleBy5(4)) print(result) #> 5 # the full history is quite verbose (see below), so we'll just print the summary diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 5f25330cee..0aed1f3706 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -15,7 +15,7 @@ import typing_extensions from . import _utils, exceptions, mermaid -from .nodes import BaseNode, End, GraphContext, NodeDef, RunEndT +from .nodes import BaseNode, DepsT, End, GraphContext, NodeDef, RunEndT from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state, nodes_schema_var __all__ = ('Graph',) @@ -24,7 +24,7 @@ @dataclass(init=False) -class Graph(Generic[StateT, RunEndT]): +class Graph(Generic[StateT, DepsT, RunEndT]): """Definition of a graph. In `pydantic-graph`, a graph is a collection of nodes that can be run in sequence. The nodes define @@ -51,12 +51,12 @@ async def run(self, ctx: GraphContext) -> Check42: return Check42() @dataclass - class Check42(BaseNode[MyState]): - async def run(self, ctx: GraphContext) -> Increment | End: + class Check42(BaseNode[MyState, None, int]): + async def run(self, ctx: GraphContext) -> Increment | End[int]: if ctx.state.number == 42: return Increment() else: - return End(None) + return End(ctx.state.number) never_42_graph = Graph(nodes=(Increment, Check42)) ``` @@ -68,7 +68,7 @@ async def run(self, ctx: GraphContext) -> Increment | End: """ name: str | None - node_defs: dict[str, NodeDef[StateT, RunEndT]] + node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] snapshot_state: Callable[[StateT], StateT] _state_type: type[StateT] | _utils.Unset = field(repr=False) _run_end_type: type[RunEndT] | _utils.Unset = field(repr=False) @@ -76,7 +76,7 @@ async def run(self, ctx: GraphContext) -> Increment | End: def __init__( self, *, - nodes: Sequence[type[BaseNode[StateT, RunEndT]]], + nodes: Sequence[type[BaseNode[StateT, DepsT, RunEndT]]], name: str | None = None, state_type: type[StateT] | _utils.Unset = _utils.UNSET, run_end_type: type[RunEndT] | _utils.Unset = _utils.UNSET, @@ -101,7 +101,7 @@ def __init__( self.snapshot_state = snapshot_state parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) - self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = {} + self.node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] = {} for node in nodes: self._register_node(node, parent_namespace) @@ -109,17 +109,19 @@ def __init__( async def run( self, - state: StateT, - start_node: BaseNode[StateT, RunEndT], + start_node: BaseNode[StateT, DepsT, RunEndT], *, + state: StateT = None, + deps: DepsT = None, infer_name: bool = True, ) -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: """Run the graph from a starting node until it ends. Args: - state: The initial state of the graph. start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. + state: The initial state of the graph. + deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. Returns: @@ -132,14 +134,14 @@ async def run( async def main(): state = MyState(1) - _, history = await never_42_graph.run(state, Increment()) + _, history = await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=2) print(len(history)) #> 3 state = MyState(41) - _, history = await never_42_graph.run(state, Increment()) + _, history = await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=43) print(len(history)) @@ -156,7 +158,7 @@ async def main(): start=start_node, ) as run_span: while True: - next_node = await self.next(state, start_node, history, infer_name=False) + next_node = await self.next(start_node, history, state=state, deps=deps, infer_name=False) if isinstance(next_node, End): history.append(EndStep(result=next_node)) run_span.set_attribute('history', history) @@ -173,9 +175,10 @@ async def main(): def run_sync( self, - state: StateT, - start_node: BaseNode[StateT, RunEndT], + start_node: BaseNode[StateT, DepsT, RunEndT], *, + state: StateT = None, + deps: DepsT = None, infer_name: bool = True, ) -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: """Run the graph synchronously. @@ -184,9 +187,10 @@ def run_sync( You therefore can't use this method inside async code or if there's an active event loop. Args: - state: The initial state of the graph. start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. + state: The initial state of the graph. + deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. Returns: @@ -194,22 +198,26 @@ def run_sync( """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - return asyncio.get_event_loop().run_until_complete(self.run(state, start_node, infer_name=False)) + return asyncio.get_event_loop().run_until_complete( + self.run(start_node, state=state, deps=deps, infer_name=False) + ) async def next( self, - state: StateT, - node: BaseNode[StateT, RunEndT], + node: BaseNode[StateT, DepsT, RunEndT], history: list[HistoryStep[StateT, RunEndT]], *, + state: StateT = None, + deps: DepsT = None, infer_name: bool = True, - ) -> BaseNode[StateT, Any] | End[RunEndT]: + ) -> BaseNode[StateT, DepsT, Any] | End[RunEndT]: """Run a node in the graph and return the next node to run. Args: - state: The current state of the graph. node: The node to run. history: The history of the graph run so far. NOTE: this will be mutated to add the new step. + state: The current state of the graph. + deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. Returns: @@ -221,7 +229,7 @@ async def next( if node_id not in self.node_defs: raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') - ctx = GraphContext(state) + ctx = GraphContext(state, deps) with _logfire.span('run node {node_id}', node_id=node_id, node=node): start_ts = _utils.now_utc() start = perf_counter() @@ -411,15 +419,17 @@ def _get_run_end_type(self) -> type[RunEndT]: for base in typing_extensions.get_original_bases(node_def.node): if typing_extensions.get_origin(base) is BaseNode: args = typing_extensions.get_args(base) - if len(args) == 2: - t = args[1] + if len(args) == 3: + t = args[2] if not _utils.is_never(t): return t # break the inner (bases) loop break raise exceptions.GraphSetupError('Could not infer run end type from nodes, please set `run_end_type`.') - def _register_node(self, node: type[BaseNode[StateT, RunEndT]], parent_namespace: dict[str, Any] | None) -> None: + def _register_node( + self, node: type[BaseNode[StateT, DepsT, RunEndT]], parent_namespace: dict[str, Any] | None + ) -> None: node_id = node.get_id() if existing_node := self.node_defs.get(node_id): raise exceptions.GraphSetupError( diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index 49e41ee267..ffa1fc0190 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -23,7 +23,7 @@ def generate_code( # noqa: C901 - graph: Graph[Any, Any], + graph: Graph[Any, Any, Any], /, *, start_node: Sequence[NodeIdent] | NodeIdent | None = None, @@ -110,7 +110,7 @@ def _node_ids(node_idents: Sequence[NodeIdent] | NodeIdent) -> Iterable[str]: def request_image( - graph: Graph[Any, Any], + graph: Graph[Any, Any, Any], /, **kwargs: Unpack[MermaidConfig], ) -> bytes: @@ -175,7 +175,7 @@ def request_image( def save_image( path: Path | str, - graph: Graph[Any, Any], + graph: Graph[Any, Any, Any], /, **kwargs: Unpack[MermaidConfig], ) -> None: @@ -247,7 +247,7 @@ class MermaidConfig(TypedDict, total=False): """An HTTPX client to use for requests, mostly for testing purposes.""" -NodeIdent: TypeAlias = 'type[BaseNode[Any, Any]] | BaseNode[Any, Any] | str' +NodeIdent: TypeAlias = 'type[BaseNode[Any, Any, Any]] | BaseNode[Any, Any, Any] | str' """A type alias for a node identifier. This can be: diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index ec97b16e84..3591e1b5be 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -14,23 +14,27 @@ else: StateT = TypeVar('StateT', default=None) -__all__ = 'GraphContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'RunEndT', 'NodeRunEndT' +__all__ = 'GraphContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'RunEndT', 'NodeRunEndT', 'DepsT' RunEndT = TypeVar('RunEndT', default=None) """Type variable for the return type of a graph [`run`][pydantic_graph.graph.Graph.run].""" NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) """Type variable for the return type of a node [`run`][pydantic_graph.nodes.BaseNode.run].""" +DepsT = TypeVar('DepsT', default=None) +"""Type variable for the dependencies of a graph and node.""" @dataclass -class GraphContext(Generic[StateT]): +class GraphContext(Generic[StateT, DepsT]): """Context for a graph.""" state: StateT """The state of the graph.""" + deps: DepsT + """Dependencies for the graph.""" -class BaseNode(ABC, Generic[StateT, NodeRunEndT]): +class BaseNode(ABC, Generic[StateT, DepsT, NodeRunEndT]): """Base class for a node.""" docstring_notes: ClassVar[bool] = False @@ -42,7 +46,7 @@ class BaseNode(ABC, Generic[StateT, NodeRunEndT]): """ @abstractmethod - async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: + async def run(self, ctx: GraphContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[NodeRunEndT]: """Run the node. This is an abstract method that must be implemented by subclasses. @@ -87,7 +91,7 @@ def get_note(cls) -> str | None: return docstring @classmethod - def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRunEndT]: + def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, DepsT, NodeRunEndT]: """Get the node definition.""" type_hints = get_type_hints(cls.run, localns=local_ns, include_extras=True) try: @@ -139,7 +143,7 @@ class Edge: @dataclass -class NodeDef(Generic[StateT, NodeRunEndT]): +class NodeDef(Generic[StateT, DepsT, NodeRunEndT]): """Definition of a node. This is a primarily internal representation of a node; in general, it shouldn't be necessary to use it directly. @@ -148,7 +152,7 @@ class NodeDef(Generic[StateT, NodeRunEndT]): mermaid graphs. """ - node: type[BaseNode[StateT, NodeRunEndT]] + node: type[BaseNode[StateT, DepsT, NodeRunEndT]] """The node definition itself.""" node_id: str """ID of the node.""" diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index 8db69fb0df..99bddd6138 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -14,7 +14,7 @@ from . import _utils from .nodes import BaseNode, End, RunEndT -__all__ = 'StateT', 'NodeStep', 'EndStep', 'HistoryStep', 'deep_copy_state' +__all__ = 'StateT', 'NodeStep', 'EndStep', 'HistoryStep', 'deep_copy_state', 'nodes_schema_var' StateT = TypeVar('StateT', default=None) @@ -35,7 +35,7 @@ class NodeStep(Generic[StateT, RunEndT]): state: StateT """The state of the graph after the node has been run.""" - node: Annotated[BaseNode[StateT, RunEndT], CustomNodeSchema()] + node: Annotated[BaseNode[StateT, Any, RunEndT], CustomNodeSchema()] """The node that was run.""" start_ts: datetime = field(default_factory=_utils.now_utc) """The timestamp when the node started running.""" @@ -53,7 +53,7 @@ def __post_init__(self): # Copy the state to prevent it from being modified by other code self.state = self.snapshot_state(self.state) - def data_snapshot(self) -> BaseNode[StateT, RunEndT]: + def data_snapshot(self) -> BaseNode[StateT, Any, RunEndT]: """Returns a deep copy of [`self.node`][pydantic_graph.state.NodeStep.node]. Useful for summarizing history. @@ -88,7 +88,7 @@ def data_snapshot(self) -> End[RunEndT]: """ -nodes_schema_var: ContextVar[Sequence[type[BaseNode[Any, Any]]]] = ContextVar('nodes_var') +nodes_schema_var: ContextVar[Sequence[type[BaseNode[Any, Any, Any]]]] = ContextVar('nodes_var') class CustomNodeSchema: diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 7be7577e2e..2ae8bc5a9d 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -44,7 +44,7 @@ async def run(self, ctx: GraphContext) -> Double: return Double(len(self.input_data)) @dataclass - class Double(BaseNode[None, int]): + class Double(BaseNode[None, None, int]): input_data: int async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noqa: UP007 @@ -53,11 +53,11 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq else: return End(self.input_data * 2) - my_graph = Graph(nodes=(Float2String, String2Length, Double)) + my_graph = Graph[None, None, int](nodes=(Float2String, String2Length, Double)) + assert my_graph.name is None assert my_graph._get_state_type() is type(None) assert my_graph._get_run_end_type() is int - assert my_graph.name is None - result, history = await my_graph.run(None, Float2String(3.14)) + result, history = await my_graph.run(Float2String(3.14)) # len('3.14') * 2 == 8 assert result == 8 assert my_graph.name == 'my_graph' @@ -84,7 +84,7 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq EndStep(result=End(data=8), ts=IsNow(tz=timezone.utc)), ] ) - result, history = await my_graph.run(None, Float2String(3.14159)) + result, history = await my_graph.run(Float2String(3.14159)) # len('3.14159') == 7, 21 * 2 == 42 assert result == 42 assert history == snapshot( @@ -139,7 +139,7 @@ class Float2String(BaseNode): async def run(self, ctx: GraphContext) -> String2Length: raise NotImplementedError() - class String2Length(BaseNode[None, None]): + class String2Length(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() @@ -158,13 +158,13 @@ class Foo(BaseNode): async def run(self, ctx: GraphContext) -> Union[Bar, Spam]: # noqa: UP007 raise NotImplementedError() - class Bar(BaseNode[None, None]): + class Bar(BaseNode[None, None, None]): input_data: str async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() - class Spam(BaseNode[None, None]): + class Spam(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() @@ -185,17 +185,17 @@ class Foo(BaseNode): async def run(self, ctx: GraphContext) -> Eggs: raise NotImplementedError() - class Bar(BaseNode[None, None]): + class Bar(BaseNode[None, None, None]): input_data: str async def run(self, ctx: GraphContext) -> Eggs: raise NotImplementedError() - class Spam(BaseNode[None, None]): + class Spam(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> Eggs: raise NotImplementedError() - class Eggs(BaseNode[None, None]): + class Eggs(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() @@ -212,7 +212,7 @@ class Foo(BaseNode): async def run(self, ctx: GraphContext) -> Bar: raise NotImplementedError() - class Bar(BaseNode[None, None]): + class Bar(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() @@ -234,18 +234,18 @@ async def run(self, ctx: GraphContext) -> Bar: return Bar() @dataclass - class Bar(BaseNode[None, None]): + class Bar(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: return Spam() # type: ignore @dataclass - class Spam(BaseNode[None, None]): + class Spam(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() g = Graph(nodes=(Foo, Bar)) with pytest.raises(GraphRuntimeError) as exc_info: - await g.run(None, Foo()) + await g.run(Foo()) assert exc_info.value.message == snapshot('Node `test_run_node_not_in_graph..Spam()` is not in the graph.') @@ -257,7 +257,7 @@ async def run(self, ctx: GraphContext) -> Bar: return Bar() @dataclass - class Bar(BaseNode[None, None]): + class Bar(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: return 42 # type: ignore @@ -265,7 +265,7 @@ async def run(self, ctx: GraphContext) -> End[None]: assert g._get_state_type() is type(None) assert g._get_run_end_type() is type(None) with pytest.raises(GraphRuntimeError) as exc_info: - await g.run(None, Foo()) + await g.run(Foo()) assert exc_info.value.message == snapshot('Invalid node return type: `int`. Expected `BaseNode` or `End`.') @@ -284,13 +284,13 @@ async def run(self, ctx: GraphContext) -> Foo: g = Graph(nodes=(Foo, Bar)) assert g.name is None history: list[HistoryStep[None, Never]] = [] - n = await g.next(None, Foo(), history) + n = await g.next(Foo(), history) assert n == Bar() assert g.name == 'g' assert history == snapshot([NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat())]) assert isinstance(n, Bar) - n2 = await g.next(None, n, history) + n2 = await g.next(n, history) assert n2 == Foo() assert history == snapshot( @@ -299,3 +299,34 @@ async def run(self, ctx: GraphContext) -> Foo: NodeStep(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), ] ) + + +async def test_deps(): + @dataclass + class Deps: + a: int + b: int + + @dataclass + class Foo(BaseNode[None, Deps]): + async def run(self, ctx: GraphContext[None, Deps]) -> Bar: + assert isinstance(ctx.deps, Deps) + return Bar() + + @dataclass + class Bar(BaseNode[None, Deps, int]): + async def run(self, ctx: GraphContext[None, Deps]) -> End[int]: + assert isinstance(ctx.deps, Deps) + return End(123) + + g = Graph(nodes=(Foo, Bar)) + result, history = await g.run(Foo(), deps=Deps(1, 2)) + + assert result == 123 + assert history == snapshot( + [ + NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeStep(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + EndStep(result=End(data=123), ts=IsNow(tz=timezone.utc)), + ] + ) diff --git a/tests/graph/test_history.py b/tests/graph/test_history.py index 89b5b73a4d..ef3a88532e 100644 --- a/tests/graph/test_history.py +++ b/tests/graph/test_history.py @@ -30,7 +30,7 @@ async def run(self, ctx: GraphContext[MyState]) -> Bar: @dataclass -class Bar(BaseNode[MyState, int]): +class Bar(BaseNode[MyState, None, int]): async def run(self, ctx: GraphContext[MyState]) -> End[int]: ctx.state.y += 'y' return End(ctx.state.x * 2) @@ -45,8 +45,8 @@ async def run(self, ctx: GraphContext[MyState]) -> End[int]: Graph(nodes=(Foo, Bar)), ], ) -async def test_dump_load_history(graph: Graph[MyState, int]): - result, history = await graph.run(MyState(1, ''), Foo()) +async def test_dump_load_history(graph: Graph[MyState, None, int]): + result, history = await graph.run(Foo(), state=MyState(1, '')) assert result == snapshot(4) assert history == snapshot( [ @@ -104,7 +104,7 @@ async def test_dump_load_history(graph: Graph[MyState, int]): def test_one_node(): @dataclass - class MyNode(BaseNode[None, int]): + class MyNode(BaseNode[None, None, int]): async def run(self, ctx: GraphContext) -> End[int]: return End(123) diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 1d33f2c611..0a4302c2e1 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -26,7 +26,7 @@ async def run(self, ctx: GraphContext) -> Bar: @dataclass -class Bar(BaseNode[None, None]): +class Bar(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: return End(None) @@ -45,7 +45,7 @@ async def run(self, ctx: GraphContext) -> Annotated[Foo, Edge(label='spam to foo @dataclass -class Eggs(BaseNode[None, None]): +class Eggs(BaseNode[None, None, None]): """This is the docstring for Eggs.""" docstring_notes = False @@ -58,7 +58,7 @@ async def run(self, ctx: GraphContext) -> Annotated[End[None], Edge(label='eggs async def test_run_graph(): - result, history = await graph1.run(None, Foo()) + result, history = await graph1.run(Foo()) assert result is None assert history == snapshot( [ diff --git a/tests/graph/test_state.py b/tests/graph/test_state.py index 6d9899dc78..597a420258 100644 --- a/tests/graph/test_state.py +++ b/tests/graph/test_state.py @@ -27,7 +27,7 @@ async def run(self, ctx: GraphContext[MyState]) -> Bar: return Bar() @dataclass - class Bar(BaseNode[MyState, str]): + class Bar(BaseNode[MyState, None, str]): async def run(self, ctx: GraphContext[MyState]) -> End[str]: ctx.state.y += 'y' return End(f'x={ctx.state.x} y={ctx.state.y}') @@ -36,7 +36,7 @@ async def run(self, ctx: GraphContext[MyState]) -> End[str]: assert graph._get_state_type() is MyState assert graph._get_run_end_type() is str state = MyState(1, '') - result, history = await graph.run(state, Foo()) + result, history = await graph.run(Foo(), state=state) assert result == snapshot('x=2 y=y') assert history == snapshot( [ diff --git a/tests/typed_graph.py b/tests/typed_graph.py index 8c6ff9b0ea..14bd171669 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -4,7 +4,7 @@ from typing_extensions import assert_type -from pydantic_graph import BaseNode, End, Graph, GraphContext +from pydantic_graph import BaseNode, End, Graph, GraphContext, HistoryStep @dataclass @@ -29,7 +29,7 @@ class X: @dataclass -class Double(BaseNode[None, X]): +class Double(BaseNode[None, None, X]): input_data: int async def run(self, ctx: GraphContext) -> String2Length | End[X]: @@ -39,7 +39,7 @@ async def run(self, ctx: GraphContext) -> String2Length | End[X]: return End(X(self.input_data * 2)) -def use_double(node: BaseNode[None, X]) -> None: +def use_double(node: BaseNode[None, None, X]) -> None: """Shoe that `Double` is valid as a `BaseNode[None, int, X]`.""" print(node) @@ -47,18 +47,18 @@ def use_double(node: BaseNode[None, X]) -> None: use_double(Double(1)) -g1 = Graph[None, X]( +g1 = Graph[None, None, X]( nodes=( Float2String, String2Length, Double, ) ) -assert_type(g1, Graph[None, X]) +assert_type(g1, Graph[None, None, X]) g2 = Graph(nodes=(Double,)) -assert_type(g2, Graph[None, X]) +assert_type(g2, Graph[None, None, X]) g3 = Graph( nodes=( @@ -68,7 +68,47 @@ def use_double(node: BaseNode[None, X]) -> None: ) ) # because String2Length came before Double, the output type is Any -assert_type(g3, Graph[None, X]) +assert_type(g3, Graph[None, None, X]) Graph[None, bytes](nodes=(Float2String, String2Length, Double)) # type: ignore[arg-type] Graph[None, str](nodes=[Double]) # type: ignore[list-item] + + +@dataclass +class MyState: + x: int + + +@dataclass +class MyDeps: + y: str + + +@dataclass +class A(BaseNode[MyState, MyDeps]): + async def run(self, ctx: GraphContext[MyState, MyDeps]) -> B: + assert ctx.state.x == 1 + assert ctx.deps.y == 'y' + return B() + + +@dataclass +class B(BaseNode[MyState, MyDeps, int]): + async def run(self, ctx: GraphContext[MyState, MyDeps]) -> End[int]: + return End(42) + + +g4 = Graph[MyState, MyDeps, int](nodes=(A, B)) +assert_type(g4, Graph[MyState, MyDeps, int]) + +g5 = Graph(nodes=(A, B)) +assert_type(g5, Graph[MyState, MyDeps, int]) + + +def run_g5() -> None: + g5.run_sync(A()) # pyright: ignore[reportArgumentType] + g5.run_sync(A(), state=MyState(x=1)) # pyright: ignore[reportArgumentType] + g5.run_sync(A(), deps=MyDeps(y='y')) # pyright: ignore[reportArgumentType] + ans, history = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y')) + assert_type(ans, int) + assert_type(history, list[HistoryStep[MyState, int]]) From a2144d6f97675b95f8ff4210654a38b429820472 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 15 Jan 2025 19:43:43 +0000 Subject: [PATCH 57/57] docs comments, GraphContext -> GraphRunContext --- docs/api/pydantic_graph/nodes.md | 2 +- docs/graph.md | 80 +++++++++---------- .../pydantic_ai_examples/question_graph.py | 12 +-- pydantic_graph/README.md | 6 +- pydantic_graph/pydantic_graph/__init__.py | 4 +- pydantic_graph/pydantic_graph/graph.py | 10 +-- pydantic_graph/pydantic_graph/nodes.py | 6 +- tests/graph/test_graph.py | 48 +++++------ tests/graph/test_history.py | 10 +-- tests/graph/test_mermaid.py | 16 ++-- tests/graph/test_state.py | 6 +- tests/typed_graph.py | 12 +-- 12 files changed, 106 insertions(+), 106 deletions(-) diff --git a/docs/api/pydantic_graph/nodes.md b/docs/api/pydantic_graph/nodes.md index e58ddf7012..ecf6d35f50 100644 --- a/docs/api/pydantic_graph/nodes.md +++ b/docs/api/pydantic_graph/nodes.md @@ -3,7 +3,7 @@ ::: pydantic_graph.nodes options: members: - - GraphContext + - GraphRunContext - BaseNode - End - Edge diff --git a/docs/graph.md b/docs/graph.md index bbd546ddc3..a394a252f5 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -1,7 +1,7 @@ # Graphs !!! danger "Don't use a nail gun unless you need a nail gun" - If PydanticAI [agents](agents.md) are a hammer, and [multi-agent workflows](multi-agent-applications.md) are a sledgehammer, then graphs are a nail gun, with flames down the side: + If PydanticAI [agents](agents.md) are a hammer, and [multi-agent workflows](multi-agent-applications.md) are a sledgehammer, then graphs are a nail gun: * sure, nail guns look cooler than hammers * but nail guns take a lot more setup than hammers @@ -10,13 +10,13 @@ In short, graphs are a powerful tool, but they're not the right tool for every job. Please consider other [multi-agent approaches](multi-agent-applications.md) before proceeding. - Unless you're sure you need a graph, you probably don't. + If you're not confident a graph-based approach is a good idea, it might be unnecessary. Graphs and finite state machines (FSMs) are a powerful abstraction to model, execute, control and visualize complex workflows. Alongside PydanticAI, we've developed `pydantic-graph` — an async graph and state machine library for Python where nodes and edges are defined using type hints. -While this library is developed as part of the PydanticAI; it has no dependency on `pydantic-ai` and can be considered as a pure graph library. You may find it useful whether or not you're using PydanticAI or even building with GenAI. +While this library is developed as part of PydanticAI; it has no dependency on `pydantic-ai` and can be considered as a pure graph-based state machine library. You may find it useful whether or not you're using PydanticAI or even building with GenAI. `pydantic-graph` is designed for advanced users and makes heavy use of Python generics and types hints. It is not designed to be as beginner-friendly as PydanticAI. @@ -33,25 +33,25 @@ pip/uv-add pydantic-graph ## Graph Types -Graphs are made up of a few key components: +`pydantic-graph` made up of a few key components: -### GraphContext +### GraphRunContext -[`GraphContext`][pydantic_graph.nodes.GraphContext] — The context for the graph run, similar to PydanticAI's [`RunContext`][pydantic_ai.tools.RunContext], this holds the state of the graph and is passed to nodes when they're run. +[`GraphRunContext`][pydantic_graph.nodes.GraphRunContext] — The context for the graph run, similar to PydanticAI's [`RunContext`][pydantic_ai.tools.RunContext]. This holds the state of the graph and dependencies and is passed to nodes when they're run. -`GraphContext` is generic in the state type of the graph it's used in, [`StateT`][pydantic_graph.state.StateT]. +`GraphRunContext` is generic in the state type of the graph it's used in, [`StateT`][pydantic_graph.state.StateT]. ### End -[`End`][pydantic_graph.nodes.End] — return value to indicates the graph run should end. +[`End`][pydantic_graph.nodes.End] — return value to indicate the graph run should end. `End` is generic in the graph return type of the graph it's used in, [`RunEndT`][pydantic_graph.nodes.RunEndT]. ### Nodes -Subclasses of [`BaseNode`][pydantic_graph.nodes.BaseNode] to define nodes. +Subclasses of [`BaseNode`][pydantic_graph.nodes.BaseNode] define nodes for execution in the graph. -Nodes which are generally [`dataclass`es][dataclasses.dataclass] include: +Nodes, which are generally [`dataclass`es][dataclasses.dataclass], generally consist of: * fields containing any parameters required/optional when calling the node * the business logic to execute the node, in the [`run`][pydantic_graph.nodes.BaseNode.run] method @@ -68,7 +68,7 @@ Here's an example of a start or intermediate node in a graph — it can't end th ```py {title="intermediate_node.py" noqa="F821" test="skip"} from dataclasses import dataclass -from pydantic_graph import BaseNode, GraphContext +from pydantic_graph import BaseNode, GraphRunContext @dataclass @@ -77,7 +77,7 @@ class MyNode(BaseNode[MyState]): # (1)! async def run( self, - ctx: GraphContext[MyState], # (3)! + ctx: GraphRunContext[MyState], # (3)! ) -> AnotherNode: # (4)! ... return AnotherNode() @@ -85,7 +85,7 @@ class MyNode(BaseNode[MyState]): # (1)! 1. State in this example is `MyState` (not shown), hence `BaseNode` is parameterized with `MyState`. This node can't end the run, so the `RunEndT` generic parameter is omitted and defaults to `Never`. 2. `MyNode` is a dataclass and has a single field `foo`, an `int`. -3. The `run` method takes a `GraphContext` parameter, again parameterized with state `MyState`. +3. The `run` method takes a `GraphRunContext` parameter, again parameterized with state `MyState`. 4. The return type of the `run` method is `AnotherNode` (not shown), this is used to determine the outgoing edges of the node. We could extend `MyNode` to optionally end the run if `foo` is divisible by 5: @@ -93,7 +93,7 @@ We could extend `MyNode` to optionally end the run if `foo` is divisible by 5: ```py {title="intermediate_or_end_node.py" hl_lines="7 13 15" noqa="F821" test="skip"} from dataclasses import dataclass -from pydantic_graph import BaseNode, End, GraphContext +from pydantic_graph import BaseNode, End, GraphRunContext @dataclass @@ -102,7 +102,7 @@ class MyNode(BaseNode[MyState, None, int]): # (1)! async def run( self, - ctx: GraphContext[MyState], + ctx: GraphRunContext[MyState], ) -> AnotherNode | End[int]: # (2)! if self.foo % 5 == 0: return End(self.foo) @@ -130,7 +130,7 @@ from __future__ import annotations from dataclasses import dataclass -from pydantic_graph import BaseNode, End, Graph, GraphContext +from pydantic_graph import BaseNode, End, Graph, GraphRunContext @dataclass @@ -139,7 +139,7 @@ class DivisibleBy5(BaseNode[None, None, int]): # (1)! async def run( self, - ctx: GraphContext, + ctx: GraphRunContext, ) -> Increment | End[int]: if self.foo % 5 == 0: return End(self.foo) @@ -151,7 +151,7 @@ class DivisibleBy5(BaseNode[None, None, int]): # (1)! class Increment(BaseNode): # (2)! foo: int - async def run(self, ctx: GraphContext) -> DivisibleBy5: + async def run(self, ctx: GraphRunContext) -> DivisibleBy5: return DivisibleBy5(self.foo + 1) @@ -192,9 +192,9 @@ stateDiagram-v2 ## Stateful Graphs -TODO introduce state +The "state" concept in `pydantic-graph` provides an optional way to access and mutate an object (often a `dataclass` or Pydantic model) as nodes run in a graph. If you think of Graphs as a production line, then you state is the engine being passed along the line and built up by each node as the graph is run. -TODO link to issue about persistent state. +In future, we intend to extend `pydantic-graph` to provide state persistence with the state recorded after each node is run, see [#695](https://github.com/pydantic/pydantic-ai/issues/695). Here's an example of a graph which represents a vending machine where the user may insert coins and select a product to purchase. @@ -205,7 +205,7 @@ from dataclasses import dataclass from rich.prompt import Prompt -from pydantic_graph import BaseNode, End, Graph, GraphContext +from pydantic_graph import BaseNode, End, Graph, GraphRunContext @dataclass @@ -216,7 +216,7 @@ class MachineState: # (1)! @dataclass class InsertCoin(BaseNode[MachineState]): # (3)! - async def run(self, ctx: GraphContext[MachineState]) -> CoinsInserted: # (16)! + async def run(self, ctx: GraphRunContext[MachineState]) -> CoinsInserted: # (16)! return CoinsInserted(float(Prompt.ask('Insert coins'))) # (4)! @@ -225,7 +225,7 @@ class CoinsInserted(BaseNode[MachineState]): amount: float # (5)! async def run( - self, ctx: GraphContext[MachineState] + self, ctx: GraphRunContext[MachineState] ) -> SelectProduct | Purchase: # (17)! ctx.state.user_balance += self.amount # (6)! if ctx.state.product is not None: # (7)! @@ -236,7 +236,7 @@ class CoinsInserted(BaseNode[MachineState]): @dataclass class SelectProduct(BaseNode[MachineState]): - async def run(self, ctx: GraphContext[MachineState]) -> Purchase: + async def run(self, ctx: GraphRunContext[MachineState]) -> Purchase: return Purchase(Prompt.ask('Select product')) @@ -253,7 +253,7 @@ class Purchase(BaseNode[MachineState, None, None]): # (18)! product: str async def run( - self, ctx: GraphContext[MachineState] + self, ctx: GraphRunContext[MachineState] ) -> End | InsertCoin | SelectProduct: if price := PRODUCT_PRICES.get(self.product): # (8)! ctx.state.product = self.product # (9)! @@ -360,7 +360,7 @@ from pydantic import BaseModel, EmailStr from pydantic_ai import Agent from pydantic_ai.format_as_xml import format_as_xml from pydantic_ai.messages import ModelMessage -from pydantic_graph import BaseNode, End, Graph, GraphContext +from pydantic_graph import BaseNode, End, Graph, GraphRunContext @dataclass @@ -393,7 +393,7 @@ email_writer_agent = Agent( class WriteEmail(BaseNode[State]): email_feedback: str | None = None - async def run(self, ctx: GraphContext[State]) -> Feedback: + async def run(self, ctx: GraphRunContext[State]) -> Feedback: if self.email_feedback: prompt = ( f'Rewrite the email for the user:\n' @@ -437,7 +437,7 @@ class Feedback(BaseNode[State, None, Email]): async def run( self, - ctx: GraphContext[State], + ctx: GraphRunContext[State], ) -> WriteEmail | End[Email]: prompt = format_as_xml({'user': ctx.state.user, 'email': self.email}) result = await feedback_agent.run(prompt) @@ -481,7 +481,7 @@ In this example, an AI asks the user a question, the user provides an answer, th from dataclasses import dataclass, field - from pydantic_graph import BaseNode, End, Graph, GraphContext + from pydantic_graph import BaseNode, End, Graph, GraphRunContext from pydantic_ai import Agent from pydantic_ai.format_as_xml import format_as_xml @@ -499,7 +499,7 @@ In this example, an AI asks the user a question, the user provides an answer, th @dataclass class Ask(BaseNode[QuestionState]): - async def run(self, ctx: GraphContext[QuestionState]) -> Answer: + async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer: result = await ask_agent.run( 'Ask a simple question with a single correct answer.', message_history=ctx.state.ask_agent_messages, @@ -514,7 +514,7 @@ In this example, an AI asks the user a question, the user provides an answer, th question: str answer: str | None = None - async def run(self, ctx: GraphContext[QuestionState]) -> Evaluate: + async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate: assert self.answer is not None return Evaluate(self.answer) @@ -538,7 +538,7 @@ In this example, an AI asks the user a question, the user provides an answer, th async def run( self, - ctx: GraphContext[QuestionState], + ctx: GraphRunContext[QuestionState], ) -> End[str] | Reprimand: assert ctx.state.question is not None result = await evaluate_agent.run( @@ -556,7 +556,7 @@ In this example, an AI asks the user a question, the user provides an answer, th class Reprimand(BaseNode[QuestionState]): comment: str - async def run(self, ctx: GraphContext[QuestionState]) -> Ask: + async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask: print(f'Comment: {self.comment}') ctx.state.question = None return Ask() @@ -637,7 +637,7 @@ You maybe have noticed that although this examples transfers control flow out of ## Dependency Injection -As with PydanticAI, `pydantic-graph` supports dependency injection via a generic parameter on [`Graph`][pydantic_graph.graph.Graph] and [`BaseNode`][pydantic_graph.nodes.BaseNode], and the [`GraphContext.deps`][pydantic_graph.nodes.GraphContext.deps] fields. +As with PydanticAI, `pydantic-graph` supports dependency injection via a generic parameter on [`Graph`][pydantic_graph.graph.Graph] and [`BaseNode`][pydantic_graph.nodes.BaseNode], and the [`GraphRunContext.deps`][pydantic_graph.nodes.GraphRunContext.deps] fields. As an example of dependency injection, let's modify the `DivisibleBy5` example [above](#graph) to use a [`ProcessPoolExecutor`][concurrent.futures.ProcessPoolExecutor] to run the compute load in a separate process (this is a contrived example, `ProcessPoolExecutor` wouldn't actually improve performance in this example): @@ -648,7 +648,7 @@ import asyncio from concurrent.futures import ProcessPoolExecutor from dataclasses import dataclass -from pydantic_graph import BaseNode, End, Graph, GraphContext +from pydantic_graph import BaseNode, End, Graph, GraphRunContext @dataclass @@ -662,7 +662,7 @@ class DivisibleBy5(BaseNode[None, None, int]): async def run( self, - ctx: GraphContext, + ctx: GraphRunContext, ) -> Increment | End[int]: if self.foo % 5 == 0: return End(self.foo) @@ -674,7 +674,7 @@ class DivisibleBy5(BaseNode[None, None, int]): class Increment(BaseNode): foo: int - async def run(self, ctx: GraphContext) -> DivisibleBy5: + async def run(self, ctx: GraphRunContext) -> DivisibleBy5: loop = asyncio.get_running_loop() compute_result = await loop.run_in_executor( ctx.deps.executor, @@ -738,7 +738,7 @@ Putting that together, we can edit the last [`ai_q_and_a_graph.py`](#custom-cont ... from typing import Annotated -from pydantic_graph import BaseNode, End, Graph, GraphContext, Edge +from pydantic_graph import BaseNode, End, Graph, GraphRunContext, Edge ... @@ -747,7 +747,7 @@ class Ask(BaseNode[QuestionState]): """Generate question using GPT-4o.""" docstring_notes = True async def run( - self, ctx: GraphContext[QuestionState] + self, ctx: GraphRunContext[QuestionState] ) -> Annotated[Answer, Edge(label='Ask the question')]: ... @@ -759,7 +759,7 @@ class Evaluate(BaseNode[QuestionState]): async def run( self, - ctx: GraphContext[QuestionState], + ctx: GraphRunContext[QuestionState], ) -> Annotated[End[str], Edge(label='success')] | Reprimand: ... diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index 6ef092d3ba..47bae54c8c 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -13,7 +13,7 @@ import logfire from devtools import debug -from pydantic_graph import BaseNode, Edge, End, Graph, GraphContext, HistoryStep +from pydantic_graph import BaseNode, Edge, End, Graph, GraphRunContext, HistoryStep from pydantic_ai import Agent from pydantic_ai.format_as_xml import format_as_xml @@ -34,7 +34,7 @@ class QuestionState: @dataclass class Ask(BaseNode[QuestionState]): - async def run(self, ctx: GraphContext[QuestionState]) -> Answer: + async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer: result = await ask_agent.run( 'Ask a simple question with a single correct answer.', message_history=ctx.state.ask_agent_messages, @@ -48,7 +48,7 @@ async def run(self, ctx: GraphContext[QuestionState]) -> Answer: class Answer(BaseNode[QuestionState]): answer: str | None = None - async def run(self, ctx: GraphContext[QuestionState]) -> Evaluate: + async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate: assert self.answer is not None return Evaluate(self.answer) @@ -72,7 +72,7 @@ class Evaluate(BaseNode[QuestionState]): async def run( self, - ctx: GraphContext[QuestionState], + ctx: GraphRunContext[QuestionState], ) -> Congratulate | Reprimand: assert ctx.state.question is not None result = await evaluate_agent.run( @@ -91,7 +91,7 @@ class Congratulate(BaseNode[QuestionState, None, None]): comment: str async def run( - self, ctx: GraphContext[QuestionState] + self, ctx: GraphRunContext[QuestionState] ) -> Annotated[End, Edge(label='success')]: print(f'Correct answer! {self.comment}') return End(None) @@ -101,7 +101,7 @@ async def run( class Reprimand(BaseNode[QuestionState]): comment: str - async def run(self, ctx: GraphContext[QuestionState]) -> Ask: + async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask: print(f'Comment: {self.comment}') # > Comment: Vichy is no longer the capital of France. ctx.state.question = None diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md index 0b2508f033..15a4062e05 100644 --- a/pydantic_graph/README.md +++ b/pydantic_graph/README.md @@ -24,7 +24,7 @@ from __future__ import annotations from dataclasses import dataclass -from pydantic_graph import BaseNode, End, Graph, GraphContext +from pydantic_graph import BaseNode, End, Graph, GraphRunContext @dataclass @@ -33,7 +33,7 @@ class DivisibleBy5(BaseNode[None, None, int]): async def run( self, - ctx: GraphContext, + ctx: GraphRunContext, ) -> Increment | End[int]: if self.foo % 5 == 0: return End(self.foo) @@ -45,7 +45,7 @@ class DivisibleBy5(BaseNode[None, None, int]): class Increment(BaseNode): foo: int - async def run(self, ctx: GraphContext) -> DivisibleBy5: + async def run(self, ctx: GraphRunContext) -> DivisibleBy5: return DivisibleBy5(self.foo + 1) diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index 135f7529a5..d4c6074e1a 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -1,13 +1,13 @@ from .exceptions import GraphRuntimeError, GraphSetupError from .graph import Graph -from .nodes import BaseNode, Edge, End, GraphContext +from .nodes import BaseNode, Edge, End, GraphRunContext from .state import EndStep, HistoryStep, NodeStep __all__ = ( 'Graph', 'BaseNode', 'End', - 'GraphContext', + 'GraphRunContext', 'Edge', 'EndStep', 'HistoryStep', diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 0aed1f3706..878918ebba 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -15,7 +15,7 @@ import typing_extensions from . import _utils, exceptions, mermaid -from .nodes import BaseNode, DepsT, End, GraphContext, NodeDef, RunEndT +from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state, nodes_schema_var __all__ = ('Graph',) @@ -38,7 +38,7 @@ class Graph(Generic[StateT, DepsT, RunEndT]): from dataclasses import dataclass - from pydantic_graph import BaseNode, End, Graph, GraphContext + from pydantic_graph import BaseNode, End, Graph, GraphRunContext @dataclass class MyState: @@ -46,13 +46,13 @@ class MyState: @dataclass class Increment(BaseNode[MyState]): - async def run(self, ctx: GraphContext) -> Check42: + async def run(self, ctx: GraphRunContext) -> Check42: ctx.state.number += 1 return Check42() @dataclass class Check42(BaseNode[MyState, None, int]): - async def run(self, ctx: GraphContext) -> Increment | End[int]: + async def run(self, ctx: GraphRunContext) -> Increment | End[int]: if ctx.state.number == 42: return Increment() else: @@ -229,7 +229,7 @@ async def next( if node_id not in self.node_defs: raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') - ctx = GraphContext(state, deps) + ctx = GraphRunContext(state, deps) with _logfire.span('run node {node_id}', node_id=node_id, node=node): start_ts = _utils.now_utc() start = perf_counter() diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index 3591e1b5be..31363f2f1e 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -14,7 +14,7 @@ else: StateT = TypeVar('StateT', default=None) -__all__ = 'GraphContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'RunEndT', 'NodeRunEndT', 'DepsT' +__all__ = 'GraphRunContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'RunEndT', 'NodeRunEndT', 'DepsT' RunEndT = TypeVar('RunEndT', default=None) """Type variable for the return type of a graph [`run`][pydantic_graph.graph.Graph.run].""" @@ -25,7 +25,7 @@ @dataclass -class GraphContext(Generic[StateT, DepsT]): +class GraphRunContext(Generic[StateT, DepsT]): """Context for a graph.""" state: StateT @@ -46,7 +46,7 @@ class BaseNode(ABC, Generic[StateT, DepsT, NodeRunEndT]): """ @abstractmethod - async def run(self, ctx: GraphContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[NodeRunEndT]: + async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[NodeRunEndT]: """Run the node. This is an abstract method that must be implemented by subclasses. diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 2ae8bc5a9d..ebd254a370 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -16,7 +16,7 @@ End, EndStep, Graph, - GraphContext, + GraphRunContext, GraphRuntimeError, GraphSetupError, HistoryStep, @@ -33,21 +33,21 @@ async def test_graph(): class Float2String(BaseNode): input_data: float - async def run(self, ctx: GraphContext) -> String2Length: + async def run(self, ctx: GraphRunContext) -> String2Length: return String2Length(str(self.input_data)) @dataclass class String2Length(BaseNode): input_data: str - async def run(self, ctx: GraphContext) -> Double: + async def run(self, ctx: GraphRunContext) -> Double: return Double(len(self.input_data)) @dataclass class Double(BaseNode[None, None, int]): input_data: int - async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noqa: UP007 + async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # noqa: UP007 if self.input_data == 7: return String2Length('x' * 21) else: @@ -136,11 +136,11 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq def test_one_bad_node(): class Float2String(BaseNode): - async def run(self, ctx: GraphContext) -> String2Length: + async def run(self, ctx: GraphRunContext) -> String2Length: raise NotImplementedError() class String2Length(BaseNode[None, None, None]): - async def run(self, ctx: GraphContext) -> End[None]: + async def run(self, ctx: GraphRunContext) -> End[None]: raise NotImplementedError() with pytest.raises(GraphSetupError) as exc_info: @@ -155,17 +155,17 @@ def test_two_bad_nodes(): class Foo(BaseNode): input_data: float - async def run(self, ctx: GraphContext) -> Union[Bar, Spam]: # noqa: UP007 + async def run(self, ctx: GraphRunContext) -> Union[Bar, Spam]: # noqa: UP007 raise NotImplementedError() class Bar(BaseNode[None, None, None]): input_data: str - async def run(self, ctx: GraphContext) -> End[None]: + async def run(self, ctx: GraphRunContext) -> End[None]: raise NotImplementedError() class Spam(BaseNode[None, None, None]): - async def run(self, ctx: GraphContext) -> End[None]: + async def run(self, ctx: GraphRunContext) -> End[None]: raise NotImplementedError() with pytest.raises(GraphSetupError) as exc_info: @@ -182,21 +182,21 @@ def test_three_bad_nodes_separate(): class Foo(BaseNode): input_data: float - async def run(self, ctx: GraphContext) -> Eggs: + async def run(self, ctx: GraphRunContext) -> Eggs: raise NotImplementedError() class Bar(BaseNode[None, None, None]): input_data: str - async def run(self, ctx: GraphContext) -> Eggs: + async def run(self, ctx: GraphRunContext) -> Eggs: raise NotImplementedError() class Spam(BaseNode[None, None, None]): - async def run(self, ctx: GraphContext) -> Eggs: + async def run(self, ctx: GraphRunContext) -> Eggs: raise NotImplementedError() class Eggs(BaseNode[None, None, None]): - async def run(self, ctx: GraphContext) -> End[None]: + async def run(self, ctx: GraphRunContext) -> End[None]: raise NotImplementedError() with pytest.raises(GraphSetupError) as exc_info: @@ -209,11 +209,11 @@ async def run(self, ctx: GraphContext) -> End[None]: def test_duplicate_id(): class Foo(BaseNode): - async def run(self, ctx: GraphContext) -> Bar: + async def run(self, ctx: GraphRunContext) -> Bar: raise NotImplementedError() class Bar(BaseNode[None, None, None]): - async def run(self, ctx: GraphContext) -> End[None]: + async def run(self, ctx: GraphRunContext) -> End[None]: raise NotImplementedError() @classmethod @@ -230,17 +230,17 @@ def get_id(cls) -> str: async def test_run_node_not_in_graph(): @dataclass class Foo(BaseNode): - async def run(self, ctx: GraphContext) -> Bar: + async def run(self, ctx: GraphRunContext) -> Bar: return Bar() @dataclass class Bar(BaseNode[None, None, None]): - async def run(self, ctx: GraphContext) -> End[None]: + async def run(self, ctx: GraphRunContext) -> End[None]: return Spam() # type: ignore @dataclass class Spam(BaseNode[None, None, None]): - async def run(self, ctx: GraphContext) -> End[None]: + async def run(self, ctx: GraphRunContext) -> End[None]: raise NotImplementedError() g = Graph(nodes=(Foo, Bar)) @@ -253,12 +253,12 @@ async def run(self, ctx: GraphContext) -> End[None]: async def test_run_return_other(): @dataclass class Foo(BaseNode): - async def run(self, ctx: GraphContext) -> Bar: + async def run(self, ctx: GraphRunContext) -> Bar: return Bar() @dataclass class Bar(BaseNode[None, None, None]): - async def run(self, ctx: GraphContext) -> End[None]: + async def run(self, ctx: GraphRunContext) -> End[None]: return 42 # type: ignore g = Graph(nodes=(Foo, Bar)) @@ -273,12 +273,12 @@ async def run(self, ctx: GraphContext) -> End[None]: async def test_next(): @dataclass class Foo(BaseNode): - async def run(self, ctx: GraphContext) -> Bar: + async def run(self, ctx: GraphRunContext) -> Bar: return Bar() @dataclass class Bar(BaseNode): - async def run(self, ctx: GraphContext) -> Foo: + async def run(self, ctx: GraphRunContext) -> Foo: return Foo() g = Graph(nodes=(Foo, Bar)) @@ -309,13 +309,13 @@ class Deps: @dataclass class Foo(BaseNode[None, Deps]): - async def run(self, ctx: GraphContext[None, Deps]) -> Bar: + async def run(self, ctx: GraphRunContext[None, Deps]) -> Bar: assert isinstance(ctx.deps, Deps) return Bar() @dataclass class Bar(BaseNode[None, Deps, int]): - async def run(self, ctx: GraphContext[None, Deps]) -> End[int]: + async def run(self, ctx: GraphRunContext[None, Deps]) -> End[int]: assert isinstance(ctx.deps, Deps) return End(123) diff --git a/tests/graph/test_history.py b/tests/graph/test_history.py index ef3a88532e..2508a53475 100644 --- a/tests/graph/test_history.py +++ b/tests/graph/test_history.py @@ -9,7 +9,7 @@ from dirty_equals import IsStr from inline_snapshot import snapshot -from pydantic_graph import BaseNode, End, EndStep, Graph, GraphContext, GraphSetupError, NodeStep +from pydantic_graph import BaseNode, End, EndStep, Graph, GraphRunContext, GraphSetupError, NodeStep from ..conftest import IsFloat, IsNow @@ -24,14 +24,14 @@ class MyState: @dataclass class Foo(BaseNode[MyState]): - async def run(self, ctx: GraphContext[MyState]) -> Bar: + async def run(self, ctx: GraphRunContext[MyState]) -> Bar: ctx.state.x += 1 return Bar() @dataclass class Bar(BaseNode[MyState, None, int]): - async def run(self, ctx: GraphContext[MyState]) -> End[int]: + async def run(self, ctx: GraphRunContext[MyState]) -> End[int]: ctx.state.y += 'y' return End(ctx.state.x * 2) @@ -105,7 +105,7 @@ async def test_dump_load_history(graph: Graph[MyState, None, int]): def test_one_node(): @dataclass class MyNode(BaseNode[None, None, int]): - async def run(self, ctx: GraphContext) -> End[int]: + async def run(self, ctx: GraphRunContext) -> End[int]: return End(123) g = Graph(nodes=[MyNode]) @@ -124,7 +124,7 @@ async def run(self, ctx: GraphContext) -> End[int]: def test_no_generic_arg(): @dataclass class NoGenericArgsNode(BaseNode): - async def run(self, ctx: GraphContext) -> NoGenericArgsNode: + async def run(self, ctx: GraphRunContext) -> NoGenericArgsNode: return NoGenericArgsNode() g = Graph(nodes=[NoGenericArgsNode]) diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 0a4302c2e1..5588fb6707 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -11,7 +11,7 @@ import pytest from inline_snapshot import snapshot -from pydantic_graph import BaseNode, Edge, End, EndStep, Graph, GraphContext, GraphSetupError, NodeStep +from pydantic_graph import BaseNode, Edge, End, EndStep, Graph, GraphRunContext, GraphSetupError, NodeStep from pydantic_graph.nodes import NodeDef from ..conftest import IsFloat, IsNow @@ -21,13 +21,13 @@ @dataclass class Foo(BaseNode): - async def run(self, ctx: GraphContext) -> Bar: + async def run(self, ctx: GraphRunContext) -> Bar: return Bar() @dataclass class Bar(BaseNode[None, None, None]): - async def run(self, ctx: GraphContext) -> End[None]: + async def run(self, ctx: GraphRunContext) -> End[None]: return End(None) @@ -40,7 +40,7 @@ class Spam(BaseNode): docstring_notes = True - async def run(self, ctx: GraphContext) -> Annotated[Foo, Edge(label='spam to foo')]: + async def run(self, ctx: GraphRunContext) -> Annotated[Foo, Edge(label='spam to foo')]: raise NotImplementedError() @@ -50,7 +50,7 @@ class Eggs(BaseNode[None, None, None]): docstring_notes = False - async def run(self, ctx: GraphContext) -> Annotated[End[None], Edge(label='eggs to end')]: + async def run(self, ctx: GraphRunContext) -> Annotated[End[None], Edge(label='eggs to end')]: raise NotImplementedError() @@ -173,7 +173,7 @@ def test_mermaid_code_without_edge_labels(): @dataclass class AllNodes(BaseNode): - async def run(self, ctx: GraphContext) -> BaseNode: + async def run(self, ctx: GraphRunContext) -> BaseNode: raise NotImplementedError() @@ -364,7 +364,7 @@ def test_get_node_def(): def test_no_return_type(): @dataclass class NoReturnType(BaseNode): - async def run(self, ctx: GraphContext): # type: ignore + async def run(self, ctx: GraphRunContext): # type: ignore raise NotImplementedError() with pytest.raises(GraphSetupError, match=r".*\.NoReturnType'> is missing a return type hint on its `run` method"): @@ -374,7 +374,7 @@ async def run(self, ctx: GraphContext): # type: ignore def test_wrong_return_type(): @dataclass class NoReturnType(BaseNode): - async def run(self, ctx: GraphContext) -> int: # type: ignore + async def run(self, ctx: GraphRunContext) -> int: # type: ignore raise NotImplementedError() with pytest.raises(GraphSetupError, match="Invalid return type: "): diff --git a/tests/graph/test_state.py b/tests/graph/test_state.py index 597a420258..fbb570cf0c 100644 --- a/tests/graph/test_state.py +++ b/tests/graph/test_state.py @@ -7,7 +7,7 @@ import pytest from inline_snapshot import snapshot -from pydantic_graph import BaseNode, End, EndStep, Graph, GraphContext, NodeStep +from pydantic_graph import BaseNode, End, EndStep, Graph, GraphRunContext, NodeStep from ..conftest import IsFloat, IsNow @@ -22,13 +22,13 @@ class MyState: @dataclass class Foo(BaseNode[MyState]): - async def run(self, ctx: GraphContext[MyState]) -> Bar: + async def run(self, ctx: GraphRunContext[MyState]) -> Bar: ctx.state.x += 1 return Bar() @dataclass class Bar(BaseNode[MyState, None, str]): - async def run(self, ctx: GraphContext[MyState]) -> End[str]: + async def run(self, ctx: GraphRunContext[MyState]) -> End[str]: ctx.state.y += 'y' return End(f'x={ctx.state.x} y={ctx.state.y}') diff --git a/tests/typed_graph.py b/tests/typed_graph.py index 14bd171669..06ef0b6a55 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -4,14 +4,14 @@ from typing_extensions import assert_type -from pydantic_graph import BaseNode, End, Graph, GraphContext, HistoryStep +from pydantic_graph import BaseNode, End, Graph, GraphRunContext, HistoryStep @dataclass class Float2String(BaseNode): input_data: float - async def run(self, ctx: GraphContext) -> String2Length: + async def run(self, ctx: GraphRunContext) -> String2Length: return String2Length(str(self.input_data)) @@ -19,7 +19,7 @@ async def run(self, ctx: GraphContext) -> String2Length: class String2Length(BaseNode): input_data: str - async def run(self, ctx: GraphContext) -> Double: + async def run(self, ctx: GraphRunContext) -> Double: return Double(len(self.input_data)) @@ -32,7 +32,7 @@ class X: class Double(BaseNode[None, None, X]): input_data: int - async def run(self, ctx: GraphContext) -> String2Length | End[X]: + async def run(self, ctx: GraphRunContext) -> String2Length | End[X]: if self.input_data == 7: return String2Length('x' * 21) else: @@ -86,7 +86,7 @@ class MyDeps: @dataclass class A(BaseNode[MyState, MyDeps]): - async def run(self, ctx: GraphContext[MyState, MyDeps]) -> B: + async def run(self, ctx: GraphRunContext[MyState, MyDeps]) -> B: assert ctx.state.x == 1 assert ctx.deps.y == 'y' return B() @@ -94,7 +94,7 @@ async def run(self, ctx: GraphContext[MyState, MyDeps]) -> B: @dataclass class B(BaseNode[MyState, MyDeps, int]): - async def run(self, ctx: GraphContext[MyState, MyDeps]) -> End[int]: + async def run(self, ctx: GraphRunContext[MyState, MyDeps]) -> End[int]: return End(42)