Skip to content

feat: Qwen3.5 / Qwen3.5-MoE MTP speculative decoding#1330

Open
sufubao wants to merge 8 commits into
ModelTC:mainfrom
sufubao:qw35_mtp
Open

feat: Qwen3.5 / Qwen3.5-MoE MTP speculative decoding#1330
sufubao wants to merge 8 commits into
ModelTC:mainfrom
sufubao:qw35_mtp

Conversation

@sufubao
Copy link
Copy Markdown
Collaborator

@sufubao sufubao commented Jun 4, 2026

Summary

Adds Multi-Token Prediction (MTP) speculative decoding for Qwen3.5 and Qwen3.5-MoE (hybrid full-attention + Gated Delta Net / qwen3next linear-attention models). The draft head proposes mtp_step tokens per step and the base model verifies them in a single fused forward, giving exact greedy output with single-stream/low-concurrency latency speedups.

Implemented in 5 logical layers:

  • Linear-attn cache state split — separate widened-GPU vs narrow-persisted conv/SSM state so the verify path can snapshot/restore committed state.
  • qwen3next GDN spec-decode verify path_gdn_verify_kernel + hybrid dispatch and a vendored spec-decode causal_conv1d_spec Triton kernel.
  • Basemodel MTP decode CUDA graphs + verify dispatch — decode graph captured with the MTP verify layout; padding grouped by lcm(unit, mtp_step+1); capture-safe (no D2H syncs).
  • Scheduler MTP verify backend + accept-len transport — canonical accept_len pointer plumbing; accept-count carry committed in phase 2 (pre-forward-release) to avoid a one-step-stale read under the overlap scheduler.
  • Draft modelsqwen3_5_mtp / qwen3_5_moe_mtp packages (full-attn draft layer, mrope, inline mtp.* weights), non-colliding draft KV slots, backend registration.

Test Plan

  • 40 new unit tests pass (cache-state split, GDN verify equivalence, spec conv kernel, draft layer/slots, MTP decode CUDA graph): pytest unit_tests/common unit_tests/models/qwen3_5 unit_tests/models/qwen3next
  • E2E greedy parity: MTP output token-identical to MTP-off baseline (spec decode is exact)
  • GSM8K parity gate on Qwen3.5-27B (TP4, conc-16): off 0.905 / mtp1 0.895 / mtp2 0.915
  • black clean

Notes

  • MTP is a low-concurrency/latency accelerator: single-stream decode speedup up to ~1.57x (mtp2, accept_len 2.43) on 27B; crossover ~conc-8 where the GPU saturates.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Qwen3.5 and Qwen3.5-Moe MTP models, integrating speculative decoding with CUDA graph support and linear attention. It introduces a spec-decode capable causal_conv1d update kernel, refactors CUDA graph capture and warmup logic to support both normal and MTP verify decode layouts, and adjusts memory management to handle dedicated draft KV slots and widened GPU conv buffers. Feedback on the changes suggests correcting a typo in the function name _nomarl_prefill_att to _normal_prefill_att in fp.py to improve code readability and maintainability.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines 86 to 93
def _nomarl_prefill_att(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_control: AttControl,
alloc_func=torch.empty,
) -> torch.Tensor:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There appears to be a typo in the function name. _nomarl_prefill_att should likely be _normal_prefill_att. This should be corrected for better code readability and maintainability. You'll also need to update the call site for this function.

Suggested change
def _nomarl_prefill_att(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_control: AttControl,
alloc_func=torch.empty,
) -> torch.Tensor:
def _normal_prefill_att(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_control: AttControl,
alloc_func=torch.empty,
) -> torch.Tensor:

sufubao added 8 commits June 5, 2026 14:50
Make the linear-attention (GDN) cache able to serve a speculative
verify pass over multiple draft tokens without corrupting the canonical
per-request state:

- conv-state shape splits into a widened GPU slot (holds the in-flight
  verify window) vs the narrow slot that is persisted/restored, while
  the SSM state keeps an (S+1) block so each draft position has a slot.
- snapshot/restore helpers read the committed conv window + SSM block
  slot and reset the carried accept_len, so the next step reads from the
  canonical offset-0 / block-0 pointer.
- relax the ReqManagerForMamba / CPU-cache MTP gates for hybrid models
  (draft KV is not persisted) and enforce the S<=7 bound.

Covered by conv-state shape-split, snapshot-split, and mamba
req-manager gate unit tests.
Add the Gated DeltaNet (qwen3next) verify forward used by MTP:

- vendor a spec-decode causal_conv1d_update kernel (causal_conv1d_spec)
  so multiple draft positions can advance the conv state in one launch.
- add the _gdn_verify kernel + MTP-verify dispatch branch, building the
  verify cu_seqlens, SSM index rows, conv indices and is_mtp_verify flag
  in infer_struct, and allocate non-colliding GPU draft full-attn slots.
- run the hybrid MTP decode eagerly so the GDN verify path is honored.

Unit tests assert the GDN verify state equals sequential T=1 decode,
cover prefill conv indices, the spec conv kernel, and draft-slot layout.
Wire MTP into the base model decode path:

- capture/replay decode CUDA graphs for the MTP verify step and thread
  b_num_accepted_tokens through ModelInput / InferStateInfo.
- add the MTP-verify dispatch in basemodel and pass the per-position
  draft index into the FA3 attention backends (fp / fp8 / mla).

Covered by the MTP decode CUDA-graph unit test.
Drive the draft/verify loop from the scheduler:

- carry a canonical InferReq.mtp_accept_len pointer and persist the
  per-request accept_len across steps; build per-req
  b_num_accepted_tokens in decode_mtp and commit it in phase 2 so the
  next step reads a fresh count.
- extend the chunked_prefill backend / base_backend with the MTP verify
  dispatch and the partial-accept read offset.
Add the MTP draft model packages and register them:

- qwen3_5_mtp: a forced single full-attn-layer draft model, with the
  MTP pre-layer infer (embed/hidden norm + fc fusion) and pre/post +
  transformer-layer weight loaders reading the mtp.* namespace.
- qwen3_5_moe_mtp: the MoE variant draft weight loaders + model.
- register qwen3_5 / qwen3_5_moe MTP draft models with per-block
  draft_idx, plus the qwen3_5 verify infer_struct.

Unit tests scaffold the MTP draft layer and the hybrid verify forward.
The write-only layer_infer._draft_kv_slot was never read anywhere; the
KV-slot mapping is fully expressed via layer_num_ = draft_kv_slot * interval.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant