feat(qwen3_5_mtp): Qwen3.5 / Qwen3.5-MoE MTP speculative decoding#1338
feat(qwen3_5_mtp): Qwen3.5 / Qwen3.5-MoE MTP speculative decoding#1338sufubao wants to merge 15 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements comprehensive multi-token prediction (MTP) and speculative decoding support for Qwen3.5 and Qwen3Next models, including updates to attention backends, CUDA graph warmup layouts, and Triton kernels for spec-decode updates. Feedback on the changes highlights a performance concern regarding synchronous device-to-host transfers when validating b_num_accepted_tokens on the GPU, as well as a potential division-by-zero error in the causal conv1d update kernel if the batch size is empty.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| assert int(b_num_accepted_tokens.min()) >= 1 and int(b_num_accepted_tokens.max()) <= mtp_step + 1, ( | ||
| f"b_num_accepted_tokens out of range [1, {mtp_step + 1}]: " | ||
| f"min={int(b_num_accepted_tokens.min())} max={int(b_num_accepted_tokens.max())}" | ||
| ) |
There was a problem hiding this comment.
Calling .min() and .max() on the GPU tensor b_num_accepted_tokens and casting to int causes a synchronous device-to-host (D2H) transfer. Since this function is called on the eager decode hot path, this synchronization will stall the CPU and degrade inference performance. Consider performing this validation on the CPU side (e.g., on the list of mtp_accept_len in infer_batch.py before moving it to CUDA) to keep the execution fully asynchronous.
| assert conv_state_indices is not None | ||
| batch = conv_state_indices.size(0) | ||
| dim = x.size(1) |
There was a problem hiding this comment.
If batch (the size of conv_state_indices) is 0, calculating seqlen = x.size(0) // batch will raise a ZeroDivisionError. To make the function robust against empty batches (which can occur in edge cases or testing), add a defensive guard to return x immediately if batch == 0.
assert conv_state_indices is not None\n batch = conv_state_indices.size(0)\n if batch == 0:\n return x\n dim = x.size(1)59705de to
d3279c8
Compare
Model-agnostic verify-decode machinery: MTP-verify dispatch in TpPartBaseModel, dedicated decode CUDA-graph capture/replay for the (mtp_step+1)-expanded verify layout, a shared mtp_verify_extra_state block on infer_struct/batch_objs, fa3 decode attention narrowed to the verify layout (b_att_seq_len + causal) for fp/fp8/mla, and env/kv-cache helpers for MTP added-layer accounting.
Self-contained dense (qwen3_5_mtp) and MoE (qwen3_5_moe_mtp) MTP draft packages: each carries its own draft wiring (reuse the main model's req/mem managers + rope caches, is_mtp_draft_model marker) and shares a weight-retarget mixin (mtp.* head, embeddings shared with the main model) plus the MTP pre-layer fuse. No shared model base class.
Gated-delta-net (linear attention) speculative-decode verify path for qwen3next: a per-sequence spec causal_conv1d kernel; a widened conv working slot split from the committed (narrow) persisted slot; MTP draft full-attn KV-slot accounting across the linear-att cache config, mem operator and req manager; and removal of the dead gen_b_req_mtp_start_loc kernel.
Wire the verify path through the inference backends: a single draft-model factory keyed on (model_type, mtp_mode); build the (mtp_step+1)-expanded verify decode batch; run the eagle + vanilla draft decode; verify accepted tokens; and thread per-request accept-lengths (b_num_accepted_tokens) from the chunked-prefill and dp backends into the model verify forward.
Behavioural/CUDA coverage for the subtle MTP paths: verify-extra-state metadata, decode CUDA-graph verify layouts, fa3 fp8 verify narrowing, GDN verify equivalence, the spec causal_conv1d kernel and its prefill->decode roundtrip, and the linear-att conv/SSM widened-slot split + snapshot + CPU-cache persistence. Also extends the static-inference MTP benchmark and anchors the .gitignore benchmark-output rule to /benchmark.
Restore blank lines that were stripped from pre-existing definitions (black-induced reformatting of upstream code that this PR didn't functionally change). Keeps the diff focused on the MTP feature; fixing historical formatting is out of scope for this PR.
Scope this branch to Qwen3.5 MTP support only by rolling back the EAGLE-mode draft optimization. The draft model again runs the full (mtp_step+1)-expanded verify layout instead of being narrowed to the single accepted row per request. - dp/chunked _draft_decode_eagle: restore full-layout draft (copy.copy + b_num_accepted_tokens=None so it routes to the (bs, False) graph); drop the per-rank padding helpers and accepted-row narrowing. - base_backend: remove _build_eagle_accepted_draft_input / _scatter_accepted_next_token_ids. - cuda_graph: the draft runs at multiples of (mtp_step+1) again, so collapse the dual batch-size sets to one and delete the now-redundant _get_graph_batch_sizes routing. Keep the (bs, is_mtp_verify_decode) graph key + verify-layout warmup (core GDN verify support, not the optimization). - static benchmark: eagle path now measures the full-layout draft cost. - tests: drop the two narrowed-draft tests; rewrite the dual-set tests to the single-set model (still cover the verify/normal key distinction).
Drop the remaining draft-side divergence from upstream so this branch is scoped to Qwen3.5 MTP support only. The draft decode no longer clears b_num_accepted_tokens to force a flat/normal layout; it reuses the main model_input (still copy.copy'd to isolate per-step input_ids/b_seq_len/ mem_indexes mutations) and runs the same (mtp_step+1)-grouped verify decode layout as the main model — exactly as upstream does. For the pure-full-attention draft (qwen3_5_mtp: full_attention_interval=1, no GDN) grouped and flat are numerically identical: each position k sees KV [0, s+k) either way, same page-table entries, same RoPE positions; the main verify forward already uses this geometry and is the validated path. The earlier flat-draft only added an unnecessary (bs, False) cudagraph layout + b_num_accepted_tokens gating; nothing the draft computes needs it. - chunked_prefill/dp_backend: 6 draft fns (vanilla/eagle + dp overlap variants) stop clearing b_num_accepted_tokens. - cuda_graph: draft warms up the verify graph key too (mtp_step>0 -> verify for both main and draft); delete the now-dead _is_mtp_draft_model. - tests: rewrite the warmup-layout test (main+draft both verify; mtp_step==0 -> normal) and drop the stale "draft uses normal layout" framing. Keep is_mtp_verify_decode (main-model GDN verify still needs it) and the committed fp8.py causal=True fix. Verified live (QW35-122B-A10B, eagle_with_att, mtp_step=1, tp4): GSM8K acc 0.964 / Invalid 0.000, accept 1.956/2.0 — matches pre-revert baseline (no regression). Codex independent pass concurred (high confidence).
…e plumbing - is_mtp_verify: drop the redundant `b_num_accepted_tokens is not None` clause (post grouped-revert it's implied by mtp_step>0 ∧ ¬prefill). - Replace the per-step host round-trip for b_num_accepted_tokens with a GPU-resident ReqManager.req_to_accept_len: a triton scatter_mtp_accept_len after verify + a GDN-only gather in init_mtp_verify_extra_state. Removes the gen_from_list H2D rebuild, the phase-2 req.mtp_accept_len writeback, and the host attr (linear-att offload + resets now read/write the buffer). - Drop the redundant `if mtp_step>0` guard inside decode_mtp/decode_overlap_mtp. - config_objs: inline the mtp draft-layer count, dropping the _mtp_added_layer_num helper (kept get_added_mtp_kv_layer_num inlined in envs_utils). - cpu_cache_meta: don't bump layer_num for linear-att models (the draft full-att slots are already in LinearAttCacheConfig.get_cpu_cache_big_page_bytes()). Static checks pass (ast, flake8). The req_to_accept_len refactor is not yet runtime-verified; pending a hybrid GSM8K + cudagraph-ON parity run.
7aeeafb to
0d15236
Compare
What
Adds Qwen3.5 / Qwen3.5-MoE Multi-Token Prediction (MTP / speculative decoding) end-to-end, together with the model-agnostic verify-decode machinery it needs. Builds on the
BaseMTPModelrefactor in #1337 (merged).Commits
TpPartBaseModel, dedicated decode CUDA-graph capture/replay for the(mtp_step+1)-expanded verify layout, a sharedmtp_verify_extra_stateblock, and fa3 decode attention narrowed to the verify layout (b_att_seq_len+ causal) for fp / fp8 / mla.mtp.*head, embeddings shared with the main model) and the MTP pre-layer fuse.causal_conv1dkernel, a widened conv working slot split from the committed (narrow) persisted slot, and MTP draft full-attn KV-slot accounting across the linear-att cache config / mem operator / req manager.(model_type, mtp_mode), the verify decode batch, eagle + vanilla draft decode, and per-request accept-length (b_num_accepted_tokens) transport through the chunked-prefill and dp backends..gitignorebenchmark-output rule anchored to/benchmark.Testing
pre-commit(black 21.12b0 + flake8 6.1.0) clean.unit_tests/cover verify-extra-state, decode CUDA-graph layouts, fa3 narrowing, GDN verify equivalence, linear-att conv/SSM split + CPU-cache persistence, and the draft-model factory.