Configuration Parsing Warning:In config.json: "quantization_config.modules_to_not_convert" must be an array

gemma-4-12B-it — FP8 + DSpark speculative decoding

A load-and-go FP8 quantization of google/gemma-4-12B-it, plus a reproducible recipe for running it with DeepSeek DSpark speculative decoding (draft head: deepseek-ai/dspark_gemma4_12b_block7) — on a single 32 GB Blackwell GPU (validated on an RTX 5090).

Two recipes are documented here, both starting from the same FP8 target:

recipe speed context notes
Fast (torch.compile max-autotune) 150 tok/s on code (2× a plain bf16 12B) ~32 k short/medium chat + code
Long-context (windowed KV cache) ~40–55 tok/s 128 k = 26.6 GB, 256 k = 28.7 GB, in-VRAM full 256 k on 32 GB

The weights in this repo are the FP8 target. The DSpark draft, DeepSpec loop, and the two recipes are described below (code in recipe/).

The FP8 quantization (this repo's weights)

torchao dynamic-activation / dynamic-weight float8, per-row, forced onto torch's native _scaled_mm kernel. On Blackwell (sm_120) _scaled_mm fp8 matmul is ~2.5× a bf16 matmul; the default KernelPreference.AUTO instead tries a cutlass kernel that doesn't load on sm_120/py3.12 and silently falls back to a slow dequant path — so KernelPreference.TORCH is essential.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load-and-go: the fp8 quantization_config is baked into config.json — no config needed at load.
model = AutoModelForCausalLM.from_pretrained(
    "skibare87/gemma-4-12B-it-FP8-DSpark",
    dtype=torch.bfloat16, device_map="cuda", attn_implementation="sdpa",
).eval()
tok = AutoTokenizer.from_pretrained("skibare87/gemma-4-12B-it-FP8-DSpark")

To re-quantize google/gemma-4-12B-it yourself instead of using these weights:

from transformers import AutoModelForCausalLM, TorchAoConfig
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
from torchao.quantization.quantize_.common.kernel_preference import KernelPreference

cfg = Float8DynamicActivationFloat8WeightConfig(
    granularity=PerRow(), kernel_preference=KernelPreference.TORCH,  # native _scaled_mm, not AUTO
)
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-4-12B-it", quantization_config=TorchAoConfig(cfg),
    dtype=torch.bfloat16, device_map="cuda", attn_implementation="sdpa",
).eval()
# model.save_pretrained("gemma-4-12B-it-FP8")  # <- produces the checkpoint in this repo

FP8 target ≈ 13 GB (vs ~24 GB bf16). Draft head ≈ 7 GB.

Recipe 1 — Fast (torch.compile max-autotune), ~150 tok/s on code

DSpark verifies the draft's proposals against the FP8 target. Speculation shines where the draft is predictable: on code, accept-length ≈ 5 and 150 tok/s (136–171 measured, ~2× a plain 12B); on open-ended prose accept-length ≈ 2.5 (90 tok/s). Speed comes from torch.compile(mode= "max-autotune-no-cudagraphs", dynamic=True) on the target, which fuses the fp8 activation-quant + _scaled_mm into proper triton kernels.

  • Compile is slow (30 min cold). Persist it: set TORCHINDUCTOR_CACHE_DIR off /tmp, and use torch 2.11's mega-cache (torch.compiler.save_cache_artifacts() / load_cache_artifacts()) so restarts are a cache hit (4 min) instead of a recompile. torch._dynamo.config.caching_precompile does not work with torchao fp8 (it can't serialize Float8Tensor guards).
  • (AOTInductor gives ~21 s startup / ~173 tok/s but its static cache caps context at ~32 k on 32 GB — fine for short context, superseded by Recipe 2 for long context.)

Recipe 2 — Long-context (128 k–256 k in-VRAM on 32 GB)

gemma-4-12B has 40 sliding-attention layers (window 1024) + 8 full-attention layers. A plain DynamicCache stores the sliding layers full-length, so 256 k KV would be 90 GB. Windowing the sliding layers makes 256 k KV **5 GB**. The pieces (patches in recipe/, applied to a DeepSpec checkout):

  1. windowed_cache.pySpecSlidingLayer: a crop-safe sliding cache (stores window + pad so a speculative reject never eats into the real window; get_mask_sizes reports the true stored length so gemma's sliding mask stays aligned). Validated logit-exact vs the full forward past 1024 tokens, through crop cycles.
  2. Chunked prefill (base_evaluator.patch): prefill long prompts in chunks and keep only the draft's target hidden-state layers, with a rolling window so target_hidden_states never materializes full-length. This was the real memory lever (128 k: 38.7 → 26.6 GB).
  3. Windowed draft context (evaluator.patch): the draft only proposes, so its context is windowed (DSPARK_DRAFT_CTX_WINDOW, default 16384); a cumulative-offset trick keeps absolute positions correct.
  4. Efficient SDPA backend: the 8 full-attn layers have head_dim=512; flash-attn caps at 256, so force torch.nn.attention.sdpa_kernel([EFFICIENT_ATTENTION, MATH]) — the math backend uses 32 GB for one such attention, efficient uses 4.4 GB.

Measured on the RTX 5090: 128 k = 26.6 GB / 56 s, 256 k = 28.7 GB / 181 s, both fully in-VRAM.

Heads-up: a transformers bug you'll hit at long context

DynamicCache(config=...) (and get_head_shapes) do layer_types[:-num_kv_shared_layers]; gemma-4-12B has num_kv_shared_layers = 0, so [:-0] is an empty list and the sliding-window cache layers are silently never created (everything becomes full-storage). The workaround (build the cache layers manually) is in windowed_cache.py; details + a minimal repro in recipe/transformers-num-kv-shared-layers-bug.md.

Serving it

recipe/server.py is a self-contained OpenAI-compatible /v1/chat/completions shim wrapping DeepSpec's Gemma4DSparkEvaluator, with env knobs for both recipes (DSPARK_COMPILE=1 → fast path; DSPARK_DRAFT_CTX_WINDOW / DSPARK_PREFILL_CHUNK → long context).

It supports real token-by-token streaming (a stream_callback added to DSpark's generate loop pushes accepted tokens as speculation commits them — see the patches) and thinking: gemma4 reasons in a <|channel>thought … <channel|> channel, which the shim exposes as OpenAI-style reasoning_content (streamed separately from the answer content; toggle with DSPARK_THINKING=0). Works through a LiteLLM gateway into Open WebUI.

Vision — gemma-4-12B is a VLM, and speculation carries images. Send OpenAI multimodal content (image_url) and the shim runs it through the processor, then passes pixel_values (+ mm_token_type_ids, image_position_ids) into the DSpark prefill via a prefill_mm param threaded through generate_decoding_sample. The target embeds the image, the KV cache carries it, and the draft speculates over image-aware hidden states — vision on the same speculative loop as text (accept-len ~3.2, ~32 tok/s), and it reasons about the image when thinking is on.

Gotchas

  • Use the -it (instruct) variant. The base pattern-completes and never stops; -it ships its chat template and eos_token_id: [1, 106, 50]. gemma-4 uses a <|channel|>/<|think|> (harmony-style) format, not <start_of_turn>.
  • Validated stack: torch 2.11 + torchao 0.17 + transformers 5.x, CUDA 12.8, RTX 5090 (sm_120), WSL2.
  • attn_implementation="sdpa" (not flash — head_dim 512).

Attribution & license

Recipe assembled while getting DSpark + gemma-4-12B running at long context on a single 5090; shared so others don't have to rediscover the Blackwell fp8 kernel choice, the mega-cache persistence, or the sliding-cache/num_kv_shared_layers interactions.

Downloads last month
-
Safetensors
Model size
12B params
Tensor type
F32
·
BF16
·
F8_E4M3
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for skibare87/gemma-4-12B-it-FP8-DSpark

Quantized
(218)
this model