Small patch to fix MTP on current vLLM (8th of June, 26)

#8
by dnhkng - opened

I tested the Canada W4A16/FP8 DeepSeek-V4-Flash-MTP checkpoint on current upstream vLLM.

The issue is that the target model is compressed/FP8, but the MTP draft block still has BF16/unquantized tensors such as
mtp.0.*. vLLM’s DeepSeek V4 NVIDIA O-proj path currently assumes wo_a.weight_scale_inv exists, so MTP startup fails for
this mixed checkpoint.

A minimal runtime fix is to add a BF16 fallback in vllm/models/deepseek_v4/nvidia/ops/o_proj.py: if wo_a has no FP8
scale tensor, run inverse RoPE in BF16, reshape the flat grouped wo_a.weight, apply grouped matmul/einsum, then continue
into wo_b.

Pseudo-patch shape:

  def get_fp8_weight_scale(layer):                                                                                           
      if hasattr(layer, "weight_scale_inv"):                                                                                 
          return layer.weight_scale_inv                                                                                      
      if hasattr(layer, "weight_scale"):                                                                                     
          return layer.weight_scale                                                                                          
      return None                                                                                                            
                                                                                                                             
  def deep_gemm_fp8_o_proj(...):                                                                                             
      weight_scale = get_fp8_weight_scale(wo_a)                                                                              
                                                                                                                             
      if weight_scale is None:                                                                                               
          # BF16/unquantized MTP fallback:  
          # 1. reshape o: [T, H, D] -> [T, n_groups, heads_per_group, D]
          # 2. apply inverse RoPE to the rope slice
          # 3. reshape to grouped wo_a input
          # 4. if wo_a.weight is flat grouped BF16, reshape:
          #      [groups * o_lora_rank, input_size]
          #      -> [groups, o_lora_rank, input_size]
          # 5. grouped matmul:
          #      z = einsum("bgi,gri->bgr", wo_a_input, grouped_weight)
          # 6. return wo_b(z.flatten(1))
          return wo_b(z.flatten(1))

      # existing FP8 path
      fp8_einsum(
          "bhr,hdr->bhd",
          (o_fp8, o_scale),
          (wo_a.weight, weight_scale),
          z,
          recipe=einsum_recipe,
      )
      return wo_b(z.flatten(1))

Measured on a GH200-style dual-GPU system with TP2, concurrency 1, prompt length 8192, output length 1024, 5 prompts,
current upstream vLLM:

Canada DeepSeek-V4-Flash-W4A16-FP8-MTP

MTP0: 105.9 tok/s, mean TPOT 9.11 ms
MTP1: 152.7 tok/s, mean TPOT 6.20 ms, acceptance 90.8%
MTP2: 179.9 tok/s, mean TPOT 5.18 ms, acceptance 81.9%
MTP3: 193.0 tok/s, mean TPOT 4.82 ms, acceptance 71.1%
MTP4: 162.1 tok/s, mean TPOT 5.81 ms, acceptance 47.9%

So the existing Canada checkpoint can use MTP without re-quantizing the MTP tensors if vLLM supports a BF16 fallback for
the unquantized draft block. Best result in this test was MTP3 at ~193 tok/s versus ~106 tok/s without MTP.

Longer-term, the cleaner artifact-side fix would probably be to quantize the BF16 mtp.0.* tensors into the same compressed-
tensors format as the rest of the checkpoint and remove the MTP ignore rule. But the runtime fallback is enough to prove
the checkpoint’s MTP weights are usable.

Canada Quant Labs org

@dnhkng — thank you, this is great. Independent confirmation + GH200 numbers + a cleaner upstream patch shape than what we'd been carrying. A few notes on adjacent work happening in parallel:

Artifact-side fix already shipped (for the MTP repo)

This week we tracked down two issues @ajdoosh and @yangsiqt2 filed against the W4A16-FP8-MTP sibling:

  1. The ignore list in quantization_config was using on-disk safetensors names (layers.X.attn.compressor.fused_wkv_wgate, …shared_experts.w1), but stock vLLM matches against runtime fused module names (model.layers.X.attn.compressor.fused_wkv_wgate, …shared_experts.gate_up_proj). The on-disk names never matched at runtime → load failure on the main model before the MTP path even runs.

  2. The MTP draft e_proj / h_proj in DeepSeekV4MultiTokenPredictorLayer.__init__ were constructed without prefix= → empty layer_name → compressed-tensors raises Unable to find matching target for in... with a blank module name.

Fixes:

Your patch is the third piece needed — once construction passes, the BF16 wo_a forward path you described in nvidia/ops/o_proj.py is what carries through. All three together are needed for end-to-end MTP on stock vLLM; missing any one of them fails at a different stage.

Hardware coverage offer

We're running an end-to-end validation right now on 2× DGX Spark (GB10, SM 12.1a, ARM64/CUDA 13.2) with TP=2, with all three fixes applied. Currently in torch.compile post weight-load. If your nvidia/ops/o_proj.py patch lands upstream as a PR, happy to be a hardware co-validator on GB10/sm_12.1a to complement your GH200 (Hopper) coverage — that'd give k=1..4 numbers across two NVIDIA generations.

For the canada-quant/dsv4 image lineage we're carrying the same fix as an old-layout patch in scripts/patch_sm12_full_mtp.py since the production image is still on the jasl pre-refactor file paths (vllm/model_executor/layers/deepseek_v4_attention.py). When upstream lands the o_proj fix we can drop our local patch.

Re: quantizing the BF16 MTP block

That's been discussed internally too — agreed it's the cleaner long-term shape. But it requires a re-quantization pass over an already-shipped artifact, so the runtime fallback is the right immediate move. Your numbers (k=3 at 1.82×) are also strong enough that BF16-MTP-on-FP8-main may be a permanently useful pattern, not just a stopgap.

Going to post our GB10 numbers here once the run lands. Will tag the upstream PR with a link to your comment as the original analysis.

Hi @pastapaul and @ajdoosh , thanks for sharing the benchmark results.

For context, I have been investigating the related vLLM issue here:
https://github.com/vllm-project/vllm/issues/43457

My current work mainly includes:

  • Reproduced the DeepSeek-V4 MTP path in vLLM.
  • Investigated the num_speculative_tokens=2 -> next_n=3 path.
  • Built a prototype patch to enable/validate the SM90 DeepGEMM paged MQA logits path for next_n=3.
  • Ran standalone DeepGEMM tests and vLLM end-to-end benchmarks for baseline / MTP k=1 / MTP k=2.
  • Also looked into the MTP draft layer quant_config / prefix handling issue in vLLM.

In my end-to-end benchmark, I observed lower MTP acceptance rates than the numbers shared here:

Case Output tok/s Acceptance Speedup vs baseline
baseline 89.94 N/A 1.00x
mtp_k1 149.41 81.3% 1.66x
mtp_k2 155.91 63.0% 1.73x

My benchmark setup used streaming /v1/completions, sequential execution, 64 prompts, max_tokens=256, temperature=0, ignore_eos=true, and short mixed technical/code/reasoning prompts with mean prompt length around 29.8 words.

Compared with the benchmark here, one major difference seems to be prompt/output length. Your setup uses prompt length 8192 and output length 1024, while mine uses much shorter mixed prompts and shorter outputs. My guess is that long-context prompts may produce a more stable continuation distribution, which could explain the higher MTP1/MTP2 acceptance rates.

Do you think prompt/output length is the main reason for the acceptance gap, or should I also check other factors such as vLLM commit, model artifact revision, TP size, MTP draft layer quant_config handling, ignore_eos, or prompt formatting?

I would be happy to rerun a matched benchmark if you can share more details about the exact benchmark script/config.

Thanks, this is very useful context.

I agree with the three-stage breakdown:

  1. Artifact metadata fix:
    the ignore patterns need to match vLLM runtime module names, not just safetensors names.

  2. vLLM construction fix:
    MTP e_proj / h_proj need prefix= propagation so compressed-tensors can match the draft modules.

  3. vLLM execution fix:
    once the MTP draft block constructs, the NVIDIA O-proj path needs a BF16/unquantized fallback for wo_a when no FP8 scale tensor is present.

My local patch was only addressing stage 3, plus a small hf_config_path propagation fix for the speculative draft config. Good to know the artifact-
side metadata fix and the prefix= PR are already moving.

Re: acceptance rates, I agree the benchmark shape is probably a major factor. My run used long random prompts and long generations:

  • prompt length: 8192
  • output length: 1024
  • prompts: 5
  • concurrency: 1
  • TP: 2
  • temperature: 0
  • max_model_len: 32768
  • max_num_batched_tokens: 8192
  • max_num_seqs: 1
    Main flags/env were:
--distributed-executor-backend mp
--tensor-parallel-size 2
--disable-custom-all-reduce
--block-size 256
--gpu-memory-utilization 0.97
--kv-cache-dtype fp8
--generation-config vllm

NCCL_P2P_DISABLE=1
VLLM_USE_FLASHINFER_SAMPLER=0

The long-prompt/long-output shape may make the continuation distribution easier for the MTP draft to track, so I would not assume my acceptance rates
transfer directly to short mixed prompts. Your MTP1/MTP2 speedups still look directionally consistent, just with lower acceptance.

For reproducibility, I can share the benchmark harness and raw JSONs. The Canada results I measured on 2x GH200 TP=2 were:

MTP0: 105.9 tok/s, mean TPOT 9.11 ms
MTP1: 152.7 tok/s, mean TPOT 6.20 ms, acceptance 90.8%
MTP2: 179.9 tok/s, mean TPOT 5.18 ms, acceptance 81.9%
MTP3: 193.0 tok/s, mean TPOT 4.82 ms, acceptance 71.1%
MTP4: 162.1 tok/s, mean TPOT 5.81 ms, acceptance 47.9%

So I think the next useful step is a matched benchmark matrix:

  • same vLLM commit
  • same model artifact revision
  • same three fixes applied
  • same prompt/output shape
  • same num_speculative_tokens
  • then compare GH200 vs GB10

That should tell us how much of the acceptance/perf difference is workload shape versus hardware/backend behavior.

I’ll shape the O-proj change as a narrow upstream PR candidate and reference the metadata fix plus vLLM PR #44837 / issues #44817 and #43893.

Thanks, this breakdown makes sense to me.

I agree that a matched benchmark matrix would be the cleanest way to separate workload-shape effects from hardware/backend effects. For now, I probably won’t continue with additional acceptance-rate benchmarking, mainly because my current environment is not fully matched to your setup and I don’t want to over-interpret results from different prompt/output shapes, hardware, or patch combinations.

My current contribution is more focused on the reproduction and implementation-analysis side:

  • reproduced the vLLM DeepSeek-V4 MTP path;
  • investigated the num_speculative_tokens=2 -> next_n=3 path;
  • validated the SM90 DeepGEMM paged MQA logits next_n=3 prototype path;
  • checked the MTP draft-layer prefix / quant_config handling issue;
  • collected one short-prompt end-to-end benchmark as supporting evidence, but not as a directly comparable acceptance benchmark against your long-context setup.

So I think I’ll treat my current benchmark numbers as directional only. Your long-prompt / long-output results are very helpful as a separate reference point.

For the O-proj change, I agree that shaping it as a narrow upstream PR candidate sounds like the right direction. If useful, I can help review the patch logic or compare it against the failure mode I saw in the next_n=3 path, but I may not be able to run a full matched benchmark matrix at this stage.

Thanks again for sharing the harness/config details. I’ll keep following the metadata fix, vLLM PR #44837, issues #44817 / #43893, and the O-proj fallback PR.

Quick update: I opened the upstream vLLM PR for the NVIDIA O-proj execution fallback:

https://github.com/vllm-project/vllm/pull/44847

Scope is deliberately narrow: it adds a BF16/unquantized fallback in vllm/models/deepseek_v4/nvidia/ops/o_proj.py for the case where the MTP draft
wo_a projection has no FP8 scale tensor, while preserving the existing FP8 path when weight_scale_inv / weight_scale exists.

So the three pieces are now:

  1. Artifact-side metadata fix:
    runtime-name-compatible ignore patterns on the MTP artifact branch.

  2. vLLM construction fix:
    prefix= propagation for MTP e_proj / h_proj in #44837.

  3. vLLM execution fix:
    BF16 MTP O-proj fallback in #44847.

The PR includes focused unit tests for the scale detection, BF16 inverse-RoPE path, grouped wo_a.weight reshape, and the deep_gemm_fp8_o_proj
fallback branch.

Hi @pastapaul and @ajdoosh , one thing I want to clarify: in the original vLLM issue #43457, num_speculative_tokens=2 maps to next_n=3, which triggers the DeepGEMM Hopper paged_mqa_logits assertion (next_n == 1 or next_n == 2). So if your MTP2 result means num_speculative_tokens=2, did your benchmark include a DeepGEMM-side patch or any fallback that avoids the original SM90 next_n=3 assertion path? Or does MTP2 refer to a different setting in your benchmark?

Sign up or log in to comment