Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b4e2bda
state persistence
samuelcolvin Feb 19, 2025
4c1d50d
fixing tests
samuelcolvin Feb 23, 2025
b737c97
tests all passing
samuelcolvin Feb 24, 2025
1c46cf9
simplify
samuelcolvin Feb 26, 2025
3889e53
Merge branch 'main' into state-persistence
samuelcolvin Mar 3, 2025
456560d
fixing tests
samuelcolvin Mar 3, 2025
767e08d
fix tests for 3.9 etc
samuelcolvin Mar 3, 2025
daffea5
refactoring state persistence
samuelcolvin Mar 6, 2025
4562244
snapshot id on node
samuelcolvin Mar 7, 2025
2bdb029
fixing snapshot id
samuelcolvin Mar 8, 2025
e9d8052
improving docs
samuelcolvin Mar 8, 2025
576d4af
Merge branch 'main' into state-persistence
samuelcolvin Mar 8, 2025
09a5174
fix spans
samuelcolvin Mar 8, 2025
e271af5
more tests, improve coverage
samuelcolvin Mar 8, 2025
3d93cb0
add file persistence
samuelcolvin Mar 9, 2025
9581979
improve coverage
samuelcolvin Mar 9, 2025
88723e3
fix for 3.9 and 3.10
samuelcolvin Mar 9, 2025
2a9f90e
improve coverage
samuelcolvin Mar 9, 2025
294cbd2
more docs
samuelcolvin Mar 10, 2025
f1e4ca1
complete docs
samuelcolvin Mar 10, 2025
1edbb06
Merge branch 'main' into state-persistence
samuelcolvin Mar 10, 2025
0ad49fd
replace human in the loop example
samuelcolvin Mar 10, 2025
27aee1a
Apply suggestions from code review
samuelcolvin Mar 12, 2025
76cae5a
Apply suggestions from code review
samuelcolvin Mar 12, 2025
f51699e
tweak docs as suggested
samuelcolvin Mar 12, 2025
1cccf75
add iter_from_persistence, deprecate next()
samuelcolvin Mar 12, 2025
2186e47
improve documentation
samuelcolvin Mar 13, 2025
894a590
fix tests
samuelcolvin Mar 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more docs
  • Loading branch information
samuelcolvin committed Mar 10, 2025
commit 294cbd2ee2c910e9569605b38554038df8d31883
1 change: 1 addition & 0 deletions docs/api/pydantic_graph/nodes.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
::: pydantic_graph.nodes
options:
members:
- StateT
- GraphRunContext
- BaseNode
- End
Expand Down
2 changes: 2 additions & 0 deletions docs/api/pydantic_graph/persistence.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
::: pydantic_graph.persistence

::: pydantic_graph.persistence.in_mem

::: pydantic_graph.persistence.file
370 changes: 194 additions & 176 deletions docs/graph.md

Large diffs are not rendered by default.

36 changes: 15 additions & 21 deletions examples/pydantic_ai_examples/question_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,24 @@ class QuestionState:


@dataclass
class GenerateQuestion(BaseNode[QuestionState]):
async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
class Ask(BaseNode[QuestionState]):
async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer:
result = await ask_agent.run(
'Ask a simple question with a single correct answer.',
message_history=ctx.state.ask_agent_messages,
)
ctx.state.ask_agent_messages += result.all_messages()
ctx.state.question = result.data
return Ask(result.data)
return Answer(result.data)


@dataclass
class Ask(BaseNode[QuestionState]):
class Answer(BaseNode[QuestionState]):
question: str

async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer:
async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate:
answer = input(f'{self.question}: ')
return Answer(answer)
return Evaluate(answer)


class EvaluationResult(BaseModel, use_attribute_docstrings=True):
Expand All @@ -73,7 +73,7 @@ class EvaluationResult(BaseModel, use_attribute_docstrings=True):


@dataclass
class Answer(BaseNode[QuestionState, None, str]):
class Evaluate(BaseNode[QuestionState, None, str]):
answer: str

async def run(
Expand All @@ -96,21 +96,20 @@ async def run(
class Reprimand(BaseNode[QuestionState]):
comment: str

async def run(self, ctx: GraphRunContext[QuestionState]) -> GenerateQuestion:
async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
print(f'Comment: {self.comment}')
# > Comment: Vichy is no longer the capital of France.
ctx.state.question = None
return GenerateQuestion()
return Ask()


question_graph = Graph(
nodes=(GenerateQuestion, Ask, Answer, Reprimand), state_type=QuestionState
nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState
)


async def run_as_continuous():
state = QuestionState()
node = GenerateQuestion()
node = Ask()
end = await question_graph.run(node, state=state)
print('END:', end.output)

Expand All @@ -124,10 +123,10 @@ async def run_as_cli(answer: str | None):
assert answer is not None, (
'answer required, usage "uv run -m pydantic_ai_examples.question_graph cli <answer>"'
)
node = Answer(answer)
node = Evaluate(answer)
else:
state = QuestionState()
node = GenerateQuestion()
node = Ask()
# debug(state, node)

with logfire.span('run questions graph'):
Expand All @@ -139,7 +138,7 @@ async def run_as_cli(answer: str | None):
print('history:', '\n'.join(str(e.node) for e in history), sep='\n')
print('Finished!')
break
elif isinstance(node, Ask):
elif isinstance(node, Answer):
print(node.question)
break
# otherwise just continue
Expand All @@ -165,12 +164,7 @@ async def run_as_cli(answer: str | None):
sys.exit(1)

if sub_command == 'mermaid':
print(question_graph.mermaid_code(start_node=GenerateQuestion))
print(
question_graph.mermaid_save(
'question_graph.jpg', start_node=GenerateQuestion
)
)
print(question_graph.mermaid_code(start_node=Ask))
elif sub_command == 'continuous':
asyncio.run(run_as_continuous())
else:
Expand Down
23 changes: 13 additions & 10 deletions pydantic_graph/pydantic_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from typing_inspection import typing_objects

from . import _utils, exceptions, mermaid
from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT
from .persistence import StatePersistence, StateT, set_nodes_type_context
from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT, StateT
from .persistence import BaseStatePersistence, set_nodes_type_context
from .persistence.in_mem import SimpleStatePersistence

# while waiting for https://github.com/pydantic/logfire/issues/745
Expand Down Expand Up @@ -125,7 +125,7 @@ async def run(
*,
state: StateT = None,
deps: DepsT = None,
persistence: StatePersistence[StateT, T] | None = None,
persistence: BaseStatePersistence[StateT, T] | None = None,
infer_name: bool = True,
span: LogfireSpan | None = None,
) -> GraphRunResult[StateT, T]:
Expand Down Expand Up @@ -181,7 +181,7 @@ def run_sync(
*,
state: StateT = None,
deps: DepsT = None,
persistence: StatePersistence[StateT, T] | None = None,
persistence: BaseStatePersistence[StateT, T] | None = None,
infer_name: bool = True,
) -> GraphRunResult[StateT, T]:
"""Synchronously run the graph.
Expand Down Expand Up @@ -215,7 +215,7 @@ async def iter(
*,
state: StateT = None,
deps: DepsT = None,
persistence: StatePersistence[StateT, T] | None = None,
persistence: BaseStatePersistence[StateT, T] | None = None,
infer_name: bool = True,
span: AbstractContextManager[Any] | None = None,
) -> AsyncIterator[GraphRun[StateT, DepsT, T]]:
Expand Down Expand Up @@ -268,7 +268,7 @@ async def iter(
async def next(
self: Graph[StateT, DepsT, T],
node: BaseNode[StateT, DepsT, T],
persistence: StatePersistence[StateT, T],
persistence: BaseStatePersistence[StateT, T],
*,
state: StateT = None,
deps: DepsT = None,
Expand Down Expand Up @@ -302,17 +302,20 @@ async def next(

async def next_from_persistence(
self: Graph[StateT, DepsT, T],
persistence: StatePersistence[StateT, T],
persistence: BaseStatePersistence[StateT, T],
*,
deps: DepsT = None,
infer_name: bool = True,
) -> BaseNode[StateT, DepsT, Any] | End[T]:
"""Run the next node in the graph from a snapshot stored in persistence."""
if infer_name and self.name is None:
self._infer_name(inspect.currentframe())
self.set_persistence_types(persistence)

snapshot = await persistence.retrieve_next()
if snapshot is None:
raise exceptions.GraphRuntimeError('Unable to restore snapshot from state persistence.')
assert snapshot.id is not None, 'Snapshot ID should be set'
snapshot.node.set_snapshot_id(snapshot.id)

run = GraphRun[StateT, DepsT, T](
Expand All @@ -325,7 +328,7 @@ async def next_from_persistence(
)
return await run.next()

def set_persistence_types(self, persistence: StatePersistence[StateT, RunEndT]) -> None:
def set_persistence_types(self, persistence: BaseStatePersistence[StateT, RunEndT]) -> None:
with set_nodes_type_context([node_def.node for node_def in self.node_defs.values()]):
persistence.set_types(lambda: self._inferred_types)

Expand Down Expand Up @@ -590,7 +593,7 @@ def __init__(
*,
graph: Graph[StateT, DepsT, RunEndT],
start_node: BaseNode[StateT, DepsT, RunEndT],
persistence: StatePersistence[StateT, RunEndT],
persistence: BaseStatePersistence[StateT, RunEndT],
state: StateT,
deps: DepsT,
snapshot_id: str | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems weird to me to have snapshot_id be a kwarg to GraphRun. Maybe you can explain this to me synchronously, but I feel like one of three things should be the case:

  • We can rework things so the GraphRun doesn't need a snapshot_id argument
  • We can rework the name of the snapshot_id argument to be more clear what its purpose is
  • My mental model about what is happening is way off

Regardless of which is right I think a discussion may get us to the bottom of this fairly quickly.

Expand Down Expand Up @@ -738,4 +741,4 @@ class GraphRunResult(Generic[StateT, RunEndT]):

output: RunEndT
state: StateT
persistence: StatePersistence[StateT, RunEndT] = field(repr=False)
persistence: BaseStatePersistence[StateT, RunEndT] = field(repr=False)
10 changes: 4 additions & 6 deletions pydantic_graph/pydantic_graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, is_dataclass
from functools import cache
from typing import TYPE_CHECKING, Any, ClassVar, Generic, get_type_hints
from typing import Any, ClassVar, Generic, get_type_hints
from uuid import uuid4

from typing_extensions import Never, Self, TypeVar, get_origin

from . import _utils, exceptions

if TYPE_CHECKING:
from .persistence import StateT
else:
StateT = TypeVar('StateT', default=None)
__all__ = 'GraphRunContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'DepsT', 'StateT', 'RunEndT'

__all__ = 'GraphRunContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'DepsT', 'RunEndT'

StateT = TypeVar('StateT', default=None)
"""Type variable for the state in a graph."""
RunEndT = TypeVar('RunEndT', covariant=True, default=None)
"""Covariant type variable for the return type of a graph [`run`][pydantic_graph.graph.Graph.run]."""
NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never)
Expand Down
15 changes: 9 additions & 6 deletions pydantic_graph/pydantic_graph/persistence/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,29 @@
import pydantic
from typing_extensions import TypeVar

from ..nodes import BaseNode, End, RunEndT
from ..nodes import BaseNode, End
from . import _utils

__all__ = (
'StateT',
'NodeSnapshot',
'EndSnapshot',
'Snapshot',
'StatePersistence',
'BaseStatePersistence',
'set_nodes_type_context',
'SnapshotStatus',
)

StateT = TypeVar('StateT', default=None)
"""Type variable for the state in a graph."""
StateT = TypeVar('StateT', default=Any)
RunEndT = TypeVar('RunEndT', covariant=True, default=Any)

UNSET_SNAPSHOT_ID = '__unset__'
SnapshotStatus = Literal['created', 'pending', 'running', 'success', 'error']
"""The status of a snapshot.

- `'created'`: The snapshot has been created but not yet run.
- `'pending'`: The snapshot has been retrieved with
[`retrieve_next`][pydantic_graph.persistence.StatePersistence.retrieve_next] but not yet run.
[`retrieve_next`][pydantic_graph.persistence.BaseStatePersistence.retrieve_next] but not yet run.
- `'running'`: The snapshot is currently running.
- `'success'`: The snapshot has been run successfully.
- `'error'`: The snapshot has been run but an error occurred.
Expand Down Expand Up @@ -101,7 +101,7 @@ def node(self) -> End[RunEndT]:
"""


class StatePersistence(ABC, Generic[StateT, RunEndT]):
class BaseStatePersistence(ABC, Generic[StateT, RunEndT]):
"""Abstract base class for storing the state of a graph."""

@abstractmethod
Expand Down Expand Up @@ -164,6 +164,9 @@ def record_run(self, snapshot_id: str) -> AbstractAsyncContextManager[None]:
async def retrieve_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
"""Retrieve a node snapshot with status `'created`' and set its status to `'pending'`.

This is used by [`Graph.next_from_persistence`][pydantic_graph.graph.Graph.next_from_persistence]
to get the next node to run.

Returns: The snapshot, or `None` if no snapshot with status `'created`' exists.
"""
raise NotImplementedError
Expand Down
3 changes: 2 additions & 1 deletion pydantic_graph/pydantic_graph/persistence/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def __get_pydantic_core_schema__(
nodes = nodes_type_context.get()
except LookupError as e:
raise RuntimeError(
'Unable to build a Pydantic schema for `BaseNode` without setting `nodes_type_context`.'
'Unable to build a Pydantic schema for `BaseNode` without setting `nodes_type_context`. '
'You should build Pydantic schemas for snapshots using `StatePersistence.set_types()`.'
) from e
if len(nodes) == 1:
nodes_type = nodes[0]
Expand Down
13 changes: 6 additions & 7 deletions pydantic_graph/pydantic_graph/persistence/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,28 @@
from typing import Any, Callable

import pydantic
from typing_extensions import TypeVar

from .. import _utils as _graph_utils, exceptions
from ..nodes import BaseNode, End
from . import (
BaseStatePersistence,
EndSnapshot,
NodeSnapshot,
RunEndT,
Snapshot,
SnapshotStatus,
StatePersistence,
StateT,
_utils,
build_snapshot_list_type_adapter,
)

StateT = TypeVar('StateT', default=Any)
RunEndT = TypeVar('RunEndT', default=Any)


@dataclass
class FileStatePersistence(StatePersistence[StateT, RunEndT]):
"""State persistence that just hold the latest snapshot."""
class FileStatePersistence(BaseStatePersistence[StateT, RunEndT]):
"""File based state persistence that hold a list of snapshots in a JSON file."""

json_file: Path
"""Path to the JSON file where the snapshots are stored."""
_snapshots_type_adapter: pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]] | None = field(
default=None, init=False, repr=False
)
Expand Down
14 changes: 6 additions & 8 deletions pydantic_graph/pydantic_graph/persistence/in_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,23 @@
from typing import Any, Callable

import pydantic
from typing_extensions import TypeVar

from .. import exceptions
from ..nodes import BaseNode, End
from . import (
BaseStatePersistence,
EndSnapshot,
NodeSnapshot,
RunEndT,
Snapshot,
StatePersistence,
StateT,
_utils,
build_snapshot_list_type_adapter,
)

StateT = TypeVar('StateT', default=Any)
RunEndT = TypeVar('RunEndT', default=Any)


@dataclass
class SimpleStatePersistence(StatePersistence[StateT, RunEndT]):
class SimpleStatePersistence(BaseStatePersistence[StateT, RunEndT]):
"""Simple in memory state persistence that just hold the latest snapshot.

If no state persistence implementation is provided when running a graph, this is used by default.
Expand Down Expand Up @@ -85,8 +83,8 @@ async def load(self) -> list[Snapshot[StateT, RunEndT]]:


@dataclass
class FullStatePersistence(StatePersistence[StateT, RunEndT]):
"""In memory state persistence that hold a history of nodes."""
class FullStatePersistence(BaseStatePersistence[StateT, RunEndT]):
"""In memory state persistence that hold a list of snapshots."""

deep_copy: bool = True
"""Whether to deep copy the state and nodes when storing them.
Expand Down
Loading