Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[SFN] Support for sync:2 #8900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,94 +1,117 @@
from __future__ import annotations

import abc
from typing import Final, Optional, TypedDict
from itertools import takewhile
from typing import Final, Optional

from localstack.services.stepfunctions.asl.component.component import Component
from localstack.utils.aws import aws_stack


class ResourceCondition(str):
WaitForTaskToken = "waitForTaskToken"
Sync2 = "sync:2"
Sync = "sync"


class ResourceARN(TypedDict):
class ResourceARN:
arn: str
partition: str
service: str
region: str
account: str
task_type: str
name: str
option: str


class Resource(Component, abc.ABC):
def __init__(self, resource_arn: str, partition: str, region: str, account: str):
self.resource_arn: Final[str] = resource_arn
self.partition: Final[str] = partition
self.region: Final[str] = region
self.account: Final[str] = account
def __init__(
self,
arn: str,
partition: str,
service: str,
region: str,
account: str,
task_type: str,
name: str,
option: Optional[str],
):
self.arn = arn
self.partition = partition
self.service = service
self.region = region
self.account = account
self.task_type = task_type
self.name = name
self.option = option

@staticmethod
def parse_resource_arn(arn: str) -> ResourceARN:
cmps: list[str] = arn.split(":")
return ResourceARN(
partition=cmps[1],
service=cmps[2],
region=cmps[3],
account=cmps[4],
task_type=cmps[5],
name=cmps[6],
def _consume_until(text: str, symbol: str) -> tuple[str, str]:
value = "".join(takewhile(lambda c: c != symbol, text))
tail_idx = len(value) + 1
return value, text[tail_idx:]

@classmethod
def from_arn(cls, arn: str) -> ResourceARN:
_, arn_tail = ResourceARN._consume_until(arn, ":")
partition, arn_tail = ResourceARN._consume_until(arn_tail, ":")
service, arn_tail = ResourceARN._consume_until(arn_tail, ":")
region, arn_tail = ResourceARN._consume_until(arn_tail, ":")
account, arn_tail = ResourceARN._consume_until(arn_tail, ":")
task_type, arn_tail = ResourceARN._consume_until(arn_tail, ":")
name, arn_tail = ResourceARN._consume_until(arn_tail, ".")
option = arn_tail
return cls(
arn=arn,
partition=partition,
service=service,
region=region,
account=account,
task_type=task_type,
name=name,
option=option,
)


class Resource(Component, abc.ABC):
resource_arn: Final[str]
partition: Final[str]
region: Final[str]
account: Final[str]

def __init__(self, resource_arn: ResourceARN):
self.resource_arn = resource_arn.arn
self.partition = resource_arn.partition
self.region = resource_arn.region
self.account = resource_arn.account

@staticmethod
def from_resource_arn(arn: str) -> Resource:
resource_arn: ResourceARN = Resource.parse_resource_arn(arn)
if not resource_arn["region"]:
resource_arn["region"] = aws_stack.get_region()
match resource_arn["service"], resource_arn["task_type"]:
resource_arn = ResourceARN.from_arn(arn)
if not resource_arn.region:
resource_arn.region = aws_stack.get_region()
match resource_arn.service, resource_arn.task_type:
case "lambda", "function":
return LambdaResource(
resource_arn=arn,
partition=resource_arn["partition"],
region=resource_arn["region"],
account=resource_arn["account"],
function_name=resource_arn["name"],
)
return LambdaResource(resource_arn=resource_arn)
case "states", "activity":
return ActivityResource(
resource_arn=arn,
partition=resource_arn["partition"],
region=resource_arn["region"],
account=resource_arn["account"],
name=resource_arn["name"],
)
case "states", service_name:
return ServiceResource(
resource_arn=arn,
partition=resource_arn["partition"],
region=resource_arn["region"],
account=resource_arn["account"],
service_name=service_name, # noqa
api_name=resource_arn["name"],
)
return ActivityResource(resource_arn=resource_arn)
case "states", _:
return ServiceResource(resource_arn=resource_arn)


class ActivityResource(Resource):
def __init__(self, resource_arn: str, partition: str, region: str, account: str, name: str):
super().__init__(
resource_arn=resource_arn, partition=partition, region=region, account=account
)
self.name: str = name
name: Final[str]

def __init__(self, resource_arn: ResourceARN):
super().__init__(resource_arn=resource_arn)
self.name = resource_arn.name


class LambdaResource(Resource):
def __init__(
self, resource_arn: str, partition: str, region: str, account: str, function_name: str
):
super().__init__(
resource_arn=resource_arn, partition=partition, region=region, account=account
)
self.function_name: str = function_name
function_name: Final[str]

def __init__(self, resource_arn: ResourceARN):
super().__init__(resource_arn=resource_arn)
self.function_name: str = resource_arn.name


class ServiceResource(Resource):
Expand All @@ -97,32 +120,29 @@ class ServiceResource(Resource):
api_action: Final[str]
condition: Final[Optional[str]]

def __init__(
self,
resource_arn: str,
partition: str,
region: str,
account: str,
service_name: str,
api_name: str,
):
super().__init__(
resource_arn=resource_arn, partition=partition, region=region, account=account
)
self.service_name = service_name
self.api_name = api_name
def __init__(self, resource_arn: ResourceARN):
super().__init__(resource_arn=resource_arn)
self.service_name = resource_arn.task_type

arn_parts = resource_arn.split(":")
tail_part = arn_parts[-1]
tail_parts = tail_part.split(".")
self.api_action = tail_parts[0]
name_parts = resource_arn.name.split(":")
if len(name_parts) == 1:
self.api_name = self.service_name
self.api_action = resource_arn.name
elif len(name_parts) > 1:
self.api_name = name_parts[0]
self.api_action = name_parts[1]
else:
raise RuntimeError(f"Incorrect definition of ResourceArn.name: '{resource_arn.name}'.")

self.condition = None
if len(tail_parts) > 1:
match tail_parts[-1]:
case "waitForTaskToken":
option = resource_arn.option
if option:
match option:
case ResourceCondition.WaitForTaskToken:
self.condition = ResourceCondition.WaitForTaskToken
case "sync":
self.condition = ResourceCondition.Sync
case "sync:2":
self.condition = ResourceCondition.Sync2
case unsupported:
raise RuntimeError(f"Unsupported condition '{unsupported}'.")
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def _sync(self, env: Environment) -> None:
f"Unsupported .sync callback procedure in resource {self.resource.resource_arn}"
)

def _sync2(self, env: Environment) -> None:
raise RuntimeError(
f"Unsupported .sync:2 callback procedure in resource {self.resource.resource_arn}"
)

def _is_condition(self):
return self.resource.condition is not None

Expand Down Expand Up @@ -127,6 +132,8 @@ def _after_eval_execution(self, env: Environment) -> None:
self._wait_for_task_token(env=env)
case ResourceCondition.Sync:
self._sync(env=env)
case ResourceCondition.Sync2:
self._sync2(env=env)
case unsupported:
raise NotImplementedError(f"Unsupported callback type '{unsupported}'.")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Any, Final, Optional

from botocore.config import Config
Expand Down Expand Up @@ -151,6 +152,18 @@ def _normalised_parameters_bindings(self, parameters: dict[str, str]) -> dict[st

return normalised_parameters

@staticmethod
def _sync2_api_output_of(typ: type, value: json) -> None:
def _replace_with_json_if_str(key: str) -> None:
inner_value = value.get(key)
if isinstance(inner_value, str):
value[key] = json.loads(inner_value)

match typ:
case DescribeExecutionOutput: # noqa
_replace_with_json_if_str("input")
_replace_with_json_if_str("output")

def _eval_service_task(self, env: Environment, parameters: dict) -> None:
api_action = camel_to_snake_case(self.resource.api_action)
sfn_client = self._get_sfn_client()
Expand All @@ -159,7 +172,7 @@ def _eval_service_task(self, env: Environment, parameters: dict) -> None:
self._normalise_botocore_response(self.resource.api_action, response)
env.stack.append(response)

def _sync_to_start_machine(self, env: Environment) -> None:
def _sync_to_start_machine(self, env: Environment, sync2_response: bool) -> None:
sfn_client = self._get_sfn_client()

submission_output: dict = env.stack.pop()
Expand All @@ -173,6 +186,10 @@ def _has_terminated() -> Optional[dict]:
execution_status: ExecutionStatus = describe_execution_output["status"]

if execution_status != ExecutionStatus.RUNNING:
if sync2_response:
self._sync2_api_output_of(
typ=DescribeExecutionOutput, value=describe_execution_output
)
self._normalise_botocore_response("describeexecution", describe_execution_output)
if execution_status == ExecutionStatus.SUCCEEDED:
return describe_execution_output
Expand Down Expand Up @@ -202,6 +219,13 @@ def _has_terminated() -> Optional[dict]:
def _sync(self, env: Environment) -> None:
match self.resource.api_action.lower():
case "startexecution":
self._sync_to_start_machine(env=env)
self._sync_to_start_machine(env=env, sync2_response=False)
case _:
super()._sync(env=env)

def _sync2(self, env: Environment) -> None:
match self.resource.api_action.lower():
case "startexecution":
self._sync_to_start_machine(env=env, sync2_response=True)
case _:
super()._sync2(env=env)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ class CallbackTemplates(TemplateLoader):
SFN_START_EXECUTION_SYNC: Final[str] = os.path.join(
_THIS_FOLDER, "statemachines/sfn_start_execution_sync.json5"
)
SFN_START_EXECUTION_SYNC2: Final[str] = os.path.join(
_THIS_FOLDER, "statemachines/sfn_start_execution_sync2.json5"
)
SQS_SUCCESS_ON_TASK_TOKEN: Final[str] = os.path.join(
_THIS_FOLDER, "statemachines/sqs_success_on_task_token.json5"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"Comment": "SFN_START_EXECUTION_SYNC:2",
"StartAt": "StartExecution",
"States": {
"StartExecution": {
"Type": "Task",
"Resource": "arn:aws:states:::states:startExecution.sync:2",
"Parameters": {
"StateMachineArn.$": "$.StateMachineArn",
"Input.$": "$.Input",
"Name.$": "$.Name"
},
"End": true,
}
}
}
47 changes: 47 additions & 0 deletions tests/aws/stepfunctions/v2/callback/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,53 @@ def test_start_execution_sync(
exec_input,
)

@markers.aws.unknown
def test_start_execution_sync2(
self,
aws_client,
create_iam_role_for_sfn,
create_state_machine,
sfn_snapshot,
):
sfn_snapshot.add_transformer(
JsonpathTransformer(
jsonpath="$..output.StartDate",
replacement="start-date",
replace_reference=False,
)
)
sfn_snapshot.add_transformer(
JsonpathTransformer(
jsonpath="$..output.StopDate",
replacement="stop-date",
replace_reference=False,
)
)

template_target = BT.load_sfn_template(BT.BASE_PASS_RESULT)
definition_target = json.dumps(template_target)
state_machine_arn_target = create(
create_iam_role_for_sfn,
create_state_machine,
sfn_snapshot,
definition_target,
)

template = CT.load_sfn_template(CT.SFN_START_EXECUTION_SYNC2)
definition = json.dumps(template)

exec_input = json.dumps(
{"StateMachineArn": state_machine_arn_target, "Input": None, "Name": "TestStartTarget"}
)
create_and_record_execution(
aws_client.stepfunctions,
create_iam_role_for_sfn,
create_state_machine,
sfn_snapshot,
definition,
exec_input,
)

@markers.aws.unknown
def test_start_execution_sync_delegate_failure(
self,
Expand Down
Loading