Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[SFN] Support for global Timeouts, Misc of Enhancements #9009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions localstack/services/stepfunctions/asl/antlr/ASLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ top_layer_stmt
: comment_decl
| startat_decl
| states_decl
| timeout_seconds_decl
;

startat_decl
Expand Down

Large diffs are not rendered by default.

1,312 changes: 661 additions & 651 deletions localstack/services/stepfunctions/asl/antlr/runtime/ASLParser.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
from typing import Any, Final

from localstack.aws.api.stepfunctions import HistoryEventType, TaskFailedEventDetails
from localstack.services.stepfunctions.asl.component.common.error_name.failure_event import (
FailureEvent,
FailureEventException,
)
from localstack.services.stepfunctions.asl.component.common.error_name.states_error_name import (
StatesErrorName,
)
from localstack.services.stepfunctions.asl.component.common.error_name.states_error_name_type import (
StatesErrorNameType,
)
from localstack.services.stepfunctions.asl.component.common.payload.payloadvalue.payloadbinding.payload_binding import (
PayloadBinding,
)
from localstack.services.stepfunctions.asl.eval.environment import Environment
from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str
from localstack.services.stepfunctions.asl.utils.json_path import JSONPathUtils


Expand All @@ -18,5 +31,18 @@ def from_raw(cls, string_dollar: str, string_path: str):
return cls(field=field, path=string_path)

def _eval_val(self, env: Environment) -> Any:
value = JSONPathUtils.extract_json(self.path, env.inp)
try:
value = JSONPathUtils.extract_json(self.path, env.inp)
except RuntimeError:
failure_event = FailureEvent(
error_name=StatesErrorName(typ=StatesErrorNameType.StatesRuntime),
event_type=HistoryEventType.TaskFailed,
event_details=EventDetails(
taskFailedEventDetails=TaskFailedEventDetails(
error=StatesErrorNameType.StatesRuntime.to_name(),
cause=f"The JSONPath '{self.path}' specified for the field '$.{self.field}' could not be found in the input '{to_json_str(env.inp)}'",
)
),
)
raise FailureEventException(failure_event=failure_event)
return value
41 changes: 36 additions & 5 deletions localstack/services/stepfunctions/asl/component/program/program.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import threading
from typing import Final, Optional

from localstack.aws.api.stepfunctions import (
ExecutionAbortedEventDetails,
ExecutionFailedEventDetails,
ExecutionSucceededEventDetails,
ExecutionTimedOutEventDetails,
HistoryEventExecutionDataDetails,
HistoryEventType,
)
Expand All @@ -13,6 +15,7 @@
FailureEventException,
)
from localstack.services.stepfunctions.asl.component.common.flow.start_at import StartAt
from localstack.services.stepfunctions.asl.component.common.timeouts.timeout import TimeoutSeconds
from localstack.services.stepfunctions.asl.component.eval_component import EvalComponent
from localstack.services.stepfunctions.asl.component.state.state import CommonStateField
from localstack.services.stepfunctions.asl.component.states import States
Expand All @@ -23,18 +26,32 @@
ProgramError,
ProgramState,
ProgramStopped,
ProgramTimedOut,
)
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str
from localstack.utils.collections import select_from_typed_dict
from localstack.utils.threads import TMP_THREADS

LOG = logging.getLogger(__name__)


class Program(EvalComponent):
def __init__(self, start_at: StartAt, states: States, comment: Optional[Comment] = None):
self.start_at: Final[StartAt] = start_at
self.states: Final[States] = states
self.comment: Final[Optional[Comment]] = comment
start_at: Final[StartAt]
states: Final[States]
timeout_seconds: Final[Optional[TimeoutSeconds]]
comment: Final[Optional[Comment]]

def __init__(
self,
start_at: StartAt,
states: States,
timeout_seconds: Optional[TimeoutSeconds],
comment: Optional[Comment] = None,
):
self.start_at = start_at
self.states = states
self.timeout_seconds = timeout_seconds
self.comment = comment

def _get_state(self, state_name: str) -> CommonStateField:
state: Optional[CommonStateField] = self.states.states.get(state_name, None)
Expand All @@ -43,8 +60,15 @@ def _get_state(self, state_name: str) -> CommonStateField:
return state

def eval(self, env: Environment) -> None:
timeout = self.timeout_seconds.timeout_seconds if self.timeout_seconds else None
env.next_state_name = self.start_at.start_at_name
super().eval(env=env)
worker_thread = threading.Thread(target=super().eval, args=(env,))
TMP_THREADS.append(worker_thread)
worker_thread.start()
worker_thread.join(timeout=timeout)
is_timeout = worker_thread.is_alive()
if is_timeout:
env.set_timed_out()

def _eval_body(self, env: Environment) -> None:
try:
Expand Down Expand Up @@ -84,6 +108,13 @@ def _eval_body(self, env: Environment) -> None:
)
),
)
elif isinstance(program_state, ProgramTimedOut):
env.event_history.add_event(
hist_type_event=HistoryEventType.ExecutionTimedOut,
event_detail=EventDetails(
executionTimedOutEventDetails=ExecutionTimedOutEventDetails()
),
)
elif isinstance(program_state, ProgramEnded):
env.event_history.add_event(
hist_type_event=HistoryEventType.ExecutionSucceeded,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def _eval_body(self, env: Environment) -> None:
input_items: list[json] = self._eval_input.input_items

input_item_prog: Final[Program] = Program(
start_at=self._start_at, states=self._states, comment=self._comment
start_at=self._start_at,
states=self._states,
timeout_seconds=None,
comment=self._comment,
)
self._job_pool = JobPool(
job_program=input_item_prog, job_inputs=self._eval_input.input_items
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ def next_job(self) -> Optional[Any]:
except IndexError:
return None

def _is_terminated(self) -> bool:
return len(self._closed_jobs) == self._jobs_number or self._worker_exception is not None

def _notify_on_termination(self):
all_cosed = len(self._closed_jobs) == self._jobs_number
if all_cosed or self._worker_exception is not None:
if self._is_terminated():
self._termination_event.set()

def get_worker_exception(self) -> Optional[Exception]:
Expand All @@ -84,4 +86,5 @@ def get_closed_jobs(self) -> list[Job]:
return sorted(closed_jobs, key=lambda closed_job: closed_job.job_index)

def await_jobs(self):
self._termination_event.wait()
if not self._is_terminated():
self._termination_event.wait()
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def exec_lambda_function(env: Environment, parameters: dict) -> None:
resp_payload = invocation_resp["Payload"].read()
resp_payload_str = to_str(resp_payload)
resp_payload_json: json = json.loads(resp_payload_str)
resp_payload_value = resp_payload_json if resp_payload_json is not None else dict()
invocation_resp["Payload"] = resp_payload_value
# resp_payload_value = resp_payload_json if resp_payload_json is not None else dict()
invocation_resp["Payload"] = resp_payload_json

response = select_from_typed_dict(typed_dict=InvocationResponse, obj=invocation_resp)
env.stack.append(response)
Expand All @@ -45,7 +45,10 @@ def exec_lambda_function(env: Environment, parameters: dict) -> None:
def to_payload_type(payload: Any) -> Optional[bytes]:
if isinstance(payload, bytes):
return payload
if isinstance(payload, str):

if payload is None:
str_value = to_json_str(dict())
elif isinstance(payload, str):
try:
json.loads(payload)
str_value = payload
Expand Down
9 changes: 9 additions & 0 deletions localstack/services/stepfunctions/asl/eval/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ProgramRunning,
ProgramState,
ProgramStopped,
ProgramTimedOut,
)

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -106,6 +107,14 @@ def set_error(self, error: ExecutionFailedEventDetails) -> None:
self.program_state_event.set()
self.program_state_event.clear()

def set_timed_out(self) -> None:
with self._state_mutex:
self._program_state = ProgramTimedOut()
for frame in self._frames:
frame.set_timed_out()
self.program_state_event.set()
self.program_state_event.clear()

def set_stop(self, stop_date: Timestamp, cause: Optional[str], error: Optional[str]) -> None:
with self._state_mutex:
if isinstance(self._program_state, ProgramRunning):
Expand Down
4 changes: 4 additions & 0 deletions localstack/services/stepfunctions/asl/eval/program_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,7 @@ class ProgramError(ProgramState):
def __init__(self, error: Optional[ExecutionFailedEventDetails]):
super().__init__()
self.error = error


class ProgramTimedOut(ProgramState):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ def visitProgram_decl(self, ctx: ASLParser.Program_declContext) -> Program:
f"No '{States}' definition for Program in context: '{ctx.getText()}'."
),
),
timeout_seconds=props.get(TimeoutSeconds),
comment=props.get(typ=Comment),
)
return program
12 changes: 9 additions & 3 deletions localstack/services/stepfunctions/asl/utils/json_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@

from jsonpath_ng import parse

from localstack.services.stepfunctions.asl.utils.encoding import to_json_str


class JSONPathUtils:
@staticmethod
def extract_json(path: str, data: json) -> json:
input_expr = parse(path)
find_res = input_expr.find(data)
if isinstance(find_res, list):
value = find_res[0].value
find_res = [match.value for match in input_expr.find(data)]
if find_res == list():
raise RuntimeError(
f"The JSONPath '{path}' could not be found in the input '{to_json_str(data)}'"
)
if len(find_res) == 1:
value = find_res[0]
else:
value = find_res
return value
7 changes: 5 additions & 2 deletions localstack/services/stepfunctions/backend/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
ProgramError,
ProgramState,
ProgramStopped,
ProgramTimedOut,
)
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str
from localstack.services.stepfunctions.backend.execution_worker import ExecutionWorker
Expand Down Expand Up @@ -68,6 +69,8 @@ def terminated(self) -> None:
self.execution.exec_status = ExecutionStatus.FAILED
self.execution.error = exit_program_state.error["error"]
self.execution.cause = exit_program_state.error["cause"]
elif isinstance(exit_program_state, ProgramTimedOut):
self.execution.exec_status = ExecutionStatus.TIMED_OUT
else:
raise RuntimeWarning(
f"Execution ended with unsupported ProgramState type '{type(exit_program_state)}'."
Expand Down Expand Up @@ -203,10 +206,10 @@ def start(self) -> None:
exec_comm=Execution.BaseExecutionWorkerComm(self),
context_object_init=ContextObjectInitData(
Execution=ContextObjectExecution(
Id="TODO",
Id=self.exec_arn,
Input=self.input_data,
Name=self.state_machine.name,
RoleArn="TODO",
RoleArn=self.role_arn,
StartTime=self.start_date.time().isoformat(),
),
StateMachine=ContextObjectStateMachine(
Expand Down
10 changes: 3 additions & 7 deletions tests/aws/scenario/test_loan_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
import aws_cdk.aws_stepfunctions_tasks as tasks
import pytest

from localstack.testing.aws.util import is_aws_cloud
from localstack.testing.pytest import markers
from localstack.testing.scenario.provisioning import InfraProvisioner
from localstack.utils.files import load_file
from localstack.utils.strings import short_uid
from localstack.utils.sync import retry
from tests.aws.services.stepfunctions.utils import await_execution_terminated

RECIPIENT_LIST_STACK_NAME = "LoanBroker-RecipientList"
PROJECT_NAME = "CDK Loan Broker"
Expand Down Expand Up @@ -191,12 +190,9 @@ def test_stepfunctions_input_recipient_list(
)
execution_arn = result["executionArn"]

def _execution_finished():
res = aws_client.stepfunctions.describe_execution(executionArn=execution_arn)
assert res["status"] == expected_result
return res
await_execution_terminated(aws_client.stepfunctions, execution_arn)

result = retry(_execution_finished, sleep=2, retries=100 if is_aws_cloud() else 10)
result = aws_client.stepfunctions.describe_execution(executionArn=execution_arn)

snapshot.match("describe-execution-finished", result)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import os
import urllib.parse

import pytest

from localstack.constants import PATH_USER_REQUEST
from localstack.testing.pytest import markers
from localstack.utils.sync import wait_until
from tests.aws.services.stepfunctions.utils import await_execution_terminated


@markers.aws.unknown
Expand Down Expand Up @@ -216,7 +219,8 @@ def _sfn_finished_running():
assert "hello_with_path from stepfunctions" in execution_result["output"]


@markers.aws.unknown
@pytest.mark.skip("Terminates with FAILED on cloud; convert to SFN v2 snapshot lambda test.")
@markers.aws.needs_fixing
def test_retry_and_catch(deploy_cfn_template, aws_client):
"""
Scenario:
Expand All @@ -242,13 +246,7 @@ def test_retry_and_catch(deploy_cfn_template, aws_client):
execution = aws_client.stepfunctions.start_execution(stateMachineArn=statemachine_arn)
execution_arn = execution["executionArn"]

def _sfn_finished_running():
return (
aws_client.stepfunctions.describe_execution(executionArn=execution_arn)["status"]
!= "RUNNING"
)

assert wait_until(_sfn_finished_running)
await_execution_terminated(aws_client.stepfunctions, execution_arn)

execution_result = aws_client.stepfunctions.describe_execution(executionArn=execution_arn)
assert execution_result["status"] == "SUCCEEDED"
Expand Down
Loading