diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh index 4fbe92bcd..3f7f75117 100755 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh @@ -2,10 +2,11 @@ # MiniMax-M3 MXFP8 MI300X (gfx942) single-node vLLM recipe. # Reuses the dedicated ROCm image and applies the checked-in hybrid gfx94x -# MXFP8 MoE patch before starting vLLM. Block size 128 is mandatory for MSA -# sparse attention. Keep the default BF16 KV cache on gfx942: the checkpoint -# has no calibrated q/prob scales for ROCm FP8 attention, and vLLM's fallback -# scale of 1.0 corrupts model accuracy. +# MXFP8 MoE patch. Short-context EP8 uses the measured native/BF16 policy; +# long-context EP8 uses sparse local-route BF16 GEMMs with fused SwiGLU. +# Block size 128 is mandatory for MSA sparse attention. Keep the default BF16 +# KV cache on gfx942: the checkpoint has no calibrated q/prob scales for ROCm +# FP8 attention, and vLLM's fallback scale of 1.0 corrupts model accuracy. # Target image vLLM revision: 4a560dd8db67c270f5e2afb614558271b76f2294. source "$(dirname "$0")/../../benchmark_lib.sh" @@ -36,6 +37,7 @@ print(Path(vllm.__file__).resolve().parent.parent) PY )" MXFP8_PATCH="$(dirname "$0")/minimaxm3_mi300x_mxfp8.patch" +MXFP8_EP_PATCH="$(dirname "$0")/minimaxm3_mi300x_ep_mxfp8.patch" MXFP8_ORACLE="$VLLM_PACKAGE_ROOT/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py" if ! grep -q "Using fused CDNA3 (gfx94x)" "$MXFP8_ORACLE"; then if ! patch --batch --forward -d "$VLLM_PACKAGE_ROOT" -p1 < "$MXFP8_PATCH"; then @@ -47,6 +49,16 @@ if ! grep -q "Using fused CDNA3 (gfx94x)" "$MXFP8_ORACLE"; then echo "MI300X MXFP8 backend marker is missing after patching" >&2 exit 1 fi +if ! grep -q "profiled gfx94x MiniMax-M3 EP8" "$MXFP8_ORACLE"; then + if ! patch --batch --forward -d "$VLLM_PACKAGE_ROOT" -p1 < "$MXFP8_EP_PATCH"; then + echo "Failed to apply the MI300X EP8 MXFP8 optimization patch" >&2 + exit 1 + fi +fi +if ! grep -q "profiled gfx94x MiniMax-M3 EP8" "$MXFP8_ORACLE"; then + echo "MI300X EP8 MXFP8 optimization marker is missing after patching" >&2 + exit 1 +fi if [[ "$MODEL" != /* ]]; then hf download "$MODEL"; fi diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch new file mode 100644 index 000000000..3557a793a --- /dev/null +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_ep_mxfp8.patch @@ -0,0 +1,1028 @@ +diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py +index 63500487d..a6002b696 100644 +--- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py ++++ b/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py +@@ -11,11 +11,21 @@ import torch + + import vllm.model_executor.layers.fused_moe.modular_kernel as mk + from vllm.logger import init_logger ++from vllm.model_executor.layers.fused_moe.activation import MoEActivation + from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, + ) + from vllm.model_executor.layers.fused_moe.experts.triton_moe import TritonExperts ++from vllm.model_executor.layers.fused_moe.fused_moe import ( ++ _prepare_expert_assignment, ++ invoke_fused_moe_gated_triton_kernel, ++ invoke_fused_moe_triton_kernel, ++) ++from vllm.model_executor.layers.fused_moe.moe_fused_mul_sum import ( ++ moe_fused_mul_sum, ++) ++from vllm.model_executor.layers.fused_moe.utils import _resize_cache + from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( + dequant_mxfp8_to_bf16, + ) +@@ -24,9 +34,35 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kMxfp8Dynamic, + kMxfp8Static, + ) ++from vllm.platforms import current_platform ++from vllm.triton_utils import tl + + logger = init_logger(__name__) + ++_MINIMAX_M3_MI300X_EP_BF16_CONFIG = { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 1, ++ "SPLIT_K": 1, ++ "num_warps": 4, ++ "num_stages": 2, ++} ++ ++ ++def _is_minimax_m3_mi300x_ep8(moe_config: FusedMoEConfig) -> bool: ++ """Match the profiled MiniMax-M3 EP8 shape on gfx94x.""" ++ return ( ++ current_platform.is_fp8_fnuz() ++ and moe_config.ep_size == 8 ++ and moe_config.has_shared_experts ++ and moe_config.num_experts == 128 ++ and moe_config.experts_per_token == 4 ++ and moe_config.hidden_dim == 6144 ++ and moe_config.intermediate_size == 3072 ++ and moe_config.max_model_len > 0 ++ ) ++ + + class Mxfp8TritonExpertsBase(TritonExperts): + """Shared MXFP8 MoE setup: stash E8M0 scales, clear scales on ``quant_config``.""" +@@ -67,6 +103,7 @@ class Mxfp8EmulationTritonExperts(Mxfp8TritonExpertsBase): + quant_config: FusedMoEQuantConfig, + ): + super().__init__(moe_config, quant_config) ++ self.use_sparse_mi300x_ep = _is_minimax_m3_mi300x_ep8(moe_config) + logger.warning_once( + "Using Mxfp8EmulationTritonExperts MoE backend. Weights are " + "dequantized to BF16 on the fly; this is slower than a native " +@@ -122,6 +159,99 @@ class Mxfp8EmulationTritonExperts(Mxfp8TritonExpertsBase): + return + super().activation(activation, output, input) + ++ def _apply_sparse_mi300x_ep( ++ self, ++ output: torch.Tensor, ++ hidden_states: torch.Tensor, ++ w1: torch.Tensor, ++ w2: torch.Tensor, ++ topk_weights: torch.Tensor, ++ topk_ids: torch.Tensor, ++ global_num_experts: int, ++ expert_map: torch.Tensor, ++ workspace13: torch.Tensor, ++ workspace2: torch.Tensor, ++ ) -> None: ++ """Run only local EP routes with a fused BF16 GEMM1 SwiGLU epilogue.""" ++ E, num_tokens, N, K, top_k_num = self.moe_problem_size( ++ hidden_states, w1, w2, topk_ids ++ ) ++ if global_num_experts == -1: ++ global_num_experts = expert_map.numel() ++ config = _MINIMAX_M3_MI300X_EP_BF16_CONFIG ++ sorted_token_ids, expert_ids, num_tokens_post_padded = ( ++ _prepare_expert_assignment( ++ topk_ids, ++ config, ++ num_tokens, ++ top_k_num, ++ global_num_experts, ++ expert_map, ++ ignore_invalid_experts=True, ++ num_local_experts=E, ++ ) ++ ) ++ assert sorted_token_ids is not None ++ ++ activation_dim = N // 2 ++ intermediate_activation = _resize_cache( ++ workspace13, ++ (num_tokens * top_k_num, activation_dim), ++ ) ++ intermediate_output = _resize_cache( ++ workspace2, ++ (num_tokens, top_k_num, K), ++ ) ++ ++ alpha = self.quant_config.gemm1_alpha ++ alpha = 1.702 if alpha is None else float(alpha) ++ beta = self.quant_config.gemm1_beta ++ beta = 1.0 if beta is None else float(beta) ++ limit = self.quant_config.gemm1_clamp_limit ++ limit = None if limit is None else float(limit) ++ ++ invoke_fused_moe_gated_triton_kernel( ++ hidden_states, ++ w1, ++ intermediate_activation, ++ sorted_token_ids, ++ expert_ids, ++ num_tokens_post_padded, ++ top_k_num, ++ config, ++ alpha, ++ beta, ++ limit, ++ ) ++ invoke_fused_moe_triton_kernel( ++ intermediate_activation, ++ w2, ++ intermediate_output, ++ None, ++ None, ++ topk_weights, ++ sorted_token_ids, ++ expert_ids, ++ num_tokens_post_padded, ++ True, ++ 1, ++ config, ++ compute_type=tl.bfloat16, ++ use_fp8_w8a8=False, ++ use_int8_w8a8=False, ++ use_int8_w8a16=False, ++ use_int4_w4a16=False, ++ per_channel_quant=False, ++ ) ++ moe_fused_mul_sum( ++ intermediate_output, ++ topk_weights, ++ outputs=output, ++ topk_ids=topk_ids, ++ expert_map=expert_map, ++ apply_weights=False, ++ ) ++ + def apply( + self, + output: torch.Tensor, +@@ -157,6 +287,29 @@ class Mxfp8EmulationTritonExperts(Mxfp8TritonExpertsBase): + hidden_states.dtype + ) + ++ use_sparse_ep = ( ++ self.use_sparse_mi300x_ep ++ and hidden_states.dtype == torch.bfloat16 ++ and activation == MoEActivation.SWIGLUOAI_UNINTERLEAVE ++ and expert_map is not None ++ and not apply_router_weight_on_input ++ and getattr(self, "_lora_context", None) is None ++ ) ++ if use_sparse_ep: ++ self._apply_sparse_mi300x_ep( ++ output=output, ++ hidden_states=hidden_states, ++ w1=w1_bf16, ++ w2=w2_bf16, ++ topk_weights=topk_weights, ++ topk_ids=topk_ids, ++ global_num_experts=global_num_experts, ++ expert_map=expert_map, ++ workspace13=workspace13, ++ workspace2=workspace2, ++ ) ++ return ++ + super().apply( + output=output, + hidden_states=hidden_states, +diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py +index 9e0145ff9..010b53d5f 100644 +--- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py ++++ b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py +@@ -48,11 +48,10 @@ _BF16_PREFILL_TOKEN_THRESHOLD = 832 + _LONG_CONTEXT_BF16_ONLY_LAYER_STRIDE = 5 + + +-def _should_use_bf16_decode_fallback(moe_config: FusedMoEConfig) -> bool: +- """Limit BF16 fallback weights to the exact MiniMax-M3 TP shape.""" ++def _is_profiled_minimax_m3_config(moe_config: FusedMoEConfig) -> bool: + return ( + current_platform.is_fp8_fnuz() +- and moe_config.ep_size == 1 ++ and moe_config.ep_size in (1, 8) + and moe_config.has_shared_experts + and moe_config.num_experts == 128 + and moe_config.experts_per_token == 4 +@@ -62,6 +61,23 @@ def _should_use_bf16_decode_fallback(moe_config: FusedMoEConfig) -> bool: + ) + + ++def _should_use_bf16_decode_fallback(moe_config: FusedMoEConfig) -> bool: ++ """Retain mixed weights where the profiled native/BF16 dispatch wins.""" ++ return _is_profiled_minimax_m3_config(moe_config) and ( ++ moe_config.ep_size == 1 ++ or (moe_config.ep_size == 8 and moe_config.max_model_len <= 4096) ++ ) ++ ++ ++def _should_use_native_ep(moe_config: FusedMoEConfig) -> bool: ++ """Use mixed native/BF16 experts for profiled short-context EP8.""" ++ return ( ++ _is_profiled_minimax_m3_config(moe_config) ++ and moe_config.ep_size == 8 ++ and moe_config.max_model_len <= 4096 ++ ) ++ ++ + def _should_store_bf16_only(max_model_len: int, layer_index: int) -> bool: + return ( + max_model_len > 4096 and layer_index % _LONG_CONTEXT_BF16_ONLY_LAYER_STRIDE == 0 +@@ -71,14 +87,30 @@ def _should_store_bf16_only(max_model_len: int, layer_index: int) -> bool: + def _should_use_bf16_experts( + num_tokens: int, + native_weights_available: bool, ++ bf16_weights_available: bool, + ) -> bool: +- return ( ++ return bf16_weights_available and ( + not native_weights_available +- or num_tokens >= _BF16_PREFILL_TOKEN_THRESHOLD + or num_tokens <= _BF16_DECODE_TOKEN_THRESHOLD ++ or num_tokens >= _BF16_PREFILL_TOKEN_THRESHOLD + ) + + ++def _max_post_padded( ++ num_valid_tokens: int, ++ num_local_experts: int, ++ block_m: int, ++ allocation_size: int, ++) -> int: ++ """Static upper bound for a block-aligned local-expert route list.""" ++ max_padded = min( ++ allocation_size, ++ num_valid_tokens * block_m, ++ num_valid_tokens + num_local_experts * (block_m - 1), ++ ) ++ return (max_padded // block_m) * block_m ++ ++ + @triton.jit + def _mxfp8_grouped_gemm_dot_scaled_kernel( + a_ptr, +@@ -108,6 +140,7 @@ def _mxfp8_grouped_gemm_dot_scaled_kernel( + stride_cn, + A_DIV: tl.constexpr, + MUL_WEIGHT: tl.constexpr, ++ FUSE_TOPK: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +@@ -161,12 +194,14 @@ def _mxfp8_grouped_gemm_dot_scaled_kernel( + w = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0) + acc = acc * w[:, None] + +- c_ptrs = c_ptr + offs_token[:, None] * stride_cm + offs_n[None, :] * stride_cn +- tl.store( +- c_ptrs, +- acc.to(c_ptr.dtype.element_ty), +- mask=token_mask[:, None] & n_mask[None, :], +- ) ++ c_row = offs_token // top_k if FUSE_TOPK else offs_token ++ c_ptrs = c_ptr + c_row[:, None] * stride_cm + offs_n[None, :] * stride_cn ++ c_mask = token_mask[:, None] & n_mask[None, :] ++ result = acc.to(c_ptr.dtype.element_ty) ++ if FUSE_TOPK: ++ tl.atomic_add(c_ptrs, result, mask=c_mask, sem="relaxed") ++ else: ++ tl.store(c_ptrs, result, mask=c_mask) + + + @triton.jit +@@ -198,6 +233,7 @@ def _mxfp8_grouped_gemm_fnuz_kernel( + stride_cn, + A_DIV: tl.constexpr, + MUL_WEIGHT: tl.constexpr, ++ FUSE_TOPK: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +@@ -271,12 +307,14 @@ def _mxfp8_grouped_gemm_fnuz_kernel( + w = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0) + acc = acc * w[:, None] + +- c_ptrs = c_ptr + offs_token[:, None] * stride_cm + offs_n[None, :] * stride_cn +- tl.store( +- c_ptrs, +- acc.to(c_ptr.dtype.element_ty), +- mask=token_mask[:, None] & n_mask[None, :], +- ) ++ c_row = offs_token // top_k if FUSE_TOPK else offs_token ++ c_ptrs = c_ptr + c_row[:, None] * stride_cm + offs_n[None, :] * stride_cn ++ c_mask = token_mask[:, None] & n_mask[None, :] ++ result = acc.to(c_ptr.dtype.element_ty) ++ if FUSE_TOPK: ++ tl.atomic_add(c_ptrs, result, mask=c_mask, sem="relaxed") ++ else: ++ tl.store(c_ptrs, result, mask=c_mask) + + + def _gfx94x_grouped_gemm_config( +@@ -330,6 +368,9 @@ def _grouped_gemm_mxfp8( + block_n_override: int = 0, + block_k_override: int = 0, + num_warps_override: int = 0, ++ output: torch.Tensor | None = None, ++ fuse_topk: bool = False, ++ zero_nonlocal_output: bool = True, + ) -> torch.Tensor: + M_routed = num_valid_tokens + E, N, K = w.shape +@@ -363,12 +404,15 @@ def _grouped_gemm_mxfp8( + BLOCK_N = 128 + BLOCK_K = 128 + num_warps = 8 +- # moe_align_block_size allocates for the worst case where every expert is +- # active. At small batches that can be much larger than the number of +- # blocks that can contain valid assignments. Limit the launch to the +- # tighter static upper bound; the device-side num_post check handles the +- # remaining tail. +- max_post_padded = min(sorted_token_ids.shape[0], M_routed * block_m) ++ # With EP, the align buffer is sized from the global expert count even ++ # though only local assignments survive. Bound the launch by the local ++ # expert count; the device-side num_post check handles the remaining tail. ++ max_post_padded = _max_post_padded( ++ M_routed, ++ E, ++ block_m, ++ sorted_token_ids.shape[0], ++ ) + if block_n_override: + BLOCK_N = block_n_override + if block_k_override: +@@ -385,11 +429,32 @@ def _grouped_gemm_mxfp8( + m_blocks = triton.cdiv(max_post_padded, block_m) + n_blocks = triton.cdiv(N, BLOCK_N) + +- # Under expert parallelism (expert_map set) tokens routed to non-local +- # experts are dropped from sorted_token_ids, so their output rows are never +- # written. +- alloc = torch.zeros if expert_map is not None else torch.empty +- out = alloc((M_routed, N), dtype=out_dtype, device=a_q.device) ++ if fuse_topk: ++ if M_routed % top_k != 0: ++ raise ValueError( ++ f"Routed rows ({M_routed}) must be divisible by top_k ({top_k})." ++ ) ++ expected_shape = (M_routed // top_k, N) ++ if output is None: ++ out = torch.zeros(expected_shape, dtype=out_dtype, device=a_q.device) ++ else: ++ if output.shape != expected_shape or output.dtype != out_dtype: ++ raise ValueError( ++ "Fused top-k output must have shape/dtype " ++ f"{expected_shape}/{out_dtype}, got " ++ f"{tuple(output.shape)}/{output.dtype}." ++ ) ++ out = output ++ out.zero_() ++ else: ++ # EP rows that route to remote experts are unwritten. Dense consumers ++ # need zeros; route-aware consumers can skip the memset. ++ alloc = ( ++ torch.zeros ++ if expert_map is not None and zero_nonlocal_output ++ else torch.empty ++ ) ++ out = alloc((M_routed, N), dtype=out_dtype, device=a_q.device) + grid = (m_blocks, n_blocks) + kernel = ( + _mxfp8_grouped_gemm_fnuz_kernel +@@ -428,6 +493,7 @@ def _grouped_gemm_mxfp8( + out.stride(1), + A_DIV=a_div, + MUL_WEIGHT=mul_weight_by is not None, ++ FUSE_TOPK=fuse_topk, + BLOCK_M=block_m, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, +@@ -476,7 +542,9 @@ def fused_moe_mxfp8_native( + global_num_experts, + expert_map, + ignore_invalid_experts=expert_map is not None, ++ num_local_experts=w13.shape[0] if expert_map is not None else None, + ) ++ use_sparse_ep_path = expert_map is not None and current_platform.is_fp8_fnuz() + + # GEMM1: x (mxfp8) @ w13^T -> [M, 2I] + a_q, a_s = mxfp8_e4m3_quantize(hidden_states) +@@ -497,6 +565,7 @@ def fused_moe_mxfp8_native( + block_n_override=g1_block_n, + block_k_override=g1_block_k, + num_warps_override=g1_num_warps, ++ zero_nonlocal_output=not use_sparse_ep_path, + ) # [M, 2I] + + # SwiGLU-OAI (split layout: gate=g1[:, :I], up=g1[:, I:]) FUSED with the +@@ -506,10 +575,62 @@ def fused_moe_mxfp8_native( + # ``silu_and_mul_with_clamp`` op: it rounds intermediates to bf16, rel ~3e-3.) + # Lazy import: the amd.ops package pulls in the minimax_m3 platform dispatch, + # only resolvable after the model module finishes loading. +- from vllm.models.minimax_m3.amd.ops import swiglu_oai_quantize_mxfp8 ++ from vllm.models.minimax_m3.amd.ops import ( ++ swiglu_oai_quantize_mxfp8, ++ swiglu_oai_quantize_mxfp8_routed, ++ ) + + # GEMM2: act (mxfp8) @ w2^T -> [M, H], weighted by topk_weights, then reduce. +- act_q, act_s = swiglu_oai_quantize_mxfp8(g1, alpha=alpha, beta=beta, limit=limit) ++ if use_sparse_ep_path: ++ max_post_padded = _max_post_padded( ++ M, ++ w13.shape[0], ++ block_m, ++ sorted_ids.shape[0], ++ ) ++ act_q, act_s = swiglu_oai_quantize_mxfp8_routed( ++ g1, ++ sorted_ids, ++ num_post, ++ num_valid_tokens=M, ++ max_num_tokens_post_padded=max_post_padded, ++ alpha=alpha, ++ beta=beta, ++ limit=limit, ++ block_m=block_m, ++ ) ++ else: ++ act_q, act_s = swiglu_oai_quantize_mxfp8( ++ g1, ++ alpha=alpha, ++ beta=beta, ++ limit=limit, ++ ) ++ ++ if use_sparse_ep_path: ++ return _grouped_gemm_mxfp8( ++ act_q, ++ act_s, ++ w2, ++ w2_scale, ++ sorted_ids, ++ expert_ids, ++ num_post, ++ M, ++ top_k, ++ block_m, ++ hidden_states.dtype, ++ a_div=1, ++ mul_weight_by=topk_weights.reshape(-1).to(torch.float32), ++ expert_map=expert_map, ++ is_gemm2=True, ++ block_n_override=g2_block_n, ++ block_k_override=g2_block_k, ++ num_warps_override=g2_num_warps, ++ output=output, ++ fuse_topk=True, ++ ) ++ + g2 = _grouped_gemm_mxfp8( + act_q, + act_s, +@@ -556,6 +677,7 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): + self.w1_bf16: torch.Tensor | None = None + self.w2_bf16: torch.Tensor | None = None + self.native_weights_available = True ++ self.bf16_weights_available = False + self.bf16_experts: TritonExperts | None = None + if _should_use_bf16_decode_fallback(moe_config): + bf16_config = biased_moe_quant_config( +@@ -583,6 +705,7 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): + self.w1_bf16 = w1_bf16 + self.w2_bf16 = w2_bf16 + self.native_weights_available = native_weights_available ++ self.bf16_weights_available = True + + def bind_packed_weight_scales( + self, +@@ -636,6 +759,7 @@ class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): + if bf16_experts is not None and _should_use_bf16_experts( + num_tokens, + self.native_weights_available, ++ self.bf16_weights_available, + ): + if self.w1_bf16 is None or self.w2_bf16 is None: + raise RuntimeError( +diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py +index 49957c8f5..26b3c6f74 100644 +--- a/vllm/model_executor/layers/fused_moe/fused_moe.py ++++ b/vllm/model_executor/layers/fused_moe/fused_moe.py +@@ -553,6 +553,168 @@ def fused_moe_kernel( + tl.store(c_ptrs, accumulator, mask=c_mask) + + ++@triton.jit ++def fused_moe_gated_kernel( ++ a_ptr, ++ b_ptr, ++ c_ptr, ++ sorted_token_ids_ptr, ++ expert_ids_ptr, ++ num_tokens_post_padded_ptr, ++ N, ++ K, ++ EM, ++ num_valid_tokens, ++ stride_am, ++ stride_ak, ++ stride_be, ++ stride_bk, ++ stride_bn, ++ stride_cm, ++ stride_cn, ++ alpha, ++ beta, ++ limit, ++ BLOCK_SIZE_M: tl.constexpr, ++ BLOCK_SIZE_N: tl.constexpr, ++ BLOCK_SIZE_K: tl.constexpr, ++ GROUP_SIZE_M: tl.constexpr, ++ top_k: tl.constexpr, ++ compute_type: tl.constexpr, ++ HAS_LIMIT: tl.constexpr, ++): ++ """BF16 grouped GEMM1 with a split SwiGLU-OAI epilogue.""" ++ pid = tl.program_id(axis=0) ++ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) ++ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) ++ num_pid_in_group = GROUP_SIZE_M * num_pid_n ++ group_id = pid // num_pid_in_group ++ first_pid_m = group_id * GROUP_SIZE_M ++ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) ++ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) ++ pid_n = (pid % num_pid_in_group) // group_size_m ++ ++ offs_m = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) ++ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) ++ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: ++ return ++ ++ offs_token_id = pid_m * BLOCK_SIZE_M + offs_m ++ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64) ++ token_mask = offs_token < num_valid_tokens ++ off_expert = tl.load(expert_ids_ptr + pid_m).to(tl.int64) ++ if off_expert < 0: ++ return ++ ++ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) ++ offs_k = tl.arange(0, BLOCK_SIZE_K) ++ n_mask = offs_n < N ++ a_ptrs = a_ptr + ( ++ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak ++ ) ++ b_gate_ptrs = ( ++ b_ptr ++ + off_expert * stride_be ++ + offs_k[:, None] * stride_bk ++ + offs_n[None, :] * stride_bn ++ ) ++ b_up_ptrs = b_gate_ptrs + N * stride_bn ++ ++ gate_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) ++ up_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) ++ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): ++ k_mask = offs_k < K - k * BLOCK_SIZE_K ++ a = tl.load( ++ a_ptrs, ++ mask=token_mask[:, None] & k_mask[None, :], ++ other=0.0, ++ ) ++ gate_w = tl.load( ++ b_gate_ptrs, ++ mask=k_mask[:, None] & n_mask[None, :], ++ other=0.0, ++ ) ++ up_w = tl.load( ++ b_up_ptrs, ++ mask=k_mask[:, None] & n_mask[None, :], ++ other=0.0, ++ ) ++ gate_acc += tl.dot(a, gate_w) ++ up_acc += tl.dot(a, up_w) ++ a_ptrs += BLOCK_SIZE_K * stride_ak ++ b_gate_ptrs += BLOCK_SIZE_K * stride_bk ++ b_up_ptrs += BLOCK_SIZE_K * stride_bk ++ ++ # Preserve the BF16 GEMM1 store/reload boundary of the unfused path before ++ # applying SwiGLU in FP32. ++ gate = gate_acc.to(compute_type).to(tl.float32) ++ up = up_acc.to(compute_type).to(tl.float32) ++ if HAS_LIMIT: ++ gate = tl.minimum(gate, limit) ++ up = tl.minimum(tl.maximum(up, -limit), limit) ++ activated = gate * tl.sigmoid(alpha * gate) * (up + beta) ++ ++ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_n[None, :] ++ c_mask = token_mask[:, None] & n_mask[None, :] ++ tl.store(c_ptrs, activated.to(compute_type), mask=c_mask) ++ ++ ++def invoke_fused_moe_gated_triton_kernel( ++ A: torch.Tensor, ++ B: torch.Tensor, ++ C: torch.Tensor, ++ sorted_token_ids: torch.Tensor, ++ expert_ids: torch.Tensor, ++ num_tokens_post_padded: torch.Tensor, ++ top_k: int, ++ config: dict[str, Any], ++ alpha: float, ++ beta: float, ++ limit: float | None, ++) -> None: ++ """Run BF16 GEMM1 and emit activated split-SwiGLU rows directly.""" ++ assert A.dtype == B.dtype == C.dtype == torch.bfloat16 ++ assert B.size(1) == C.size(1) * 2 ++ assert A.size(1) == B.size(2) ++ assert sorted_token_ids.stride(0) == 1 ++ ++ EM = sorted_token_ids.size(0) ++ grid = lambda META: ( ++ triton.cdiv(EM, META["BLOCK_SIZE_M"]) ++ * triton.cdiv(C.size(1), META["BLOCK_SIZE_N"]), ++ ) ++ launch_config = config.copy() ++ launch_config.pop("SPLIT_K", None) ++ block_size_k = launch_config.pop("BLOCK_SIZE_K") ++ fused_moe_gated_kernel[grid]( ++ A, ++ B, ++ C, ++ sorted_token_ids, ++ expert_ids, ++ num_tokens_post_padded, ++ C.size(1), ++ A.size(1), ++ EM, ++ A.size(0) * top_k, ++ A.stride(0), ++ A.stride(1), ++ B.stride(0), ++ B.stride(2), ++ B.stride(1), ++ C.stride(0), ++ C.stride(1), ++ alpha, ++ beta, ++ 0.0 if limit is None else limit, ++ top_k=top_k, ++ compute_type=tl.bfloat16, ++ HAS_LIMIT=limit is not None, ++ BLOCK_SIZE_K=block_size_k, ++ **launch_config, ++ ) ++ ++ + # NOTE(zyongye): we can remove all the wna16 kernel + # once we drop off sm75 support + def invoke_fused_moe_wna16_cuda_kernel( +@@ -1434,6 +1596,7 @@ def _prepare_expert_assignment( + use_int4_w4a16: bool = False, + block_shape: list[int] | None = None, + ignore_invalid_experts: bool = False, ++ num_local_experts: int | None = None, + ) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor]: + """Prepare expert assignments for the aligned and low-latency Triton paths.""" + # SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k +@@ -1468,6 +1631,7 @@ def _prepare_expert_assignment( + global_num_experts, + expert_map, + ignore_invalid_experts=ignore_invalid_experts, ++ num_local_experts=num_local_experts, + ) + + +diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +index 7fc8bfcf8..99404ed95 100644 +--- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py ++++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +@@ -15,6 +15,7 @@ def moe_align_block_size( + expert_map: torch.Tensor | None = None, + pad_sorted_ids: bool = False, + ignore_invalid_experts: bool = False, ++ num_local_experts: int | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns the token distribution across experts to be compatible with block +@@ -43,6 +44,9 @@ def moe_align_block_size( + as -1. When True, all invalid expert_ids in topk_ids will be ignored + and will not participate in counting or ranking, and there will be no + -1 in expert_ids. ++ - num_local_experts: The number of experts retained by ``expert_map``. ++ When invalid experts are ignored, this tightens the output allocation ++ from the global expert count to the local expert count. + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according +@@ -71,7 +75,20 @@ def moe_align_block_size( + - The padding ensures that the total number of tokens is now divisible + by block_size for proper block matrix operations. + """ +- max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) ++ padding_experts = num_experts ++ if ( ++ ignore_invalid_experts ++ and expert_map is not None ++ and num_local_experts is not None ++ ): ++ if not 0 < num_local_experts <= num_experts: ++ raise ValueError( ++ "num_local_experts must be in (0, num_experts], got " ++ f"{num_local_experts} for num_experts={num_experts}." ++ ) ++ padding_experts = num_local_experts ++ ++ max_num_tokens_padded = topk_ids.numel() + padding_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + if topk_ids.numel() < num_experts: +diff --git a/vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py b/vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py +index 768f41db8..0465112ee 100644 +--- a/vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py ++++ b/vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py +@@ -17,6 +17,7 @@ def moe_fused_mul_sum_kernel( + num_tokens, + stride_m, + has_expert_map: tl.constexpr, ++ apply_weights: tl.constexpr, + top_k: tl.constexpr, + size: tl.constexpr, + BLOCK_M: tl.constexpr, +@@ -38,7 +39,10 @@ def moe_fused_mul_sum_kernel( + acc = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) + + for n in tl.static_range(top_k): +- b_val = tl.load(b_base + n, mask=m_mask, other=0.0).to(tl.float32) ++ if apply_weights: ++ b_val = tl.load(b_base + n, mask=m_mask, other=0.0).to(tl.float32) ++ else: ++ b_val = 1.0 + if has_expert_map: + id_val = tl.load(top_ids_ptr + offs_m * top_k + n, mask=m_mask, other=0) + expert_mask = tl.load(expert_map_ptr + id_val) >= 0 +@@ -138,6 +142,7 @@ def moe_fused_mul_sum( + outputs: torch.Tensor | None = None, + topk_ids: torch.Tensor | None = None, + expert_map: torch.Tensor | None = None, ++ apply_weights: bool = True, + ) -> torch.Tensor: + """ + Fused kernel for MoE (Mixture of Experts) to perform weighted summation +@@ -154,6 +159,8 @@ def moe_fused_mul_sum( + `expert_map` is provided. Shape: (num_tokens, top_k). + expert_map: Optional mapping for Expert Parallelism. A value < 0 + indicates an invalid token/expert pair that will be skipped. ++ apply_weights: Multiply each route by ``topk_weights`` before summing. ++ Set to false when the expert GEMM already applied router weights. + + Returns: + The fused weighted sum of expert outputs. +@@ -191,6 +198,7 @@ def moe_fused_mul_sum( + num_tokens, + top_k * size, + expert_map is not None, ++ apply_weights, + top_k, + size, + BLOCK_M, +diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py +index bc00da41e..378b897a7 100644 +--- a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py ++++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py +@@ -82,14 +82,25 @@ def _select_rocm_mxfp8_backend( + """ROCm fallback when vendor MXFP8 backends are unavailable.""" + + if current_platform.is_fp8_fnuz() and config.ep_size > 1: ++ from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( ++ Mxfp8NativeTritonExperts, ++ _should_use_native_ep, ++ ) ++ ++ if _should_use_native_ep(config): ++ logger.info_once( ++ "Using the profiled gfx94x MiniMax-M3 EP8 MXFP8 backend: " ++ "native local-route kernels with retained BF16 decode experts." ++ ) ++ return Fp8MoeBackend.NATIVE_MXFP8, Mxfp8NativeTritonExperts ++ + from vllm.model_executor.layers.fused_moe.experts.mxfp8_emulation_moe import ( + Mxfp8EmulationTritonExperts, + ) + + logger.info_once( +- "Using BF16 MXFP8 emulation for gfx94x expert parallelism; the " +- "native CDNA3 path is optimized for decode-sized TP workloads and " +- "is slower for the large local batches reached during EP prefill." ++ "Using the profiled sparse BF16 MXFP8 path for long-context gfx94x " ++ "expert parallelism." + ) + return Fp8MoeBackend.EMULATION, Mxfp8EmulationTritonExperts + +diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py +index 1a4049d3f..8ea843e7d 100644 +--- a/vllm/model_executor/layers/quantization/modelopt.py ++++ b/vllm/model_executor/layers/quantization/modelopt.py +@@ -2177,8 +2177,8 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): + + if self.moe.max_model_len <= 4096: + logger.info_once( +- "Retaining BF16 MXFP8 MoE weights for gfx94x TP decode and " +- "prefill dispatch." ++ "Retaining BF16 MXFP8 MoE weights for profiled gfx94x " ++ "native/BF16 dispatch." + ) + else: + logger.info_once( +diff --git a/vllm/models/minimax_m3/amd/ops/__init__.py b/vllm/models/minimax_m3/amd/ops/__init__.py +index 22d96f9de..a2672b05d 100644 +--- a/vllm/models/minimax_m3/amd/ops/__init__.py ++++ b/vllm/models/minimax_m3/amd/ops/__init__.py +@@ -13,6 +13,7 @@ from vllm.models.minimax_m3.amd.ops.gemma_rmsnorm import ( + ) + from vllm.models.minimax_m3.amd.ops.swiglu_oai import ( + swiglu_oai_quantize_mxfp8, ++ swiglu_oai_quantize_mxfp8_routed, + swiglu_oai_split, + ) + +@@ -21,4 +22,5 @@ __all__ = [ + "gemma_fused_add_rmsnorm", + "swiglu_oai_split", + "swiglu_oai_quantize_mxfp8", ++ "swiglu_oai_quantize_mxfp8_routed", + ] +diff --git a/vllm/models/minimax_m3/amd/ops/swiglu_oai.py b/vllm/models/minimax_m3/amd/ops/swiglu_oai.py +index 9572c5109..7fca053fb 100644 +--- a/vllm/models/minimax_m3/amd/ops/swiglu_oai.py ++++ b/vllm/models/minimax_m3/amd/ops/swiglu_oai.py +@@ -124,6 +124,77 @@ def _swiglu_oai_quant_kernel( + ) + + ++@triton.jit ++def _swiglu_oai_quant_routed_kernel( ++ g_ptr, ++ aq_ptr, ++ as_ptr, ++ sorted_token_ids_ptr, ++ num_tokens_post_padded_ptr, ++ num_valid_tokens, ++ n_inter, ++ stride_gm, ++ stride_gn, ++ stride_qm, ++ stride_qn, ++ stride_sm, ++ stride_sk, ++ alpha, ++ beta, ++ limit, ++ HAS_LIMIT: tl.constexpr, ++ BLOCK_M: tl.constexpr, ++): ++ """Route-aware SwiGLU-OAI + MXFP8 quantization for expert parallelism.""" ++ pid_m = tl.program_id(0) ++ pid_b = tl.program_id(1) ++ num_post = tl.load(num_tokens_post_padded_ptr) ++ if pid_m * BLOCK_M >= num_post: ++ return ++ ++ offs_tid = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) ++ route_mask = offs_tid < num_post ++ route_ids = tl.load( ++ sorted_token_ids_ptr + offs_tid, ++ mask=route_mask, ++ other=num_valid_tokens, ++ ).to(tl.int64) ++ route_mask = route_mask & (route_ids < num_valid_tokens) ++ safe_route_ids = tl.where(route_mask, route_ids, 0) ++ offs_c = pid_b * 32 + tl.arange(0, 32) ++ ++ gate = tl.load( ++ g_ptr + safe_route_ids[:, None] * stride_gm + offs_c[None, :] * stride_gn, ++ mask=route_mask[:, None], ++ other=0.0, ++ ).to(tl.float32) ++ up = tl.load( ++ g_ptr ++ + safe_route_ids[:, None] * stride_gm ++ + (n_inter + offs_c)[None, :] * stride_gn, ++ mask=route_mask[:, None], ++ other=0.0, ++ ).to(tl.float32) ++ if HAS_LIMIT: ++ gate = tl.minimum(gate, limit) ++ up = tl.minimum(tl.maximum(up, -limit), limit) ++ act = gate * tl.sigmoid(alpha * gate) * (up + beta) ++ amax = tl.maximum(tl.max(tl.abs(act), axis=1), 1e-30) ++ sb = tl.minimum(tl.maximum(tl.floor(tl.log2(amax)) + 127.0, 0.0), 254.0) ++ descale = tl.exp2(sb - 127.0) ++ aq = (act / descale[:, None]).to(aq_ptr.dtype.element_ty) ++ tl.store( ++ aq_ptr + safe_route_ids[:, None] * stride_qm + offs_c[None, :] * stride_qn, ++ aq, ++ mask=route_mask[:, None], ++ ) ++ tl.store( ++ as_ptr + safe_route_ids * stride_sm + pid_b * stride_sk, ++ sb.to(tl.uint8), ++ mask=route_mask, ++ ) ++ ++ + def swiglu_oai_quantize_mxfp8( + gate_up: torch.Tensor, + alpha: float, +@@ -181,6 +252,79 @@ def swiglu_oai_quantize_mxfp8( + return aq, asc + + ++def swiglu_oai_quantize_mxfp8_routed( ++ gate_up: torch.Tensor, ++ sorted_token_ids: torch.Tensor, ++ num_tokens_post_padded: torch.Tensor, ++ *, ++ num_valid_tokens: int, ++ max_num_tokens_post_padded: int, ++ alpha: float, ++ beta: float, ++ limit: float | None, ++ block_m: int, ++) -> tuple[torch.Tensor, torch.Tensor]: ++ """Quantize only the locally routed rows produced by an EP GEMM1.""" ++ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( ++ MXFP8_BLOCK_SIZE, ++ MXFP8_SCALE_DTYPE, ++ MXFP8_VALUE_DTYPE, ++ ) ++ ++ two_i = gate_up.shape[-1] ++ n_inter = two_i // 2 ++ if n_inter % MXFP8_BLOCK_SIZE != 0: ++ raise ValueError( ++ f"fused swiglu+quant needs I % {MXFP8_BLOCK_SIZE} == 0, got I={n_inter}" ++ ) ++ if max_num_tokens_post_padded % block_m != 0: ++ raise ValueError( ++ "max_num_tokens_post_padded must be block aligned, got " ++ f"{max_num_tokens_post_padded} for block_m={block_m}." ++ ) ++ ++ g1 = gate_up.reshape(-1, two_i).contiguous() ++ value_dtype = ( ++ torch.float8_e4m3fnuz if current_platform.is_fp8_fnuz() else MXFP8_VALUE_DTYPE ++ ) ++ aq = torch.empty( ++ (num_valid_tokens, n_inter), ++ dtype=value_dtype, ++ device=g1.device, ++ ) ++ asc = torch.empty( ++ (num_valid_tokens, n_inter // MXFP8_BLOCK_SIZE), ++ dtype=MXFP8_SCALE_DTYPE, ++ device=g1.device, ++ ) ++ grid = ( ++ max_num_tokens_post_padded // block_m, ++ n_inter // MXFP8_BLOCK_SIZE, ++ ) ++ _swiglu_oai_quant_routed_kernel[grid]( ++ g1, ++ aq, ++ asc, ++ sorted_token_ids, ++ num_tokens_post_padded, ++ num_valid_tokens, ++ n_inter, ++ g1.stride(0), ++ g1.stride(1), ++ aq.stride(0), ++ aq.stride(1), ++ asc.stride(0), ++ asc.stride(1), ++ float(alpha), ++ float(beta), ++ 0.0 if limit is None else float(limit), ++ HAS_LIMIT=limit is not None, ++ BLOCK_M=block_m, ++ num_warps=4, ++ ) ++ return aq, asc ++ ++ + def swiglu_oai_split( + gate_up: torch.Tensor, + alpha: float,