"""Gradio app for `srt_introspect.Trace` — adaptive-density reasoning trace.
A single tab over a frozen Qwen-2.5-7B backbone:
prompt → generated continuation, every token tinted by SRT adapter
divergence, with hover-cards showing the activation verbalizer's
best-guess narration at scheduler-picked positions.
Run locally on a GPU box:
pip install -r demo/requirements.txt
PYTHONPATH=. python demo/srt_introspect_app.py
Or deploy as an HF Space (hardware: a10g-small / zero-a100 work fine —
Qwen-7B needs ~16 GB VRAM in bf16, plus the AV decoder ~2 GB).
"""
from __future__ import annotations
import logging
import os
import pathlib
import time
import gradio as gr
import torch
from srt_introspect import Trace
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
log = logging.getLogger("srt_introspect_app")
# Optional ZeroGPU support — same pattern as demo/app.py.
try:
import spaces # type: ignore
_ON_ZEROGPU = bool(os.environ.get("SPACES_ZERO_GPU")) or hasattr(spaces, "GPU")
def _gpu(duration: int = 180):
if _ON_ZEROGPU:
return spaces.GPU(duration=duration)
return lambda fn: fn
except ImportError: # pragma: no cover
_ON_ZEROGPU = False
def _gpu(duration: int = 180):
return lambda fn: fn
DEVICE = "cuda" if (torch.cuda.is_available() or _ON_ZEROGPU) else "cpu"
log.info("device=%s zero_gpu=%s", DEVICE, _ON_ZEROGPU)
_TRACE: Trace | None = None
def _get_trace() -> Trace:
global _TRACE
if _TRACE is None:
log.info("Loading SRT adapter + activation verbalizer on %s", DEVICE)
_TRACE = Trace.load(device=DEVICE)
return _TRACE
# ---------- palette (matches scripts/demos/render_trace_html.py) ----------
BG = "#0a1429"
PANEL = "#16213d"
PANEL_ALT = "#1e2d4f"
INK = "#e6ecf5"
DIM = "#8a9bb8"
RULE = "#26345a"
CYAN = "#7ee0ff"
MINT = "#7eebc0"
PINK = "#ff7eb9"
LAVENDER = "#b8a4ff"
def _lerp_rgb(a: tuple[int, int, int], b: tuple[int, int, int], t: float):
return tuple(int(a[i] + (b[i] - a[i]) * t) for i in range(3))
def _hex(h: str):
h = h.lstrip("#")
return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16))
def _div_color(div: float, lo: float, hi: float) -> str:
if hi <= lo:
t = 0.0
else:
t = max(0.0, min(1.0, (div - lo) / (hi - lo)))
c, m, p = _hex(CYAN), _hex(MINT), _hex(PINK)
r, g, b = _lerp_rgb(c, m, t * 2) if t < 0.5 else _lerp_rgb(m, p, (t - 0.5) * 2)
return f"rgba({r},{g},{b},0.30)"
# Per-MAH-layer sparkline colours, cycled by layer index in the order
# returned by Trace (== adapter.config.mah_layer_indices order).
_LAYER_COLORS = (CYAN, LAVENDER, PINK, MINT)
def _render_aggregate_curve(steps) -> str:
"""SVG with two stacked traces — aggregate divergence (pink) and next-token
entropy (cyan) — plus a tick row marking verbalization positions. Lets
viewers see whether the adapter's signal lines up with raw uncertainty
or whether it's catching something the entropy can't.
"""
if not steps or len(steps) < 2:
return ""
divs = [s.divergence for s in steps]
ents = [s.entropy for s in steps]
verb_idxs = [i for i, s in enumerate(steps) if s.verbalization]
W, H, PAD_X, PAD_Y = 1000, 130, 8, 14
n = len(steps)
def _poly(vals, color):
vmin, vmax = min(vals), max(vals)
rng = (vmax - vmin) or 1.0
pts = []
for i, v in enumerate(vals):
x = PAD_X + (W - 2 * PAD_X) * (i / (n - 1))
y = PAD_Y + (H - 2 * PAD_Y) * (1.0 - (v - vmin) / rng)
pts.append(f"{x:.1f},{y:.1f}")
return (
f''
)
ticks = "".join(
f''
for i in vrb_safe(verb_idxs, n)
)
svg = (
f''
)
legend = (
f''
f'aggregate div'
f''
f'next-tok entropy'
f''
f'verbalization'
)
return (
'
'
'
'
'aggregate signal vs. predictive uncertainty'
f'{legend}'
'
'
f'{svg}'
'
token 0'
f'token {n - 1}
'
'
'
)
def vrb_safe(idxs, n):
# tiny guard: drop out-of-range indices in case caller mismatched lengths
return [i for i in idxs if 0 <= i < n]
def _render_layer_sparkline(steps, layer_indices: list[int]) -> str:
"""Tiny inline SVG showing one polyline per MAH layer over token index.
Makes the "tap layers 7/14/21" claim visible — viewers can see whether
early vs. late layers are doing the work at each token.
"""
if not steps:
return ""
series = [s.per_layer_divergence for s in steps]
n_layers = max((len(p) for p in series), default=0)
if n_layers == 0:
return ""
# If config didn't surface labels, fall back to 0..n-1.
labels = layer_indices if len(layer_indices) == n_layers else list(range(n_layers))
W, H, PAD_X, PAD_Y = 1000, 110, 8, 14
n = len(steps)
if n < 2:
return ""
# Per-layer min/max for independent y-scaling per line (so a weak
# layer is still readable next to a dominant one).
polylines = []
for li in range(n_layers):
vals = [(p[li] if li < len(p) else 0.0) for p in series]
vmin, vmax = min(vals), max(vals)
rng = (vmax - vmin) or 1.0
pts = []
for i, v in enumerate(vals):
x = PAD_X + (W - 2 * PAD_X) * (i / (n - 1))
# plot from top of band; lower y = larger value
y = PAD_Y + (H - 2 * PAD_Y) * (1.0 - (v - vmin) / rng)
pts.append(f"{x:.1f},{y:.1f}")
color = _LAYER_COLORS[li % len(_LAYER_COLORS)]
polylines.append(
f''
)
legend_bits = "".join(
f''
f'L{labels[li]}'
for li in range(n_layers)
)
svg = (
f''
)
return (
'
'
f'
'
f'per-layer divergence'
f'{legend_bits}'
'
'
f'{svg}'
'
token 0'
f'token {n - 1}
'
'
'
)
def _render_trace_html(result, prompt: str, elapsed: float,
layer_indices: list[int] | None = None,
title: str | None = None) -> str:
import html as _html
steps = result.steps
if not steps:
return '
(no tokens generated)
'
divs = [s.divergence for s in steps]
ds = sorted(divs)
lo = ds[max(0, int(0.10 * len(ds)))]
hi = ds[min(len(ds) - 1, int(0.90 * len(ds)))]
if hi <= lo:
lo, hi = min(divs), max(divs)
# Detect regime flips: any step whose regime differs from the previous.
# The first step is never a flip. These get an extra outline + glyph so
# viewers can spot BEN-state transitions without hovering each token.
flip_set: set[int] = set()
prev_reg = steps[0].regime
for s in steps[1:]:
if s.regime != prev_reg:
flip_set.add(s.token_idx)
prev_reg = s.regime
spans = []
for s in steps:
disp = _html.escape(s.token).replace("\n", " ")
bg = _div_color(s.divergence, lo, hi)
classes = ["tok"]
if s.verbalization:
classes.append("selected")
if s.regime == 0:
classes.append("reg-bif")
if s.token_idx in flip_set:
classes.append("reg-flip")
klass = " ".join(classes)
title_txt = (
f"i={s.token_idx} · d={s.divergence:.2f} · "
f"H={s.entropy:.2f} · r̂={s.r_hat:.2f} · reg={s.regime}"
)
if s.verbalization:
title_txt += f" → {s.verbalization[:240]}"
# onclick pins the verbalization (or the metric line) to the side panel.
# JS lives in _TRACE_CSS so it's defined once per page load.
pin_payload = _html.escape(
(s.verbalization or title_txt), quote=True
).replace("\n", " ")
onclick = f"srtPin(this,"{pin_payload}")"
spans.append(
f'{disp}'
)
sel = result.selected()
sel_rows = "".join(
f'
{s.token_idx}
'
f'
{_html.escape(s.token)}
'
f'
{s.divergence:.2f}
'
f'
{s.r_hat:.2f}
'
f'
{s.regime}
'
f'
{_html.escape(s.verbalization or "")}
'
for s in sel
)
n_flip = len(flip_set)
n_bif = sum(1 for s in steps if s.regime == 0)
title_html = (
f'
divergencelow→high·box= verbalization (hover)·▲= bifurcating regime (r=0)·⇋= regime flip·click any token to pin
{_html.escape(prompt)}
{''.join(spans)}
{_render_aggregate_curve(steps)}
{_render_layer_sparkline(steps, layer_indices or [])}
pinned
click tokens above to pin their verbalization or metrics here
Selected verbalizations
idx
token
div
r̂
reg
verbalization (AV)
{sel_rows}
"""
# Gradio HTML components are sandboxed, so all styles must be inline-injected
# alongside the markup. We ship the CSS once at app boot via gr.HTML.
_TRACE_CSS = f"""
"""
# JS for click-to-pin. Lives outside _TRACE_CSS because Gradio strips
#