From 28e3f756a933466f900f3f016674f567f11cd72d Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Mon, 15 Jun 2026 09:43:16 -0700 Subject: [PATCH 1/4] perf(vllm): optimize MiniMax M3 MXFP8 EP routes --- .../fixed_seq_len/minimaxm3_fp8_mi300x.sh | 19 +- .../minimaxm3_mi300x_ep_mxfp8.patch | 842 ++++++++++++++++++ 2 files changed, 857 insertions(+), 4 deletions(-) create mode 100644 benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh index 4fbe92bcd..ff5a2b1cd 100755 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh @@ -2,10 +2,10 @@ # MiniMax-M3 MXFP8 MI300X (gfx942) single-node vLLM recipe. # Reuses the dedicated ROCm image and applies the checked-in hybrid gfx94x -# MXFP8 MoE patch before starting vLLM. Block size 128 is mandatory for MSA -# sparse attention. Keep the default BF16 KV cache on gfx942: the checkpoint -# has no calibrated q/prob scales for ROCm FP8 attention, and vLLM's fallback -# scale of 1.0 corrupts model accuracy. +# MXFP8 MoE patch plus the EP8 local-route optimization before starting vLLM. +# Block size 128 is mandatory for MSA sparse attention. Keep the default BF16 +# KV cache on gfx942: the checkpoint has no calibrated q/prob scales for ROCm +# FP8 attention, and vLLM's fallback scale of 1.0 corrupts model accuracy. # Target image vLLM revision: 4a560dd8db67c270f5e2afb614558271b76f2294. source "$(dirname "$0")/../../benchmark_lib.sh" @@ -36,6 +36,7 @@ print(Path(vllm.__file__).resolve().parent.parent) PY )" MXFP8_PATCH="$(dirname "$0")/minimaxm3_mi300x_mxfp8.patch" +MXFP8_EP_PATCH="$(dirname "$0")/minimaxm3_mi300x_ep_mxfp8.patch" MXFP8_ORACLE="$VLLM_PACKAGE_ROOT/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py" if ! grep -q "Using fused CDNA3 (gfx94x)" "$MXFP8_ORACLE"; then if ! patch --batch --forward -d "$VLLM_PACKAGE_ROOT" -p1 < "$MXFP8_PATCH"; then @@ -47,6 +48,16 @@ if ! grep -q "Using fused CDNA3 (gfx94x)" "$MXFP8_ORACLE"; then echo "MI300X MXFP8 backend marker is missing after patching" >&2 exit 1 fi +if ! grep -q "profiled gfx94x MiniMax-M3 EP8" "$MXFP8_ORACLE"; then + if ! patch --batch --forward -d "$VLLM_PACKAGE_ROOT" -p1 < "$MXFP8_EP_PATCH"; then + echo "Failed to apply the MI300X EP8 MXFP8 optimization patch" >&2 + exit 1 + fi +fi +if ! grep -q "profiled gfx94x MiniMax-M3 EP8" "$MXFP8_ORACLE"; then + echo "MI300X EP8 MXFP8 optimization marker is missing after patching" >&2 + exit 1 +fi if [[ "$MODEL" != /* ]]; then hf download "$MODEL"; fi diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch new file mode 100644 index 000000000..e990a0121 --- /dev/null +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch @@ -0,0 +1,842 @@ +diff --git a/tests/kernels/test_minimax_m3_amd_ops.py b/tests/kernels/test_minimax_m3_amd_ops.py +index 4119d7f5c..fe9c524f3 100644 +--- a/tests/kernels/test_minimax_m3_amd_ops.py ++++ b/tests/kernels/test_minimax_m3_amd_ops.py +@@ -112,7 +112,8 @@ def test_mxfp8_gfx94x_grouped_gemm_config( + ({}, True), + ({"max_model_len": 10240}, True), + ({"max_model_len": 0}, False), +- ({"ep_size": 8}, False), ++ ({"ep_size": 8}, True), ++ ({"ep_size": 4}, False), + ({"has_shared_experts": False}, False), + ({"experts_per_token": 8}, False), + ({"hidden_dim": 4096}, False), +@@ -165,45 +166,92 @@ def test_mxfp8_bf16_decode_fallback_disabled_on_gfx950(monkeypatch): + + + @pytest.mark.parametrize( +- ("num_tokens", "native_weights_available", "expected"), ++ ( ++ "num_tokens", ++ "native_weights_available", ++ "bf16_weights_available", ++ "expected", ++ ), + [ +- (1, True, True), +- (8, True, True), +- (9, True, False), +- (128, True, False), +- (831, True, False), +- (832, True, True), +- (1023, True, True), +- (128, False, True), ++ (1, True, True, True), ++ (8, True, True, True), ++ (9, True, True, False), ++ (128, True, True, False), ++ (831, True, True, False), ++ (832, True, True, True), ++ (1023, True, True, True), ++ (128, False, True, True), ++ (1, True, False, False), ++ (1023, True, False, False), + ], + ) + def test_mxfp8_bf16_expert_dispatch( + num_tokens, + native_weights_available, ++ bf16_weights_available, + expected, + ): + from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( + _should_use_bf16_experts, + ) + +- assert _should_use_bf16_experts(num_tokens, native_weights_available) is expected ++ assert ( ++ _should_use_bf16_experts( ++ num_tokens, ++ native_weights_available, ++ bf16_weights_available, ++ ) ++ is expected ++ ) ++ ++ ++@pytest.mark.parametrize( ++ ("max_model_len", "layer_index", "ep_size", "expected"), ++ [ ++ (2304, 0, 1, "dual"), ++ (2304, 0, 8, "dual"), ++ (9472, 0, 1, "bf16_only"), ++ (9472, 0, 8, "native_only"), ++ (9472, 1, 8, "dual"), ++ (9472, 5, 8, "native_only"), ++ ], ++) ++def test_mxfp8_weight_storage_policy(max_model_len, layer_index, ep_size, expected): ++ from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( ++ _mxfp8_weight_storage_policy, ++ ) ++ ++ assert _mxfp8_weight_storage_policy(max_model_len, layer_index, ep_size) == expected + + + @pytest.mark.parametrize( +- ("max_model_len", "layer_index", "expected"), ++ ("num_valid_tokens", "num_local_experts", "block_m", "allocation", "expected"), + [ +- (2304, 0, False), +- (9472, 0, True), +- (9472, 1, False), +- (9472, 5, True), ++ (864, 16, 16, 2784, 1104), ++ (64, 8, 16, 1024, 176), ++ (4, 128, 16, 64, 64), + ], + ) +-def test_mxfp8_bf16_only_storage_policy(max_model_len, layer_index, expected): ++def test_mxfp8_max_post_padded( ++ num_valid_tokens, ++ num_local_experts, ++ block_m, ++ allocation, ++ expected, ++): + from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( +- _should_store_bf16_only, ++ _max_post_padded, + ) + +- assert _should_store_bf16_only(max_model_len, layer_index) is expected ++ assert ( ++ _max_post_padded( ++ num_valid_tokens, ++ num_local_experts, ++ block_m, ++ allocation, ++ ) ++ == expected ++ ) + + + @pytest.mark.skipif( +@@ -380,7 +428,17 @@ def test_mxfp8_native_linear(m, n, k): + # --------------------------------------------------------------------------- # + # Fused MXFP8 MoE grouped GEMM vs dequant-to-bf16 MoE math + # --------------------------------------------------------------------------- # +-def _ref_moe(x, w13, w2, topk_weights, topk_ids, alpha, beta, limit): ++def _ref_moe( ++ x, ++ w13, ++ w2, ++ topk_weights, ++ topk_ids, ++ alpha, ++ beta, ++ limit, ++ expert_map=None, ++): + T, H = x.shape + inter = w2.shape[-1] + top_k = topk_ids.shape[1] +@@ -388,6 +446,10 @@ def _ref_moe(x, w13, w2, topk_weights, topk_ids, alpha, beta, limit): + for t in range(T): + for j in range(top_k): + e = int(topk_ids[t, j].item()) ++ if expert_map is not None: ++ e = int(expert_map[e].item()) ++ if e < 0: ++ continue + g1 = x[t].float() @ w13[e].float().T # [2I] + gate = g1[:inter] + up = g1[inter:] +@@ -459,6 +521,98 @@ def test_mxfp8_native_moe(T, H, inter, E, top_k, pack_scales): + assert _relerr(got, ref) < 5e-2 + + ++@pytest.mark.skipif( ++ not current_platform.is_fp8_fnuz(), ++ reason="route-aware native EP optimization is specific to gfx94x.", ++) ++@torch.inference_mode() ++def test_mxfp8_native_moe_expert_parallel(): ++ from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( ++ fused_moe_mxfp8_native, ++ ) ++ ++ torch.manual_seed(0) ++ T, H, inter, global_e, local_e, top_k = 8, 256, 512, 8, 4, 2 ++ alpha, beta, limit = 1.702, 1.0, 7.0 ++ w13_bf16 = ( ++ torch.randn(local_e, 2 * inter, H, device=DEVICE, dtype=torch.bfloat16) * 0.1 ++ ) ++ w2_bf16 = torch.randn(local_e, H, inter, device=DEVICE, dtype=torch.bfloat16) * 0.1 ++ w13_fp8, w13_scale = _mxfp8_e4m3_quantize_torch( ++ w13_bf16, ++ is_sf_swizzled_layout=False, ++ ) ++ w2_fp8, w2_scale = _mxfp8_e4m3_quantize_torch( ++ w2_bf16, ++ is_sf_swizzled_layout=False, ++ ) ++ w13_fp8, w13_scale = normalize_mxfp8_e4m3fn_to_e4m3fnuz( ++ w13_fp8, ++ w13_scale, ++ ) ++ w2_fp8, w2_scale = normalize_mxfp8_e4m3fn_to_e4m3fnuz( ++ w2_fp8, ++ w2_scale, ++ ) ++ w13_deq = dequant_mxfp8_to_bf16(w13_fp8, w13_scale) ++ w2_deq = dequant_mxfp8_to_bf16(w2_fp8, w2_scale) ++ w13_scale = w13_scale.transpose(1, 2).contiguous() ++ w2_scale = w2_scale.transpose(1, 2).contiguous() ++ ++ expert_map = torch.tensor( ++ [0, -1, 1, -1, 2, -1, 3, -1], ++ device=DEVICE, ++ dtype=torch.int32, ++ ) ++ topk_ids = torch.tensor( ++ [ ++ [0, 1], ++ [2, 3], ++ [4, 6], ++ [1, 3], ++ [6, 7], ++ [0, 4], ++ [5, 7], ++ [2, 6], ++ ], ++ device=DEVICE, ++ dtype=torch.int32, ++ ) ++ topk_weights = torch.rand(T, top_k, device=DEVICE, dtype=torch.float32) ++ topk_weights /= topk_weights.sum(dim=-1, keepdim=True) ++ x = torch.randn(T, H, device=DEVICE, dtype=torch.bfloat16) * 0.5 ++ output = torch.empty_like(x) ++ ++ got = fused_moe_mxfp8_native( ++ x, ++ w13_fp8, ++ w13_scale, ++ w2_fp8, ++ w2_scale, ++ topk_weights, ++ topk_ids, ++ alpha=alpha, ++ beta=beta, ++ limit=limit, ++ global_num_experts=global_e, ++ expert_map=expert_map, ++ output=output, ++ ) ++ ref = _ref_moe( ++ x, ++ w13_deq, ++ w2_deq, ++ topk_weights, ++ topk_ids, ++ alpha, ++ beta, ++ limit, ++ expert_map=expert_map, ++ ) ++ assert got is output ++ assert _relerr(got, ref) < 5e-2 ++ ++ + # --------------------------------------------------------------------------- # + # MXFP8 linear emulation: BF16-at-load (default) vs per-step dequant + switch + # --------------------------------------------------------------------------- # +diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py +index 9e0145ff9..244065006 100644 +--- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py ++++ b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py +@@ -46,13 +46,16 @@ _BF16_DECODE_TOKEN_THRESHOLD = 8 + # between 827 and 843 tokens on MI300X. Keep the cutoff tile-aligned. + _BF16_PREFILL_TOKEN_THRESHOLD = 832 + _LONG_CONTEXT_BF16_ONLY_LAYER_STRIDE = 5 ++_WEIGHT_STORAGE_DUAL = "dual" ++_WEIGHT_STORAGE_BF16_ONLY = "bf16_only" ++_WEIGHT_STORAGE_NATIVE_ONLY = "native_only" + + + def _should_use_bf16_decode_fallback(moe_config: FusedMoEConfig) -> bool: +- """Limit BF16 fallback weights to the exact MiniMax-M3 TP shape.""" ++ """Limit mixed BF16/native storage to profiled MiniMax-M3 shapes.""" + return ( + current_platform.is_fp8_fnuz() +- and moe_config.ep_size == 1 ++ and moe_config.ep_size in (1, 8) + and moe_config.has_shared_experts + and moe_config.num_experts == 128 + and moe_config.experts_per_token == 4 +@@ -62,23 +65,45 @@ def _should_use_bf16_decode_fallback(moe_config: FusedMoEConfig) -> bool: + ) + + +-def _should_store_bf16_only(max_model_len: int, layer_index: int) -> bool: +- return ( +- max_model_len > 4096 and layer_index % _LONG_CONTEXT_BF16_ONLY_LAYER_STRIDE == 0 +- ) ++def _mxfp8_weight_storage_policy( ++ max_model_len: int, ++ layer_index: int, ++ ep_size: int, ++) -> str: ++ if max_model_len <= 4096 or layer_index % _LONG_CONTEXT_BF16_ONLY_LAYER_STRIDE != 0: ++ return _WEIGHT_STORAGE_DUAL ++ if ep_size > 1: ++ return _WEIGHT_STORAGE_NATIVE_ONLY ++ return _WEIGHT_STORAGE_BF16_ONLY + + + def _should_use_bf16_experts( + num_tokens: int, + native_weights_available: bool, ++ bf16_weights_available: bool, + ) -> bool: +- return ( ++ return bf16_weights_available and ( + not native_weights_available +- or num_tokens >= _BF16_PREFILL_TOKEN_THRESHOLD + or num_tokens <= _BF16_DECODE_TOKEN_THRESHOLD ++ or num_tokens >= _BF16_PREFILL_TOKEN_THRESHOLD + ) + + ++def _max_post_padded( ++ num_valid_tokens: int, ++ num_local_experts: int, ++ block_m: int, ++ allocation_size: int, ++) -> int: ++ """Static upper bound for a block-aligned local-expert route list.""" ++ max_padded = min( ++ allocation_size, ++ num_valid_tokens * block_m, ++ num_valid_tokens + num_local_experts * (block_m - 1), ++ ) ++ return (max_padded // block_m) * block_m ++ ++ + @triton.jit + def _mxfp8_grouped_gemm_dot_scaled_kernel( + a_ptr, +@@ -108,6 +133,7 @@ def _mxfp8_grouped_gemm_dot_scaled_kernel( + stride_cn, + A_DIV: tl.constexpr, + MUL_WEIGHT: tl.constexpr, ++ FUSE_TOPK: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +@@ -161,12 +187,14 @@ def _mxfp8_grouped_gemm_dot_scaled_kernel( + w = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0) + acc = acc * w[:, None] + +- c_ptrs = c_ptr + offs_token[:, None] * stride_cm + offs_n[None, :] * stride_cn +- tl.store( +- c_ptrs, +- acc.to(c_ptr.dtype.element_ty), +- mask=token_mask[:, None] & n_mask[None, :], +- ) ++ c_row = offs_token // top_k if FUSE_TOPK else offs_token ++ c_ptrs = c_ptr + c_row[:, None] * stride_cm + offs_n[None, :] * stride_cn ++ c_mask = token_mask[:, None] & n_mask[None, :] ++ result = acc.to(c_ptr.dtype.element_ty) ++ if FUSE_TOPK: ++ tl.atomic_add(c_ptrs, result, mask=c_mask, sem="relaxed") ++ else: ++ tl.store(c_ptrs, result, mask=c_mask) + + + @triton.jit +@@ -198,6 +226,7 @@ def _mxfp8_grouped_gemm_fnuz_kernel( + stride_cn, + A_DIV: tl.constexpr, + MUL_WEIGHT: tl.constexpr, ++ FUSE_TOPK: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +@@ -271,12 +300,14 @@ def _mxfp8_grouped_gemm_fnuz_kernel( + w = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0) + acc = acc * w[:, None] + +- c_ptrs = c_ptr + offs_token[:, None] * stride_cm + offs_n[None, :] * stride_cn +- tl.store( +- c_ptrs, +- acc.to(c_ptr.dtype.element_ty), +- mask=token_mask[:, None] & n_mask[None, :], +- ) ++ c_row = offs_token // top_k if FUSE_TOPK else offs_token ++ c_ptrs = c_ptr + c_row[:, None] * stride_cm + offs_n[None, :] * stride_cn ++ c_mask = token_mask[:, None] & n_mask[None, :] ++ result = acc.to(c_ptr.dtype.element_ty) ++ if FUSE_TOPK: ++ tl.atomic_add(c_ptrs, result, mask=c_mask, sem="relaxed") ++ else: ++ tl.store(c_ptrs, result, mask=c_mask) + + + def _gfx94x_grouped_gemm_config( +@@ -330,6 +361,9 @@ def _grouped_gemm_mxfp8( + block_n_override: int = 0, + block_k_override: int = 0, + num_warps_override: int = 0, ++ output: torch.Tensor | None = None, ++ fuse_topk: bool = False, ++ zero_nonlocal_output: bool = True, + ) -> torch.Tensor: + M_routed = num_valid_tokens + E, N, K = w.shape +@@ -363,12 +397,15 @@ def _grouped_gemm_mxfp8( + BLOCK_N = 128 + BLOCK_K = 128 + num_warps = 8 +- # moe_align_block_size allocates for the worst case where every expert is +- # active. At small batches that can be much larger than the number of +- # blocks that can contain valid assignments. Limit the launch to the +- # tighter static upper bound; the device-side num_post check handles the +- # remaining tail. +- max_post_padded = min(sorted_token_ids.shape[0], M_routed * block_m) ++ # With EP, the align buffer is sized from the global expert count even ++ # though only local assignments survive. Bound the launch by the local ++ # expert count; the device-side num_post check handles the remaining tail. ++ max_post_padded = _max_post_padded( ++ M_routed, ++ E, ++ block_m, ++ sorted_token_ids.shape[0], ++ ) + if block_n_override: + BLOCK_N = block_n_override + if block_k_override: +@@ -385,11 +422,32 @@ def _grouped_gemm_mxfp8( + m_blocks = triton.cdiv(max_post_padded, block_m) + n_blocks = triton.cdiv(N, BLOCK_N) + +- # Under expert parallelism (expert_map set) tokens routed to non-local +- # experts are dropped from sorted_token_ids, so their output rows are never +- # written. +- alloc = torch.zeros if expert_map is not None else torch.empty +- out = alloc((M_routed, N), dtype=out_dtype, device=a_q.device) ++ if fuse_topk: ++ if M_routed % top_k != 0: ++ raise ValueError( ++ f"Routed rows ({M_routed}) must be divisible by top_k ({top_k})." ++ ) ++ expected_shape = (M_routed // top_k, N) ++ if output is None: ++ out = torch.zeros(expected_shape, dtype=out_dtype, device=a_q.device) ++ else: ++ if output.shape != expected_shape or output.dtype != out_dtype: ++ raise ValueError( ++ "Fused top-k output must have shape/dtype " ++ f"{expected_shape}/{out_dtype}, got " ++ f"{tuple(output.shape)}/{output.dtype}." ++ ) ++ out = output ++ out.zero_() ++ else: ++ # EP rows that route to remote experts are unwritten. Dense consumers ++ # need zeros; route-aware consumers can skip the memset. ++ alloc = ( ++ torch.zeros ++ if expert_map is not None and zero_nonlocal_output ++ else torch.empty ++ ) ++ out = alloc((M_routed, N), dtype=out_dtype, device=a_q.device) + grid = (m_blocks, n_blocks) + kernel = ( + _mxfp8_grouped_gemm_fnuz_kernel +@@ -428,6 +486,7 @@ def _grouped_gemm_mxfp8( + out.stride(1), + A_DIV=a_div, + MUL_WEIGHT=mul_weight_by is not None, ++ FUSE_TOPK=fuse_topk, + BLOCK_M=block_m, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, +@@ -477,6 +536,7 @@ def fused_moe_mxfp8_native( + expert_map, + ignore_invalid_experts=expert_map is not None, + ) ++ use_sparse_ep_path = expert_map is not None and current_platform.is_fp8_fnuz() + + # GEMM1: x (mxfp8) @ w13^T -> [M, 2I] + a_q, a_s = mxfp8_e4m3_quantize(hidden_states) +@@ -497,6 +557,7 @@ def fused_moe_mxfp8_native( + block_n_override=g1_block_n, + block_k_override=g1_block_k, + num_warps_override=g1_num_warps, ++ zero_nonlocal_output=not use_sparse_ep_path, + ) # [M, 2I] + + # SwiGLU-OAI (split layout: gate=g1[:, :I], up=g1[:, I:]) FUSED with the +@@ -506,10 +567,62 @@ def fused_moe_mxfp8_native( + # ``silu_and_mul_with_clamp`` op: it rounds intermediates to bf16, rel ~3e-3.) + # Lazy import: the amd.ops package pulls in the minimax_m3 platform dispatch, + # only resolvable after the model module finishes loading. +- from vllm.models.minimax_m3.amd.ops import swiglu_oai_quantize_mxfp8 ++ from vllm.models.minimax_m3.amd.ops import ( ++ swiglu_oai_quantize_mxfp8, ++ swiglu_oai_quantize_mxfp8_routed, ++ ) + + # GEMM2: act (mxfp8) @ w2^T -> [M, H], weighted by topk_weights, then reduce. +- act_q, act_s = swiglu_oai_quantize_mxfp8(g1, alpha=alpha, beta=beta, limit=limit) ++ if use_sparse_ep_path: ++ max_post_padded = _max_post_padded( ++ M, ++ w13.shape[0], ++ block_m, ++ sorted_ids.shape[0], ++ ) ++ act_q, act_s = swiglu_oai_quantize_mxfp8_routed( ++ g1, ++ sorted_ids, ++ num_post, ++ num_valid_tokens=M, ++ max_num_tokens_post_padded=max_post_padded, ++ alpha=alpha, ++ beta=beta, ++ limit=limit, ++ block_m=block_m, ++ ) ++ else: ++ act_q, act_s = swiglu_oai_quantize_mxfp8( ++ g1, ++ alpha=alpha, ++ beta=beta, ++ limit=limit, ++ ) ++ ++ if use_sparse_ep_path: ++ return _grouped_gemm_mxfp8( ++ act_q, ++ act_s, ++ w2, ++ w2_scale, ++ sorted_ids, ++ expert_ids, ++ num_post, ++ M, ++ top_k, ++ block_m, ++ hidden_states.dtype, ++ a_div=1, ++ mul_weight_by=topk_weights.reshape(-1).to(torch.float32), ++ expert_map=expert_map, ++ is_gemm2=True, ++ block_n_override=g2_block_n, ++ block_k_override=g2_block_k, ++ num_warps_override=g2_num_warps, ++ output=output, ++ fuse_topk=True, ++ ) ++ + g2 = _grouped_gemm_mxfp8( + act_q, + act_s, +@@ -556,6 +669,7 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): + self.w1_bf16: torch.Tensor | None = None + self.w2_bf16: torch.Tensor | None = None + self.native_weights_available = True ++ self.bf16_weights_available = False + self.bf16_experts: TritonExperts | None = None + if _should_use_bf16_decode_fallback(moe_config): + bf16_config = biased_moe_quant_config( +@@ -583,6 +697,7 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): + self.w1_bf16 = w1_bf16 + self.w2_bf16 = w2_bf16 + self.native_weights_available = native_weights_available ++ self.bf16_weights_available = True + + def bind_packed_weight_scales( + self, +@@ -636,6 +751,7 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): + if bf16_experts is not None and _should_use_bf16_experts( + num_tokens, + self.native_weights_available, ++ self.bf16_weights_available, + ): + if self.w1_bf16 is None or self.w2_bf16 is None: + raise RuntimeError( +diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py +index bc00da41e..e7c2c26b0 100644 +--- a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py ++++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py +@@ -82,6 +82,19 @@ def _select_rocm_mxfp8_backend( + """ROCm fallback when vendor MXFP8 backends are unavailable.""" + + if current_platform.is_fp8_fnuz() and config.ep_size > 1: ++ from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( ++ Mxfp8NativeTritonExperts, ++ _should_use_bf16_decode_fallback, ++ ) ++ ++ if _should_use_bf16_decode_fallback(config): ++ logger.info_once( ++ "Using the profiled gfx94x MiniMax-M3 EP8 MXFP8 backend: " ++ "native local-route kernels for decode and retained BF16 " ++ "experts for large prefill." ++ ) ++ return Fp8MoeBackend.NATIVE_MXFP8, Mxfp8NativeTritonExperts ++ + from vllm.model_executor.layers.fused_moe.experts.mxfp8_emulation_moe import ( + Mxfp8EmulationTritonExperts, + ) +diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py +index 1a4049d3f..9417f2ebd 100644 +--- a/vllm/model_executor/layers/quantization/modelopt.py ++++ b/vllm/model_executor/layers/quantization/modelopt.py +@@ -2132,10 +2132,12 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): + ) + + def _retain_bf16_fallback_weights(self, layer: RoutedExperts) -> None: +- """Keep the BF16 weights selected by the gfx94x TP dispatch policy.""" ++ """Configure profiled gfx94x MiniMax-M3 native/BF16 weight storage.""" + from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( ++ _WEIGHT_STORAGE_BF16_ONLY, ++ _WEIGHT_STORAGE_NATIVE_ONLY, + Mxfp8NativeTritonExperts, +- _should_store_bf16_only, ++ _mxfp8_weight_storage_policy, + ) + from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( + dequant_mxfp8_to_bf16, +@@ -2150,6 +2152,20 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): + "Expected Mxfp8NativeTritonExperts for the gfx94x native backend." + ) + ++ layer_index = extract_layer_index(layer.layer_name) ++ storage_policy = _mxfp8_weight_storage_policy( ++ self.moe.max_model_len, ++ layer_index, ++ self.moe.ep_size, ++ ) ++ if storage_policy == _WEIGHT_STORAGE_NATIVE_ONLY: ++ logger.info_once( ++ "Using native-only MXFP8 storage for one-fifth of long-context " ++ "gfx94x EP MoE layers; the remaining layers retain BF16 for " ++ "large-prefill dispatch." ++ ) ++ return ++ + target_dtype = getattr(layer, "orig_dtype", torch.bfloat16) + w13_bf16 = dequant_mxfp8_to_bf16(layer.w13_weight, layer.w13_weight_scale).to( + target_dtype +@@ -2157,12 +2173,8 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): + w2_bf16 = dequant_mxfp8_to_bf16(layer.w2_weight, layer.w2_weight_scale).to( + target_dtype + ) +- layer_index = extract_layer_index(layer.layer_name) +- store_bf16_only = _should_store_bf16_only( +- self.moe.max_model_len, +- layer_index, +- ) + ++ store_bf16_only = storage_policy == _WEIGHT_STORAGE_BF16_ONLY + if store_bf16_only: + replace_parameter(layer, "w13_weight", w13_bf16) + replace_parameter(layer, "w2_weight", w2_bf16) +@@ -2177,10 +2189,10 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): + + if self.moe.max_model_len <= 4096: + logger.info_once( +- "Retaining BF16 MXFP8 MoE weights for gfx94x TP decode and " +- "prefill dispatch." ++ "Retaining native MXFP8 and BF16 MoE weights for profiled " ++ "gfx94x MiniMax-M3 dispatch." + ) +- else: ++ elif self.moe.ep_size == 1: + logger.info_once( + "Using BF16-only storage for one-fifth of gfx94x TP MoE " + "layers and retaining both MXFP8 and BF16 weights for the " +diff --git a/vllm/models/minimax_m3/amd/ops/__init__.py b/vllm/models/minimax_m3/amd/ops/__init__.py +index 22d96f9de..a2672b05d 100644 +--- a/vllm/models/minimax_m3/amd/ops/__init__.py ++++ b/vllm/models/minimax_m3/amd/ops/__init__.py +@@ -13,6 +13,7 @@ from vllm.models.minimax_m3.amd.ops.gemma_rmsnorm import ( + ) + from vllm.models.minimax_m3.amd.ops.swiglu_oai import ( + swiglu_oai_quantize_mxfp8, ++ swiglu_oai_quantize_mxfp8_routed, + swiglu_oai_split, + ) + +@@ -21,4 +22,5 @@ __all__ = [ + "gemma_fused_add_rmsnorm", + "swiglu_oai_split", + "swiglu_oai_quantize_mxfp8", ++ "swiglu_oai_quantize_mxfp8_routed", + ] +diff --git a/vllm/models/minimax_m3/amd/ops/swiglu_oai.py b/vllm/models/minimax_m3/amd/ops/swiglu_oai.py +index 9572c5109..7fca053fb 100644 +--- a/vllm/models/minimax_m3/amd/ops/swiglu_oai.py ++++ b/vllm/models/minimax_m3/amd/ops/swiglu_oai.py +@@ -124,6 +124,77 @@ def _swiglu_oai_quant_kernel( + ) + + ++@triton.jit ++def _swiglu_oai_quant_routed_kernel( ++ g_ptr, ++ aq_ptr, ++ as_ptr, ++ sorted_token_ids_ptr, ++ num_tokens_post_padded_ptr, ++ num_valid_tokens, ++ n_inter, ++ stride_gm, ++ stride_gn, ++ stride_qm, ++ stride_qn, ++ stride_sm, ++ stride_sk, ++ alpha, ++ beta, ++ limit, ++ HAS_LIMIT: tl.constexpr, ++ BLOCK_M: tl.constexpr, ++): ++ """Route-aware SwiGLU-OAI + MXFP8 quantization for expert parallelism.""" ++ pid_m = tl.program_id(0) ++ pid_b = tl.program_id(1) ++ num_post = tl.load(num_tokens_post_padded_ptr) ++ if pid_m * BLOCK_M >= num_post: ++ return ++ ++ offs_tid = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) ++ route_mask = offs_tid < num_post ++ route_ids = tl.load( ++ sorted_token_ids_ptr + offs_tid, ++ mask=route_mask, ++ other=num_valid_tokens, ++ ).to(tl.int64) ++ route_mask = route_mask & (route_ids < num_valid_tokens) ++ safe_route_ids = tl.where(route_mask, route_ids, 0) ++ offs_c = pid_b * 32 + tl.arange(0, 32) ++ ++ gate = tl.load( ++ g_ptr + safe_route_ids[:, None] * stride_gm + offs_c[None, :] * stride_gn, ++ mask=route_mask[:, None], ++ other=0.0, ++ ).to(tl.float32) ++ up = tl.load( ++ g_ptr ++ + safe_route_ids[:, None] * stride_gm ++ + (n_inter + offs_c)[None, :] * stride_gn, ++ mask=route_mask[:, None], ++ other=0.0, ++ ).to(tl.float32) ++ if HAS_LIMIT: ++ gate = tl.minimum(gate, limit) ++ up = tl.minimum(tl.maximum(up, -limit), limit) ++ act = gate * tl.sigmoid(alpha * gate) * (up + beta) ++ amax = tl.maximum(tl.max(tl.abs(act), axis=1), 1e-30) ++ sb = tl.minimum(tl.maximum(tl.floor(tl.log2(amax)) + 127.0, 0.0), 254.0) ++ descale = tl.exp2(sb - 127.0) ++ aq = (act / descale[:, None]).to(aq_ptr.dtype.element_ty) ++ tl.store( ++ aq_ptr + safe_route_ids[:, None] * stride_qm + offs_c[None, :] * stride_qn, ++ aq, ++ mask=route_mask[:, None], ++ ) ++ tl.store( ++ as_ptr + safe_route_ids * stride_sm + pid_b * stride_sk, ++ sb.to(tl.uint8), ++ mask=route_mask, ++ ) ++ ++ + def swiglu_oai_quantize_mxfp8( + gate_up: torch.Tensor, + alpha: float, +@@ -181,6 +252,79 @@ def swiglu_oai_quantize_mxfp8( + return aq, asc + + ++def swiglu_oai_quantize_mxfp8_routed( ++ gate_up: torch.Tensor, ++ sorted_token_ids: torch.Tensor, ++ num_tokens_post_padded: torch.Tensor, ++ *, ++ num_valid_tokens: int, ++ max_num_tokens_post_padded: int, ++ alpha: float, ++ beta: float, ++ limit: float | None, ++ block_m: int, ++) -> tuple[torch.Tensor, torch.Tensor]: ++ """Quantize only the locally routed rows produced by an EP GEMM1.""" ++ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( ++ MXFP8_BLOCK_SIZE, ++ MXFP8_SCALE_DTYPE, ++ MXFP8_VALUE_DTYPE, ++ ) ++ ++ two_i = gate_up.shape[-1] ++ n_inter = two_i // 2 ++ if n_inter % MXFP8_BLOCK_SIZE != 0: ++ raise ValueError( ++ f"fused swiglu+quant needs I % {MXFP8_BLOCK_SIZE} == 0, got I={n_inter}" ++ ) ++ if max_num_tokens_post_padded % block_m != 0: ++ raise ValueError( ++ "max_num_tokens_post_padded must be block aligned, got " ++ f"{max_num_tokens_post_padded} for block_m={block_m}." ++ ) ++ ++ g1 = gate_up.reshape(-1, two_i).contiguous() ++ value_dtype = ( ++ torch.float8_e4m3fnuz if current_platform.is_fp8_fnuz() else MXFP8_VALUE_DTYPE ++ ) ++ aq = torch.empty( ++ (num_valid_tokens, n_inter), ++ dtype=value_dtype, ++ device=g1.device, ++ ) ++ asc = torch.empty( ++ (num_valid_tokens, n_inter // MXFP8_BLOCK_SIZE), ++ dtype=MXFP8_SCALE_DTYPE, ++ device=g1.device, ++ ) ++ grid = ( ++ max_num_tokens_post_padded // block_m, ++ n_inter // MXFP8_BLOCK_SIZE, ++ ) ++ _swiglu_oai_quant_routed_kernel[grid]( ++ g1, ++ aq, ++ asc, ++ sorted_token_ids, ++ num_tokens_post_padded, ++ num_valid_tokens, ++ n_inter, ++ g1.stride(0), ++ g1.stride(1), ++ aq.stride(0), ++ aq.stride(1), ++ asc.stride(0), ++ asc.stride(1), ++ float(alpha), ++ float(beta), ++ 0.0 if limit is None else float(limit), ++ HAS_LIMIT=limit is not None, ++ BLOCK_M=block_m, ++ num_warps=4, ++ ) ++ return aq, asc ++ ++ + def swiglu_oai_split( + gate_up: torch.Tensor, + alpha: float, From 8279f5020c568450ffeb93dbb4db008b8342fbb0 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Mon, 15 Jun 2026 09:49:20 -0700 Subject: [PATCH 2/4] fix(vllm): exclude tests from runtime patch --- .../minimaxm3_mi300x_ep_mxfp8.patch | 254 ------------------ 1 file changed, 254 deletions(-) diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch index e990a0121..fb65eb059 100644 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch @@ -1,257 +1,3 @@ -diff --git a/tests/kernels/test_minimax_m3_amd_ops.py b/tests/kernels/test_minimax_m3_amd_ops.py -index 4119d7f5c..fe9c524f3 100644 ---- a/tests/kernels/test_minimax_m3_amd_ops.py -+++ b/tests/kernels/test_minimax_m3_amd_ops.py -@@ -112,7 +112,8 @@ def test_mxfp8_gfx94x_grouped_gemm_config( - ({}, True), - ({"max_model_len": 10240}, True), - ({"max_model_len": 0}, False), -- ({"ep_size": 8}, False), -+ ({"ep_size": 8}, True), -+ ({"ep_size": 4}, False), - ({"has_shared_experts": False}, False), - ({"experts_per_token": 8}, False), - ({"hidden_dim": 4096}, False), -@@ -165,45 +166,92 @@ def test_mxfp8_bf16_decode_fallback_disabled_on_gfx950(monkeypatch): - - - @pytest.mark.parametrize( -- ("num_tokens", "native_weights_available", "expected"), -+ ( -+ "num_tokens", -+ "native_weights_available", -+ "bf16_weights_available", -+ "expected", -+ ), - [ -- (1, True, True), -- (8, True, True), -- (9, True, False), -- (128, True, False), -- (831, True, False), -- (832, True, True), -- (1023, True, True), -- (128, False, True), -+ (1, True, True, True), -+ (8, True, True, True), -+ (9, True, True, False), -+ (128, True, True, False), -+ (831, True, True, False), -+ (832, True, True, True), -+ (1023, True, True, True), -+ (128, False, True, True), -+ (1, True, False, False), -+ (1023, True, False, False), - ], - ) - def test_mxfp8_bf16_expert_dispatch( - num_tokens, - native_weights_available, -+ bf16_weights_available, - expected, - ): - from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( - _should_use_bf16_experts, - ) - -- assert _should_use_bf16_experts(num_tokens, native_weights_available) is expected -+ assert ( -+ _should_use_bf16_experts( -+ num_tokens, -+ native_weights_available, -+ bf16_weights_available, -+ ) -+ is expected -+ ) -+ -+ -+@pytest.mark.parametrize( -+ ("max_model_len", "layer_index", "ep_size", "expected"), -+ [ -+ (2304, 0, 1, "dual"), -+ (2304, 0, 8, "dual"), -+ (9472, 0, 1, "bf16_only"), -+ (9472, 0, 8, "native_only"), -+ (9472, 1, 8, "dual"), -+ (9472, 5, 8, "native_only"), -+ ], -+) -+def test_mxfp8_weight_storage_policy(max_model_len, layer_index, ep_size, expected): -+ from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( -+ _mxfp8_weight_storage_policy, -+ ) -+ -+ assert _mxfp8_weight_storage_policy(max_model_len, layer_index, ep_size) == expected - - - @pytest.mark.parametrize( -- ("max_model_len", "layer_index", "expected"), -+ ("num_valid_tokens", "num_local_experts", "block_m", "allocation", "expected"), - [ -- (2304, 0, False), -- (9472, 0, True), -- (9472, 1, False), -- (9472, 5, True), -+ (864, 16, 16, 2784, 1104), -+ (64, 8, 16, 1024, 176), -+ (4, 128, 16, 64, 64), - ], - ) --def test_mxfp8_bf16_only_storage_policy(max_model_len, layer_index, expected): -+def test_mxfp8_max_post_padded( -+ num_valid_tokens, -+ num_local_experts, -+ block_m, -+ allocation, -+ expected, -+): - from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( -- _should_store_bf16_only, -+ _max_post_padded, - ) - -- assert _should_store_bf16_only(max_model_len, layer_index) is expected -+ assert ( -+ _max_post_padded( -+ num_valid_tokens, -+ num_local_experts, -+ block_m, -+ allocation, -+ ) -+ == expected -+ ) - - - @pytest.mark.skipif( -@@ -380,7 +428,17 @@ def test_mxfp8_native_linear(m, n, k): - # --------------------------------------------------------------------------- # - # Fused MXFP8 MoE grouped GEMM vs dequant-to-bf16 MoE math - # --------------------------------------------------------------------------- # --def _ref_moe(x, w13, w2, topk_weights, topk_ids, alpha, beta, limit): -+def _ref_moe( -+ x, -+ w13, -+ w2, -+ topk_weights, -+ topk_ids, -+ alpha, -+ beta, -+ limit, -+ expert_map=None, -+): - T, H = x.shape - inter = w2.shape[-1] - top_k = topk_ids.shape[1] -@@ -388,6 +446,10 @@ def _ref_moe(x, w13, w2, topk_weights, topk_ids, alpha, beta, limit): - for t in range(T): - for j in range(top_k): - e = int(topk_ids[t, j].item()) -+ if expert_map is not None: -+ e = int(expert_map[e].item()) -+ if e < 0: -+ continue - g1 = x[t].float() @ w13[e].float().T # [2I] - gate = g1[:inter] - up = g1[inter:] -@@ -459,6 +521,98 @@ def test_mxfp8_native_moe(T, H, inter, E, top_k, pack_scales): - assert _relerr(got, ref) < 5e-2 - - -+@pytest.mark.skipif( -+ not current_platform.is_fp8_fnuz(), -+ reason="route-aware native EP optimization is specific to gfx94x.", -+) -+@torch.inference_mode() -+def test_mxfp8_native_moe_expert_parallel(): -+ from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( -+ fused_moe_mxfp8_native, -+ ) -+ -+ torch.manual_seed(0) -+ T, H, inter, global_e, local_e, top_k = 8, 256, 512, 8, 4, 2 -+ alpha, beta, limit = 1.702, 1.0, 7.0 -+ w13_bf16 = ( -+ torch.randn(local_e, 2 * inter, H, device=DEVICE, dtype=torch.bfloat16) * 0.1 -+ ) -+ w2_bf16 = torch.randn(local_e, H, inter, device=DEVICE, dtype=torch.bfloat16) * 0.1 -+ w13_fp8, w13_scale = _mxfp8_e4m3_quantize_torch( -+ w13_bf16, -+ is_sf_swizzled_layout=False, -+ ) -+ w2_fp8, w2_scale = _mxfp8_e4m3_quantize_torch( -+ w2_bf16, -+ is_sf_swizzled_layout=False, -+ ) -+ w13_fp8, w13_scale = normalize_mxfp8_e4m3fn_to_e4m3fnuz( -+ w13_fp8, -+ w13_scale, -+ ) -+ w2_fp8, w2_scale = normalize_mxfp8_e4m3fn_to_e4m3fnuz( -+ w2_fp8, -+ w2_scale, -+ ) -+ w13_deq = dequant_mxfp8_to_bf16(w13_fp8, w13_scale) -+ w2_deq = dequant_mxfp8_to_bf16(w2_fp8, w2_scale) -+ w13_scale = w13_scale.transpose(1, 2).contiguous() -+ w2_scale = w2_scale.transpose(1, 2).contiguous() -+ -+ expert_map = torch.tensor( -+ [0, -1, 1, -1, 2, -1, 3, -1], -+ device=DEVICE, -+ dtype=torch.int32, -+ ) -+ topk_ids = torch.tensor( -+ [ -+ [0, 1], -+ [2, 3], -+ [4, 6], -+ [1, 3], -+ [6, 7], -+ [0, 4], -+ [5, 7], -+ [2, 6], -+ ], -+ device=DEVICE, -+ dtype=torch.int32, -+ ) -+ topk_weights = torch.rand(T, top_k, device=DEVICE, dtype=torch.float32) -+ topk_weights /= topk_weights.sum(dim=-1, keepdim=True) -+ x = torch.randn(T, H, device=DEVICE, dtype=torch.bfloat16) * 0.5 -+ output = torch.empty_like(x) -+ -+ got = fused_moe_mxfp8_native( -+ x, -+ w13_fp8, -+ w13_scale, -+ w2_fp8, -+ w2_scale, -+ topk_weights, -+ topk_ids, -+ alpha=alpha, -+ beta=beta, -+ limit=limit, -+ global_num_experts=global_e, -+ expert_map=expert_map, -+ output=output, -+ ) -+ ref = _ref_moe( -+ x, -+ w13_deq, -+ w2_deq, -+ topk_weights, -+ topk_ids, -+ alpha, -+ beta, -+ limit, -+ expert_map=expert_map, -+ ) -+ assert got is output -+ assert _relerr(got, ref) < 5e-2 -+ -+ - # --------------------------------------------------------------------------- # - # MXFP8 linear emulation: BF16-at-load (default) vs per-step dequant + switch - # --------------------------------------------------------------------------- # diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py index 9e0145ff9..244065006 100644 --- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py From b25eff5c035c4f7b40caadb372b7c58b7fe2aef3 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Mon, 15 Jun 2026 10:53:54 -0700 Subject: [PATCH 3/4] perf(vllm): keep MiniMax M3 EP weights compressed --- .../minimaxm3_mi300x_ep_mxfp8.patch | 124 ++++-------------- 1 file changed, 24 insertions(+), 100 deletions(-) diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch index fb65eb059..ed22709e6 100644 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch @@ -1,19 +1,14 @@ diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py -index 9e0145ff9..244065006 100644 +index 9e0145ff9..21187a2ac 100644 --- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py -@@ -46,13 +46,16 @@ _BF16_DECODE_TOKEN_THRESHOLD = 8 - # between 827 and 843 tokens on MI300X. Keep the cutoff tile-aligned. - _BF16_PREFILL_TOKEN_THRESHOLD = 832 +@@ -48,11 +48,10 @@ _BF16_PREFILL_TOKEN_THRESHOLD = 832 _LONG_CONTEXT_BF16_ONLY_LAYER_STRIDE = 5 -+_WEIGHT_STORAGE_DUAL = "dual" -+_WEIGHT_STORAGE_BF16_ONLY = "bf16_only" -+_WEIGHT_STORAGE_NATIVE_ONLY = "native_only" - def _should_use_bf16_decode_fallback(moe_config: FusedMoEConfig) -> bool: +-def _should_use_bf16_decode_fallback(moe_config: FusedMoEConfig) -> bool: - """Limit BF16 fallback weights to the exact MiniMax-M3 TP shape.""" -+ """Limit mixed BF16/native storage to profiled MiniMax-M3 shapes.""" ++def _is_profiled_minimax_m3_config(moe_config: FusedMoEConfig) -> bool: return ( current_platform.is_fp8_fnuz() - and moe_config.ep_size == 1 @@ -21,26 +16,24 @@ index 9e0145ff9..244065006 100644 and moe_config.has_shared_experts and moe_config.num_experts == 128 and moe_config.experts_per_token == 4 -@@ -62,23 +65,45 @@ def _should_use_bf16_decode_fallback(moe_config: FusedMoEConfig) -> bool: +@@ -62,6 +61,16 @@ def _should_use_bf16_decode_fallback(moe_config: FusedMoEConfig) -> bool: ) --def _should_store_bf16_only(max_model_len: int, layer_index: int) -> bool: -- return ( -- max_model_len > 4096 and layer_index % _LONG_CONTEXT_BF16_ONLY_LAYER_STRIDE == 0 -- ) -+def _mxfp8_weight_storage_policy( -+ max_model_len: int, -+ layer_index: int, -+ ep_size: int, -+) -> str: -+ if max_model_len <= 4096 or layer_index % _LONG_CONTEXT_BF16_ONLY_LAYER_STRIDE != 0: -+ return _WEIGHT_STORAGE_DUAL -+ if ep_size > 1: -+ return _WEIGHT_STORAGE_NATIVE_ONLY -+ return _WEIGHT_STORAGE_BF16_ONLY - - ++def _should_use_bf16_decode_fallback(moe_config: FusedMoEConfig) -> bool: ++ """Retain mixed BF16/native weights only for MiniMax-M3 tensor parallelism.""" ++ return _is_profiled_minimax_m3_config(moe_config) and moe_config.ep_size == 1 ++ ++ ++def _should_use_native_ep(moe_config: FusedMoEConfig) -> bool: ++ """Use compressed native weights for the profiled MiniMax-M3 EP8 shape.""" ++ return _is_profiled_minimax_m3_config(moe_config) and moe_config.ep_size == 8 ++ ++ + def _should_store_bf16_only(max_model_len: int, layer_index: int) -> bool: + return ( + max_model_len > 4096 and layer_index % _LONG_CONTEXT_BF16_ONLY_LAYER_STRIDE == 0 +@@ -71,14 +80,30 @@ def _should_store_bf16_only(max_model_len: int, layer_index: int) -> bool: def _should_use_bf16_experts( num_tokens: int, native_weights_available: bool, @@ -315,97 +308,28 @@ index 9e0145ff9..244065006 100644 if self.w1_bf16 is None or self.w2_bf16 is None: raise RuntimeError( diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py -index bc00da41e..e7c2c26b0 100644 +index bc00da41e..13db06d37 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py -@@ -82,6 +82,19 @@ def _select_rocm_mxfp8_backend( +@@ -82,6 +82,18 @@ def _select_rocm_mxfp8_backend( """ROCm fallback when vendor MXFP8 backends are unavailable.""" if current_platform.is_fp8_fnuz() and config.ep_size > 1: + from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( + Mxfp8NativeTritonExperts, -+ _should_use_bf16_decode_fallback, ++ _should_use_native_ep, + ) + -+ if _should_use_bf16_decode_fallback(config): ++ if _should_use_native_ep(config): + logger.info_once( + "Using the profiled gfx94x MiniMax-M3 EP8 MXFP8 backend: " -+ "native local-route kernels for decode and retained BF16 " -+ "experts for large prefill." ++ "native local-route kernels with compressed-only expert weights." + ) + return Fp8MoeBackend.NATIVE_MXFP8, Mxfp8NativeTritonExperts + from vllm.model_executor.layers.fused_moe.experts.mxfp8_emulation_moe import ( Mxfp8EmulationTritonExperts, ) -diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py -index 1a4049d3f..9417f2ebd 100644 ---- a/vllm/model_executor/layers/quantization/modelopt.py -+++ b/vllm/model_executor/layers/quantization/modelopt.py -@@ -2132,10 +2132,12 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): - ) - - def _retain_bf16_fallback_weights(self, layer: RoutedExperts) -> None: -- """Keep the BF16 weights selected by the gfx94x TP dispatch policy.""" -+ """Configure profiled gfx94x MiniMax-M3 native/BF16 weight storage.""" - from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( -+ _WEIGHT_STORAGE_BF16_ONLY, -+ _WEIGHT_STORAGE_NATIVE_ONLY, - Mxfp8NativeTritonExperts, -- _should_store_bf16_only, -+ _mxfp8_weight_storage_policy, - ) - from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( - dequant_mxfp8_to_bf16, -@@ -2150,6 +2152,20 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): - "Expected Mxfp8NativeTritonExperts for the gfx94x native backend." - ) - -+ layer_index = extract_layer_index(layer.layer_name) -+ storage_policy = _mxfp8_weight_storage_policy( -+ self.moe.max_model_len, -+ layer_index, -+ self.moe.ep_size, -+ ) -+ if storage_policy == _WEIGHT_STORAGE_NATIVE_ONLY: -+ logger.info_once( -+ "Using native-only MXFP8 storage for one-fifth of long-context " -+ "gfx94x EP MoE layers; the remaining layers retain BF16 for " -+ "large-prefill dispatch." -+ ) -+ return -+ - target_dtype = getattr(layer, "orig_dtype", torch.bfloat16) - w13_bf16 = dequant_mxfp8_to_bf16(layer.w13_weight, layer.w13_weight_scale).to( - target_dtype -@@ -2157,12 +2173,8 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): - w2_bf16 = dequant_mxfp8_to_bf16(layer.w2_weight, layer.w2_weight_scale).to( - target_dtype - ) -- layer_index = extract_layer_index(layer.layer_name) -- store_bf16_only = _should_store_bf16_only( -- self.moe.max_model_len, -- layer_index, -- ) - -+ store_bf16_only = storage_policy == _WEIGHT_STORAGE_BF16_ONLY - if store_bf16_only: - replace_parameter(layer, "w13_weight", w13_bf16) - replace_parameter(layer, "w2_weight", w2_bf16) -@@ -2177,10 +2189,10 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): - - if self.moe.max_model_len <= 4096: - logger.info_once( -- "Retaining BF16 MXFP8 MoE weights for gfx94x TP decode and " -- "prefill dispatch." -+ "Retaining native MXFP8 and BF16 MoE weights for profiled " -+ "gfx94x MiniMax-M3 dispatch." - ) -- else: -+ elif self.moe.ep_size == 1: - logger.info_once( - "Using BF16-only storage for one-fifth of gfx94x TP MoE " - "layers and retaining both MXFP8 and BF16 weights for the " diff --git a/vllm/models/minimax_m3/amd/ops/__init__.py b/vllm/models/minimax_m3/amd/ops/__init__.py index 22d96f9de..a2672b05d 100644 --- a/vllm/models/minimax_m3/amd/ops/__init__.py From 16c596a2ee42b43dbc08c330dac404caeeb6fb2b Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Mon, 15 Jun 2026 12:00:21 -0700 Subject: [PATCH 4/4] perf(vllm): fuse MiniMax M3 BF16 EP experts --- .../fixed_seq_len/minimaxm3_fp8_mi300x.sh | 3 +- .../minimaxm3_mi300x_ep_mxfp8.patch | 564 +++++++++++++++++- 2 files changed, 542 insertions(+), 25 deletions(-) diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh index ff5a2b1cd..3f7f75117 100755 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh @@ -2,7 +2,8 @@ # MiniMax-M3 MXFP8 MI300X (gfx942) single-node vLLM recipe. # Reuses the dedicated ROCm image and applies the checked-in hybrid gfx94x -# MXFP8 MoE patch plus the EP8 local-route optimization before starting vLLM. +# MXFP8 MoE patch. Short-context EP8 uses the measured native/BF16 policy; +# long-context EP8 uses sparse local-route BF16 GEMMs with fused SwiGLU. # Block size 128 is mandatory for MSA sparse attention. Keep the default BF16 # KV cache on gfx942: the checkpoint has no calibrated q/prob scales for ROCm # FP8 attention, and vLLM's fallback scale of 1.0 corrupts model accuracy. diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch index ed22709e6..3557a793a 100644 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch @@ -1,5 +1,205 @@ +diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py +index 63500487d..a6002b696 100644 +--- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py ++++ b/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py +@@ -11,11 +11,21 @@ import torch + + import vllm.model_executor.layers.fused_moe.modular_kernel as mk + from vllm.logger import init_logger ++from vllm.model_executor.layers.fused_moe.activation import MoEActivation + from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, + ) + from vllm.model_executor.layers.fused_moe.experts.triton_moe import TritonExperts ++from vllm.model_executor.layers.fused_moe.fused_moe import ( ++ _prepare_expert_assignment, ++ invoke_fused_moe_gated_triton_kernel, ++ invoke_fused_moe_triton_kernel, ++) ++from vllm.model_executor.layers.fused_moe.moe_fused_mul_sum import ( ++ moe_fused_mul_sum, ++) ++from vllm.model_executor.layers.fused_moe.utils import _resize_cache + from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( + dequant_mxfp8_to_bf16, + ) +@@ -24,9 +34,35 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kMxfp8Dynamic, + kMxfp8Static, + ) ++from vllm.platforms import current_platform ++from vllm.triton_utils import tl + + logger = init_logger(__name__) + ++_MINIMAX_M3_MI300X_EP_BF16_CONFIG = { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 1, ++ "SPLIT_K": 1, ++ "num_warps": 4, ++ "num_stages": 2, ++} ++ ++ ++def _is_minimax_m3_mi300x_ep8(moe_config: FusedMoEConfig) -> bool: ++ """Match the profiled MiniMax-M3 EP8 shape on gfx94x.""" ++ return ( ++ current_platform.is_fp8_fnuz() ++ and moe_config.ep_size == 8 ++ and moe_config.has_shared_experts ++ and moe_config.num_experts == 128 ++ and moe_config.experts_per_token == 4 ++ and moe_config.hidden_dim == 6144 ++ and moe_config.intermediate_size == 3072 ++ and moe_config.max_model_len > 0 ++ ) ++ + + class Mxfp8TritonExpertsBase(TritonExperts): + """Shared MXFP8 MoE setup: stash E8M0 scales, clear scales on ``quant_config``.""" +@@ -67,6 +103,7 @@ class Mxfp8EmulationTritonExperts(Mxfp8TritonExpertsBase): + quant_config: FusedMoEQuantConfig, + ): + super().__init__(moe_config, quant_config) ++ self.use_sparse_mi300x_ep = _is_minimax_m3_mi300x_ep8(moe_config) + logger.warning_once( + "Using Mxfp8EmulationTritonExperts MoE backend. Weights are " + "dequantized to BF16 on the fly; this is slower than a native " +@@ -122,6 +159,99 @@ class Mxfp8EmulationTritonExperts(Mxfp8TritonExpertsBase): + return + super().activation(activation, output, input) + ++ def _apply_sparse_mi300x_ep( ++ self, ++ output: torch.Tensor, ++ hidden_states: torch.Tensor, ++ w1: torch.Tensor, ++ w2: torch.Tensor, ++ topk_weights: torch.Tensor, ++ topk_ids: torch.Tensor, ++ global_num_experts: int, ++ expert_map: torch.Tensor, ++ workspace13: torch.Tensor, ++ workspace2: torch.Tensor, ++ ) -> None: ++ """Run only local EP routes with a fused BF16 GEMM1 SwiGLU epilogue.""" ++ E, num_tokens, N, K, top_k_num = self.moe_problem_size( ++ hidden_states, w1, w2, topk_ids ++ ) ++ if global_num_experts == -1: ++ global_num_experts = expert_map.numel() ++ config = _MINIMAX_M3_MI300X_EP_BF16_CONFIG ++ sorted_token_ids, expert_ids, num_tokens_post_padded = ( ++ _prepare_expert_assignment( ++ topk_ids, ++ config, ++ num_tokens, ++ top_k_num, ++ global_num_experts, ++ expert_map, ++ ignore_invalid_experts=True, ++ num_local_experts=E, ++ ) ++ ) ++ assert sorted_token_ids is not None ++ ++ activation_dim = N // 2 ++ intermediate_activation = _resize_cache( ++ workspace13, ++ (num_tokens * top_k_num, activation_dim), ++ ) ++ intermediate_output = _resize_cache( ++ workspace2, ++ (num_tokens, top_k_num, K), ++ ) ++ ++ alpha = self.quant_config.gemm1_alpha ++ alpha = 1.702 if alpha is None else float(alpha) ++ beta = self.quant_config.gemm1_beta ++ beta = 1.0 if beta is None else float(beta) ++ limit = self.quant_config.gemm1_clamp_limit ++ limit = None if limit is None else float(limit) ++ ++ invoke_fused_moe_gated_triton_kernel( ++ hidden_states, ++ w1, ++ intermediate_activation, ++ sorted_token_ids, ++ expert_ids, ++ num_tokens_post_padded, ++ top_k_num, ++ config, ++ alpha, ++ beta, ++ limit, ++ ) ++ invoke_fused_moe_triton_kernel( ++ intermediate_activation, ++ w2, ++ intermediate_output, ++ None, ++ None, ++ topk_weights, ++ sorted_token_ids, ++ expert_ids, ++ num_tokens_post_padded, ++ True, ++ 1, ++ config, ++ compute_type=tl.bfloat16, ++ use_fp8_w8a8=False, ++ use_int8_w8a8=False, ++ use_int8_w8a16=False, ++ use_int4_w4a16=False, ++ per_channel_quant=False, ++ ) ++ moe_fused_mul_sum( ++ intermediate_output, ++ topk_weights, ++ outputs=output, ++ topk_ids=topk_ids, ++ expert_map=expert_map, ++ apply_weights=False, ++ ) ++ + def apply( + self, + output: torch.Tensor, +@@ -157,6 +287,29 @@ class Mxfp8EmulationTritonExperts(Mxfp8TritonExpertsBase): + hidden_states.dtype + ) + ++ use_sparse_ep = ( ++ self.use_sparse_mi300x_ep ++ and hidden_states.dtype == torch.bfloat16 ++ and activation == MoEActivation.SWIGLUOAI_UNINTERLEAVE ++ and expert_map is not None ++ and not apply_router_weight_on_input ++ and getattr(self, "_lora_context", None) is None ++ ) ++ if use_sparse_ep: ++ self._apply_sparse_mi300x_ep( ++ output=output, ++ hidden_states=hidden_states, ++ w1=w1_bf16, ++ w2=w2_bf16, ++ topk_weights=topk_weights, ++ topk_ids=topk_ids, ++ global_num_experts=global_num_experts, ++ expert_map=expert_map, ++ workspace13=workspace13, ++ workspace2=workspace2, ++ ) ++ return ++ + super().apply( + output=output, + hidden_states=hidden_states, diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py -index 9e0145ff9..21187a2ac 100644 +index 9e0145ff9..010b53d5f 100644 --- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py @@ -48,11 +48,10 @@ _BF16_PREFILL_TOKEN_THRESHOLD = 832 @@ -16,24 +216,31 @@ index 9e0145ff9..21187a2ac 100644 and moe_config.has_shared_experts and moe_config.num_experts == 128 and moe_config.experts_per_token == 4 -@@ -62,6 +61,16 @@ def _should_use_bf16_decode_fallback(moe_config: FusedMoEConfig) -> bool: +@@ -62,6 +61,23 @@ def _should_use_bf16_decode_fallback(moe_config: FusedMoEConfig) -> bool: ) +def _should_use_bf16_decode_fallback(moe_config: FusedMoEConfig) -> bool: -+ """Retain mixed BF16/native weights only for MiniMax-M3 tensor parallelism.""" -+ return _is_profiled_minimax_m3_config(moe_config) and moe_config.ep_size == 1 ++ """Retain mixed weights where the profiled native/BF16 dispatch wins.""" ++ return _is_profiled_minimax_m3_config(moe_config) and ( ++ moe_config.ep_size == 1 ++ or (moe_config.ep_size == 8 and moe_config.max_model_len <= 4096) ++ ) + + +def _should_use_native_ep(moe_config: FusedMoEConfig) -> bool: -+ """Use compressed native weights for the profiled MiniMax-M3 EP8 shape.""" -+ return _is_profiled_minimax_m3_config(moe_config) and moe_config.ep_size == 8 ++ """Use mixed native/BF16 experts for profiled short-context EP8.""" ++ return ( ++ _is_profiled_minimax_m3_config(moe_config) ++ and moe_config.ep_size == 8 ++ and moe_config.max_model_len <= 4096 ++ ) + + def _should_store_bf16_only(max_model_len: int, layer_index: int) -> bool: return ( max_model_len > 4096 and layer_index % _LONG_CONTEXT_BF16_ONLY_LAYER_STRIDE == 0 -@@ -71,14 +80,30 @@ def _should_store_bf16_only(max_model_len: int, layer_index: int) -> bool: +@@ -71,14 +87,30 @@ def _should_store_bf16_only(max_model_len: int, layer_index: int) -> bool: def _should_use_bf16_experts( num_tokens: int, native_weights_available: bool, @@ -66,7 +273,7 @@ index 9e0145ff9..21187a2ac 100644 @triton.jit def _mxfp8_grouped_gemm_dot_scaled_kernel( a_ptr, -@@ -108,6 +133,7 @@ def _mxfp8_grouped_gemm_dot_scaled_kernel( +@@ -108,6 +140,7 @@ def _mxfp8_grouped_gemm_dot_scaled_kernel( stride_cn, A_DIV: tl.constexpr, MUL_WEIGHT: tl.constexpr, @@ -74,7 +281,7 @@ index 9e0145ff9..21187a2ac 100644 BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, -@@ -161,12 +187,14 @@ def _mxfp8_grouped_gemm_dot_scaled_kernel( +@@ -161,12 +194,14 @@ def _mxfp8_grouped_gemm_dot_scaled_kernel( w = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0) acc = acc * w[:, None] @@ -95,7 +302,7 @@ index 9e0145ff9..21187a2ac 100644 @triton.jit -@@ -198,6 +226,7 @@ def _mxfp8_grouped_gemm_fnuz_kernel( +@@ -198,6 +233,7 @@ def _mxfp8_grouped_gemm_fnuz_kernel( stride_cn, A_DIV: tl.constexpr, MUL_WEIGHT: tl.constexpr, @@ -103,7 +310,7 @@ index 9e0145ff9..21187a2ac 100644 BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, -@@ -271,12 +300,14 @@ def _mxfp8_grouped_gemm_fnuz_kernel( +@@ -271,12 +307,14 @@ def _mxfp8_grouped_gemm_fnuz_kernel( w = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0) acc = acc * w[:, None] @@ -124,7 +331,7 @@ index 9e0145ff9..21187a2ac 100644 def _gfx94x_grouped_gemm_config( -@@ -330,6 +361,9 @@ def _grouped_gemm_mxfp8( +@@ -330,6 +368,9 @@ def _grouped_gemm_mxfp8( block_n_override: int = 0, block_k_override: int = 0, num_warps_override: int = 0, @@ -134,7 +341,7 @@ index 9e0145ff9..21187a2ac 100644 ) -> torch.Tensor: M_routed = num_valid_tokens E, N, K = w.shape -@@ -363,12 +397,15 @@ def _grouped_gemm_mxfp8( +@@ -363,12 +404,15 @@ def _grouped_gemm_mxfp8( BLOCK_N = 128 BLOCK_K = 128 num_warps = 8 @@ -156,7 +363,7 @@ index 9e0145ff9..21187a2ac 100644 if block_n_override: BLOCK_N = block_n_override if block_k_override: -@@ -385,11 +422,32 @@ def _grouped_gemm_mxfp8( +@@ -385,11 +429,32 @@ def _grouped_gemm_mxfp8( m_blocks = triton.cdiv(max_post_padded, block_m) n_blocks = triton.cdiv(N, BLOCK_N) @@ -194,7 +401,7 @@ index 9e0145ff9..21187a2ac 100644 grid = (m_blocks, n_blocks) kernel = ( _mxfp8_grouped_gemm_fnuz_kernel -@@ -428,6 +486,7 @@ def _grouped_gemm_mxfp8( +@@ -428,6 +493,7 @@ def _grouped_gemm_mxfp8( out.stride(1), A_DIV=a_div, MUL_WEIGHT=mul_weight_by is not None, @@ -202,15 +409,17 @@ index 9e0145ff9..21187a2ac 100644 BLOCK_M=block_m, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, -@@ -477,6 +536,7 @@ def fused_moe_mxfp8_native( +@@ -476,7 +542,9 @@ def fused_moe_mxfp8_native( + global_num_experts, expert_map, ignore_invalid_experts=expert_map is not None, ++ num_local_experts=w13.shape[0] if expert_map is not None else None, ) + use_sparse_ep_path = expert_map is not None and current_platform.is_fp8_fnuz() # GEMM1: x (mxfp8) @ w13^T -> [M, 2I] a_q, a_s = mxfp8_e4m3_quantize(hidden_states) -@@ -497,6 +557,7 @@ def fused_moe_mxfp8_native( +@@ -497,6 +565,7 @@ def fused_moe_mxfp8_native( block_n_override=g1_block_n, block_k_override=g1_block_k, num_warps_override=g1_num_warps, @@ -218,7 +427,7 @@ index 9e0145ff9..21187a2ac 100644 ) # [M, 2I] # SwiGLU-OAI (split layout: gate=g1[:, :I], up=g1[:, I:]) FUSED with the -@@ -506,10 +567,62 @@ def fused_moe_mxfp8_native( +@@ -506,10 +575,62 @@ def fused_moe_mxfp8_native( # ``silu_and_mul_with_clamp`` op: it rounds intermediates to bf16, rel ~3e-3.) # Lazy import: the amd.ops package pulls in the minimax_m3 platform dispatch, # only resolvable after the model module finishes loading. @@ -283,7 +492,7 @@ index 9e0145ff9..21187a2ac 100644 g2 = _grouped_gemm_mxfp8( act_q, act_s, -@@ -556,6 +669,7 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -556,6 +677,7 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): self.w1_bf16: torch.Tensor | None = None self.w2_bf16: torch.Tensor | None = None self.native_weights_available = True @@ -291,7 +500,7 @@ index 9e0145ff9..21187a2ac 100644 self.bf16_experts: TritonExperts | None = None if _should_use_bf16_decode_fallback(moe_config): bf16_config = biased_moe_quant_config( -@@ -583,6 +697,7 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -583,6 +705,7 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): self.w1_bf16 = w1_bf16 self.w2_bf16 = w2_bf16 self.native_weights_available = native_weights_available @@ -299,7 +508,7 @@ index 9e0145ff9..21187a2ac 100644 def bind_packed_weight_scales( self, -@@ -636,6 +751,7 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +@@ -636,6 +759,7 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): if bf16_experts is not None and _should_use_bf16_experts( num_tokens, self.native_weights_available, @@ -307,11 +516,293 @@ index 9e0145ff9..21187a2ac 100644 ): if self.w1_bf16 is None or self.w2_bf16 is None: raise RuntimeError( +diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py +index 49957c8f5..26b3c6f74 100644 +--- a/vllm/model_executor/layers/fused_moe/fused_moe.py ++++ b/vllm/model_executor/layers/fused_moe/fused_moe.py +@@ -553,6 +553,168 @@ def fused_moe_kernel( + tl.store(c_ptrs, accumulator, mask=c_mask) + + ++@triton.jit ++def fused_moe_gated_kernel( ++ a_ptr, ++ b_ptr, ++ c_ptr, ++ sorted_token_ids_ptr, ++ expert_ids_ptr, ++ num_tokens_post_padded_ptr, ++ N, ++ K, ++ EM, ++ num_valid_tokens, ++ stride_am, ++ stride_ak, ++ stride_be, ++ stride_bk, ++ stride_bn, ++ stride_cm, ++ stride_cn, ++ alpha, ++ beta, ++ limit, ++ BLOCK_SIZE_M: tl.constexpr, ++ BLOCK_SIZE_N: tl.constexpr, ++ BLOCK_SIZE_K: tl.constexpr, ++ GROUP_SIZE_M: tl.constexpr, ++ top_k: tl.constexpr, ++ compute_type: tl.constexpr, ++ HAS_LIMIT: tl.constexpr, ++): ++ """BF16 grouped GEMM1 with a split SwiGLU-OAI epilogue.""" ++ pid = tl.program_id(axis=0) ++ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) ++ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) ++ num_pid_in_group = GROUP_SIZE_M * num_pid_n ++ group_id = pid // num_pid_in_group ++ first_pid_m = group_id * GROUP_SIZE_M ++ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) ++ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) ++ pid_n = (pid % num_pid_in_group) // group_size_m ++ ++ offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) ++ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) ++ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: ++ return ++ ++ offs_token_id = pid_m * BLOCK_SIZE_M + offs_m ++ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64) ++ token_mask = offs_token < num_valid_tokens ++ off_expert = tl.load(expert_ids_ptr + pid_m).to(tl.int64) ++ if off_expert < 0: ++ return ++ ++ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) ++ offs_k = tl.arange(0, BLOCK_SIZE_K) ++ n_mask = offs_n < N ++ a_ptrs = a_ptr + ( ++ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak ++ ) ++ b_gate_ptrs = ( ++ b_ptr ++ + off_expert * stride_be ++ + offs_k[:, None] * stride_bk ++ + offs_n[None, :] * stride_bn ++ ) ++ b_up_ptrs = b_gate_ptrs + N * stride_bn ++ ++ gate_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) ++ up_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) ++ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): ++ k_mask = offs_k < K - k * BLOCK_SIZE_K ++ a = tl.load( ++ a_ptrs, ++ mask=token_mask[:, None] & k_mask[None, :], ++ other=0.0, ++ ) ++ gate_w = tl.load( ++ b_gate_ptrs, ++ mask=k_mask[:, None] & n_mask[None, :], ++ other=0.0, ++ ) ++ up_w = tl.load( ++ b_up_ptrs, ++ mask=k_mask[:, None] & n_mask[None, :], ++ other=0.0, ++ ) ++ gate_acc += tl.dot(a, gate_w) ++ up_acc += tl.dot(a, up_w) ++ a_ptrs += BLOCK_SIZE_K * stride_ak ++ b_gate_ptrs += BLOCK_SIZE_K * stride_bk ++ b_up_ptrs += BLOCK_SIZE_K * stride_bk ++ ++ # Preserve the BF16 GEMM1 store/reload boundary of the unfused path before ++ # applying SwiGLU in FP32. ++ gate = gate_acc.to(compute_type).to(tl.float32) ++ up = up_acc.to(compute_type).to(tl.float32) ++ if HAS_LIMIT: ++ gate = tl.minimum(gate, limit) ++ up = tl.minimum(tl.maximum(up, -limit), limit) ++ activated = gate * tl.sigmoid(alpha * gate) * (up + beta) ++ ++ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_n[None, :] ++ c_mask = token_mask[:, None] & n_mask[None, :] ++ tl.store(c_ptrs, activated.to(compute_type), mask=c_mask) ++ ++ ++def invoke_fused_moe_gated_triton_kernel( ++ A: torch.Tensor, ++ B: torch.Tensor, ++ C: torch.Tensor, ++ sorted_token_ids: torch.Tensor, ++ expert_ids: torch.Tensor, ++ num_tokens_post_padded: torch.Tensor, ++ top_k: int, ++ config: dict[str, Any], ++ alpha: float, ++ beta: float, ++ limit: float | None, ++) -> None: ++ """Run BF16 GEMM1 and emit activated split-SwiGLU rows directly.""" ++ assert A.dtype == B.dtype == C.dtype == torch.bfloat16 ++ assert B.size(1) == C.size(1) * 2 ++ assert A.size(1) == B.size(2) ++ assert sorted_token_ids.stride(0) == 1 ++ ++ EM = sorted_token_ids.size(0) ++ grid = lambda META: ( ++ triton.cdiv(EM, META["BLOCK_SIZE_M"]) ++ * triton.cdiv(C.size(1), META["BLOCK_SIZE_N"]), ++ ) ++ launch_config = config.copy() ++ launch_config.pop("SPLIT_K", None) ++ block_size_k = launch_config.pop("BLOCK_SIZE_K") ++ fused_moe_gated_kernel[grid]( ++ A, ++ B, ++ C, ++ sorted_token_ids, ++ expert_ids, ++ num_tokens_post_padded, ++ C.size(1), ++ A.size(1), ++ EM, ++ A.size(0) * top_k, ++ A.stride(0), ++ A.stride(1), ++ B.stride(0), ++ B.stride(2), ++ B.stride(1), ++ C.stride(0), ++ C.stride(1), ++ alpha, ++ beta, ++ 0.0 if limit is None else limit, ++ top_k=top_k, ++ compute_type=tl.bfloat16, ++ HAS_LIMIT=limit is not None, ++ BLOCK_SIZE_K=block_size_k, ++ **launch_config, ++ ) ++ ++ + # NOTE(zyongye): we can remove all the wna16 kernel + # once we drop off sm75 support + def invoke_fused_moe_wna16_cuda_kernel( +@@ -1434,6 +1596,7 @@ def _prepare_expert_assignment( + use_int4_w4a16: bool = False, + block_shape: list[int] | None = None, + ignore_invalid_experts: bool = False, ++ num_local_experts: int | None = None, + ) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor]: + """Prepare expert assignments for the aligned and low-latency Triton paths.""" + # SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k +@@ -1468,6 +1631,7 @@ def _prepare_expert_assignment( + global_num_experts, + expert_map, + ignore_invalid_experts=ignore_invalid_experts, ++ num_local_experts=num_local_experts, + ) + + +diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +index 7fc8bfcf8..99404ed95 100644 +--- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py ++++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +@@ -15,6 +15,7 @@ def moe_align_block_size( + expert_map: torch.Tensor | None = None, + pad_sorted_ids: bool = False, + ignore_invalid_experts: bool = False, ++ num_local_experts: int | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns the token distribution across experts to be compatible with block +@@ -43,6 +44,9 @@ def moe_align_block_size( + as -1. When True, all invalid expert_ids in topk_ids will be ignored + and will not participate in counting or ranking, and there will be no + -1 in expert_ids. ++ - num_local_experts: The number of experts retained by ``expert_map``. ++ When invalid experts are ignored, this tightens the output allocation ++ from the global expert count to the local expert count. + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according +@@ -71,7 +75,20 @@ def moe_align_block_size( + - The padding ensures that the total number of tokens is now divisible + by block_size for proper block matrix operations. + """ +- max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) ++ padding_experts = num_experts ++ if ( ++ ignore_invalid_experts ++ and expert_map is not None ++ and num_local_experts is not None ++ ): ++ if not 0 < num_local_experts <= num_experts: ++ raise ValueError( ++ "num_local_experts must be in (0, num_experts], got " ++ f"{num_local_experts} for num_experts={num_experts}." ++ ) ++ padding_experts = num_local_experts ++ ++ max_num_tokens_padded = topk_ids.numel() + padding_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + if topk_ids.numel() < num_experts: +diff --git a/vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py b/vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py +index 768f41db8..0465112ee 100644 +--- a/vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py ++++ b/vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py +@@ -17,6 +17,7 @@ def moe_fused_mul_sum_kernel( + num_tokens, + stride_m, + has_expert_map: tl.constexpr, ++ apply_weights: tl.constexpr, + top_k: tl.constexpr, + size: tl.constexpr, + BLOCK_M: tl.constexpr, +@@ -38,7 +39,10 @@ def moe_fused_mul_sum_kernel( + acc = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) + + for n in tl.static_range(top_k): +- b_val = tl.load(b_base + n, mask=m_mask, other=0.0).to(tl.float32) ++ if apply_weights: ++ b_val = tl.load(b_base + n, mask=m_mask, other=0.0).to(tl.float32) ++ else: ++ b_val = 1.0 + if has_expert_map: + id_val = tl.load(top_ids_ptr + offs_m * top_k + n, mask=m_mask, other=0) + expert_mask = tl.load(expert_map_ptr + id_val) >= 0 +@@ -138,6 +142,7 @@ def moe_fused_mul_sum( + outputs: torch.Tensor | None = None, + topk_ids: torch.Tensor | None = None, + expert_map: torch.Tensor | None = None, ++ apply_weights: bool = True, + ) -> torch.Tensor: + """ + Fused kernel for MoE (Mixture of Experts) to perform weighted summation +@@ -154,6 +159,8 @@ def moe_fused_mul_sum( + `expert_map` is provided. Shape: (num_tokens, top_k). + expert_map: Optional mapping for Expert Parallelism. A value < 0 + indicates an invalid token/expert pair that will be skipped. ++ apply_weights: Multiply each route by ``topk_weights`` before summing. ++ Set to false when the expert GEMM already applied router weights. + + Returns: + The fused weighted sum of expert outputs. +@@ -191,6 +198,7 @@ def moe_fused_mul_sum( + num_tokens, + top_k * size, + expert_map is not None, ++ apply_weights, + top_k, + size, + BLOCK_M, diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py -index bc00da41e..13db06d37 100644 +index bc00da41e..378b897a7 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py -@@ -82,6 +82,18 @@ def _select_rocm_mxfp8_backend( +@@ -82,14 +82,25 @@ def _select_rocm_mxfp8_backend( """ROCm fallback when vendor MXFP8 backends are unavailable.""" if current_platform.is_fp8_fnuz() and config.ep_size > 1: @@ -323,13 +814,38 @@ index bc00da41e..13db06d37 100644 + if _should_use_native_ep(config): + logger.info_once( + "Using the profiled gfx94x MiniMax-M3 EP8 MXFP8 backend: " -+ "native local-route kernels with compressed-only expert weights." ++ "native local-route kernels with retained BF16 decode experts." + ) + return Fp8MoeBackend.NATIVE_MXFP8, Mxfp8NativeTritonExperts + from vllm.model_executor.layers.fused_moe.experts.mxfp8_emulation_moe import ( Mxfp8EmulationTritonExperts, ) + + logger.info_once( +- "Using BF16 MXFP8 emulation for gfx94x expert parallelism; the " +- "native CDNA3 path is optimized for decode-sized TP workloads and " +- "is slower for the large local batches reached during EP prefill." ++ "Using the profiled sparse BF16 MXFP8 path for long-context gfx94x " ++ "expert parallelism." + ) + return Fp8MoeBackend.EMULATION, Mxfp8EmulationTritonExperts + +diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py +index 1a4049d3f..8ea843e7d 100644 +--- a/vllm/model_executor/layers/quantization/modelopt.py ++++ b/vllm/model_executor/layers/quantization/modelopt.py +@@ -2177,8 +2177,8 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): + + if self.moe.max_model_len <= 4096: + logger.info_once( +- "Retaining BF16 MXFP8 MoE weights for gfx94x TP decode and " +- "prefill dispatch." ++ "Retaining BF16 MXFP8 MoE weights for profiled gfx94x " ++ "native/BF16 dispatch." + ) + else: + logger.info_once( diff --git a/vllm/models/minimax_m3/amd/ops/__init__.py b/vllm/models/minimax_m3/amd/ops/__init__.py index 22d96f9de..a2672b05d 100644 --- a/vllm/models/minimax_m3/amd/ops/__init__.py