from __future__ import annotations import logging import threading import time from dataclasses import dataclass from fastapi import Request from fastapi.responses import JSONResponse from backend.app.core.config import settings from backend.app.db.redis_client import get_redis_client logger = logging.getLogger(__name__) @dataclass class RateLimitExceeded(Exception): retry_after_seconds: int _MEMORY_COUNTS: dict[str, tuple[int, float]] = {} _MEMORY_LOCK = threading.RLock() def _client_identifier(request: Request) -> str: forwarded_for = request.headers.get("x-forwarded-for", "") if forwarded_for: return forwarded_for.split(",", 1)[0].strip() or "unknown" if request.client: return request.client.host return "unknown" def _bucket_key(scope: str, identifier: str, window_seconds: int, now: float) -> tuple[str, int]: bucket = int(now // window_seconds) return f"rate_limit:{scope}:{identifier}:{bucket}", bucket def _memory_increment(key: str, window_seconds: int, now: float) -> int: expires_at = now + window_seconds with _MEMORY_LOCK: count, existing_expiry = _MEMORY_COUNTS.get(key, (0, expires_at)) if existing_expiry <= now: count = 0 existing_expiry = expires_at count += 1 _MEMORY_COUNTS[key] = (count, existing_expiry) stale_keys = [name for name, (_, expiry) in _MEMORY_COUNTS.items() if expiry <= now] for name in stale_keys: _MEMORY_COUNTS.pop(name, None) return count def clear_rate_limit_state() -> None: """Clear rate limiter counters for tests and local smoke checks.""" with _MEMORY_LOCK: _MEMORY_COUNTS.clear() try: redis = get_redis_client() for key in list(redis.scan_iter(match="rate_limit:*")): redis.delete(key) except Exception: logger.debug("Unable to clear Redis rate limit keys", exc_info=True) async def rate_limit_exception_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse: return JSONResponse( status_code=429, content={ "error": "rate_limited", "retry_after_seconds": exc.retry_after_seconds, }, ) def rate_limit_dependency(scope: str, limit_attr: str, window_seconds: int = 60): async def _check(request: Request) -> None: if not settings.rate_limit_enabled: return limit = max(0, int(getattr(settings, limit_attr, 0) or 0)) if limit <= 0: return now = time.time() identifier = _client_identifier(request) key, bucket = _bucket_key(scope, identifier, window_seconds, now) retry_after = max(1, int(((bucket + 1) * window_seconds) - now)) try: redis = get_redis_client() if redis.__class__.__name__ == "_InMemoryRedis": count = _memory_increment(key, window_seconds, now) else: count = int(redis.incr(key)) if count == 1: redis.expire(key, window_seconds) except Exception: logger.warning("Redis rate limiter unavailable; using in-memory limiter", exc_info=True) count = _memory_increment(key, window_seconds, now) if count > limit: raise RateLimitExceeded(retry_after_seconds=retry_after) return _check