Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[SFN] Support for Resource Tagging #8990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from botocore.exceptions import ClientError

from localstack.aws.api.stepfunctions import HistoryEventType, TaskFailedEventDetails
from localstack.aws.connect import connect_externally_to
from localstack.services.stepfunctions.asl.component.common.error_name.custom_error_name import (
CustomErrorName,
)
Expand All @@ -16,7 +17,6 @@
from localstack.services.stepfunctions.asl.eval.environment import Environment
from localstack.services.stepfunctions.asl.eval.event.event_detail import EventDetails
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str
from localstack.utils.aws import aws_stack
from localstack.utils.strings import camel_to_snake_case


Expand Down Expand Up @@ -65,7 +65,7 @@ def _eval_service_task(self, env: Environment, parameters: dict) -> None:
parameters["MessageBody"] = to_json_str(message_body)

api_action = camel_to_snake_case(self.resource.api_action)
sqs_client = aws_stack.connect_to_service("sqs", config=Config(parameter_validation=False))
sqs_client = connect_externally_to(config=Config(parameter_validation=False)).sqs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

response = getattr(sqs_client, api_action)(**parameters)
response.pop("ResponseMetadata", None)
env.stack.append(response)
41 changes: 41 additions & 0 deletions localstack/services/stepfunctions/backend/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
import json
from collections import OrderedDict
from datetime import datetime
from typing import Final, Optional

Expand All @@ -16,8 +17,11 @@
StateMachineStatus,
StateMachineType,
StateMachineVersionListItem,
Tag,
TagKeyList,
TagList,
TracingConfiguration,
ValidationException,
)
from localstack.utils.strings import long_uid

Expand Down Expand Up @@ -78,8 +82,44 @@ def itemise(self):


class StateMachineRevision(StateMachineInstance):
class TagManager:
_tags: Final[dict[str, Optional[str]]]

def __init__(self):
self._tags = OrderedDict()

@staticmethod
def _validate_key_value(key: str) -> None:
if not key:
raise ValidationException()

@staticmethod
def _validate_tag_value(value: str) -> None:
if value is None:
raise ValidationException()

def add_all(self, tags: TagList) -> None:
for tag in tags:
tag_key = tag["key"]
tag_value = tag["value"]
self._validate_key_value(key=tag_key)
self._validate_tag_value(value=tag_value)
self._tags[tag_key] = tag_value

def remove_all(self, keys: TagKeyList):
for key in keys:
self._validate_key_value(key=key)
self._tags.pop(key, None)

def to_tag_list(self) -> TagList:
tag_list = list()
for key, value in self._tags.items():
tag_list.append(Tag(key=key, value=value))
return tag_list

_next_version_number: int
versions: Final[dict[RevisionId, Arn]]
tag_manager: Final[TagManager]

def __init__(
self,
Expand All @@ -106,6 +146,7 @@ def __init__(
)
self.versions = dict()
self._version_number = 0
self.tag_manager = StateMachineRevision.TagManager()

def create_revision(
self, definition: Optional[str], role_arn: Optional[Arn]
Expand Down
46 changes: 46 additions & 0 deletions localstack/services/stepfunctions/provider_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ListExecutionsPageToken,
ListStateMachinesOutput,
ListStateMachineVersionsOutput,
ListTagsForResourceOutput,
LoggingConfiguration,
LongArn,
MissingRequiredParameter,
Expand All @@ -37,6 +38,7 @@
PageToken,
Publish,
PublishStateMachineVersionOutput,
ResourceNotFound,
ReverseOrder,
RevisionId,
SendTaskFailureOutput,
Expand All @@ -52,11 +54,15 @@
StateMachineType,
StepfunctionsApi,
StopExecutionOutput,
TagKeyList,
TagList,
TagResourceOutput,
TaskDoesNotExist,
TaskTimedOut,
TaskToken,
TraceHeader,
TracingConfiguration,
UntagResourceOutput,
UpdateStateMachineOutput,
ValidationException,
VersionDescription,
Expand Down Expand Up @@ -199,6 +205,10 @@ def create_state_machine(
tracing_config=request.get("tracingConfiguration"),
)

tags = request.get("tags")
if tags:
state_machine.tag_manager.add_all(tags)

state_machines[arn] = state_machine

create_output = CreateStateMachineOutput(
Expand Down Expand Up @@ -538,3 +548,39 @@ def publish_state_machine_version(
creationDate=state_machine_version.create_date,
stateMachineVersionArn=state_machine_version.arn,
)

def tag_resource(
self, context: RequestContext, resource_arn: Arn, tags: TagList
) -> TagResourceOutput:
# TODO: add tagging for activities.
state_machines = self.get_store(context).state_machines
state_machine = state_machines.get(resource_arn)
if not isinstance(state_machine, StateMachineRevision):
raise ResourceNotFound(f"Resource not found: '{resource_arn}'")

state_machine.tag_manager.add_all(tags)
return TagResourceOutput()

def untag_resource(
self, context: RequestContext, resource_arn: Arn, tag_keys: TagKeyList
) -> UntagResourceOutput:
# TODO: add untagging for activities.
state_machines = self.get_store(context).state_machines
state_machine = state_machines.get(resource_arn)
if not isinstance(state_machine, StateMachineRevision):
raise ResourceNotFound(f"Resource not found: '{resource_arn}'")

state_machine.tag_manager.remove_all(tag_keys)
return UntagResourceOutput()

def list_tags_for_resource(
self, context: RequestContext, resource_arn: Arn
) -> ListTagsForResourceOutput:
# TODO: add untagging for activities.
state_machines = self.get_store(context).state_machines
state_machine = state_machines.get(resource_arn)
if not isinstance(state_machine, StateMachineRevision):
raise ResourceNotFound(f"Resource not found: '{resource_arn}'")

tags: TagList = state_machine.tag_manager.to_tag_list()
return ListTagsForResourceOutput(tags=tags)
190 changes: 190 additions & 0 deletions tests/aws/services/stepfunctions/v2/test_sfn_api_tagging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import json

import pytest

from localstack.aws.api.stepfunctions import Tag
from localstack.testing.pytest import markers
from localstack.testing.snapshots.transformer import RegexTransformer
from localstack.utils.strings import short_uid
from tests.aws.services.stepfunctions.templates.base.base_templates import BaseTemplate
from tests.aws.services.stepfunctions.utils import is_old_provider

pytestmark = pytest.mark.skipif(
condition=is_old_provider(), reason="Test suite for v2 provider only."
)


@markers.snapshot.skip_snapshot_verify(
paths=["$..loggingConfiguration", "$..tracingConfiguration", "$..previousEventId"]
)
class TestSnfApiTagging:
@markers.aws.validated
@pytest.mark.parametrize(
"tag_list",
[
[],
[Tag(key="key1", value="value1")],
[Tag(key="key1", value="")],
[Tag(key="key1", value="value1"), Tag(key="key1", value="value1")],
[Tag(key="key1", value="value1"), Tag(key="key2", value="value2")],
],
)
def test_tag_state_machine(
self, create_iam_role_for_sfn, create_state_machine, sfn_snapshot, aws_client, tag_list
):
snf_role_arn = create_iam_role_for_sfn()
sfn_snapshot.add_transformer(RegexTransformer(snf_role_arn, "snf_role_arn"))

definition = BaseTemplate.load_sfn_template(BaseTemplate.BASE_PASS_RESULT)
definition_str = json.dumps(definition)

sm_name = f"statemachine_{short_uid()}"
creation_resp_1 = create_state_machine(
name=sm_name, definition=definition_str, roleArn=snf_role_arn
)
state_machine_arn = creation_resp_1["stateMachineArn"]
sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sm_create_arn(creation_resp_1, 0))
sfn_snapshot.match("creation_resp_1", creation_resp_1)

tag_resource_resp = aws_client.stepfunctions.tag_resource(
resourceArn=state_machine_arn, tags=tag_list
)
sfn_snapshot.match("tag_resource_resp", tag_resource_resp)

list_resources_res = aws_client.stepfunctions.list_tags_for_resource(
resourceArn=state_machine_arn
)
sfn_snapshot.match("list_resources_res", list_resources_res)

@markers.aws.validated
@pytest.mark.parametrize(
"tag_list",
[
None,
[Tag(key="", value="value")],
[Tag(key=None, value="value")],
[Tag(key="key1", value=None)],
],
)
def test_tag_invalid_state_machine(
self, create_iam_role_for_sfn, create_state_machine, sfn_snapshot, aws_client, tag_list
):
snf_role_arn = create_iam_role_for_sfn()
sfn_snapshot.add_transformer(RegexTransformer(snf_role_arn, "snf_role_arn"))

definition = BaseTemplate.load_sfn_template(BaseTemplate.BASE_PASS_RESULT)
definition_str = json.dumps(definition)

sm_name = f"statemachine_{short_uid()}"
creation_resp_1 = create_state_machine(
name=sm_name, definition=definition_str, roleArn=snf_role_arn
)
state_machine_arn = creation_resp_1["stateMachineArn"]
sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sm_create_arn(creation_resp_1, 0))
sfn_snapshot.match("creation_resp_1", creation_resp_1)

with pytest.raises(Exception) as error:
aws_client.stepfunctions.tag_resource(resourceArn=state_machine_arn, tags=tag_list)
sfn_snapshot.match("error", error.value)

@markers.aws.validated
def test_tag_state_machine_version(
self,
create_iam_role_for_sfn,
create_state_machine,
sfn_snapshot,
aws_client,
):
snf_role_arn = create_iam_role_for_sfn()
sfn_snapshot.add_transformer(RegexTransformer(snf_role_arn, "snf_role_arn"))

definition = BaseTemplate.load_sfn_template(BaseTemplate.BASE_PASS_RESULT)
definition_str = json.dumps(definition)

sm_name = f"statemachine_{short_uid()}"
creation_resp_1 = create_state_machine(
name=sm_name, definition=definition_str, roleArn=snf_role_arn
)
state_machine_arn = creation_resp_1["stateMachineArn"]
sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sm_create_arn(creation_resp_1, 0))
sfn_snapshot.match("creation_resp_1", creation_resp_1)

publish_resp = aws_client.stepfunctions.publish_state_machine_version(
stateMachineArn=state_machine_arn
)
state_machine_version_arn = publish_resp["stateMachineVersionArn"]
sfn_snapshot.match("publish_resp", publish_resp)

with pytest.raises(Exception) as error:
aws_client.stepfunctions.tag_resource(
resourceArn=state_machine_version_arn, tags=[Tag(key="key1", value="value1")]
)
sfn_snapshot.match("error", error.value)

@markers.aws.validated
@pytest.mark.parametrize(
"tag_keys",
[
[],
["key1"],
["key1", "key1"],
["key1", "key2"],
],
)
def test_untag_state_machine(
self, create_iam_role_for_sfn, create_state_machine, sfn_snapshot, aws_client, tag_keys
):
snf_role_arn = create_iam_role_for_sfn()
sfn_snapshot.add_transformer(RegexTransformer(snf_role_arn, "snf_role_arn"))

definition = BaseTemplate.load_sfn_template(BaseTemplate.BASE_PASS_RESULT)
definition_str = json.dumps(definition)

sm_name = f"statemachine_{short_uid()}"
creation_resp_1 = create_state_machine(
name=sm_name, definition=definition_str, roleArn=snf_role_arn
)
state_machine_arn = creation_resp_1["stateMachineArn"]
sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sm_create_arn(creation_resp_1, 0))
sfn_snapshot.match("creation_resp_1", creation_resp_1)

tag_resource_resp = aws_client.stepfunctions.tag_resource(
resourceArn=state_machine_arn, tags=[Tag(key="key1", value="value1")]
)
sfn_snapshot.match("tag_resource_resp", tag_resource_resp)

untag_resource_resp = aws_client.stepfunctions.untag_resource(
resourceArn=state_machine_arn, tagKeys=tag_keys
)
sfn_snapshot.match("untag_resource_resp", untag_resource_resp)

list_resources_res = aws_client.stepfunctions.list_tags_for_resource(
resourceArn=state_machine_arn
)
sfn_snapshot.match("list_resources_res", list_resources_res)

@markers.aws.validated
def test_create_state_machine(
self, create_iam_role_for_sfn, create_state_machine, sfn_snapshot, aws_client
):
snf_role_arn = create_iam_role_for_sfn()
sfn_snapshot.add_transformer(RegexTransformer(snf_role_arn, "snf_role_arn"))

definition = BaseTemplate.load_sfn_template(BaseTemplate.BASE_PASS_RESULT)
definition_str = json.dumps(definition)

sm_name = f"statemachine_{short_uid()}"
creation_resp_1 = create_state_machine(
name=sm_name,
definition=definition_str,
roleArn=snf_role_arn,
tags=[Tag(key="key1", value="value1"), Tag(key="key2", value="value2")],
)
state_machine_arn = creation_resp_1["stateMachineArn"]
sfn_snapshot.add_transformer(sfn_snapshot.transform.sfn_sm_create_arn(creation_resp_1, 0))
sfn_snapshot.match("creation_resp_1", creation_resp_1)

list_resources_res = aws_client.stepfunctions.list_tags_for_resource(
resourceArn=state_machine_arn
)
sfn_snapshot.match("list_resources_res", list_resources_res)
Loading