Thanks to visit codestin.com
Credit goes to github.com

Skip to content

DDB Global table: add logic for streams #12641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
May 27, 2025
39 changes: 38 additions & 1 deletion localstack-core/localstack/services/dynamodb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,18 @@
TableName,
Update,
)
from localstack.aws.api.dynamodbstreams import (
ResourceNotFoundException as DynamoDBStreamsResourceNotFoundException,
)
from localstack.aws.connect import connect_to
from localstack.constants import INTERNAL_AWS_SECRET_ACCESS_KEY
from localstack.http import Response
from localstack.utils.aws.arns import dynamodb_table_arn, get_partition
from localstack.utils.aws.arns import (
dynamodb_stream_arn,
dynamodb_table_arn,
get_partition,
parse_arn,
)
from localstack.utils.json import canonical_json
from localstack.utils.testutil import list_all_resources

Expand Down Expand Up @@ -348,3 +356,32 @@ def _convert_arn(matchobj):

# update x-amz-crc32 header required by some clients
response.headers["x-amz-crc32"] = crc32(response.data) & 0xFFFFFFFF


def change_region_in_ddb_stream_arn(arn: str, region: str) -> str:
"""
Modify the ARN or a DynamoDB Stream by changing its region.
We need this logic when dealing with global tables, as we create a stream only in the originating region, and we
need to modify the ARN to mimic the stream of the replica regions.
"""
arn_data = parse_arn(arn)
if arn_data["region"] == region:
return arn

if arn_data["service"] != "dynamodb":
raise Exception(f"{arn} is not a DynamoDB Streams ARN")

# Note: a DynamoDB Streams ARN has the following pattern:
# arn:aws:dynamodb:<region>:<account>:table/<table_name>/stream/<latest_stream_label>
resource_splits = arn_data["resource"].split("/")
if len(resource_splits) != 4:
raise DynamoDBStreamsResourceNotFoundException(
f"The format of the '{arn}' ARN is not valid"
)

return dynamodb_stream_arn(
table_name=resource_splits[1],
latest_stream_label=resource_splits[-1],
account_id=arn_data["account"],
region_name=region,
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from bson.json_util import dumps

from localstack import config
from localstack.aws.api import RequestContext
from localstack.aws.api.dynamodbstreams import StreamStatus, StreamViewType, TableName
from localstack.aws.connect import connect_to
from localstack.services.dynamodb.v2.provider import DynamoDBProvider
from localstack.services.dynamodbstreams.models import DynamoDbStreamsStore, dynamodbstreams_stores
from localstack.utils.aws import arns, resources
from localstack.utils.common import now_utc
Expand Down Expand Up @@ -211,3 +213,23 @@ def get_shard_id(stream: Dict, kinesis_shard_id: str) -> str:
stream["shards_id_map"][kinesis_shard_id] = ddb_stream_shard_id

return ddb_stream_shard_id


def get_original_region(
context: RequestContext, stream_arn: str | None = None, table_name: str | None = None
) -> str:
"""
In DDB Global tables, we forward all the requests to the original region, instead of really replicating the data.
Since each table has a separate stream associated, we need to have a similar forwarding logic for DDB Streams.
To determine the original region, we need the table name, that can be either provided here or determined from the
ARN of the stream.
"""
if not stream_arn and not table_name:
LOG.debug(
"No Stream ARN or table name provided. Returning region '%s' from the request",
context.region,
)
return context.region

table_name = table_name or table_name_from_stream_arn(stream_arn)
return DynamoDBProvider.get_global_table_region(context=context, table_name=table_name)
58 changes: 49 additions & 9 deletions localstack-core/localstack/services/dynamodbstreams/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
TableName,
)
from localstack.aws.connect import connect_to
from localstack.services.dynamodb.utils import change_region_in_ddb_stream_arn
from localstack.services.dynamodbstreams.dynamodbstreams_api import (
get_dynamodbstreams_store,
get_kinesis_client,
get_kinesis_stream_name,
get_original_region,
get_shard_id,
kinesis_shard_id,
stream_name_from_stream_arn,
Expand All @@ -47,6 +49,13 @@


class DynamoDBStreamsProvider(DynamodbstreamsApi, ServiceLifecycleHook):
shard_to_region: dict[str, str]
"""Map a shard iterator to the originating region. This is used in case of replica tables, as LocalStack keeps the
data in one region only, redirecting all the requests from replica regions."""

def __init__(self):
self.shard_to_region = {}

def describe_stream(
self,
context: RequestContext,
Expand All @@ -55,13 +64,17 @@ def describe_stream(
exclusive_start_shard_id: ShardId = None,
**kwargs,
) -> DescribeStreamOutput:
store = get_dynamodbstreams_store(context.account_id, context.region)
kinesis = get_kinesis_client(account_id=context.account_id, region_name=context.region)
og_region = get_original_region(context=context, stream_arn=stream_arn)
store = get_dynamodbstreams_store(context.account_id, og_region)
kinesis = get_kinesis_client(account_id=context.account_id, region_name=og_region)
for stream in store.ddb_streams.values():
if stream["StreamArn"] == stream_arn:
_stream_arn = stream_arn
if context.region != og_region:
_stream_arn = change_region_in_ddb_stream_arn(_stream_arn, og_region)
if stream["StreamArn"] == _stream_arn:
# get stream details
dynamodb = connect_to(
aws_access_key_id=context.account_id, region_name=context.region
aws_access_key_id=context.account_id, region_name=og_region
).dynamodb
table_name = table_name_from_stream_arn(stream["StreamArn"])
stream_name = get_kinesis_stream_name(table_name)
Expand Down Expand Up @@ -90,6 +103,7 @@ def describe_stream(

stream["Shards"] = stream_shards
stream_description = select_from_typed_dict(StreamDescription, stream)
stream_description["StreamArn"] = _stream_arn
return DescribeStreamOutput(StreamDescription=stream_description)

raise ResourceNotFoundException(
Expand All @@ -98,11 +112,17 @@ def describe_stream(

@handler("GetRecords", expand=False)
def get_records(self, context: RequestContext, payload: GetRecordsInput) -> GetRecordsOutput:
kinesis = get_kinesis_client(account_id=context.account_id, region_name=context.region)
prefix, _, payload["ShardIterator"] = payload["ShardIterator"].rpartition("|")
_shard_iterator = payload["ShardIterator"]
region_name = context.region
if payload["ShardIterator"] in self.shard_to_region:
region_name = self.shard_to_region[_shard_iterator]

kinesis = get_kinesis_client(account_id=context.account_id, region_name=region_name)
prefix, _, payload["ShardIterator"] = _shard_iterator.rpartition("|")
try:
kinesis_records = kinesis.get_records(**payload)
except kinesis.exceptions.ExpiredIteratorException:
self.shard_to_region.pop(_shard_iterator, None)
LOG.debug("Shard iterator for underlying kinesis stream expired")
raise ExpiredIteratorException("Shard iterator has expired")
result = {
Expand All @@ -113,6 +133,11 @@ def get_records(self, context: RequestContext, payload: GetRecordsInput) -> GetR
record_data = loads(record["Data"])
record_data["dynamodb"]["SequenceNumber"] = record["SequenceNumber"]
result["Records"].append(record_data)

# Similar as the logic in GetShardIterator, we need to track the originating region when we get the
# NextShardIterator in the results.
if region_name != context.region and "NextShardIterator" in result:
self.shard_to_region[result["NextShardIterator"]] = region_name
return GetRecordsOutput(**result)

def get_shard_iterator(
Expand All @@ -125,8 +150,9 @@ def get_shard_iterator(
**kwargs,
) -> GetShardIteratorOutput:
stream_name = stream_name_from_stream_arn(stream_arn)
og_region = get_original_region(context=context, stream_arn=stream_arn)
stream_shard_id = kinesis_shard_id(shard_id)
kinesis = get_kinesis_client(account_id=context.account_id, region_name=context.region)
kinesis = get_kinesis_client(account_id=context.account_id, region_name=og_region)

kwargs = {"StartingSequenceNumber": sequence_number} if sequence_number else {}
result = kinesis.get_shard_iterator(
Expand All @@ -138,6 +164,11 @@ def get_shard_iterator(
del result["ResponseMetadata"]
# TODO not quite clear what the |1| exactly denotes, because at AWS it's sometimes other numbers
result["ShardIterator"] = f"{stream_arn}|1|{result['ShardIterator']}"

# In case of a replica table, we need to keep track of the real region originating the shard iterator.
# This region will be later used in GetRecords to redirect to the originating region, holding the data.
if og_region != context.region:
self.shard_to_region[result["ShardIterator"]] = og_region
return GetShardIteratorOutput(**result)

def list_streams(
Expand All @@ -148,8 +179,17 @@ def list_streams(
exclusive_start_stream_arn: StreamArn = None,
**kwargs,
) -> ListStreamsOutput:
store = get_dynamodbstreams_store(context.account_id, context.region)
og_region = get_original_region(context=context, table_name=table_name)
store = get_dynamodbstreams_store(context.account_id, og_region)
result = [select_from_typed_dict(Stream, res) for res in store.ddb_streams.values()]
if table_name:
result = [res for res in result if res["TableName"] == table_name]
result: list[Stream] = [res for res in result if res["TableName"] == table_name]
# If this is a stream from a table replica, we need to change the region in the stream ARN, as LocalStack
# keeps a stream only in the originating region.
if context.region != og_region:
for stream in result:
stream["StreamArn"] = change_region_in_ddb_stream_arn(
stream["StreamArn"], context.region
)

return ListStreamsOutput(Streams=result)
59 changes: 54 additions & 5 deletions localstack-core/localstack/services/dynamodbstreams/v2/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,22 @@
)
from localstack.services.dynamodb.server import DynamodbServer
from localstack.services.dynamodb.utils import modify_ddblocal_arns
from localstack.services.dynamodb.v2.provider import DynamoDBProvider
from localstack.services.dynamodb.v2.provider import DynamoDBProvider, modify_context_region
from localstack.services.dynamodbstreams.dynamodbstreams_api import get_original_region
from localstack.services.plugins import ServiceLifecycleHook
from localstack.utils.aws.arns import parse_arn

LOG = logging.getLogger(__name__)


class DynamoDBStreamsProvider(DynamodbstreamsApi, ServiceLifecycleHook):
shard_to_region: dict[str, str]
"""Map a shard iterator to the originating region. This is used in case of replica tables, as LocalStack keeps the
data in one region only, redirecting all the requests from replica regions."""

def __init__(self):
self.server = DynamodbServer.get()
self.shard_to_region = {}

def on_after_init(self):
# add response processor specific to ddblocal
Expand All @@ -33,6 +39,20 @@ def on_after_init(self):
def on_before_start(self):
self.server.start_dynamodb()

def _forward_request(
self, context: RequestContext, region: str | None, service_request: ServiceRequest
) -> ServiceResponse:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have a docstring explaining how this is different from self.forward_request()

"""
Modify the context region and then forward request to DynamoDB Local.

This is used for operations impacted by global tables. In LocalStack, a single copy of global table
is kept, and any requests to replicated tables are forwarded to this original table.
"""
if region:
with modify_context_region(context, region):
return self.forward_request(context, service_request=service_request)
return self.forward_request(context, service_request=service_request)

def forward_request(
self, context: RequestContext, service_request: ServiceRequest = None
) -> ServiceResponse:
Expand All @@ -55,27 +75,56 @@ def describe_stream(
context: RequestContext,
payload: DescribeStreamInput,
) -> DescribeStreamOutput:
global_table_region = get_original_region(context=context, stream_arn=payload["StreamArn"])
request = payload.copy()
request["StreamArn"] = self.modify_stream_arn_for_ddb_local(request.get("StreamArn", ""))
return self.forward_request(context, request)
return self._forward_request(
context=context, service_request=request, region=global_table_region
)

@handler("GetRecords", expand=False)
def get_records(self, context: RequestContext, payload: GetRecordsInput) -> GetRecordsOutput:
request = payload.copy()
request["ShardIterator"] = self.modify_stream_arn_for_ddb_local(
request.get("ShardIterator", "")
)
return self.forward_request(context, request)
region = self.shard_to_region.pop(request["ShardIterator"], None)
response = self._forward_request(context=context, region=region, service_request=request)
# Similar as the logic in GetShardIterator, we need to track the originating region when we get the
# NextShardIterator in the results.
if (
region
and region != context.region
and (next_shard := response.get("NextShardIterator"))
):
self.shard_to_region[next_shard] = region
return response

@handler("GetShardIterator", expand=False)
def get_shard_iterator(
self, context: RequestContext, payload: GetShardIteratorInput
) -> GetShardIteratorOutput:
global_table_region = get_original_region(context=context, stream_arn=payload["StreamArn"])
request = payload.copy()
request["StreamArn"] = self.modify_stream_arn_for_ddb_local(request.get("StreamArn", ""))
return self.forward_request(context, request)
response = self._forward_request(
context=context, service_request=request, region=global_table_region
)

# In case of a replica table, we need to keep track of the real region originating the shard iterator.
# This region will be later used in GetRecords to redirect to the originating region, holding the data.
if global_table_region != context.region and (
shard_iterator := response.get("ShardIterator")
):
self.shard_to_region[shard_iterator] = global_table_region
return response

@handler("ListStreams", expand=False)
def list_streams(self, context: RequestContext, payload: ListStreamsInput) -> ListStreamsOutput:
global_table_region = get_original_region(
context=context, stream_arn=payload.get("TableName")
)
# TODO: look into `ExclusiveStartStreamArn` param
return self.forward_request(context, payload)
return self._forward_request(
context=context, service_request=payload, region=global_table_region
)
7 changes: 3 additions & 4 deletions localstack-core/localstack/testing/pytest/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,11 +792,10 @@ def is_stream_ready():

@pytest.fixture
def wait_for_dynamodb_stream_ready(aws_client):
def _wait_for_stream_ready(stream_arn: str):
def _wait_for_stream_ready(stream_arn: str, client=None):
def is_stream_ready():
describe_stream_response = aws_client.dynamodbstreams.describe_stream(
StreamArn=stream_arn
)
ddb_client = client or aws_client.dynamodbstreams
describe_stream_response = ddb_client.describe_stream(StreamArn=stream_arn)
return describe_stream_response["StreamDescription"]["StreamStatus"] == "ENABLED"

return poll_condition(is_stream_ready)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ def dynamodb_api():
@staticmethod
def dynamodb_streams_api():
return [
RegexTransformer(
r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}$", replacement="<stream-label>"
),
TransformerUtility.key_value("TableName"),
TransformerUtility.key_value("TableStatus"),
TransformerUtility.key_value("LatestStreamLabel"),
Expand Down
Loading
Loading