#!/usr/bin/env python3

import argparse
import calendar
import gzip
import logging
import os
import re
import signal
import sys
from datetime import date, datetime, timedelta, timezone
from enum import Enum, auto
from re import Match
from typing import TextIO

ZULIP_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(ZULIP_PATH)

from scripts.lib.setup_path import setup_path

setup_path()

os.environ["DJANGO_SETTINGS_MODULE"] = "zproject.settings"

from typing import Protocol

from django.conf import settings

from scripts.lib.zulip_tools import (
    BOLD,
    CYAN,
    ENDC,
    FAIL,
    GRAY,
    OKBLUE,
    get_config,
    get_config_file,
)


def parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Search logfiles, ignoring commonly-fetched URLs.")
    log_selection = parser.add_argument_group("File selection")
    log_selection_options = log_selection.add_mutually_exclusive_group()
    access_log_retention_days = int(
        get_config(get_config_file(), "application_server", "access_log_retention_days", "14")
    )
    log_selection_options.add_argument(
        "--log-files",
        "-n",
        help="Number of log files to search",
        choices=range(1, access_log_retention_days + 2),
        metavar=f"[1-{access_log_retention_days + 1}]",
        type=int,
    )
    log_selection_options.add_argument(
        "--all-logs",
        "-A",
        help="Parse all logfiles, not just most recent",
        action="store_true",
    )
    log_selection_options.add_argument(
        "--min-hours",
        "-H",
        help="Estimated minimum number of hours; includes previous log file, if estimated less than this",
        type=int,
        choices=range(24),
        default=3,
    )
    log_selection.add_argument(
        "--nginx",
        "-N",
        help="Parse from NGINX logs, not server.log",
        action="store_true",
    )

    filtering = parser.add_argument_group("Filtering")
    filtering.add_argument(
        "filter_terms",
        help="IP address, hostname, user-id, HTTP method, path, datetime prefix, or status code to search for; multiple are AND'ed together",
        nargs="+",
    )
    filtering.add_argument(
        "--all-lines",
        "-L",
        help="Show all matching lines; equivalent to -suemtpr",
        action="store_true",
    )
    filtering.add_argument("--static", "-s", help="Include static file paths", action="store_true")
    filtering.add_argument("--uploads", "-u", help="Include file upload paths", action="store_true")
    filtering.add_argument("--avatars", "-a", help="Include avatar paths", action="store_true")
    filtering.add_argument("--events", "-e", help="Include event fetch paths", action="store_true")
    filtering.add_argument("--messages", "-m", help="Include message paths", action="store_true")
    filtering.add_argument(
        "--typing",
        "-t",
        help="Include typing notification path",
        action="store_true",
    )
    filtering.add_argument("--presence", "-p", help="Include presence paths", action="store_true")
    filtering.add_argument(
        "--report", "-r", help="Include Sentry reporting paths", action="store_true"
    )
    filtering.add_argument(
        "--no-other", "-O", help="Exclude paths not explicitly included", action="store_true"
    )
    filtering.add_argument(
        "--client",
        "--user-agent",
        "-C",
        help="Only include requests whose client/user-agent contains this string",
    )

    output = parser.add_argument_group("Output")
    output.add_argument("--full-line", "-F", help="Show full matching line", action="store_true")
    output.add_argument("--timeline", "-T", help="Show start, end, and gaps", action="store_true")
    return parser


def maybe_gzip(logfile_name: str) -> TextIO:
    if logfile_name.endswith(".gz"):
        return gzip.open(logfile_name, "rt")  # noqa: SIM115
    return open(logfile_name)  # noqa: SIM115


NGINX_LOG_LINE_RE = re.compile(
    r"""
      (?P<ip> \S+ ) \s+
      - \s+
      (?P<user> \S+ ) \s+
      \[
         (?P<date> \d+/\w+/\d+ )
         :
         (?P<time> \d+:\d+:\d+ )
         \s+ [+-]\d+
      \] \s+
      "
         (?P<method> \S+ )
         \s+
         (?P<full_path> (?P<path> [^"?]+ ) (\?[^"]*)? )
         \s+
         (?P<http_version> HTTP/[^"]+ )
      " \s+
      (?P<code> \d+ ) \s+
      (?P<bytes> \d+ ) \s+
      "(?P<referer> [^"]* )" \s+
      "(?P<user_agent> [^"]* )" \s+
      (?P<hostname> \S+ ) \s+
      (?P<duration> \S+ )
    """,
    re.VERBOSE,
)

PYTHON_LOG_LINE_RE = re.compile(
    r"""
      (?P<date> \d+-\d+-\d+ ) \s+
      (?P<time> \d+:\d+:\d+\.\d+ ) \s+
      INFO \s+  # All access log lines are INFO
      (pid:\d+ \s+) ?
      \[ (?P<source> zr(:\d+)?) \] \s+
      (?P<ip>
        \d{1,3}(\.\d{1,3}){3}
      | ([a-f0-9:]+:+){1,7}[a-f0-9]*
      ) \s+
      (?P<method> [A-Z]+ ) \s+
      (?P<code> \d+ ) \s+
      (?P<duration> \S+ ) \s+ # This can be "217ms" or "1.7s"
      ( \( [^)]+ \) \s+ )*
      (?P<full_path> (?P<path> /\S* ) ) \s+
      .*   # Multiple extra things can go here
      \(
        (?P<user>
           ( (?P<user_id> \d+ ) | unauth )
           @
           (?P<hostname> \S+ )
         | zulip-server:\S+
         | scim-client:\S+
         | internal
        ) \s+ via \s+ (?P<user_agent> .* )
      \)
    """,
    re.VERBOSE,
)


class FilterType(Enum):
    HOSTNAME = auto()
    CLIENT_IP = auto()
    USER_ID = auto()
    METHOD = auto()
    PATH = auto()
    STATUS = auto()
    DATETIME = auto()


class FilterFunc(Protocol):
    def __call__(self, m: Match[str], t: str = ...) -> bool: ...


def main() -> None:
    args = parser().parse_args()

    (filter_types, filter_funcs, substr_terms) = parse_filters(args)
    logfile_names = parse_logfile_names(args)
    if args.timeline and args.nginx:
        print("! nginx logs not suggested for timeline, due to imprecision", file=sys.stderr)

    use_color = sys.stdout.isatty()
    lowered_terms = [term.lower() for term in substr_terms]
    try:
        for logfile_name in reversed(logfile_names):
            with maybe_gzip(logfile_name) as logfile:
                for logline in logfile:
                    # As a performance optimization, just do a substring
                    # check before we parse the line fully
                    lowered = logline.lower()
                    if not all(f in lowered for f in lowered_terms):
                        continue

                    if args.nginx:
                        match = NGINX_LOG_LINE_RE.match(logline)
                    else:
                        match = PYTHON_LOG_LINE_RE.match(logline)
                    if match is None:
                        # We expect other types of loglines in the Python logfiles
                        if args.nginx:
                            print(f"! Failed to parse:\n{logline}", file=sys.stderr)
                        continue
                    if passes_filters(filter_funcs, match, args):
                        print_line(
                            match,
                            args,
                            filter_types=filter_types,
                            use_color=use_color,
                        )
    except BrokenPipeError:
        # Python flushes standard streams on exit; redirect remaining output
        # to devnull to avoid another BrokenPipeError at shutdown
        devnull = os.open(os.devnull, os.O_WRONLY)
        os.dup2(devnull, sys.stdout.fileno())
        sys.exit(1)
    except KeyboardInterrupt:
        sys.exit(signal.SIGINT + 128)


def parse_logfile_names(args: argparse.Namespace) -> list[str]:
    if args.nginx:
        base_path = "/var/log/nginx/access.log"
    else:
        base_path = "/var/log/zulip/server.log"

    for term in args.filter_terms:
        date_term = re.match(r"2\d\d\d-\d\d-\d\d", term)
        if not date_term:
            continue
        # They're limiting to a specific day; find the right logfile
        # which is going to have any matches
        rotations = int(
            (datetime.now(tz=timezone.utc).date() - date.fromisoformat(date_term[0]))
            / timedelta(days=1)
        )
        access_log_retention_days = int(
            get_config(get_config_file(), "application_server", "access_log_retention_days", "14")
        )
        if rotations > access_log_retention_days:
            raise RuntimeError(f"Date is too old (more than {access_log_retention_days} days ago)")
        if rotations == 0:
            return [base_path]
        if rotations == 1:
            return [f"{base_path}.1"]
        else:
            return [f"{base_path}.{rotations}.gz"]

    logfile_names = [base_path]
    if args.all_logs:
        logfile_count = 15
    elif args.log_files is not None:
        logfile_count = args.log_files
    else:
        # Detect if there was a logfile rotation in the last
        # (min-hours)-ish hours, and if so include the previous
        # logfile as well.
        logfile_count = 1
        try:
            current_size = os.path.getsize(base_path)
            past_size = os.path.getsize(base_path + ".1")
            if current_size < (args.min_hours / 24.0) * past_size:
                logfile_count = 2
        except FileNotFoundError:
            pass
    for n in range(1, logfile_count):
        logname = f"{base_path}.{n}"
        if n > 1:
            logname += ".gz"
        logfile_names.append(logname)
    return logfile_names


month_no_to_name_lookup = {f"{k:02d}": v for k, v in enumerate(calendar.month_abbr)}
month_name_to_no_lookup = {v: k for k, v in month_no_to_name_lookup.items()}


def convert_from_nginx_date(date: str) -> str:
    day_of_month, month_abbr, year = date.split("/")
    return f"{year}-{month_name_to_no_lookup[month_abbr]}-{day_of_month}"


def convert_to_nginx_date(date: str) -> str:
    year, month_no, day_of_month = date.split("-")
    return f"{day_of_month}/{month_no_to_name_lookup[month_no]}/{year}"


def parse_filters(
    args: argparse.Namespace,
) -> tuple[set[FilterType], list[FilterFunc], list[str]]:
    # The heuristics below are not intended to be precise -- they
    # certainly count things as "IPv4" or "IPv6" addresses that are
    # invalid.  However, we expect the input here to already be
    # reasonably well-formed.

    filter_types = set()
    filter_funcs = []
    filter_terms = []

    if args.events and not args.nginx:
        logging.warning("Adding --nginx -- /events requests do not appear in Django logs.")
        args.nginx = True

    for filter_term in args.filter_terms:
        if re.match(r"[1-5][0-9][0-9]$", filter_term):
            filter_func = lambda m, t=filter_term: m["code"] == t
            filter_type = FilterType.STATUS
            if not args.nginx and filter_term == "502":
                logging.warning("Adding --nginx -- 502's do not appear in Django logs.")
                args.nginx = True
        elif re.match(r"[1-5]xx$", filter_term):
            filter_term = filter_term[0]
            filter_func = lambda m, t=filter_term: m["code"].startswith(t)
            filter_type = FilterType.STATUS
        elif re.match(r"\d+$", filter_term):
            if args.nginx:
                raise parser().error("Cannot parse user-ids with nginx logs; try without --nginx")
            filter_func = lambda m, t=filter_term: m["user_id"] == t
            filter_type = FilterType.USER_ID
        elif re.match(r"\d{1,3}(\.\d{1,3}){3}$", filter_term):
            filter_func = lambda m, t=filter_term: m["ip"] == t
            filter_type = FilterType.CLIENT_IP
        elif re.match(r"([a-f0-9:]+:+){1,7}[a-f0-9]+$", filter_term):
            filter_func = lambda m, t=filter_term: m["ip"] == t
            filter_type = FilterType.CLIENT_IP
        elif re.match(r"DELETE|GET|HEAD|OPTIONS|PATCH|POST|PUT", filter_term):
            filter_func = lambda m, t=filter_term: m["method"].upper() == t
            filter_type = FilterType.METHOD
        elif re.match(r"(2\d\d\d-\d\d-\d\d)( \d(\d(:(\d(\d(:(\d\d?)?)?)?)?)?)?)?", filter_term):
            if args.nginx:
                datetime_parts = filter_term.split(" ")
                filter_term = ":".join(
                    [convert_to_nginx_date(datetime_parts[0]), *datetime_parts[1:]]
                )
                filter_func = lambda m, t=filter_term: f"{m['date']}:{m['time']}".startswith(t)
            else:
                filter_func = lambda m, t=filter_term: f"{m['date']} {m['time']}".startswith(t)
            filter_type = FilterType.DATETIME
        elif re.match(r"[a-z0-9]([a-z0-9-]*[a-z0-9])?$", filter_term.lower()):
            filter_term = filter_term.lower()
            if args.nginx:
                filter_func = lambda m, t=filter_term: m["hostname"].startswith(t + ".")
            else:
                filter_func = lambda m, t=filter_term: m["hostname"] == t
            filter_type = FilterType.HOSTNAME
        elif re.match(r"[a-z0-9-]+(\.[a-z0-9-]+)+$", filter_term.lower()) and re.search(
            r"[a-z-]", filter_term.lower()
        ):
            if not args.nginx:
                raise parser().error("Cannot parse full domains with Python logs; try --nginx")
            filter_term = filter_term.lower()
            filter_func = lambda m, t=filter_term: m["hostname"] == t
            filter_type = FilterType.HOSTNAME
        elif re.match(r"/\S*$", filter_term):
            filter_func = lambda m, t=filter_term: m["path"] == t
            filter_type = FilterType.PATH
            args.all_lines = True
        else:
            raise RuntimeError(
                f"Can't parse {filter_term} as an IP, hostname, user-id, HTTP method, path, or status code."
            )
        if filter_type in filter_types:
            parser().error("Supplied the same time of value more than once, which cannot match!")
        filter_types.add(filter_type)
        filter_funcs.append(filter_func)
        filter_terms.append(filter_term)

    if args.client:
        filter_funcs.append(lambda m, t=args.client: t in m["user_agent"])
        filter_terms.append(args.client)

    return (filter_types, filter_funcs, filter_terms)


def passes_filters(
    string_filters: list[FilterFunc],
    match: Match[str],
    args: argparse.Namespace,
) -> bool:
    if not all(f(match) for f in string_filters):
        return False

    if args.all_lines:
        return True

    path = match["path"]
    if path.startswith("/static/"):
        return args.static
    elif path.startswith("/user_uploads/"):
        return args.uploads
    elif path.startswith(("/user_avatars/", "/avatar/")):
        return args.avatars
    elif re.match(r"/(json|api/v1)/events($|\?|/)", path):
        return args.events
    elif path in ("/api/v1/typing", "/json/typing"):
        return args.typing
    elif re.match(r"/(json|api/v1)/messages($|\?|/)", path):
        return args.messages
    elif path in ("/api/v1/users/me/presence", "/json/users/me/presence"):
        return args.presence
    elif path == "/error_tracing":
        return args.report
    else:
        return not args.no_other


last_match_end: datetime | None = None


def print_line(
    match: Match[str],
    args: argparse.Namespace,
    filter_types: set[FilterType],
    use_color: bool,
) -> None:
    global last_match_end

    if args.full_line:
        print(match.group(0))
        return

    if args.nginx:
        date = convert_from_nginx_date(match["date"])
    else:
        date = match["date"]
    if args.all_logs or (args.log_files is not None and args.log_files > 1):
        ts = date + " " + match["time"]
    else:
        ts = match["time"]

    if match["duration"].endswith("ms"):
        duration_ms = int(match["duration"][:-2])
    else:
        duration_ms = int(float(match["duration"][:-1]) * 1000)

    code = int(match["code"])
    indicator = " "
    color = ""
    if code == 401:
        indicator = ":"
        color = CYAN
    elif code == 499:
        indicator = "-"
        color = GRAY
    elif code >= 400 and code < 499:
        indicator = ">"
        color = OKBLUE
    elif code >= 500 and code <= 599:
        indicator = "!"
        color = FAIL

    if use_color:
        url = f"{BOLD}{match['full_path']}"
    else:
        url = match["full_path"]
        color = ""

    if FilterType.HOSTNAME not in filter_types:
        hostname = match["hostname"]
        if hostname is None:
            hostname = "???." + settings.EXTERNAL_HOST
        elif not args.nginx:
            if hostname != "root":
                hostname += "." + settings.EXTERNAL_HOST
            elif settings.EXTERNAL_HOST == "zulipchat.com":
                hostname = "zulip.com"
            else:
                hostname = settings.EXTERNAL_HOST
        url = "https://" + hostname + url

    user_id = ""
    if not args.nginx and match["user_id"] is not None:
        user_id = match["user_id"] + "@"

    if args.timeline:
        logline_end = datetime.fromisoformat(date + " " + match["time"])
        logline_start = logline_end - timedelta(milliseconds=duration_ms)
        if last_match_end is not None:
            gap_ms = int((logline_start - last_match_end) / timedelta(milliseconds=1))
            if gap_ms > 5000:
                print()
                print(f"========== {int(gap_ms / 1000):>4} second gap ==========")
                print()
            elif gap_ms > 1000:
                print(f"============ {gap_ms:>5}ms gap ============")
            elif gap_ms > 0:
                print(f"------------ {gap_ms:>5}ms gap ------------")
            else:
                print(f"!!!!!!!!!! {abs(gap_ms):>5}ms overlap !!!!!!!!!!")
        if args.all_logs or (args.log_files is not None and args.log_files > 1):
            print(logline_start.isoformat(" ", timespec="milliseconds") + " (start)")
        else:
            print(logline_start.time().isoformat(timespec="milliseconds") + " (start)")
        last_match_end = logline_end

    parts = [
        ts,
        f"{duration_ms:>5}ms",
        f"{user_id:7}" if not args.nginx and FilterType.USER_ID not in filter_types else None,
        f"{match['ip']:39}" if FilterType.CLIENT_IP not in filter_types else None,
        indicator + match["code"],
        f"{match['method']:6}",
        url,
    ]

    print(color + " ".join(p for p in parts if p is not None) + (ENDC if use_color else ""))


if __name__ == "__main__":
    main()
