from __future__ import annotations import contextvars import json import logging import time from collections.abc import Mapping, Sequence from typing import Any from uuid import uuid4 from fastapi import FastAPI, HTTPException, Request from fastapi.exception_handlers import http_exception_handler, request_validation_exception_handler from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from backend.app.core.pii_scrubber import scrub_pii _REQUEST_ID: contextvars.ContextVar[str] = contextvars.ContextVar("request_id", default="-") _LOGGING_CONFIGURED = False _SENSITIVE_FIELD_MARKERS = ( "password", "secret", "token", "authorization", "cookie", "api_key", "apikey", "audio", "base64", "binary", ) class _RequestContextFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: record.request_id = _REQUEST_ID.get() return True def configure_logging(level: str = "INFO") -> None: global _LOGGING_CONFIGURED numeric_level = getattr(logging, level.upper(), logging.INFO) formatter = logging.Formatter( fmt="%(asctime)s | %(levelname)-8s | req=%(request_id)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) handler = logging.StreamHandler() handler.setLevel(numeric_level) handler.setFormatter(formatter) handler.addFilter(_RequestContextFilter()) root_logger = logging.getLogger() root_logger.handlers.clear() root_logger.addHandler(handler) root_logger.setLevel(numeric_level) for logger_name in ("uvicorn", "uvicorn.error", "uvicorn.access", "fastapi"): logger = logging.getLogger(logger_name) logger.handlers.clear() logger.propagate = True logger.setLevel(numeric_level) for logger_name in ("httpx", "httpcore", "urllib3"): logger = logging.getLogger(logger_name) logger.handlers.clear() logger.propagate = True logger.setLevel(logging.WARNING) logging.captureWarnings(True) _LOGGING_CONFIGURED = True def get_request_id() -> str: return _REQUEST_ID.get() def install_application_logging( app: FastAPI, *, capture_request_bodies: bool = True, body_max_chars: int = 1200, ) -> None: request_logger = logging.getLogger("backend.request") error_logger = logging.getLogger("backend.error") @app.middleware("http") async def log_requests(request: Request, call_next): incoming_request_id = request.headers.get("X-Request-ID") request_id = incoming_request_id.strip() if incoming_request_id else uuid4().hex[:12] token = _REQUEST_ID.set(request_id) request.state.request_id = request_id body_preview = await _prepare_body_preview( request, capture_request_bodies=capture_request_bodies, body_max_chars=body_max_chars, ) start = time.perf_counter() method = request.method.upper() path = request.url.path client = _describe_client(request) query = _summarize_query(request, body_max_chars) request_logger.info( "Request started | method=%s | path=%s | client=%s | query=%s%s", method, path, client, query, f" | body={body_preview}" if body_preview else "", ) try: response = await call_next(request) except Exception: duration_ms = (time.perf_counter() - start) * 1000 request_logger.exception( "Request crashed | method=%s | path=%s | client=%s | duration_ms=%.2f", method, path, client, duration_ms, ) raise else: duration_ms = (time.perf_counter() - start) * 1000 route = request.scope.get("route") endpoint = request.scope.get("endpoint") route_path = getattr(route, "path", path) endpoint_name = getattr(endpoint, "__name__", "unknown") response.headers["X-Request-ID"] = request_id level = logging.INFO if response.status_code < 500 else logging.ERROR request_logger.log( level, "Request completed | method=%s | path=%s | route=%s | endpoint=%s | status=%s | duration_ms=%.2f", method, path, route_path, endpoint_name, response.status_code, duration_ms, ) return response finally: _REQUEST_ID.reset(token) @app.exception_handler(RequestValidationError) async def handle_validation_error(request: Request, exc: RequestValidationError): error_logger.warning( "Validation failed | method=%s | path=%s | errors=%s%s", request.method.upper(), request.url.path, _safe_json(_sanitize_value(exc.errors(), body_max_chars)), _format_body_suffix(request), ) response = await request_validation_exception_handler(request, exc) response.headers["X-Request-ID"] = getattr(request.state, "request_id", get_request_id()) return response @app.exception_handler(HTTPException) async def handle_http_exception(request: Request, exc: HTTPException): log_level = logging.ERROR if exc.status_code >= 500 else logging.WARNING error_logger.log( log_level, "HTTP exception | method=%s | path=%s | status=%s | detail=%s%s", request.method.upper(), request.url.path, exc.status_code, _safe_json(_sanitize_value(exc.detail, body_max_chars)), _format_body_suffix(request), ) response = await http_exception_handler(request, exc) response.headers["X-Request-ID"] = getattr(request.state, "request_id", get_request_id()) return response @app.exception_handler(Exception) async def handle_unexpected_exception(request: Request, exc: Exception): error_logger.exception( "Unhandled exception | method=%s | path=%s%s", request.method.upper(), request.url.path, _format_body_suffix(request), ) return JSONResponse( status_code=500, content={ "detail": "Internal server error", "request_id": getattr(request.state, "request_id", get_request_id()), }, headers={"X-Request-ID": getattr(request.state, "request_id", get_request_id())}, ) def log_startup_summary(*, app_name: str, environment: str, port: int, cors_origins: list[str], ollama_model: str, embedding_model: str, neo4j_uri: str, redis_url: str, chroma_host: str, chroma_port: int, infrastructure: dict[str, str] | None = None) -> None: infra = infrastructure or { "neo4j_uri": neo4j_uri, "redis_endpoint": redis_url, "chroma_endpoint": f"{chroma_host}:{chroma_port}", "neo4j_mode": "unknown", "redis_mode": "unknown", "chroma_mode": "unknown", } logging.getLogger("backend.lifecycle").info( "Application boot | app=%s | env=%s | port=%s | cors=%s | ollama_model=%s | embedding_model=%s | neo4j=%s (%s) | redis=%s (%s) | chroma=%s (%s)", app_name, environment, port, ",".join(cors_origins) if cors_origins else "", ollama_model, embedding_model, infra["neo4j_uri"], infra["neo4j_mode"], infra["redis_endpoint"], infra["redis_mode"], infra["chroma_endpoint"], infra["chroma_mode"], ) async def _prepare_body_preview(request: Request, *, capture_request_bodies: bool, body_max_chars: int) -> str | None: if not capture_request_bodies: request.state.request_body_preview = None return None try: body = await request.body() except Exception: request.state.request_body_preview = "" return request.state.request_body_preview if not body: request.state.request_body_preview = None return None request.state.request_body_preview = _summarize_body( body, content_type=request.headers.get("content-type", ""), max_chars=body_max_chars, ) async def receive() -> dict[str, Any]: return {"type": "http.request", "body": body, "more_body": False} request._receive = receive return request.state.request_body_preview def _summarize_body(body: bytes, *, content_type: str, max_chars: int) -> str: if not body: return "" decoded = body.decode("utf-8", errors="replace") if "application/json" in content_type: try: payload = json.loads(decoded) except json.JSONDecodeError: return _truncate(scrub_pii(decoded), max_chars) return _truncate(_safe_json(_sanitize_value(payload, max_chars)), max_chars) return _truncate(scrub_pii(decoded), max_chars) def _sanitize_value(value: Any, max_chars: int) -> Any: if isinstance(value, Mapping): sanitized: dict[str, Any] = {} for key, item in value.items(): key_text = str(key) if _looks_sensitive(key_text): sanitized[key_text] = "[REDACTED]" else: sanitized[key_text] = _sanitize_value(item, max_chars) return sanitized if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): return [_sanitize_value(item, max_chars) for item in value] if isinstance(value, bytes): return "[BINARY_PAYLOAD_REDACTED]" if isinstance(value, str): return _truncate(scrub_pii(value), max_chars) return value def _looks_sensitive(key: str) -> bool: normalized = key.strip().lower() return any(marker in normalized for marker in _SENSITIVE_FIELD_MARKERS) def _truncate(text: str, max_chars: int) -> str: if len(text) <= max_chars: return text return f"{text[:max_chars]}... " def _safe_json(value: Any) -> str: try: return json.dumps(value, ensure_ascii=True, default=str) except TypeError: return scrub_pii(str(value)) def _describe_client(request: Request) -> str: if request.client is None: return "unknown" return f"{request.client.host}:{request.client.port}" def _summarize_query(request: Request, max_chars: int) -> str: if not request.query_params: return "" query_dict = {key: _sanitize_value(value, max_chars) for key, value in request.query_params.items()} return _safe_json(query_dict) def _format_body_suffix(request: Request) -> str: body_preview = getattr(request.state, "request_body_preview", None) if not body_preview: return "" return f" | body={body_preview}"