perf(qwen3next): fuse the three MoE gate GEMMs into one (decode)#1358
perf(qwen3next): fuse the three MoE gate GEMMs into one (decode)#1358sufubao wants to merge 2 commits into
Conversation
The shared-expert gate_up (TP-sharded), router gate (replicated) and shared-expert gate logit (replicated) all read the same input. Concatenate their weights lazily (bf16/no-quant only, EP excluded, never during graph capture) into one GEMM whose output width is zero-padded to a multiple of 8 (odd widths force a slow cuBLAS align1 kernel), replacing three same-input small GEMMs. silu_and_mul_fwd now accepts row-strided views so the gate_up slice feeds it directly; the router slice is copied contiguous for sgl topk_softmax. Extracted standalone onto main from the split_decode_speed stack (cherry-pick of a58449e9). The original sat atop a decode refactor that returned shared_gate_logits to a downstream fused combine kernel; here the shared-expert sigmoid gate is applied inline (column 0 of the [shared gate logit | zero-pad] tail) to match main's _compute_shared_expert.
There was a problem hiding this comment.
Code Review
This pull request introduces fused MoE gate weight optimization for the Qwen3Next model by fusing the weights of gate_up_proj, moe_gate, and ffn_gate into a single tensor when conditions permit. The inference pipeline is updated to leverage these fused weights, reducing the number of matrix multiplications. Additionally, the assertion in silu_and_mul_fwd is relaxed to only require a stride of 1 on the last dimension. Feedback suggests explicitly releasing references to large intermediate tensors (shared_gate and fused_out) by setting them to None after use to optimize memory usage and reduce peak memory overhead.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| shared_gate = fused_out[:, gate_up_cols + n_experts : gate_up_cols + n_experts + 1].sigmoid_() | ||
| shared_expert_out.mul_(shared_gate) | ||
| router_logits = self.alloc_tensor((num_tokens, n_experts), hidden_states.dtype) | ||
| router_logits.copy_(fused_out[:, gate_up_cols : gate_up_cols + n_experts]) |
There was a problem hiding this comment.
To optimize memory usage, we should explicitly release the references to shared_gate and fused_out by setting them to None after router_logits.copy_. Since fused_out is a large intermediate tensor of shape (num_tokens, padded_size), keeping it in scope during the execution of layer_weight.experts.experts (which is a heavy operation that allocates its own large tensors) unnecessarily increases peak memory usage. Setting them to None allows the allocator to reclaim or reuse this memory immediately.
| shared_gate = fused_out[:, gate_up_cols + n_experts : gate_up_cols + n_experts + 1].sigmoid_() | |
| shared_expert_out.mul_(shared_gate) | |
| router_logits = self.alloc_tensor((num_tokens, n_experts), hidden_states.dtype) | |
| router_logits.copy_(fused_out[:, gate_up_cols : gate_up_cols + n_experts]) | |
| shared_gate = fused_out[:, gate_up_cols + n_experts : gate_up_cols + n_experts + 1].sigmoid_() | |
| shared_expert_out.mul_(shared_gate) | |
| router_logits = self.alloc_tensor((num_tokens, n_experts), hidden_states.dtype) | |
| router_logits.copy_(fused_out[:, gate_up_cols : gate_up_cols + n_experts]) | |
| shared_gate = None | |
| fused_out = None |
softmax_topk now casts to fp32 inside the kernel and reads per-row strides, and fused_topk routes a non-contiguous gating tensor to it instead of the sgl topk_softmax (which needs contiguous). So the router slice of the fused gate GEMM output is passed straight to the experts with no contiguous copy. The triton path is bit-identical to sgl for these shapes (ids match, weight maxdiff 0.0). Extracted standalone onto main from the split_decode_speed stack (cherry-pick of 09603789), stacked on the fused-gate-GEMM commit. The two topk kernel files apply verbatim; the qwen3next _moe_ffn_tp hunk is adapted to this branch's inline shared-expert sigmoid gate (the in-place sigmoid touches a disjoint column, so the strided router-slice view stays valid).
What
Two stacked decode optimizations for Qwen3-Next-architecture MoE layers. In every MoE FFN the shared-expert
gate_up(TP-sharded), the routergate(replicated), and the shared-expertgate_logit(replicated) all read the same input.Commit 1 — fuse the three gate GEMMs into one. Concatenate their weights lazily into a single GEMM whose output is zero-padded to a multiple of 8 (odd widths force a slow cuBLAS
align1kernel).silu_and_mul_fwdaccepts row-strided views, so thegate_upslice feeds it directly (no copy); the shared-expert sigmoid gate is applied inline (col 0 of the padded tail), matching_compute_shared_expert.Commit 2 — drop the per-layer router-logits copy.
softmax_topkcasts to fp32 inside the kernel and reads per-row strides, andfused_topkroutes the non-contiguous router slice of the fused output to it (instead of sgltopk_softmax, which needs contiguous). The router slice is passed straight to the experts with no contiguous copy. The triton path is bit-identical to sgl for these shapes (ids match, weight maxdiff 0.0).Files (5):
moe_silu_and_mul.py,softmax_topk.py,topk_select.py,qwen3next/layer_infer/transformer_layer_infer.py,qwen3next/layer_weights/transformer_layer_weight.py.Activation guard (no behavior change otherwise)
Lazy, cached per layer. The fused path is taken only when:
NoQuantization, bf16, EP disabled, not during stream capture, no bias on the three gates. Otherwise it transparently falls back to the original three-GEMM path.Performance verification
Model: Qwen3.5-122B-A10B · TP8 (8×H200) · bf16 · fa3.
Harness: static benchmark (no server/scheduler)
test/benchmark/static_inference/test_model.py.Method: A/B = HEAD (both commits) vs parent
d471c212(the 5 changed files reverted), identical config, 3 reps, CUDA graph ON,input_len=256 / output_len=128.Reproduce
Decode throughput (tokens/s, mean ± std over 3 reps, steady-state step)
d471c212)router copy_+ sgl topk cost canceled the fusion gain). Dropping the copy (commit 2) recovered it to +4.5 % — it added +1.3–4.6 % across the sweep.Kernel-level (bs64 decode step,
--torch_profile --disable_cudagraph)aten::mmcalls 301 → 205 (−96 = 2 × 48 layers) — 3 gate GEMMs collapsed to 1 per layer. The strided-viewsilu_and_mulis free (same 96 calls, same time). The routercopy_(+48 calls/step in commit 1) is gone in commit 2.Accuracy verification
GSM8K via
test/acc/test_gsmk.py(5-shot/generate, 300 questions,max_tokens=4096, parallel 128), 122B server TP8--disable_cudagraph.d471c212)Δ = +0.4 pp ≈ 1 question out of 300 → parity, no regression (even at T=0, batching/scheduling causes ±1–2 Q run-to-run). Consistent with commit 2's bit-identical claim.
Correctness
Identical behavior on the fallback path (guard returns the original three-GEMM code). The fused path applies the shared-expert sigmoid gate inline to match
_compute_shared_expert. Activation confirmed empirically: the only code difference between the two sides is this PR, and decode improves at every batch with no accuracy change.