100 lines
3.4 KiB
Python
100 lines
3.4 KiB
Python
"""JWT token creation and verification for API auth."""
|
|
|
|
import os
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional
|
|
|
|
import jwt
|
|
|
|
def _get_secret_key() -> str:
|
|
"""Load JWT secret - from env, file, or auto-generated."""
|
|
if os.environ.get("TESTING", "").lower() in ("1", "true"):
|
|
return os.environ.get("JWT_SECRET_KEY", "test-jwt-secret-key-minimum-32-chars!!")
|
|
from app.secrets import get_jwt_secret
|
|
key = get_jwt_secret()
|
|
if len(key) < 32:
|
|
import warnings as _warnings
|
|
_warnings.warn(
|
|
f"JWT_SECRET_KEY is {len(key)} chars — minimum 32 recommended",
|
|
UserWarning, stacklevel=2,
|
|
)
|
|
return key
|
|
|
|
|
|
_SECRET_KEY_CACHE: Optional[str] = None
|
|
|
|
ALGORITHM = "HS256"
|
|
ACCESS_TOKEN_EXPIRE_HOURS = 24 # 24 hours
|
|
|
|
|
|
def _get_cached_secret_key() -> str:
|
|
"""Return the JWT secret, caching after first call.
|
|
|
|
The cache is reset when TESTING env var is set so that each test
|
|
module picks up the correct JWT_SECRET_KEY from monkeypatch/env.
|
|
"""
|
|
global _SECRET_KEY_CACHE
|
|
# In test mode, always re-read from env to respect monkeypatch
|
|
if os.environ.get("TESTING", "").lower() in ("1", "true"):
|
|
return os.environ.get("JWT_SECRET_KEY", "test-jwt-secret-key-minimum-32-chars!!")
|
|
if _SECRET_KEY_CACHE is None:
|
|
_SECRET_KEY_CACHE = _get_secret_key()
|
|
return _SECRET_KEY_CACHE
|
|
|
|
|
|
def create_access_token(
|
|
user_id: str,
|
|
email: str,
|
|
expires_delta: Optional[timedelta] = None,
|
|
token_id: Optional[str] = None,
|
|
typ: str = "session",
|
|
omit_exp: bool = False,
|
|
extra_claims: Optional[dict] = None,
|
|
) -> str:
|
|
"""Create a JWT. `typ` is "session" (interactive login) or "pat" (long-lived).
|
|
|
|
If `omit_exp=True`, no `exp` claim is embedded. This is used by PATs with
|
|
"no expiry" — the authoritative expiry check is the DB row in
|
|
`personal_access_tokens.expires_at`, and a claim-less JWT avoids the
|
|
misleading ~100y horizon that previously pretended to be "never".
|
|
|
|
`extra_claims` merges arbitrary key/value pairs into the JWT payload
|
|
after the reserved identity/metadata claims. Reserved keys (sub, email,
|
|
typ, iat, jti, exp) are protected — they cannot be overridden by the
|
|
caller.
|
|
|
|
No ``role`` claim — authorization is derived from
|
|
``user_group_members`` at request time via ``app.auth.access.is_user_admin``.
|
|
The JWT carries only identity (``sub``, ``email``) and token metadata.
|
|
"""
|
|
payload = {
|
|
"sub": user_id,
|
|
"email": email,
|
|
"typ": typ,
|
|
"iat": datetime.now(timezone.utc),
|
|
"jti": token_id or uuid.uuid4().hex,
|
|
}
|
|
if not omit_exp:
|
|
expire = datetime.now(timezone.utc) + (
|
|
expires_delta or timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS)
|
|
)
|
|
payload["exp"] = expire
|
|
if extra_claims:
|
|
_reserved = {"sub", "email", "typ", "iat", "jti", "exp"}
|
|
for k, v in extra_claims.items():
|
|
if k in _reserved:
|
|
continue
|
|
payload[k] = v
|
|
return jwt.encode(payload, _get_cached_secret_key(), algorithm=ALGORITHM)
|
|
|
|
|
|
def verify_token(token: str) -> Optional[dict]:
|
|
"""Verify and decode a JWT token. Returns payload dict or None."""
|
|
try:
|
|
payload = jwt.decode(token, _get_cached_secret_key(), algorithms=[ALGORITHM])
|
|
return payload
|
|
except jwt.ExpiredSignatureError:
|
|
return None
|
|
except jwt.InvalidTokenError:
|
|
return None
|