""" state_manager.py – OmniBiMol Phase-Scoped Session State Architecture ======================================================================= Public surface -------------- handle_phase_routing(display_name) → call at TOP of main(), before any reads reset_state_for_phase(phase) → explicitly wipe one phase transition_to(new_phase) → detect change + cleanup + rerun guard get(key, *, phase, default) → namespaced read set_key(key, value, *, phase) → namespaced write pop(key, *, phase, default) → namespaced delete snapshot() → debug: all keys grouped by phase audit_global_keys() → warn if globals exceed threshold clear_global_state(confirm) → dev helper: wipe global non-infra keys gc_stale_phases() → lazy-cleanup: sweep inactive-marked phases Namespacing ----------- __ e.g. genomics__blast_results protein__current_uniprot_id Phase slugs (never shown to users) genomics | protein | dti Lazy-Cleanup (opt-in, default ON) ---------------------------------- Instead of deleting every key the moment a phase is exited, keys are marked via the sentinel set ``_stale_phases`` in session_state. Actual deletion runs in ``gc_stale_phases()``, called once per rerun BEFORE any module logic executes. This separates the cleanup trigger from the deletion work, making reruns cheaper when nothing changed. Flow: handle_phase_routing() → detects phase change → marks old slug in _stale_phases (instant, no iteration) → updates current_phase / previous_phase → calls st.rerun() ← ensures next rerun starts clean gc_stale_phases() (called at very top of main) → iterates _stale_phases → deletes keys, clears the set """ from __future__ import annotations import logging import sys import time from typing import Any, Dict, FrozenSet, List, Optional, Set, Union import streamlit as st logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Tuning constants # --------------------------------------------------------------------------- # Warn when global session keys exceed this count GLOBAL_KEY_COUNT_THRESHOLD: int = 50 # Warn when a single global value exceeds this byte estimate GLOBAL_VALUE_SIZE_THRESHOLD_BYTES: int = 10 * 1024 * 1024 # 10 MB # Minimum milliseconds between consecutive st.rerun() calls. # Calls arriving within this window are silently suppressed. RERUN_DEBOUNCE_MS: int = 200 # snapshot() raw-mode: truncate sequences / dicts larger than this count _SNAPSHOT_RAW_SEQUENCE_MAX: int = 50 # snapshot() raw-mode: truncate string values longer than this _SNAPSHOT_RAW_STR_MAX: int = 500 # --------------------------------------------------------------------------- # Phase slug → display-name mapping # --------------------------------------------------------------------------- PHASE_SLUGS: Dict[str, str] = { "genomics": "Genomics & Variant Discovery", "protein": "Protein Analysis Suite", "dti": "Drug-Target Interaction & Matching", } # Reverse lookup: display-name → slug _DISPLAY_TO_SLUG: Dict[str, str] = {v: k for k, v in PHASE_SLUGS.items()} # --------------------------------------------------------------------------- # GLOBAL_PERSISTENT_KEYS – NEVER cleared regardless of phase # --------------------------------------------------------------------------- GLOBAL_PERSISTENT_KEYS: FrozenSet[str] = frozenset( { # ── Infrastructure ────────────────────────────────────────────────── "api_client", "cache_manager", "portfolio_engine", # ── Navigation / routing ──────────────────────────────────────────── "current_page", "phase_selector", "previous_phase", "current_phase", # ── Lazy-cleanup internals ────────────────────────────────────────── "_stale_phases", "_phase_changed", # ── Rerun debounce ────────────────────────────────────────────────── "_last_rerun_ts", # ── Auth ──────────────────────────────────────────────────────────── "user", "user_email", "auth_token", "auth_state", # ── App-wide UI toggles ───────────────────────────────────────────── "docking_mode_choice", # ── Cross-phase copilot context (read-only after write) ───────────── "repurposing_report_data", } ) # --------------------------------------------------------------------------- # Per-phase ephemeral key registry # Auto-detected namespaced keys (__*) are handled at runtime. # --------------------------------------------------------------------------- _PHASE_EPHEMERAL_KEYS: Dict[str, Set[str]] = { "genomics": { "genome_analysis_results", "vcf_upload_df", "vcf_parse_error", "blast_results", "blast_protein_id", "sequence_input", "alignment_results", "msa_result", "phylo_result", "domain_result", "motif_result", "wgs_job_id", "wgs_report", }, "protein": { "search_results", "show_results", "trigger_search", "protein_input", "search_input_key", "current_data", "current_uniprot_id", "current_gene_name", "fetch_time", "docking_results", "docking_job_id", "selected_ligand", "repurposing_results", "copilot_messages", "ligand_binding_data", }, "dti": { "drugs_search_query", "drugs_results", "clinical_trials_results", }, } # =========================================================================== # Internal helpers # =========================================================================== def _ns(phase: str, key: str) -> str: """Return the namespaced session key for a phase+key pair.""" return f"{phase}__{key}" def _is_global(key: str) -> bool: """True if *key* must never be cleared by any phase-reset operation.""" return key in GLOBAL_PERSISTENT_KEYS def _is_phase_key(key: str) -> bool: """True if *key* belongs to any phase (namespaced or statically registered).""" if any(key.startswith(f"{slug}__") for slug in PHASE_SLUGS): return True return any(key in keys for keys in _PHASE_EPHEMERAL_KEYS.values()) def _slug_from_display(display_name: str) -> Optional[str]: """Convert a display-name to its slug, or None if not recognised.""" return _DISPLAY_TO_SLUG.get(display_name) def _phase_cleanup_enabled() -> bool: """Allow runtime opt-out of stale-mark/sweep cleanup for Streamlit routing.""" return not bool(st.session_state.get("disable_phase_cleanup", False)) def _collect_ephemeral_keys_for_phase(phase: str) -> List[str]: """ Return all live session-state keys that belong to *phase*. Pass 1 – static registry (bare keys listed in ``_PHASE_EPHEMERAL_KEYS``) Pass 2 – dynamic detection (keys prefixed ``__``) Global keys are always excluded. """ registered: Set[str] = _PHASE_EPHEMERAL_KEYS.get(phase, set()) prefix = f"{phase}__" seen: Set[str] = set() result: List[str] = [] for raw_key in list(st.session_state.keys()): if raw_key in seen or _is_global(raw_key): continue if raw_key in registered or raw_key.startswith(prefix): result.append(raw_key) seen.add(raw_key) return result def _estimate_size(value: Any) -> int: """Best-effort byte estimate for a session-state value.""" try: return sys.getsizeof(value) except Exception: return 0 # =========================================================================== # Lazy-cleanup GC # =========================================================================== def gc_stale_phases() -> None: """ Delete all keys marked as stale by a prior ``transition_to()`` call. MUST be called at the very top of ``main()``, before ANY state reads or module logic. This is the second half of the lazy-cleanup strategy: transition_to() → marks old phase stale, calls st.rerun() gc_stale_phases() → runs at top of the *next* rerun, deletes keys Safe to call on every rerun; it's a no-op when nothing is stale. """ stale: Set[str] = st.session_state.get("_stale_phases", set()) if not stale: return total_cleared = 0 for slug in list(stale): if slug not in PHASE_SLUGS: logger.warning("gc_stale_phases: ignoring unknown slug '%s'", slug) stale.discard(slug) continue keys = _collect_ephemeral_keys_for_phase(slug) for key in keys: try: del st.session_state[key] except KeyError: pass total_cleared += len(keys) logger.info( "GC: swept phase=%s cleared=%d keys=%s", slug, len(keys), keys ) stale.discard(slug) st.session_state["_stale_phases"] = stale if total_cleared: logger.info("gc_stale_phases: total keys removed=%d", total_cleared) # =========================================================================== # Public API – core # =========================================================================== def reset_state_for_phase(phase: str) -> None: """ Immediately delete all ephemeral session-state keys for *phase*. Safe to call: • On explicit user action ("Reset this module") • In tests / dev tooling • From ``gc_stale_phases()`` (lazy path, default) Parameters ---------- phase : str Phase slug (``"genomics"``) or display-name (``"Genomics & Variant Discovery"``). Notes ----- ``GLOBAL_PERSISTENT_KEYS`` are NEVER deleted. """ slug = _slug_from_display(phase) or phase if slug not in PHASE_SLUGS: logger.warning( "reset_state_for_phase: unknown phase '%s'. Skipping.", slug ) return keys = _collect_ephemeral_keys_for_phase(slug) for key in keys: try: del st.session_state[key] logger.debug("Cleared key=%s phase=%s", key, slug) except KeyError: pass logger.info( "reset_state_for_phase: phase=%s cleared=%d keys=%s", slug, len(keys), keys, ) def transition_to(new_phase: str) -> bool: """ Detect a phase change, mark the old phase stale, update routing keys, and signal a rerun is needed. Returns ------- bool ``True`` if the phase actually changed from an *established* phase (caller should then call ``st.rerun()``). ``False`` on first boot (no prior phase) or if the phase is unchanged. Implementation -------------- Rather than deleting keys inline (which risks partial execution with a mix of cleared + live state), this function: 1. Records the old slug in ``_stale_phases`` (a set in session_state). 2. Updates ``current_phase`` / ``previous_phase``. 3. Sets the ``_phase_changed`` flag so ``handle_phase_routing`` can trigger ``st.rerun()`` safely. ``gc_stale_phases()`` runs at the very top of the next rerun and performs the actual key deletion before any module logic executes. Boot-path special case ---------------------- When ``current_phase`` is ``None`` (app just started), we record the initial phase WITHOUT triggering a rerun – there is no stale state to flush yet and a spurious rerun would create an infinite loop. """ new_slug = _slug_from_display(new_phase) or new_phase old_slug: Optional[str] = st.session_state.get("current_phase") if old_slug == new_slug: return False # same phase – nothing to do # Boot path: first time a phase is established. Record it but do NOT # mark anything stale or request a rerun – nothing has been loaded yet. if old_slug is None: st.session_state["previous_phase"] = None st.session_state["current_phase"] = new_slug logger.info("Phase initialised on boot: %s", new_slug) return False logger.info("Phase transition: %s → %s", old_slug, new_slug) # Mark old phase for lazy GC if old_slug in PHASE_SLUGS: stale: Set[str] = st.session_state.get("_stale_phases") or set() stale.add(old_slug) st.session_state["_stale_phases"] = stale st.session_state["previous_phase"] = old_slug st.session_state["current_phase"] = new_slug st.session_state["_phase_changed"] = True return True def handle_phase_routing(phase_display_name: str) -> None: """ **Call this at the VERY TOP of main(), before any state reads.** Step 1 – ``gc_stale_phases()`` : delete any keys left stale by a prior transition (lazy-cleanup second pass). Step 2 – ``audit_global_keys()`` : warn on bloat, log oversized values. Step 3 – ``transition_to()`` : detect phase change, mark stale, flag. Step 4 – ``st.rerun()`` : if the phase just changed, abort the current (potentially half-stale) run and start a clean one. Parameters ---------- phase_display_name : str The display-name for the phase currently selected in the sidebar, e.g. ``"Genomics & Variant Discovery"``. Example (inside app.py main()) -------------------------------- :: # ── FIRST LINES of main(), before anything else ── selected_phase_idx = st.session_state.get("phase_selector", 0) selected_phase_title, phase_modules = PHASES[selected_phase_idx] state_manager.handle_phase_routing(selected_phase_title) # ── everything below runs with a clean slate ───── """ cleanup_enabled = _phase_cleanup_enabled() # 1. Run GC for any phases marked stale in the previous rerun if cleanup_enabled: gc_stale_phases() else: st.session_state["_stale_phases"] = set() # 2. Proactive global-key health check audit_global_keys() # 3. Detect transition; mark old phase stale if changed if cleanup_enabled: changed = transition_to(phase_display_name) else: new_slug = _slug_from_display(phase_display_name) or phase_display_name old_slug: Optional[str] = st.session_state.get("current_phase") changed = old_slug != new_slug if changed: st.session_state["previous_phase"] = old_slug st.session_state["current_phase"] = new_slug st.session_state["_phase_changed"] = False # 4. If the phase just changed: abort this (potentially mixed) rerun. # The *next* rerun will find no stale keys (gc_stale_phases ran above) # and will execute fully with a clean state. if changed and cleanup_enabled: _safe_rerun(trigger="phase_change", context=phase_display_name) # =========================================================================== # Rerun debounce helper # =========================================================================== def _safe_rerun(*, trigger: str, context: str = "") -> None: """ Debounced wrapper around ``st.rerun()``. Prevents rapid back-to-back reruns (e.g. double-clicks, hot widget callbacks firing in the same 200 ms window) that would cause a visible flicker or, in edge cases, an infinite rerun loop. Behaviour --------- * Records the call time in ``_last_rerun_ts`` (monotonic clock, seconds). * Suppresses the call if the previous rerun was less than ``RERUN_DEBOUNCE_MS`` milliseconds ago, logging a WARNING instead. * Always logs the trigger source and context at INFO level so rerun events are traceable in the Streamlit log stream. Parameters ---------- trigger : str Short identifier for what caused this rerun, e.g. ``"phase_change"``. context : str Optional human-readable detail logged alongside the trigger. """ now_s: float = time.monotonic() last_ts: Optional[float] = st.session_state.get("_last_rerun_ts") if last_ts is not None: elapsed_ms = (now_s - last_ts) * 1000 if elapsed_ms < RERUN_DEBOUNCE_MS: logger.warning( "_safe_rerun: SUPPRESSED trigger=%s context=%r " "elapsed_ms=%.1f debounce_ms=%d", trigger, context, elapsed_ms, RERUN_DEBOUNCE_MS, ) return logger.info( "_safe_rerun: FIRING trigger=%s context=%r " "prev_rerun_ago_ms=%s", trigger, context, f"{(now_s - last_ts) * 1000:.1f}" if last_ts is not None else "", ) st.session_state["_last_rerun_ts"] = now_s st.rerun() # =========================================================================== # Global-key safeguards # =========================================================================== def audit_global_keys() -> Dict[str, Any]: """ Inspect global (non-phase) session-state keys for size and count bloat. Logs a WARNING if: • Total global key count exceeds ``GLOBAL_KEY_COUNT_THRESHOLD`` • Any single value's byte estimate exceeds ``GLOBAL_VALUE_SIZE_THRESHOLD_BYTES`` Returns ------- dict with keys: ``global_key_count`` int ``oversized_keys`` list[str] – keys whose value is too large ``total_bytes`` int – sum of estimated sizes """ global_keys = [ k for k in st.session_state.keys() if not _is_phase_key(k) ] count = len(global_keys) if count > GLOBAL_KEY_COUNT_THRESHOLD: logger.warning( "audit_global_keys: global key count=%d exceeds threshold=%d. " "Review GLOBAL_PERSISTENT_KEYS or call clear_global_state().", count, GLOBAL_KEY_COUNT_THRESHOLD, ) oversized: List[str] = [] total_bytes = 0 for key in global_keys: size = _estimate_size(st.session_state[key]) total_bytes += size if size > GLOBAL_VALUE_SIZE_THRESHOLD_BYTES: oversized.append(key) logger.warning( "audit_global_keys: key='%s' estimated_size=%d bytes (threshold=%d).", key, size, GLOBAL_VALUE_SIZE_THRESHOLD_BYTES, ) return { "global_key_count": count, "oversized_keys": oversized, "total_bytes": total_bytes, } def clear_global_state(*, confirm: bool = False) -> int: """ **Debug / dev helper** – wipe non-essential global session keys. Removes every key in session_state that is NOT in ``GLOBAL_PERSISTENT_KEYS`` AND is not a phase-ephemeral key. In other words, it clears "unknown" keys that leaked into the global namespace without being properly registered. Parameters ---------- confirm : bool Must be ``True`` to execute. Safety guard against accidental calls. Returns ------- int Number of keys removed. Example ------- :: # In a dev-mode sidebar toggle: if st.button("⚠️ Clear Unknown Globals"): n = state_manager.clear_global_state(confirm=True) st.toast(f"Cleared {n} unknown global keys") """ if not confirm: logger.error( "clear_global_state: called without confirm=True. Aborting. " "Pass confirm=True to actually clear." ) return 0 # Collect all ephemeral keys across all phases all_ephemeral: Set[str] = set() for keys in _PHASE_EPHEMERAL_KEYS.values(): all_ephemeral.update(keys) for slug in PHASE_SLUGS: prefix = f"{slug}__" for k in list(st.session_state.keys()): if k.startswith(prefix): all_ephemeral.add(k) unknown_globals = [ k for k in list(st.session_state.keys()) if k not in GLOBAL_PERSISTENT_KEYS and k not in all_ephemeral ] for key in unknown_globals: try: del st.session_state[key] except KeyError: pass logger.info( "clear_global_state: removed %d unknown global keys: %s", len(unknown_globals), unknown_globals, ) return len(unknown_globals) # =========================================================================== # Namespaced read / write / delete helpers # =========================================================================== def get(key: str, *, phase: Optional[str] = None, default: Any = None) -> Any: """ Read a session-state value. Parameters ---------- key : str – logical key name phase : str – optional slug; if given, key is read as ``__`` default : Any – returned when key is absent """ return st.session_state.get(_ns(phase, key) if phase else key, default) def set_key(key: str, value: Any, *, phase: Optional[str] = None) -> None: """ Write a session-state value. Parameters ---------- key : str – logical key name value : Any – value to store phase : str – optional slug for namespacing """ st.session_state[_ns(phase, key) if phase else key] = value def pop(key: str, *, phase: Optional[str] = None, default: Any = None) -> Any: """ Delete and return a session-state value. Parameters ---------- key : str – logical key name phase : str – optional slug default : Any – returned if key is absent """ return st.session_state.pop(_ns(phase, key) if phase else key, default) # =========================================================================== # Debug: snapshot # =========================================================================== def _snapshot_value(value: Any, *, raw: bool) -> Union[str, Any]: """ Serialise a session-state value for snapshot output. ``raw=False`` (default / UI-safe) Returns ``repr(value)[:120]`` – always a short string, safe for ``st.json()`` and JSON serialisation. ``raw=True`` (deep-debug) Returns the actual Python value, with large containers truncated so the result stays inspectable without memory pressure: * ``list`` / ``tuple`` → first ``_SNAPSHOT_RAW_SEQUENCE_MAX`` items + a trailing ``"..."`` marker if truncated. * ``dict`` → first ``_SNAPSHOT_RAW_SEQUENCE_MAX`` pairs + marker. * ``str`` → first ``_SNAPSHOT_RAW_STR_MAX`` chars + ``"..."`` marker if truncated. * ``bytes`` → ``""`` string. * Everything else → returned as-is (numbers, bools, None, dataclass instances, etc.). """ if not raw: return repr(value)[:120] # ── raw mode: type-aware truncation ────────────────────────────────── if isinstance(value, bytes): return f"" if isinstance(value, str): if len(value) > _SNAPSHOT_RAW_STR_MAX: return value[:_SNAPSHOT_RAW_STR_MAX] + f"...<{len(value)} total>" return value if isinstance(value, (list, tuple)): n = len(value) if n > _SNAPSHOT_RAW_SEQUENCE_MAX: truncated = list(value[:_SNAPSHOT_RAW_SEQUENCE_MAX]) truncated.append(f"...<{n} total>") return truncated return list(value) if isinstance(value, dict): n = len(value) if n > _SNAPSHOT_RAW_SEQUENCE_MAX: items = list(value.items())[:_SNAPSHOT_RAW_SEQUENCE_MAX] result_dict = dict(items) result_dict["__truncated__"] = f"...<{n} total>" return result_dict return dict(value) # Primitives and everything else: return as-is return value def snapshot(*, raw: bool = False) -> Dict[str, Any]: """ Return every session-state key grouped by phase slug. Groups ------ ``__global__`` – keys in GLOBAL_PERSISTENT_KEYS ```` – phase-owned keys (registry or ``__`` prefix) ``__unknown__`` – keys that don't fit anywhere (potential leaks) ``__meta__`` – snapshot metadata (key count, raw mode, timestamp) Parameters ---------- raw : bool, default ``False`` • ``False`` – every value is serialised to a short ``repr()`` string. Safe for ``st.json()`` and JSON serialisation. • ``True`` – actual Python values are returned with type-aware truncation for sequences, dicts, and strings. Large primitive objects (bytes, DataFrames, etc.) are replaced by a size descriptor. Use this mode in a Streamlit expander or ``st.write()`` call, not in ``st.json()`` (which only accepts JSON-serialisable data). Examples -------- :: # UI-safe (always works) st.json(state_manager.snapshot()) # Deep debug (use st.write, not st.json) st.write(state_manager.snapshot(raw=True)) # Programmatic inspection snap = state_manager.snapshot(raw=True) blast = snap["genomics"].get("blast_results") # actual list """ result: Dict[str, Any] = { "__global__": {}, "__unknown__": {}, "__meta__": { "raw": raw, "total_keys": len(st.session_state), "snapshot_ts": time.time(), }, } for slug in PHASE_SLUGS: result[slug] = {} for raw_key, value in st.session_state.items(): serialised = _snapshot_value(value, raw=raw) # 1. Namespaced prefix? matched = False for slug in PHASE_SLUGS: if raw_key.startswith(f"{slug}__"): result[slug][raw_key] = serialised matched = True break if matched: continue # 2. Static registry? placed = False for slug, keys in _PHASE_EPHEMERAL_KEYS.items(): if raw_key in keys: result[slug][raw_key] = serialised placed = True break if placed: continue # 3. Global? if raw_key in GLOBAL_PERSISTENT_KEYS: result["__global__"][raw_key] = serialised else: result["__unknown__"][raw_key] = serialised return result # =========================================================================== # PhaseState – OO convenience wrapper for module renderers # =========================================================================== class PhaseState: """ Binds a phase slug to every session-state operation. Usage inside a module renderer:: state = PhaseState("genomics") data = state.get("blast_results") # reads genomics__blast_results state.set("blast_results", data) # writes genomics__blast_results state.pop("blast_results") # deletes alive = state.has("blast_results") # existence check state.reset() # wipe all genomics ephemeral keys Parameters ---------- phase : str Slug (``"genomics"``) or display-name (``"Genomics & Variant Discovery"``). """ def __init__(self, phase: str) -> None: slug = _slug_from_display(phase) or phase if slug not in PHASE_SLUGS: raise ValueError( f"Unknown phase slug '{slug}'. " f"Valid: {list(PHASE_SLUGS.keys())}" ) self._phase = slug def get(self, key: str, default: Any = None) -> Any: """Read the namespaced key, returning *default* if absent.""" return get(key, phase=self._phase, default=default) def set(self, key: str, value: Any) -> None: """Write the namespaced key.""" set_key(key, value, phase=self._phase) def pop(self, key: str, default: Any = None) -> Any: """Delete and return the namespaced key.""" return pop(key, phase=self._phase, default=default) def has(self, key: str) -> bool: """True if the namespaced key exists.""" return _ns(self._phase, key) in st.session_state def reset(self) -> None: """Immediately wipe all ephemeral keys for this phase.""" reset_state_for_phase(self._phase) @property def phase(self) -> str: """The phase slug this instance is bound to.""" return self._phase def __repr__(self) -> str: # pragma: no cover return f"PhaseState(phase={self._phase!r})"