From 49ee68aac844a236f2b49237a6387ae271cc1a3b Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Thu, 26 Jun 2025 17:13:10 +0100 Subject: [PATCH] Fix decoding headers signed numbers --- aws_lambda_powertools/shared/functions.py | 16 ++++++++++++++ .../utilities/data_classes/kafka_event.py | 3 ++- .../utilities/kafka/consumer_records.py | 5 ++++- .../utilities/parser/models/kafka.py | 10 ++++----- tests/events/kafkaEventMsk.json | 22 +++++++++++++++++++ .../required_dependencies/test_kafka_event.py | 7 +++--- tests/unit/parser/_pydantic/test_kafka.py | 6 ++--- 7 files changed, 55 insertions(+), 14 deletions(-) diff --git a/aws_lambda_powertools/shared/functions.py b/aws_lambda_powertools/shared/functions.py index 2d92af54360..16f51da1cb9 100644 --- a/aws_lambda_powertools/shared/functions.py +++ b/aws_lambda_powertools/shared/functions.py @@ -291,3 +291,19 @@ def sanitize_xray_segment_name(name: str) -> str: def get_tracer_id() -> str | None: xray_trace_id = os.getenv(constants.XRAY_TRACE_ID_ENV) return xray_trace_id.split(";")[0].replace("Root=", "") if xray_trace_id else None + + +def decode_header_bytes(byte_list): + """ + Decode a list of byte values that might be signed. + If any negative values exist, handle them as signed bytes. + Otherwise use the normal bytes construction. + """ + has_negative = any(b < 0 for b in byte_list) + + if not has_negative: + # Use normal bytes construction if all values are positive + return bytes(byte_list) + # Convert signed bytes to unsigned (0-255 range) + unsigned_bytes = [(b & 0xFF) for b in byte_list] + return bytes(unsigned_bytes) diff --git a/aws_lambda_powertools/utilities/data_classes/kafka_event.py b/aws_lambda_powertools/utilities/data_classes/kafka_event.py index 094bd4bed6f..53d23530cec 100644 --- a/aws_lambda_powertools/utilities/data_classes/kafka_event.py +++ b/aws_lambda_powertools/utilities/data_classes/kafka_event.py @@ -4,6 +4,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any +from aws_lambda_powertools.shared.functions import decode_header_bytes from aws_lambda_powertools.utilities.data_classes.common import CaseInsensitiveDict, DictWrapper if TYPE_CHECKING: @@ -110,7 +111,7 @@ def headers(self) -> list[dict[str, list[int]]]: @cached_property def decoded_headers(self) -> dict[str, bytes]: """Decodes the headers as a single dictionary.""" - return CaseInsensitiveDict((k, bytes(v)) for chunk in self.headers for k, v in chunk.items()) + return CaseInsensitiveDict((k, decode_header_bytes(v)) for chunk in self.headers for k, v in chunk.items()) class KafkaEventBase(DictWrapper): diff --git a/aws_lambda_powertools/utilities/kafka/consumer_records.py b/aws_lambda_powertools/utilities/kafka/consumer_records.py index 6da8f9fa1fa..1fa6afba15c 100644 --- a/aws_lambda_powertools/utilities/kafka/consumer_records.py +++ b/aws_lambda_powertools/utilities/kafka/consumer_records.py @@ -4,6 +4,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any +from aws_lambda_powertools.shared.functions import decode_header_bytes from aws_lambda_powertools.utilities.data_classes.common import CaseInsensitiveDict from aws_lambda_powertools.utilities.data_classes.kafka_event import KafkaEventBase, KafkaEventRecordBase from aws_lambda_powertools.utilities.kafka.deserializer.deserializer import get_deserializer @@ -115,7 +116,9 @@ def original_headers(self) -> list[dict[str, list[int]]]: @cached_property def headers(self) -> dict[str, bytes]: """Decodes the headers as a single dictionary.""" - return CaseInsensitiveDict((k, bytes(v)) for chunk in self.original_headers for k, v in chunk.items()) + return CaseInsensitiveDict( + (k, decode_header_bytes(v)) for chunk in self.original_headers for k, v in chunk.items() + ) class ConsumerRecords(KafkaEventBase): diff --git a/aws_lambda_powertools/utilities/parser/models/kafka.py b/aws_lambda_powertools/utilities/parser/models/kafka.py index 717d47ff26c..b22c3a2613a 100644 --- a/aws_lambda_powertools/utilities/parser/models/kafka.py +++ b/aws_lambda_powertools/utilities/parser/models/kafka.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, field_validator -from aws_lambda_powertools.shared.functions import base64_decode, bytes_to_string +from aws_lambda_powertools.shared.functions import base64_decode, bytes_to_string, decode_header_bytes SERVERS_DELIMITER = "," @@ -28,9 +28,7 @@ class KafkaRecordModel(BaseModel): # key is optional; only decode if not None @field_validator("key", mode="before") def decode_key(cls, value): - if value is not None: - return base64_decode(value) - return value + return base64_decode(value) if value is not None else value @field_validator("value", mode="before") def data_base64_decode(cls, value): @@ -41,7 +39,7 @@ def data_base64_decode(cls, value): def decode_headers_list(cls, value): for header in value: for key, values in header.items(): - header[key] = bytes(values) + header[key] = decode_header_bytes(values) return value @@ -51,7 +49,7 @@ class KafkaBaseEventModel(BaseModel): @field_validator("bootstrapServers", mode="before") def split_servers(cls, value): - return None if not value else value.split(SERVERS_DELIMITER) + return value.split(SERVERS_DELIMITER) if value else None class KafkaSelfManagedEventModel(KafkaBaseEventModel): diff --git a/tests/events/kafkaEventMsk.json b/tests/events/kafkaEventMsk.json index 6c27594460c..a91980b8ecc 100644 --- a/tests/events/kafkaEventMsk.json +++ b/tests/events/kafkaEventMsk.json @@ -104,6 +104,28 @@ "dataFormat": "AVRO", "schemaId": "1234" } + }, + { + "topic":"mymessage-with-unsigned", + "partition":0, + "offset":15, + "timestamp":1545084650987, + "timestampType":"CREATE_TIME", + "key": null, + "value":"eyJrZXkiOiJ2YWx1ZSJ9", + "headers":[ + { + "headerKey":[104, 101, 108, 108, 111, 45, 119, 111, 114, 108, 100, 45, -61, -85] + } + ], + "valueSchemaMetadata": { + "dataFormat": "AVRO", + "schemaId": "1234" + }, + "keySchemaMetadata": { + "dataFormat": "AVRO", + "schemaId": "1234" + } } ] } diff --git a/tests/unit/data_classes/required_dependencies/test_kafka_event.py b/tests/unit/data_classes/required_dependencies/test_kafka_event.py index fc7bbf12a1a..98e933ab94a 100644 --- a/tests/unit/data_classes/required_dependencies/test_kafka_event.py +++ b/tests/unit/data_classes/required_dependencies/test_kafka_event.py @@ -21,7 +21,7 @@ def test_kafka_msk_event(): assert parsed_event.decoded_bootstrap_servers == bootstrap_servers_list records = list(parsed_event.records) - assert len(records) == 3 + assert len(records) == 4 record = records[0] raw_record = raw_event["records"]["mytopic-0"][0] assert record.topic == raw_record["topic"] @@ -40,9 +40,10 @@ def test_kafka_msk_event(): assert record.value_schema_metadata.schema_id == raw_record["valueSchemaMetadata"]["schemaId"] assert parsed_event.record == records[0] - for i in range(1, 3): + for i in range(1, 4): record = records[i] assert record.key is None + assert record.decoded_headers is not None def test_kafka_self_managed_event(): @@ -90,5 +91,5 @@ def test_kafka_record_property_with_stopiteration_error(): # WHEN calling record property thrice # THEN raise StopIteration with pytest.raises(StopIteration): - for _ in range(4): + for _ in range(5): assert parsed_event.record.topic is not None diff --git a/tests/unit/parser/_pydantic/test_kafka.py b/tests/unit/parser/_pydantic/test_kafka.py index 779756831a9..4a49bac1fce 100644 --- a/tests/unit/parser/_pydantic/test_kafka.py +++ b/tests/unit/parser/_pydantic/test_kafka.py @@ -17,7 +17,7 @@ def test_kafka_msk_event_with_envelope(): ) for i in range(3): assert parsed_event[i].key == "value" - assert len(parsed_event) == 3 + assert len(parsed_event) == 4 def test_kafka_self_managed_event_with_envelope(): @@ -70,7 +70,7 @@ def test_kafka_msk_event(): assert parsed_event.eventSourceArn == raw_event["eventSourceArn"] records = list(parsed_event.records["mytopic-0"]) - assert len(records) == 3 + assert len(records) == 4 record: KafkaRecordModel = records[0] raw_record = raw_event["records"]["mytopic-0"][0] assert record.topic == raw_record["topic"] @@ -88,6 +88,6 @@ def test_kafka_msk_event(): assert record.keySchemaMetadata.schemaId == "1234" assert record.valueSchemaMetadata.dataFormat == "AVRO" assert record.valueSchemaMetadata.schemaId == "1234" - for i in range(1, 3): + for i in range(1, 4): record: KafkaRecordModel = records[i] assert record.key is None