- π¬ Qwen3.5-4B TopK SAE Β· Layer 18
- π Before vs. after this release
- β‘ 30-second quickstart
- π§ Feature catalog β 25 reasoning directions, pre-labeled
- π― Causal validation (feature #11424)
- π§ͺ GSM8K reasoning experiments (honest findings)
- π Specs
- π Training + eval metrics
- π Files
- π Full usage
- π§ Training details
- π§ Why a custom script?
- β οΈ Limitations
- π Citation
- π§ͺ Companion releases
- π License
- π Before vs. after this release
π¬ Qwen3.5-4B TopK SAE Β· Layer 18
The first public TopK residual-stream Sparse Autoencoder for Qwen3.5, with a labeled catalog of reasoning features and downstream RL-reward validation.
The nearest prior public SAE on the Qwen3.5 hybrid-GDN family is
kroonen-ai/sae-qwen3.5-9b
(ReLU SAE on layer-16 MLP output of Qwen3.5-9B, d_sae=16,384, ~50 M
training tokens, March 2026). Ours is distinct on every axis relevant for
downstream use: TopK vs ReLU, residual stream vs MLP output, 16Γ vs 4Γ
expansion, 200 M vs 50 M training tokens, Qwen3.5-4B vs 9B, and the only
one of the two with a validated feature-pack + RL-reward downstream pipeline
(see Stage Gates below).
Trained on residual-stream activations at layer 18 of Qwen/Qwen3.5-4B
using the TopK objective from Gao et al.
(arxiv:2406.04093), and shipped with a
ReasonScore feature catalog
(arxiv:2503.18878) that identifies the
directions in Qwen3.5's latent space that encode meta-cognitive reasoning.
π Before vs. after this release
| Before | After (this SAE) | |
|---|---|---|
| TopK residual-stream SAE for Qwen3.5 | β None (kroonen-ai/sae-qwen3.5-9b, Mar 2026, is ReLU on layer-16 MLP β different hook + activation) |
β 40,960 features, TopK k=128, residual stream layer 18 |
TransformerLens / sae_lens support |
β No | β Bypassed with plain HF forward hooks |
| Reasoning features identified | β Unknown | β Top-25 labeled via ReasonScore |
| Feature catalog | β β | β
reasoning_pack.json β per-word distributions + entropy |
| Causally validated features | β None | β
#11424 verified via activation steering β see steering_validation.json |
| Reconstruction on out-of-distribution prompts | β Unmeasured | β var_exp 0.866 on code + reasoning prompts |
| Usage | β Custom training stack needed | β
3 lines with mechreward, or raw PyTorch |
Qwen3.5 uses a hybrid Gated Delta Networks architecture that
TransformerLens and sae_lens don't yet support. This release adds a
TopK residual-stream SAE to the Qwen3.5 interpretability stack (kroonen-ai's
prior ReLU SAE targets layer-16 MLP output on Qwen3.5-9B β a complementary rather
than competing hook), and ships with the downstream pipeline β feature catalog,
causal-steering validation, and per-token RL-reward demonstration β that lets
it drop into GRPO fine-tuning without needing any custom hook infrastructure.
β‘ 30-second quickstart
pip install mechreward
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from mechreward.sae.topk_sae import load_topk_sae
tok = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-4B")
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3.5-4B", dtype=torch.bfloat16, device_map="cuda"
).eval()
sae = load_topk_sae(
"caiovicentino1/Qwen3.5-4B-SAE-L18-topk",
layer=18,
model_name="Qwen/Qwen3.5-4B",
)
ids = tok("Wait, let me think about this more carefully.",
return_tensors="pt").input_ids.cuda()
with torch.inference_mode():
out = model(input_ids=ids, output_hidden_states=True,
use_cache=False, return_dict=True)
h = out.hidden_states[19].float().squeeze(0) # layer 18 residual
features = sae.encode(h) # [seq_len, 40960]
active = (features > 0).sum(-1).float().mean()
print(f"L0 = {active.item():.1f}/128") # β 128 β structural
print("Top feature per token:", features.argmax(-1).tolist())
That's it. No forward-hook gymnastics, no TransformerLens, no training
stack. Use hidden_states[19] because the layer index in
output_hidden_states is 1-based.
π§ Feature catalog β 25 reasoning directions, pre-labeled
Identified by the ReasonScore metric (Gao et al., AIRI, arxiv:2503.18878):
ReasonScorei = p(i | DRW) Β· HiΞ± β p(i | DnRW)
Contrastive between a reasoning corpus (300 samples from
open-thoughts/OpenThoughts-114k) and a baseline corpus (300 from
HuggingFaceFW/fineweb-edu sample-10BT), restricted to Β±(2,3) token windows
around 10 meta-cognitive trigger words: alternatively, hmm, maybe, wait,
perhaps, let me, therefore, however, but, another. The entropy factor
HiΞ±=0.7 rewards features that fire across multiple
trigger words, filtering out polysemantic or word-specific features.
Top-10 reasoning features (full top-25 with per-word breakdowns in
reasoning_pack.json):
| Rank | Feature | Score | Entropy (H) | Fires strongest on |
|---|---|---|---|---|
| 1 | #11424 | 0.0171 | 2.19 | however / therefore / another / but |
| 2 | #22188 | 0.0147 | 1.90 | let me / perhaps / maybe / wait |
| 3 | #13281 | 0.0131 | 2.13 | perhaps / maybe / therefore / however |
| 4 | #33609 | 0.0095 | 2.00 | however / let me / therefore / hmm |
| 5 | #23501 | 0.0094 | 1.91 | however / another / therefore / hmm |
| 6 | #9639 | 0.0093 | 1.82 | (see reasoning_pack.json) |
| 7 | #16924 | 0.0077 | 1.66 | (see reasoning_pack.json) |
| 8 | #2015 | 0.0076 | 1.71 | (see reasoning_pack.json) |
| 9 | #8097 | 0.0073 | 1.70 | (see reasoning_pack.json) |
| 10 | #39743 | 0.0067 | 1.67 | (see reasoning_pack.json) |
Max possible entropy is log(10) β 2.30 β the top features hit β₯ 2.0, confirming they fire across the full reasoning vocabulary rather than latching onto a single word.
Feature health: 24,956 / 40,960 features are alive (60.9%), 15,154 fire on β₯2 reasoning words (H > 0.1), 7,259 on β₯3 (H > 1.0). Plenty of signal left to explore beyond the top-25.
Use the catalog
import json
from huggingface_hub import hf_hub_download
pack = json.load(open(hf_hub_download(
"caiovicentino1/Qwen3.5-4B-SAE-L18-topk", "reasoning_pack.json")))
for f in pack["top_25_features"][:5]:
top_word = max(f["per_word_distribution"].items(), key=lambda kv: kv[1])
print(f"#{f['feature_id']:>5} score={f['reason_score']:.4f} "
f"H={f['entropy']:.2f} top={top_word[0]} ({top_word[1]:.2f})")
π― Causal validation (feature #11424)
Top-25 ReasonScore gives you correlational evidence that a feature tracks reasoning. To get causal evidence, we intervene on the feature during generation via activation steering:
hidden_steered = hidden + (sae_decode(feats_modified) - sae_decode(feats))
This is a delta injection β we preserve the SAE reconstruction error and only perturb the direction corresponding to one feature. Scale = 1.0 gives zero delta (exact baseline); scale = 0.0 ablates the feature; scale = 10.0 amplifies it 10Γ.
Hook diagnostic β delta norms scale linearly
| Phase | Scale 0.0 (ablate) | Scale 1.0 (baseline) | Scale 3.0 | Scale 10.0 |
|---|---|---|---|---|
| Prefill (prompt pass) | 0.333 | 0.000 | 0.665 | 2.994 |
| Decode (per new token) | 0.045 | 0.000 | 0.083 | 0.428 |
Baseline is exactly zero (correct β scale = 1.0 cancels). Perturbation grows monotonically with amplification, confirming the steering hook works.
Text divergence vs baseline (greedy decoding, 3 prompts Γ 4 scales)
| Prompt | Ablate | Amp Γ 3 | Amp Γ 10 |
|---|---|---|---|
| "Let me analyze the tradeoffs between speed and accuracyβ¦" | 16.7 % | 0 % | 8.0 % |
| "Critics argue that AI regulation slows innovation." | 15.6 % | 0 % | 0 % |
| "The experiment produced mixed results. On one hand, accuracy improved." | 23.2 % | 0 % | 0 % |
Ablation produces 15β23 % text divergence on 3 / 3 prompts β removing #11424 reliably changes generation. Amplification is non-monotonic under greedy decoding: perturbations below the argmax margin don't flip tokens, so only Γ 10 crosses the boundary (and only on one prompt).
The smoking gun β semantic diff on prompt 1
On "Let me analyze the tradeoffsβ¦", the character-level divergence looks small but the semantic change is unambiguous:
| Intervention | Continuation snippet |
|---|---|
| Ablate (Γ 0) | "The key insight isβ¦ We can use a hash set to store elementsβ¦" |
| Amp Γ 10 | "The key insight isβ¦ Instead of checking all pairs, we can use a hash setβ¦" |
Amplification injects the contrastive construction "Instead of X, Y" β exactly the kind of alternative-enumeration scaffold that the ReasonScore profile predicted (however / therefore / another / but). Ablation suppresses it. The feature's label is:
#11424 β alternative enumeration / contrastive structure
Full 3-prompt Γ 4-scale traces + delta-norm diagnostics are in
steering_validation.json.
Try it yourself
target = model.get_submodule("model.language_model.layers.18")
current_scale = {"v": 1.0}
def steering_hook(module, inputs, output):
hidden, rest = (output[0], output[1:]) if isinstance(output, tuple) else (output, None)
flat = hidden.float().view(-1, hidden.shape[-1])
feats = sae.encode(flat)
feats_mod = feats.clone()
feats_mod[:, 11424] = feats[:, 11424] * current_scale["v"]
delta = sae.decode(feats_mod) - sae.decode(feats)
new = (flat + delta).view_as(hidden).to(hidden.dtype)
return (new, *rest) if rest is not None else new
handle = target.register_forward_hook(steering_hook)
for scale in [1.0, 0.0, 10.0]:
current_scale["v"] = scale
# model.generate(...) as usual β output changes with scale
handle.remove()
π§ͺ GSM8K reasoning experiments (honest findings)
We ran a series of interventional experiments on GSM8K (openai/gsm8k:main:test, 50
questions, Qwen3.5-4B chat-template CoT, greedy decoding) to test whether the SAE
features could boost math accuracy beyond the 88 % baseline. Results are mixed and
instructive β the mechanism is causally real but the net aggregate lift is
sample-size-dominated at n = 50.
Experiment summary
| Experiment | Features tested | Method | Net accuracy Ξ | Mechanism confirmed? |
|---|---|---|---|---|
| ReasonScore top-5 amp | 5 features by ReasonScore rank | Feature amp Γ 3 | β8.0 pp (bundle) | β ReasonScore features are reasoning rhetoric, not computation |
| Contrastive discovery | Top-5 features by mean(correct) β mean(wrong) on 50 baseline responses |
Bundle amp Γ 5β8 + suppress Γ β1 to β2 | β6.0 pp to β8.0 pp | β Recovered 3β4 of 6 baseline failures, collateral damage on 5β6 others |
| CAA (no SAE) at L13 | Single direction = mean(h_correct) β mean(h_wrong) at layer 13 |
Raw steering vector, scale +0.3 | 0 pp (net), +2 recovered / β2 lost | β Same recoverable questions as SAE bundle, lower collateral |
Key findings
1. ReasonScore features do not boost GSM8K.
The top-25 ReasonScore features (in reasoning_pack.json)
encode discourse/hedging markers (however, let me, wait, another). They are
causally validated for text-level structure (see the
causal validation section above) but not for
arithmetic reasoning. Amplifying them during GSM8K generation neither helps nor
hurts individual questions meaningfully; a bundle of 5 at amp Γ 3 costs β8 pp.
2. Contrastive feature discovery finds the right features.
Running mean(activations | correct) β mean(activations | wrong) on 50 baseline GSM8K
responses surfaces an entirely different set of features than ReasonScore:
# Top-5 helpful (robust β high absolute mean + ratio):
[35842, 6998, 39695, 15360, 8773]
# Top-5 harmful:
[30967, 14912, 33246, 23272, 18654]
These features recover 4 of the 6 baseline failures (indices 5, 13, 14, 39
out of the full wrong set [5, 8, 12, 13, 14, 39]) when amplified/suppressed
through the SAE. All 5 ReasonScore top-5 features appear as
effect = neutral in this ranking (|Cohen's d| < 1, while the contrastive top-5
have |d| > 1.8). ReasonScore and GSM8K correctness select for almost disjoint
feature populations.
3. Layer-wise contrastive mapping puts the peak at L13, not L18.
Multi-layer effect-size analysis (top-10 directions per layer) of mean residual activations across all 32 decoder layers of Qwen3.5-4B:
effect peak @ L13 (top-10 |d| = 1.391)
effect @ L18 (our SAE layer) = 1.289
L13 sits exactly at the transition where the logit lens
first shows commitment tokens (Step, ####, <think>) in Qwen3.5-4B. L18 is ~7 %
past the peak, in the post-commitment consolidation zone. This layer is not "wrong"
β it still carries recoverable causal signal β but future SAE training for math-reasoning
work should target L13βL15 on Qwen3.5-4B.
4. CAA at L13 matches SAE bundle without training new SAEs.
We built a single steering_vec = mean(h_13 | correct) β mean(h_13 | wrong) (norm
0.612 at L13) and applied hidden += scale Β· steering_vec during generation:
| Scale | Accuracy | Ξ | Gained / Lost |
|---|---|---|---|
| baseline | 88.0 % | β | β |
| +0.3 | 88.0 % | 0 pp | +2 (q5, q14) / β2 (q20, q43) |
| +0.5 | 84.0 % | β4 pp | +2 (q5, q39) / β4 |
| +1.0 | 86.0 % | β2 pp | +1 (q5) / β2 |
| β0.5 (negative control) | 80.0 % | β8 pp | confirms direction is causal |
The mechanism is real β positive scales recover the same baseline failures that the SAE bundle recovers, and the negative scale confirms directional causality by consistently hurting performance. The net aggregate is null at n = 50 because gains and losses cancel. At amp Γ 5 and above, the steering vector pushes the residual off-manifold and destroys generation entirely.
Honest limitations
- n = 50 is too small. With 88 % baseline (6 wrong), one flipped question is 2 pp of noise. We saw +13 pp at step 15 for CAA +0.3 and watched it regress to 0 pp at step 50 β that's textbook regression-to-mean. The cleanest next step is to run n = 200+ to quantify whether any scale sustains > 2 pp at statistical significance.
- Collateral damage is structural. Every intervention method we tried (ReasonScore, SAE contrastive, raw CAA) broke 2β6 questions that the baseline got right. This is the central unsolved problem β uniform magnitude steering perturbs all generations, not just the ones that needed it.
- The 4 recoverable questions are heterogeneous. Question 5 is recovered by almost every method; question 39 is recovered only by +0.5 CAA; question 13 only by SAE bundle strong amp. Different interventions find different sub-mechanisms.
Update (2026-04-16): CAA n = 100 + 55 % recovery rate
We extended the CAA scan to n = 100 held-out GSM8K questions with a fine scale sweep around the previously observed sweet spot:
| Scale | Accuracy | Ξ vs baseline | Baseline-failure recovery |
|---|---|---|---|
| baseline | 89.0 % | β | β |
| +0.15 | 87.0 % | β2.0 pp | 4/11 (36 %) |
| +0.20 | 89.0 % | 0 pp | 5/11 (45 %) |
| +0.25 | 90.0 % | +1.0 pp | 6/11 (55 %) |
| +0.30 | 90.0 % | +1.0 pp | 5/11 (45 %) |
The aggregate lift is small, but the recovery rate is the real story. At scale
+0.25, the contrastive direction at L13 recovers 55 % of the questions the
baseline got wrong β 6 of 11 failures flipped to correct. The collateral
damage falls on a consistent set of 5 questions [7, 20, 43, 58, 93] across
all positive scales, which the baseline already answers with high confidence.
This is the first clean evidence in our experiments that the mechanistic signal is structured, not random: hard questions can be recovered by pushing the residual in the discovered direction, while easy questions get perturbed without benefit. The obvious follow-up β conditional / confidence-gated steering Γ la CAST (arxiv:2409.05907), ICLR 2025 spotlight β is on the roadmap.
π§ͺ Stage Gate 1: mech-reward viability check (passed)
The mechreward repo exists because the project's original thesis is not
activation steering at inference time β it is using contrastive SAE features
as a dense reward signal for RL fine-tuning. Before committing GPU hours to
RL, we ran a passive correlation pre-test (n = 100 held-out GSM8K questions, no
steering, just measuring whether the features predict baseline correctness):
| Signal | Spearman Ο | Pearson r | p-value |
|---|---|---|---|
| R_mech_features (top-10 helpful β top-10 harmful, L18 SAE) | +0.540 | +0.726 | < 0.0001 |
| R_mech_direction (cos-sim with L13 contrastive direction) | +0.508 | +0.624 | < 0.0001 |
Both pass the "strong correlation" threshold (Ο > 0.5). The SAE feature decomposition slightly outperforms the raw residual direction, which is additional evidence that the SAE is extracting signal beyond what is visible in the raw L13 hidden state.
Interpretation: the mechanism responsible for GSM8K correctness is readable from the SAE features β it is just not controllable via uniform additive steering (as we showed above with the collateral-damage experiments). This gap between "readable" and "controllable" is exactly what reward-based RL fine-tuning can close: let the model learn when to fire the helpful features and when to suppress the harmful ones, rather than imposing the decision uniformly at every token.
Top-10 helpful features (fire more on correct answers, Cohen's d > 1.8):
[36405, 27873, 35818, 12399, 2643, 6998, 15360, 25868, 21154, 35842]
Top-10 harmful features (fire more on wrong answers):
[18654, 36412, 15686, 23272, 13672, 14912, 5863, 29516, 30690, 40589]
π Stage Gate 2: RL training with mech-reward (passed)
We ran GRPO fine-tuning on 500 GSM8K training questions (100 steps, 4 rollouts per question, LR = 1e-6, KL penalty Ξ² = 0.05) with three reward configurations on the same base model, same seed, same eval set (100 held-out questions):
| Run | Reward | Final eval acc | Ξ vs step 0 | Ξ vs R0 |
|---|---|---|---|---|
| R0 | Outcome only (binary) | 74 % | +10 pp | β (baseline) |
| R1 | Outcome + SAE features (Ξ» = 0.1) | 76 % | +12 pp | +2 pp β |
| R2 | Outcome + raw L13 direction (Ξ» = 0.1) | 65 % | +1 pp | β9 pp β |
R1 passes the +2 pp threshold over outcome-only. The convergence profile is even more striking than the final gap:
Step 0 20 40 60 80 100
R0: 64 % 68 % 70 % 73 % 76 % 74 % β drops at end
R1: 64 % 70 % 74 % 75 % 75 % 76 % β stable at end
R1 reaches R0's final accuracy (74 %) at step 40 β 2.5Γ faster convergence. The mech-reward signal provides denser per-response gradient that accelerates early learning. At convergence, R0 drops from its step-80 peak (76 % β 74 %) while R1 holds or climbs (75 % β 76 %), suggesting the SAE features act as a regularizer that prevents late-training degradation.
R2 (raw direction without SAE decomposition) is actively harmful: R2 finished at 65 % (+1 pp from baseline), a full β9 pp below outcome-only R0 and β11 pp below SAE-based R1. The raw contrastive direction β which passively correlates with correctness at Ο = 0.508 β produced gradient signals that actively interfered with the outcome-driven learning. R2's eval trajectory (64 β 61 β 64 β 65 β 66 β 65) shows a model that barely learns, spending most of its budget recovering from the early damage caused by the noisy reward signal.
The 11 pp gap between R1 and R2 (76 % vs 65 %) β using the same underlying contrastive signal, same model, same training data, same seed β is the strongest evidence in our pipeline that SAE feature decomposition is the necessary processing step. The raw L13 direction is polysemantic: it correlates with correctness on average but also encodes task-irrelevant variation (difficulty, length, style) that corrupts the policy gradient. The SAE's sparse decomposition into 10 helpful + 10 harmful features acts as a denoiser β removing the polysemantic components and leaving a clean, sparse optimization target that the GRPO policy can reliably learn from.
Reproducible artifacts
All raw results are saved as JSON in this repo:
| File | Content |
|---|---|
reasoning_pack.json |
ReasonScore top-25 + per-word distributions |
steering_validation.json |
Causal validation of feature #11424 (delta norms, traces) |
Pending upload: contrastive discovery JSON, CAA sweep n = 100, Stage Gate 1 correlation, Stage Gate 2 GRPO traces.
ποΈ Stage Gate 3 Phase A: ceiling broken at scale (passed)
π Full write-up: Per-token SAE features as online RL reward: breaking the G2 76% ceiling on GSM8K (LessWrong, 2026-04-17)
Building on G2's C2 validation, we scaled up to test whether per-token
mech-reward could break the 76 % G2 R1 ceiling. GRPO trained on 7500 GSM8K
questions with Qwen3.5-4B + LoRA r=32, raw prompt Q:/A:, 4 rollouts Γ 4
questions/step (16 rollouts/step), max_gen_len=256, LR=3e-6, KL Ξ²=0.05,
Ξ»_mech=0.1, seed=42. Trained LoRA adapter:
caiovicentino1/Qwen3.5-4B-mechreward-G3-phaseA-step400.
Step-400 eval (vs baseline with LoRA disabled via model.disable_adapter()):
| Metric | Baseline | G3 @ step 400 | Ξ |
|---|---|---|---|
| GSM8K (500Q greedy, raw prompt) | 64.00 % | 83.00 % | +19 pp |
| MMLU (200Q raw zeroshot) | 50.00 % | 54.50 % | +4.50 pp |
| MATH-500 (500Q greedy, not trained on) | β | 18.20 % | β (transfer) |
| Hack rate (canaries n=50) | 4.0 % (2/50) | 8.0 % (4/50) | +4 pp (within 95 % CI) |
| Correct under canary (n=50) | 18.0 % | 28.0 % | +10 pp |
| Ambiguous under canary (n=50) | 78.0 % | 64.0 % | β14 pp |
Ceiling-break claim: G2 R1 final was 76 % in 100 steps with
trajectory-level mech-reward. G3 Phase A reaches 83 % at nominal step 400,
but the first 232 steps ran at the G2-documented LR=1e-6 and produced zero
lift (KL stuck at 0.018, quick_gsm8k@200 = 64 % == baseline). Only after
raising LR to 3e-6 at step 232 did training actually move β so the effective
training budget that broke the ceiling was 168 steps at LR=3e-6 (step
232 β 400), roughly 1.68Γ G2 R1's 100-step budget for a +7 pp gain on
GSM8K and using the same 20 contrastive features (10 helpful + 10 harmful).
We stopped at step 400 of 2000 planned because the C2-extended target
(β₯ 80 %) was already achieved with margin.
Mech signal trajectory during training β key evidence that features drive the lift, not just outcome-only GRPO:
- Steps 1β200 at LR=1e-6 (G2's documented value):
mechoscillated at β0.02, KL stuck at 0.018,quick_gsm8k=64 %(same as baseline β no lift). - Step 232 diagnostic:
gnormwas always < 0.5 throughout, soclip_grad_norm_=1.0was inert β clipping was never the bottleneck. LR was. Raising LR to 3e-6 produced immediate learning. - Steps 232β400 at LR=3e-6:
mechrose β0.02 β +0.076 (step 250) β +0.382 (step 270) β +0.58 peak (step 330) β plateau ~+0.4-0.5. KL climbed 0.018 β 0.11. Training outcome at temp-0.9 went 0.55 β 0.75 avg.
Anti-Goodhart validation (n=50 canaries): with 10-canary measurement the hack rate read 30 % (statistical noise in a 10-point test). Expanding to 50 canaries across 5 hack patterns (repetition demands, reasoning-rhetoric abuse, forced templates, instruction bypass, chain-of-reasoning abuse), the trained policy showed hack_rate = 8 % (vs 4 % baseline, within 95 % CI), while correct-under-adversarial-prompt rose +10 pp (18 β 28 %). The policy is both more confident and more resilient to distraction β the opposite of the Goodhart failure mode.
Key lessons applied to future mechreward RL runs:
- Verify
gnormbefore blaming clipping. Our gnormβ0.3 proved clip=1.0 was inert for 200 steps while we wasted time on LR-unrelated hypotheses. - LR in prior work may not transfer across batch configurations. G2's 1e-6 with 4Γ1 rollouts needed 3e-6 here for 4Γ4 rollouts + per-token reward + bf16 log_softmax.
- Prompt format must match the SAE contrastive-discovery distribution. Using
a chat template when the SAE pack was discovered on raw
Q:/A:outputs keptmechnegative for 200 steps before we reverted.
π Specs
| Base model | Qwen/Qwen3.5-4B |
| Hook | Residual stream, post-layer 18 |
d_model |
2,560 |
d_sae |
40,960 (16Γ expansion) |
k (active features / token) |
128 |
| Training tokens | 200 M (FineWeb-Edu sample-10BT, streamed) |
| Optimizer | Adam, peak LR 2e-4 β cosine β floor 6e-5 |
| Aux dead-feature loss | coef = 0.125, k_aux = 512 |
| Decoder column norm | 1.0 (re-projected every 10 steps) |
b_dec init |
Geometric median over 16,384 samples |
| Precision | bf16 model, fp32 SAE |
| Hardware | NVIDIA RTX PRO 6000 Blackwell 96 GB |
π Training + eval metrics
| Metric | Training (FineWeb-Edu) | Eval (code + reasoning prompts) |
|---|---|---|
| Variance explained | β₯ 0.87 | 0.866 |
| MSE | ~0.010 | 0.0178 |
| L0 (active features / token) | 128 / 128 | 127.9 / 128 |
| Alive features | 40,950 / 40,960 | 24,956 / 40,960 (on eval prompts) |
| Features with H > 1.0 | β | 7,259 |
The eval prompts are Python snippets and meta-cognitive reasoning fragments β a different distribution from the educational articles the SAE was trained on. Reconstruction holds tight and L0 stays structural (127.9/128), confirming the TopK sparsity constraint survives distribution shift.
Dead-feature revival worked as designed: the auxiliary loss kicked in at step 5,225, pulled the peak of 9,171 dead features back to 1 by step 6,400, and the SAE stayed in healthy churn (0β10 dead) for the remaining 43k steps.
Reproduce the eval numbers
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
ckpt = torch.load(
hf_hub_download("caiovicentino1/Qwen3.5-4B-SAE-L18-topk", "sae_final.pt"),
map_location="cuda", weights_only=True,
)
W_enc, W_dec = ckpt["W_enc"].cuda().float(), ckpt["W_dec"].cuda().float()
b_enc, b_dec = ckpt["b_enc"].cuda().float(), ckpt["b_dec"].cuda().float()
k = int(ckpt["k"])
def encode(x):
pre = (x - b_dec) @ W_enc + b_enc
topv, topi = torch.topk(pre, k, dim=-1)
out = torch.zeros_like(pre)
out.scatter_(-1, topi, topv)
return torch.relu(out)
def decode(f):
return f @ W_dec + b_dec
tok = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-4B")
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3.5-4B", dtype=torch.bfloat16, device_map="cuda"
).eval()
prompts = [
"The key insight behind transformers is that attention lets every token",
"def fibonacci(n):\n if n < 2:\n return n\n return",
"Wait, let me think about this more carefully. The problem is that",
"import torch\nimport torch.nn as nn\n\nclass Attention(nn.Module):",
"A sparse autoencoder decomposes neural network activations into",
]
total_var, total_mse, total_l0 = 0.0, 0.0, 0.0
with torch.inference_mode():
for p in prompts:
ids = tok(p, return_tensors="pt").input_ids.cuda()
out = model(input_ids=ids, output_hidden_states=True,
use_cache=False, return_dict=True)
x = out.hidden_states[19].float().squeeze(0)
feats = encode(x)
recon = decode(feats)
mse = (x - recon).pow(2).mean().item()
total_var += 1 - mse / x.var().item()
total_mse += mse
total_l0 += (feats > 0).float().sum(-1).mean().item()
n = len(prompts)
print(f"var_exp={total_var/n:.4f} mse={total_mse/n:.6f} L0={total_l0/n:.1f}/{k}")
# var_exp=0.8655 mse=0.017728 L0=127.9/128
π Files
| File | Size | Purpose |
|---|---|---|
sae_final.pt |
839 MB | Trained weights (W_enc, W_dec, b_enc, b_dec + scalar meta) |
sae_final.json |
256 B | Structured metadata (d_model, d_sae, k, layer, tokens, timestamp) |
reasoning_pack.json |
~15 KB | Top-25 reasoning features, per-word distributions, entropy + stats |
steering_validation.json |
~7 KB | Causal validation of #11424 β activation steering traces, delta norms, semantic diff |
π Full usage
With mechreward (recommended)
pip install mechreward
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from mechreward.sae.topk_sae import load_topk_sae
tok = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-4B")
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3.5-4B", dtype=torch.bfloat16, device_map="cuda"
).eval()
sae = load_topk_sae(
"caiovicentino1/Qwen3.5-4B-SAE-L18-topk",
layer=18,
model_name="Qwen/Qwen3.5-4B",
)
# Extract layer-18 residual and encode
ids = tok("The sparse features behind reasoning", return_tensors="pt").input_ids.cuda()
with torch.inference_mode():
out = model(input_ids=ids, output_hidden_states=True,
use_cache=False, return_dict=True)
h = out.hidden_states[19].float().squeeze(0)
features = sae.encode(h)
print(features.shape) # [seq_len, 40960]
print((features > 0).sum(-1).float().mean()) # β 128 β structural L0
Raw (no mechreward dependency)
import torch
from huggingface_hub import hf_hub_download
ckpt = torch.load(
hf_hub_download("caiovicentino1/Qwen3.5-4B-SAE-L18-topk", "sae_final.pt"),
map_location="cuda",
weights_only=True,
)
W_enc, W_dec = ckpt["W_enc"].cuda(), ckpt["W_dec"].cuda()
b_enc, b_dec = ckpt["b_enc"].cuda(), ckpt["b_dec"].cuda()
k = int(ckpt["k"])
def encode(x): # x: [..., 2560]
pre = (x - b_dec) @ W_enc + b_enc
topv, topi = torch.topk(pre, k, dim=-1)
out = torch.zeros_like(pre)
out.scatter_(-1, topi, topv)
return torch.relu(out)
def decode(f): # f: [..., 40960]
return f @ W_dec + b_dec
π§ Training details
- Objective: MSE reconstruction of the layer-18 residual stream.
- Sparsity: hard TopK (Gao et al. 2024) β
L0 = k = 128by construction. No L1 tuning. - Dead-feature revival: auxiliary loss on the top-512 pre-activations of
dormant features, scaled by
1/8. Dead threshold: 5,000 steps without firing. - Decoder norm constraint: each decoder column re-projected to unit norm every 10 steps, preventing the weight explosion that plagues L1 SAEs.
b_decinitialization: geometric median of 16,384 sampled residual-stream activations β more robust than the mean for the heavy-tailed Qwen3.5 activation distribution (max|x|β 38 vsstdβ 0.28).- LR schedule: 5k-step warmup β cosine decay from
2e-4to6e-5floor. The non-zero floor sustains dead-feature revival throughout training. - Data:
HuggingFaceFW/fineweb-edusample-10BT, streamed in 2M-token buffers. No tokenizer-specific filtering. - Activation extraction: plain HF
output_hidden_states=Trueβ noTransformerLens, no custom hooks, ports trivially to any HF-compatible model.
π§ Why a custom script?
Qwen3.5 uses a hybrid Gated Delta Networks (GDN) architecture that
TransformerLens and
sae_lens don't yet support β their
hook infrastructure targets Gemma, Llama, and GPT-2 style attention, not
hybrid state-space layers. The training script
(mechreward/scripts/train_sae_qwen35.py) uses only HuggingFace's built-in
output_hidden_states=True, so it ports trivially to any HF-compatible model.
β οΈ Limitations
- Token budget: 200 M tokens is small relative to Gemma Scope (8 B tokens on 9 B models). Feature quality is publishable but further training would likely improve coverage.
- Single layer: only layer 18 (mid-depth residual). Other depths would need separate runs.
- Feature catalog is reasoning-focused: the top-25 pack targets meta-cognitive features specifically. Math, code, multilingual, and factual feature catalogs are future work.
- Causal validation is demo-scale: only feature #11424 has been validated by activation steering so far (3 prompts Γ 4 scales). Scaling this to the full top-25 and testing under sampled decoding is the next step.
- Layer choice is suboptimal for math reasoning. Multi-layer effect-size analysis shows contrastive signal peaks at L13 (1.391 top-10 Cohen's d) while our SAE is at L18 (1.289). L18 still carries recoverable causal features β it's in the post-commitment consolidation zone β but a future SAE trained on L13-L15 activations would likely give cleaner GSM8K steering. Re-training is expected to cost ~9 h on an RTX PRO 6000 Blackwell.
- Token budget is small compared to Gemma Scope (200 M vs 8 B tokens on
9 B models) and to our Gemma 4 E4B companion (1 B tokens,
var_exp = 0.939vs our 0.87). Re-training this SAE with 1 B tokens would likely liftvar_expto the 0.92β0.94 range and tighten the feature catalog. ~36 h on RTX PRO 6000. - No net GSM8K accuracy boost demonstrated at n = 50. See the experiments section β individual features and CAA directions recover specific baseline failures, but net aggregate lift is sample-size-dominated at 50 questions. n β₯ 200 needed to quantify.
π Citation
If you use this SAE, please cite the TopK SAE paper:
@article{gao2024scaling,
title = {Scaling and evaluating sparse autoencoders},
author = {Gao, Leo and Dupr{\'e} la Tour, Tom and Tillman, Henk and
Goh, Gabriel and Troll, Rajan and Radford, Alec and
Sutskever, Ilya and Leike, Jan and Wu, Jeffrey},
journal = {arXiv preprint arXiv:2406.04093},
year = {2024}
}
And the ReasonScore metric used to build the feature catalog:
@article{gao2025reasoning,
title = {I Have Covered All the Bases Here: Interpreting Reasoning
Features in Large Language Models via Sparse Autoencoders},
author = {AIRI Institute},
journal = {arXiv preprint arXiv:2503.18878},
year = {2025}
}
And the base model:
@misc{qwen35,
title = {Qwen3.5 Technical Report},
author = {Qwen Team},
year = {2026}
}
π§ͺ Companion releases
caiovicentino1/Gemma-4-E4B-SAE-L21-topkβ Gemma 4 E4B TopK SAE at layer 21 of 42. Trained on 1 B tokens,d_sae = 32,768,var_exp = 0.939. First public SAE for hybrid MoE reasoning research. Final n = 50 CAA sweep at L27 (peak contrastive layer) produced a null aggregate: +10 scale finished at 0 pp vs baseline while the β5 negative control finished at +2 pp β inconsistent with a directional-causality hypothesis. Both +10 and β5 recovered the same 20 % of baseline-wrong questions, but on partially disjoint subsets, indicating that moderate perturbation acts as a decision-boundary regularizer in either direction rather than encoding a "correctness direction". The monotone magnitude penalty (+15 β β12 pp) is the only directional signal that survived. This gap between readable (Ο = 0.54 correlation on the Qwen companion) and controllable (null aggregate steering on both models) is the central motivator for pivoting from inference-time steering to learned-policy reward shaping in latermechrewardstages. Full honest breakdown in the Gemma README. Released 2026-04-15, analysis 2026-04-16.
Cross-architecture observations (Qwen3.5-4B GDN vs Gemma 4 E4B MoE)
| Metric | Qwen3.5-4B (GDN) | Gemma 4 E4B (MoE) |
|---|---|---|
| Total layers | 32 | 42 |
| Peak contrastive effect layer | L13 (41 % depth) | L27 (64 % depth) |
| First English-commitment layer (logit lens) | L15 (47 %) | L30 (71 %) |
| GSM8K baseline accuracy (chat template CoT) | 89 % (n = 100) | 50 % (n = 50) |
| CAA best positive scale lift | +1 pp / 55 % recovery (@ +0.25) | 0 pp / 20 % recovery (@ +10) |
| CAA negative control lift | 0 pp (as expected) | +2 pp (beats positive β non-directional) |
| CAA overdose penalty (largest scale) | β (not tested) | β12 pp (real magnitude effect) |
| Top-10 feature effect size (Cohen's d) | 1.391 | 0.896 |
| Stage Gate 1 correlation (Ο, held-out) | 0.540 (p < 0.0001) | not yet measured |
Gemma 4 E4B commits to English mathematical reasoning 15 layers later than Qwen3.5-4B, consistent with extensive multilingual parallel processing in the early-middle stack of the ensemble-MoE architecture. The logit lens trajectory for Gemma shows dominant non-English token predictions across L0βL29 (mixed Chinese, Bengali, Hindi, Arabic, code-adjacent tokens) before converging on English reasoning formatting at L30+. This is a novel observation β no prior published SAE work has been performed on hybrid MoE architectures.
Hybrid architectures break the dense-transformer consensus on reasoning-feature localization. Published work on dense models (AIRI ReasonScore on DeepSeek-R1-Distill, Rimsky CAA on Llama-2, RouteSAE) converges on 50β60 % depth. Qwen3.5 GDN peaks at 41 % (earlier, consistent with linear-attention collapsing information earlier in the stack); Gemma 4 E4B MoE peaks at 64 % (later, consistent with longer distributed processing before consolidation).
π License
Apache-2.0 β same as the base model Qwen/Qwen3.5-4B.


