Thanks to visit codestin.com
Credit goes to github.com

Skip to content

SNS: validate cross-region behavior #12673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions localstack-core/localstack/services/sns/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_moto_backend(account_id: str, region_name: str) -> SNSBackend:
return sns_backends[account_id][region_name]

@staticmethod
def _get_topic(arn: str, context: RequestContext, multiregion: bool = True) -> Topic:
def _get_topic(arn: str, context: RequestContext) -> Topic:
"""
:param arn: the Topic ARN
:param context: the RequestContext of the request
Expand All @@ -145,13 +145,13 @@ def _get_topic(arn: str, context: RequestContext, multiregion: bool = True) -> T
:return: the Moto model Topic
"""
arn_data = parse_and_validate_topic_arn(arn)
if context.region != arn_data["region"]:
raise InvalidParameterException("Invalid parameter: TopicArn")

try:
return sns_backends[arn_data["account"]][context.region].topics[arn]
except KeyError:
if multiregion or context.region == arn_data["region"]:
raise NotFoundException("Topic does not exist")
else:
raise InvalidParameterException("Invalid parameter: TopicArn")
raise NotFoundException("Topic does not exist")

def get_topic_attributes(
self, context: RequestContext, topic_arn: topicARN, **kwargs
Expand Down Expand Up @@ -179,6 +179,18 @@ def get_topic_attributes(

return moto_response

def set_topic_attributes(
self,
context: RequestContext,
topic_arn: topicARN,
attribute_name: attributeName,
attribute_value: attributeValue | None = None,
**kwargs,
) -> None:
# validate the topic first
self._get_topic(topic_arn, context)
call_moto(context)

def publish_batch(
self,
context: RequestContext,
Expand All @@ -193,7 +205,7 @@ def publish_batch(

parsed_arn = parse_and_validate_topic_arn(topic_arn)
store = self.get_store(account_id=parsed_arn["account"], region_name=context.region)
moto_topic = self._get_topic(topic_arn, context, multiregion=False)
moto_topic = self._get_topic(topic_arn, context)

ids = [entry["Id"] for entry in publish_batch_request_entries]
if len(set(ids)) != len(publish_batch_request_entries):
Expand Down Expand Up @@ -561,7 +573,7 @@ def publish(
raise InvalidParameterException(
"Invalid parameter: The MessageGroupId parameter is required for FIFO topics",
)
topic_model = self._get_topic(topic_or_target_arn, context, multiregion=False)
topic_model = self._get_topic(topic_or_target_arn, context)
if topic_model.content_based_deduplication == "false":
if not message_deduplication_id:
raise InvalidParameterException(
Expand Down Expand Up @@ -608,7 +620,7 @@ def publish(
elif not platform_endpoint.enabled:
raise EndpointDisabledException("Endpoint is disabled")
else:
topic_model = self._get_topic(topic_or_target_arn, context, multiregion=False)
topic_model = self._get_topic(topic_or_target_arn, context)
else:
# use the store from the request context
store = self.get_store(account_id=context.account_id, region_name=context.region)
Expand Down Expand Up @@ -659,6 +671,9 @@ def subscribe(
) -> SubscribeResponse:
# TODO: check validation ordering
parsed_topic_arn = parse_and_validate_topic_arn(topic_arn)
if context.region != parsed_topic_arn["region"]:
raise InvalidParameterException("Invalid parameter: TopicArn")

store = self.get_store(account_id=parsed_topic_arn["account"], region_name=context.region)

if topic_arn not in store.topic_subscriptions:
Expand Down Expand Up @@ -834,8 +849,11 @@ def existing_tag_index(_item):
return TagResourceResponse()

def delete_topic(self, context: RequestContext, topic_arn: topicARN, **kwargs) -> None:
call_moto(context)
parsed_arn = parse_and_validate_topic_arn(topic_arn)
if context.region != parsed_arn["region"]:
raise InvalidParameterException("Invalid parameter: TopicArn")

call_moto(context)
store = self.get_store(account_id=parsed_arn["account"], region_name=context.region)
topic_subscriptions = store.topic_subscriptions.pop(topic_arn, [])
for topic_sub in topic_subscriptions:
Expand Down
146 changes: 143 additions & 3 deletions tests/aws/services/sns/test_sns.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,19 @@ def test_create_topic_test_arn(self, sns_create_topic, snapshot, aws_client, acc
aws_client.sns.get_topic_attributes(TopicArn=topic_arn)
snapshot.match("topic-not-exists", e.value.response)

@markers.aws.validated
def test_delete_topic_idempotency(self, sns_create_topic, aws_client, snapshot):
topic_arn = sns_create_topic()["TopicArn"]

response = aws_client.sns.delete_topic(TopicArn=topic_arn)
snapshot.match("delete-topic", response)

with pytest.raises(ClientError):
aws_client.sns.get_topic_attributes(TopicArn=topic_arn)

delete_topic = aws_client.sns.delete_topic(TopicArn=topic_arn)
snapshot.match("delete-topic-again", delete_topic)

@markers.aws.validated
def test_create_duplicate_topic_with_more_tags(self, sns_create_topic, snapshot, aws_client):
topic_name = "test-duplicated-topic-more-tags"
Expand Down Expand Up @@ -4270,7 +4283,7 @@ def sqs_secondary_client(self, secondary_aws_client):
return secondary_aws_client.sqs

@markers.aws.only_localstack
def test_cross_account_access(self, sns_primary_client, sns_secondary_client):
def test_cross_account_access(self, sns_primary_client, sns_secondary_client, sns_create_topic):
# Cross-account access is supported for below operations.
# This list is taken from ActionName param of the AddPermissions operation
#
Expand All @@ -4284,7 +4297,8 @@ def test_cross_account_access(self, sns_primary_client, sns_secondary_client):
# - DeleteTopic

topic_name = f"topic-{short_uid()}"
topic_arn = sns_primary_client.create_topic(Name=topic_name)["TopicArn"]
# sns_create_topic uses the primary client by default
topic_arn = sns_create_topic(Name=topic_name)["TopicArn"]

assert sns_secondary_client.set_topic_attributes(
TopicArn=topic_arn, AttributeName="DisplayName", AttributeValue="xenon"
Expand Down Expand Up @@ -4325,13 +4339,15 @@ def test_cross_account_access(self, sns_primary_client, sns_secondary_client):
@markers.aws.only_localstack
def test_cross_account_publish_to_sqs(
self,
sns_create_topic,
secondary_account_id,
region_name,
sns_primary_client,
sns_secondary_client,
sqs_primary_client,
sqs_secondary_client,
sqs_get_queue_arn,
cleanups,
):
"""
This test validates that we can publish to SQS queues that are not in the default account, and that another
Expand All @@ -4342,18 +4358,20 @@ def test_cross_account_publish_to_sqs(
"""

topic_name = "sample_topic"
topic_1 = sns_primary_client.create_topic(Name=topic_name)
topic_1 = sns_create_topic(Name=topic_name)
topic_1_arn = topic_1["TopicArn"]

# create a queue with the primary AccountId
queue_name = "sample_queue"
queue_1 = sqs_primary_client.create_queue(QueueName=queue_name)
queue_1_url = queue_1["QueueUrl"]
cleanups.append(lambda: sqs_primary_client.delete_queue(QueueUrl=queue_1_url))
queue_1_arn = sqs_get_queue_arn(queue_1_url)

# create a queue with the secondary AccountId
queue_2 = sqs_secondary_client.create_queue(QueueName=queue_name)
queue_2_url = queue_2["QueueUrl"]
cleanups.append(lambda: sqs_secondary_client.delete_queue(QueueUrl=queue_2_url))
# test that we get the right queue URL at the same time, even if we use the primary client
queue_2_arn = sqs_queue_arn(
queue_2_url,
Expand All @@ -4365,6 +4383,7 @@ def test_cross_account_publish_to_sqs(
queue_name_2 = "sample_queue_two"
queue_3 = sqs_secondary_client.create_queue(QueueName=queue_name_2)
queue_3_url = queue_3["QueueUrl"]
cleanups.append(lambda: sqs_secondary_client.delete_queue(QueueUrl=queue_3_url))
# test that we get the right queue URL at the same time, even if we use the primary client
queue_3_arn = sqs_queue_arn(
queue_3_url,
Expand Down Expand Up @@ -4427,6 +4446,127 @@ def get_messages_from_queues(message_content: str):
get_messages_from_queues("TestMessageSecondary")


class TestSNSMultiRegions:
@pytest.fixture
def sns_region1_client(self, aws_client):
return aws_client.sns

@pytest.fixture
def sns_region2_client(self, aws_client_factory, secondary_region_name):
return aws_client_factory(region_name=secondary_region_name).sns

@pytest.fixture
def sqs_region2_client(self, aws_client_factory, secondary_region_name):
return aws_client_factory(region_name=secondary_region_name).sqs

@markers.aws.validated
def test_cross_region_access(self, sns_region1_client, sns_region2_client, snapshot, cleanups):
# We do not have a list of supported Cross-region access for operations.
# This test is validating that Cross-account does not mean Cross-region most of the time

topic_name = f"topic-{short_uid()}"
topic_arn = sns_region1_client.create_topic(Name=topic_name)["TopicArn"]
cleanups.append(lambda: sns_region1_client.delete_topic(TopicArn=topic_arn))

with pytest.raises(ClientError) as e:
sns_region2_client.set_topic_attributes(
TopicArn=topic_arn, AttributeName="DisplayName", AttributeValue="xenon"
)
snapshot.match("set-topic-attrs", e.value.response)

with pytest.raises(ClientError) as e:
sns_region2_client.get_topic_attributes(TopicArn=topic_arn)
snapshot.match("get-topic-attrs", e.value.response)

with pytest.raises(ClientError) as e:
sns_region2_client.publish(TopicArn=topic_arn, Message="hello world")
snapshot.match("cross-region-publish-forbidden", e.value.response)

with pytest.raises(ClientError) as e:
sns_region2_client.subscribe(
TopicArn=topic_arn, Protocol="email", Endpoint="[email protected]"
)
Comment on lines +4486 to +4488
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice email 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copy pasted it from another test, I hadn't even realized 😂

snapshot.match("cross-region-subscribe", e.value.response)

with pytest.raises(ClientError) as e:
sns_region2_client.list_subscriptions_by_topic(TopicArn=topic_arn)
snapshot.match("list-subs", e.value.response)

with pytest.raises(ClientError) as e:
sns_region2_client.delete_topic(TopicArn=topic_arn)
snapshot.match("delete-topic", e.value.response)

@markers.aws.validated
def test_cross_region_delivery_sqs(
self,
sns_region1_client,
sns_region2_client,
sqs_region2_client,
sns_create_topic,
sqs_create_queue,
sns_allow_topic_sqs_queue,
cleanups,
snapshot,
):
topic_arn = sns_create_topic()["TopicArn"]

queue_url = sqs_create_queue()
response = sqs_region2_client.create_queue(QueueName=f"queue-{short_uid()}")
queue_url = response["QueueUrl"]
cleanups.append(lambda: sqs_region2_client.delete_queue(QueueUrl=queue_url))

queue_arn = sqs_region2_client.get_queue_attributes(
QueueUrl=queue_url, AttributeNames=["QueueArn"]
)["Attributes"]["QueueArn"]

# allow topic to write to sqs queue
sqs_region2_client.set_queue_attributes(
QueueUrl=queue_url,
Attributes={
"Policy": json.dumps(
{
"Statement": [
{
"Effect": "Allow",
"Principal": {"Service": "sns.amazonaws.com"},
"Action": "sqs:SendMessage",
"Resource": queue_arn,
"Condition": {"ArnEquals": {"aws:SourceArn": topic_arn}},
}
]
}
)
},
)

# connect sns topic to sqs
with pytest.raises(ClientError) as e:
sns_region2_client.subscribe(TopicArn=topic_arn, Protocol="sqs", Endpoint=queue_arn)
snapshot.match("subscribe-cross-region", e.value.response)

subscription = sns_region1_client.subscribe(
TopicArn=topic_arn, Protocol="sqs", Endpoint=queue_arn
)
snapshot.match("subscribe-same-region", subscription)

message = "This is a test message"
# we already test that publishing from another region is forbidden with `test_topic_publish_another_region`
sns_region1_client.publish(
TopicArn=topic_arn,
Message=message,
MessageAttributes={"attr1": {"DataType": "Number", "StringValue": "99.12"}},
)

# assert that message is received
response = sqs_region2_client.receive_message(
QueueUrl=queue_url,
VisibilityTimeout=0,
MessageAttributeNames=["All"],
WaitTimeSeconds=4,
)
snapshot.match("messages", response)


class TestSNSPublishDelivery:
@markers.aws.validated
@markers.snapshot.skip_snapshot_verify(
Expand Down
Loading
Loading