|
| 1 | +"""Fetch and cache the HF Inference Router model catalog. |
| 2 | +
|
| 3 | +The router exposes an OpenAI-compatible listing at |
| 4 | +``https://router.huggingface.co/v1/models`` with per-provider availability, |
| 5 | +pricing, context length, and tool-use support. We use it to: |
| 6 | +
|
| 7 | + • Validate ``/model`` switches with live data instead of a hard-coded allowlist. |
| 8 | + • Show the user which providers serve a model, at what price, and whether they |
| 9 | + support tool calls. |
| 10 | + • Derive a reasonable context-window limit for any routed model. |
| 11 | +
|
| 12 | +The listing is cached in-memory for a few minutes so repeated lookups during a |
| 13 | +session are free. On fetch failure we return stale data if we have it, or an |
| 14 | +empty catalog otherwise. |
| 15 | +""" |
| 16 | + |
| 17 | +import logging |
| 18 | +import time |
| 19 | +from dataclasses import dataclass |
| 20 | +from difflib import get_close_matches |
| 21 | +from typing import Optional |
| 22 | + |
| 23 | +import httpx |
| 24 | + |
| 25 | +logger = logging.getLogger(__name__) |
| 26 | + |
| 27 | +_CATALOG_URL = "https://router.huggingface.co/v1/models" |
| 28 | +_CACHE_TTL_SECONDS = 300 |
| 29 | +_HTTP_TIMEOUT_SECONDS = 5.0 |
| 30 | + |
| 31 | +_cache: Optional[dict] = None |
| 32 | +_cache_time: float = 0.0 |
| 33 | + |
| 34 | + |
| 35 | +@dataclass |
| 36 | +class ProviderInfo: |
| 37 | + provider: str |
| 38 | + status: str |
| 39 | + context_length: Optional[int] |
| 40 | + input_price: Optional[float] |
| 41 | + output_price: Optional[float] |
| 42 | + supports_tools: bool |
| 43 | + supports_structured_output: bool |
| 44 | + |
| 45 | + |
| 46 | +@dataclass |
| 47 | +class ModelInfo: |
| 48 | + id: str |
| 49 | + providers: list[ProviderInfo] |
| 50 | + |
| 51 | + @property |
| 52 | + def live_providers(self) -> list[ProviderInfo]: |
| 53 | + return [p for p in self.providers if p.status == "live"] |
| 54 | + |
| 55 | + @property |
| 56 | + def max_context_length(self) -> Optional[int]: |
| 57 | + lengths = [p.context_length for p in self.live_providers if p.context_length] |
| 58 | + return max(lengths) if lengths else None |
| 59 | + |
| 60 | + @property |
| 61 | + def any_supports_tools(self) -> bool: |
| 62 | + return any(p.supports_tools for p in self.live_providers) |
| 63 | + |
| 64 | + |
| 65 | +def _fetch_catalog(force: bool = False) -> dict: |
| 66 | + global _cache, _cache_time |
| 67 | + now = time.time() |
| 68 | + if not force and _cache is not None and now - _cache_time < _CACHE_TTL_SECONDS: |
| 69 | + return _cache |
| 70 | + try: |
| 71 | + resp = httpx.get(_CATALOG_URL, timeout=_HTTP_TIMEOUT_SECONDS) |
| 72 | + resp.raise_for_status() |
| 73 | + _cache = resp.json() |
| 74 | + _cache_time = now |
| 75 | + except Exception as e: |
| 76 | + logger.warning("Failed to fetch HF router catalog: %s", e) |
| 77 | + if _cache is None: |
| 78 | + _cache = {"data": []} |
| 79 | + _cache_time = now |
| 80 | + return _cache |
| 81 | + |
| 82 | + |
| 83 | +def _parse_entry(entry: dict) -> ModelInfo: |
| 84 | + providers = [] |
| 85 | + for p in entry.get("providers", []) or []: |
| 86 | + pricing = p.get("pricing") or {} |
| 87 | + providers.append( |
| 88 | + ProviderInfo( |
| 89 | + provider=p.get("provider", ""), |
| 90 | + status=p.get("status", ""), |
| 91 | + context_length=p.get("context_length"), |
| 92 | + input_price=pricing.get("input"), |
| 93 | + output_price=pricing.get("output"), |
| 94 | + supports_tools=bool(p.get("supports_tools", False)), |
| 95 | + supports_structured_output=bool(p.get("supports_structured_output", False)), |
| 96 | + ) |
| 97 | + ) |
| 98 | + return ModelInfo(id=entry.get("id", ""), providers=providers) |
| 99 | + |
| 100 | + |
| 101 | +def lookup(model_id: str) -> Optional[ModelInfo]: |
| 102 | + """Find a model in the router catalog. |
| 103 | +
|
| 104 | + Accepts ``<org>/<model>`` or ``<org>/<model>:<tag>`` — the tag is stripped |
| 105 | + for lookup. Returns ``None`` if the model isn't listed. |
| 106 | + """ |
| 107 | + bare = model_id.split(":", 1)[0] |
| 108 | + catalog = _fetch_catalog() |
| 109 | + for entry in catalog.get("data", []): |
| 110 | + if entry.get("id") == bare: |
| 111 | + return _parse_entry(entry) |
| 112 | + return None |
| 113 | + |
| 114 | + |
| 115 | +def fuzzy_suggest(model_id: str, limit: int = 3) -> list[str]: |
| 116 | + """Return the closest model ids from the catalog.""" |
| 117 | + bare = model_id.split(":", 1)[0] |
| 118 | + catalog = _fetch_catalog() |
| 119 | + ids = [e.get("id", "") for e in catalog.get("data", []) if e.get("id")] |
| 120 | + return get_close_matches(bare, ids, n=limit, cutoff=0.4) |
| 121 | + |
| 122 | + |
| 123 | +def prewarm() -> None: |
| 124 | + """Fetch the catalog so subsequent lookups are instant. Safe to call from |
| 125 | + a background task — swallows failures.""" |
| 126 | + try: |
| 127 | + _fetch_catalog(force=False) |
| 128 | + except Exception: |
| 129 | + pass |
0 commit comments