Skip to content

perf(qwen3next): fuse the three MoE gate GEMMs into one (decode)#1358

Closed
sufubao wants to merge 2 commits into
ModelTC:mainfrom
sufubao:perf/fused_gate_gemm
Closed

perf(qwen3next): fuse the three MoE gate GEMMs into one (decode)#1358
sufubao wants to merge 2 commits into
ModelTC:mainfrom
sufubao:perf/fused_gate_gemm

Conversation

@sufubao

@sufubao sufubao commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator

What

Two stacked decode optimizations for Qwen3-Next-architecture MoE layers. In every MoE FFN the shared-expert gate_up (TP-sharded), the router gate (replicated), and the shared-expert gate_logit (replicated) all read the same input.

Commit 1 — fuse the three gate GEMMs into one. Concatenate their weights lazily into a single GEMM whose output is zero-padded to a multiple of 8 (odd widths force a slow cuBLAS align1 kernel). silu_and_mul_fwd accepts row-strided views, so the gate_up slice feeds it directly (no copy); the shared-expert sigmoid gate is applied inline (col 0 of the padded tail), matching _compute_shared_expert.

Commit 2 — drop the per-layer router-logits copy. softmax_topk casts to fp32 inside the kernel and reads per-row strides, and fused_topk routes the non-contiguous router slice of the fused output to it (instead of sgl topk_softmax, which needs contiguous). The router slice is passed straight to the experts with no contiguous copy. The triton path is bit-identical to sgl for these shapes (ids match, weight maxdiff 0.0).

Files (5): moe_silu_and_mul.py, softmax_topk.py, topk_select.py, qwen3next/layer_infer/transformer_layer_infer.py, qwen3next/layer_weights/transformer_layer_weight.py.

Activation guard (no behavior change otherwise)

Lazy, cached per layer. The fused path is taken only when: NoQuantization, bf16, EP disabled, not during stream capture, no bias on the three gates. Otherwise it transparently falls back to the original three-GEMM path.

Performance verification

Model: Qwen3.5-122B-A10B · TP8 (8×H200) · bf16 · fa3.
Harness: static benchmark (no server/scheduler) test/benchmark/static_inference/test_model.py.
Method: A/B = HEAD (both commits) vs parent d471c212 (the 5 changed files reverted), identical config, 3 reps, CUDA graph ON, input_len=256 / output_len=128.

Reproduce

docker run --rm --gpus all --privileged --ipc=host \
  --ulimit memlock=-1 --ulimit stack=67108864 \
  -v /dev/shm/:/dev/shm/ -v $MODEL_DIR:/model:ro -v $REPO:/lightllm \
  -e CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -e LOADWORKER=18 \
  -e FLASHINFER_DISABLE_VERSION_CHECK=1 -e PYTHONUNBUFFERED=1 \
  ghcr.io/modeltc/lightllm:main \
  bash -lc "cd /lightllm && python -u test/benchmark/static_inference/test_model.py \
    --model_dir /model --tp 8 --data_type bfloat16 \
    --nccl_host 127.0.0.1 --nccl_port 28901 --max_req_total_len 8192 \
    --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \
    --input_len 256 --output_len 128"
# BEFORE = `git checkout d471c212 -- <the 5 files>`, re-run, then restore.
# Omit --batch_size to sweep [2,8,16,32,64,128]. isl=256 avoids a bs128×isl512 mrope grid-Y overflow.

Decode throughput (tokens/s, mean ± std over 3 reps, steady-state step)

batch before (d471c212) after (HEAD) Δ
2 222.1 ± 6.5 242.3 ± 0.7 +9.1 %
8 814.7 ± 3.8 873.9 ± 5.5 +7.3 %
16 1501.6 ± 22.6 1588.4 ± 17.9 +5.8 %
32 2617.8 ± 27.4 2809.9 ± 32.4 +7.3 %
64 4476.6 ± 52.5 4676.8 ± 48.9 +4.5 %
128 7517.0 ± 68.9 7906.2 ± 108.2 +5.2 %
  • All batches: +4.5–9.1 %, statistically clear (Δ > 2σ pooled SE).
  • Commit 2 matters most at bs64. With commit 1 alone bs64 was ~0 (compute-bound: the router copy_ + sgl topk cost canceled the fusion gain). Dropping the copy (commit 2) recovered it to +4.5 % — it added +1.3–4.6 % across the sweep.
  • Prefill: flat (within noise) — the optimization targets decode.

Kernel-level (bs64 decode step, --torch_profile --disable_cudagraph)

aten::mm calls 301 → 205 (−96 = 2 × 48 layers) — 3 gate GEMMs collapsed to 1 per layer. The strided-view silu_and_mul is free (same 96 calls, same time). The router copy_ (+48 calls/step in commit 1) is gone in commit 2.

Accuracy verification

GSM8K via test/acc/test_gsmk.py (5-shot /generate, 300 questions, max_tokens=4096, parallel 128), 122B server TP8 --disable_cudagraph.

config accuracy invalid
after (HEAD, both commits) 0.977 0.000
before (d471c212) 0.973 0.000

Δ = +0.4 pp ≈ 1 question out of 300 → parity, no regression (even at T=0, batching/scheduling causes ±1–2 Q run-to-run). Consistent with commit 2's bit-identical claim.

Correctness

Identical behavior on the fallback path (guard returns the original three-GEMM code). The fused path applies the shared-expert sigmoid gate inline to match _compute_shared_expert. Activation confirmed empirically: the only code difference between the two sides is this PR, and decode improves at every batch with no accuracy change.

The shared-expert gate_up (TP-sharded), router gate (replicated) and
shared-expert gate logit (replicated) all read the same input. Concatenate
their weights lazily (bf16/no-quant only, EP excluded, never during graph
capture) into one GEMM whose output width is zero-padded to a multiple of
8 (odd widths force a slow cuBLAS align1 kernel), replacing three
same-input small GEMMs. silu_and_mul_fwd now accepts row-strided views so
the gate_up slice feeds it directly; the router slice is copied contiguous
for sgl topk_softmax.

Extracted standalone onto main from the split_decode_speed stack
(cherry-pick of a58449e9). The original sat atop a decode refactor that
returned shared_gate_logits to a downstream fused combine kernel; here the
shared-expert sigmoid gate is applied inline (column 0 of the
[shared gate logit | zero-pad] tail) to match main's _compute_shared_expert.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces fused MoE gate weight optimization for the Qwen3Next model by fusing the weights of gate_up_proj, moe_gate, and ffn_gate into a single tensor when conditions permit. The inference pipeline is updated to leverage these fused weights, reducing the number of matrix multiplications. Additionally, the assertion in silu_and_mul_fwd is relaxed to only require a stride of 1 on the last dimension. Feedback suggests explicitly releasing references to large intermediate tensors (shared_gate and fused_out) by setting them to None after use to optimize memory usage and reduce peak memory overhead.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +136 to +139
shared_gate = fused_out[:, gate_up_cols + n_experts : gate_up_cols + n_experts + 1].sigmoid_()
shared_expert_out.mul_(shared_gate)
router_logits = self.alloc_tensor((num_tokens, n_experts), hidden_states.dtype)
router_logits.copy_(fused_out[:, gate_up_cols : gate_up_cols + n_experts])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To optimize memory usage, we should explicitly release the references to shared_gate and fused_out by setting them to None after router_logits.copy_. Since fused_out is a large intermediate tensor of shape (num_tokens, padded_size), keeping it in scope during the execution of layer_weight.experts.experts (which is a heavy operation that allocates its own large tensors) unnecessarily increases peak memory usage. Setting them to None allows the allocator to reclaim or reuse this memory immediately.

Suggested change
shared_gate = fused_out[:, gate_up_cols + n_experts : gate_up_cols + n_experts + 1].sigmoid_()
shared_expert_out.mul_(shared_gate)
router_logits = self.alloc_tensor((num_tokens, n_experts), hidden_states.dtype)
router_logits.copy_(fused_out[:, gate_up_cols : gate_up_cols + n_experts])
shared_gate = fused_out[:, gate_up_cols + n_experts : gate_up_cols + n_experts + 1].sigmoid_()
shared_expert_out.mul_(shared_gate)
router_logits = self.alloc_tensor((num_tokens, n_experts), hidden_states.dtype)
router_logits.copy_(fused_out[:, gate_up_cols : gate_up_cols + n_experts])
shared_gate = None
fused_out = None

softmax_topk now casts to fp32 inside the kernel and reads per-row strides,
and fused_topk routes a non-contiguous gating tensor to it instead of the
sgl topk_softmax (which needs contiguous). So the router slice of the
fused gate GEMM output is passed straight to the experts with no contiguous
copy. The triton path is bit-identical to sgl for these shapes (ids match,
weight maxdiff 0.0).

Extracted standalone onto main from the split_decode_speed stack
(cherry-pick of 09603789), stacked on the fused-gate-GEMM commit. The two
topk kernel files apply verbatim; the qwen3next _moe_ffn_tp hunk is adapted
to this branch's inline shared-expert sigmoid gate (the in-place sigmoid
touches a disjoint column, so the strided router-slice view stays valid).
@sufubao sufubao closed this Jun 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant