Small patch to fix MTP on current vLLM (8th of June, 26)
I tested the Canada W4A16/FP8 DeepSeek-V4-Flash-MTP checkpoint on current upstream vLLM.
The issue is that the target model is compressed/FP8, but the MTP draft block still has BF16/unquantized tensors such as
mtp.0.*. vLLM’s DeepSeek V4 NVIDIA O-proj path currently assumes wo_a.weight_scale_inv exists, so MTP startup fails for
this mixed checkpoint.
A minimal runtime fix is to add a BF16 fallback in vllm/models/deepseek_v4/nvidia/ops/o_proj.py: if wo_a has no FP8
scale tensor, run inverse RoPE in BF16, reshape the flat grouped wo_a.weight, apply grouped matmul/einsum, then continue
into wo_b.
Pseudo-patch shape:
def get_fp8_weight_scale(layer):
if hasattr(layer, "weight_scale_inv"):
return layer.weight_scale_inv
if hasattr(layer, "weight_scale"):
return layer.weight_scale
return None
def deep_gemm_fp8_o_proj(...):
weight_scale = get_fp8_weight_scale(wo_a)
if weight_scale is None:
# BF16/unquantized MTP fallback:
# 1. reshape o: [T, H, D] -> [T, n_groups, heads_per_group, D]
# 2. apply inverse RoPE to the rope slice
# 3. reshape to grouped wo_a input
# 4. if wo_a.weight is flat grouped BF16, reshape:
# [groups * o_lora_rank, input_size]
# -> [groups, o_lora_rank, input_size]
# 5. grouped matmul:
# z = einsum("bgi,gri->bgr", wo_a_input, grouped_weight)
# 6. return wo_b(z.flatten(1))
return wo_b(z.flatten(1))
# existing FP8 path
fp8_einsum(
"bhr,hdr->bhd",
(o_fp8, o_scale),
(wo_a.weight, weight_scale),
z,
recipe=einsum_recipe,
)
return wo_b(z.flatten(1))
Measured on a GH200-style dual-GPU system with TP2, concurrency 1, prompt length 8192, output length 1024, 5 prompts,
current upstream vLLM:
Canada DeepSeek-V4-Flash-W4A16-FP8-MTP
MTP0: 105.9 tok/s, mean TPOT 9.11 ms
MTP1: 152.7 tok/s, mean TPOT 6.20 ms, acceptance 90.8%
MTP2: 179.9 tok/s, mean TPOT 5.18 ms, acceptance 81.9%
MTP3: 193.0 tok/s, mean TPOT 4.82 ms, acceptance 71.1%
MTP4: 162.1 tok/s, mean TPOT 5.81 ms, acceptance 47.9%
So the existing Canada checkpoint can use MTP without re-quantizing the MTP tensors if vLLM supports a BF16 fallback for
the unquantized draft block. Best result in this test was MTP3 at ~193 tok/s versus ~106 tok/s without MTP.
Longer-term, the cleaner artifact-side fix would probably be to quantize the BF16 mtp.0.* tensors into the same compressed-
tensors format as the rest of the checkpoint and remove the MTP ignore rule. But the runtime fallback is enough to prove
the checkpoint’s MTP weights are usable.
@dnhkng — thank you, this is great. Independent confirmation + GH200 numbers + a cleaner upstream patch shape than what we'd been carrying. A few notes on adjacent work happening in parallel:
Artifact-side fix already shipped (for the MTP repo)
This week we tracked down two issues @ajdoosh and @yangsiqt2 filed against the W4A16-FP8-MTP sibling:
The
ignorelist inquantization_configwas using on-disk safetensors names (layers.X.attn.compressor.fused_wkv_wgate,…shared_experts.w1), but stock vLLM matches against runtime fused module names (model.layers.X.attn.compressor.fused_wkv_wgate,…shared_experts.gate_up_proj). The on-disk names never matched at runtime → load failure on the main model before the MTP path even runs.The MTP draft
e_proj/h_projinDeepSeekV4MultiTokenPredictorLayer.__init__were constructed withoutprefix=→ emptylayer_name→ compressed-tensors raisesUnable to find matching target for in...with a blank module name.
Fixes:
- Artifact: regex-based ignore patterns on branch
fix-ignore-runtime-names(commit289254e7) — metadata only, no re-quantization. - vLLM: 6-line
prefix=propagation inmtp.pyacross nvidia/amd/xpu — open as vllm-project/vllm#44837, references #44817 and @yangsiqt2 's #43893.
Your patch is the third piece needed — once construction passes, the BF16 wo_a forward path you described in nvidia/ops/o_proj.py is what carries through. All three together are needed for end-to-end MTP on stock vLLM; missing any one of them fails at a different stage.
Hardware coverage offer
We're running an end-to-end validation right now on 2× DGX Spark (GB10, SM 12.1a, ARM64/CUDA 13.2) with TP=2, with all three fixes applied. Currently in torch.compile post weight-load. If your nvidia/ops/o_proj.py patch lands upstream as a PR, happy to be a hardware co-validator on GB10/sm_12.1a to complement your GH200 (Hopper) coverage — that'd give k=1..4 numbers across two NVIDIA generations.
For the canada-quant/dsv4 image lineage we're carrying the same fix as an old-layout patch in scripts/patch_sm12_full_mtp.py since the production image is still on the jasl pre-refactor file paths (vllm/model_executor/layers/deepseek_v4_attention.py). When upstream lands the o_proj fix we can drop our local patch.
Re: quantizing the BF16 MTP block
That's been discussed internally too — agreed it's the cleaner long-term shape. But it requires a re-quantization pass over an already-shipped artifact, so the runtime fallback is the right immediate move. Your numbers (k=3 at 1.82×) are also strong enough that BF16-MTP-on-FP8-main may be a permanently useful pattern, not just a stopgap.
Going to post our GB10 numbers here once the run lands. Will tag the upstream PR with a link to your comment as the original analysis.
Hi @pastapaul and @ajdoosh , thanks for sharing the benchmark results.
For context, I have been investigating the related vLLM issue here:
https://github.com/vllm-project/vllm/issues/43457
My current work mainly includes:
- Reproduced the DeepSeek-V4 MTP path in vLLM.
- Investigated the
num_speculative_tokens=2 -> next_n=3path. - Built a prototype patch to enable/validate the SM90 DeepGEMM paged MQA logits path for
next_n=3. - Ran standalone DeepGEMM tests and vLLM end-to-end benchmarks for baseline / MTP k=1 / MTP k=2.
- Also looked into the MTP draft layer quant_config / prefix handling issue in vLLM.
In my end-to-end benchmark, I observed lower MTP acceptance rates than the numbers shared here:
| Case | Output tok/s | Acceptance | Speedup vs baseline |
|---|---|---|---|
| baseline | 89.94 | N/A | 1.00x |
| mtp_k1 | 149.41 | 81.3% | 1.66x |
| mtp_k2 | 155.91 | 63.0% | 1.73x |
My benchmark setup used streaming /v1/completions, sequential execution, 64 prompts, max_tokens=256, temperature=0, ignore_eos=true, and short mixed technical/code/reasoning prompts with mean prompt length around 29.8 words.
Compared with the benchmark here, one major difference seems to be prompt/output length. Your setup uses prompt length 8192 and output length 1024, while mine uses much shorter mixed prompts and shorter outputs. My guess is that long-context prompts may produce a more stable continuation distribution, which could explain the higher MTP1/MTP2 acceptance rates.
Do you think prompt/output length is the main reason for the acceptance gap, or should I also check other factors such as vLLM commit, model artifact revision, TP size, MTP draft layer quant_config handling, ignore_eos, or prompt formatting?
I would be happy to rerun a matched benchmark if you can share more details about the exact benchmark script/config.
Thanks, this is very useful context.
I agree with the three-stage breakdown:
Artifact metadata fix:
the ignore patterns need to match vLLM runtime module names, not just safetensors names.vLLM construction fix:
MTPe_proj/h_projneedprefix=propagation so compressed-tensors can match the draft modules.vLLM execution fix:
once the MTP draft block constructs, the NVIDIA O-proj path needs a BF16/unquantized fallback forwo_awhen no FP8 scale tensor is present.
My local patch was only addressing stage 3, plus a small hf_config_path propagation fix for the speculative draft config. Good to know the artifact-
side metadata fix and the prefix= PR are already moving.
Re: acceptance rates, I agree the benchmark shape is probably a major factor. My run used long random prompts and long generations:
- prompt length: 8192
- output length: 1024
- prompts: 5
- concurrency: 1
- TP: 2
- temperature: 0
max_model_len: 32768max_num_batched_tokens: 8192max_num_seqs: 1
Main flags/env were:
--distributed-executor-backend mp
--tensor-parallel-size 2
--disable-custom-all-reduce
--block-size 256
--gpu-memory-utilization 0.97
--kv-cache-dtype fp8
--generation-config vllm
NCCL_P2P_DISABLE=1
VLLM_USE_FLASHINFER_SAMPLER=0
The long-prompt/long-output shape may make the continuation distribution easier for the MTP draft to track, so I would not assume my acceptance rates
transfer directly to short mixed prompts. Your MTP1/MTP2 speedups still look directionally consistent, just with lower acceptance.
For reproducibility, I can share the benchmark harness and raw JSONs. The Canada results I measured on 2x GH200 TP=2 were:
MTP0: 105.9 tok/s, mean TPOT 9.11 ms
MTP1: 152.7 tok/s, mean TPOT 6.20 ms, acceptance 90.8%
MTP2: 179.9 tok/s, mean TPOT 5.18 ms, acceptance 81.9%
MTP3: 193.0 tok/s, mean TPOT 4.82 ms, acceptance 71.1%
MTP4: 162.1 tok/s, mean TPOT 5.81 ms, acceptance 47.9%
So I think the next useful step is a matched benchmark matrix:
- same vLLM commit
- same model artifact revision
- same three fixes applied
- same prompt/output shape
- same num_speculative_tokens
- then compare GH200 vs GB10
That should tell us how much of the acceptance/perf difference is workload shape versus hardware/backend behavior.
I’ll shape the O-proj change as a narrow upstream PR candidate and reference the metadata fix plus vLLM PR #44837 / issues #44817 and #43893.
Thanks, this breakdown makes sense to me.
I agree that a matched benchmark matrix would be the cleanest way to separate workload-shape effects from hardware/backend effects. For now, I probably won’t continue with additional acceptance-rate benchmarking, mainly because my current environment is not fully matched to your setup and I don’t want to over-interpret results from different prompt/output shapes, hardware, or patch combinations.
My current contribution is more focused on the reproduction and implementation-analysis side:
- reproduced the vLLM DeepSeek-V4 MTP path;
- investigated the
num_speculative_tokens=2 -> next_n=3path; - validated the SM90 DeepGEMM paged MQA logits
next_n=3prototype path; - checked the MTP draft-layer prefix / quant_config handling issue;
- collected one short-prompt end-to-end benchmark as supporting evidence, but not as a directly comparable acceptance benchmark against your long-context setup.
So I think I’ll treat my current benchmark numbers as directional only. Your long-prompt / long-output results are very helpful as a separate reference point.
For the O-proj change, I agree that shaping it as a narrow upstream PR candidate sounds like the right direction. If useful, I can help review the patch logic or compare it against the failure mode I saw in the next_n=3 path, but I may not be able to run a full matched benchmark matrix at this stage.
Thanks again for sharing the harness/config details. I’ll keep following the metadata fix, vLLM PR #44837, issues #44817 / #43893, and the O-proj fallback PR.
Quick update: I opened the upstream vLLM PR for the NVIDIA O-proj execution fallback:
https://github.com/vllm-project/vllm/pull/44847
Scope is deliberately narrow: it adds a BF16/unquantized fallback in vllm/models/deepseek_v4/nvidia/ops/o_proj.py for the case where the MTP draft
wo_a projection has no FP8 scale tensor, while preserving the existing FP8 path when weight_scale_inv / weight_scale exists.
So the three pieces are now:
Artifact-side metadata fix:
runtime-name-compatible ignore patterns on the MTP artifact branch.vLLM construction fix:
prefix=propagation for MTPe_proj/h_projin #44837.vLLM execution fix:
BF16 MTP O-proj fallback in #44847.
The PR includes focused unit tests for the scale detection, BF16 inverse-RoPE path, grouped wo_a.weight reshape, and the deep_gemm_fp8_o_proj
fallback branch.
Hi @pastapaul and @ajdoosh , one thing I want to clarify: in the original vLLM issue #43457, num_speculative_tokens=2 maps to next_n=3, which triggers the DeepGEMM Hopper paged_mqa_logits assertion (next_n == 1 or next_n == 2). So if your MTP2 result means num_speculative_tokens=2, did your benchmark include a DeepGEMM-side patch or any fallback that avoids the original SM90 next_n=3 assertion path? Or does MTP2 refer to a different setting in your benchmark?