Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 540437a

Browse files
authored
feat(quota): daily Opus cap + HF-org gate + cap dialog (#72)
1 parent 5d357ba commit 540437a

12 files changed

Lines changed: 794 additions & 39 deletions

File tree

‎backend/dependencies.py‎

Lines changed: 109 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
1818
AUTH_ENABLED = bool(os.environ.get("OAUTH_CLIENT_ID", ""))
19+
HF_EMPLOYEE_ORG = os.environ.get("HF_EMPLOYEE_ORG", "huggingface")
1920

2021
# Simple in-memory token cache: token -> (user_info, expiry_time)
2122
_token_cache: dict[str, tuple[dict[str, Any], float]] = {}
@@ -28,8 +29,13 @@
2829
"user_id": "dev",
2930
"username": "dev",
3031
"authenticated": True,
32+
"plan": "org", # Dev runs at the Pro/Org quota tier so local testing isn't capped.
3133
}
3234

35+
# Plan field discovery — log the whoami-v2 shape once at DEBUG so we can
36+
# confirm the actual key in production without hammering the HF API.
37+
_WHOAMI_SHAPE_LOGGED = False
38+
3339

3440
async def _validate_token(token: str) -> dict[str, Any] | None:
3541
"""Validate a token against HF OAuth userinfo endpoint.
@@ -74,12 +80,86 @@ def _user_from_info(user_info: dict[str, Any]) -> dict[str, Any]:
7480
}
7581

7682

83+
def _normalize_plan(whoami: dict[str, Any]) -> str:
84+
"""Map an HF /api/whoami-v2 payload to one of: 'free' | 'pro' | 'org'.
85+
86+
The exact field shape in whoami-v2 isn't documented for our purposes,
87+
so we try a handful of likely keys and fall back to 'free'. The first
88+
call logs the raw shape at DEBUG (see `_fetch_user_plan`) so we can
89+
pin the real key post-deploy.
90+
"""
91+
plan_str = ""
92+
for key in ("plan", "type", "accountType"):
93+
val = whoami.get(key)
94+
if isinstance(val, str) and val:
95+
plan_str = val.lower()
96+
break
97+
98+
if not plan_str:
99+
if whoami.get("isPro") is True or whoami.get("is_pro") is True:
100+
return "pro"
101+
102+
if "pro" in plan_str or "enterprise" in plan_str or "team" in plan_str:
103+
return "pro"
104+
105+
# Org tier: anyone in a paid / enterprise org. We don't pay for this
106+
# right now, but the "pro" cap applies identically.
107+
orgs = whoami.get("orgs") or []
108+
if isinstance(orgs, list):
109+
for org in orgs:
110+
if isinstance(org, dict):
111+
org_plan = str(org.get("plan") or org.get("type") or "").lower()
112+
if "pro" in org_plan or "enterprise" in org_plan or "team" in org_plan:
113+
return "org"
114+
115+
return "free"
116+
117+
118+
async def _fetch_user_plan(token: str) -> str:
119+
"""Look up the user's HF plan via /api/whoami-v2.
120+
121+
Returns 'free' | 'pro' | 'org'. Non-200, network errors, or an unknown
122+
payload shape all collapse to 'free' — safe default; we'd rather under-
123+
grant the Pro cap than over-grant it on bad data.
124+
"""
125+
global _WHOAMI_SHAPE_LOGGED
126+
async with httpx.AsyncClient(timeout=5.0) as client:
127+
try:
128+
resp = await client.get(
129+
f"{OPENID_PROVIDER_URL}/api/whoami-v2",
130+
headers={"Authorization": f"Bearer {token}"},
131+
)
132+
if resp.status_code != 200:
133+
return "free"
134+
whoami = resp.json()
135+
except httpx.HTTPError:
136+
return "free"
137+
except ValueError:
138+
return "free"
139+
140+
if not _WHOAMI_SHAPE_LOGGED:
141+
_WHOAMI_SHAPE_LOGGED = True
142+
logger.debug(
143+
"whoami-v2 payload keys: %s (sample values: plan=%r type=%r isPro=%r)",
144+
sorted(whoami.keys()) if isinstance(whoami, dict) else type(whoami).__name__,
145+
whoami.get("plan") if isinstance(whoami, dict) else None,
146+
whoami.get("type") if isinstance(whoami, dict) else None,
147+
whoami.get("isPro") if isinstance(whoami, dict) else None,
148+
)
149+
150+
if not isinstance(whoami, dict):
151+
return "free"
152+
return _normalize_plan(whoami)
153+
154+
77155
async def _extract_user_from_token(token: str) -> dict[str, Any] | None:
78156
"""Validate a token and return a user dict, or None."""
79157
user_info = await _validate_token(token)
80-
if user_info:
81-
return _user_from_info(user_info)
82-
return None
158+
if user_info is None:
159+
return None
160+
user = _user_from_info(user_info)
161+
user["plan"] = await _fetch_user_plan(token)
162+
return user
83163

84164

85165
async def check_org_membership(token: str, org_name: str) -> bool:
@@ -141,3 +221,29 @@ async def get_current_user(request: Request) -> dict[str, Any]:
141221
)
142222

143223

224+
def _extract_token(request: Request) -> str | None:
225+
"""Pull the HF access token from the Authorization header or cookie.
226+
227+
Mirrors the lookup order used by ``get_current_user``.
228+
"""
229+
auth_header = request.headers.get("Authorization", "")
230+
if auth_header.startswith("Bearer "):
231+
return auth_header[7:]
232+
return request.cookies.get("hf_access_token")
233+
234+
235+
async def require_huggingface_org_member(request: Request) -> bool:
236+
"""Return True if the caller is a member of the ``huggingface`` org.
237+
238+
Used to gate endpoints that can push a session onto an Anthropic model
239+
billed to the Space's ``ANTHROPIC_API_KEY``. Returns True unconditionally
240+
in dev mode so local testing isn't blocked.
241+
"""
242+
if not AUTH_ENABLED:
243+
return True
244+
token = _extract_token(request)
245+
if not token:
246+
return False
247+
return await check_org_membership(token, HF_EMPLOYEE_ORG)
248+
249+

0 commit comments

Comments
 (0)