Thanks to visit codestin.com
Credit goes to github.com

Skip to content

fix SNS PublishBatch modifying batch context for all subscribers after filtering #7674

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions localstack/services/sns/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,8 @@ def get_subscription_attributes(
removed_attrs = ["sqs_queue_url"]
if "FilterPolicyScope" in sub and "FilterPolicy" not in sub:
removed_attrs.append("FilterPolicyScope")
elif "FilterPolicy" in sub and "FilterPolicyScope" not in sub:
sub["FilterPolicyScope"] = "MessageAttributes"

attributes = {k: v for k, v in sub.items() if k not in removed_attrs}
return GetSubscriptionAttributesResponse(Attributes=attributes)
Expand Down
18 changes: 13 additions & 5 deletions localstack/services/sns/publisher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import ast
import base64
import copy
import datetime
import hashlib
import json
Expand Down Expand Up @@ -1101,28 +1102,33 @@ def publish_batch_to_topic(self, ctx: SnsBatchPublishContext, topic_arn: str) ->
notifier = self.batch_topic_notifiers.get(protocol)
# does the notifier supports batching natively? for now, only SQS supports it
if notifier:
subscriber_ctx = ctx
messages_amount_before_filtering = len(ctx.messages)
ctx.messages = [
filtered_messages = [
message
for message in ctx.messages
if self._should_publish(ctx.store, message, subscriber)
]
if not ctx.messages:
if not filtered_messages:
LOG.debug(
"No messages match filter policy, not publishing batch from topic '%s' to subscription '%s'",
topic_arn,
subscriber["SubscriptionArn"],
)
return
continue

messages_amount = len(ctx.messages)
messages_amount = len(filtered_messages)
if messages_amount != messages_amount_before_filtering:
LOG.debug(
"After applying subscription filter, %s out of %s message(s) to be sent to '%s'",
messages_amount,
messages_amount_before_filtering,
subscriber["SubscriptionArn"],
)
# We need to copy the context to not overwrite the messages after filtering messages, otherwise we
# would filter on the same context for different subscribers
subscriber_ctx = copy.copy(ctx)
subscriber_ctx.messages = filtered_messages

LOG.debug(
"Topic '%s' batch publishing %s messages to subscribed '%s' with protocol '%s' (subscription '%s')",
Expand All @@ -1132,7 +1138,9 @@ def publish_batch_to_topic(self, ctx: SnsBatchPublishContext, topic_arn: str) ->
subscriber["Protocol"],
subscriber["SubscriptionArn"],
)
self.executor.submit(notifier.publish, context=ctx, subscriber=subscriber)
self.executor.submit(
notifier.publish, context=subscriber_ctx, subscriber=subscriber
)
else:
# if no batch support, fall back to sending them sequentially
notifier = self.topic_notifiers[subscriber["Protocol"]]
Expand Down
114 changes: 114 additions & 0 deletions tests/integration/test_sns.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def test_attribute_raw_subscribe(
snapshot.match("messages-response", response)

@pytest.mark.aws_validated
@pytest.mark.skip_snapshot_verify(paths=["$..Attributes.SubscriptionPrincipal"])
def test_filter_policy(
self,
sns_client,
Expand Down Expand Up @@ -308,6 +309,7 @@ def test_filter_policy(
assert num_msgs_2 == num_msgs_1

@pytest.mark.aws_validated
@pytest.mark.skip_snapshot_verify(paths=["$..Attributes.SubscriptionPrincipal"])
def test_exists_filter_policy(
self,
sns_client,
Expand Down Expand Up @@ -432,6 +434,7 @@ def get_filter_policy():
assert num_msgs_4 == num_msgs_3

@pytest.mark.aws_validated
@pytest.mark.skip_snapshot_verify(paths=["$..Attributes.SubscriptionPrincipal"])
def test_subscribe_sqs_queue(
self,
sns_client,
Expand Down Expand Up @@ -3263,3 +3266,114 @@ def test_publish_to_fifo_with_target_arn(self, sns_client, sns_create_topic):
MessageGroupId="123",
)
assert "MessageId" in response

@pytest.mark.aws_validated
@pytest.mark.skip_snapshot_verify(paths=["$..Attributes.SubscriptionPrincipal"])
def test_filter_policy_for_batch(
self,
sns_client,
sqs_client,
sqs_create_queue,
sns_create_topic,
sns_create_sqs_subscription,
snapshot,
):

topic_arn = sns_create_topic()["TopicArn"]
queue_url_with_filter = sqs_create_queue()
subscription_with_filter = sns_create_sqs_subscription(
topic_arn=topic_arn, queue_url=queue_url_with_filter
)
subscription_with_filter_arn = subscription_with_filter["SubscriptionArn"]

queue_url_no_filter = sqs_create_queue()
subscription_no_filter = sns_create_sqs_subscription(
topic_arn=topic_arn, queue_url=queue_url_no_filter
)
subscription_no_filter_arn = subscription_no_filter["SubscriptionArn"]

filter_policy = {"attr1": [{"numeric": [">", 0, "<=", 100]}]}
sns_client.set_subscription_attributes(
SubscriptionArn=subscription_with_filter_arn,
AttributeName="FilterPolicy",
AttributeValue=json.dumps(filter_policy),
)

response_attributes = sns_client.get_subscription_attributes(
SubscriptionArn=subscription_with_filter_arn
)
snapshot.match("subscription-attributes-with-filter", response_attributes)

response_attributes = sns_client.get_subscription_attributes(
SubscriptionArn=subscription_no_filter_arn
)
snapshot.match("subscription-attributes-no-filter", response_attributes)

sqs_wait_time = 4 if is_aws_cloud() else 1

response_before_publish_no_filter = sqs_client.receive_message(
QueueUrl=queue_url_with_filter, VisibilityTimeout=0, WaitTimeSeconds=sqs_wait_time
)
snapshot.match("messages-no-filter-before-publish", response_before_publish_no_filter)

response_before_publish_filter = sqs_client.receive_message(
QueueUrl=queue_url_with_filter, VisibilityTimeout=0, WaitTimeSeconds=sqs_wait_time
)
snapshot.match("messages-with-filter-before-publish", response_before_publish_filter)

# publish message that satisfies the filter policy, assert that message is received
message = "This is a test message"
message_attributes = {"attr1": {"DataType": "Number", "StringValue": "99"}}
sns_client.publish_batch(
TopicArn=topic_arn,
PublishBatchRequestEntries=[
{
"Id": "1",
"Message": message,
"MessageAttributes": message_attributes,
}
],
)

response_after_publish_no_filter = sqs_client.receive_message(
QueueUrl=queue_url_no_filter, VisibilityTimeout=0, WaitTimeSeconds=sqs_wait_time
)
snapshot.match("messages-no-filter-after-publish-ok", response_after_publish_no_filter)
sqs_client.delete_message(
QueueUrl=queue_url_no_filter,
ReceiptHandle=response_after_publish_no_filter["Messages"][0]["ReceiptHandle"],
)

response_after_publish_filter = sqs_client.receive_message(
QueueUrl=queue_url_with_filter, VisibilityTimeout=0, WaitTimeSeconds=sqs_wait_time
)
snapshot.match("messages-with-filter-after-publish-ok", response_after_publish_filter)
sqs_client.delete_message(
QueueUrl=queue_url_with_filter,
ReceiptHandle=response_after_publish_filter["Messages"][0]["ReceiptHandle"],
)

# publish message that does not satisfy the filter policy, assert that message is not received by the
# subscription with the filter and received by the other
sns_client.publish_batch(
TopicArn=topic_arn,
PublishBatchRequestEntries=[
{
"Id": "1",
"Message": "This is another test message",
"MessageAttributes": {"attr1": {"DataType": "Number", "StringValue": "111"}},
}
],
)

response_after_publish_no_filter = sqs_client.receive_message(
QueueUrl=queue_url_no_filter, VisibilityTimeout=0, WaitTimeSeconds=sqs_wait_time
)
# there should be 1 message in the queue, latest sent
snapshot.match("messages-no-filter-after-publish-ok-1", response_after_publish_no_filter)

response_after_publish_filter = sqs_client.receive_message(
QueueUrl=queue_url_with_filter, VisibilityTimeout=0, WaitTimeSeconds=sqs_wait_time
)
# there should be no messages in this queue
snapshot.match("messages-with-filter-after-publish-filtered", response_after_publish_filter)
Loading