#!/usr/bin/env python3
import os
import sys
from email.headerregistry import Address
from email.utils import parseaddr

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from scripts.lib.setup_path import setup_path

setup_path()

os.environ["DJANGO_SETTINGS_MODULE"] = "zproject.settings"
import argparse
import secrets
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TypedDict

import boto3.session
import orjson
from django.conf import settings
from mypy_boto3_ses import SESClient
from mypy_boto3_sns import SNSClient
from mypy_boto3_sqs import SQSClient
from mypy_boto3_sqs.type_defs import DeleteMessageBatchRequestEntryTypeDef


class IdentityArgsDict(TypedDict, total=False):
    default: str
    required: bool


def main() -> None:
    session = boto3.session.Session(region_name=settings.S3_REGION)

    # Strip off the realname, if present, and extract just the domain name
    _, from_address = parseaddr(settings.NOREPLY_EMAIL_ADDRESS)
    from_host = Address(addr_spec=from_address).domain

    ses: SESClient = session.client("ses")
    possible_identities = []
    identity_paginator = ses.get_paginator("list_identities")
    for identity_resp in identity_paginator.paginate():
        possible_identities += identity_resp["Identities"]

    identity_args: IdentityArgsDict = {}
    if from_host in possible_identities:
        identity_args["default"] = from_host
    elif from_address in possible_identities:
        identity_args["default"] = from_address
    else:
        identity_args["required"] = True

    parser = argparse.ArgumentParser(description="Tail SES delivery or bounces")
    parser.add_argument(
        "--identity",
        "-i",
        choices=possible_identities,
        help="Sending identity in SES",
        **identity_args,
    )
    topic_group = parser.add_mutually_exclusive_group(required=True)
    topic_group.add_argument("--bounces", "-b", action="store_true")
    topic_group.add_argument("--deliveries", "-d", action="store_true")
    topic_group.add_argument("--complaints", "-c", action="store_true")
    args = parser.parse_args()

    sns_topic_arn = get_ses_arn(session, args)
    with (
        our_sqs_queue(session, sns_topic_arn) as (queue_arn, queue_url),
        our_sns_subscription(session, sns_topic_arn, queue_arn),
    ):
        print_messages(session, queue_url)


def get_ses_arn(session: boto3.session.Session, args: argparse.Namespace) -> str:
    ses: SESClient = session.client("ses")

    notification_settings = ses.get_identity_notification_attributes(Identities=[args.identity])
    settings = notification_settings["NotificationAttributes"][args.identity]

    if args.bounces:
        return settings["BounceTopic"]
    elif args.complaints:
        return settings["ComplaintTopic"]
    elif args.deliveries:
        return settings["DeliveryTopic"]
    raise AssertionError  # Unreachable


@contextmanager
def our_sqs_queue(session: boto3.session.Session, ses_topic_arn: str) -> Iterator[tuple[str, str]]:
    (_, _, _, region, account_id, topic_name) = ses_topic_arn.split(":")

    sqs: SQSClient = session.client("sqs")
    queue_name = "tail-ses-" + secrets.token_hex(10)
    try:
        resp = sqs.create_queue(
            QueueName=queue_name,
            Attributes={
                "Policy": orjson.dumps(
                    {
                        "Version": "2012-10-17",
                        "Id": secrets.token_hex(10),
                        "Statement": [
                            {
                                "Sid": "Sid" + secrets.token_hex(10),
                                "Effect": "Allow",
                                "Principal": {"AWS": "*"},
                                "Action": "SQS:SendMessage",
                                "Resource": f"arn:aws:sqs:{region}:{account_id}:{queue_name}",
                                "Condition": {"ArnEquals": {"aws:SourceArn": ses_topic_arn}},
                            }
                        ],
                    }
                ).decode("UTF-8")
            },
        )
        queue_url = resp["QueueUrl"]
        yield f"arn:aws:sqs:{region}:{account_id}:{queue_name}", queue_url
    finally:
        if queue_url is not None:
            print("Deleting temporary queue...", file=sys.stderr)
            sqs.delete_queue(QueueUrl=queue_url)


@contextmanager
def our_sns_subscription(
    session: boto3.session.Session, ses_topic_arn: str, queue_arn: str
) -> Iterator[str]:
    sns: SNSClient = session.client("sns")
    try:
        resp = sns.subscribe(
            TopicArn=ses_topic_arn,
            Protocol="sqs",
            Endpoint=queue_arn,
            Attributes={"RawMessageDelivery": "false"},
            ReturnSubscriptionArn=True,
        )
        subscription_arn = resp["SubscriptionArn"]
        yield subscription_arn
    finally:
        if subscription_arn is not None:
            print("Deleting temporary SNS subscription...", file=sys.stderr)
            sns.unsubscribe(SubscriptionArn=subscription_arn)


def print_messages(session: boto3.session.Session, queue_url: str) -> None:
    sqs: SQSClient = session.client("sqs")
    try:
        while True:
            resp = sqs.receive_message(
                QueueUrl=queue_url,
                MaxNumberOfMessages=10,
                WaitTimeSeconds=5,
                MessageAttributeNames=["All"],
            )
            messages = resp.get("Messages", [])
            delete_list: list[DeleteMessageBatchRequestEntryTypeDef] = []
            for m in messages:
                body = orjson.loads(m["Body"])
                body_message = orjson.loads(body["Message"])
                print(
                    body["Timestamp"]
                    + " "
                    + orjson.dumps(body_message, option=orjson.OPT_INDENT_2).decode("utf-8")
                )
                delete_list.append({"Id": m["MessageId"], "ReceiptHandle": m["ReceiptHandle"]})
            if delete_list:
                sqs.delete_message_batch(QueueUrl=queue_url, Entries=delete_list)
    except KeyboardInterrupt:
        pass


if __name__ == "__main__":
    main()
