diff --git a/docs/api/pydantic_graph/nodes.md b/docs/api/pydantic_graph/nodes.md index 7540920a02..e58ddf7012 100644 --- a/docs/api/pydantic_graph/nodes.md +++ b/docs/api/pydantic_graph/nodes.md @@ -7,5 +7,6 @@ - BaseNode - End - Edge + - DepsT - RunEndT - NodeRunEndT diff --git a/docs/graph.md b/docs/graph.md index ba5e00f511..bbd546ddc3 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -12,7 +12,7 @@ Unless you're sure you need a graph, you probably don't. -Graphs and finite state machines (FSMs) are a powerful abstraction to model, control and visualize complex workflows. +Graphs and finite state machines (FSMs) are a powerful abstraction to model, execute, control and visualize complex workflows. Alongside PydanticAI, we've developed `pydantic-graph` — an async graph and state machine library for Python where nodes and edges are defined using type hints. @@ -59,7 +59,8 @@ Nodes which are generally [`dataclass`es][dataclasses.dataclass] include: Nodes are generic in: -* **state**, which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.state.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter +* **state**, which must have the same type as the state of graphs they're included in, [`StateT`][pydantic_graph.state.StateT] has a default of `None`, so if you're not using state you can omit this generic parameter, see [stateful graphs](#stateful-graphs) for more information +* **deps**, which must have the same type as the deps of the graph they're included in, [`DepsT`][pydantic_graph.nodes.DepsT] has a default of `None`, so if you're not using deps you can omit this generic parameter, see [dependency injection](#dependency-injection) for more information * **graph return type** — this only applies if the node returns [`End`][pydantic_graph.nodes.End]. [`RunEndT`][pydantic_graph.nodes.RunEndT] has a default of [Never][typing.Never] so this generic parameter can be omitted if the node doesn't return `End`, but must be included if it does. Here's an example of a start or intermediate node in a graph — it can't end the run as it doesn't return [`End`][pydantic_graph.nodes.End]: @@ -96,7 +97,7 @@ from pydantic_graph import BaseNode, End, GraphContext @dataclass -class MyNode(BaseNode[MyState, int]): # (1)! +class MyNode(BaseNode[MyState, None, int]): # (1)! foo: int async def run( @@ -109,7 +110,7 @@ class MyNode(BaseNode[MyState, int]): # (1)! return AnotherNode() ``` -1. We parameterize the node with the return type (`int` in this case) as well as state. +1. We parameterize the node with the return type (`int` in this case) as well as state. Because generic parameters are positional-only, we have to include `None` as the second parameter representing deps. 2. The return type of the `run` method is now a union of `AnotherNode` and `End[int]`, this allows the node to end the run if `foo` is divisible by 5. ### Graph @@ -119,6 +120,7 @@ class MyNode(BaseNode[MyState, int]): # (1)! `Graph` is generic in: * **state** the state type of the graph, [`StateT`][pydantic_graph.state.StateT] +* **deps** the deps type of the graph, [`DepsT`][pydantic_graph.nodes.DepsT] * **graph return type** the return type of the graph run, [`RunEndT`][pydantic_graph.nodes.RunEndT] Here's an example of a simple graph: @@ -132,7 +134,7 @@ from pydantic_graph import BaseNode, End, Graph, GraphContext @dataclass -class DivisibleBy5(BaseNode[None, int]): # (1)! +class DivisibleBy5(BaseNode[None, None, int]): # (1)! foo: int async def run( @@ -154,7 +156,7 @@ class Increment(BaseNode): # (2)! fives_graph = Graph(nodes=[DivisibleBy5, Increment]) # (3)! -result, history = fives_graph.run_sync(None, DivisibleBy5(4)) # (4)! +result, history = fives_graph.run_sync(DivisibleBy5(4)) # (4)! print(result) #> 5 # the full history is quite verbose (see below), so we'll just print the summary @@ -162,7 +164,7 @@ print([item.data_snapshot() for item in history]) #> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] ``` -1. The `DivisibleBy5` node is parameterized with `None` as this graph doesn't use state, and `int` as it can end the run. +1. The `DivisibleBy5` node is parameterized with `None` for the state param and `None` for the deps param as this graph doesn't use state or deps, and `int` as it can end the run. 2. The `Increment` node doesn't return `End`, so the `RunEndT` generic parameter is omitted, state can also be omitted as the graph doesn't use state. 3. The graph is created with a sequence of nodes. 4. The graph is run synchronously with [`run_sync`][pydantic_graph.graph.Graph.run_sync] the initial state `None` and the start node `DivisibleBy5(4)` are passed as arguments. @@ -247,7 +249,7 @@ PRODUCT_PRICES = { # (2)! @dataclass -class Purchase(BaseNode[MachineState, None]): # (18)! +class Purchase(BaseNode[MachineState, None, None]): # (18)! product: str async def run( @@ -275,7 +277,7 @@ vending_machine_graph = Graph( # (13)! async def main(): state = MachineState() # (14)! - await vending_machine_graph.run(state, InsertCoin()) # (15)! + await vending_machine_graph.run(InsertCoin(), state=state) # (15)! print(f'purchase successful item={state.product} change={state.user_balance:0.2f}') #> purchase successful item=crisps change=0.25 ``` @@ -430,7 +432,7 @@ feedback_agent = Agent[None, EmailRequiresWrite | EmailOk]( @dataclass -class Feedback(BaseNode[State, Email]): +class Feedback(BaseNode[State, None, Email]): email: Email async def run( @@ -453,7 +455,7 @@ async def main(): ) state = State(user) feedback_graph = Graph(nodes=(WriteEmail, Feedback)) - email, _ = await feedback_graph.run(state, WriteEmail()) + email, _ = await feedback_graph.run(WriteEmail(), state=state) print(email) """ Email( @@ -579,7 +581,7 @@ async def main(): node = Ask() # (2)! history: list[HistoryStep[QuestionState]] = [] # (3)! while True: - node = await question_graph.next(state, node, history) # (4)! + node = await question_graph.next(node, history, state=state) # (4)! if isinstance(node, Answer): node.answer = Prompt.ask(node.question) # (5)! elif isinstance(node, End): # (6)! @@ -633,6 +635,82 @@ stateDiagram-v2 You maybe have noticed that although this examples transfers control flow out of the graph run, we're still using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] to get user input, with the process hanging while we wait for the user to enter a response. For an example of genuine out-of-process control flow, see the [question graph example](examples/question-graph.md). +## Dependency Injection + +As with PydanticAI, `pydantic-graph` supports dependency injection via a generic parameter on [`Graph`][pydantic_graph.graph.Graph] and [`BaseNode`][pydantic_graph.nodes.BaseNode], and the [`GraphContext.deps`][pydantic_graph.nodes.GraphContext.deps] fields. + +As an example of dependency injection, let's modify the `DivisibleBy5` example [above](#graph) to use a [`ProcessPoolExecutor`][concurrent.futures.ProcessPoolExecutor] to run the compute load in a separate process (this is a contrived example, `ProcessPoolExecutor` wouldn't actually improve performance in this example): + +```py {title="deps_example.py" py="3.10" test="skip" hl_lines="4 10-12 35-37 48-49"} +from __future__ import annotations + +import asyncio +from concurrent.futures import ProcessPoolExecutor +from dataclasses import dataclass + +from pydantic_graph import BaseNode, End, Graph, GraphContext + + +@dataclass +class GraphDeps: + executor: ProcessPoolExecutor + + +@dataclass +class DivisibleBy5(BaseNode[None, None, int]): + foo: int + + async def run( + self, + ctx: GraphContext, + ) -> Increment | End[int]: + if self.foo % 5 == 0: + return End(self.foo) + else: + return Increment(self.foo) + + +@dataclass +class Increment(BaseNode): + foo: int + + async def run(self, ctx: GraphContext) -> DivisibleBy5: + loop = asyncio.get_running_loop() + compute_result = await loop.run_in_executor( + ctx.deps.executor, + self.compute, + ) + return DivisibleBy5(compute_result) + + def compute(self) -> int: + return self.foo + 1 + + +fives_graph = Graph(nodes=[DivisibleBy5, Increment]) + + +async def main(): + with ProcessPoolExecutor() as executor: + deps = GraphDeps(executor) + result, history = await fives_graph.run(DivisibleBy5(3), deps=deps) + print(result) + #> 5 + # the full history is quite verbose (see below), so we'll just print the summary + print([item.data_snapshot() for item in history]) + """ + [ + DivisibleBy5(foo=3), + Increment(foo=3), + DivisibleBy5(foo=4), + Increment(foo=4), + DivisibleBy5(foo=5), + End(data=5), + ] + """ +``` + +_(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ + ## Mermaid Diagrams Pydantic Graph can generate [mermaid](https://mermaid.js.org/) [`stateDiagram-v2`](https://mermaid.js.org/syntax/stateDiagram.html) diagrams for graphs, as shown above. diff --git a/examples/pydantic_ai_examples/question_graph.py b/examples/pydantic_ai_examples/question_graph.py index f73dbf1684..6ef092d3ba 100644 --- a/examples/pydantic_ai_examples/question_graph.py +++ b/examples/pydantic_ai_examples/question_graph.py @@ -87,7 +87,7 @@ async def run( @dataclass -class Congratulate(BaseNode[QuestionState, None]): +class Congratulate(BaseNode[QuestionState, None, None]): comment: str async def run( @@ -119,7 +119,7 @@ async def run_as_continuous(): history: list[HistoryStep[QuestionState, None]] = [] with logfire.span('run questions graph'): while True: - node = await question_graph.next(state, node, history) + node = await question_graph.next(node, history, state=state) if isinstance(node, End): debug([e.data_snapshot() for e in history]) break @@ -150,7 +150,7 @@ async def run_as_cli(answer: str | None): with logfire.span('run questions graph'): while True: - node = await question_graph.next(state, node, history) + node = await question_graph.next(node, history, state=state) if isinstance(node, End): debug([e.data_snapshot() for e in history]) print('Finished!') diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md index 6ac17fb216..0b2508f033 100644 --- a/pydantic_graph/README.md +++ b/pydantic_graph/README.md @@ -28,7 +28,7 @@ from pydantic_graph import BaseNode, End, Graph, GraphContext @dataclass -class DivisibleBy5(BaseNode[None, int]): +class DivisibleBy5(BaseNode[None, None, int]): foo: int async def run( @@ -50,7 +50,7 @@ class Increment(BaseNode): fives_graph = Graph(nodes=[DivisibleBy5, Increment]) -result, history = fives_graph.run_sync(None, DivisibleBy5(4)) +result, history = fives_graph.run_sync(DivisibleBy5(4)) print(result) #> 5 # the full history is quite verbose (see below), so we'll just print the summary diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 5f25330cee..0aed1f3706 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -15,7 +15,7 @@ import typing_extensions from . import _utils, exceptions, mermaid -from .nodes import BaseNode, End, GraphContext, NodeDef, RunEndT +from .nodes import BaseNode, DepsT, End, GraphContext, NodeDef, RunEndT from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state, nodes_schema_var __all__ = ('Graph',) @@ -24,7 +24,7 @@ @dataclass(init=False) -class Graph(Generic[StateT, RunEndT]): +class Graph(Generic[StateT, DepsT, RunEndT]): """Definition of a graph. In `pydantic-graph`, a graph is a collection of nodes that can be run in sequence. The nodes define @@ -51,12 +51,12 @@ async def run(self, ctx: GraphContext) -> Check42: return Check42() @dataclass - class Check42(BaseNode[MyState]): - async def run(self, ctx: GraphContext) -> Increment | End: + class Check42(BaseNode[MyState, None, int]): + async def run(self, ctx: GraphContext) -> Increment | End[int]: if ctx.state.number == 42: return Increment() else: - return End(None) + return End(ctx.state.number) never_42_graph = Graph(nodes=(Increment, Check42)) ``` @@ -68,7 +68,7 @@ async def run(self, ctx: GraphContext) -> Increment | End: """ name: str | None - node_defs: dict[str, NodeDef[StateT, RunEndT]] + node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] snapshot_state: Callable[[StateT], StateT] _state_type: type[StateT] | _utils.Unset = field(repr=False) _run_end_type: type[RunEndT] | _utils.Unset = field(repr=False) @@ -76,7 +76,7 @@ async def run(self, ctx: GraphContext) -> Increment | End: def __init__( self, *, - nodes: Sequence[type[BaseNode[StateT, RunEndT]]], + nodes: Sequence[type[BaseNode[StateT, DepsT, RunEndT]]], name: str | None = None, state_type: type[StateT] | _utils.Unset = _utils.UNSET, run_end_type: type[RunEndT] | _utils.Unset = _utils.UNSET, @@ -101,7 +101,7 @@ def __init__( self.snapshot_state = snapshot_state parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) - self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = {} + self.node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] = {} for node in nodes: self._register_node(node, parent_namespace) @@ -109,17 +109,19 @@ def __init__( async def run( self, - state: StateT, - start_node: BaseNode[StateT, RunEndT], + start_node: BaseNode[StateT, DepsT, RunEndT], *, + state: StateT = None, + deps: DepsT = None, infer_name: bool = True, ) -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: """Run the graph from a starting node until it ends. Args: - state: The initial state of the graph. start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. + state: The initial state of the graph. + deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. Returns: @@ -132,14 +134,14 @@ async def run( async def main(): state = MyState(1) - _, history = await never_42_graph.run(state, Increment()) + _, history = await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=2) print(len(history)) #> 3 state = MyState(41) - _, history = await never_42_graph.run(state, Increment()) + _, history = await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=43) print(len(history)) @@ -156,7 +158,7 @@ async def main(): start=start_node, ) as run_span: while True: - next_node = await self.next(state, start_node, history, infer_name=False) + next_node = await self.next(start_node, history, state=state, deps=deps, infer_name=False) if isinstance(next_node, End): history.append(EndStep(result=next_node)) run_span.set_attribute('history', history) @@ -173,9 +175,10 @@ async def main(): def run_sync( self, - state: StateT, - start_node: BaseNode[StateT, RunEndT], + start_node: BaseNode[StateT, DepsT, RunEndT], *, + state: StateT = None, + deps: DepsT = None, infer_name: bool = True, ) -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: """Run the graph synchronously. @@ -184,9 +187,10 @@ def run_sync( You therefore can't use this method inside async code or if there's an active event loop. Args: - state: The initial state of the graph. start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, you need to provide the starting node. + state: The initial state of the graph. + deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. Returns: @@ -194,22 +198,26 @@ def run_sync( """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - return asyncio.get_event_loop().run_until_complete(self.run(state, start_node, infer_name=False)) + return asyncio.get_event_loop().run_until_complete( + self.run(start_node, state=state, deps=deps, infer_name=False) + ) async def next( self, - state: StateT, - node: BaseNode[StateT, RunEndT], + node: BaseNode[StateT, DepsT, RunEndT], history: list[HistoryStep[StateT, RunEndT]], *, + state: StateT = None, + deps: DepsT = None, infer_name: bool = True, - ) -> BaseNode[StateT, Any] | End[RunEndT]: + ) -> BaseNode[StateT, DepsT, Any] | End[RunEndT]: """Run a node in the graph and return the next node to run. Args: - state: The current state of the graph. node: The node to run. history: The history of the graph run so far. NOTE: this will be mutated to add the new step. + state: The current state of the graph. + deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. Returns: @@ -221,7 +229,7 @@ async def next( if node_id not in self.node_defs: raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') - ctx = GraphContext(state) + ctx = GraphContext(state, deps) with _logfire.span('run node {node_id}', node_id=node_id, node=node): start_ts = _utils.now_utc() start = perf_counter() @@ -411,15 +419,17 @@ def _get_run_end_type(self) -> type[RunEndT]: for base in typing_extensions.get_original_bases(node_def.node): if typing_extensions.get_origin(base) is BaseNode: args = typing_extensions.get_args(base) - if len(args) == 2: - t = args[1] + if len(args) == 3: + t = args[2] if not _utils.is_never(t): return t # break the inner (bases) loop break raise exceptions.GraphSetupError('Could not infer run end type from nodes, please set `run_end_type`.') - def _register_node(self, node: type[BaseNode[StateT, RunEndT]], parent_namespace: dict[str, Any] | None) -> None: + def _register_node( + self, node: type[BaseNode[StateT, DepsT, RunEndT]], parent_namespace: dict[str, Any] | None + ) -> None: node_id = node.get_id() if existing_node := self.node_defs.get(node_id): raise exceptions.GraphSetupError( diff --git a/pydantic_graph/pydantic_graph/mermaid.py b/pydantic_graph/pydantic_graph/mermaid.py index 49e41ee267..ffa1fc0190 100644 --- a/pydantic_graph/pydantic_graph/mermaid.py +++ b/pydantic_graph/pydantic_graph/mermaid.py @@ -23,7 +23,7 @@ def generate_code( # noqa: C901 - graph: Graph[Any, Any], + graph: Graph[Any, Any, Any], /, *, start_node: Sequence[NodeIdent] | NodeIdent | None = None, @@ -110,7 +110,7 @@ def _node_ids(node_idents: Sequence[NodeIdent] | NodeIdent) -> Iterable[str]: def request_image( - graph: Graph[Any, Any], + graph: Graph[Any, Any, Any], /, **kwargs: Unpack[MermaidConfig], ) -> bytes: @@ -175,7 +175,7 @@ def request_image( def save_image( path: Path | str, - graph: Graph[Any, Any], + graph: Graph[Any, Any, Any], /, **kwargs: Unpack[MermaidConfig], ) -> None: @@ -247,7 +247,7 @@ class MermaidConfig(TypedDict, total=False): """An HTTPX client to use for requests, mostly for testing purposes.""" -NodeIdent: TypeAlias = 'type[BaseNode[Any, Any]] | BaseNode[Any, Any] | str' +NodeIdent: TypeAlias = 'type[BaseNode[Any, Any, Any]] | BaseNode[Any, Any, Any] | str' """A type alias for a node identifier. This can be: diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index ec97b16e84..3591e1b5be 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -14,23 +14,27 @@ else: StateT = TypeVar('StateT', default=None) -__all__ = 'GraphContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'RunEndT', 'NodeRunEndT' +__all__ = 'GraphContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'RunEndT', 'NodeRunEndT', 'DepsT' RunEndT = TypeVar('RunEndT', default=None) """Type variable for the return type of a graph [`run`][pydantic_graph.graph.Graph.run].""" NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) """Type variable for the return type of a node [`run`][pydantic_graph.nodes.BaseNode.run].""" +DepsT = TypeVar('DepsT', default=None) +"""Type variable for the dependencies of a graph and node.""" @dataclass -class GraphContext(Generic[StateT]): +class GraphContext(Generic[StateT, DepsT]): """Context for a graph.""" state: StateT """The state of the graph.""" + deps: DepsT + """Dependencies for the graph.""" -class BaseNode(ABC, Generic[StateT, NodeRunEndT]): +class BaseNode(ABC, Generic[StateT, DepsT, NodeRunEndT]): """Base class for a node.""" docstring_notes: ClassVar[bool] = False @@ -42,7 +46,7 @@ class BaseNode(ABC, Generic[StateT, NodeRunEndT]): """ @abstractmethod - async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: + async def run(self, ctx: GraphContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[NodeRunEndT]: """Run the node. This is an abstract method that must be implemented by subclasses. @@ -87,7 +91,7 @@ def get_note(cls) -> str | None: return docstring @classmethod - def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRunEndT]: + def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, DepsT, NodeRunEndT]: """Get the node definition.""" type_hints = get_type_hints(cls.run, localns=local_ns, include_extras=True) try: @@ -139,7 +143,7 @@ class Edge: @dataclass -class NodeDef(Generic[StateT, NodeRunEndT]): +class NodeDef(Generic[StateT, DepsT, NodeRunEndT]): """Definition of a node. This is a primarily internal representation of a node; in general, it shouldn't be necessary to use it directly. @@ -148,7 +152,7 @@ class NodeDef(Generic[StateT, NodeRunEndT]): mermaid graphs. """ - node: type[BaseNode[StateT, NodeRunEndT]] + node: type[BaseNode[StateT, DepsT, NodeRunEndT]] """The node definition itself.""" node_id: str """ID of the node.""" diff --git a/pydantic_graph/pydantic_graph/state.py b/pydantic_graph/pydantic_graph/state.py index 8db69fb0df..99bddd6138 100644 --- a/pydantic_graph/pydantic_graph/state.py +++ b/pydantic_graph/pydantic_graph/state.py @@ -14,7 +14,7 @@ from . import _utils from .nodes import BaseNode, End, RunEndT -__all__ = 'StateT', 'NodeStep', 'EndStep', 'HistoryStep', 'deep_copy_state' +__all__ = 'StateT', 'NodeStep', 'EndStep', 'HistoryStep', 'deep_copy_state', 'nodes_schema_var' StateT = TypeVar('StateT', default=None) @@ -35,7 +35,7 @@ class NodeStep(Generic[StateT, RunEndT]): state: StateT """The state of the graph after the node has been run.""" - node: Annotated[BaseNode[StateT, RunEndT], CustomNodeSchema()] + node: Annotated[BaseNode[StateT, Any, RunEndT], CustomNodeSchema()] """The node that was run.""" start_ts: datetime = field(default_factory=_utils.now_utc) """The timestamp when the node started running.""" @@ -53,7 +53,7 @@ def __post_init__(self): # Copy the state to prevent it from being modified by other code self.state = self.snapshot_state(self.state) - def data_snapshot(self) -> BaseNode[StateT, RunEndT]: + def data_snapshot(self) -> BaseNode[StateT, Any, RunEndT]: """Returns a deep copy of [`self.node`][pydantic_graph.state.NodeStep.node]. Useful for summarizing history. @@ -88,7 +88,7 @@ def data_snapshot(self) -> End[RunEndT]: """ -nodes_schema_var: ContextVar[Sequence[type[BaseNode[Any, Any]]]] = ContextVar('nodes_var') +nodes_schema_var: ContextVar[Sequence[type[BaseNode[Any, Any, Any]]]] = ContextVar('nodes_var') class CustomNodeSchema: diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index 7be7577e2e..2ae8bc5a9d 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -44,7 +44,7 @@ async def run(self, ctx: GraphContext) -> Double: return Double(len(self.input_data)) @dataclass - class Double(BaseNode[None, int]): + class Double(BaseNode[None, None, int]): input_data: int async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noqa: UP007 @@ -53,11 +53,11 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq else: return End(self.input_data * 2) - my_graph = Graph(nodes=(Float2String, String2Length, Double)) + my_graph = Graph[None, None, int](nodes=(Float2String, String2Length, Double)) + assert my_graph.name is None assert my_graph._get_state_type() is type(None) assert my_graph._get_run_end_type() is int - assert my_graph.name is None - result, history = await my_graph.run(None, Float2String(3.14)) + result, history = await my_graph.run(Float2String(3.14)) # len('3.14') * 2 == 8 assert result == 8 assert my_graph.name == 'my_graph' @@ -84,7 +84,7 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq EndStep(result=End(data=8), ts=IsNow(tz=timezone.utc)), ] ) - result, history = await my_graph.run(None, Float2String(3.14159)) + result, history = await my_graph.run(Float2String(3.14159)) # len('3.14159') == 7, 21 * 2 == 42 assert result == 42 assert history == snapshot( @@ -139,7 +139,7 @@ class Float2String(BaseNode): async def run(self, ctx: GraphContext) -> String2Length: raise NotImplementedError() - class String2Length(BaseNode[None, None]): + class String2Length(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() @@ -158,13 +158,13 @@ class Foo(BaseNode): async def run(self, ctx: GraphContext) -> Union[Bar, Spam]: # noqa: UP007 raise NotImplementedError() - class Bar(BaseNode[None, None]): + class Bar(BaseNode[None, None, None]): input_data: str async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() - class Spam(BaseNode[None, None]): + class Spam(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() @@ -185,17 +185,17 @@ class Foo(BaseNode): async def run(self, ctx: GraphContext) -> Eggs: raise NotImplementedError() - class Bar(BaseNode[None, None]): + class Bar(BaseNode[None, None, None]): input_data: str async def run(self, ctx: GraphContext) -> Eggs: raise NotImplementedError() - class Spam(BaseNode[None, None]): + class Spam(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> Eggs: raise NotImplementedError() - class Eggs(BaseNode[None, None]): + class Eggs(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() @@ -212,7 +212,7 @@ class Foo(BaseNode): async def run(self, ctx: GraphContext) -> Bar: raise NotImplementedError() - class Bar(BaseNode[None, None]): + class Bar(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() @@ -234,18 +234,18 @@ async def run(self, ctx: GraphContext) -> Bar: return Bar() @dataclass - class Bar(BaseNode[None, None]): + class Bar(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: return Spam() # type: ignore @dataclass - class Spam(BaseNode[None, None]): + class Spam(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: raise NotImplementedError() g = Graph(nodes=(Foo, Bar)) with pytest.raises(GraphRuntimeError) as exc_info: - await g.run(None, Foo()) + await g.run(Foo()) assert exc_info.value.message == snapshot('Node `test_run_node_not_in_graph..Spam()` is not in the graph.') @@ -257,7 +257,7 @@ async def run(self, ctx: GraphContext) -> Bar: return Bar() @dataclass - class Bar(BaseNode[None, None]): + class Bar(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: return 42 # type: ignore @@ -265,7 +265,7 @@ async def run(self, ctx: GraphContext) -> End[None]: assert g._get_state_type() is type(None) assert g._get_run_end_type() is type(None) with pytest.raises(GraphRuntimeError) as exc_info: - await g.run(None, Foo()) + await g.run(Foo()) assert exc_info.value.message == snapshot('Invalid node return type: `int`. Expected `BaseNode` or `End`.') @@ -284,13 +284,13 @@ async def run(self, ctx: GraphContext) -> Foo: g = Graph(nodes=(Foo, Bar)) assert g.name is None history: list[HistoryStep[None, Never]] = [] - n = await g.next(None, Foo(), history) + n = await g.next(Foo(), history) assert n == Bar() assert g.name == 'g' assert history == snapshot([NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat())]) assert isinstance(n, Bar) - n2 = await g.next(None, n, history) + n2 = await g.next(n, history) assert n2 == Foo() assert history == snapshot( @@ -299,3 +299,34 @@ async def run(self, ctx: GraphContext) -> Foo: NodeStep(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), ] ) + + +async def test_deps(): + @dataclass + class Deps: + a: int + b: int + + @dataclass + class Foo(BaseNode[None, Deps]): + async def run(self, ctx: GraphContext[None, Deps]) -> Bar: + assert isinstance(ctx.deps, Deps) + return Bar() + + @dataclass + class Bar(BaseNode[None, Deps, int]): + async def run(self, ctx: GraphContext[None, Deps]) -> End[int]: + assert isinstance(ctx.deps, Deps) + return End(123) + + g = Graph(nodes=(Foo, Bar)) + result, history = await g.run(Foo(), deps=Deps(1, 2)) + + assert result == 123 + assert history == snapshot( + [ + NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + NodeStep(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), + EndStep(result=End(data=123), ts=IsNow(tz=timezone.utc)), + ] + ) diff --git a/tests/graph/test_history.py b/tests/graph/test_history.py index 89b5b73a4d..ef3a88532e 100644 --- a/tests/graph/test_history.py +++ b/tests/graph/test_history.py @@ -30,7 +30,7 @@ async def run(self, ctx: GraphContext[MyState]) -> Bar: @dataclass -class Bar(BaseNode[MyState, int]): +class Bar(BaseNode[MyState, None, int]): async def run(self, ctx: GraphContext[MyState]) -> End[int]: ctx.state.y += 'y' return End(ctx.state.x * 2) @@ -45,8 +45,8 @@ async def run(self, ctx: GraphContext[MyState]) -> End[int]: Graph(nodes=(Foo, Bar)), ], ) -async def test_dump_load_history(graph: Graph[MyState, int]): - result, history = await graph.run(MyState(1, ''), Foo()) +async def test_dump_load_history(graph: Graph[MyState, None, int]): + result, history = await graph.run(Foo(), state=MyState(1, '')) assert result == snapshot(4) assert history == snapshot( [ @@ -104,7 +104,7 @@ async def test_dump_load_history(graph: Graph[MyState, int]): def test_one_node(): @dataclass - class MyNode(BaseNode[None, int]): + class MyNode(BaseNode[None, None, int]): async def run(self, ctx: GraphContext) -> End[int]: return End(123) diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 1d33f2c611..0a4302c2e1 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -26,7 +26,7 @@ async def run(self, ctx: GraphContext) -> Bar: @dataclass -class Bar(BaseNode[None, None]): +class Bar(BaseNode[None, None, None]): async def run(self, ctx: GraphContext) -> End[None]: return End(None) @@ -45,7 +45,7 @@ async def run(self, ctx: GraphContext) -> Annotated[Foo, Edge(label='spam to foo @dataclass -class Eggs(BaseNode[None, None]): +class Eggs(BaseNode[None, None, None]): """This is the docstring for Eggs.""" docstring_notes = False @@ -58,7 +58,7 @@ async def run(self, ctx: GraphContext) -> Annotated[End[None], Edge(label='eggs async def test_run_graph(): - result, history = await graph1.run(None, Foo()) + result, history = await graph1.run(Foo()) assert result is None assert history == snapshot( [ diff --git a/tests/graph/test_state.py b/tests/graph/test_state.py index 6d9899dc78..597a420258 100644 --- a/tests/graph/test_state.py +++ b/tests/graph/test_state.py @@ -27,7 +27,7 @@ async def run(self, ctx: GraphContext[MyState]) -> Bar: return Bar() @dataclass - class Bar(BaseNode[MyState, str]): + class Bar(BaseNode[MyState, None, str]): async def run(self, ctx: GraphContext[MyState]) -> End[str]: ctx.state.y += 'y' return End(f'x={ctx.state.x} y={ctx.state.y}') @@ -36,7 +36,7 @@ async def run(self, ctx: GraphContext[MyState]) -> End[str]: assert graph._get_state_type() is MyState assert graph._get_run_end_type() is str state = MyState(1, '') - result, history = await graph.run(state, Foo()) + result, history = await graph.run(Foo(), state=state) assert result == snapshot('x=2 y=y') assert history == snapshot( [ diff --git a/tests/typed_graph.py b/tests/typed_graph.py index 8c6ff9b0ea..14bd171669 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -4,7 +4,7 @@ from typing_extensions import assert_type -from pydantic_graph import BaseNode, End, Graph, GraphContext +from pydantic_graph import BaseNode, End, Graph, GraphContext, HistoryStep @dataclass @@ -29,7 +29,7 @@ class X: @dataclass -class Double(BaseNode[None, X]): +class Double(BaseNode[None, None, X]): input_data: int async def run(self, ctx: GraphContext) -> String2Length | End[X]: @@ -39,7 +39,7 @@ async def run(self, ctx: GraphContext) -> String2Length | End[X]: return End(X(self.input_data * 2)) -def use_double(node: BaseNode[None, X]) -> None: +def use_double(node: BaseNode[None, None, X]) -> None: """Shoe that `Double` is valid as a `BaseNode[None, int, X]`.""" print(node) @@ -47,18 +47,18 @@ def use_double(node: BaseNode[None, X]) -> None: use_double(Double(1)) -g1 = Graph[None, X]( +g1 = Graph[None, None, X]( nodes=( Float2String, String2Length, Double, ) ) -assert_type(g1, Graph[None, X]) +assert_type(g1, Graph[None, None, X]) g2 = Graph(nodes=(Double,)) -assert_type(g2, Graph[None, X]) +assert_type(g2, Graph[None, None, X]) g3 = Graph( nodes=( @@ -68,7 +68,47 @@ def use_double(node: BaseNode[None, X]) -> None: ) ) # because String2Length came before Double, the output type is Any -assert_type(g3, Graph[None, X]) +assert_type(g3, Graph[None, None, X]) Graph[None, bytes](nodes=(Float2String, String2Length, Double)) # type: ignore[arg-type] Graph[None, str](nodes=[Double]) # type: ignore[list-item] + + +@dataclass +class MyState: + x: int + + +@dataclass +class MyDeps: + y: str + + +@dataclass +class A(BaseNode[MyState, MyDeps]): + async def run(self, ctx: GraphContext[MyState, MyDeps]) -> B: + assert ctx.state.x == 1 + assert ctx.deps.y == 'y' + return B() + + +@dataclass +class B(BaseNode[MyState, MyDeps, int]): + async def run(self, ctx: GraphContext[MyState, MyDeps]) -> End[int]: + return End(42) + + +g4 = Graph[MyState, MyDeps, int](nodes=(A, B)) +assert_type(g4, Graph[MyState, MyDeps, int]) + +g5 = Graph(nodes=(A, B)) +assert_type(g5, Graph[MyState, MyDeps, int]) + + +def run_g5() -> None: + g5.run_sync(A()) # pyright: ignore[reportArgumentType] + g5.run_sync(A(), state=MyState(x=1)) # pyright: ignore[reportArgumentType] + g5.run_sync(A(), deps=MyDeps(y='y')) # pyright: ignore[reportArgumentType] + ans, history = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y')) + assert_type(ans, int) + assert_type(history, list[HistoryStep[MyState, int]])