-
Notifications
You must be signed in to change notification settings - Fork 1.3k
State persistence #955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
State persistence #955
Changes from 1 commit
b4e2bda
4c1d50d
b737c97
1c46cf9
3889e53
456560d
767e08d
daffea5
4562244
2bdb029
e9d8052
576d4af
09a5174
e271af5
3d93cb0
9581979
88723e3
2a9f90e
294cbd2
f1e4ca1
1edbb06
0ad49fd
27aee1a
76cae5a
f51699e
1cccf75
2186e47
894a590
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
::: pydantic_graph.nodes | ||
options: | ||
members: | ||
- StateT | ||
- GraphRunContext | ||
- BaseNode | ||
- End | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,5 @@ | |
::: pydantic_graph.persistence | ||
|
||
::: pydantic_graph.persistence.in_mem | ||
|
||
::: pydantic_graph.persistence.file |
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,8 +14,8 @@ | |
from typing_inspection import typing_objects | ||
|
||
from . import _utils, exceptions, mermaid | ||
from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT | ||
from .persistence import StatePersistence, StateT, set_nodes_type_context | ||
from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT, StateT | ||
from .persistence import BaseStatePersistence, set_nodes_type_context | ||
from .persistence.in_mem import SimpleStatePersistence | ||
|
||
# while waiting for https://github.com/pydantic/logfire/issues/745 | ||
|
@@ -125,7 +125,7 @@ async def run( | |
*, | ||
state: StateT = None, | ||
deps: DepsT = None, | ||
persistence: StatePersistence[StateT, T] | None = None, | ||
persistence: BaseStatePersistence[StateT, T] | None = None, | ||
infer_name: bool = True, | ||
span: LogfireSpan | None = None, | ||
) -> GraphRunResult[StateT, T]: | ||
|
@@ -181,7 +181,7 @@ def run_sync( | |
*, | ||
state: StateT = None, | ||
deps: DepsT = None, | ||
persistence: StatePersistence[StateT, T] | None = None, | ||
persistence: BaseStatePersistence[StateT, T] | None = None, | ||
infer_name: bool = True, | ||
) -> GraphRunResult[StateT, T]: | ||
"""Synchronously run the graph. | ||
|
@@ -215,7 +215,7 @@ async def iter( | |
*, | ||
state: StateT = None, | ||
deps: DepsT = None, | ||
persistence: StatePersistence[StateT, T] | None = None, | ||
persistence: BaseStatePersistence[StateT, T] | None = None, | ||
infer_name: bool = True, | ||
span: AbstractContextManager[Any] | None = None, | ||
) -> AsyncIterator[GraphRun[StateT, DepsT, T]]: | ||
|
@@ -268,7 +268,7 @@ async def iter( | |
async def next( | ||
self: Graph[StateT, DepsT, T], | ||
node: BaseNode[StateT, DepsT, T], | ||
persistence: StatePersistence[StateT, T], | ||
persistence: BaseStatePersistence[StateT, T], | ||
*, | ||
state: StateT = None, | ||
deps: DepsT = None, | ||
|
@@ -302,17 +302,20 @@ async def next( | |
|
||
async def next_from_persistence( | ||
self: Graph[StateT, DepsT, T], | ||
persistence: StatePersistence[StateT, T], | ||
persistence: BaseStatePersistence[StateT, T], | ||
*, | ||
deps: DepsT = None, | ||
infer_name: bool = True, | ||
) -> BaseNode[StateT, DepsT, Any] | End[T]: | ||
"""Run the next node in the graph from a snapshot stored in persistence.""" | ||
if infer_name and self.name is None: | ||
self._infer_name(inspect.currentframe()) | ||
self.set_persistence_types(persistence) | ||
|
||
snapshot = await persistence.retrieve_next() | ||
if snapshot is None: | ||
raise exceptions.GraphRuntimeError('Unable to restore snapshot from state persistence.') | ||
assert snapshot.id is not None, 'Snapshot ID should be set' | ||
snapshot.node.set_snapshot_id(snapshot.id) | ||
|
||
run = GraphRun[StateT, DepsT, T]( | ||
|
@@ -325,7 +328,7 @@ async def next_from_persistence( | |
) | ||
return await run.next() | ||
|
||
def set_persistence_types(self, persistence: StatePersistence[StateT, RunEndT]) -> None: | ||
def set_persistence_types(self, persistence: BaseStatePersistence[StateT, RunEndT]) -> None: | ||
with set_nodes_type_context([node_def.node for node_def in self.node_defs.values()]): | ||
persistence.set_types(lambda: self._inferred_types) | ||
|
||
|
@@ -590,7 +593,7 @@ def __init__( | |
*, | ||
graph: Graph[StateT, DepsT, RunEndT], | ||
start_node: BaseNode[StateT, DepsT, RunEndT], | ||
persistence: StatePersistence[StateT, RunEndT], | ||
persistence: BaseStatePersistence[StateT, RunEndT], | ||
state: StateT, | ||
deps: DepsT, | ||
snapshot_id: str | None = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems weird to me to have
Regardless of which is right I think a discussion may get us to the bottom of this fairly quickly. |
||
|
@@ -738,4 +741,4 @@ class GraphRunResult(Generic[StateT, RunEndT]): | |
|
||
output: RunEndT | ||
state: StateT | ||
persistence: StatePersistence[StateT, RunEndT] = field(repr=False) | ||
persistence: BaseStatePersistence[StateT, RunEndT] = field(repr=False) |
Uh oh!
There was an error while loading. Please reload this page.