#!/usr/bin/env python3

"""Nagios plugin to check the difference between the primary and
replica PostgreSQL servers' xlog location.  Requires that the user this
connects to PostgreSQL as has been granted the `pg_monitor` role.

This can only use stdlib modules from python.
"""

import configparser
import re
import subprocess
import sys


def get_config(
    config_file: configparser.RawConfigParser,
    section: str,
    key: str,
    default_value: str = "",
) -> str:
    if config_file.has_option(section, key):
        return config_file.get(section, key)
    return default_value


config_file = configparser.RawConfigParser()
config_file.read("/etc/zulip/zulip.conf")

states = {
    "OK": 0,
    "WARNING": 1,
    "CRITICAL": 2,
    "UNKNOWN": 3,
}

MAXSTATE = 0


def report(state: str, msg: str) -> None:
    global MAXSTATE
    print(f"{state}: {msg}")
    MAXSTATE = max(MAXSTATE, states[state])


def run_sql_query(query: str) -> list[list[str]]:
    command = [
        "psql",
        "-t",  # Omit header line
        "-A",  # Don't pad with spaces
        "-z",  # Separate columns with nulls
        "-v",
        "ON_ERROR_STOP=1",
        "-d",
        get_config(config_file, "postgresql", "database_name", "zulip"),
        # No -U; nagios connects as itself
        "-c",
        f"SELECT {query}",
    ]
    try:
        output = subprocess.check_output(command, stderr=subprocess.STDOUT, text=True).strip()
        if not output:
            return []
        return [x.split("\0") for x in output.split("\n")]
    except subprocess.CalledProcessError as e:
        report("CRITICAL", f"psql failed: {e}: {e.output}")
        sys.exit(MAXSTATE)


def loc_to_abs_offset(loc_str: str) -> int:
    m = re.match(r"^\s*([0-9a-fA-F]+)/([0-9a-fA-F]+)\s*$", loc_str)
    if not m:
        raise ValueError("Unknown xlog location format: " + loc_str)
    (xlog_file, file_offset) = (m.group(1), m.group(2))

    # From PostgreSQL 9.2's pg_xlog_location_diff:
    #   result = XLogFileSize * (xlogid1 - xlogid2) + xrecoff1 - xrecoff2
    # Taking xlogid2 and xrecoff2 to be zero to get the absolute offset:
    #   result = XLogFileSize * xlogid1 + xrecoff1
    #
    # xlog_internal.h says:
    #   #define XLogSegSize ((uint32) XLOG_SEG_SIZE)
    #   #define XLogSegsPerFile (((uint32) 0xffffffff) / XLogSegSize)
    #   #define XLogFileSize (XLogSegsPerFile * XLogSegSize)
    #
    # Since XLOG_SEG_SIZE is normally 16MB, XLogFileSize comes out to 0xFF000000
    return 0xFF000000 * int(xlog_file, 16) + int(file_offset, 16)


is_in_recovery = run_sql_query("pg_is_in_recovery()")

if is_in_recovery[0][0] == "t":
    replication_info = run_sql_query(
        "sender_host, status, pg_last_wal_replay_lsn(), pg_last_wal_receive_lsn()"
        " from pg_stat_wal_receiver"
    )
    if not replication_info:
        report("CRITICAL", "Replaying WAL logs from backups")
    else:
        (primary_server, state, replay_loc, recv_loc) = replication_info[0]
        recv_offset = loc_to_abs_offset(recv_loc)
        replay_lag = recv_offset - loc_to_abs_offset(replay_loc)

        if state != "streaming":
            report("CRITICAL", f"replica is in state {state}, not streaming")

        msg = f"replica is {replay_lag} bytes behind in replay of WAL logs from {primary_server}"
        if replay_lag > 5 * 16 * 1024**2:
            report("CRITICAL", msg)
        elif replay_lag > 16 * 1024**2:
            report("WARNING", msg)
        else:
            report("OK", msg)

else:
    replication_info = run_sql_query(
        "client_addr, state, sent_lsn, write_lsn, flush_lsn, replay_lsn from pg_stat_replication"
    )
    if not replication_info:
        report("CRITICAL", "No replicas!")
    elif len(replication_info) == 1:
        report("WARNING", "Only one replica!")
    else:
        report("OK", f"Found {len(replication_info)} replicas")

    for replica in replication_info:
        (client_addr, state, sent_lsn, write_lsn, flush_lsn, replay_lsn) = replica
        if state != "streaming":
            report("CRITICAL", f"replica {client_addr} is in state {state}, not streaming")

        sent_offset = loc_to_abs_offset(sent_lsn)
        lag: dict[str, int] = {}
        lag["write"] = sent_offset - loc_to_abs_offset(write_lsn)
        lag["flush"] = sent_offset - loc_to_abs_offset(flush_lsn)
        lag["replay"] = sent_offset - loc_to_abs_offset(replay_lsn)
        for lag_type in ("write", "flush", "replay"):
            lag_bytes = lag[lag_type]
            msg = f"replica {client_addr} is {lag_bytes} bytes behind in {lag_type} of WAL logs"
            if lag_bytes > 5 * 16 * 1024**2:
                report("CRITICAL", msg)
            elif lag_bytes > 16 * 1024**2:
                report("WARNING", msg)
            else:
                report("OK", msg)

sys.exit(MAXSTATE)
