RiverRider commited on
Commit
dfd612c
·
verified ·
1 Parent(s): 1aa9513

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +843 -21
app.py CHANGED
@@ -1,32 +1,854 @@
1
- """HF Space entrypoint shim.
2
 
3
- Hugging Face Spaces expects an `app.py` at the repo root. The actual
4
- implementation lives in `demo/srt_introspect_app.py`; this file just
5
- calls `build_app()` and launches.
6
- """
 
7
 
 
 
 
 
 
 
 
8
  from __future__ import annotations
9
 
 
10
  import os
 
 
11
 
12
- from demo.srt_introspect_app import build_app, _get_trace, _ON_ZEROGPU, DEVICE # noqa: E402
 
13
 
14
- demo = build_app()
15
- demo.queue(default_concurrency_limit=1, max_size=20)
16
 
17
- if DEVICE == "cuda" and not _ON_ZEROGPU:
18
- try:
19
- _get_trace()
20
- except Exception as e: # pragma: no cover
21
- print(f"warmup skipped: {e}")
22
 
23
- if __name__ == "__main__":
24
- # On HF Spaces we let the platform supply server_name/server_port via env;
25
- # locally we default to 0.0.0.0:7860.
26
- if _ON_ZEROGPU or os.environ.get("SPACE_ID"):
27
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  else:
29
- demo.launch(
30
- server_name="0.0.0.0",
31
- server_port=int(os.environ.get("PORT", 7860)),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio app for `srt_introspect.Trace` — adaptive-density reasoning trace.
2
 
3
+ A single tab over a frozen Qwen-2.5-7B backbone:
4
+
5
+ prompt generated continuation, every token tinted by SRT adapter
6
+ divergence, with hover-cards showing the activation verbalizer's
7
+ best-guess narration at scheduler-picked positions.
8
 
9
+ Run locally on a GPU box:
10
+ pip install -r demo/requirements.txt
11
+ PYTHONPATH=. python demo/srt_introspect_app.py
12
+
13
+ Or deploy as an HF Space (hardware: a10g-small / zero-a100 work fine —
14
+ Qwen-7B needs ~16 GB VRAM in bf16, plus the AV decoder ~2 GB).
15
+ """
16
  from __future__ import annotations
17
 
18
+ import logging
19
  import os
20
+ import pathlib
21
+ import time
22
 
23
+ import gradio as gr
24
+ import torch
25
 
26
+ from srt_introspect import Trace
 
27
 
28
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
29
+ log = logging.getLogger("srt_introspect_app")
 
 
 
30
 
31
+ # Optional ZeroGPU support — same pattern as demo/app.py.
32
+ try:
33
+ import spaces # type: ignore
34
+
35
+ _ON_ZEROGPU = bool(os.environ.get("SPACES_ZERO_GPU")) or hasattr(spaces, "GPU")
36
+
37
+ def _gpu(duration: int = 180):
38
+ if _ON_ZEROGPU:
39
+ return spaces.GPU(duration=duration)
40
+ return lambda fn: fn
41
+ except ImportError: # pragma: no cover
42
+ _ON_ZEROGPU = False
43
+
44
+ def _gpu(duration: int = 180):
45
+ return lambda fn: fn
46
+
47
+
48
+ DEVICE = "cuda" if (torch.cuda.is_available() or _ON_ZEROGPU) else "cpu"
49
+ log.info("device=%s zero_gpu=%s", DEVICE, _ON_ZEROGPU)
50
+
51
+ _TRACE: Trace | None = None
52
+
53
+
54
+ def _get_trace() -> Trace:
55
+ global _TRACE
56
+ if _TRACE is None:
57
+ log.info("Loading SRT adapter + activation verbalizer on %s", DEVICE)
58
+ _TRACE = Trace.load(device=DEVICE)
59
+ return _TRACE
60
+
61
+
62
+ # ---------- palette (matches scripts/demos/render_trace_html.py) ----------
63
+ BG = "#0a1429"
64
+ PANEL = "#16213d"
65
+ PANEL_ALT = "#1e2d4f"
66
+ INK = "#e6ecf5"
67
+ DIM = "#8a9bb8"
68
+ RULE = "#26345a"
69
+ CYAN = "#7ee0ff"
70
+ MINT = "#7eebc0"
71
+ PINK = "#ff7eb9"
72
+ LAVENDER = "#b8a4ff"
73
+
74
+
75
+ def _lerp_rgb(a: tuple[int, int, int], b: tuple[int, int, int], t: float):
76
+ return tuple(int(a[i] + (b[i] - a[i]) * t) for i in range(3))
77
+
78
+
79
+ def _hex(h: str):
80
+ h = h.lstrip("#")
81
+ return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16))
82
+
83
+
84
+ def _div_color(div: float, lo: float, hi: float) -> str:
85
+ if hi <= lo:
86
+ t = 0.0
87
  else:
88
+ t = max(0.0, min(1.0, (div - lo) / (hi - lo)))
89
+ c, m, p = _hex(CYAN), _hex(MINT), _hex(PINK)
90
+ r, g, b = _lerp_rgb(c, m, t * 2) if t < 0.5 else _lerp_rgb(m, p, (t - 0.5) * 2)
91
+ return f"rgba({r},{g},{b},0.30)"
92
+
93
+
94
+ # Per-MAH-layer sparkline colours, cycled by layer index in the order
95
+ # returned by Trace (== adapter.config.mah_layer_indices order).
96
+ _LAYER_COLORS = (CYAN, LAVENDER, PINK, MINT)
97
+
98
+
99
+ def _render_aggregate_curve(steps) -> str:
100
+ """SVG with two stacked traces — aggregate divergence (pink) and next-token
101
+ entropy (cyan) — plus a tick row marking verbalization positions. Lets
102
+ viewers see whether the adapter's signal lines up with raw uncertainty
103
+ or whether it's catching something the entropy can't.
104
+ """
105
+ if not steps or len(steps) < 2:
106
+ return ""
107
+ divs = [s.divergence for s in steps]
108
+ ents = [s.entropy for s in steps]
109
+ verb_idxs = [i for i, s in enumerate(steps) if s.verbalization]
110
+
111
+ W, H, PAD_X, PAD_Y = 1000, 130, 8, 14
112
+ n = len(steps)
113
+
114
+ def _poly(vals, color):
115
+ vmin, vmax = min(vals), max(vals)
116
+ rng = (vmax - vmin) or 1.0
117
+ pts = []
118
+ for i, v in enumerate(vals):
119
+ x = PAD_X + (W - 2 * PAD_X) * (i / (n - 1))
120
+ y = PAD_Y + (H - 2 * PAD_Y) * (1.0 - (v - vmin) / rng)
121
+ pts.append(f"{x:.1f},{y:.1f}")
122
+ return (
123
+ f'<polyline fill="none" stroke="{color}" stroke-width="1.6" '
124
+ f'stroke-linejoin="round" stroke-linecap="round" opacity="0.95" '
125
+ f'points="{" ".join(pts)}"/>'
126
+ )
127
+
128
+ ticks = "".join(
129
+ f'<line x1="{PAD_X + (W - 2 * PAD_X) * (i / (n - 1)):.1f}" '
130
+ f'x2="{PAD_X + (W - 2 * PAD_X) * (i / (n - 1)):.1f}" '
131
+ f'y1="{H - 4}" y2="{H - 1}" stroke="{LAVENDER}" '
132
+ f'stroke-width="1.2" opacity="0.85"/>'
133
+ for i in vrb_safe(verb_idxs, n)
134
+ )
135
+
136
+ svg = (
137
+ f'<svg viewBox="0 0 {W} {H}" preserveAspectRatio="none" '
138
+ f'class="srt-spk-svg" style="height:130px" '
139
+ f'aria-label="aggregate divergence and next-token entropy">'
140
+ f'{_poly(divs, PINK)}{_poly(ents, CYAN)}{ticks}</svg>'
141
+ )
142
+ legend = (
143
+ f'<span class="srt-spk-key"><span class="dot" style="background:{PINK}"></span>'
144
+ f'aggregate div</span>'
145
+ f'<span class="srt-spk-key"><span class="dot" style="background:{CYAN}"></span>'
146
+ f'next-tok entropy</span>'
147
+ f'<span class="srt-spk-key"><span class="dot" '
148
+ f'style="background:{LAVENDER};width:2px;height:10px;border-radius:1px"></span>'
149
+ f'verbalization</span>'
150
+ )
151
+ return (
152
+ '<div class="srt-spk">'
153
+ '<div class="srt-spk-head">'
154
+ '<span class="srt-spk-title">aggregate signal vs. predictive uncertainty</span>'
155
+ f'<span class="srt-spk-legend">{legend}</span>'
156
+ '</div>'
157
+ f'{svg}'
158
+ '<div class="srt-spk-axis"><span>token 0</span>'
159
+ f'<span>token {n - 1}</span></div>'
160
+ '</div>'
161
+ )
162
+
163
+
164
+ def vrb_safe(idxs, n):
165
+ # tiny guard: drop out-of-range indices in case caller mismatched lengths
166
+ return [i for i in idxs if 0 <= i < n]
167
+
168
+
169
+ def _render_layer_sparkline(steps, layer_indices: list[int]) -> str:
170
+ """Tiny inline SVG showing one polyline per MAH layer over token index.
171
+
172
+ Makes the "tap layers 7/14/21" claim visible — viewers can see whether
173
+ early vs. late layers are doing the work at each token.
174
+ """
175
+ if not steps:
176
+ return ""
177
+ series = [s.per_layer_divergence for s in steps]
178
+ n_layers = max((len(p) for p in series), default=0)
179
+ if n_layers == 0:
180
+ return ""
181
+ # If config didn't surface labels, fall back to 0..n-1.
182
+ labels = layer_indices if len(layer_indices) == n_layers else list(range(n_layers))
183
+
184
+ W, H, PAD_X, PAD_Y = 1000, 110, 8, 14
185
+ n = len(steps)
186
+ if n < 2:
187
+ return ""
188
+
189
+ # Per-layer min/max for independent y-scaling per line (so a weak
190
+ # layer is still readable next to a dominant one).
191
+ polylines = []
192
+ for li in range(n_layers):
193
+ vals = [(p[li] if li < len(p) else 0.0) for p in series]
194
+ vmin, vmax = min(vals), max(vals)
195
+ rng = (vmax - vmin) or 1.0
196
+ pts = []
197
+ for i, v in enumerate(vals):
198
+ x = PAD_X + (W - 2 * PAD_X) * (i / (n - 1))
199
+ # plot from top of band; lower y = larger value
200
+ y = PAD_Y + (H - 2 * PAD_Y) * (1.0 - (v - vmin) / rng)
201
+ pts.append(f"{x:.1f},{y:.1f}")
202
+ color = _LAYER_COLORS[li % len(_LAYER_COLORS)]
203
+ polylines.append(
204
+ f'<polyline fill="none" stroke="{color}" stroke-width="1.4" '
205
+ f'stroke-linejoin="round" stroke-linecap="round" '
206
+ f'opacity="0.95" points="{" ".join(pts)}"/>'
207
  )
208
+
209
+ legend_bits = "".join(
210
+ f'<span class="srt-spk-key"><span class="dot" '
211
+ f'style="background:{_LAYER_COLORS[li % len(_LAYER_COLORS)]}"></span>'
212
+ f'L{labels[li]}</span>'
213
+ for li in range(n_layers)
214
+ )
215
+ svg = (
216
+ f'<svg viewBox="0 0 {W} {H}" preserveAspectRatio="none" '
217
+ f'class="srt-spk-svg" aria-label="per-layer divergence over tokens">'
218
+ f'{"".join(polylines)}</svg>'
219
+ )
220
+ return (
221
+ '<div class="srt-spk">'
222
+ f'<div class="srt-spk-head">'
223
+ f'<span class="srt-spk-title">per-layer divergence</span>'
224
+ f'<span class="srt-spk-legend">{legend_bits}</span>'
225
+ '</div>'
226
+ f'{svg}'
227
+ '<div class="srt-spk-axis"><span>token 0</span>'
228
+ f'<span>token {n - 1}</span></div>'
229
+ '</div>'
230
+ )
231
+
232
+
233
+ def _render_trace_html(result, prompt: str, elapsed: float,
234
+ layer_indices: list[int] | None = None,
235
+ title: str | None = None) -> str:
236
+ import html as _html
237
+
238
+ steps = result.steps
239
+ if not steps:
240
+ return '<div style="color:#8a9bb8;padding:1rem">(no tokens generated)</div>'
241
+
242
+ divs = [s.divergence for s in steps]
243
+ ds = sorted(divs)
244
+ lo = ds[max(0, int(0.10 * len(ds)))]
245
+ hi = ds[min(len(ds) - 1, int(0.90 * len(ds)))]
246
+ if hi <= lo:
247
+ lo, hi = min(divs), max(divs)
248
+
249
+ # Detect regime flips: any step whose regime differs from the previous.
250
+ # The first step is never a flip. These get an extra outline + glyph so
251
+ # viewers can spot BEN-state transitions without hovering each token.
252
+ flip_set: set[int] = set()
253
+ prev_reg = steps[0].regime
254
+ for s in steps[1:]:
255
+ if s.regime != prev_reg:
256
+ flip_set.add(s.token_idx)
257
+ prev_reg = s.regime
258
+
259
+ spans = []
260
+ for s in steps:
261
+ disp = _html.escape(s.token).replace("\n", "<br>")
262
+ bg = _div_color(s.divergence, lo, hi)
263
+ classes = ["tok"]
264
+ if s.verbalization:
265
+ classes.append("selected")
266
+ if s.regime == 0:
267
+ classes.append("reg-bif")
268
+ if s.token_idx in flip_set:
269
+ classes.append("reg-flip")
270
+ klass = " ".join(classes)
271
+ title_txt = (
272
+ f"i={s.token_idx} · d={s.divergence:.2f} · "
273
+ f"H={s.entropy:.2f} · r̂={s.r_hat:.2f} · reg={s.regime}"
274
+ )
275
+ if s.verbalization:
276
+ title_txt += f" → {s.verbalization[:240]}"
277
+ # onclick pins the verbalization (or the metric line) to the side panel.
278
+ # JS lives in _TRACE_CSS so it's defined once per page load.
279
+ pin_payload = _html.escape(
280
+ (s.verbalization or title_txt), quote=True
281
+ ).replace("\n", " ")
282
+ onclick = f"srtPin(this,&quot;{pin_payload}&quot;)"
283
+ spans.append(
284
+ f'<span class="{klass}" style="background:{bg}" '
285
+ f'data-title="{_html.escape(title_txt)}" '
286
+ f'onclick="{onclick}">{disp}</span>'
287
+ )
288
+
289
+ sel = result.selected()
290
+ sel_rows = "".join(
291
+ f'<tr><td class="num">{s.token_idx}</td>'
292
+ f'<td><code>{_html.escape(s.token)}</code></td>'
293
+ f'<td class="num">{s.divergence:.2f}</td>'
294
+ f'<td class="num">{s.r_hat:.2f}</td>'
295
+ f'<td class="num">{s.regime}</td>'
296
+ f'<td class="verb">{_html.escape(s.verbalization or "")}</td></tr>'
297
+ for s in sel
298
+ )
299
+
300
+ n_flip = len(flip_set)
301
+ n_bif = sum(1 for s in steps if s.regime == 0)
302
+ title_html = (
303
+ f'<div class="srt-trace-title">{_html.escape(title)}</div>' if title else ""
304
+ )
305
+
306
+ return f"""
307
+ <div class="srt-trace">
308
+ {title_html}
309
+ <div class="srt-meta">
310
+ <span class="chip"><span class="lbl">tokens</span>{len(steps)}</span>
311
+ <span class="chip"><span class="lbl">verbalizations</span>{len(sel)}</span>
312
+ <span class="chip"><span class="lbl">regime flips</span>{n_flip}</span>
313
+ <span class="chip"><span class="lbl">bif (r=0)</span>{n_bif}</span>
314
+ <span class="chip"><span class="lbl">div range</span>{lo:.2f} → {hi:.2f}</span>
315
+ <span class="chip"><span class="lbl">elapsed</span>{elapsed:.1f}s</span>
316
+ </div>
317
+ <div class="srt-legend">
318
+ <span>divergence</span>
319
+ <span class="grad"></span>
320
+ <span style="color:{CYAN}">low</span><span>→</span><span style="color:{PINK}">high</span>
321
+ <span style="opacity:.5">·</span>
322
+ <span class="box">box</span><span>= verbalization (hover)</span>
323
+ <span style="opacity:.5">·</span>
324
+ <span class="bif-key">▲</span><span>= bifurcating regime (r=0)</span>
325
+ <span style="opacity:.5">·</span>
326
+ <span class="flip-key">⇋</span><span>= regime flip</span>
327
+ <span style="opacity:.5">·</span>
328
+ <span style="color:{LAVENDER}">click any token to pin</span>
329
+ </div>
330
+ <div class="srt-prompt">{_html.escape(prompt)}</div>
331
+ <div class="srt-response">{''.join(spans)}</div>
332
+ {_render_aggregate_curve(steps)}
333
+ {_render_layer_sparkline(steps, layer_indices or [])}
334
+ <div class="srt-pinboard" id="srt-pinboard">
335
+ <div class="srt-pin-head">
336
+ <span class="srt-spk-title">pinned</span>
337
+ <button class="srt-pin-clear" onclick="srtPinClear()">clear</button>
338
+ </div>
339
+ <ol class="srt-pin-list"></ol>
340
+ <div class="srt-pin-empty">click tokens above to pin their verbalization or metrics here</div>
341
+ </div>
342
+ <div class="srt-label">Selected verbalizations</div>
343
+ <table class="srt-table">
344
+ <thead><tr><th>idx</th><th>token</th><th>div</th><th>r̂</th><th>reg</th><th>verbalization (AV)</th></tr></thead>
345
+ <tbody>{sel_rows}</tbody>
346
+ </table>
347
+ </div>
348
+ """
349
+
350
+
351
+ # Gradio HTML components are sandboxed, so all styles must be inline-injected
352
+ # alongside the markup. We ship the CSS once at app boot via gr.HTML.
353
+ _TRACE_CSS = f"""
354
+ <style>
355
+ .srt-trace {{
356
+ font-family: 'Inter', -apple-system, system-ui, sans-serif;
357
+ color: {INK};
358
+ }}
359
+ .srt-meta {{
360
+ display: flex; flex-wrap: wrap; gap: 0.5rem 0.75rem;
361
+ margin-bottom: 0.8rem;
362
+ }}
363
+ .srt-meta .chip {{
364
+ display: inline-flex; align-items: center; gap: 0.35rem;
365
+ background: {PANEL}; border: 1px solid {RULE};
366
+ padding: 0.22rem 0.6rem; border-radius: 999px;
367
+ font-family: 'JetBrains Mono', ui-monospace, monospace; font-size: 0.78rem;
368
+ color: {INK};
369
+ }}
370
+ .srt-meta .chip .lbl {{ color: {DIM}; }}
371
+ .srt-legend {{
372
+ display: flex; flex-wrap: wrap; align-items: center; gap: 0.4rem 0.75rem;
373
+ font-size: 0.78rem; color: {DIM}; margin-bottom: 0.8rem;
374
+ }}
375
+ .srt-legend .grad {{
376
+ display: inline-block; width: 180px; height: 0.5rem; border-radius: 999px;
377
+ background: linear-gradient(90deg, {CYAN}, {MINT}, {PINK});
378
+ box-shadow: 0 0 10px rgba(126,235,192,0.25);
379
+ }}
380
+ .srt-legend .box {{
381
+ display: inline-block; padding: 0.05rem 0.45rem; border-radius: 4px;
382
+ outline: 1.5px solid {LAVENDER}; outline-offset: 1px; color: {INK};
383
+ }}
384
+ .srt-prompt {{
385
+ background: {PANEL}; border-left: 3px solid {CYAN};
386
+ padding: 0.7rem 0.9rem; border-radius: 6px;
387
+ font-family: 'JetBrains Mono', ui-monospace, monospace;
388
+ white-space: pre-wrap; font-size: 0.82rem; color: {INK};
389
+ margin-bottom: 0.8rem;
390
+ }}
391
+ .srt-response {{
392
+ background: {PANEL}; border: 1px solid {RULE}; border-radius: 10px;
393
+ padding: 1.2rem 1.4rem; line-height: 2.0; font-size: 1.0rem;
394
+ color: {INK};
395
+ white-space: pre-wrap;
396
+ overflow-wrap: anywhere;
397
+ word-break: normal;
398
+ }}
399
+ .srt-response .tok {{
400
+ display: inline; padding: 1px 1px; border-radius: 3px; position: relative;
401
+ }}
402
+ .srt-response .tok.selected {{
403
+ outline: 1.5px solid {LAVENDER}; outline-offset: 1px;
404
+ cursor: help; box-shadow: 0 0 10px rgba(184,164,255,0.35);
405
+ }}
406
+ .srt-response .tok:hover {{ background: rgba(184,164,255,0.35) !important; }}
407
+ .srt-response .tok:hover::after {{
408
+ content: attr(data-title);
409
+ position: absolute; left: 50%; transform: translateX(-50%);
410
+ top: 1.7em; z-index: 20;
411
+ background: {PANEL_ALT}; color: {INK};
412
+ padding: 0.55rem 0.75rem; border-radius: 6px;
413
+ border: 1px solid {LAVENDER};
414
+ font-family: 'JetBrains Mono', ui-monospace, monospace;
415
+ font-size: 0.74rem; line-height: 1.45;
416
+ white-space: pre-wrap; max-width: 460px; min-width: 220px;
417
+ box-shadow: 0 6px 24px rgba(0,0,0,0.45);
418
+ pointer-events: none;
419
+ }}
420
+ .srt-label {{
421
+ font-size: 0.72rem; letter-spacing: 0.18em; text-transform: uppercase;
422
+ color: {DIM}; margin: 1.4rem 0 0.5rem; font-weight: 600;
423
+ }}
424
+ .srt-spk {{
425
+ background: {PANEL}; border: 1px solid {RULE}; border-radius: 8px;
426
+ padding: 0.6rem 0.8rem 0.5rem; margin: 0.6rem 0 0;
427
+ }}
428
+ .srt-spk-head {{
429
+ display: flex; align-items: center; justify-content: space-between;
430
+ gap: 0.8rem; margin-bottom: 0.3rem;
431
+ }}
432
+ .srt-spk-title {{
433
+ font-size: 0.66rem; letter-spacing: 0.18em; text-transform: uppercase;
434
+ color: {DIM}; font-weight: 600;
435
+ }}
436
+ .srt-spk-legend {{
437
+ display: flex; gap: 0.8rem; font-size: 0.72rem; color: {INK};
438
+ font-family: 'JetBrains Mono', ui-monospace, monospace;
439
+ }}
440
+ .srt-spk-key {{ display: inline-flex; align-items: center; gap: 0.3rem; }}
441
+ .srt-spk-key .dot {{
442
+ width: 8px; height: 8px; border-radius: 50%; display: inline-block;
443
+ }}
444
+ .srt-spk-svg {{
445
+ display: block; width: 100%; height: 110px;
446
+ background: {PANEL_ALT}; border-radius: 6px;
447
+ }}
448
+ .srt-spk-axis {{
449
+ display: flex; justify-content: space-between;
450
+ font-size: 0.62rem; color: {DIM}; margin-top: 0.25rem;
451
+ font-family: 'JetBrains Mono', ui-monospace, monospace;
452
+ }}
453
+ .srt-table {{
454
+ border-collapse: collapse; width: 100%; font-size: 0.83rem;
455
+ background: {PANEL}; border: 1px solid {RULE}; border-radius: 10px;
456
+ overflow: hidden;
457
+ }}
458
+ .srt-table th, .srt-table td {{
459
+ text-align: left; padding: 0.45rem 0.7rem;
460
+ border-bottom: 1px solid {RULE}; vertical-align: top;
461
+ }}
462
+ .srt-table tr:last-child td {{ border-bottom: none; }}
463
+ .srt-table th {{
464
+ background: {PANEL_ALT}; font-weight: 600; color: {DIM};
465
+ font-size: 0.7rem; letter-spacing: 0.1em; text-transform: uppercase;
466
+ }}
467
+ .srt-table td.num {{
468
+ font-family: 'JetBrains Mono', ui-monospace, monospace; color: {MINT};
469
+ }}
470
+ .srt-table td code {{
471
+ background: {BG}; padding: 1px 6px; border-radius: 3px;
472
+ font-family: 'JetBrains Mono', ui-monospace, monospace; color: {CYAN};
473
+ font-size: 0.8rem;
474
+ }}
475
+ .srt-table td.verb {{ color: {INK}; }}
476
+ .srt-trace-title {{
477
+ font-size: 0.74rem; letter-spacing: 0.18em; text-transform: uppercase;
478
+ color: {LAVENDER}; font-weight: 600; margin-bottom: 0.5rem;
479
+ }}
480
+ .srt-legend .bif-key {{
481
+ display: inline-block; color: {MINT}; font-weight: 700;
482
+ }}
483
+ .srt-legend .flip-key {{
484
+ display: inline-block; color: {PINK}; font-weight: 700;
485
+ }}
486
+ /* Component A: bifurcating regime (r=0) and regime-flip markers.
487
+ Both render as small unicode glyphs above the token without consuming
488
+ line height, so the prose still reads naturally. */
489
+ .srt-response .tok.reg-bif {{
490
+ box-shadow: inset 0 -1.5px 0 0 {MINT};
491
+ }}
492
+ .srt-response .tok.reg-bif::before {{
493
+ content: "▲"; position: absolute; top: -0.85em; left: 50%;
494
+ transform: translateX(-50%); font-size: 0.55em; color: {MINT};
495
+ pointer-events: none; opacity: 0.9;
496
+ }}
497
+ .srt-response .tok.reg-flip {{
498
+ outline: 1px dashed {PINK}; outline-offset: 1px;
499
+ }}
500
+ .srt-response .tok.reg-flip::before {{
501
+ content: "⇋"; position: absolute; top: -0.85em; left: 50%;
502
+ transform: translateX(-50%); font-size: 0.6em; color: {PINK};
503
+ pointer-events: none; opacity: 0.95;
504
+ }}
505
+ /* When a token is both bif and flip, the flip glyph wins; use a combined
506
+ visual cue via a top-bar. */
507
+ .srt-response .tok.reg-bif.reg-flip {{
508
+ box-shadow: inset 0 -1.5px 0 0 {MINT};
509
+ outline: 1px dashed {PINK}; outline-offset: 1px;
510
+ }}
511
+ .srt-response .tok.srt-pinned {{
512
+ outline: 2px solid {CYAN} !important; outline-offset: 1px;
513
+ box-shadow: 0 0 12px rgba(126,224,255,0.55);
514
+ }}
515
+ /* Component D: pinboard side panel. */
516
+ .srt-pinboard {{
517
+ margin: 0.8rem 0 0; padding: 0.55rem 0.8rem 0.6rem;
518
+ background: {PANEL}; border: 1px solid {RULE}; border-radius: 8px;
519
+ }}
520
+ .srt-pin-head {{
521
+ display: flex; align-items: center; justify-content: space-between;
522
+ gap: 0.6rem; margin-bottom: 0.35rem;
523
+ }}
524
+ .srt-pin-clear {{
525
+ background: transparent; color: {DIM}; border: 1px solid {RULE};
526
+ border-radius: 4px; padding: 0.1rem 0.55rem; font-size: 0.7rem;
527
+ cursor: pointer; font-family: 'JetBrains Mono', ui-monospace, monospace;
528
+ }}
529
+ .srt-pin-clear:hover {{ color: {CYAN}; border-color: {CYAN}; }}
530
+ .srt-pin-list {{
531
+ list-style: decimal inside; margin: 0; padding: 0;
532
+ font-size: 0.78rem; color: {INK};
533
+ font-family: 'JetBrains Mono', ui-monospace, monospace; line-height: 1.55;
534
+ }}
535
+ .srt-pin-list li {{ padding: 0.15rem 0; }}
536
+ .srt-pin-empty {{ font-size: 0.74rem; color: {DIM}; }}
537
+ .srt-pinboard.has-pins .srt-pin-empty {{ display: none; }}
538
+ </style>
539
+ """
540
+
541
+
542
+ # JS for click-to-pin. Lives outside _TRACE_CSS because Gradio strips
543
+ # <script> tags from gr.HTML *updates* (sanitization on innerHTML
544
+ # replacement) and even when a <script> survives in an initial render,
545
+ # scripts inserted via innerHTML do not execute per the HTML5 spec.
546
+ # We wire this once at page load via gr.Blocks(js=...) so the symbols
547
+ # are defined on `window` before any onclick fires.
548
+ _PIN_JS = r"""
549
+ () => {
550
+ if (window.__srtPinInstalled) return;
551
+ window.__srtPinInstalled = true;
552
+ window.srtPin = function(el, text) {
553
+ var trace = el.closest('.srt-trace');
554
+ if (!trace) return;
555
+ var board = trace.querySelector('.srt-pinboard');
556
+ if (!board) return;
557
+ var list = board.querySelector('.srt-pin-list');
558
+ if (el.classList.contains('srt-pinned')) {
559
+ el.classList.remove('srt-pinned');
560
+ var key = el.getAttribute('data-pin-key');
561
+ if (key) {
562
+ var existing = list.querySelector('li[data-pin-key="' + key + '"]');
563
+ if (existing) existing.remove();
564
+ el.removeAttribute('data-pin-key');
565
+ }
566
+ } else {
567
+ el.classList.add('srt-pinned');
568
+ var key = 'p' + Math.random().toString(36).slice(2, 9);
569
+ el.setAttribute('data-pin-key', key);
570
+ var li = document.createElement('li');
571
+ li.setAttribute('data-pin-key', key);
572
+ li.textContent = text;
573
+ list.appendChild(li);
574
+ }
575
+ if (list.children.length > 0) board.classList.add('has-pins');
576
+ else board.classList.remove('has-pins');
577
+ };
578
+ window.srtPinClear = function() {
579
+ document.querySelectorAll('.srt-trace .tok.srt-pinned').forEach(function(el) {
580
+ el.classList.remove('srt-pinned');
581
+ el.removeAttribute('data-pin-key');
582
+ });
583
+ document.querySelectorAll('.srt-pinboard').forEach(function(b) {
584
+ var list = b.querySelector('.srt-pin-list');
585
+ if (list) list.innerHTML = '';
586
+ b.classList.remove('has-pins');
587
+ });
588
+ };
589
+ }
590
+ """
591
+
592
+
593
+ # ---------- callbacks ----------
594
+
595
+ MAX_PROMPT_CHARS = 1500
596
+
597
+ # Qwen-2.5-7B is a *base* completion model, not Instruct. To give visitors
598
+ # a chat-style UX without retraining or swapping the backbone (which would
599
+ # invalidate the adapter's calibration on base activations), we wrap the
600
+ # user's message in a minimal User/Assistant scaffold that the base model
601
+ # completes in-context. The trace shows the user's original message; the
602
+ # scaffold is implementation detail.
603
+ _CHAT_PREFIX = "User: "
604
+ _CHAT_SUFFIX = "\nAssistant:"
605
+
606
+
607
+ def _wrap_chat(user_text: str) -> str:
608
+ return f"{_CHAT_PREFIX}{user_text.strip()}{_CHAT_SUFFIX}"
609
+
610
+ # Pre-rendered HTML for the "before-first-generate" placeholder. We ship a
611
+ # cached trace so visitors land on a populated demo instead of an empty
612
+ # panel + 60-90s wait on a cold ZeroGPU slice. The cache is produced by
613
+ # `scripts/cache_demo_traces.py` and committed to the repo / Space.
614
+ _CACHE_DIR = pathlib.Path(__file__).resolve().parent / "cached_traces"
615
+
616
+
617
+ def _initial_trace_html() -> str:
618
+ candidates = [_CACHE_DIR / "default.html"]
619
+ for p in candidates:
620
+ try:
621
+ if p.exists():
622
+ body = p.read_text(encoding="utf-8")
623
+ return (
624
+ '<div style="color:#8a9bb8;font-size:0.78rem;padding:0 0 .6rem">'
625
+ 'Cached trace from a previous run \u2014 click '
626
+ '<b>Generate trace</b> for a fresh one. '
627
+ 'First request on a cold ZeroGPU slice loads ~17&nbsp;GB '
628
+ 'of weights and takes 60\u201390 s; subsequent requests '
629
+ 'are ~7\u201310 s. Prompts and outputs are not logged.'
630
+ '</div>'
631
+ ) + body
632
+ except Exception as e: # pragma: no cover
633
+ log.warning("cached trace load failed for %s: %s", p, e)
634
+ return (
635
+ '<div style="color:#8a9bb8;padding:1rem">'
636
+ 'Click <b>Generate trace</b> to start. First request on a fresh '
637
+ 'ZeroGPU slice loads ~17&nbsp;GB of weights and may take '
638
+ '60&ndash;90 s; subsequent requests are ~7&ndash;10 s. '
639
+ 'Prompts and outputs are not logged.'
640
+ '</div>'
641
+ )
642
+
643
+
644
+ @_gpu(duration=300)
645
+ def cb_generate(prompt: str, mode: str, max_new: int, budget: int, k: int,
646
+ temperature: float, top_p: float, repetition_penalty: float):
647
+ if not prompt.strip():
648
+ return '<div style="color:#8a9bb8;padding:1rem">(enter a prompt above)</div>'
649
+ # Server-side bounds: prompts are O(N²) since the adapter has no KV cache,
650
+ # and we don't want a single user to pin the GPU for minutes.
651
+ user_prompt = prompt[:MAX_PROMPT_CHARS]
652
+ chat_mode = (mode or "").lower().startswith("chat")
653
+ model_prompt = _wrap_chat(user_prompt) if chat_mode else user_prompt
654
+ display_prompt = user_prompt
655
+ max_new = max(8, min(int(max_new), 512))
656
+ budget = max(1, min(int(budget), 20))
657
+ k = max(1, min(int(k), 8))
658
+ t = _get_trace()
659
+ layer_indices = list(getattr(t.adapter.config, "mah_layer_indices", []) or [])
660
+
661
+ chat_note = (
662
+ " · chat shim: User:/Assistant: wrapper applied" if chat_mode else ""
663
+ )
664
+ t0 = time.perf_counter()
665
+ result = t.generate(
666
+ model_prompt,
667
+ max_new_tokens=int(max_new),
668
+ budget=int(budget),
669
+ k=int(k),
670
+ temperature=float(temperature),
671
+ top_p=float(top_p),
672
+ repetition_penalty=float(repetition_penalty),
673
+ )
674
+ elapsed = time.perf_counter() - t0
675
+ return _render_trace_html(
676
+ result, display_prompt, elapsed, layer_indices=layer_indices,
677
+ title=(None if not chat_mode else f"chat mode{chat_note}"),
678
+ )
679
+
680
+
681
+ # ---------- UI ----------
682
+
683
+ INTRO_MD = """\
684
+ # SRT&nbsp;·&nbsp;introspect — adaptive-density reasoning trace
685
+
686
+ **SRT** ([paper](https://github.com/space-bacon/SRT/blob/main/paper.md)) is a ~12 M-parameter adapter that bolts onto a frozen LLM and watches its own residual stream. The backbone here is **Qwen-2.5-7B**, fully frozen, with two sidecars attached:
687
+
688
+ - **SRT-Adapter** (Stage 3) — three metapragmatic-attention heads tap layers 7/14/21 and feed a GRU. For every token it emits a **divergence** scalar (how hard the heads are working), a **reflexivity** estimate r̂, and a **regime** label.
689
+ - **Activation Verbalizer** (Stage 4) — a small decoder trained to translate a single layer-20 hidden state into one English sentence describing what the model is processing.
690
+
691
+ What this app does:
692
+ 1. Generates a continuation token-by-token through the frozen backbone, with the SRT adapter scoring each step.
693
+ 2. A **density scheduler** picks `budget` positions whose divergence values carry equal mass — so verbalizations cluster where the meta-state is *actually* moving instead of being evenly spaced.
694
+ 3. At each pick, the AV is run K times for a consensus narration.
695
+
696
+ In the trace below, token background colour encodes divergence (<span style="color:#7ee0ff">**cyan**</span> = coasting, <span style="color:#7eebc0">**mint**</span> = mid, <span style="color:#ff7eb9">**pink**</span> = forking). Boxed tokens have a verbalization — **hover them** to read it.
697
+ """
698
+
699
+
700
+ def build_app() -> gr.Blocks:
701
+ theme = gr.themes.Base(
702
+ primary_hue="cyan",
703
+ secondary_hue="pink",
704
+ neutral_hue="slate",
705
+ font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui"],
706
+ font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "ui-monospace"],
707
+ ).set(
708
+ body_background_fill=BG,
709
+ body_text_color=INK,
710
+ background_fill_primary=PANEL,
711
+ background_fill_secondary=PANEL_ALT,
712
+ border_color_primary=RULE,
713
+ block_background_fill=PANEL,
714
+ block_border_color=RULE,
715
+ block_label_text_color=DIM,
716
+ block_title_text_color=INK,
717
+ input_background_fill=PANEL_ALT,
718
+ input_border_color=RULE,
719
+ button_primary_background_fill=PINK,
720
+ button_primary_background_fill_hover=LAVENDER,
721
+ button_primary_text_color=BG,
722
+ )
723
+
724
+ with gr.Blocks(theme=theme, title="SRT · introspect", js=_PIN_JS, css=f"""
725
+ body, .gradio-container {{ background: {BG} !important; }}
726
+ .gradio-container {{ max-width: 1080px !important; margin: 0 auto; }}
727
+ h1, h2, h3 {{ color: {INK}; }}
728
+ a {{ color: {CYAN}; }}
729
+ """) as app:
730
+ gr.HTML(_TRACE_CSS)
731
+ gr.Markdown(INTRO_MD)
732
+
733
+ with gr.Row():
734
+ with gr.Column(scale=3):
735
+ mode = gr.Radio(
736
+ choices=["Completion", "Chat"],
737
+ value="Completion",
738
+ label="Input mode",
739
+ info=(
740
+ "Completion: feed the prompt raw (good for code, "
741
+ "narrative, or mid-sentence continuations). "
742
+ "Chat: type a question or instruction naturally; "
743
+ "we wrap it as User:/Assistant: for the base model."
744
+ ),
745
+ )
746
+ prompt = gr.Textbox(
747
+ label="Prompt",
748
+ value=(
749
+ "def quicksort(arr):\n"
750
+ " if len(arr) <= 1:\n"
751
+ " return arr\n"
752
+ " pivot = arr[len(arr) // 2]\n"
753
+ ),
754
+ lines=6,
755
+ )
756
+ with gr.Column(scale=2):
757
+ max_new = gr.Slider(32, 512, value=160, step=8, label="max_new_tokens (capped at 512 on the public demo)")
758
+ budget = gr.Slider(2, 20, value=10, step=1, label="verbalization budget (adaptive slots)")
759
+ k = gr.Slider(1, 8, value=6, step=1, label="AV samples per slot (K)")
760
+ with gr.Accordion("Sampling", open=False):
761
+ temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature")
762
+ top_p = gr.Slider(0.5, 1.0, value=0.95, step=0.01, label="top_p")
763
+ rep_pen = gr.Slider(1.0, 1.5, value=1.15, step=0.01, label="repetition_penalty")
764
+ with gr.Row():
765
+ go = gr.Button("Generate trace", variant="primary")
766
+ stop = gr.Button("Stop", variant="secondary")
767
+
768
+ gr.Markdown("### Trace")
769
+ out = gr.HTML(_initial_trace_html())
770
+
771
+ gr.Examples(
772
+ examples=[
773
+ # ---- Chat mode (natural-language) ----
774
+ ["Explain why warm water sometimes freezes faster than cold water.", "Chat"],
775
+ ["What is the capital of Australia, and why isn't it Sydney?", "Chat"],
776
+ ["A patient has fever, joint pain, and a rash. What should I consider?", "Chat"],
777
+ ["Write the first paragraph of a short story about a lighthouse keeper.", "Chat"],
778
+ ["Is consciousness computable? Argue both sides briefly.", "Chat"],
779
+ # ---- Completion mode (raw continuation) ----
780
+ ["def quicksort(arr):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr) // 2]\n", "Completion"],
781
+ ['<title>The Bell Tower</title>\n<chapter id="1">', "Completion"],
782
+ ["The capital of Australia is", "Completion"],
783
+ ["For the first half of the essay she defended free trade, "
784
+ "but in the second half she", "Completion"],
785
+ ],
786
+ inputs=[prompt, mode],
787
+ label="Try one",
788
+ )
789
+
790
+ gen_event = go.click(
791
+ cb_generate,
792
+ inputs=[prompt, mode, max_new, budget, k, temperature, top_p, rep_pen],
793
+ outputs=[out],
794
+ )
795
+ stop.click(fn=None, cancels=[gen_event])
796
+
797
+ # Default code prompt only makes sense in Completion mode. When the
798
+ # user switches to Chat, blank the textbox so they don't try to chat
799
+ # with `def quicksort(...)`; when they switch back to Completion,
800
+ # restore the code seed so the textbox isn't stranded empty.
801
+ _COMPLETION_DEFAULT = (
802
+ "def quicksort(arr):\n"
803
+ " if len(arr) <= 1:\n"
804
+ " return arr\n"
805
+ " pivot = arr[len(arr) // 2]\n"
806
+ )
807
+
808
+ def _on_mode_change(new_mode: str, current: str):
809
+ if (new_mode or "").lower().startswith("chat"):
810
+ # Only clear if the user hasn't edited the default code seed.
811
+ if current.strip() == _COMPLETION_DEFAULT.strip():
812
+ return gr.update(value="", placeholder="Ask anything…")
813
+ return gr.update(placeholder="Ask anything…")
814
+ # Switched to Completion: restore code seed only if textbox is empty.
815
+ if not current.strip():
816
+ return gr.update(value=_COMPLETION_DEFAULT, placeholder=None)
817
+ return gr.update(placeholder=None)
818
+
819
+ mode.change(_on_mode_change, inputs=[mode, prompt], outputs=[prompt])
820
+
821
+ gr.Markdown(f"""
822
+ ---
823
+
824
+ ### How to read the trace
825
+
826
+ | signal | what it means | where it comes from |
827
+ |---|---|---|
828
+ | <span style="color:{PINK}">**divergence**</span> (token tint) | how far the metapragmatic-attention heads are pulling the meta-state at this step — peaks correspond to entity transitions, hedging, or topic shifts | sum of L2 norms of three MAH delta vectors |
829
+ | <span style="color:{LAVENDER}">**boxed tokens**</span> | scheduler-picked positions where the AV ran — placed at equal-mass quantiles of the divergence curve | `srt_introspect.scheduler.quantile_by_density` |
830
+ | **r̂** (in hover-card) | bifurcation network's reflexivity estimate (0=automatic, 1=actively reflecting) | BEN head on the GRU meta-state |
831
+ | **reg** (in hover-card) | discrete regime label, 0 or 1 — often stuck at 1, take with salt | BEN regime classifier |
832
+ | **verbalization** | AV decoder's best-guess English summary of the layer-20 hidden state at that position; the same hidden state the model would have continued from | `RiverRider/srt-nla-av-v1` decoder |
833
+
834
+ The AV is a paraphraser, not a mind-reader. Given one of the model's hidden states, its single best guess matches a known-good description about a quarter of the way from "random text" to "perfect paraphrase." Let it propose 64 candidates and keep the closest one, and it gets ~90% of the way there. This demo samples a handful and picks by consensus, so quality sits in between. Read each verbalization as **roughly what neighborhood of meaning the model is in at that step** — not a transcript of its thoughts.
835
+ """)
836
+
837
+ return app
838
+
839
+
840
+ if __name__ == "__main__":
841
+ app = build_app()
842
+ # Pre-warm weights when we have a persistent GPU. On ZeroGPU we can't
843
+ # touch CUDA outside an @spaces.GPU function, so skip the warmup there
844
+ # — the first user request takes the load hit instead.
845
+ if not _ON_ZEROGPU and DEVICE == "cuda":
846
+ try:
847
+ _get_trace()
848
+ except Exception as e: # pragma: no cover
849
+ log.warning("warmup skipped: %s", e)
850
+ app.queue(default_concurrency_limit=1, max_size=20).launch(
851
+ server_name="0.0.0.0",
852
+ server_port=int(os.environ.get("PORT", 7860)),
853
+ share=False,
854
+ )