"""Tools for formatting localstack logs."""

import logging
import re
from functools import lru_cache
from typing import Any

from localstack.utils.numbers import format_bytes
from localstack.utils.strings import to_bytes

MAX_THREAD_NAME_LEN = 12
MAX_NAME_LEN = 26

LOG_FORMAT = f"%(asctime)s.%(msecs)03d %(ls_level)5s --- [%(ls_thread){MAX_THREAD_NAME_LEN}s] %(ls_name)-{MAX_NAME_LEN}s : %(message)s"
LOG_DATE_FORMAT = "%Y-%m-%dT%H:%M:%S"
LOG_INPUT_FORMAT = "%(input_type)s(%(input)s, headers=%(request_headers)s)"
LOG_OUTPUT_FORMAT = "%(output_type)s(%(output)s, headers=%(response_headers)s)"
LOG_CONTEXT_FORMAT = "%(account_id)s/%(region)s"

CUSTOM_LEVEL_NAMES = {
    50: "FATAL",
    40: "ERROR",
    30: "WARN",
    20: "INFO",
    10: "DEBUG",
}


class DefaultFormatter(logging.Formatter):
    """
    A formatter that uses ``LOG_FORMAT`` and ``LOG_DATE_FORMAT``.
    """

    def __init__(self, fmt=LOG_FORMAT, datefmt=LOG_DATE_FORMAT):
        super().__init__(fmt=fmt, datefmt=datefmt)


class AddFormattedAttributes(logging.Filter):
    """
    Filter that adds three attributes to a log record:

    - ls_level: the abbreviated loglevel that's max 5 characters long
    - ls_name: the abbreviated name of the logger (e.g., `l.bootstrap.install`), trimmed to ``MAX_NAME_LEN``
    - ls_thread: the abbreviated thread name (prefix trimmed, .e.g, ``omeThread-108``)
    """

    max_name_len: int
    max_thread_len: int

    def __init__(self, max_name_len: int = None, max_thread_len: int = None):
        super().__init__()
        self.max_name_len = max_name_len if max_name_len else MAX_NAME_LEN
        self.max_thread_len = max_thread_len if max_thread_len else MAX_THREAD_NAME_LEN

    def filter(self, record):
        record.ls_level = CUSTOM_LEVEL_NAMES.get(record.levelno, record.levelname)
        record.ls_name = self._get_compressed_logger_name(record.name)
        record.ls_thread = record.threadName[-self.max_thread_len :]
        return True

    @lru_cache(maxsize=256)
    def _get_compressed_logger_name(self, name):
        return compress_logger_name(name, self.max_name_len)


class MaskSensitiveInputFilter(logging.Filter):
    """
    Filter that hides sensitive from a binary json string in a record input.
    It will find the mathing keys and replace their values with "******"

    For example, if initialized with `sensitive_keys=["my_key"]`, the input
    b'{"my_key": "sensitive_value"}' would become b'{"my_key": "******"}'.
    """

    patterns: list[tuple[re.Pattern[bytes], bytes]]

    def __init__(self, sensitive_keys: list[str]):
        super().__init__()

        self.patterns = [
            (re.compile(to_bytes(rf'"{key}":\s*"[^"]+"')), to_bytes(f'"{key}": "******"'))
            for key in sensitive_keys
        ]

    def filter(self, record):
        if record.input and isinstance(record.input, bytes):
            record.input = self.mask_sensitive_msg(record.input)
        return True

    def mask_sensitive_msg(self, message: bytes) -> bytes:
        for pattern, replacement in self.patterns:
            message = re.sub(pattern, replacement, message)
        return message


def compress_logger_name(name: str, length: int) -> str:
    """
    Creates a short version of a logger name. For example ``my.very.long.logger.name`` with length=17 turns into
    ``m.v.l.logger.name``.

    :param name: the logger name
    :param length: the max length of the logger name
    :return: the compressed name
    """
    if len(name) <= length:
        return name

    parts = name.split(".")
    parts.reverse()

    new_parts = []

    # we start by assuming that all parts are collapsed
    # x.x.x requires 5 = 2n - 1 characters
    cur_length = (len(parts) * 2) - 1

    for i in range(len(parts)):
        # try to expand the current part and calculate the resulting length
        part = parts[i]
        next_len = cur_length + (len(part) - 1)

        if next_len > length:
            # if the resulting length would exceed the limit, add only the first letter of the parts of all remaining
            # parts
            new_parts += [p[0] for p in parts[i:]]

            # but if this is the first item, that means we would display nothing, so at least display as much of the
            # max length as possible
            if i == 0:
                remaining = length - cur_length
                if remaining > 0:
                    new_parts[0] = part[: (remaining + 1)]

            break

        # expanding the current part, i.e., instead of using just the one character, we add the entire part
        new_parts.append(part)
        cur_length = next_len

    new_parts.reverse()
    return ".".join(new_parts)


class TraceLoggingFormatter(logging.Formatter):
    aws_trace_log_format = "; ".join([LOG_FORMAT, LOG_INPUT_FORMAT, LOG_OUTPUT_FORMAT])
    bytes_length_display_threshold = 512

    def __init__(self):
        super().__init__(fmt=self.aws_trace_log_format, datefmt=LOG_DATE_FORMAT)

    def _replace_large_payloads(self, input: Any) -> Any:
        """
        Replaces large payloads in the logs with placeholders to avoid cluttering the logs with huge bytes payloads.
        :param input: Input/output extra passed when logging. If it is bytes, it will be replaced if larger than
            bytes_length_display_threshold
        :return: Input, unless it is bytes and longer than bytes_length_display_threshold, then `Bytes(length_of_input)`
        """
        if isinstance(input, bytes) and len(input) > self.bytes_length_display_threshold:
            return f"Bytes({format_bytes(len(input))})"
        return input

    def format(self, record: logging.LogRecord) -> str:
        record.input = self._replace_large_payloads(record.input)
        record.output = self._replace_large_payloads(record.output)
        return super().format(record=record)


class AwsTraceLoggingFormatter(TraceLoggingFormatter):
    aws_trace_log_format = "; ".join(
        [LOG_FORMAT, LOG_CONTEXT_FORMAT, LOG_INPUT_FORMAT, LOG_OUTPUT_FORMAT]
    )

    def __init__(self):
        super().__init__()

    def _copy_service_dict(self, service_dict: dict) -> dict:
        if not isinstance(service_dict, dict):
            return service_dict
        result = {}
        for key, value in service_dict.items():
            if isinstance(value, dict):
                result[key] = self._copy_service_dict(value)
            elif isinstance(value, bytes) and len(value) > self.bytes_length_display_threshold:
                result[key] = f"Bytes({format_bytes(len(value))})"
            elif isinstance(value, list):
                result[key] = [self._copy_service_dict(item) for item in value]
            else:
                result[key] = value
        return result

    def format(self, record: logging.LogRecord) -> str:
        record.input = self._copy_service_dict(record.input)
        record.output = self._copy_service_dict(record.output)
        return super().format(record=record)
