#!/usr/bin/env python3

import contextlib
import sys
import time
from collections.abc import Iterable, Sequence
from typing import Any

sys.path.append("/home/zulip/deployments/current")
from scripts.lib.setup_path import setup_path

setup_path()

import bmemcached
from prometheus_client import start_http_server
from prometheus_client.core import REGISTRY, CounterMetricFamily, GaugeMetricFamily
from prometheus_client.metrics_core import Metric
from prometheus_client.registry import Collector
from prometheus_client.samples import Sample
from typing_extensions import override

from zproject import settings


class MemcachedCollector(Collector):
    @override
    def collect(self) -> Iterable[Metric]:
        def gauge(
            name: str,
            doc: str,
            value: float | bytes | None = None,
            labels: Sequence[str] | None = None,
        ) -> GaugeMetricFamily:
            return GaugeMetricFamily(
                f"memcached_{name}", doc, float(value) if value else None, labels
            )

        def counter(
            name: str,
            doc: str,
            labels: Iterable[str] | None = None,
        ) -> CounterMetricFamily:
            return CounterMetricFamily(
                f"memcached_{name}", doc, labels=list(labels) if labels is not None else None
            )

        def counter_value(
            name: str,
            doc: str,
            value: bytes | float,
            labels: dict[str, str] | None = None,
        ) -> CounterMetricFamily:
            if labels is None:
                labels = {}
            metric = counter(name, doc, labels=labels.keys())
            # CounterMetricFamily strips off a trailing "_total" from
            # the metric's .name, and force-appends "_total" when you
            # use .add_metric.  Since we have counters that don't end
            # in _total, manually add samples, re-appending _total
            # only if we originally ended with it.
            append_total = name.endswith("_total")
            metric.samples.append(
                Sample(metric.name + ("_total" if append_total else ""), labels, float(value), None)
            )
            return metric

        cache: dict[str, Any] = settings.CACHES["default"]
        client = None
        with contextlib.suppress(Exception):
            client = bmemcached.Client((cache["LOCATION"],), **cache["OPTIONS"])
        yield gauge("up", "If memcached is up", value=client is not None)

        if client is None:
            return

        raw_stats = client.stats()
        stats: dict[str, bytes] = next(iter(raw_stats.values()))

        version_gauge = gauge(
            "version", "The version of this memcached server.", labels=["version"]
        )
        version_gauge.add_metric(value=1, labels=[stats["version"].decode()])
        yield version_gauge
        yield counter_value(
            "uptime_seconds", "Number of seconds since the server started.", value=stats["uptime"]
        )

        commands_counter = counter(
            "commands_total",
            "Total number of all requests broken down by command (get, set, etc.) and status.",
            labels=["command", "status"],
        )
        for op in ("get", "delete", "incr", "decr", "cas", "touch"):
            commands_counter.add_metric(value=float(stats[f"{op}_hits"]), labels=[op, "hit"])
            commands_counter.add_metric(value=float(stats[f"{op}_misses"]), labels=[op, "miss"])
        commands_counter.add_metric(value=float(stats["cas_badval"]), labels=["cas", "badval"])
        commands_counter.add_metric(value=float(stats["cmd_flush"]), labels=["flush", "hit"])

        # memcached includes cas operations again in cmd_set
        commands_counter.add_metric(
            value=int(stats["cmd_set"])
            - (int(stats["cas_hits"]) + int(stats["cas_hits"]) + int(stats["cas_badval"])),
            labels=["set", "hit"],
        )
        yield commands_counter

        yield counter_value(
            "process_user_cpu_seconds_total",
            "Accumulated user time for this process.",
            value=stats["rusage_user"],
        )

        yield counter_value(
            "process_system_cpu_seconds_total",
            "Accumulated system time for this process.",
            value=stats["rusage_system"],
        )

        yield gauge(
            "current_bytes", "Current number of bytes used to store items.", value=stats["bytes"]
        )
        yield gauge(
            "limit_bytes",
            "Number of bytes this server is allowed to use for storage.",
            value=stats["limit_maxbytes"],
        )
        yield gauge(
            "current_items",
            "Current number of items stored by this instance.",
            value=stats["curr_items"],
        )
        yield counter_value(
            "items_total",
            "Total number of items stored during the life of this instance.",
            value=stats["total_items"],
        )

        yield counter_value(
            "read_bytes_total",
            "Total number of bytes read by this server from network.",
            value=stats["bytes_read"],
        )
        yield counter_value(
            "written_bytes_total",
            "Total number of bytes sent by this server to network.",
            value=stats["bytes_written"],
        )

        yield gauge(
            "current_connections",
            "Current number of open connections.",
            value=stats["curr_connections"],
        )
        yield counter_value(
            "connections_total",
            "Total number of connections opened since the server started running.",
            value=stats["total_connections"],
        )
        yield counter_value(
            "connections_rejected_total",
            "Total number of connections rejected due to hitting the memcached's -c limit in maxconns_fast mode.",
            value=stats["rejected_connections"],
        )
        yield counter_value(
            "connections_yielded_total",
            "Total number of connections yielded running due to hitting the memcached's -R limit.",
            value=stats["conn_yields"],
        )
        yield counter_value(
            "connections_listener_disabled_total",
            "Number of times that memcached has hit its connections limit and disabled its listener.",
            value=stats["listen_disabled_num"],
        )

        yield counter_value(
            "items_evicted_total",
            "Total number of valid items removed from cache to free memory for new items.",
            value=stats["evictions"],
        )
        yield counter_value(
            "items_reclaimed_total",
            "Total number of times an entry was stored using memory from an expired entry.",
            value=stats["reclaimed"],
        )
        if "store_too_large" in stats:
            yield counter_value(
                "item_too_large_total",
                "The number of times an item exceeded the max-item-size when being stored.",
                value=stats["store_too_large"],
            )
        if "store_no_memory" in stats:
            yield counter_value(
                "item_no_memory_total",
                "The number of times an item could not be stored due to no more memory.",
                value=stats["store_no_memory"],
            )

        raw_stats = client.stats("settings")
        settings_stats = next(iter(raw_stats.values()))
        yield counter_value(
            "max_connections",
            "Maximum number of clients allowed.",
            value=settings_stats["maxconns"],
        )
        yield counter_value(
            "max_item_size_bytes",
            "Maximum item size.",
            value=settings_stats["item_size_max"],
        )

        raw_stats = client.stats("slabs")
        slab_stats = next(iter(raw_stats.values()))
        yield counter_value(
            "malloced_bytes",
            "Number of bytes of memory allocated to slab pages.",
            value=slab_stats["total_malloced"],
        )

        slabs = {key.split(":", 1)[0] for key in slab_stats if ":" in key}
        slab_commands = counter(
            "slab_commands_total",
            "Total number of all requests broken down by command (get, set, etc.) and status per slab.",
            labels=["slab", "command", "status"],
        )
        for slab_no in slabs:
            for op in ("get", "delete", "incr", "decr", "cas", "touch"):
                slab_commands.add_metric(
                    labels=[slab_no, op, "hit"], value=slab_stats[f"{slab_no}:{op}_hits"]
                )
            slab_commands.add_metric(
                labels=[slab_no, "cas", "badval"], value=slab_stats[f"{slab_no}:cas_badval"]
            )
            slab_commands.add_metric(
                labels=[slab_no, "set", "hit"],
                value=float(slab_stats[f"{slab_no}:cmd_set"])
                - (
                    float(slab_stats[f"{slab_no}:cas_hits"])
                    + float(slab_stats[f"{slab_no}:cas_badval"])
                ),
            )
        yield slab_commands

        def slab_counter(name: str, doc: str) -> CounterMetricFamily:
            return counter(f"slab_{name}", doc, labels=["slab"])

        def slab_gauge(name: str, doc: str) -> GaugeMetricFamily:
            return gauge(f"slab_{name}", doc, labels=["slab"])

        slab_metrics = {
            "chunk_size": slab_gauge("chunk_size_bytes", "The amount of space each chunk uses."),
            "chunks_per_page": slab_gauge(
                "chunks_per_page", "How many chunks exist within one page."
            ),
            "total_pages": slab_gauge(
                "current_pages", "Total number of pages allocated to the slab class."
            ),
            "total_chunks": slab_gauge(
                "current_chunks", "Total number of chunks allocated to the slab class."
            ),
            "used_chunks": slab_gauge(
                "chunks_used", "How many chunks have been allocated to items."
            ),
            "free_chunks": slab_gauge(
                "chunks_free", "Chunks not yet allocated to items, or freed via delete."
            ),
            "free_chunks_end": slab_gauge(
                "chunks_free_end", "Number of free chunks at the end of the last allocated page."
            ),
        }
        for slab_no in slabs:
            for key, slab_metric in slab_metrics.items():
                slab_metric.samples.append(
                    Sample(
                        slab_metric.name,
                        {"slab": slab_no},
                        slab_stats.get(f"{slab_no}:{key}", b"0"),
                    )
                )
        for slab_metric in slab_metrics.values():
            yield slab_metric

        raw_stats = client.stats("items")
        item_stats = next(iter(raw_stats.values()))
        item_hits_counter = counter(
            "slab_lru_hits_total", "Number of get_hits to the LRU.", labels=["slab", "lru"]
        )
        for slab_no in slabs:
            for lru in ("hot", "warm", "cold", "temp"):
                item_hits_counter.add_metric(
                    labels=[slab_no, lru],
                    value=item_stats.get(f"items:{slab_no}:hits_to_{lru}", b"0"),
                )
        yield item_hits_counter

        item_metrics = {
            "number": slab_gauge(
                "current_items", "Number of items presently stored in this class."
            ),
            "number_hot": slab_gauge(
                "hot_items", "Number of items presently stored in the HOT LRU."
            ),
            "number_warm": slab_gauge(
                "warm_items", "Number of items presently stored in the WARM LRU."
            ),
            "number_cold": slab_gauge(
                "cold_items", "Number of items presently stored in the COLD LRU."
            ),
            "number_temp": slab_gauge(
                "temporary_items", "Number of items presently stored in the TEMPORARY LRU."
            ),
            "age_hot": slab_gauge("hot_age_seconds", "Age of the oldest item in HOT LRU."),
            "age_warm": slab_gauge("warm_age_seconds", "Age of the oldest item in WARM LRU."),
            "age": slab_gauge("items_age_seconds", "Age of the oldest item in the LRU."),
            "mem_requested": slab_gauge(
                "mem_requested_bytes", "Number of bytes requested to be stored in this LRU."
            ),
            "evicted": slab_counter(
                "items_evicted_total",
                "Total number of times an item had to be evicted from the LRU before it expired.",
            ),
            "evicted_nonzero": slab_counter(
                "items_evicted_nonzero_total",
                "Number of times an item which had an explicit expire time set had to be evicted from the LRU before it expired.",
            ),
            "evicted_time": slab_gauge(
                "items_evicted_time_seconds",
                "Seconds since the last access for the most recent item evicted from this class.",
            ),
            "outofmemory": slab_counter(
                "items_outofmemory_total",
                " Number of times the underlying slab class was unable to store a new item.",
            ),
            "tailrepairs": slab_counter(
                "items_tailrepairs_total",
                "Number of times we self-healed a slab with a refcount leak.",
            ),
            "reclaimed": slab_counter(
                "items_reclaimed_total",
                "Number of times an entry was stored using memory from an expired entry.",
            ),
            "expired_unfetched": slab_counter(
                "items_expired_unfetched_total",
                "Number of expired items reclaimed from the LRU which were never touched after being set.",
            ),
            "evicted_unfetched": slab_counter(
                "items_evicted_unfetched_total",
                "Number of valid items evicted from the LRU which were never touched after being set.",
            ),
            "evicted_active": slab_counter(
                "items_evicted_active_total",
                "Number of valid items evicted from the LRU which were recently touched but were evicted before being moved to the top of the LRU again.",
            ),
            "crawler_reclaimed": slab_counter(
                "items_crawler_reclaimed_total", "Number of items freed by the LRU Crawler."
            ),
            "lrutail_reflocked": slab_counter(
                "items_lrutail_reflocked_total",
                "Number of items found to be refcount locked in the LRU tail.",
            ),
            "moves_to_cold": slab_counter(
                "moves_to_cold_total", "Number of items moved from HOT or WARM into COLD."
            ),
            "moves_to_warm": slab_counter(
                "moves_to_warm_total", "Number of items moved from COLD to WARM."
            ),
            "moves_within_lru": slab_counter(
                "moves_within_lru_total",
                "Number of times active items were bumped within HOW or WARM.",
            ),
        }
        for slab_no in slabs:
            for key, item_metric in item_metrics.items():
                item_metric.samples.append(
                    Sample(
                        item_metric.name,
                        {"slab": slab_no},
                        item_stats.get(f"items:{slab_no}:{key}", b"0"),
                    )
                )
        for item_metric in item_metrics.values():
            yield item_metric

        raw_stats = client.stats("sizes")
        sizes_stats = next(iter(raw_stats.values()))
        if sizes_stats.get("sizes_status") != b"disabled" and sizes_stats != {}:
            sizes = sorted([int(x) for x in sizes_stats])
            yield gauge(
                "item_max_bytes", "Largest item (rounded to 32 bytes) in bytes.", value=sizes[-1]
            )

        client.disconnect_all()


if __name__ == "__main__":
    REGISTRY.register(MemcachedCollector())
    start_http_server(11212)
    while True:
        time.sleep(60)
