Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
dab200d
start pydantic-ai-graph
samuelcolvin Dec 22, 2024
bdd5c2c
lower case state machine
samuelcolvin Dec 22, 2024
a65df82
starting tests
samuelcolvin Dec 22, 2024
af0ba32
add history and logfire
samuelcolvin Dec 22, 2024
877ee36
add example, alter types
samuelcolvin Dec 22, 2024
5cf3ad0
fix dependencies
samuelcolvin Dec 23, 2024
544b6c8
fix ci deps
samuelcolvin Dec 23, 2024
7288cc9
fix tests for other versions
samuelcolvin Dec 23, 2024
0572bda
change node test times
samuelcolvin Dec 23, 2024
d10dc87
pydantic-ai-graph - simplify public generics (#539)
dmontagu Jan 2, 2025
06428bb
Typo in Graph Documentation (#596)
izzyacademy Jan 3, 2025
1a5d3e2
fix linting
samuelcolvin Jan 7, 2025
6faaf97
separate mermaid logic
samuelcolvin Jan 7, 2025
892f661
fix graph type checking
samuelcolvin Jan 7, 2025
02b7f28
bump
samuelcolvin Jan 7, 2025
be3f689
adding node highlighting to mermaid, testing locally
samuelcolvin Jan 7, 2025
749cc31
bump
samuelcolvin Jan 7, 2025
c0d35da
fix type checking imports
samuelcolvin Jan 7, 2025
190fe40
fix for python 3.9
samuelcolvin Jan 7, 2025
ccc0c17
simplify mermaid config
samuelcolvin Jan 8, 2025
50b590f
remove GraphRunner
samuelcolvin Jan 8, 2025
8ac10c7
add Interrupt
samuelcolvin Jan 9, 2025
b63ca74
remove interrupt, replace with "next()"
samuelcolvin Jan 9, 2025
745e3d5
address comments
samuelcolvin Jan 9, 2025
1370f88
switch name to pydantic-graph
samuelcolvin Jan 10, 2025
24cdd35
allow labeling edges and notes for docstrings
samuelcolvin Jan 10, 2025
6990c49
allow notes to be disabled
samuelcolvin Jan 10, 2025
4f69960
adding graph tests
samuelcolvin Jan 10, 2025
29f8a95
more mermaid tests, fix 3.9
samuelcolvin Jan 10, 2025
02c4dc0
rename node to start_node in graph.run()
samuelcolvin Jan 10, 2025
fc7dfc6
more tests for graphs
samuelcolvin Jan 10, 2025
db9543e
coverage in tests
samuelcolvin Jan 10, 2025
e9d1d2b
cleanup graph properties
samuelcolvin Jan 10, 2025
81cb333
infer graph name
samuelcolvin Jan 11, 2025
88c1d46
fix for 3.9
samuelcolvin Jan 11, 2025
0e3ecb3
adding API docs
samuelcolvin Jan 11, 2025
3284ff1
fix state, more docs
samuelcolvin Jan 11, 2025
b4d6c1c
fix graph api examples
samuelcolvin Jan 11, 2025
22708d9
starting graph documentation
samuelcolvin Jan 11, 2025
d1af561
fix examples
samuelcolvin Jan 11, 2025
9d5f45c
more graph documentation
samuelcolvin Jan 11, 2025
a3a0ddc
add GenAI example
samuelcolvin Jan 11, 2025
3994899
more graph docs
samuelcolvin Jan 12, 2025
ecc2434
extending graph docs
samuelcolvin Jan 13, 2025
5717bd5
fix history serialization
samuelcolvin Jan 14, 2025
3cb79c8
add history (de)serialization tests
samuelcolvin Jan 14, 2025
08bb7dd
add mermaid diagram section to graph docs
samuelcolvin Jan 14, 2025
466a7df
fix tests
samuelcolvin Jan 14, 2025
8098d34
add exceptions docs
samuelcolvin Jan 14, 2025
a3f507a
docs tweaks
samuelcolvin Jan 14, 2025
4e9b516
copy edits from @dmontagu
samuelcolvin Jan 15, 2025
a834eed
fix pydantic-graph readme
samuelcolvin Jan 15, 2025
447a259
adding deps to graphs
samuelcolvin Jan 15, 2025
3a1cddd
fix build
samuelcolvin Jan 15, 2025
46d8833
Merge branch 'graph' into graph-deps
samuelcolvin Jan 15, 2025
d78a9d1
fix type hint
samuelcolvin Jan 15, 2025
d653c0a
add deps example and tests
samuelcolvin Jan 15, 2025
e0ab64b
cleanup
samuelcolvin Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove interrupt, replace with "next()"
  • Loading branch information
samuelcolvin committed Jan 10, 2025
commit b63ca74f5bd8729c8b9c0dabe2558b98233e160d
3 changes: 2 additions & 1 deletion examples/pydantic_ai_examples/email_extract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ async def run(

async def main():
state = State(email_content=email)
result, history = await graph.run(state, ExtractEvent())
history = []
result = await graph.run(state, ExtractEvent())
debug(result, history)


Expand Down
6 changes: 2 additions & 4 deletions pydantic_ai_graph/pydantic_ai_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from .graph import Graph, GraphRun
from .nodes import BaseNode, End, GraphContext, Interrupt
from .graph import Graph
from .nodes import BaseNode, End, GraphContext
from .state import AbstractState, EndEvent, HistoryStep, NextNodeEvent

__all__ = (
'Graph',
'GraphRun',
'BaseNode',
'End',
'Interrupt',
'GraphContext',
'AbstractState',
'EndEvent',
Expand Down
108 changes: 46 additions & 62 deletions pydantic_ai_graph/pydantic_ai_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@

import inspect
from collections.abc import Sequence
from dataclasses import dataclass, field
from dataclasses import dataclass
from pathlib import Path
from time import perf_counter
from typing import TYPE_CHECKING, Generic
from typing import TYPE_CHECKING, Any, Generic

import logfire_api
from typing_extensions import Never, ParamSpec, TypeVar, Unpack, assert_never

from . import _utils, mermaid
from ._utils import get_parent_namespace
from .nodes import BaseNode, End, GraphContext, Interrupt, NodeDef, RunInterrupt
from .state import EndEvent, HistoryStep, InterruptEvent, NextNodeEvent, StateT
from .nodes import BaseNode, End, GraphContext, NodeDef
from .state import EndEvent, HistoryStep, NextNodeEvent, StateT

__all__ = 'Graph', 'GraphRun'
__all__ = ('Graph',)

_logfire = logfire_api.Logfire(otel_scope='pydantic-ai-graph')

Expand Down Expand Up @@ -74,16 +74,48 @@ def _validate_edges(self):
b = '\n'.join(f' {be}' for be in bad_edges_list)
raise ValueError(f'Nodes are referenced in the graph but not included in the graph:\n{b}')

async def next(
self, state: StateT, node: BaseNode[StateT, RunEndT], history: list[HistoryStep[StateT, RunEndT]]
) -> BaseNode[StateT, Any] | End[RunEndT]:
node_id = node.get_id()
if node_id not in self.node_defs:
raise TypeError(f'Node "{node}" is not in the graph.')

history_step: NextNodeEvent[StateT, RunEndT] | None = NextNodeEvent(state, node)
history.append(history_step)

ctx = GraphContext(state)
with _logfire.span('run node {node_id}', node_id=node_id, node=node):
start = perf_counter()
next_node = await node.run(ctx)
history_step.duration = perf_counter() - start
return next_node

async def run(
self, state: StateT, node: BaseNode[StateT, RunEndT]
) -> tuple[End[RunEndT] | RunInterrupt[StateT], list[HistoryStep[StateT, RunEndT]]]:
if not isinstance(node, self.nodes):
raise ValueError(f'Node "{node}" is not in the graph.')
run = GraphRun[StateT, RunEndT](state=state)
# TODO: Infer the graph name properly
result = await run.run(self.name or 'graph', node)
history = run.history
return result, history
self,
state: StateT,
node: BaseNode[StateT, RunEndT],
) -> tuple[End[RunEndT], list[HistoryStep[StateT, RunEndT]]]:
history: list[HistoryStep[StateT, RunEndT]] = []

with _logfire.span(
'{graph_name} run {start=}',
graph_name=self.name or 'graph',
start=node,
) as run_span:
while True:
next_node = await self.next(state, node, history=history)
if isinstance(next_node, End):
history.append(EndEvent(state, next_node))
run_span.set_attribute('history', history)
return next_node, history
elif isinstance(next_node, BaseNode):
node = next_node
else:
if TYPE_CHECKING:
assert_never(next_node)
else:
raise TypeError(f'Invalid node type: {type(next_node)}. Expected `BaseNode` or `End`.')

def mermaid_code(
self,
Expand All @@ -101,51 +133,3 @@ def mermaid_image(self, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes:

def mermaid_save(self, path: Path | str, /, **kwargs: Unpack[mermaid.MermaidConfig]) -> None:
mermaid.save_image(path, self, **kwargs)


@dataclass
class GraphRun(Generic[StateT, RunEndT]):
"""Stateful run of a graph."""

state: StateT
history: list[HistoryStep[StateT, RunEndT]] = field(default_factory=list)

async def run(
self, graph_name: str, start: BaseNode[StateT, RunEndT], infer_name: bool = True
) -> End[RunEndT] | RunInterrupt[StateT]:
current_node = start

with _logfire.span(
'{graph_name} run {start=}',
graph_name=graph_name,
start=start,
) as run_span:
while True:
next_node = await self.step(current_node)
if isinstance(next_node, (Interrupt, End)):
if isinstance(next_node, End):
self.history.append(EndEvent(self.state, next_node))
else:
self.history.append(InterruptEvent(self.state, next_node))
run_span.set_attribute('history', self.history)
return next_node
elif isinstance(next_node, BaseNode):
current_node = next_node
else:
if TYPE_CHECKING:
assert_never(next_node)
else:
raise TypeError(f'Invalid node type: {type(next_node)}. Expected `BaseNode` or `End`.')

async def step(
self, node: BaseNode[StateT, RunEndT]
) -> BaseNode[StateT, RunEndT] | End[RunEndT] | RunInterrupt[StateT]:
history_step = NextNodeEvent(self.state, node)
self.history.append(history_step)

ctx = GraphContext(self.state)
with _logfire.span('run node {node_id}', node_id=node.get_id()):
start = perf_counter()
next_node = await node.run(ctx)
history_step.duration = perf_counter() - start
return next_node
10 changes: 2 additions & 8 deletions pydantic_ai_graph/pydantic_ai_graph/mermaid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import base64
from collections.abc import Iterable, Sequence
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Literal

Expand Down Expand Up @@ -43,19 +42,14 @@ def generate_code(
raise LookupError(f'Start node "{node_id}" is not in the graph.')

node_order = {node_id: index for index, node_id in enumerate(graph.node_defs)}
after_interrupt_nodes = set(chain(*(node_def.after_interrupt_nodes for node_def in graph.node_defs.values())))

lines = ['graph TD']
for node in graph.nodes:
node_id = node.get_id()
node_def = graph.node_defs[node_id]

# we use square brackets (square box) for nodes that can interrupt the flow,
# and round brackets (rounded box) for nodes that cannot interrupt the flow
if node_id in after_interrupt_nodes:
mermaid_name = f'[{node_id}]'
else:
mermaid_name = f'({node_id})'
# we use round brackets (rounded box) for nodes other than the start and end
mermaid_name = f'({node_id})'
if node_id in start_node_ids:
lines.append(f' START --> {node_id}{mermaid_name}')
if node_def.returns_base_node:
Expand Down
31 changes: 3 additions & 28 deletions pydantic_ai_graph/pydantic_ai_graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from abc import abstractmethod
from dataclasses import dataclass
from functools import cache
from typing import Any, Generic, get_args, get_origin, get_type_hints
from typing import Any, Generic, get_origin, get_type_hints

from typing_extensions import Never, TypeVar

from . import _utils
from .state import StateT

__all__ = 'GraphContext', 'BaseNode', 'End', 'Interrupt', 'RunInterrupt', 'NodeDef'
__all__ = 'GraphContext', 'BaseNode', 'End', 'NodeDef'

RunEndT = TypeVar('RunEndT', default=None)
NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never)
Expand All @@ -27,9 +27,7 @@ class BaseNode(Generic[StateT, NodeRunEndT]):
"""Base class for a node."""

@abstractmethod
async def run(
self, ctx: GraphContext[StateT]
) -> BaseNode[StateT, Any] | End[NodeRunEndT] | RunInterrupt[StateT]: ...
async def run(self, ctx: GraphContext[StateT]) -> BaseNode[StateT, Any] | End[NodeRunEndT]: ...

@classmethod
@cache
Expand All @@ -47,17 +45,10 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRu
next_node_ids: set[str] = set()
returns_end: bool = False
returns_base_node: bool = False
after_interrupt_nodes: list[str] = []
for return_type in _utils.get_union_args(return_hint):
return_type_origin = get_origin(return_type) or return_type
if return_type_origin is End:
returns_end = True
elif issubclass(return_type_origin, Interrupt):
interrupt_args = get_args(return_type)
assert len(interrupt_args) == 1, f'Invalid Interrupt return type: {return_type}'
next_node_id = interrupt_args[0].get_id()
next_node_ids.add(next_node_id)
after_interrupt_nodes.append(next_node_id)
elif return_type_origin is BaseNode:
# TODO: Should we disallow this?
returns_base_node = True
Expand All @@ -71,7 +62,6 @@ def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, NodeRu
cls.get_id(),
next_node_ids,
returns_end,
after_interrupt_nodes,
returns_base_node,
)

Expand All @@ -83,19 +73,6 @@ class End(Generic[RunEndT]):
data: RunEndT


InterruptNextNodeT = TypeVar('InterruptNextNodeT', covariant=True, bound=BaseNode[Any, Any])


@dataclass
class Interrupt(Generic[InterruptNextNodeT]):
"""Type to return from a node to signal that the run should be interrupted."""

node: InterruptNextNodeT


RunInterrupt = Interrupt[BaseNode[StateT, Any]]


@dataclass
class NodeDef(Generic[StateT, NodeRunEndT]):
"""Definition of a node.
Expand All @@ -112,7 +89,5 @@ class NodeDef(Generic[StateT, NodeRunEndT]):
"""IDs of the nodes that can be called next."""
returns_end: bool
"""The node definition returns an `End`, hence the node and end the run."""
after_interrupt_nodes: list[str]
"""Nodes that can be returned within an `Interrupt`."""
returns_base_node: bool
"""The node definition returns a `BaseNode`, hence any node in the next can be called next."""
28 changes: 5 additions & 23 deletions pydantic_ai_graph/pydantic_ai_graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

from . import _utils

__all__ = 'AbstractState', 'StateT', 'NextNodeEvent', 'EndEvent', 'InterruptEvent', 'HistoryStep'
__all__ = 'AbstractState', 'StateT', 'NextNodeEvent', 'EndEvent', 'HistoryStep'

if TYPE_CHECKING:
from pydantic_ai_graph import BaseNode
from pydantic_ai_graph.nodes import End, RunInterrupt
from pydantic_ai_graph.nodes import End


class AbstractState(ABC):
Expand Down Expand Up @@ -50,28 +50,10 @@ def __post_init__(self):
# Copy the state to prevent it from being modified by other code
self.state = _deep_copy_state(self.state)

def node_summary(self) -> str:
def summary(self) -> str:
return str(self.node)


@dataclass
class InterruptEvent(Generic[StateT]):
"""History step describing the interruption of a graph run."""

state: StateT
result: RunInterrupt[StateT]
ts: datetime = field(default_factory=_utils.now_utc)

kind: Literal['interrupt'] = 'interrupt'

def __post_init__(self):
# Copy the state to prevent it from being modified by other code
self.state = _deep_copy_state(self.state)

def node_summary(self) -> str:
return str(self.result)


@dataclass
class EndEvent(Generic[StateT, RunEndT]):
"""History step describing the end of a graph run."""
Expand All @@ -86,7 +68,7 @@ def __post_init__(self):
# Copy the state to prevent it from being modified by other code
self.state = _deep_copy_state(self.state)

def node_summary(self) -> str:
def summary(self) -> str:
return str(self.result)


Expand All @@ -97,4 +79,4 @@ def _deep_copy_state(state: StateT) -> StateT:
return state.deep_copy()


HistoryStep = Union[NextNodeEvent[StateT, RunEndT], InterruptEvent[StateT], EndEvent[StateT, RunEndT]]
HistoryStep = Union[NextNodeEvent[StateT, RunEndT], EndEvent[StateT, RunEndT]]