πŸ”¬ Gemma 4 E4B TopK SAE Β· Layer 21

The first public Sparse Autoencoder for Gemma 4 E4B β€” a hybrid ensemble-MoE architecture with no prior published mechanistic interpretability work.

Trained on residual-stream activations at layer 21 of google/gemma-4-e4b using the TopK objective from Gao et al. (arxiv:2406.04093). Shipped together with a Qwen3.5-4B counterpart at caiovicentino1/Qwen3.5-4B-SAE-L18-topk to enable cross-architecture mechanistic comparisons between a linear-attention hybrid (Qwen3.5 Gated Delta Networks) and a hybrid ensemble-MoE (Gemma 4).


πŸ“ Specs

Base model google/gemma-4-e4b
Architecture Hybrid ensemble-MoE (Gemma4ForConditionalGeneration)
Total layers 42
Hook Residual stream, post-layer 21 (50 % depth)
d_model 2,560
d_sae 32,768 (12.8Γ— expansion)
k (active features / token) 128
Training tokens 1 B (~31 k tokens per feature, 6Γ— denser per-feature than Qwen3.5-4B SAE)
Optimizer Adam, peak LR 2e-4 β†’ cosine β†’ floor 6e-5
Aux dead-feature loss coef = 0.125, k_aux = 512
Decoder column norm 1.0 (re-projected every 10 steps)
Precision bf16 model, fp32 SAE
Hardware NVIDIA RTX PRO 6000 Blackwell 96 GB
Total training time ~20 h

πŸ“Š Training metrics

Metric Value
Variance explained 0.939
MSE ~0.181
L0 (active features / token) 128 / 128 (structural)
Feature coverage 0.620 (62 % of features fire on at least one token in steady state)
Dead features (steady state) 0 / 32,768

Compared with the Qwen3.5-4B companion SAE (var_exp = 0.87 on 200 M tokens), this SAE achieves significantly higher reconstruction quality because of 6Γ— more training tokens per feature (1 B / 32 k vs 200 M / 40 k). Dead-feature revival was active throughout training β€” the final dead count is 0.


πŸ§ͺ GSM8K experiments (cross-architecture reasoning study)

This SAE was trained as part of the mechreward project, which studies how mechanistic interpretability signals can serve as reward functions for reasoning tasks. Results on 50 held-out GSM8K questions:

Baseline and peak layer

Gemma 4 E4B's baseline GSM8K accuracy under chat-template CoT is ~45 %, dramatically lower than Qwen3.5-4B at the same scale (89 %). This makes Gemma 4 E4B a much better test bed for steering experiments: 25 held-out failures to recover vs only 6 for Qwen.

A multi-layer contrastive effect-size analysis (Cohen's d of mean_correct βˆ’ mean_wrong at each of 42 layers) finds the peak at L27 (top-10 |d| = 0.896), not L21. Our SAE is at the 50 %-depth mark; the peak is at 64 %. This is slightly after the "mini-rising" contrastive zone (L23–L28) and just before the English-commitment zone identified by the logit lens (L30+).

Contrastive activation steering at L27 β€” null aggregate with informative substructure

We applied a raw contrastive activation-steering vector (steering_vec = mean(h_L27 | correct) βˆ’ mean(h_L27 | wrong), unit-normalized) during greedy generation on 50 held-out GSM8K questions. Scale is reported as a fraction of the L27 mean residual norm (95.3). Final results at n = 50:

Scale Perturbation Final acc Ξ” vs baseline Recovery of baseline-wrong Net flips
baseline 0 % 50.0 % (25/50) β€” β€” β€”
+2 2.1 % 48.0 % βˆ’2 pp 3 / 25 (12 %) βˆ’1
+5 5.2 % 42.0 % βˆ’8 pp 3 / 25 (12 %) βˆ’4
+10 10.5 % 50.0 % 0 pp 5 / 25 (20 %) 0
+15 15.7 % 38.0 % βˆ’12 pp 3 / 25 (12 %) βˆ’6
βˆ’5 (negative control) βˆ’5.2 % 52.0 % +2 pp 5 / 25 (20 %) +1

The negative-direction control (βˆ’5) numerically outperforms every positive scale. Early checkpoints had shown +10 sustaining +8 to +13 pp through step 35, but the final 15 questions regressed the signal to zero. Meanwhile, βˆ’5 (designed as the negative control that "should hurt" under a directional-causality hypothesis) trended positive and finished with +2 pp and +1 net flip.

The recovery sets are structurally informative:

+10 recovered  : [5, 11, 18, 26, 35]
βˆ’5  recovered  : [3, 5, 26, 35, 44]
                 ↑ ↑       ↑  ↑
Overlap (both) : [5, 26, 35]        ← 3 questions respond to *any* moderate perturbation
+10 only       : [11, 18]
βˆ’5 only        : [3, 44]

Both directions recover the same 20 % of the baseline-wrong set, but the sets are only partially overlapping. This is inconsistent with "the steering vector encodes a canonical correctness direction at L27" β€” under that hypothesis, +10 should strictly dominate βˆ’5 and their recovered sets should be nested, not partially disjoint.

The only robust signal is the monotone magnitude penalty: +15 (15.7 % perturbation) consistently dropped accuracy by ~12 pp from step 10 through step 50, and +5 consistently dropped by ~5–8 pp. These are magnitude effects, not direction effects. The most parsimonious interpretation is that moderate perturbation in either direction acts as a decision-boundary regularizer β€” some borderline-wrong questions flip to correct, some borderline-correct questions flip to wrong, the net is roughly zero, and large perturbations are destructive regardless of direction.

Why this null-with-structure matters for the mechreward thesis

This is the cleanest empirical case we have for why inference-time activation steering is the wrong application for these contrastive signals:

  • The companion Qwen3.5-4B SAE shows the same contrastive features correlate with GSM8K correctness at Spearman ρ = 0.540 on 100 held-out questions (Stage Gate 1 test, p < 0.0001). The signal is readable.
  • Yet uniform additive steering β€” both on Qwen (null aggregate after n = 100) and here on Gemma (null aggregate after n = 50, with non-directional negative control) β€” fails to convert the readable signal into a net accuracy lift. The signal is not controllable via uniform intervention.

The gap between "readable" (correlation is strong) and "controllable" (uniform steering gives net zero) is precisely the gap that learned policy methods close. The mechreward next step β€” Stage Gate 2 β€” tests exactly this by using the contrastive features as a dense reward signal in GRPO fine-tuning on the companion Qwen model, where the model learns when to fire the helpful features and when to suppress the harmful ones per token, rather than having uniform steering imposed at every generation step.

What does survive from this experiment:

  1. Layer mapping. The peak contrastive effect at L27/42 (64 % depth) is a real structural observation of Gemma 4 E4B's hybrid MoE architecture, independent of whether the direction at that layer is steerable.
  2. Magnitude penalty. The +15 β†’ βˆ’12 pp effect is a clean magnitude signal and is repeatable.
  3. Logit lens: late English commitment. L30+ as the first English-commitment layer is a novel cross-architecture finding.
  4. Higher-quality SAE. var_exp = 0.939 and 0 dead features set a new bar for publicly-released hybrid-architecture SAEs.

Why the regression matters for the mechreward thesis

This weak/ambiguous result is not a contradiction of the project thesis β€” it is in fact one of the motivations for pivoting from inference-time steering to training-time reward shaping. Uniform activation steering is sample-size dominated for the regime we care about (100+ questions on 45–90 % baselines), and is fundamentally susceptible to collateral damage on easy questions. The same steering_vec direction that failed to sustain a lift here does predict correctness on held-out data at Spearman ρ β‰ˆ 0.5 on the companion Qwen3.5-4B SAE (Stage Gate 1 test), which is the actual target application: use the direction as a dense reward signal during RL fine-tuning, not as a uniform inference-time injection.

The +15 collapse and the dose–response shape do confirm that something real lives at L27, just that inference-time intervention is the wrong extraction tool. The right extraction tool is a learned policy (GRPO) that can decide per-token and per-question when to fire the direction β€” which is what Stage Gate 2+ targets on the Qwen companion.

Logit lens: Gemma's late English commitment

Running the logit lens (project each layer's hidden state through lm_head and decode the top token) on a sample GSM8K prompt reveals a striking pattern:

L0-L29 : mixed multilingual predictions (Chinese, Bengali, Hindi, Arabic, code)
L30    : first English-commitment tokens ('####', '<think>', '###')
L31-L38: reasoning-format tokens ('Calculations', 'Solving', 'let', 'Let')
L39-L41: response-start tokens ('Janet', 'Let', 'To', '1')

Gemma 4 E4B commits to English mathematical reasoning at L30 of 42 (71 % depth), 15 layers later than Qwen3.5-4B (L15 of 32, 47 %). Multilingual parallel processing dominates most of the early-middle stack. This is a novel observation enabled by our SAE + logit-lens infrastructure and is consistent with the MoE architecture distributing processing across experts before consolidating language choice late in the stack.


πŸš€ Usage

Gemma 4 E4B's tokenizer does not include a chat template. Use the manual format below (same tokens the model was trained on):

import torch, re
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download

MODEL = "google/gemma-4-e4b"
SAE_REPO = "caiovicentino1/Gemma-4-E4B-SAE-L21-topk"
LAYER = 21

tok = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(
    MODEL, dtype=torch.bfloat16, device_map="cuda",
).eval()

# Manual chat format (tokens are in the vocab)
def build_prompt(q):
    text = (
        "<start_of_turn>user\n"
        f"Solve step by step and give final answer as '#### N'.\n\n{q}"
        "<end_of_turn>\n"
        "<start_of_turn>model\n"
    )
    return tok(text, return_tensors="pt").input_ids.to("cuda")

# Load SAE weights
ckpt = torch.load(
    hf_hub_download(SAE_REPO, "sae_final.pt"),
    map_location="cuda", weights_only=True,
)
W_enc, W_dec = ckpt["W_enc"].cuda().float(), ckpt["W_dec"].cuda().float()
b_enc, b_dec = ckpt["b_enc"].cuda().float(), ckpt["b_dec"].cuda().float()
k = int(ckpt["k"])

def sae_encode(x):
    pre = (x - b_dec) @ W_enc + b_enc
    topv, topi = torch.topk(pre, k, dim=-1)
    out = torch.zeros_like(pre)
    out.scatter_(-1, topi, topv)
    return torch.relu(out)

def sae_decode(f):
    return f @ W_dec + b_dec

# Extract layer-21 residual via output_hidden_states (no hooks needed)
q = "What is 127 + 248?"
ids = build_prompt(q)
with torch.inference_mode():
    out = model(input_ids=ids, output_hidden_states=True,
                use_cache=False, return_dict=True)
h21 = out.hidden_states[22].float().squeeze(0)   # layer 21 = index 22

features = sae_encode(h21)
print(features.shape)                          # [seq_len, 32768]
print((features > 0).sum(-1).float().mean())   # β‰ˆ 128 β€” structural L0

With activation steering (CAA at L27)

# For the peak layer L27, register a hook that adds a scaled steering vector:
target = model.get_submodule("model.language_model.layers.27")
steering_vec = torch.load("steering_vec_L27.pt").to("cuda").to(torch.bfloat16)  # unit-normalized
current_scale = {"v": 0.0}

def caa_hook(module, inputs, output):
    if current_scale["v"] == 0.0:
        return output
    hidden, rest = (output[0], output[1:]) if isinstance(output, tuple) else (output, None)
    new_hidden = hidden + current_scale["v"] * steering_vec
    return (new_hidden, *rest) if rest is not None else new_hidden

handle = target.register_forward_hook(caa_hook)
current_scale["v"] = 10.0   # productive zone for L27 (10.5 % residual perturbation)
# ... run model.generate(...) as usual ...
handle.remove()

The steering_vec_L27.pt artifact will be uploaded alongside the SAE weights β€” see contrastive_features_L27.json (pending).


πŸ”§ Why a custom script?

Gemma 4 E4B's hybrid ensemble-MoE architecture is not supported by TransformerLens or sae_lens, which target dense decoder architectures. The training script (mechreward/scripts/train_sae_gemma4.py) uses only HuggingFace's built-in output_hidden_states=True and forward hooks, so it ports trivially to any HF-compatible model.


⚠️ Limitations

  • google/gemma-4-e4b is a gated model. You will need HF access to download the base model weights.
  • GSM8K baseline is low (~45 %). This is likely a consequence of our manual chat template + enable_thinking=False configuration and does not reflect the model's true capability under its production prompting setup. The low baseline is useful for steering research (more headroom) but should not be cited as a capability benchmark.
  • CAA sample size is n = 50. The +10 scale showed +6 to +13 pp sustained across 6+ checkpoints, which is cleaner than any equivalent signal on Qwen, but a 200-question run would tighten the confidence interval further.
  • SAE is at L21, not L27. The contrastive-effect peak is at L27. Future SAE training for Gemma 4 E4B reasoning research should target L27 directly.
  • Single layer. Only L21 is released. Other depths would require retraining.
  • No feature catalog yet. Unlike the Qwen companion, this SAE does not (yet) ship a ReasonScore-based feature labeling. Contrastive discovery results will be added as a follow-up artifact.

πŸ§ͺ Companion release

  • caiovicentino1/Qwen3.5-4B-SAE-L18-topk β€” Qwen3.5-4B TopK SAE at layer 18. First public SAE for hybrid Gated Delta Networks. Ships with ReasonScore feature catalog, activation-steering causal validation, and Stage Gate 1 correlation test (ρ = 0.540, p < 0.0001) showing that the contrastive SAE features predict GSM8K correctness as a dense reward signal β€” validating the mechreward thesis of using SAE features for RL fine-tuning rather than test-time steering.

Both SAEs released under Apache-2.0 for reproducible cross-architecture interpretability research.


πŸ“š Citation

@article{gao2024scaling,
  title   = {Scaling and evaluating sparse autoencoders},
  author  = {Gao, Leo and Dupr{\'e} la Tour, Tom and Tillman, Henk and
             Goh, Gabriel and Troll, Rajan and Radford, Alec and
             Sutskever, Ilya and Leike, Jan and Wu, Jeffrey},
  journal = {arXiv preprint arXiv:2406.04093},
  year    = {2024}
}

πŸ“„ License

Apache-2.0 β€” same as the base model google/gemma-4-e4b.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Paper for caiovicentino1/Gemma-4-E4B-SAE-L21-topk