from __future__ import annotations

import datetime
import json
import logging
from typing import Final

from localstack.aws.api.events import PutEventsRequestEntry
from localstack.aws.api.stepfunctions import (
    Arn,
    CloudWatchEventsExecutionDataDetails,
    DescribeExecutionOutput,
    DescribeStateMachineForExecutionOutput,
    ExecutionListItem,
    ExecutionStatus,
    GetExecutionHistoryOutput,
    HistoryEventList,
    InvalidName,
    SensitiveCause,
    SensitiveError,
    StartExecutionOutput,
    StartSyncExecutionOutput,
    StateMachineType,
    SyncExecutionStatus,
    Timestamp,
    TraceHeader,
    VariableReferences,
)
from localstack.aws.connect import connect_to
from localstack.services.stepfunctions.asl.eval.evaluation_details import (
    AWSExecutionDetails,
    EvaluationDetails,
    ExecutionDetails,
    StateMachineDetails,
)
from localstack.services.stepfunctions.asl.eval.event.logging import (
    CloudWatchLoggingSession,
)
from localstack.services.stepfunctions.asl.eval.program_state import (
    ProgramEnded,
    ProgramError,
    ProgramState,
    ProgramStopped,
    ProgramTimedOut,
)
from localstack.services.stepfunctions.asl.static_analyser.variable_references_static_analyser import (
    VariableReferencesStaticAnalyser,
)
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str
from localstack.services.stepfunctions.backend.activity import Activity
from localstack.services.stepfunctions.backend.execution_worker import (
    ExecutionWorker,
    SyncExecutionWorker,
)
from localstack.services.stepfunctions.backend.execution_worker_comm import (
    ExecutionWorkerCommunication,
)
from localstack.services.stepfunctions.backend.state_machine import (
    StateMachineInstance,
    StateMachineVersion,
)
from localstack.services.stepfunctions.mocking.mock_config import MockTestCase

LOG = logging.getLogger(__name__)


class BaseExecutionWorkerCommunication(ExecutionWorkerCommunication):
    execution: Final[Execution]

    def __init__(self, execution: Execution):
        self.execution = execution

    def _reflect_execution_status(self):
        exit_program_state: ProgramState = self.execution.exec_worker.env.program_state()
        self.execution.stop_date = datetime.datetime.now(tz=datetime.UTC)
        if isinstance(exit_program_state, ProgramEnded):
            self.execution.exec_status = ExecutionStatus.SUCCEEDED
            self.execution.output = self.execution.exec_worker.env.states.get_input()
        elif isinstance(exit_program_state, ProgramStopped):
            self.execution.exec_status = ExecutionStatus.ABORTED
        elif isinstance(exit_program_state, ProgramError):
            self.execution.exec_status = ExecutionStatus.FAILED
            self.execution.error = exit_program_state.error.get("error")
            self.execution.cause = exit_program_state.error.get("cause")
        elif isinstance(exit_program_state, ProgramTimedOut):
            self.execution.exec_status = ExecutionStatus.TIMED_OUT
        else:
            raise RuntimeWarning(
                f"Execution ended with unsupported ProgramState type '{type(exit_program_state)}'."
            )

    def terminated(self) -> None:
        self._reflect_execution_status()
        self.execution.publish_execution_status_change_event()


class Execution:
    name: Final[str]
    sm_type: Final[StateMachineType]
    role_arn: Final[Arn]
    exec_arn: Final[Arn]

    account_id: str
    region_name: str

    state_machine: Final[StateMachineInstance]
    state_machine_arn: Final[Arn]
    state_machine_version_arn: Final[Arn | None]
    state_machine_alias_arn: Final[Arn | None]

    mock_test_case: Final[MockTestCase | None]

    start_date: Final[Timestamp]
    input_data: Final[json | None]
    input_details: Final[CloudWatchEventsExecutionDataDetails | None]
    trace_header: Final[TraceHeader | None]
    _cloud_watch_logging_session: Final[CloudWatchLoggingSession | None]

    exec_status: ExecutionStatus | None
    stop_date: Timestamp | None

    output: json | None
    output_details: CloudWatchEventsExecutionDataDetails | None

    error: SensitiveError | None
    cause: SensitiveCause | None

    exec_worker: ExecutionWorker | None

    _activity_store: dict[Arn, Activity]

    def __init__(
        self,
        name: str,
        sm_type: StateMachineType,
        role_arn: Arn,
        exec_arn: Arn,
        account_id: str,
        region_name: str,
        state_machine: StateMachineInstance,
        start_date: Timestamp,
        cloud_watch_logging_session: CloudWatchLoggingSession | None,
        activity_store: dict[Arn, Activity],
        input_data: json | None = None,
        trace_header: TraceHeader | None = None,
        state_machine_alias_arn: Arn | None = None,
        mock_test_case: MockTestCase | None = None,
    ):
        self.name = name
        self.sm_type = sm_type
        self.role_arn = role_arn
        self.exec_arn = exec_arn
        self.account_id = account_id
        self.region_name = region_name
        self.state_machine = state_machine
        if isinstance(state_machine, StateMachineVersion):
            self.state_machine_arn = state_machine.source_arn
            self.state_machine_version_arn = state_machine.arn
        else:
            self.state_machine_arn = state_machine.arn
            self.state_machine_version_arn = None
        self.state_machine_alias_arn = state_machine_alias_arn
        self.start_date = start_date
        self._cloud_watch_logging_session = cloud_watch_logging_session
        self.input_data = input_data
        self.input_details = CloudWatchEventsExecutionDataDetails(included=True)
        self.trace_header = trace_header
        self.exec_status = None
        self.stop_date = None
        self.output = None
        self.output_details = CloudWatchEventsExecutionDataDetails(included=True)
        self.exec_worker = None
        self.error = None
        self.cause = None
        self._activity_store = activity_store
        self.mock_test_case = mock_test_case

    def _get_events_client(self):
        return connect_to(aws_access_key_id=self.account_id, region_name=self.region_name).events

    def to_start_output(self) -> StartExecutionOutput:
        return StartExecutionOutput(executionArn=self.exec_arn, startDate=self.start_date)

    def to_describe_output(self) -> DescribeExecutionOutput:
        describe_output = DescribeExecutionOutput(
            executionArn=self.exec_arn,
            stateMachineArn=self.state_machine_arn,
            name=self.name,
            status=self.exec_status,
            startDate=self.start_date,
            stopDate=self.stop_date,
            input=to_json_str(self.input_data, separators=(",", ":")),
            inputDetails=self.input_details,
            traceHeader=self.trace_header,
        )
        if describe_output["status"] == ExecutionStatus.SUCCEEDED:
            describe_output["output"] = to_json_str(self.output, separators=(",", ":"))
            describe_output["outputDetails"] = self.output_details
        if self.error is not None:
            describe_output["error"] = self.error
        if self.cause is not None:
            describe_output["cause"] = self.cause
        if self.state_machine_version_arn is not None:
            describe_output["stateMachineVersionArn"] = self.state_machine_version_arn
        if self.state_machine_alias_arn is not None:
            describe_output["stateMachineAliasArn"] = self.state_machine_alias_arn
        return describe_output

    def to_describe_state_machine_for_execution_output(
        self,
    ) -> DescribeStateMachineForExecutionOutput:
        state_machine: StateMachineInstance = self.state_machine
        state_machine_arn = (
            state_machine.source_arn
            if isinstance(state_machine, StateMachineVersion)
            else state_machine.arn
        )
        out = DescribeStateMachineForExecutionOutput(
            stateMachineArn=state_machine_arn,
            name=state_machine.name,
            definition=state_machine.definition,
            roleArn=self.role_arn,
            # The date and time the state machine associated with an execution was updated.
            updateDate=state_machine.create_date,
            loggingConfiguration=state_machine.logging_config,
        )
        revision_id = self.state_machine.revision_id
        if self.state_machine.revision_id:
            out["revisionId"] = revision_id
        variable_references: VariableReferences = VariableReferencesStaticAnalyser.process_and_get(
            definition=self.state_machine.definition
        )
        if variable_references:
            out["variableReferences"] = variable_references
        return out

    def to_execution_list_item(self) -> ExecutionListItem:
        if isinstance(self.state_machine, StateMachineVersion):
            state_machine_arn = self.state_machine.source_arn
            state_machine_version_arn = self.state_machine.arn
        else:
            state_machine_arn = self.state_machine.arn
            state_machine_version_arn = None

        item = ExecutionListItem(
            executionArn=self.exec_arn,
            stateMachineArn=state_machine_arn,
            name=self.name,
            status=self.exec_status,
            startDate=self.start_date,
            stopDate=self.stop_date,
        )
        if state_machine_version_arn is not None:
            item["stateMachineVersionArn"] = state_machine_version_arn
        if self.state_machine_alias_arn is not None:
            item["stateMachineAliasArn"] = self.state_machine_alias_arn
        return item

    def to_history_output(self) -> GetExecutionHistoryOutput:
        env = self.exec_worker.env
        event_history: HistoryEventList = []
        if env is not None:
            # The execution has not started yet.
            event_history: HistoryEventList = env.event_manager.get_event_history()
        return GetExecutionHistoryOutput(events=event_history)

    @staticmethod
    def _to_serialized_date(timestamp: datetime.datetime) -> str:
        """See test in tests.aws.services.stepfunctions.v2.base.test_base.TestSnfBase.test_execution_dateformat"""
        return f"{timestamp.astimezone(datetime.UTC).strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3]}Z"

    def _get_start_execution_worker_comm(self) -> BaseExecutionWorkerCommunication:
        return BaseExecutionWorkerCommunication(self)

    def _get_start_aws_execution_details(self) -> AWSExecutionDetails:
        return AWSExecutionDetails(
            account=self.account_id, region=self.region_name, role_arn=self.role_arn
        )

    def get_start_execution_details(self) -> ExecutionDetails:
        return ExecutionDetails(
            arn=self.exec_arn,
            name=self.name,
            role_arn=self.role_arn,
            inpt=self.input_data,
            start_time=self._to_serialized_date(self.start_date),
        )

    def get_start_state_machine_details(self) -> StateMachineDetails:
        return StateMachineDetails(
            arn=self.state_machine.arn,
            name=self.state_machine.name,
            typ=self.state_machine.sm_type,
            definition=self.state_machine.definition,
        )

    def _get_start_execution_worker(self) -> ExecutionWorker:
        return ExecutionWorker(
            evaluation_details=EvaluationDetails(
                aws_execution_details=self._get_start_aws_execution_details(),
                execution_details=self.get_start_execution_details(),
                state_machine_details=self.get_start_state_machine_details(),
            ),
            exec_comm=self._get_start_execution_worker_comm(),
            cloud_watch_logging_session=self._cloud_watch_logging_session,
            activity_store=self._activity_store,
            mock_test_case=self.mock_test_case,
        )

    def start(self) -> None:
        # TODO: checks exec_worker does not exists already?
        if self.exec_worker:
            raise InvalidName()  # TODO.
        self.exec_worker = self._get_start_execution_worker()
        self.exec_status = ExecutionStatus.RUNNING
        self.publish_execution_status_change_event()
        self.exec_worker.start()

    def stop(self, stop_date: datetime.datetime, error: str | None, cause: str | None):
        exec_worker: ExecutionWorker | None = self.exec_worker
        if exec_worker:
            exec_worker.stop(stop_date=stop_date, cause=cause, error=error)

    def publish_execution_status_change_event(self):
        input_value = (
            {} if not self.input_data else to_json_str(self.input_data, separators=(",", ":"))
        )
        output_value = (
            None if self.output is None else to_json_str(self.output, separators=(",", ":"))
        )
        output_details = None if output_value is None else self.output_details
        entry = PutEventsRequestEntry(
            Source="aws.states",
            Resources=[self.exec_arn],
            DetailType="Step Functions Execution Status Change",
            Detail=to_json_str(
                # Note: this operation carries significant changes from a describe_execution request.
                DescribeExecutionOutput(
                    executionArn=self.exec_arn,
                    stateMachineArn=self.state_machine.arn,
                    stateMachineAliasArn=None,
                    stateMachineVersionArn=None,
                    name=self.name,
                    status=self.exec_status,
                    startDate=self.start_date,
                    stopDate=self.stop_date,
                    input=input_value,
                    inputDetails=self.input_details,
                    output=output_value,
                    outputDetails=output_details,
                    error=self.error,
                    cause=self.cause,
                )
            ),
        )
        try:
            self._get_events_client().put_events(Entries=[entry])
        except Exception:
            LOG.error(
                "Unable to send notification of Entry='%s' for Step Function execution with Arn='%s' to EventBridge.",
                entry,
                self.exec_arn,
                exc_info=LOG.isEnabledFor(logging.DEBUG),
            )


class SyncExecutionWorkerCommunication(BaseExecutionWorkerCommunication):
    execution: Final[SyncExecution]

    def _reflect_execution_status(self) -> None:
        super()._reflect_execution_status()
        exit_status: ExecutionStatus = self.execution.exec_status
        if exit_status == ExecutionStatus.SUCCEEDED:
            self.execution.sync_execution_status = SyncExecutionStatus.SUCCEEDED
        elif exit_status == ExecutionStatus.TIMED_OUT:
            self.execution.sync_execution_status = SyncExecutionStatus.TIMED_OUT
        else:
            self.execution.sync_execution_status = SyncExecutionStatus.FAILED


class SyncExecution(Execution):
    sync_execution_status: SyncExecutionStatus | None = None

    def _get_start_execution_worker(self) -> SyncExecutionWorker:
        return SyncExecutionWorker(
            evaluation_details=EvaluationDetails(
                aws_execution_details=self._get_start_aws_execution_details(),
                execution_details=self.get_start_execution_details(),
                state_machine_details=self.get_start_state_machine_details(),
            ),
            exec_comm=self._get_start_execution_worker_comm(),
            cloud_watch_logging_session=self._cloud_watch_logging_session,
            activity_store=self._activity_store,
            mock_test_case=self.mock_test_case,
        )

    def _get_start_execution_worker_comm(self) -> BaseExecutionWorkerCommunication:
        return SyncExecutionWorkerCommunication(self)

    def to_start_sync_execution_output(self) -> StartSyncExecutionOutput:
        start_output = StartSyncExecutionOutput(
            executionArn=self.exec_arn,
            stateMachineArn=self.state_machine.arn,
            name=self.name,
            status=self.sync_execution_status,
            startDate=self.start_date,
            stopDate=self.stop_date,
            input=to_json_str(self.input_data, separators=(",", ":")),
            inputDetails=self.input_details,
            traceHeader=self.trace_header,
        )
        if self.sync_execution_status == SyncExecutionStatus.SUCCEEDED:
            start_output["output"] = to_json_str(self.output, separators=(",", ":"))
        if self.output_details:
            start_output["outputDetails"] = self.output_details
        if self.error is not None:
            start_output["error"] = self.error
        if self.cause is not None:
            start_output["cause"] = self.cause
        return start_output
