diff --git a/localstack-core/localstack/services/stepfunctions/backend/execution.py b/localstack-core/localstack/services/stepfunctions/backend/execution.py index 2798bf7d479ee..24f9a3746dca9 100644 --- a/localstack-core/localstack/services/stepfunctions/backend/execution.py +++ b/localstack-core/localstack/services/stepfunctions/backend/execution.py @@ -15,7 +15,6 @@ ExecutionStatus, GetExecutionHistoryOutput, HistoryEventList, - InvalidName, SensitiveCause, SensitiveError, StartExecutionOutput, @@ -71,11 +70,12 @@ def __init__(self, execution: Execution): self.execution = execution def _reflect_execution_status(self): - exit_program_state: ProgramState = self.execution.exec_worker.env.program_state() + exec_worker = get_exec_worker(self.execution.exec_arn) + exit_program_state: ProgramState = exec_worker.env.program_state() self.execution.stop_date = datetime.datetime.now(tz=datetime.UTC) if isinstance(exit_program_state, ProgramEnded): self.execution.exec_status = ExecutionStatus.SUCCEEDED - self.execution.output = self.execution.exec_worker.env.states.get_input() + self.execution.output = exec_worker.env.states.get_input() elif isinstance(exit_program_state, ProgramStopped): self.execution.exec_status = ExecutionStatus.ABORTED elif isinstance(exit_program_state, ProgramError): @@ -94,6 +94,13 @@ def terminated(self) -> None: self.execution.publish_execution_status_change_event() +EXEC_ARN_TO_WORKER: dict[str, ExecutionWorker] = {} + + +def get_exec_worker(arn: str) -> ExecutionWorker | None: + return EXEC_ARN_TO_WORKER.get(arn) + + class Execution: name: Final[str] sm_type: Final[StateMachineType] @@ -125,8 +132,6 @@ class Execution: error: SensitiveError | None cause: SensitiveCause | None - exec_worker: ExecutionWorker | None - _activity_store: dict[Arn, Activity] def __init__( @@ -169,7 +174,6 @@ def __init__( self.stop_date = None self.output = None self.output_details = CloudWatchEventsExecutionDataDetails(included=True) - self.exec_worker = None self.error = None self.cause = None self._activity_store = activity_store @@ -257,7 +261,8 @@ def to_execution_list_item(self) -> ExecutionListItem: return item def to_history_output(self) -> GetExecutionHistoryOutput: - env = self.exec_worker.env + exec_worker = get_exec_worker(self.exec_arn) + env = exec_worker.env event_history: HistoryEventList = [] if env is not None: # The execution has not started yet. @@ -307,20 +312,6 @@ def _get_start_execution_worker(self) -> ExecutionWorker: mock_test_case=self.mock_test_case, ) - def start(self) -> None: - # TODO: checks exec_worker does not exists already? - if self.exec_worker: - raise InvalidName() # TODO. - self.exec_worker = self._get_start_execution_worker() - self.exec_status = ExecutionStatus.RUNNING - self.publish_execution_status_change_event() - self.exec_worker.start() - - def stop(self, stop_date: datetime.datetime, error: str | None, cause: str | None): - exec_worker: ExecutionWorker | None = self.exec_worker - if exec_worker: - exec_worker.stop(stop_date=stop_date, cause=cause, error=error) - def publish_execution_status_change_event(self): input_value = ( {} if not self.input_data else to_json_str(self.input_data, separators=(",", ":")) diff --git a/localstack-core/localstack/services/stepfunctions/backend/test_state/execution.py b/localstack-core/localstack/services/stepfunctions/backend/test_state/execution.py index 92fcb8546c0d6..d7b28691b86c8 100644 --- a/localstack-core/localstack/services/stepfunctions/backend/test_state/execution.py +++ b/localstack-core/localstack/services/stepfunctions/backend/test_state/execution.py @@ -26,6 +26,7 @@ from localstack.services.stepfunctions.backend.execution import ( BaseExecutionWorkerCommunication, Execution, + get_exec_worker, ) from localstack.services.stepfunctions.backend.state_machine import StateMachineInstance from localstack.services.stepfunctions.backend.test_state.execution_worker import ( @@ -36,17 +37,17 @@ class TestStateExecution(Execution): - exec_worker: TestStateExecutionWorker | None next_state: str | None class TestCaseExecutionWorkerCommunication(BaseExecutionWorkerCommunication): _execution: TestStateExecution def terminated(self) -> None: - exit_program_state: ProgramState = self.execution.exec_worker.env.program_state() + exec_worker = get_exec_worker(self.execution.exec_arn) + exit_program_state: ProgramState = exec_worker.env.program_state() if isinstance(exit_program_state, ProgramChoiceSelected): self.execution.exec_status = ExecutionStatus.SUCCEEDED - self.execution.output = self.execution.exec_worker.env.states.get_input() + self.execution.output = exec_worker.env.states.get_input() self.execution.next_state = exit_program_state.next_state_name else: self._reflect_execution_status() diff --git a/localstack-core/localstack/services/stepfunctions/provider.py b/localstack-core/localstack/services/stepfunctions/provider.py index 19e3e68e07603..595c39a029dc8 100644 --- a/localstack-core/localstack/services/stepfunctions/provider.py +++ b/localstack-core/localstack/services/stepfunctions/provider.py @@ -139,7 +139,12 @@ ) from localstack.services.stepfunctions.backend.activity import Activity, ActivityTask from localstack.services.stepfunctions.backend.alias import Alias -from localstack.services.stepfunctions.backend.execution import Execution, SyncExecution +from localstack.services.stepfunctions.backend.execution import ( + EXEC_ARN_TO_WORKER, + Execution, + SyncExecution, + get_exec_worker, +) from localstack.services.stepfunctions.backend.state_machine import ( StateMachineInstance, StateMachineRevision, @@ -710,7 +715,7 @@ def send_task_heartbeat( running_executions: list[Execution] = self._get_executions(context, ExecutionStatus.RUNNING) for execution in running_executions: try: - if execution.exec_worker.env.callback_pool_manager.heartbeat( + if get_exec_worker(execution.exec_arn).env.callback_pool_manager.heartbeat( callback_id=task_token ): return SendTaskHeartbeatOutput() @@ -732,7 +737,7 @@ def send_task_success( running_executions: list[Execution] = self._get_executions(context, ExecutionStatus.RUNNING) for execution in running_executions: try: - if execution.exec_worker.env.callback_pool_manager.notify( + if get_exec_worker(execution.exec_arn).env.callback_pool_manager.notify( callback_id=task_token, outcome=outcome ): return SendTaskSuccessOutput() @@ -755,7 +760,7 @@ def send_task_failure( store = self.get_store(context) for execution in store.executions.values(): try: - if execution.exec_worker.env.callback_pool_manager.notify( + if get_exec_worker(execution.exec_arn).env.callback_pool_manager.notify( callback_id=task_token, outcome=outcome ): return SendTaskFailureOutput() @@ -877,7 +882,15 @@ def start_execution( store.executions[exec_arn] = execution - execution.start() + if get_exec_worker(exec_arn): + raise InvalidName() # TODO + + exec_worker = execution._get_start_execution_worker() + execution.exec_status = ExecutionStatus.RUNNING + execution.publish_execution_status_change_event() + exec_worker.start() + EXEC_ARN_TO_WORKER[exec_arn] = exec_worker + return execution.to_start_output() def start_sync_execution( @@ -951,7 +964,15 @@ def start_sync_execution( ) self.get_store(context).executions[exec_arn] = execution - execution.start() + if get_exec_worker(exec_arn): + raise InvalidName() # TODO + + exec_worker = execution._get_start_execution_worker() + execution.exec_status = ExecutionStatus.RUNNING + execution.publish_execution_status_change_event() + exec_worker.start() + EXEC_ARN_TO_WORKER[exec_arn] = exec_worker + return execution.to_start_sync_execution_output() def describe_execution( @@ -1268,7 +1289,9 @@ def stop_execution( self._raise_resource_type_not_in_context(resource_type=execution.sm_type) stop_date = datetime.datetime.now(tz=datetime.UTC) - execution.stop(stop_date=stop_date, cause=cause, error=error) + + if exec_worker := get_exec_worker(execution.exec_arn): + exec_worker.stop(stop_date=stop_date, cause=cause, error=error) return StopExecutionOutput(stopDate=stop_date) def update_state_machine( @@ -1427,9 +1450,9 @@ def describe_map_run( ) -> DescribeMapRunOutput: store = self.get_store(context) for execution in store.executions.values(): - map_run_record: MapRunRecord | None = ( - execution.exec_worker.env.map_run_record_pool_manager.get(map_run_arn) - ) + map_run_record: MapRunRecord | None = get_exec_worker( + execution.exec_arn + ).env.map_run_record_pool_manager.get(map_run_arn) if map_run_record is not None: return map_run_record.describe() raise ResourceNotFound() @@ -1444,9 +1467,9 @@ def list_map_runs( ) -> ListMapRunsOutput: # TODO: add support for paging. execution = self._get_execution(context=context, execution_arn=execution_arn) - map_run_records: list[MapRunRecord] = ( - execution.exec_worker.env.map_run_record_pool_manager.get_all() - ) + map_run_records: list[MapRunRecord] = get_exec_worker( + execution.exec_arn + ).env.map_run_record_pool_manager.get_all() return ListMapRunsOutput( mapRuns=[map_run_record.list_item() for map_run_record in map_run_records] ) @@ -1467,9 +1490,9 @@ def update_map_run( # TODO: investigate behaviour of empty requests. store = self.get_store(context) for execution in store.executions.values(): - map_run_record: MapRunRecord | None = ( - execution.exec_worker.env.map_run_record_pool_manager.get(map_run_arn) - ) + map_run_record: MapRunRecord | None = get_exec_worker( + execution.exec_arn + ).env.map_run_record_pool_manager.get(map_run_arn) if map_run_record is not None: map_run_record.update( max_concurrency=max_concurrency, @@ -1521,7 +1544,15 @@ def test_state( input_data=input_json, activity_store=self.get_store(context).activities, ) - execution.start() + + if get_exec_worker(exec_arn): + raise InvalidName() # TODO + + exec_worker = execution._get_start_execution_worker() + execution.exec_status = ExecutionStatus.RUNNING + execution.publish_execution_status_change_event() + exec_worker.start() + EXEC_ARN_TO_WORKER[exec_arn] = exec_worker test_state_output = execution.to_test_state_output( inspection_level=inspection_level or InspectionLevel.INFO @@ -1585,7 +1616,7 @@ def _send_activity_task_started( ) -> None: executions: list[Execution] = self._get_executions(context) for execution in executions: - callback_endpoint = execution.exec_worker.env.callback_pool_manager.get( + callback_endpoint = get_exec_worker(execution.exec_arn).env.callback_pool_manager.get( callback_id=task_token ) if isinstance(callback_endpoint, ActivityCallbackEndpoint):