#!/usr/bin/env python3
# Processes updates to PostgreSQL full-text search for new/edited messages.
#
# Zulip manages its PostgreSQL full-text search as follows.  When the
# content of a message is modified, a PostgreSQL trigger logs the
# message ID to the `fts_update_log` table.  In the background, this
# program processes `fts_update_log`, updating the PostgreSQL full-text
# search column search_tsvector in the main zerver_message.
#
# There are four cases this has to cover:
#
# 1. Running in development, with a venv but no
#    /home/zulip/deployments/current nor /etc/zulip/zulip.conf
#
# 2. Running in production, with postgresql on the same machine as the
#    frontend, with a venv and both /home/zulip/deployments/current
#    and /etc/zulip/zulip.conf
#
# 3. Running in production, on a postgresql machine which is not the
#    frontend, with a /etc/zulip/zulip.conf but no venv nor
#    /home/zulip/deployments/current
#
# 4. Running in production, on an application frontend server connected
#    to a remote postgresql server, because we cannot run code _on_ the
#    PostgreSQL server, such as in docker-zulip.
#
# Because of case (3), we cannot rely on functions from outside this
# file (e.g. provided by scripts.lib.zulip_tools).  For case (1),
# however, we wish to import `zerver.settings` since we have no
# /etc/zulip/zulip.conf to read, and that case _requires_ loading the
# venv.
#
# We also must handle the cases where we are run as the `nagios` user,
# which may not have permission to read all configuration files, and
# thus (2) will look like (3).

import argparse
import configparser
import logging
import os
import select
import sys
import time

import psycopg2
import psycopg2.extensions

BATCH_SIZE = 1000

parser = argparse.ArgumentParser()
parser.add_argument("--quiet", action="store_true")
parser.add_argument("--nagios-check", action="store_true")
parser.add_argument("--nagios-user")
options = parser.parse_args()

logging.Formatter.converter = time.gmtime
logging.basicConfig(format="%(asctime)s %(levelname)s: %(message)s")
logger = logging.getLogger("process_fts_updates")
if options.quiet:
    logger.setLevel(logging.INFO)
else:
    logger.setLevel(logging.DEBUG)


def update_fts_columns(cursor: psycopg2.extensions.cursor) -> int:
    cursor.execute(
        "SELECT id, message_id FROM fts_update_log LIMIT %s;",
        [BATCH_SIZE],
    )
    ids = []
    for id, message_id in cursor.fetchall():
        if USING_PGROONGA:
            cursor.execute(
                "UPDATE zerver_message SET "
                "search_pgroonga = "
                "escape_html(subject) || ' ' || rendered_content "
                "WHERE id = %s",
                (message_id,),
            )
        cursor.execute(
            "UPDATE zerver_message SET "
            "search_tsvector = to_tsvector('zulip.english_us_search', "
            "subject || rendered_content) "
            "WHERE id = %s",
            (message_id,),
        )
        ids.append(id)
    cursor.execute("DELETE FROM fts_update_log WHERE id = ANY(%s)", (ids,))
    return len(ids)


def am_master(cursor: psycopg2.extensions.cursor) -> bool:
    cursor.execute("SELECT pg_is_in_recovery()")
    return not cursor.fetchall()[0][0]


def get_config(
    config_file: configparser.RawConfigParser,
    section: str,
    key: str,
    default_value: str = "",
) -> str:
    if config_file.has_option(section, key):
        return config_file.get(section, key)
    return default_value


pg_args = {}

USING_PGROONGA = False
try:
    # Case (1); we insert the path to the development root.
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../..")))

    # Cases (2) and (4); we insert the path to the production root.
    # This likely works out the same as the above path.
    #
    # We insert this path after the above, so that if running this
    # command from a specific non-current Zulip deployment, we prefer
    # that deployment's libraries.
    sys.path.insert(1, "/home/zulip/deployments/current")

    # For cases (2) and (4), we also need to set up the virtualenv, so we
    # can read the Django settings.
    from scripts.lib.setup_path import setup_path

    setup_path()

    os.environ["DJANGO_SETTINGS_MODULE"] = "zproject.settings"
    from django.conf import settings

    pg_args["host"] = settings.DATABASES["default"]["HOST"]
    pg_args["port"] = settings.DATABASES["default"].get("PORT")
    pg_args["password"] = settings.DATABASES["default"].get("PASSWORD")
    pg_args["user"] = settings.DATABASES["default"]["USER"]
    pg_args["dbname"] = settings.DATABASES["default"]["NAME"]
    pg_args["sslmode"] = settings.DATABASES["default"]["OPTIONS"].get("sslmode")
    pg_args["connect_timeout"] = "600"
    USING_PGROONGA = settings.USING_PGROONGA
except (ImportError, PermissionError):
    # Case (3): we know that the PostgreSQL server is on this host.
    pg_args["user"] = "zulip"

    config_file = configparser.RawConfigParser()
    config_file.read("/etc/zulip/zulip.conf")

    if get_config(config_file, "machine", "pgroonga", "false").lower() in [
        "1",
        "y",
        "t",
        "yes",
        "true",
        "enable",
        "enabled",
    ]:
        USING_PGROONGA = True

    pg_args["user"] = get_config(config_file, "postgresql", "database_user", "zulip")
    if pg_args["user"] != "zulip":
        secrets_file = configparser.RawConfigParser()
        secrets_file.read("/etc/zulip/zulip-secrets.conf")
        pg_args["password"] = get_config(secrets_file, "secrets", "postgres_password")
        pg_args["host"] = "localhost"
    pg_args["dbname"] = get_config(config_file, "postgresql", "database_name", "zulip")

conn: psycopg2.extensions.connection | None

if options.nagios_check:
    # Nagios connects as itself, unless you specify otherwise
    if options.nagios_user:
        pg_args["user"] = options.nagios_user
    else:
        del pg_args["user"]
    # connection_factory=None lets mypy understand the return type
    conn = psycopg2.connect(connection_factory=None, **pg_args)
    cursor = conn.cursor()
    cursor.execute("SELECT count(*) FROM fts_update_log")
    num = cursor.fetchall()[0][0]

    # nagios exit codes
    states = {
        "OK": 0,
        "WARNING": 1,
        "CRITICAL": 2,
        "UNKNOWN": 3,
    }

    state = "OK"
    if num > 5:
        state = "CRITICAL"
    print(f"{state}: {num} rows in fts_update_log table")
    sys.exit(states[state])


conn = None

retries = 1

while True:
    try:
        if conn is None:
            # connection_factory=None lets mypy understand the return type
            conn = psycopg2.connect(connection_factory=None, **pg_args)
            cursor = conn.cursor()
            retries = 30

            conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)

            first_check = True
            while not am_master(cursor):
                if first_check:
                    first_check = False
                    logger.warning("In recovery; sleeping")
                time.sleep(5)

            logger.debug("process_fts_updates: listening for search index updates")

            cursor.execute("LISTEN fts_update_log;")
            # Catch up on any historical columns
            while True:
                rows_updated = update_fts_columns(cursor)
                logger.log(
                    logging.INFO if rows_updated > 0 else logging.DEBUG,
                    "process_fts_updates: Processed %d rows catching up",
                    rows_updated,
                )

                if rows_updated != BATCH_SIZE:
                    # We're caught up, so proceed to the listening for updates phase.
                    break

        # TODO: If we go back into recovery, we should stop processing updates
        if select.select([conn], [], [], 30) != ([], [], []):
            conn.poll()
            while conn.notifies:
                conn.notifies.pop()
                update_fts_columns(cursor)
    except psycopg2.OperationalError as e:
        retries -= 1
        if retries <= 0:
            raise
        logger.info(e.pgerror, exc_info=True)
        logger.info("Sleeping and reconnecting")
        time.sleep(5)
        if conn is not None:
            conn.close()
            conn = None
    except KeyboardInterrupt:
        print(sys.argv[0], "exited after receiving KeyboardInterrupt")
        break
