Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions src/pdl/pdl_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ class ClosureBlock(FunctionBlock):
def __call__(self, **kwds):
state = self.pdl__state.with_yield_result(False).with_yield_background(False)
current_context = state.current_pdl_context.ref
result, _, _ = execute_call(
state, current_context, self, kwds, empty_block_location
result, _ = execute_call(
state, self, kwds, current_context, empty_block_location
)
return result

Expand Down Expand Up @@ -1919,8 +1919,6 @@ def call_pdl(code: str, scope: ScopeType) -> PdlLazy[Any]:
def process_call(
state: InterpreterState, scope: ScopeType, block: CallBlock, loc: PdlLocationType
) -> tuple[Any, LazyMessages, ScopeType, CallBlock]:
result = None
background: LazyMessages = DependentContext([])
args, block = process_expr_of(block, "args", scope, loc)
closure, _ = process_expr_of(block, "call", scope, loc)
if not isinstance(closure, ClosureBlock):
Expand All @@ -1943,20 +1941,33 @@ def process_call(
)
current_context = scope.data["pdl_context"]
try:
result, background, call_trace = execute_call(
state, current_context, closure, args, loc
)
result, call_trace = execute_call(state, closure, args, current_context, loc)
except PDLRuntimeError as exc:
raise PDLRuntimeError(
exc.message,
loc=exc.loc or closure.pdl__location,
trace=block.model_copy(update={"pdl__trace": exc.pdl__trace}),
) from exc
trace = block.model_copy(update={"pdl__trace": call_trace})
background = SingletonContext(
PdlDict(
{
"role": state.role,
"content": result,
"pdl__defsite": ".".join(state.id_stack),
}
)
)
return result, background, scope, trace


def execute_call(state, current_context, closure, args, loc):
def execute_call(
state: InterpreterState,
closure: ClosureBlock,
args: dict[str, Any],
current_context: LazyMessages,
loc: PdlLocationType,
) -> tuple[Any, BlockType]:
if "pdl_context" in args:
args = args | {"pdl_context": deserialize(args["pdl_context"])}
f_body = closure.return_
Expand All @@ -1973,7 +1984,7 @@ def execute_call(state, current_context, closure, args, loc):
)
else:
fun_loc = empty_block_location
result, background, _, f_trace = process_block(state, f_scope, f_body, fun_loc)
result, _, _, f_trace = process_block(state, f_scope, f_body, fun_loc)
if closure.spec is not None:
result = lazy_apply(
lambda r: result_with_type_checking(
Expand All @@ -1985,7 +1996,7 @@ def execute_call(state, current_context, closure, args, loc):
),
result,
)
return result, background, f_trace
return result, f_trace


def process_input(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_call_from_code_04():
"""
result = exec_str(prog)
assert result == [
"Bye[{'role': 'user', 'content': 'Hello', 'pdl__defsite': 'lastOf.0'},{'role': 'user', 'content': 'How are you?', 'pdl__defsite': 'lastOf.1.array.0.text.0.call.lastOf.0'},{'role': 'user', 'content': 'Bye', 'pdl__defsite': 'lastOf.1.array.0.text.0.call.lastOf.1'}]",
"Bye[{'role': 'user', 'content': 'Hello', 'pdl__defsite': 'lastOf.0'},{'role': 'user', 'content': 'Bye', 'pdl__defsite': 'lastOf.1.array.0.text.0.call'}]",
"Bye[{'role': 'user', 'content': 'Hello', 'pdl__defsite': 'lastOf.0'},{'role': 'user', 'content': 'Bye', 'pdl__defsite': 'lastOf.1.array.1.text.0'}]",
"Bye[{'role': 'user', 'content': 'Hello', 'pdl__defsite': 'lastOf.0'},{'role': 'user', 'content': 'Bye', 'pdl__defsite': 'lastOf.1.array.2.text.0.code'}]",
]
Loading