diff --git a/.github/configs/amd-master.yaml b/.github/configs/amd-master.yaml index 7e4918e09..25dd37c16 100644 --- a/.github/configs/amd-master.yaml +++ b/.github/configs/amd-master.yaml @@ -2847,10 +2847,11 @@ minimaxm3-fp8-mi355x-vllm-mtp: - { tp: 4, conc-start: 1, conc-end: 64, spec-decoding: mtp } - { tp: 8, ep: 8, dp-attn: true, conc-start: 128, conc-end: 256, spec-decoding: mtp } -# MiniMax-M3 MXFP8 MI300X day-zero recipe. Reuse the dedicated ROCm image and -# MI355X serving shape, but retain the default BF16 KV cache because this -# checkpoint lacks calibrated ROCm FP8 attention scales. Use the TP8-only H100 -# search space: TP8 for latency and TP8+EP8 (TEP) at high concurrency. +# MiniMax-M3 MXFP8 MI300X recipe. Apply the checked-in hybrid gfx94x MXFP8 MoE +# patch to the dedicated ROCm image: BF16 for small TP batches and EP, native +# compressed MXFP8 for larger TP batches and long context. Retain the default +# BF16 KV cache because this checkpoint lacks calibrated ROCm FP8 attention +# scales. Use TP8 for latency and TP8+EP8 at high concurrency. minimaxm3-fp8-mi300x-vllm: image: vllm/vllm-openai-rocm:minimax-m3 model: MiniMaxAI/MiniMax-M3-MXFP8 diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh index f2cdaf284..4fbe92bcd 100755 --- a/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh @@ -1,10 +1,12 @@ #!/usr/bin/env bash # MiniMax-M3 MXFP8 MI300X (gfx942) single-node vLLM recipe. -# Reuses the dedicated ROCm image and the MI355X serving shape. Block size 128 -# is mandatory for MSA sparse attention. Keep the default BF16 KV cache on -# gfx942: the checkpoint has no calibrated q/prob scales for ROCm FP8 -# attention, and vLLM's fallback scale of 1.0 corrupts model accuracy. +# Reuses the dedicated ROCm image and applies the checked-in hybrid gfx94x +# MXFP8 MoE patch before starting vLLM. Block size 128 is mandatory for MSA +# sparse attention. Keep the default BF16 KV cache on gfx942: the checkpoint +# has no calibrated q/prob scales for ROCm FP8 attention, and vLLM's fallback +# scale of 1.0 corrupts model accuracy. +# Target image vLLM revision: 4a560dd8db67c270f5e2afb614558271b76f2294. source "$(dirname "$0")/../../benchmark_lib.sh" @@ -24,6 +26,28 @@ if [[ -n "$SLURM_JOB_ID" ]]; then echo "JOB $SLURM_JOB_ID running on $SLURMD_NODENAME" fi +VLLM_PACKAGE_ROOT="$( + python - <<'PY' +from pathlib import Path + +import vllm + +print(Path(vllm.__file__).resolve().parent.parent) +PY +)" +MXFP8_PATCH="$(dirname "$0")/minimaxm3_mi300x_mxfp8.patch" +MXFP8_ORACLE="$VLLM_PACKAGE_ROOT/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py" +if ! grep -q "Using fused CDNA3 (gfx94x)" "$MXFP8_ORACLE"; then + if ! patch --batch --forward -d "$VLLM_PACKAGE_ROOT" -p1 < "$MXFP8_PATCH"; then + echo "Failed to apply the MI300X MXFP8 patch" >&2 + exit 1 + fi +fi +if ! grep -q "Using fused CDNA3 (gfx94x)" "$MXFP8_ORACLE"; then + echo "MI300X MXFP8 backend marker is missing after patching" >&2 + exit 1 +fi + if [[ "$MODEL" != /* ]]; then hf download "$MODEL"; fi if [ -n "$ROCR_VISIBLE_DEVICES" ]; then diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch new file mode 100644 index 000000000..b391d59f1 --- /dev/null +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3_mi300x_mxfp8.patch @@ -0,0 +1,1040 @@ +diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py +index 0755699d1..4f6046d88 100644 +--- a/vllm/model_executor/layers/fused_moe/config.py ++++ b/vllm/model_executor/layers/fused_moe/config.py +@@ -1276,7 +1276,9 @@ class FusedMoEConfig: + + moe_backend: MoEBackend = "auto" + max_num_tokens: int = SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP ++ max_model_len: int = 0 + has_bias: bool = False ++ has_shared_experts: bool = False + is_lora_enabled: bool = False + + # SwiGLU clamp limit. When set, backends that do not implement the clamp +diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=AMD_Instinct_MI300X.json +new file mode 100644 +index 000000000..201cfad15 +--- /dev/null ++++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=AMD_Instinct_MI300X.json +@@ -0,0 +1,35 @@ ++{ ++ "1": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 1, ++ "SPLIT_K": 1, ++ "num_warps": 4, ++ "num_stages": 2 ++ }, ++ "1024": { ++ "BLOCK_SIZE_M": 128, ++ "BLOCK_SIZE_N": 128, ++ "BLOCK_SIZE_K": 64, ++ "GROUP_SIZE_M": 8, ++ "SPLIT_K": 1, ++ "num_warps": 8, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ }, ++ "8192": { ++ "BLOCK_SIZE_M": 128, ++ "BLOCK_SIZE_N": 128, ++ "BLOCK_SIZE_K": 64, ++ "GROUP_SIZE_M": 8, ++ "SPLIT_K": 1, ++ "num_warps": 8, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ } ++} +diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=AMD_Instinct_MI300X.json +new file mode 100644 +index 000000000..f9de47ad6 +--- /dev/null ++++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=AMD_Instinct_MI300X.json +@@ -0,0 +1,53 @@ ++{ ++ "1": { ++ "BLOCK_SIZE_M": 16, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 1, ++ "SPLIT_K": 1, ++ "num_warps": 4, ++ "num_stages": 2 ++ }, ++ "128": { ++ "BLOCK_SIZE_M": 64, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 8, ++ "SPLIT_K": 1, ++ "num_warps": 4, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ }, ++ "256": { ++ "BLOCK_SIZE_M": 64, ++ "BLOCK_SIZE_N": 64, ++ "BLOCK_SIZE_K": 128, ++ "GROUP_SIZE_M": 8, ++ "SPLIT_K": 1, ++ "num_warps": 4, ++ "num_stages": 2, ++ "waves_per_eu": 0, ++ "matrix_instr_nonkdim": 16, ++ "kpack": 2 ++ }, ++ "1024": { ++ "BLOCK_SIZE_M": 128, ++ "BLOCK_SIZE_N": 128, ++ "BLOCK_SIZE_K": 64, ++ "GROUP_SIZE_M": 1, ++ "SPLIT_K": 1, ++ "num_warps": 8, ++ "num_stages": 2 ++ }, ++ "8192": { ++ "BLOCK_SIZE_M": 128, ++ "BLOCK_SIZE_N": 128, ++ "BLOCK_SIZE_K": 64, ++ "GROUP_SIZE_M": 16, ++ "SPLIT_K": 1, ++ "num_warps": 8, ++ "num_stages": 2 ++ } ++} +diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py +index 71dd7634a..63500487d 100644 +--- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py ++++ b/vllm/model_executor/layers/fused_moe/experts/mxfp8_emulation_moe.py +@@ -4,7 +4,7 @@ + + ``Mxfp8TritonExpertsBase`` stashes E8M0 weight scales for checkpoint layout. + ``Mxfp8EmulationTritonExperts`` dequantizes to BF16 and runs ``TritonExperts`` +-for devices without a native MXFP8 MoE kernel (e.g. ROCm gfx942 / MI300). ++for devices without a fused MXFP8 MoE kernel. + """ + + import torch +diff --git a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py +index 33851fdc8..9e0145ff9 100644 +--- a/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py ++++ b/vllm/model_executor/layers/fused_moe/experts/mxfp8_native_moe.py +@@ -1,28 +1,35 @@ + # SPDX-License-Identifier: Apache-2.0 + # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +-"""Native MXFP8 (1x32 block, E8M0 scale) MoE for AMD CDNA4 (gfx950) via Triton +-``tl.dot_scaled`` (hardware microscaling matmul). ++"""Fused MXFP8 (1x32 block, E8M0 scale) MoE for AMD CDNA3/CDNA4. + + The expert GEMMs consume the FP8 E4M3 weights and their E8M0 block scales + directly (no dequant-to-BF16), and activations are MXFP8-quantized per token. +-On CDNA4 ``dot_scaled`` maps to the native MX matrix-core ops; on other archs +-Triton upcasts to BF16 (so this stays correct, just not faster) — but the +-oracle only selects this path on gfx950 and routes everything else to the +-BF16 ``Mxfp8EmulationTritonExperts`` fallback. ++CDNA4 uses ``tl.dot_scaled`` and native MX matrix-core ops. CDNA3 stores the ++weights as E4M3FNUZ, runs one native FP8 ``tl.dot`` per 32-value MX block, and ++applies the E8M0 scale products in-register. Both paths keep weights compressed ++in HBM instead of expanding them to persistent BF16. + + Structure mirrors vLLM's ``fused_moe_kernel``: tokens are sorted by expert + (``moe_align_block_size``); each program computes a ``[BLOCK_M, BLOCK_N]`` tile +-for one expert, accumulating over K with ``dot_scaled``. SwiGLU-OAI activation +-and the top-k weighted reduction run in PyTorch between/after the two GEMMs. ++for one expert, accumulating over K with the architecture-specific fused path. ++SwiGLU-OAI activation and the top-k weighted reduction run between/after the ++two GEMMs. + """ + + import torch + + import vllm.model_executor.layers.fused_moe.modular_kernel as mk ++from vllm import _custom_ops as ops + from vllm.logger import init_logger ++from vllm.model_executor.layers.fused_moe.config import ( ++ FusedMoEConfig, ++ FusedMoEQuantConfig, ++ biased_moe_quant_config, ++) + from vllm.model_executor.layers.fused_moe.experts.mxfp8_emulation_moe import ( + Mxfp8TritonExpertsBase, + ) ++from vllm.model_executor.layers.fused_moe.experts.triton_moe import TritonExperts + from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size, + ) +@@ -34,9 +41,46 @@ + + logger = init_logger(__name__) + ++_BF16_DECODE_TOKEN_THRESHOLD = 8 ++# MiniMax-M3 eager refill shapes cross over to the retained BF16 experts ++# between 827 and 843 tokens on MI300X. Keep the cutoff tile-aligned. ++_BF16_PREFILL_TOKEN_THRESHOLD = 832 ++_LONG_CONTEXT_BF16_ONLY_LAYER_STRIDE = 5 + ++ ++def _should_use_bf16_decode_fallback(moe_config: FusedMoEConfig) -> bool: ++ """Limit BF16 fallback weights to the exact MiniMax-M3 TP shape.""" ++ return ( ++ current_platform.is_fp8_fnuz() ++ and moe_config.ep_size == 1 ++ and moe_config.has_shared_experts ++ and moe_config.num_experts == 128 ++ and moe_config.experts_per_token == 4 ++ and moe_config.hidden_dim == 6144 ++ and moe_config.intermediate_size == 3072 ++ and moe_config.max_model_len > 0 ++ ) ++ ++ ++def _should_store_bf16_only(max_model_len: int, layer_index: int) -> bool: ++ return ( ++ max_model_len > 4096 and layer_index % _LONG_CONTEXT_BF16_ONLY_LAYER_STRIDE == 0 ++ ) ++ ++ ++def _should_use_bf16_experts( ++ num_tokens: int, ++ native_weights_available: bool, ++) -> bool: ++ return ( ++ not native_weights_available ++ or num_tokens >= _BF16_PREFILL_TOKEN_THRESHOLD ++ or num_tokens <= _BF16_DECODE_TOKEN_THRESHOLD ++ ) ++ ++ + @triton.jit +-def _mxfp8_grouped_gemm_kernel( ++def _mxfp8_grouped_gemm_dot_scaled_kernel( + a_ptr, + a_scale_ptr, + b_ptr, +@@ -125,6 +169,148 @@ + ) + + ++@triton.jit ++def _mxfp8_grouped_gemm_fnuz_kernel( ++ a_ptr, ++ a_scale_ptr, ++ b_ptr, ++ b_scale_ptr, ++ c_ptr, ++ topk_weights_ptr, ++ sorted_token_ids_ptr, ++ expert_ids_ptr, ++ num_tokens_post_padded_ptr, ++ N, ++ K, ++ num_valid_tokens, ++ top_k, ++ stride_am, ++ stride_ak, ++ stride_asm, ++ stride_ask, ++ stride_be, ++ stride_bn, ++ stride_bk, ++ stride_bse, ++ stride_bsn, ++ stride_bsk, ++ stride_cm, ++ stride_cn, ++ A_DIV: tl.constexpr, ++ MUL_WEIGHT: tl.constexpr, ++ BLOCK_M: tl.constexpr, ++ BLOCK_N: tl.constexpr, ++ BLOCK_K: tl.constexpr, ++): ++ pid_m = tl.program_id(0) ++ pid_n = tl.program_id(1) ++ num_post = tl.load(num_tokens_post_padded_ptr) ++ if pid_m * BLOCK_M >= num_post: ++ return ++ ++ offs_tid = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) ++ offs_token = tl.load(sorted_token_ids_ptr + offs_tid).to(tl.int64) ++ token_mask = offs_token < num_valid_tokens ++ off_e = tl.load(expert_ids_ptr + pid_m).to(tl.int64) ++ ++ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) ++ offs_k = tl.arange(0, 32) ++ a_row = offs_token // A_DIV ++ ++ a_ptrs = a_ptr + a_row[:, None] * stride_am + offs_k[None, :] * stride_ak ++ as_ptrs = a_scale_ptr + a_row * stride_asm ++ b_ptrs = ( ++ b_ptr ++ + off_e * stride_be ++ + offs_n[:, None] * stride_bn ++ + offs_k[None, :] * stride_bk ++ ) ++ bs_ptrs = b_scale_ptr + off_e * stride_bse + offs_n * stride_bsn ++ ++ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) ++ n_mask = offs_n < N ++ for _ in range(0, tl.cdiv(K, BLOCK_K)): ++ for k_offset in tl.static_range(0, BLOCK_K, 32): ++ a = tl.load( ++ a_ptrs + k_offset * stride_ak, ++ mask=token_mask[:, None], ++ other=0.0, ++ ) ++ b = tl.load( ++ b_ptrs + k_offset * stride_bk, ++ mask=n_mask[:, None], ++ other=0.0, ++ ) ++ asc = tl.load( ++ as_ptrs + (k_offset // 32) * stride_ask, ++ mask=token_mask, ++ other=0, ++ ).to(tl.uint16) ++ bsc = tl.load( ++ bs_ptrs + (k_offset // 32) * stride_bsk, ++ mask=n_mask, ++ other=0, ++ ).to(tl.uint16) ++ ++ # E8M0 and BF16 use the same eight-bit biased exponent. Shift each ++ # scale byte into a BF16 exponent field, as Marlin does, then form ++ # the per-token/per-output scale product around the FP8 dot. ++ asc_scale = (asc << 7).to(tl.bfloat16, bitcast=True) ++ bsc_scale = (bsc << 7).to(tl.bfloat16, bitcast=True) ++ block_scale = asc_scale[:, None].to(tl.float32) * bsc_scale[None, :].to( ++ tl.float32 ++ ) ++ acc += tl.dot(a, b.T) * block_scale ++ ++ a_ptrs += BLOCK_K * stride_ak ++ b_ptrs += BLOCK_K * stride_bk ++ as_ptrs += (BLOCK_K // 32) * stride_ask ++ bs_ptrs += (BLOCK_K // 32) * stride_bsk ++ ++ if MUL_WEIGHT: ++ w = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0) ++ acc = acc * w[:, None] ++ ++ c_ptrs = c_ptr + offs_token[:, None] * stride_cm + offs_n[None, :] * stride_cn ++ tl.store( ++ c_ptrs, ++ acc.to(c_ptr.dtype.element_ty), ++ mask=token_mask[:, None] & n_mask[None, :], ++ ) ++ ++ ++def _gfx94x_grouped_gemm_config( ++ m_routed: int, ++ n: int, ++ k: int, ++ block_m: int, ++ is_gemm2: bool, ++) -> tuple[int, int, int]: ++ short_k_gemm2 = is_gemm2 and block_m <= 16 and k <= 512 and k % 64 == 0 ++ if short_k_gemm2: ++ # MiniMax-M3 TP GEMM2 has a wide N=6144 and short K=384. Pairing two ++ # waves over 64 columns amortizes indexing while a 64-wide K tile ++ # exposes enough independent work for the short reduction. ++ return 64, 64, 2 ++ ++ block_k = 128 if k % 128 == 0 and block_m <= 16 else 64 if k % 64 == 0 else 32 ++ if block_m <= 16: ++ # One wave per 32 output columns avoids the register pressure of the ++ # original 128-column tile. At the very smallest routed batch, pairing ++ # two waves in a 64-column program amortizes launch/indexing overhead. ++ block_n = 64 if m_routed < 32 else 32 ++ num_warps = 2 if m_routed < 32 else 1 ++ elif block_m >= 64 and n >= 2048 and k >= 2048: ++ # EP prefill GEMMs remain register-bound at a 128-column tile even with ++ # 64 rows. Two-wave 64-column programs expose more independent work. ++ block_n = 64 ++ num_warps = 2 ++ else: ++ block_n = 128 ++ num_warps = 4 if block_m <= 32 else 8 ++ return block_n, block_k, num_warps ++ ++ + def _grouped_gemm_mxfp8( + a_q: torch.Tensor, # [M, K] fp8 e4m3 + a_scale: torch.Tensor, # [M, K//32] uint8 (E8M0) +@@ -140,19 +326,81 @@ + a_div: int, + mul_weight_by: torch.Tensor | None = None, + expert_map: torch.Tensor | None = None, ++ is_gemm2: bool = False, ++ block_n_override: int = 0, ++ block_k_override: int = 0, ++ num_warps_override: int = 0, + ) -> torch.Tensor: + M_routed = num_valid_tokens + E, N, K = w.shape +- assert K % 128 == 0, f"MXFP8 native MoE requires K%128==0, got K={K}" ++ k_alignment = 32 if current_platform.is_fp8_fnuz() else 128 ++ assert K % k_alignment == 0, ( ++ f"MXFP8 native MoE requires K%{k_alignment}==0, got K={K}" ++ ) ++ if w_scale.shape == (E, N, K // 32): ++ scale_stride_e = w_scale.stride(0) ++ scale_stride_n = w_scale.stride(1) ++ scale_stride_k = w_scale.stride(2) ++ elif w_scale.shape == (E, K // 32, N): ++ scale_stride_e = w_scale.stride(0) ++ scale_stride_n = w_scale.stride(2) ++ scale_stride_k = w_scale.stride(1) ++ else: ++ raise ValueError( ++ "MXFP8 weight scales must use [E, N, K/32] or packed " ++ f"[E, K/32, N] layout, got {tuple(w_scale.shape)}." ++ ) ++ is_fnuz = current_platform.is_fp8_fnuz() ++ if is_fnuz: ++ BLOCK_N, BLOCK_K, num_warps = _gfx94x_grouped_gemm_config( ++ M_routed, ++ N, ++ K, ++ block_m, ++ is_gemm2, ++ ) ++ else: ++ BLOCK_N = 128 ++ BLOCK_K = 128 ++ num_warps = 8 ++ # moe_align_block_size allocates for the worst case where every expert is ++ # active. At small batches that can be much larger than the number of ++ # blocks that can contain valid assignments. Limit the launch to the ++ # tighter static upper bound; the device-side num_post check handles the ++ # remaining tail. ++ max_post_padded = min(sorted_token_ids.shape[0], M_routed * block_m) ++ if block_n_override: ++ BLOCK_N = block_n_override ++ if block_k_override: ++ BLOCK_K = block_k_override ++ if num_warps_override: ++ num_warps = num_warps_override ++ if BLOCK_K % 32 != 0 or K % BLOCK_K != 0: ++ raise ValueError( ++ f"MXFP8 grouped GEMM requires BLOCK_K to divide K in 32-value " ++ f"units, got BLOCK_K={BLOCK_K}, K={K}." ++ ) ++ if num_warps not in (1, 2, 4, 8): ++ raise ValueError(f"Unsupported num_warps={num_warps}.") ++ m_blocks = triton.cdiv(max_post_padded, block_m) ++ n_blocks = triton.cdiv(N, BLOCK_N) ++ + # Under expert parallelism (expert_map set) tokens routed to non-local + # experts are dropped from sorted_token_ids, so their output rows are never +- # written — zero them so the downstream reduction ignores their garbage. ++ # written. + alloc = torch.zeros if expert_map is not None else torch.empty + out = alloc((M_routed, N), dtype=out_dtype, device=a_q.device) +- BLOCK_N = 128 +- BLOCK_K = 128 +- grid = (triton.cdiv(sorted_token_ids.shape[0], block_m), triton.cdiv(N, BLOCK_N)) +- _mxfp8_grouped_gemm_kernel[grid]( ++ grid = (m_blocks, n_blocks) ++ kernel = ( ++ _mxfp8_grouped_gemm_fnuz_kernel ++ if current_platform.is_fp8_fnuz() ++ else _mxfp8_grouped_gemm_dot_scaled_kernel ++ ) ++ if current_platform.is_fp8_fnuz() and ( ++ a_q.dtype != torch.float8_e4m3fnuz or w.dtype != torch.float8_e4m3fnuz ++ ): ++ raise ValueError("gfx94x MXFP8 MoE requires E4M3FNUZ inputs.") ++ kernel[grid]( + a_q, + a_scale, + w, +@@ -173,9 +421,9 @@ + w.stride(0), + w.stride(1), + w.stride(2), +- w_scale.stride(0), +- w_scale.stride(1), +- w_scale.stride(2), ++ scale_stride_e, ++ scale_stride_n, ++ scale_stride_k, + out.stride(0), + out.stride(1), + A_DIV=a_div, +@@ -183,7 +431,7 @@ + BLOCK_M=block_m, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, +- num_warps=8, ++ num_warps=num_warps, + ) + return out + +@@ -202,12 +450,26 @@ + limit: float | None, + global_num_experts: int, + expert_map: torch.Tensor | None, ++ output: torch.Tensor | None = None, ++ g1_block_n: int = 0, ++ g1_block_k: int = 0, ++ g1_num_warps: int = 0, ++ g2_block_n: int = 0, ++ g2_block_k: int = 0, ++ g2_num_warps: int = 0, + ) -> torch.Tensor: + T, H = hidden_states.shape + top_k = topk_ids.shape[1] + M = T * top_k + +- block_m = 64 ++ if current_platform.is_fp8_fnuz(): ++ # Padding is per expert, so tile from the average expert occupancy ++ # rather than the total routed-token count. MiniMax-M3 has 128 experts; ++ # a 64-row tile wastes most of both GEMMs at low occupancy. ++ tokens_per_expert = max(1, M // global_num_experts) ++ block_m = max(16, min(1 << (tokens_per_expert - 1).bit_length(), 64)) ++ else: ++ block_m = 64 + sorted_ids, expert_ids, num_post = moe_align_block_size( + topk_ids, + block_m, +@@ -232,6 +494,9 @@ + hidden_states.dtype, + a_div=top_k, + expert_map=expert_map, ++ block_n_override=g1_block_n, ++ block_k_override=g1_block_k, ++ num_warps_override=g1_num_warps, + ) # [M, 2I] + + # SwiGLU-OAI (split layout: gate=g1[:, :I], up=g1[:, I:]) FUSED with the +@@ -258,15 +523,76 @@ + block_m, +- torch.float32, ++ hidden_states.dtype if current_platform.is_fp8_fnuz() else torch.float32, + a_div=1, + mul_weight_by=topk_weights.reshape(-1).to(torch.float32), + expert_map=expert_map, +- ) # [M, H] == [T*top_k, H] +- +- return g2.view(T, top_k, H).sum(dim=1).to(hidden_states.dtype) ++ is_gemm2=True, ++ block_n_override=g2_block_n, ++ block_k_override=g2_block_k, ++ num_warps_override=g2_num_warps, ++ ) # [M, H] == [T*top_k, H] + ++ if current_platform.is_fp8_fnuz(): ++ if output is None: ++ output = torch.empty_like(hidden_states) ++ ops.moe_sum(g2.view(T, top_k, H), output) ++ return output + ++ result = g2.view(T, top_k, H).sum(dim=1).to(hidden_states.dtype) ++ if output is not None: ++ output.copy_(result) ++ return output ++ return result ++ ++ + class Mxfp8NativeTritonExperts(Mxfp8TritonExpertsBase): +- """Native MXFP8 MoE (CDNA4 ``dot_scaled``) on gfx950.""" ++ """Fused MXFP8 MoE on gfx94x/gfx95x.""" + ++ def __init__( ++ self, ++ moe_config: FusedMoEConfig, ++ quant_config: FusedMoEQuantConfig, ++ ): ++ super().__init__(moe_config, quant_config) ++ self.w1_bf16: torch.Tensor | None = None ++ self.w2_bf16: torch.Tensor | None = None ++ self.native_weights_available = True ++ self.bf16_experts: TritonExperts | None = None ++ if _should_use_bf16_decode_fallback(moe_config): ++ bf16_config = biased_moe_quant_config( ++ None, ++ None, ++ gemm1_alpha=quant_config.gemm1_alpha, ++ gemm1_beta=quant_config.gemm1_beta, ++ gemm1_clamp_limit=quant_config.gemm1_clamp_limit, ++ ) ++ self.bf16_experts = TritonExperts(moe_config, bf16_config) ++ + @property ++ def requires_bf16_fallback_weights(self) -> bool: ++ return self.bf16_experts is not None ++ ++ def bind_bf16_weights( ++ self, ++ w1_bf16: torch.Tensor, ++ w2_bf16: torch.Tensor, ++ *, ++ native_weights_available: bool, ++ ) -> None: ++ if self.bf16_experts is None: ++ raise RuntimeError("BF16 decode experts are not enabled for this config.") ++ self.w1_bf16 = w1_bf16 ++ self.w2_bf16 = w2_bf16 ++ self.native_weights_available = native_weights_available ++ ++ def bind_packed_weight_scales( ++ self, ++ w1_scale: torch.Tensor, ++ w2_scale: torch.Tensor, ++ ) -> None: ++ if not current_platform.is_fp8_fnuz(): ++ raise RuntimeError("Packed MXFP8 scales are specific to gfx94x.") ++ self.w1_scale_val = w1_scale ++ self.w2_scale_val = w2_scale ++ ++ @property + def quant_dtype(self) -> torch.dtype | str | None: +@@ -283,7 +609,9 @@ + + @staticmethod + def _supports_current_device() -> bool: +- return current_platform.is_rocm() and current_platform.supports_mx() ++ return current_platform.is_rocm() and ( ++ current_platform.supports_mx() or current_platform.is_fp8_fnuz() ++ ) + + def apply( + self, +@@ -303,6 +631,35 @@ + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): ++ num_tokens = hidden_states.shape[0] ++ bf16_experts = self.bf16_experts ++ if bf16_experts is not None and _should_use_bf16_experts( ++ num_tokens, ++ self.native_weights_available, ++ ): ++ if self.w1_bf16 is None or self.w2_bf16 is None: ++ raise RuntimeError( ++ "BF16 fallback weights were not bound after loading." ++ ) ++ bf16_experts.apply( ++ output=output, ++ hidden_states=hidden_states, ++ w1=self.w1_bf16, ++ w2=self.w2_bf16, ++ topk_weights=topk_weights, ++ topk_ids=topk_ids, ++ activation=activation, ++ global_num_experts=global_num_experts, ++ expert_map=expert_map, ++ a1q_scale=None, ++ a2_scale=None, ++ workspace13=workspace13, ++ workspace2=workspace2, ++ expert_tokens_meta=expert_tokens_meta, ++ apply_router_weight_on_input=apply_router_weight_on_input, ++ ) ++ return ++ + alpha = self.quant_config.gemm1_alpha + alpha = 1.702 if alpha is None else float(alpha) + beta = self.quant_config.gemm1_beta +@@ -322,5 +679,6 @@ + limit=limit, + global_num_experts=global_num_experts, + expert_map=expert_map, ++ output=output, + ) +- output.copy_(out) ++ assert out is output +diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py +index 225484385..912ed6152 100644 +--- a/vllm/model_executor/layers/fused_moe/layer.py ++++ b/vllm/model_executor/layers/fused_moe/layer.py +@@ -318,7 +318,13 @@ def FusedMoE( + moe_backend=vllm_config.kernel_config.moe_backend, + router_logits_dtype=router_logits_dtype, + max_num_tokens=max_num_batched_tokens, ++ max_model_len=( ++ vllm_config.model_config.max_model_len ++ if vllm_config.model_config is not None ++ else 0 ++ ), + has_bias=has_bias, ++ has_shared_experts=shared_experts is not None, + is_lora_enabled=vllm_config.lora_config is not None, + activation=moe_activation, + device=vllm_config.device_config.device, +diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py +index acbf2cb46..1fcf67678 100644 +--- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py ++++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py +@@ -55,9 +55,9 @@ class Fp8MoeBackend(Enum): + # Dequantize-to-BF16 emulation for MXFP8 on devices without a native + # MXFP8 MoE kernel (e.g. ROCm). Weights pass through unchanged here. + EMULATION = "EMULATION" +- # MXFP8 MoE via a Triton ``dot_scaled`` kernel that lowers to CDNA4 +- # (gfx950) native MX matrix-core ops. Weights stay in MXFP8 (no load-time +- # format conversion); the FP8 values + E8M0 scales are consumed directly. ++ # Fused ROCm MXFP8 MoE. CDNA4 (gfx95x) uses native ``dot_scaled`` MX ops; ++ # CDNA3 (gfx94x) uses E4M3FNUZ FP8 partial dots with in-register E8M0 scale ++ # application. Both consume compressed weights directly. + NATIVE_MXFP8 = "NATIVE_MXFP8" + + +@@ -463,6 +463,13 @@ def convert_to_fp8_moe_kernel_format( + ) + + w13, w2 = prepare_fp8_moe_layer_for_cpu(w13, w2) ++ elif fp8_backend == Fp8MoeBackend.NATIVE_MXFP8 and current_platform.is_fp8_fnuz(): ++ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( ++ normalize_mxfp8_e4m3fn_to_e4m3fnuz, ++ ) ++ ++ w13, w13_scale = normalize_mxfp8_e4m3fn_to_e4m3fnuz(w13, w13_scale) ++ w2, w2_scale = normalize_mxfp8_e4m3fn_to_e4m3fnuz(w2, w2_scale) + else: + if fp8_backend not in [ + Fp8MoeBackend.TRITON, +@@ -470,8 +477,8 @@ def convert_to_fp8_moe_kernel_format( + Fp8MoeBackend.VLLM_CUTLASS, + Fp8MoeBackend.BATCHED_VLLM_CUTLASS, + Fp8MoeBackend.XPU, +- # EMULATION dequantizes weights at runtime; NATIVE_MXFP8 consumes +- # the MXFP8 weights as-is — neither needs a load-time layout change. ++ # EMULATION consumes checkpoint layout directly. CDNA4 NATIVE_MXFP8 ++ # also needs no layout change; CDNA3 normalization is handled above. + Fp8MoeBackend.EMULATION, + Fp8MoeBackend.NATIVE_MXFP8, + ]: +diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py +index d0d7c7648..bc00da41e 100644 +--- a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py ++++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py +@@ -76,15 +76,37 @@ def _select_kernel_cls( + ) + + +-def _select_rocm_mxfp8_backend() -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]: ++def _select_rocm_mxfp8_backend( ++ config: FusedMoEConfig, ++) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]: + """ROCm fallback when vendor MXFP8 backends are unavailable.""" + +- if current_platform.supports_mx(): ++ if current_platform.is_fp8_fnuz() and config.ep_size > 1: ++ from vllm.model_executor.layers.fused_moe.experts.mxfp8_emulation_moe import ( ++ Mxfp8EmulationTritonExperts, ++ ) ++ ++ logger.info_once( ++ "Using BF16 MXFP8 emulation for gfx94x expert parallelism; the " ++ "native CDNA3 path is optimized for decode-sized TP workloads and " ++ "is slower for the large local batches reached during EP prefill." ++ ) ++ return Fp8MoeBackend.EMULATION, Mxfp8EmulationTritonExperts ++ ++ if current_platform.supports_mx() or current_platform.is_fp8_fnuz(): + from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( + Mxfp8NativeTritonExperts, + ) + +- logger.info_once("Using native CDNA4 (gfx950) MXFP8 dot_scaled MoE backend.") ++ if current_platform.supports_mx(): ++ logger.info_once( ++ "Using native CDNA4 (gfx95x) MXFP8 dot_scaled MoE backend." ++ ) ++ else: ++ logger.info_once( ++ "Using fused CDNA3 (gfx94x) MXFP8 FP8 MoE backend; weights " ++ "remain compressed and 1x32 scales are applied in-kernel." ++ ) + return Fp8MoeBackend.NATIVE_MXFP8, Mxfp8NativeTritonExperts + + from vllm.model_executor.layers.fused_moe.experts.mxfp8_emulation_moe import ( +@@ -134,6 +156,6 @@ def select_mxfp8_moe_backend( + + # simplify the logic for rocm, refactor later when more backends are supported + if current_platform.is_rocm(): +- return _select_rocm_mxfp8_backend() ++ return _select_rocm_mxfp8_backend(config) + + raise ValueError("No MXFP8 MoE backends available.") +diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py +index 33c7c7532..9e26aa823 100644 +--- a/vllm/model_executor/layers/quantization/modelopt.py ++++ b/vllm/model_executor/layers/quantization/modelopt.py +@@ -92,6 +92,7 @@ from vllm.model_executor.parameter import ( + PerTensorScaleParameter, + ) + from vllm.model_executor.utils import replace_parameter, set_weight_attrs ++from vllm.platforms import current_platform + + if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper +@@ -2086,7 +2087,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): + def _dequant_mxfp8_weights_to_bf16(self, layer: RoutedExperts) -> None: + """One-time MXFP8->BF16 weight dequant for the emulation path. + +- On devices without a native MXFP8 MoE kernel (e.g. gfx942 / MI300), ++ On devices without a fused MXFP8 MoE kernel, + ``Mxfp8EmulationTritonExperts`` otherwise dequantizes every expert + weight to BF16 on *every* forward step -- the dominant cost (conc1 + ~1.3 tok/s). Doing the dequant once here and replacing the MXFP8 +@@ -2121,6 +2122,90 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): + num_experts, + ) + ++ def _retain_bf16_fallback_weights(self, layer: RoutedExperts) -> None: ++ """Keep the BF16 weights selected by the gfx94x TP dispatch policy.""" ++ from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( ++ Mxfp8NativeTritonExperts, ++ _should_store_bf16_only, ++ ) ++ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( ++ dequant_mxfp8_to_bf16, ++ ) ++ from vllm.model_executor.models.utils import extract_layer_index ++ ++ if self.moe_kernel is None: ++ raise RuntimeError("MXFP8 MoE kernel was not initialized.") ++ experts = self.moe_kernel.fused_experts ++ if not isinstance(experts, Mxfp8NativeTritonExperts): ++ raise TypeError( ++ "Expected Mxfp8NativeTritonExperts for the gfx94x native backend." ++ ) ++ ++ target_dtype = getattr(layer, "orig_dtype", torch.bfloat16) ++ w13_bf16 = dequant_mxfp8_to_bf16(layer.w13_weight, layer.w13_weight_scale).to( ++ target_dtype ++ ) ++ w2_bf16 = dequant_mxfp8_to_bf16(layer.w2_weight, layer.w2_weight_scale).to( ++ target_dtype ++ ) ++ layer_index = extract_layer_index(layer.layer_name) ++ store_bf16_only = _should_store_bf16_only( ++ self.moe.max_model_len, ++ layer_index, ++ ) ++ ++ if store_bf16_only: ++ replace_parameter(layer, "w13_weight", w13_bf16) ++ replace_parameter(layer, "w2_weight", w2_bf16) ++ else: ++ layer.register_buffer("_mxfp8_w13_bf16", w13_bf16, persistent=False) ++ layer.register_buffer("_mxfp8_w2_bf16", w2_bf16, persistent=False) ++ experts.bind_bf16_weights( ++ w13_bf16, ++ w2_bf16, ++ native_weights_available=not store_bf16_only, ++ ) ++ ++ if self.moe.max_model_len <= 4096: ++ logger.info_once( ++ "Retaining BF16 MXFP8 MoE weights for gfx94x TP decode and " ++ "prefill dispatch." ++ ) ++ else: ++ logger.info_once( ++ "Using BF16-only storage for one-fifth of gfx94x TP MoE " ++ "layers and retaining both MXFP8 and BF16 weights for the " ++ "remaining layers." ++ ) ++ ++ def _pack_mxfp8_weight_scales(self, layer: RoutedExperts, experts) -> None: ++ """Pack gfx94x E8M0 scales so consecutive output columns are contiguous.""" ++ if not experts.native_weights_available: ++ return ++ ++ def pack(scale: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: ++ E, N, K = weight.shape ++ if scale.shape == (E, K // 32, N): ++ return scale ++ if scale.shape != (E, N, K // 32): ++ raise ValueError( ++ "Unexpected MXFP8 weight-scale shape " ++ f"{tuple(scale.shape)} for weight {tuple(weight.shape)}." ++ ) ++ return scale.transpose(1, 2).contiguous() ++ ++ w13_scale = pack(layer.w13_weight_scale, layer.w13_weight) ++ w2_scale = pack(layer.w2_weight_scale, layer.w2_weight) ++ replace_parameter(layer, "w13_weight_scale", w13_scale) ++ replace_parameter(layer, "w2_weight_scale", w2_scale) ++ experts.bind_packed_weight_scales( ++ layer.w13_weight_scale, ++ layer.w2_weight_scale, ++ ) ++ logger.info_once( ++ "Packed gfx94x MXFP8 MoE weight scales as [expert, K/32, N]." ++ ) ++ + def process_weights_after_loading(self, layer: RoutedExperts) -> None: + # TODO(bnell): why is this required only for mxfp8? + if getattr(layer, "_already_called_process_weights_after_loading", False): +@@ -2158,7 +2243,21 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): + routing_tables=layer._expert_routing_tables(), + ) + +- # No native MXFP8 MoE kernel on this device (e.g. gfx942): the emulation ++ from vllm.model_executor.layers.fused_moe.experts.mxfp8_native_moe import ( ++ Mxfp8NativeTritonExperts, ++ ) ++ ++ experts = self.moe_kernel.fused_experts ++ if ( ++ self.mxfp8_backend == Fp8MoeBackend.NATIVE_MXFP8 ++ and current_platform.is_fp8_fnuz() ++ and isinstance(experts, Mxfp8NativeTritonExperts) ++ ): ++ if experts.requires_bf16_fallback_weights: ++ self._retain_bf16_fallback_weights(layer) ++ self._pack_mxfp8_weight_scales(layer, experts) ++ ++ # No fused MXFP8 MoE kernel on this device: the emulation + # experts would dequant MXFP8->BF16 every forward step. Convert the + # weights to BF16 once, here, so the MoE runs like a BF16 checkpoint. + # Opt out (VLLM_MXFP8_EMULATION_DEQUANT_AT_LOAD=0) to keep the 1-byte +diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py +index e6063b463..fa5b01615 100644 +--- a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py ++++ b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py +@@ -11,6 +11,32 @@ MXFP8_SCALE_DTYPE = torch.uint8 + MXFP8_BLOCK_SIZE = 32 + + ++def normalize_mxfp8_e4m3fn_to_e4m3fnuz( ++ values: torch.Tensor, ++ scales: torch.Tensor, ++) -> tuple[torch.Tensor, torch.Tensor]: ++ """Convert OCP E4M3 MXFP8 storage to AMD E4M3FNUZ in place. ++ ++ For an identical byte pattern, E4M3FNUZ represents half the E4M3FN value. ++ Incrementing the E8M0 exponent preserves the dequantized value without ++ expanding the one-byte weights. OCP negative zero (0x80) is NaN in FNUZ, ++ so it must be canonicalized to positive zero before reinterpreting. ++ """ ++ if values.dtype == torch.float8_e4m3fnuz: ++ return values, scales ++ if values.dtype != torch.float8_e4m3fn: ++ raise ValueError(f"Expected E4M3FN or E4M3FNUZ values, got {values.dtype}.") ++ if scales.dtype != MXFP8_SCALE_DTYPE: ++ raise ValueError(f"Expected {MXFP8_SCALE_DTYPE} scales, got {scales.dtype}.") ++ if int(scales.max().item()) >= 254: ++ raise ValueError("Cannot convert MXFP8 scale exponent 254 to E4M3FNUZ.") ++ ++ value_bits = values.view(torch.int8) ++ value_bits.masked_fill_(value_bits == -128, 0) ++ scales.add_(1) ++ return value_bits.view(torch.float8_e4m3fnuz), scales ++ ++ + def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor: + """Swizzle MXFP8 scales from row-major 2D to F8_128x4 layout.""" + scaling_vector_size = MXFP8_BLOCK_SIZE # 32 for MXFP8 +@@ -38,6 +64,7 @@ def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor: + def _mxfp8_e4m3_quantize_torch( + x: torch.Tensor, + is_sf_swizzled_layout: bool = False, ++ value_dtype: torch.dtype = MXFP8_VALUE_DTYPE, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Naive MXFP8 quantization. + For each block of 32 elements along the last dimension, compute a +@@ -65,7 +92,7 @@ def _mxfp8_e4m3_quantize_torch( + descale = torch.exp2(scale_biased - 127.0) + x_scaled = x_blocked / descale.unsqueeze(-1) + +- x_fp8 = x_scaled.view(orig_shape).to(MXFP8_VALUE_DTYPE) ++ x_fp8 = x_scaled.view(orig_shape).to(value_dtype) + + if x.ndim == 2: + M, K = x.shape +@@ -139,6 +166,7 @@ def _mxfp8_e4m3_quantize_triton( + x: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Fused 2D MXFP8 quant (non-swizzled, row-major [M, K//32] scales).""" ++ from vllm.platforms import current_platform + from vllm.triton_utils import triton + + global _MXFP8_QUANT_KERNEL +@@ -147,7 +175,10 @@ def _mxfp8_e4m3_quantize_triton( + + M, K = x.shape + x = x.contiguous() +- xq = torch.empty((M, K), dtype=MXFP8_VALUE_DTYPE, device=x.device) ++ value_dtype = ( ++ torch.float8_e4m3fnuz if current_platform.is_fp8_fnuz() else MXFP8_VALUE_DTYPE ++ ) ++ xq = torch.empty((M, K), dtype=value_dtype, device=x.device) + scales = torch.empty( + (M, K // MXFP8_BLOCK_SIZE), dtype=MXFP8_SCALE_DTYPE, device=x.device + ) +@@ -233,7 +264,19 @@ def mxfp8_e4m3_quantize_fake( + alignment: int = 0, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Fake implementation for torch.compile tracing.""" +- fp_data = torch.empty_like(x, dtype=MXFP8_VALUE_DTYPE) ++ from vllm.platforms import current_platform ++ ++ value_dtype = ( ++ torch.float8_e4m3fnuz ++ if ( ++ current_platform.is_fp8_fnuz() ++ and not is_sf_swizzled_layout ++ and x.ndim == 2 ++ and x.shape[-1] % MXFP8_BLOCK_SIZE == 0 ++ ) ++ else MXFP8_VALUE_DTYPE ++ ) ++ fp_data = torch.empty_like(x, dtype=value_dtype) + + block_size = MXFP8_BLOCK_SIZE + +diff --git a/vllm/models/minimax_m3/amd/ops/swiglu_oai.py b/vllm/models/minimax_m3/amd/ops/swiglu_oai.py +index 836649b72..9572c5109 100644 +--- a/vllm/models/minimax_m3/amd/ops/swiglu_oai.py ++++ b/vllm/models/minimax_m3/amd/ops/swiglu_oai.py +@@ -24,6 +24,7 @@ HIP graphs already eliminate — measured end-to-end throughput is identical + + import torch + ++from vllm.platforms import current_platform + from vllm.triton_utils import tl, triton + + +@@ -132,11 +133,10 @@ def swiglu_oai_quantize_mxfp8( + ) -> tuple[torch.Tensor, torch.Tensor]: + """SwiGLU-OAI on split-layout ``[M, 2I]`` fused with MXFP8 activation-quant. + +- Returns ``(act_q [M, I] float8_e4m3fn, act_scale [M, I//32] uint8 E8M0)``, +- identical to ``mxfp8_e4m3_quantize(swiglu_oai_split(gate_up))`` but in a +- single Triton pass (no bf16 intermediate). Used between the two GEMMs of the +- native MXFP8 MoE. Numerically equivalent to the unfused chain (bit-exact on +- measured MoE shapes); marginally more accurate (fp32 act, no bf16 round-trip). ++ Returns platform-native E4M3 values plus ``[M, I//32]`` uint8 E8M0 scales, ++ equivalent to ``mxfp8_e4m3_quantize(swiglu_oai_split(gate_up))`` but in a ++ single Triton pass (no bf16 intermediate). gfx94x emits E4M3FNUZ so ++ ``tl.dot`` lowers to the native CDNA3 FP8 matrix cores. + """ + from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( + MXFP8_BLOCK_SIZE, +@@ -151,7 +151,10 @@ def swiglu_oai_quantize_mxfp8( + ) + g1 = gate_up.reshape(-1, two_i).contiguous() + M = g1.shape[0] +- aq = torch.empty((M, n_inter), dtype=MXFP8_VALUE_DTYPE, device=g1.device) ++ value_dtype = ( ++ torch.float8_e4m3fnuz if current_platform.is_fp8_fnuz() else MXFP8_VALUE_DTYPE ++ ) ++ aq = torch.empty((M, n_inter), dtype=value_dtype, device=g1.device) + asc = torch.empty( + (M, n_inter // MXFP8_BLOCK_SIZE), dtype=MXFP8_SCALE_DTYPE, device=g1.device + ) diff --git a/perf-changelog.yaml b/perf-changelog.yaml index 15cc9c94f..81610d9d0 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -3830,3 +3830,12 @@ description: - "Extend MiniMax-M3 MXFP8 H100/H200 non-MTP sweeps to concurrency 1 on the latency rows (H100: TP8; H200: TP4 and TP8) and add full TEP coverage from conc 1 to 256 (H100: TP8+EP8; H200: TP4+EP4 and TP8+EP8, incl. a new TP4+EP4 row for 8k1k). H200 TP8+EP8 upper bound moves 512->256 (high concurrency stays covered by the TP8+EP8 dp-attn DEP rows). DEP rows unchanged" pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1761 + +- config-keys: + - minimaxm3-fp8-mi300x-vllm + description: + - "Add a fused gfx942 MXFP8 MoE kernel, normalize OCP E4M3FN weights to AMD E4M3FNUZ, and reconstruct E8M0 scales with the Marlin-style BF16 exponent bitcast" + - "Use BF16 experts for TP decode batches up to 8 tokens and eager refill batches from 832 tokens, native W8A8 MXFP8 between those thresholds, and load-time BF16 dequantization under expert parallelism" + - "Pack gfx942 weight scales as [expert, K/32, N] for contiguous output-column scale loads; exact InferenceX graph replay shows 9.5-17.0% lower native MoE latency" + - "For long-context TP, store one-fifth of MoE layers in BF16 only and retain both BF16 and compressed MXFP8 weights for the remaining layers without changing the TP8 and TP8+EP8 parallelism matrix" + pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1753