"""Gradio app for `srt_introspect.Trace` — adaptive-density reasoning trace. A single tab over a frozen Qwen-2.5-7B backbone: prompt → generated continuation, every token tinted by SRT adapter divergence, with hover-cards showing the activation verbalizer's best-guess narration at scheduler-picked positions. Run locally on a GPU box: pip install -r demo/requirements.txt PYTHONPATH=. python demo/srt_introspect_app.py Or deploy as an HF Space (hardware: a10g-small / zero-a100 work fine — Qwen-7B needs ~16 GB VRAM in bf16, plus the AV decoder ~2 GB). """ from __future__ import annotations import logging import os import pathlib import time import gradio as gr import torch from srt_introspect import Trace logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s") log = logging.getLogger("srt_introspect_app") # Optional ZeroGPU support — same pattern as demo/app.py. try: import spaces # type: ignore _ON_ZEROGPU = bool(os.environ.get("SPACES_ZERO_GPU")) or hasattr(spaces, "GPU") def _gpu(duration: int = 180): if _ON_ZEROGPU: return spaces.GPU(duration=duration) return lambda fn: fn except ImportError: # pragma: no cover _ON_ZEROGPU = False def _gpu(duration: int = 180): return lambda fn: fn DEVICE = "cuda" if (torch.cuda.is_available() or _ON_ZEROGPU) else "cpu" log.info("device=%s zero_gpu=%s", DEVICE, _ON_ZEROGPU) _TRACE: Trace | None = None def _get_trace() -> Trace: global _TRACE if _TRACE is None: log.info("Loading SRT adapter + activation verbalizer on %s", DEVICE) _TRACE = Trace.load(device=DEVICE) return _TRACE # ---------- palette (matches scripts/demos/render_trace_html.py) ---------- BG = "#0a1429" PANEL = "#16213d" PANEL_ALT = "#1e2d4f" INK = "#e6ecf5" DIM = "#8a9bb8" RULE = "#26345a" CYAN = "#7ee0ff" MINT = "#7eebc0" PINK = "#ff7eb9" LAVENDER = "#b8a4ff" def _lerp_rgb(a: tuple[int, int, int], b: tuple[int, int, int], t: float): return tuple(int(a[i] + (b[i] - a[i]) * t) for i in range(3)) def _hex(h: str): h = h.lstrip("#") return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)) def _div_color(div: float, lo: float, hi: float) -> str: if hi <= lo: t = 0.0 else: t = max(0.0, min(1.0, (div - lo) / (hi - lo))) c, m, p = _hex(CYAN), _hex(MINT), _hex(PINK) r, g, b = _lerp_rgb(c, m, t * 2) if t < 0.5 else _lerp_rgb(m, p, (t - 0.5) * 2) return f"rgba({r},{g},{b},0.30)" # Per-MAH-layer sparkline colours, cycled by layer index in the order # returned by Trace (== adapter.config.mah_layer_indices order). _LAYER_COLORS = (CYAN, LAVENDER, PINK, MINT) def _render_aggregate_curve(steps) -> str: """SVG with two stacked traces — aggregate divergence (pink) and next-token entropy (cyan) — plus a tick row marking verbalization positions. Lets viewers see whether the adapter's signal lines up with raw uncertainty or whether it's catching something the entropy can't. """ if not steps or len(steps) < 2: return "" divs = [s.divergence for s in steps] ents = [s.entropy for s in steps] verb_idxs = [i for i, s in enumerate(steps) if s.verbalization] W, H, PAD_X, PAD_Y = 1000, 130, 8, 14 n = len(steps) def _poly(vals, color): vmin, vmax = min(vals), max(vals) rng = (vmax - vmin) or 1.0 pts = [] for i, v in enumerate(vals): x = PAD_X + (W - 2 * PAD_X) * (i / (n - 1)) y = PAD_Y + (H - 2 * PAD_Y) * (1.0 - (v - vmin) / rng) pts.append(f"{x:.1f},{y:.1f}") return ( f'' ) ticks = "".join( f'' for i in vrb_safe(verb_idxs, n) ) svg = ( f'' f'{_poly(divs, PINK)}{_poly(ents, CYAN)}{ticks}' ) legend = ( f'' f'aggregate div' f'' f'next-tok entropy' f'' f'verbalization' ) return ( '
' '
' 'aggregate signal vs. predictive uncertainty' f'{legend}' '
' f'{svg}' '
token 0' f'token {n - 1}
' '
' ) def vrb_safe(idxs, n): # tiny guard: drop out-of-range indices in case caller mismatched lengths return [i for i in idxs if 0 <= i < n] def _render_layer_sparkline(steps, layer_indices: list[int]) -> str: """Tiny inline SVG showing one polyline per MAH layer over token index. Makes the "tap layers 7/14/21" claim visible — viewers can see whether early vs. late layers are doing the work at each token. """ if not steps: return "" series = [s.per_layer_divergence for s in steps] n_layers = max((len(p) for p in series), default=0) if n_layers == 0: return "" # If config didn't surface labels, fall back to 0..n-1. labels = layer_indices if len(layer_indices) == n_layers else list(range(n_layers)) W, H, PAD_X, PAD_Y = 1000, 110, 8, 14 n = len(steps) if n < 2: return "" # Per-layer min/max for independent y-scaling per line (so a weak # layer is still readable next to a dominant one). polylines = [] for li in range(n_layers): vals = [(p[li] if li < len(p) else 0.0) for p in series] vmin, vmax = min(vals), max(vals) rng = (vmax - vmin) or 1.0 pts = [] for i, v in enumerate(vals): x = PAD_X + (W - 2 * PAD_X) * (i / (n - 1)) # plot from top of band; lower y = larger value y = PAD_Y + (H - 2 * PAD_Y) * (1.0 - (v - vmin) / rng) pts.append(f"{x:.1f},{y:.1f}") color = _LAYER_COLORS[li % len(_LAYER_COLORS)] polylines.append( f'' ) legend_bits = "".join( f'' f'L{labels[li]}' for li in range(n_layers) ) svg = ( f'' f'{"".join(polylines)}' ) return ( '
' f'
' f'per-layer divergence' f'{legend_bits}' '
' f'{svg}' '
token 0' f'token {n - 1}
' '
' ) def _render_trace_html(result, prompt: str, elapsed: float, layer_indices: list[int] | None = None, title: str | None = None) -> str: import html as _html steps = result.steps if not steps: return '
(no tokens generated)
' divs = [s.divergence for s in steps] ds = sorted(divs) lo = ds[max(0, int(0.10 * len(ds)))] hi = ds[min(len(ds) - 1, int(0.90 * len(ds)))] if hi <= lo: lo, hi = min(divs), max(divs) # Detect regime flips: any step whose regime differs from the previous. # The first step is never a flip. These get an extra outline + glyph so # viewers can spot BEN-state transitions without hovering each token. flip_set: set[int] = set() prev_reg = steps[0].regime for s in steps[1:]: if s.regime != prev_reg: flip_set.add(s.token_idx) prev_reg = s.regime spans = [] for s in steps: disp = _html.escape(s.token).replace("\n", "
") bg = _div_color(s.divergence, lo, hi) classes = ["tok"] if s.verbalization: classes.append("selected") if s.regime == 0: classes.append("reg-bif") if s.token_idx in flip_set: classes.append("reg-flip") klass = " ".join(classes) title_txt = ( f"i={s.token_idx} · d={s.divergence:.2f} · " f"H={s.entropy:.2f} · r̂={s.r_hat:.2f} · reg={s.regime}" ) if s.verbalization: title_txt += f" → {s.verbalization[:240]}" # onclick pins the verbalization (or the metric line) to the side panel. # JS lives in _TRACE_CSS so it's defined once per page load. pin_payload = _html.escape( (s.verbalization or title_txt), quote=True ).replace("\n", " ") onclick = f"srtPin(this,"{pin_payload}")" spans.append( f'{disp}' ) sel = result.selected() sel_rows = "".join( f'{s.token_idx}' f'{_html.escape(s.token)}' f'{s.divergence:.2f}' f'{s.r_hat:.2f}' f'{s.regime}' f'{_html.escape(s.verbalization or "")}' for s in sel ) n_flip = len(flip_set) n_bif = sum(1 for s in steps if s.regime == 0) title_html = ( f'
{_html.escape(title)}
' if title else "" ) return f"""
{title_html}
tokens{len(steps)} verbalizations{len(sel)} regime flips{n_flip} bif (r=0){n_bif} div range{lo:.2f} → {hi:.2f} elapsed{elapsed:.1f}s
divergence lowhigh · box= verbalization (hover) · = bifurcating regime (r=0) · = regime flip · click any token to pin
{_html.escape(prompt)}
{''.join(spans)}
{_render_aggregate_curve(steps)} {_render_layer_sparkline(steps, layer_indices or [])}
pinned
    click tokens above to pin their verbalization or metrics here
    Selected verbalizations
    {sel_rows}
    idxtokendivregverbalization (AV)
    """ # Gradio HTML components are sandboxed, so all styles must be inline-injected # alongside the markup. We ship the CSS once at app boot via gr.HTML. _TRACE_CSS = f""" """ # JS for click-to-pin. Lives outside _TRACE_CSS because Gradio strips #