""" backend_router.py — OmniBiMol Backend Router (v2) =================================================== Improvements over v1 -------------------- * **Retry + exponential back-off** — each probe attempt is retried up to ``probe_max_retries`` times with ``backoff_base * 2^attempt`` seconds between attempts. A backend is only declared unhealthy after every retry is exhausted, preventing false failures due to transient cold-start latency. * **Minimum fallback duration** — once the router switches to the fallback it will not consider switching back until at least ``min_fallback_duration_seconds`` have elapsed. This prevents a recovering primary (still warming JVM/Python workers) from immediately pulling traffic back. * **N-consecutive-success recovery gate** — even after the min-duration has passed, ``recovery_success_threshold`` consecutive successful health probes of the primary are required before traffic returns. This eliminates flip-flopping caused by intermittent 200 responses during partial recovery. * **Rich status dict** — ``force_health_check()`` now returns probe latency, retry counts, fallback age, and recovery progress so the Streamlit badge can display meaningful context without an extra network call. Unchanged from v1 ----------------- * Thread-safe via ``threading.Lock``. * Sync-safe — no ``asyncio.run`` nesting. * No import-time side effects. * Streamlit session-state isolation (per-user router). * Module-level singleton for FastAPI workers / CLI. """ from __future__ import annotations import logging import os import threading import time from dataclasses import dataclass, field from typing import Optional import httpx logger = logging.getLogger(__name__) # Lazy import to avoid circular dependencies def _get_environment(): """Lazily import and get environment config.""" try: from app_environment import get_environment return get_environment() except ImportError: return None # --------------------------------------------------------------------------- # Configuration loader # --------------------------------------------------------------------------- @dataclass(frozen=True) class BackendConfig: """Validated, immutable snapshot of backend routing configuration. Loaded once per process from ``backend.core.config.Settings`` (pydantic- settings). Falls back to raw ``os.getenv()`` calls so the router works even without the full FastAPI import tree (tests, CLI). """ environment: str primary_url: str fallback_url: str health_timeout: float failure_threshold: int recheck_interval: int api_prefix: str # Probe resilience probe_max_retries: int = 2 # extra attempts (0 = single probe, no retries) probe_backoff_base: float = 0.5 # seconds; delay = base * 2^attempt → 0.5, 1, 2… # Fallback stability min_fallback_duration: int = 30 # seconds to hold on fallback before allowing recovery recovery_success_threshold: int = 2 # consecutive OK probes needed to switch back # Computed — not settable at construction time primary_health_url: str = field(init=False) fallback_health_url: str = field(init=False) def __post_init__(self) -> None: object.__setattr__( self, "primary_health_url", f"{self.primary_url.rstrip('/')}{self.api_prefix}/healthz", ) object.__setattr__( self, "fallback_health_url", f"{self.fallback_url.rstrip('/')}{self.api_prefix}/healthz" if self.fallback_url else "", ) @classmethod def load(cls) -> "BackendConfig": """Build config from pydantic-settings, falling back to raw env vars.""" try: from backend.core.config import get_settings # lazy; avoids circular import s = get_settings() return cls( environment=s.environment, primary_url=s.backend_api_url.rstrip("/"), fallback_url=s.backend_api_fallback_url.rstrip("/"), health_timeout=s.backend_health_timeout_seconds, failure_threshold=s.backend_health_failure_threshold, recheck_interval=s.backend_health_recheck_interval_seconds, api_prefix=s.api_prefix, probe_max_retries=s.backend_health_probe_max_retries, probe_backoff_base=s.backend_health_probe_backoff_base_seconds, min_fallback_duration=s.backend_health_min_fallback_duration_seconds, recovery_success_threshold=s.backend_health_recovery_success_threshold, ) except Exception as exc: logger.warning( "Could not load settings via pydantic-settings (%s); " "falling back to raw os.getenv().", exc, ) return cls( environment=os.getenv("ENVIRONMENT", "development"), primary_url=os.getenv("BACKEND_API_URL", "http://localhost:8000").rstrip("/"), fallback_url=os.getenv("BACKEND_API_FALLBACK_URL", "").rstrip("/"), health_timeout=float(os.getenv("BACKEND_HEALTH_TIMEOUT_SECONDS", "5")), failure_threshold=int(os.getenv("BACKEND_HEALTH_FAILURE_THRESHOLD", "2")), recheck_interval=int( os.getenv("BACKEND_HEALTH_RECHECK_INTERVAL_SECONDS", "60") ), api_prefix=os.getenv("API_PREFIX", "/api/v1"), probe_max_retries=int(os.getenv("BACKEND_HEALTH_PROBE_MAX_RETRIES", "2")), probe_backoff_base=float( os.getenv("BACKEND_HEALTH_PROBE_BACKOFF_BASE_SECONDS", "0.5") ), min_fallback_duration=int( os.getenv("BACKEND_HEALTH_MIN_FALLBACK_DURATION_SECONDS", "30") ), recovery_success_threshold=int( os.getenv("BACKEND_HEALTH_RECOVERY_SUCCESS_THRESHOLD", "2") ), ) # --------------------------------------------------------------------------- # Health probe (retry + exponential back-off) # --------------------------------------------------------------------------- def probe_health( url: str, timeout: float = 5.0, max_retries: int = 2, backoff_base: float = 0.5, ) -> tuple[bool, int, float]: """Probe *url* with retry + exponential back-off. Parameters ---------- url: Fully-qualified health endpoint, e.g. ``http://localhost:8000/api/v1/healthz``. timeout: Per-attempt timeout in seconds. max_retries: Number of *extra* attempts after the first. ``0`` = single attempt. backoff_base: Base delay in seconds; wait ``base * 2^attempt`` between retries. Sequence: 0.5 s → 1 s → 2 s → … Returns ------- tuple[bool, int, float] ``(healthy, attempts_used, total_elapsed_seconds)`` Notes ----- A backend is only declared unhealthy after **all** retries are exhausted, preventing false failures caused by cold-start latency or transient blips. """ if not url: return False, 0, 0.0 t_start = time.monotonic() attempts = 0 last_exc: Optional[Exception] = None for attempt in range(max_retries + 1): # attempt 0, 1, …, max_retries if attempt > 0: delay = backoff_base * (2 ** (attempt - 1)) # 0.5, 1, 2, … logger.debug( "Health probe retry %d/%d for %s (backoff %.2fs)", attempt, max_retries, url, delay, ) time.sleep(delay) attempts += 1 try: resp = httpx.get(url, timeout=timeout, follow_redirects=True) elapsed = time.monotonic() - t_start if resp.is_success: if attempt > 0: logger.info( "Health probe succeeded on attempt %d/%d for %s (%.2fs total)", attempt + 1, max_retries + 1, url, elapsed, ) return True, attempts, elapsed logger.warning( "Health probe HTTP %s for %s (attempt %d/%d)", resp.status_code, url, attempt + 1, max_retries + 1, ) last_exc = None # non-2xx is not a connection error; still retry except httpx.TimeoutException as exc: logger.warning( "Health probe timeout (%.1fs) for %s (attempt %d/%d)", timeout, url, attempt + 1, max_retries + 1, ) last_exc = exc except httpx.RequestError as exc: logger.warning( "Health probe connection error for %s (attempt %d/%d): %s", url, attempt + 1, max_retries + 1, exc, ) last_exc = exc elapsed = time.monotonic() - t_start logger.error( "Health probe FAILED for %s after %d attempt(s) (%.2fs) — last error: %s", url, attempts, elapsed, last_exc, ) return False, attempts, elapsed # --------------------------------------------------------------------------- # Backend router (circuit-breaker + stable fallback + recovery gate) # --------------------------------------------------------------------------- class BackendRouter: """Thread-safe circuit-breaker with stabilised fallback switching. State machine ------------- NORMAL (``_using_fallback=False``) All traffic → primary. Probe failures accumulate; after ``failure_threshold`` the router transitions to FALLBACK. FALLBACK (``_using_fallback=True``) All traffic → fallback. The primary is re-probed every ``recheck_interval`` seconds. Two gates must both be satisfied before switching back to NORMAL: 1. ``min_fallback_duration`` seconds have elapsed since fallback was engaged (prevents premature recovery during cold-starts). 2. ``recovery_success_threshold`` consecutive successful probes of the primary have been recorded (prevents flip-flopping caused by intermittent 200s). Thread safety ------------- All state mutations go through ``self._lock``. Status reads (active_url, is_using_fallback) are done outside the lock for performance — CPython's GIL makes these reads atomic in practice. """ def __init__(self, config: Optional[BackendConfig] = None) -> None: self._cfg = config or BackendConfig.load() self._lock = threading.Lock() # Failure accumulation (primary → fallback) self._failure_count: int = 0 # Fallback state self._using_fallback: bool = False self._fallback_engaged_at: float = 0.0 # monotonic timestamp # Recovery gate (fallback → primary) self._consecutive_recovery_successes: int = 0 self._last_primary_check: float = 0.0 # monotonic timestamp logger.info( "BackendRouter v2 initialised | env=%s | primary=%s | fallback=%s | " "retries=%d | backoff=%.1fs | min_fallback=%ds | recovery_gate=%d", self._cfg.environment, self._cfg.primary_url, self._cfg.fallback_url or "", self._cfg.probe_max_retries, self._cfg.probe_backoff_base, self._cfg.min_fallback_duration, self._cfg.recovery_success_threshold, ) # ------------------------------------------------------------------ # Public read properties # ------------------------------------------------------------------ @property def active_url(self) -> str: """Current active backend base URL (no trailing slash).""" if self._using_fallback and self._cfg.fallback_url: return self._cfg.fallback_url return self._cfg.primary_url @property def is_using_fallback(self) -> bool: """``True`` when traffic is routed to the fallback.""" return self._using_fallback @property def config(self) -> BackendConfig: """Read-only access to the resolved configuration.""" return self._cfg # ------------------------------------------------------------------ # Primary entry-points # ------------------------------------------------------------------ def get_active_url(self) -> str: """Return the active backend URL, triggering a recheck when due. Intentionally cheap: the full probe (with retries) only runs when on fallback AND ``recheck_interval`` seconds have elapsed. """ if self._using_fallback: self._maybe_recheck_primary() return self.active_url def record_success(self) -> None: """Signal that the last call to the *primary* succeeded. Resets the consecutive failure counter; no-op when on fallback. """ with self._lock: if not self._using_fallback: self._failure_count = 0 def record_failure(self) -> None: """Signal that the last call to the *primary* failed. Increments the failure counter; switches to fallback once the threshold is reached. No-op when already on fallback. """ with self._lock: if self._using_fallback: return self._failure_count += 1 logger.warning( "Backend failure #%d/%d for %s", self._failure_count, self._cfg.failure_threshold, self._cfg.primary_url, ) if self._failure_count >= self._cfg.failure_threshold: self._engage_fallback() def force_health_check(self) -> dict: """Run full probes (with retries) against both backends. Returns a rich status dict for the Streamlit sidebar badge and diagnostics panels. This call is *blocking* — use sparingly. """ primary_ok, p_attempts, p_elapsed = probe_health( self._cfg.primary_health_url, self._cfg.health_timeout, self._cfg.probe_max_retries, self._cfg.probe_backoff_base, ) if self._cfg.fallback_health_url: fallback_ok, f_attempts, f_elapsed = probe_health( self._cfg.fallback_health_url, self._cfg.health_timeout, self._cfg.probe_max_retries, self._cfg.probe_backoff_base, ) else: fallback_ok, f_attempts, f_elapsed = None, 0, 0.0 # Update internal state based on probe results with self._lock: if primary_ok and not self._using_fallback: self._failure_count = 0 elif primary_ok and self._using_fallback: self._consecutive_recovery_successes += 1 self._failure_count = 0 if self._can_recover(): self._recover_primary_locked() elif not primary_ok and not self._using_fallback: # Live probe failure — count toward threshold self._failure_count += 1 self._consecutive_recovery_successes = 0 if self._failure_count >= self._cfg.failure_threshold: self._engage_fallback_locked() else: self._consecutive_recovery_successes = 0 # Compute human-readable fallback age fallback_age: Optional[float] = None if self._using_fallback and self._fallback_engaged_at: fallback_age = time.monotonic() - self._fallback_engaged_at min_hold_remaining: Optional[float] = None if self._using_fallback and fallback_age is not None: remaining = self._cfg.min_fallback_duration - fallback_age min_hold_remaining = max(0.0, remaining) return { "environment": self._cfg.environment, "primary_url": self._cfg.primary_url, "primary_healthy": primary_ok, "primary_probe_attempts": p_attempts, "primary_probe_elapsed_s": round(p_elapsed, 3), "fallback_url": self._cfg.fallback_url or None, "fallback_healthy": fallback_ok, "fallback_probe_attempts": f_attempts, "fallback_probe_elapsed_s": round(f_elapsed, 3), "active_url": self.active_url, "using_fallback": self._using_fallback, "consecutive_failures": self._failure_count, "consecutive_recovery_successes": self._consecutive_recovery_successes, "recovery_threshold": self._cfg.recovery_success_threshold, "fallback_age_seconds": round(fallback_age, 1) if fallback_age else None, "min_hold_remaining_seconds": ( round(min_hold_remaining, 1) if min_hold_remaining is not None else None ), } def status_snapshot(self) -> dict: """Cheap status snapshot — NO network I/O. Safe to call on every Streamlit render cycle for the sidebar badge. """ fallback_age: Optional[float] = None if self._using_fallback and self._fallback_engaged_at: fallback_age = time.monotonic() - self._fallback_engaged_at min_hold_remaining: Optional[float] = None if self._using_fallback and fallback_age is not None: remaining = self._cfg.min_fallback_duration - fallback_age min_hold_remaining = max(0.0, remaining) return { "active_url": self.active_url, "using_fallback": self._using_fallback, "environment": self._cfg.environment, "consecutive_failures": self._failure_count, "consecutive_recovery_successes": self._consecutive_recovery_successes, "recovery_threshold": self._cfg.recovery_success_threshold, "fallback_age_seconds": round(fallback_age, 1) if fallback_age else None, "min_hold_remaining_seconds": ( round(min_hold_remaining, 1) if min_hold_remaining is not None else None ), "min_fallback_duration": self._cfg.min_fallback_duration, } # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _engage_fallback(self) -> None: """Switch to fallback. Acquires lock internally; safe to call outside lock.""" with self._lock: self._engage_fallback_locked() def _engage_fallback_locked(self) -> None: """Switch to fallback — must be called **inside** ``self._lock``.""" if not self._cfg.fallback_url: logger.error( "Primary backend unhealthy but BACKEND_API_FALLBACK_URL is not set. " "Requests will continue hitting the unhealthy primary." ) return logger.error( "Primary backend unhealthy after %d failures. Engaging fallback: %s", self._failure_count, self._cfg.fallback_url, ) self._using_fallback = True self._fallback_engaged_at = time.monotonic() self._last_primary_check = time.monotonic() self._consecutive_recovery_successes = 0 # reset gate def _can_recover(self) -> bool: """Both stability gates satisfied — must be called inside lock.""" # Gate 1: minimum time on fallback if self._fallback_engaged_at: age = time.monotonic() - self._fallback_engaged_at if age < self._cfg.min_fallback_duration: logger.debug( "Recovery gate 1: min_fallback_duration not met (%.1fs / %ds)", age, self._cfg.min_fallback_duration, ) return False # Gate 2: N consecutive successes if self._consecutive_recovery_successes < self._cfg.recovery_success_threshold: logger.debug( "Recovery gate 2: need %d more consecutive success(es) (have %d/%d)", self._cfg.recovery_success_threshold - self._consecutive_recovery_successes, self._consecutive_recovery_successes, self._cfg.recovery_success_threshold, ) return False return True def _recover_primary_locked(self) -> None: """Restore traffic to primary — must be called **inside** ``self._lock``.""" logger.info( "Both recovery gates satisfied — switching back to primary: %s", self._cfg.primary_url, ) self._using_fallback = False self._failure_count = 0 self._consecutive_recovery_successes = 0 self._fallback_engaged_at = 0.0 def _maybe_recheck_primary(self) -> None: """Re-probe the primary (with retries) if the recheck interval has elapsed.""" now = time.monotonic() if now - self._last_primary_check < self._cfg.recheck_interval: return # too soon; skip probe # Stamp timestamp inside lock to prevent concurrent probes with self._lock: if now - self._last_primary_check < self._cfg.recheck_interval: return # double-checked; another thread already ran the probe self._last_primary_check = now primary_ok, attempts, elapsed = probe_health( self._cfg.primary_health_url, self._cfg.health_timeout, self._cfg.probe_max_retries, self._cfg.probe_backoff_base, ) with self._lock: if primary_ok: self._consecutive_recovery_successes += 1 logger.info( "Primary recheck OK (%d/%d consecutive successes, %.2fs, %d attempt(s))", self._consecutive_recovery_successes, self._cfg.recovery_success_threshold, elapsed, attempts, ) if self._can_recover(): self._recover_primary_locked() else: # Reset recovery counter — must see N *consecutive* successes if self._consecutive_recovery_successes > 0: logger.info( "Primary recheck failed; resetting recovery counter " "(%d → 0, %.2fs, %d attempt(s))", self._consecutive_recovery_successes, elapsed, attempts, ) self._consecutive_recovery_successes = 0 else: logger.info( "Primary still unhealthy during recheck (%.2fs, %d attempt(s))", elapsed, attempts, ) # --------------------------------------------------------------------------- # Module-level singleton (FastAPI workers / CLI) # --------------------------------------------------------------------------- _module_router: Optional[BackendRouter] = None _module_router_lock = threading.Lock() def get_router(config: Optional[BackendConfig] = None) -> BackendRouter: """Return the process-level ``BackendRouter`` singleton. In Streamlit prefer ``get_router_from_session()`` so each user session owns its own failure counters independently. """ global _module_router if _module_router is None: with _module_router_lock: if _module_router is None: _module_router = BackendRouter(config) return _module_router def get_active_backend_url() -> str: """Convenience wrapper — primary entry-point for non-Streamlit callers.""" return get_router().get_active_url() # --------------------------------------------------------------------------- # Streamlit session-state helpers # --------------------------------------------------------------------------- _SESSION_KEY = "__omnibimol_backend_router__" def get_router_from_session() -> BackendRouter: """Return the ``BackendRouter`` stored in ``st.session_state``. Creates a new router on the first call per Streamlit user session. Falls back to the module-level singleton outside Streamlit. """ try: import streamlit as st if _SESSION_KEY not in st.session_state: st.session_state[_SESSION_KEY] = BackendRouter() return st.session_state[_SESSION_KEY] except Exception: return get_router() def get_active_backend_url_for_session() -> str: """Streamlit-session-aware version of ``get_active_backend_url()``. Call this from ``app.py`` when (re)creating ``ProteinAPIClient``:: st.session_state.api_client = ProteinAPIClient( st.session_state.cache_manager, backend_api_url=get_active_backend_url_for_session(), ) """ return get_router_from_session().get_active_url() # --------------------------------------------------------------------------- # Sidebar badge renderer (Streamlit helper) # --------------------------------------------------------------------------- def render_backend_status_badge() -> None: """Render a compact backend-status badge inside ``st.sidebar``. Uses only ``status_snapshot()`` — zero network I/O — so it is safe to call on every Streamlit render cycle without performance impact. **PRODUCTION MODE**: This badge is hidden in production. Only shown in development mode (APP_ENV=development). Visual states (development only) -------------------------------- ✅ Primary healthy → green badge, primary URL ⚠️ On fallback → amber badge, fallback URL, age, recovery progress ❌ No fallback → red badge if on fallback without a configured URL """ try: import streamlit as st except ImportError: return # non-Streamlit environment; silently skip # Check environment - only show in development mode env = _get_environment() if env and env.is_production(): return # Silent no-op in production snap = get_router_from_session().status_snapshot() using_fallback: bool = snap["using_fallback"] active_url: str = snap["active_url"] env_label: str = snap["environment"] # ── colour tokens ────────────────────────────────────────────────── if not using_fallback: dot = "🟢" label = "Primary" colour = "#16a34a" # green-600 bg = "#f0fdf4" # green-50 border = "#86efac" # green-300 else: dot = "🟡" label = "Fallback" colour = "#b45309" # amber-700 bg = "#fffbeb" # amber-50 border = "#fcd34d" # amber-300 # Short URL for display (strip scheme + trailing slash) display_url = active_url.replace("https://", "").replace("http://", "").rstrip("/") if len(display_url) > 34: display_url = display_url[:31] + "…" # ── build extra lines for fallback state ─────────────────────────── extra_lines = "" if using_fallback: age = snap.get("fallback_age_seconds") hold = snap.get("min_hold_remaining_seconds") rec = snap.get("consecutive_recovery_successes", 0) rec_thr = snap.get("recovery_threshold", 2) min_dur = snap.get("min_fallback_duration", 30) age_str = f"{int(age)}s" if age is not None else "—" if hold and hold > 0: gate_str = f"min-hold: {int(hold)}s left" else: gate_str = f"recovery: {rec}/{rec_thr} probes" extra_lines = ( f"
" f"Active {age_str} · {gate_str}" f"
" ) badge_html = f"""
{dot} {label} · {env_label}
{display_url}
{extra_lines}
""".strip() st.sidebar.markdown(badge_html, unsafe_allow_html=True)