diff --git a/localstack-core/localstack/services/dynamodb/utils.py b/localstack-core/localstack/services/dynamodb/utils.py index 995458b2deed7..4ff065440abec 100644 --- a/localstack-core/localstack/services/dynamodb/utils.py +++ b/localstack-core/localstack/services/dynamodb/utils.py @@ -20,10 +20,18 @@ TableName, Update, ) +from localstack.aws.api.dynamodbstreams import ( + ResourceNotFoundException as DynamoDBStreamsResourceNotFoundException, +) from localstack.aws.connect import connect_to from localstack.constants import INTERNAL_AWS_SECRET_ACCESS_KEY from localstack.http import Response -from localstack.utils.aws.arns import dynamodb_table_arn, get_partition +from localstack.utils.aws.arns import ( + dynamodb_stream_arn, + dynamodb_table_arn, + get_partition, + parse_arn, +) from localstack.utils.json import canonical_json from localstack.utils.testutil import list_all_resources @@ -348,3 +356,32 @@ def _convert_arn(matchobj): # update x-amz-crc32 header required by some clients response.headers["x-amz-crc32"] = crc32(response.data) & 0xFFFFFFFF + + +def change_region_in_ddb_stream_arn(arn: str, region: str) -> str: + """ + Modify the ARN or a DynamoDB Stream by changing its region. + We need this logic when dealing with global tables, as we create a stream only in the originating region, and we + need to modify the ARN to mimic the stream of the replica regions. + """ + arn_data = parse_arn(arn) + if arn_data["region"] == region: + return arn + + if arn_data["service"] != "dynamodb": + raise Exception(f"{arn} is not a DynamoDB Streams ARN") + + # Note: a DynamoDB Streams ARN has the following pattern: + # arn:aws:dynamodb:::table//stream/ + resource_splits = arn_data["resource"].split("/") + if len(resource_splits) != 4: + raise DynamoDBStreamsResourceNotFoundException( + f"The format of the '{arn}' ARN is not valid" + ) + + return dynamodb_stream_arn( + table_name=resource_splits[1], + latest_stream_label=resource_splits[-1], + account_id=arn_data["account"], + region_name=region, + ) diff --git a/localstack-core/localstack/services/dynamodbstreams/dynamodbstreams_api.py b/localstack-core/localstack/services/dynamodbstreams/dynamodbstreams_api.py index 84079dbbf3d6f..e9164465fdd57 100644 --- a/localstack-core/localstack/services/dynamodbstreams/dynamodbstreams_api.py +++ b/localstack-core/localstack/services/dynamodbstreams/dynamodbstreams_api.py @@ -5,8 +5,10 @@ from bson.json_util import dumps from localstack import config +from localstack.aws.api import RequestContext from localstack.aws.api.dynamodbstreams import StreamStatus, StreamViewType, TableName from localstack.aws.connect import connect_to +from localstack.services.dynamodb.v2.provider import DynamoDBProvider from localstack.services.dynamodbstreams.models import DynamoDbStreamsStore, dynamodbstreams_stores from localstack.utils.aws import arns, resources from localstack.utils.common import now_utc @@ -211,3 +213,23 @@ def get_shard_id(stream: Dict, kinesis_shard_id: str) -> str: stream["shards_id_map"][kinesis_shard_id] = ddb_stream_shard_id return ddb_stream_shard_id + + +def get_original_region( + context: RequestContext, stream_arn: str | None = None, table_name: str | None = None +) -> str: + """ + In DDB Global tables, we forward all the requests to the original region, instead of really replicating the data. + Since each table has a separate stream associated, we need to have a similar forwarding logic for DDB Streams. + To determine the original region, we need the table name, that can be either provided here or determined from the + ARN of the stream. + """ + if not stream_arn and not table_name: + LOG.debug( + "No Stream ARN or table name provided. Returning region '%s' from the request", + context.region, + ) + return context.region + + table_name = table_name or table_name_from_stream_arn(stream_arn) + return DynamoDBProvider.get_global_table_region(context=context, table_name=table_name) diff --git a/localstack-core/localstack/services/dynamodbstreams/provider.py b/localstack-core/localstack/services/dynamodbstreams/provider.py index fc8d0050c4ea6..6c9548bb81ebf 100644 --- a/localstack-core/localstack/services/dynamodbstreams/provider.py +++ b/localstack-core/localstack/services/dynamodbstreams/provider.py @@ -24,10 +24,12 @@ TableName, ) from localstack.aws.connect import connect_to +from localstack.services.dynamodb.utils import change_region_in_ddb_stream_arn from localstack.services.dynamodbstreams.dynamodbstreams_api import ( get_dynamodbstreams_store, get_kinesis_client, get_kinesis_stream_name, + get_original_region, get_shard_id, kinesis_shard_id, stream_name_from_stream_arn, @@ -47,6 +49,13 @@ class DynamoDBStreamsProvider(DynamodbstreamsApi, ServiceLifecycleHook): + shard_to_region: dict[str, str] + """Map a shard iterator to the originating region. This is used in case of replica tables, as LocalStack keeps the + data in one region only, redirecting all the requests from replica regions.""" + + def __init__(self): + self.shard_to_region = {} + def describe_stream( self, context: RequestContext, @@ -55,13 +64,17 @@ def describe_stream( exclusive_start_shard_id: ShardId = None, **kwargs, ) -> DescribeStreamOutput: - store = get_dynamodbstreams_store(context.account_id, context.region) - kinesis = get_kinesis_client(account_id=context.account_id, region_name=context.region) + og_region = get_original_region(context=context, stream_arn=stream_arn) + store = get_dynamodbstreams_store(context.account_id, og_region) + kinesis = get_kinesis_client(account_id=context.account_id, region_name=og_region) for stream in store.ddb_streams.values(): - if stream["StreamArn"] == stream_arn: + _stream_arn = stream_arn + if context.region != og_region: + _stream_arn = change_region_in_ddb_stream_arn(_stream_arn, og_region) + if stream["StreamArn"] == _stream_arn: # get stream details dynamodb = connect_to( - aws_access_key_id=context.account_id, region_name=context.region + aws_access_key_id=context.account_id, region_name=og_region ).dynamodb table_name = table_name_from_stream_arn(stream["StreamArn"]) stream_name = get_kinesis_stream_name(table_name) @@ -90,6 +103,7 @@ def describe_stream( stream["Shards"] = stream_shards stream_description = select_from_typed_dict(StreamDescription, stream) + stream_description["StreamArn"] = _stream_arn return DescribeStreamOutput(StreamDescription=stream_description) raise ResourceNotFoundException( @@ -98,11 +112,17 @@ def describe_stream( @handler("GetRecords", expand=False) def get_records(self, context: RequestContext, payload: GetRecordsInput) -> GetRecordsOutput: - kinesis = get_kinesis_client(account_id=context.account_id, region_name=context.region) - prefix, _, payload["ShardIterator"] = payload["ShardIterator"].rpartition("|") + _shard_iterator = payload["ShardIterator"] + region_name = context.region + if payload["ShardIterator"] in self.shard_to_region: + region_name = self.shard_to_region[_shard_iterator] + + kinesis = get_kinesis_client(account_id=context.account_id, region_name=region_name) + prefix, _, payload["ShardIterator"] = _shard_iterator.rpartition("|") try: kinesis_records = kinesis.get_records(**payload) except kinesis.exceptions.ExpiredIteratorException: + self.shard_to_region.pop(_shard_iterator, None) LOG.debug("Shard iterator for underlying kinesis stream expired") raise ExpiredIteratorException("Shard iterator has expired") result = { @@ -113,6 +133,11 @@ def get_records(self, context: RequestContext, payload: GetRecordsInput) -> GetR record_data = loads(record["Data"]) record_data["dynamodb"]["SequenceNumber"] = record["SequenceNumber"] result["Records"].append(record_data) + + # Similar as the logic in GetShardIterator, we need to track the originating region when we get the + # NextShardIterator in the results. + if region_name != context.region and "NextShardIterator" in result: + self.shard_to_region[result["NextShardIterator"]] = region_name return GetRecordsOutput(**result) def get_shard_iterator( @@ -125,8 +150,9 @@ def get_shard_iterator( **kwargs, ) -> GetShardIteratorOutput: stream_name = stream_name_from_stream_arn(stream_arn) + og_region = get_original_region(context=context, stream_arn=stream_arn) stream_shard_id = kinesis_shard_id(shard_id) - kinesis = get_kinesis_client(account_id=context.account_id, region_name=context.region) + kinesis = get_kinesis_client(account_id=context.account_id, region_name=og_region) kwargs = {"StartingSequenceNumber": sequence_number} if sequence_number else {} result = kinesis.get_shard_iterator( @@ -138,6 +164,11 @@ def get_shard_iterator( del result["ResponseMetadata"] # TODO not quite clear what the |1| exactly denotes, because at AWS it's sometimes other numbers result["ShardIterator"] = f"{stream_arn}|1|{result['ShardIterator']}" + + # In case of a replica table, we need to keep track of the real region originating the shard iterator. + # This region will be later used in GetRecords to redirect to the originating region, holding the data. + if og_region != context.region: + self.shard_to_region[result["ShardIterator"]] = og_region return GetShardIteratorOutput(**result) def list_streams( @@ -148,8 +179,17 @@ def list_streams( exclusive_start_stream_arn: StreamArn = None, **kwargs, ) -> ListStreamsOutput: - store = get_dynamodbstreams_store(context.account_id, context.region) + og_region = get_original_region(context=context, table_name=table_name) + store = get_dynamodbstreams_store(context.account_id, og_region) result = [select_from_typed_dict(Stream, res) for res in store.ddb_streams.values()] if table_name: - result = [res for res in result if res["TableName"] == table_name] + result: list[Stream] = [res for res in result if res["TableName"] == table_name] + # If this is a stream from a table replica, we need to change the region in the stream ARN, as LocalStack + # keeps a stream only in the originating region. + if context.region != og_region: + for stream in result: + stream["StreamArn"] = change_region_in_ddb_stream_arn( + stream["StreamArn"], context.region + ) + return ListStreamsOutput(Streams=result) diff --git a/localstack-core/localstack/services/dynamodbstreams/v2/provider.py b/localstack-core/localstack/services/dynamodbstreams/v2/provider.py index 5f6a86150b315..a91fbc592a992 100644 --- a/localstack-core/localstack/services/dynamodbstreams/v2/provider.py +++ b/localstack-core/localstack/services/dynamodbstreams/v2/provider.py @@ -15,7 +15,8 @@ ) from localstack.services.dynamodb.server import DynamodbServer from localstack.services.dynamodb.utils import modify_ddblocal_arns -from localstack.services.dynamodb.v2.provider import DynamoDBProvider +from localstack.services.dynamodb.v2.provider import DynamoDBProvider, modify_context_region +from localstack.services.dynamodbstreams.dynamodbstreams_api import get_original_region from localstack.services.plugins import ServiceLifecycleHook from localstack.utils.aws.arns import parse_arn @@ -23,8 +24,13 @@ class DynamoDBStreamsProvider(DynamodbstreamsApi, ServiceLifecycleHook): + shard_to_region: dict[str, str] + """Map a shard iterator to the originating region. This is used in case of replica tables, as LocalStack keeps the + data in one region only, redirecting all the requests from replica regions.""" + def __init__(self): self.server = DynamodbServer.get() + self.shard_to_region = {} def on_after_init(self): # add response processor specific to ddblocal @@ -33,6 +39,20 @@ def on_after_init(self): def on_before_start(self): self.server.start_dynamodb() + def _forward_request( + self, context: RequestContext, region: str | None, service_request: ServiceRequest + ) -> ServiceResponse: + """ + Modify the context region and then forward request to DynamoDB Local. + + This is used for operations impacted by global tables. In LocalStack, a single copy of global table + is kept, and any requests to replicated tables are forwarded to this original table. + """ + if region: + with modify_context_region(context, region): + return self.forward_request(context, service_request=service_request) + return self.forward_request(context, service_request=service_request) + def forward_request( self, context: RequestContext, service_request: ServiceRequest = None ) -> ServiceResponse: @@ -55,9 +75,12 @@ def describe_stream( context: RequestContext, payload: DescribeStreamInput, ) -> DescribeStreamOutput: + global_table_region = get_original_region(context=context, stream_arn=payload["StreamArn"]) request = payload.copy() request["StreamArn"] = self.modify_stream_arn_for_ddb_local(request.get("StreamArn", "")) - return self.forward_request(context, request) + return self._forward_request( + context=context, service_request=request, region=global_table_region + ) @handler("GetRecords", expand=False) def get_records(self, context: RequestContext, payload: GetRecordsInput) -> GetRecordsOutput: @@ -65,17 +88,43 @@ def get_records(self, context: RequestContext, payload: GetRecordsInput) -> GetR request["ShardIterator"] = self.modify_stream_arn_for_ddb_local( request.get("ShardIterator", "") ) - return self.forward_request(context, request) + region = self.shard_to_region.pop(request["ShardIterator"], None) + response = self._forward_request(context=context, region=region, service_request=request) + # Similar as the logic in GetShardIterator, we need to track the originating region when we get the + # NextShardIterator in the results. + if ( + region + and region != context.region + and (next_shard := response.get("NextShardIterator")) + ): + self.shard_to_region[next_shard] = region + return response @handler("GetShardIterator", expand=False) def get_shard_iterator( self, context: RequestContext, payload: GetShardIteratorInput ) -> GetShardIteratorOutput: + global_table_region = get_original_region(context=context, stream_arn=payload["StreamArn"]) request = payload.copy() request["StreamArn"] = self.modify_stream_arn_for_ddb_local(request.get("StreamArn", "")) - return self.forward_request(context, request) + response = self._forward_request( + context=context, service_request=request, region=global_table_region + ) + + # In case of a replica table, we need to keep track of the real region originating the shard iterator. + # This region will be later used in GetRecords to redirect to the originating region, holding the data. + if global_table_region != context.region and ( + shard_iterator := response.get("ShardIterator") + ): + self.shard_to_region[shard_iterator] = global_table_region + return response @handler("ListStreams", expand=False) def list_streams(self, context: RequestContext, payload: ListStreamsInput) -> ListStreamsOutput: + global_table_region = get_original_region( + context=context, stream_arn=payload.get("TableName") + ) # TODO: look into `ExclusiveStartStreamArn` param - return self.forward_request(context, payload) + return self._forward_request( + context=context, service_request=payload, region=global_table_region + ) diff --git a/localstack-core/localstack/testing/pytest/fixtures.py b/localstack-core/localstack/testing/pytest/fixtures.py index b89d5aedf2a87..5c282ea8fcbc5 100644 --- a/localstack-core/localstack/testing/pytest/fixtures.py +++ b/localstack-core/localstack/testing/pytest/fixtures.py @@ -792,11 +792,10 @@ def is_stream_ready(): @pytest.fixture def wait_for_dynamodb_stream_ready(aws_client): - def _wait_for_stream_ready(stream_arn: str): + def _wait_for_stream_ready(stream_arn: str, client=None): def is_stream_ready(): - describe_stream_response = aws_client.dynamodbstreams.describe_stream( - StreamArn=stream_arn - ) + ddb_client = client or aws_client.dynamodbstreams + describe_stream_response = ddb_client.describe_stream(StreamArn=stream_arn) return describe_stream_response["StreamDescription"]["StreamStatus"] == "ENABLED" return poll_condition(is_stream_ready) diff --git a/localstack-core/localstack/testing/snapshots/transformer_utility.py b/localstack-core/localstack/testing/snapshots/transformer_utility.py index 7d2d73c844dbb..562cc9e097646 100644 --- a/localstack-core/localstack/testing/snapshots/transformer_utility.py +++ b/localstack-core/localstack/testing/snapshots/transformer_utility.py @@ -327,6 +327,9 @@ def dynamodb_api(): @staticmethod def dynamodb_streams_api(): return [ + RegexTransformer( + r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}$", replacement="" + ), TransformerUtility.key_value("TableName"), TransformerUtility.key_value("TableStatus"), TransformerUtility.key_value("LatestStreamLabel"), diff --git a/tests/aws/services/dynamodb/test_dynamodb.py b/tests/aws/services/dynamodb/test_dynamodb.py index c4a2efc227618..2c0ab3e50b42f 100644 --- a/tests/aws/services/dynamodb/test_dynamodb.py +++ b/tests/aws/services/dynamodb/test_dynamodb.py @@ -1138,17 +1138,23 @@ def test_global_tables_version_2019( assert "Replicas" not in response["Table"] @markers.aws.validated - @pytest.mark.skipif( - condition=not is_aws_cloud(), reason="Streams do not work on the regional replica" + # An ARN stream has a stream label as suffix. In AWS, such a label differs between the stream of the original table + # and the ones of the replicas. In LocalStack, it does not differ. The only difference in the stream ARNs is the + # region. Therefore, we skip the following paths from the snapshots. + # However, we run plain assertions to make sure that the region changes in the ARNs, i.e., the replica have their + # own stream. + @markers.snapshot.skip_snapshot_verify( + paths=["$..Streams..StreamArn", "$..Streams..StreamLabel"] ) def test_streams_on_global_tables( self, aws_client_factory, - dynamodb_wait_for_table_active, + wait_for_dynamodb_stream_ready, cleanups, snapshot, region_name, secondary_region_name, + dynamodbstreams_snapshot_transformers, ): """ This test exposes an issue in LocalStack with Global tables and streams. In AWS, each regional replica should @@ -1158,9 +1164,6 @@ def test_streams_on_global_tables( region_1_factory = aws_client_factory(region_name=region_name) region_2_factory = aws_client_factory(region_name=secondary_region_name) snapshot.add_transformer(snapshot.transform.regex(secondary_region_name, "")) - snapshot.add_transformer( - snapshot.transform.jsonpath("$..Streams..StreamLabel", "stream-label") - ) # Create table in the original region table_name = f"table-{short_uid()}" @@ -1193,11 +1196,96 @@ def test_streams_on_global_tables( waiter = region_2_factory.dynamodb.get_waiter("table_exists") waiter.wait(TableName=table_name, WaiterConfig={"Delay": WAIT_SEC, "MaxAttempts": 20}) - us_streams = region_1_factory.dynamodbstreams.list_streams(TableName=table_name) - snapshot.match("region-streams", us_streams) - # FIXME: LS doesn't have a stream on the replica region - eu_streams = region_2_factory.dynamodbstreams.list_streams(TableName=table_name) - snapshot.match("secondary-region-streams", eu_streams) + stream_arn_region = region_1_factory.dynamodb.describe_table(TableName=table_name)["Table"][ + "LatestStreamArn" + ] + assert region_name in stream_arn_region + wait_for_dynamodb_stream_ready(stream_arn_region) + stream_arn_secondary_region = region_2_factory.dynamodb.describe_table( + TableName=table_name + )["Table"]["LatestStreamArn"] + assert secondary_region_name in stream_arn_secondary_region + wait_for_dynamodb_stream_ready( + stream_arn_secondary_region, region_2_factory.dynamodbstreams + ) + + # Verify that we can list streams on both regions + streams_region_1 = region_1_factory.dynamodbstreams.list_streams(TableName=table_name) + snapshot.match("region-streams", streams_region_1) + assert region_name in streams_region_1["Streams"][0]["StreamArn"] + streams_region_2 = region_2_factory.dynamodbstreams.list_streams(TableName=table_name) + snapshot.match("secondary-region-streams", streams_region_2) + assert secondary_region_name in streams_region_2["Streams"][0]["StreamArn"] + + region_1_factory.dynamodb.batch_write_item( + RequestItems={ + table_name: [ + { + "PutRequest": { + "Item": { + "Artist": {"S": "The Queen"}, + "SongTitle": {"S": "Bohemian Rhapsody"}, + } + } + }, + { + "PutRequest": { + "Item": {"Artist": {"S": "Oasis"}, "SongTitle": {"S": "Live Forever"}} + } + }, + ] + } + ) + + def _read_records_from_shards(_stream_arn, _expected_record_count, _client) -> int: + describe_stream_result = _client.describe_stream(StreamArn=_stream_arn) + shard_id_to_iterator: dict[str, str] = {} + fetched_records = [] + # Records can be spread over multiple shards. We need to read all over them + for stream_info in describe_stream_result["StreamDescription"]["Shards"]: + _shard_id = stream_info["ShardId"] + shard_iterator = _client.get_shard_iterator( + StreamArn=_stream_arn, ShardId=_shard_id, ShardIteratorType="TRIM_HORIZON" + )["ShardIterator"] + shard_id_to_iterator[_shard_id] = shard_iterator + + while len(fetched_records) < _expected_record_count and shard_id_to_iterator: + for _shard_id, _shard_iterator in list(shard_id_to_iterator.items()): + _resp = _client.get_records(ShardIterator=_shard_iterator) + fetched_records.extend(_resp["Records"]) + if next_shard_iterator := _resp.get("NextShardIterator"): + shard_id_to_iterator[_shard_id] = next_shard_iterator + continue + shard_id_to_iterator.pop(_shard_id, None) + return fetched_records + + def _assert_records(_stream_arn, _expected_count, _client) -> None: + records = _read_records_from_shards( + _stream_arn, + _expected_count, + _client, + ) + assert len(records) == _expected_count, ( + f"Expected {_expected_count} records, got {len(records)}" + ) + + retry( + _assert_records, + sleep=WAIT_SEC, + retries=20, + _stream_arn=stream_arn_region, + _expected_count=2, + _client=region_1_factory.dynamodbstreams, + ) + + retry( + _assert_records, + sleep=WAIT_SEC, + retries=20, + _stream_arn=stream_arn_secondary_region, + _expected_count=2, + _client=region_2_factory.dynamodbstreams, + ) @markers.aws.only_localstack def test_global_tables(self, aws_client, ddb_test_table): diff --git a/tests/aws/services/dynamodb/test_dynamodb.snapshot.json b/tests/aws/services/dynamodb/test_dynamodb.snapshot.json index ad40bf18e7c05..4842ef3f2406b 100644 --- a/tests/aws/services/dynamodb/test_dynamodb.snapshot.json +++ b/tests/aws/services/dynamodb/test_dynamodb.snapshot.json @@ -1730,14 +1730,14 @@ } }, "tests/aws/services/dynamodb/test_dynamodb.py::TestDynamoDB::test_streams_on_global_tables": { - "recorded-date": "15-05-2025, 13:42:48", + "recorded-date": "22-05-2025, 12:44:58", "recorded-content": { "region-streams": { "Streams": [ { - "StreamArn": "arn::dynamodb::111111111111:table//stream/", + "StreamArn": "arn::dynamodb::111111111111:table//stream/", "StreamLabel": "", - "TableName": "" + "TableName": "" } ], "ResponseMetadata": { @@ -1748,9 +1748,9 @@ "secondary-region-streams": { "Streams": [ { - "StreamArn": "arn::dynamodb::111111111111:table//stream/", + "StreamArn": "arn::dynamodb::111111111111:table//stream/", "StreamLabel": "", - "TableName": "" + "TableName": "" } ], "ResponseMetadata": { diff --git a/tests/aws/services/dynamodb/test_dynamodb.validation.json b/tests/aws/services/dynamodb/test_dynamodb.validation.json index d56f13b218112..6a2220f1f2937 100644 --- a/tests/aws/services/dynamodb/test_dynamodb.validation.json +++ b/tests/aws/services/dynamodb/test_dynamodb.validation.json @@ -75,7 +75,7 @@ "last_validated_date": "2024-01-03T17:52:19+00:00" }, "tests/aws/services/dynamodb/test_dynamodb.py::TestDynamoDB::test_streams_on_global_tables": { - "last_validated_date": "2025-05-15T13:42:45+00:00" + "last_validated_date": "2025-05-22T12:44:55+00:00" }, "tests/aws/services/dynamodb/test_dynamodb.py::TestDynamoDB::test_transact_get_items": { "last_validated_date": "2023-08-23T14:33:37+00:00"