agnes-the-ai-analyst/app/auth/jwt.py
ZdenekSrotyr e25a7aba7d fix: resolve JWT secret key test isolation issue
Replace module-level SECRET_KEY cache with lazy _get_cached_secret_key()
that re-reads env vars in test mode. This fixes 20 test failures caused
by JWT secret mismatch when test modules load in different orders.
2026-04-12 14:05:41 +02:00

74 lines
2.2 KiB
Python

"""JWT token creation and verification for API auth."""
import os
import uuid
from datetime import datetime, timedelta, timezone
from typing import Optional
import jwt
def _get_secret_key() -> str:
"""Load JWT secret - from env, file, or auto-generated."""
if os.environ.get("TESTING", "").lower() in ("1", "true"):
return os.environ.get("JWT_SECRET_KEY", "test-jwt-secret-key-minimum-32-chars!!")
from app.secrets import get_jwt_secret
key = get_jwt_secret()
if len(key) < 32:
import warnings as _warnings
_warnings.warn(
f"JWT_SECRET_KEY is {len(key)} chars — minimum 32 recommended",
UserWarning, stacklevel=2,
)
return key
_SECRET_KEY_CACHE: Optional[str] = None
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_HOURS = 24 # 24 hours
def _get_cached_secret_key() -> str:
"""Return the JWT secret, caching after first call.
The cache is reset when TESTING env var is set so that each test
module picks up the correct JWT_SECRET_KEY from monkeypatch/env.
"""
global _SECRET_KEY_CACHE
# In test mode, always re-read from env to respect monkeypatch
if os.environ.get("TESTING", "").lower() in ("1", "true"):
return os.environ.get("JWT_SECRET_KEY", "test-jwt-secret-key-minimum-32-chars!!")
if _SECRET_KEY_CACHE is None:
_SECRET_KEY_CACHE = _get_secret_key()
return _SECRET_KEY_CACHE
def create_access_token(
user_id: str,
email: str,
role: str = "analyst",
expires_delta: Optional[timedelta] = None,
) -> str:
expire = datetime.now(timezone.utc) + (
expires_delta or timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS)
)
payload = {
"sub": user_id,
"email": email,
"role": role,
"exp": expire,
"iat": datetime.now(timezone.utc),
"jti": uuid.uuid4().hex,
}
return jwt.encode(payload, _get_cached_secret_key(), algorithm=ALGORITHM)
def verify_token(token: str) -> Optional[dict]:
"""Verify and decode a JWT token. Returns payload dict or None."""
try:
payload = jwt.decode(token, _get_cached_secret_key(), algorithms=[ALGORITHM])
return payload
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None