Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
dab200d
start pydantic-ai-graph
samuelcolvin Dec 22, 2024
bdd5c2c
lower case state machine
samuelcolvin Dec 22, 2024
a65df82
starting tests
samuelcolvin Dec 22, 2024
af0ba32
add history and logfire
samuelcolvin Dec 22, 2024
877ee36
add example, alter types
samuelcolvin Dec 22, 2024
5cf3ad0
fix dependencies
samuelcolvin Dec 23, 2024
544b6c8
fix ci deps
samuelcolvin Dec 23, 2024
7288cc9
fix tests for other versions
samuelcolvin Dec 23, 2024
0572bda
change node test times
samuelcolvin Dec 23, 2024
d10dc87
pydantic-ai-graph - simplify public generics (#539)
dmontagu Jan 2, 2025
06428bb
Typo in Graph Documentation (#596)
izzyacademy Jan 3, 2025
1a5d3e2
fix linting
samuelcolvin Jan 7, 2025
6faaf97
separate mermaid logic
samuelcolvin Jan 7, 2025
892f661
fix graph type checking
samuelcolvin Jan 7, 2025
02b7f28
bump
samuelcolvin Jan 7, 2025
be3f689
adding node highlighting to mermaid, testing locally
samuelcolvin Jan 7, 2025
749cc31
bump
samuelcolvin Jan 7, 2025
c0d35da
fix type checking imports
samuelcolvin Jan 7, 2025
190fe40
fix for python 3.9
samuelcolvin Jan 7, 2025
ccc0c17
simplify mermaid config
samuelcolvin Jan 8, 2025
50b590f
remove GraphRunner
samuelcolvin Jan 8, 2025
8ac10c7
add Interrupt
samuelcolvin Jan 9, 2025
b63ca74
remove interrupt, replace with "next()"
samuelcolvin Jan 9, 2025
745e3d5
address comments
samuelcolvin Jan 9, 2025
1370f88
switch name to pydantic-graph
samuelcolvin Jan 10, 2025
24cdd35
allow labeling edges and notes for docstrings
samuelcolvin Jan 10, 2025
6990c49
allow notes to be disabled
samuelcolvin Jan 10, 2025
4f69960
adding graph tests
samuelcolvin Jan 10, 2025
29f8a95
more mermaid tests, fix 3.9
samuelcolvin Jan 10, 2025
02c4dc0
rename node to start_node in graph.run()
samuelcolvin Jan 10, 2025
fc7dfc6
more tests for graphs
samuelcolvin Jan 10, 2025
db9543e
coverage in tests
samuelcolvin Jan 10, 2025
e9d1d2b
cleanup graph properties
samuelcolvin Jan 10, 2025
81cb333
infer graph name
samuelcolvin Jan 11, 2025
88c1d46
fix for 3.9
samuelcolvin Jan 11, 2025
0e3ecb3
adding API docs
samuelcolvin Jan 11, 2025
3284ff1
fix state, more docs
samuelcolvin Jan 11, 2025
b4d6c1c
fix graph api examples
samuelcolvin Jan 11, 2025
22708d9
starting graph documentation
samuelcolvin Jan 11, 2025
d1af561
fix examples
samuelcolvin Jan 11, 2025
9d5f45c
more graph documentation
samuelcolvin Jan 11, 2025
a3a0ddc
add GenAI example
samuelcolvin Jan 11, 2025
3994899
more graph docs
samuelcolvin Jan 12, 2025
ecc2434
extending graph docs
samuelcolvin Jan 13, 2025
5717bd5
fix history serialization
samuelcolvin Jan 14, 2025
3cb79c8
add history (de)serialization tests
samuelcolvin Jan 14, 2025
08bb7dd
add mermaid diagram section to graph docs
samuelcolvin Jan 14, 2025
466a7df
fix tests
samuelcolvin Jan 14, 2025
8098d34
add exceptions docs
samuelcolvin Jan 14, 2025
a3f507a
docs tweaks
samuelcolvin Jan 14, 2025
4e9b516
copy edits from @dmontagu
samuelcolvin Jan 15, 2025
a834eed
fix pydantic-graph readme
samuelcolvin Jan 15, 2025
447a259
adding deps to graphs
samuelcolvin Jan 15, 2025
3a1cddd
fix build
samuelcolvin Jan 15, 2025
46d8833
Merge branch 'graph' into graph-deps
samuelcolvin Jan 15, 2025
d78a9d1
fix type hint
samuelcolvin Jan 15, 2025
d653c0a
add deps example and tests
samuelcolvin Jan 15, 2025
e0ab64b
cleanup
samuelcolvin Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
adding graph tests
  • Loading branch information
samuelcolvin committed Jan 10, 2025
commit 4f69960044de1aaa9444c3331a0bc1093e57d110
3 changes: 3 additions & 0 deletions pydantic_graph/pydantic_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .exceptions import GraphRuntimeError, GraphSetupError
from .graph import Graph
from .nodes import BaseNode, Edge, End, GraphContext
from .state import AbstractState, EndEvent, HistoryStep, NodeEvent
Expand All @@ -12,4 +13,6 @@
'EndEvent',
'HistoryStep',
'NodeEvent',
'GraphSetupError',
'GraphRuntimeError',
)
20 changes: 20 additions & 0 deletions pydantic_graph/pydantic_graph/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
class GraphSetupError(TypeError):
"""Error caused by an incorrectly configured graph."""

message: str
"""Description of the mistake."""

def __init__(self, message: str):
self.message = message
super().__init__(message)


class GraphRuntimeError(RuntimeError):
"""Error caused by an issue during graph execution."""

message: str
"""The error message."""

def __init__(self, message: str):
self.message = message
super().__init__(message)
44 changes: 27 additions & 17 deletions pydantic_graph/pydantic_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from . import _utils, mermaid
from ._utils import get_parent_namespace
from .exceptions import GraphRuntimeError, GraphSetupError
from .nodes import BaseNode, End, GraphContext, NodeDef
from .state import EndEvent, HistoryStep, NodeEvent, StateT

Expand Down Expand Up @@ -44,42 +45,42 @@ def __init__(
_nodes_by_id: dict[str, type[BaseNode[StateT, RunEndT]]] = {}
for node in nodes:
node_id = node.get_id()
if (existing_node := _nodes_by_id.get(node_id)) and existing_node is not node:
raise ValueError(f'Node ID "{node_id}" is not unique — found in {existing_node} and {node}')
if existing_node := _nodes_by_id.get(node_id):
raise GraphSetupError(f'Node ID `{node_id}` is not unique — found in {existing_node} and {node}')
else:
_nodes_by_id[node_id] = node
self.nodes = tuple(_nodes_by_id.values())

parent_namespace = get_parent_namespace(inspect.currentframe())
self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = {}
for node in self.nodes:
self.node_defs[node.get_id()] = node.get_node_def(parent_namespace)
self.node_defs: dict[str, NodeDef[StateT, RunEndT]] = {
node.get_id(): node.get_node_def(parent_namespace) for node in self.nodes
}

self._validate_edges()

def _validate_edges(self):
known_node_ids = set(self.node_defs.keys())
known_node_ids = self.node_defs.keys()
bad_edges: dict[str, list[str]] = {}

for node_id, node_def in self.node_defs.items():
node_bad_edges = node_def.next_node_edges.keys() - known_node_ids
for bad_edge in node_bad_edges:
bad_edges.setdefault(bad_edge, []).append(f'"{node_id}"')
for edge in node_def.next_node_edges.keys():
if edge not in known_node_ids:
bad_edges.setdefault(edge, []).append(f'`{node_id}`')

if bad_edges:
bad_edges_list = [f'"{k}" is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()]
bad_edges_list = [f'`{k}` is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()]
if len(bad_edges_list) == 1:
raise ValueError(f'{bad_edges_list[0]} but not included in the graph.')
raise GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.')
else:
b = '\n'.join(f' {be}' for be in bad_edges_list)
raise ValueError(f'Nodes are referenced in the graph but not included in the graph:\n{b}')
raise GraphSetupError(f'Nodes are referenced in the graph but not included in the graph:\n{b}')

async def next(
self, state: StateT, node: BaseNode[StateT, RunEndT], history: list[HistoryStep[StateT, RunEndT]]
) -> BaseNode[StateT, Any] | End[RunEndT]:
node_id = node.get_id()
if node_id not in self.node_defs:
raise TypeError(f'Node "{node}" is not in the graph.')
raise GraphRuntimeError(f'Node `{node}` is not in the graph.')

history_step: NodeEvent[StateT, RunEndT] | None = NodeEvent(state, node)
history.append(history_step)
Expand All @@ -95,7 +96,7 @@ async def run(
self,
state: StateT,
node: BaseNode[StateT, RunEndT],
) -> tuple[End[RunEndT], list[HistoryStep[StateT, RunEndT]]]:
) -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]:
history: list[HistoryStep[StateT, RunEndT]] = []

with _logfire.span(
Expand All @@ -108,24 +109,33 @@ async def run(
if isinstance(next_node, End):
history.append(EndEvent(state, next_node))
run_span.set_attribute('history', history)
return next_node, history
return next_node.data, history
elif isinstance(next_node, BaseNode):
node = next_node
else:
if TYPE_CHECKING:
assert_never(next_node)
else:
raise TypeError(f'Invalid node type: {type(next_node)}. Expected `BaseNode` or `End`.')
raise GraphRuntimeError(
f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.'
)

def mermaid_code(
self,
*,
start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None,
highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None,
highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS,
edge_labels: bool = True,
notes: bool = True,
) -> str:
return mermaid.generate_code(
self, start_node=start_node, highlighted_nodes=highlighted_nodes, highlight_css=highlight_css
self,
start_node=start_node,
highlighted_nodes=highlighted_nodes,
highlight_css=highlight_css,
edge_labels=edge_labels,
notes=notes,
)

def mermaid_image(self, **kwargs: Unpack[mermaid.MermaidConfig]) -> bytes:
Expand Down
2 changes: 1 addition & 1 deletion pydantic_graph/pydantic_graph/mermaid.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def generate_code(
if node_id in start_node_ids:
lines.append(f' [*] --> {node_id}')
if node_def.returns_base_node:
for next_node_id in graph.nodes:
for next_node_id in graph.node_defs:
lines.append(f' {node_id} --> {next_node_id}')
else:
for next_node_id, edge in node_def.next_node_edges.items():
Expand Down
113 changes: 110 additions & 3 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from dataclasses import dataclass
from datetime import timezone
from functools import cache
from typing import Union

import pytest
from dirty_equals import IsStr
from inline_snapshot import snapshot

from pydantic_graph import BaseNode, End, EndEvent, Graph, GraphContext, NodeEvent
from pydantic_graph import BaseNode, End, EndEvent, Graph, GraphContext, GraphRuntimeError, GraphSetupError, NodeEvent

from .conftest import IsFloat, IsNow

Expand Down Expand Up @@ -42,7 +44,7 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq
g = Graph[None, int](nodes=(Float2String, String2Length, Double))
result, history = await g.run(None, Float2String(3.14))
# len('3.14') * 2 == 8
assert result == End(8)
assert result == 8
assert history == snapshot(
[
NodeEvent(
Expand Down Expand Up @@ -72,7 +74,7 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq
)
result, history = await g.run(None, Float2String(3.14159))
# len('3.14159') == 7, 21 * 2 == 42
assert result == End(42)
assert result == 42
assert history == snapshot(
[
NodeEvent(
Expand Down Expand Up @@ -112,3 +114,108 @@ async def run(self, ctx: GraphContext) -> Union[String2Length, End[int]]: # noq
),
]
)


def test_one_bad_node():
class Float2String(BaseNode):
async def run(self, ctx: GraphContext) -> String2Length:
return String2Length()

class String2Length(BaseNode[None, None]):
async def run(self, ctx: GraphContext) -> End[None]:
return End(None)

with pytest.raises(GraphSetupError) as exc_info:
Graph(nodes=(Float2String,))

assert exc_info.value.message == snapshot(
'`String2Length` is referenced by `Float2String` but not included in the graph.'
)


def test_two_bad_nodes():
class Float2String(BaseNode):
input_data: float

async def run(self, ctx: GraphContext) -> String2Length | Double:
raise NotImplementedError()

class String2Length(BaseNode[None, None]):
input_data: str

async def run(self, ctx: GraphContext) -> End[None]:
return End(None)

class Double(BaseNode[None, None]):
async def run(self, ctx: GraphContext) -> End[None]:
return End(None)

with pytest.raises(GraphSetupError) as exc_info:
Graph(nodes=(Float2String,))

assert exc_info.value.message == snapshot("""\
Nodes are referenced in the graph but not included in the graph:
`String2Length` is referenced by `Float2String`
`Double` is referenced by `Float2String`\
""")


def test_duplicate_id():
class Foo(BaseNode):
async def run(self, ctx: GraphContext) -> Bar:
return Bar()

class Bar(BaseNode[None, None]):
async def run(self, ctx: GraphContext) -> End[None]:
return End(None)

@classmethod
@cache
def get_id(cls) -> str:
return 'Foo'

with pytest.raises(GraphSetupError) as exc_info:
Graph(nodes=(Foo, Bar))

assert exc_info.value.message == snapshot(IsStr(regex='Node ID `Foo` is not unique — found in.+'))


async def test_run_node_not_in_graph():
@dataclass
class Foo(BaseNode):
async def run(self, ctx: GraphContext) -> Bar:
return Bar()

@dataclass
class Bar(BaseNode[None, None]):
async def run(self, ctx: GraphContext) -> End[None]:
return Spam() # type: ignore

@dataclass
class Spam(BaseNode[None, None]):
async def run(self, ctx: GraphContext) -> End[None]:
return End(None)

g = Graph(nodes=(Foo, Bar))
with pytest.raises(GraphRuntimeError) as exc_info:
await g.run(None, Foo())

assert exc_info.value.message == snapshot('Node `test_run_node_not_in_graph.<locals>.Spam()` is not in the graph.')


async def test_run_return_other():
@dataclass
class Foo(BaseNode):
async def run(self, ctx: GraphContext) -> Bar:
return Bar()

@dataclass
class Bar(BaseNode[None, None]):
async def run(self, ctx: GraphContext) -> End[None]:
return 42 # type: ignore

g = Graph(nodes=(Foo, Bar))
with pytest.raises(GraphRuntimeError) as exc_info:
await g.run(None, Foo())

assert exc_info.value.message == snapshot('Invalid node return type: `int`. Expected `BaseNode` or `End`.')
Loading
Loading