diff --git a/.github/configs/amd-master.yaml b/.github/configs/amd-master.yaml index 7e4918e09..5774d06ed 100644 --- a/.github/configs/amd-master.yaml +++ b/.github/configs/amd-master.yaml @@ -2966,3 +2966,24 @@ minimaxm3-fp8-mi325x-vllm-mtp: - { tp: 8, conc-start: 1, conc-end: 128, spec-decoding: mtp } - { tp: 8, ep: 8, conc-start: 256, conc-end: 256, spec-decoding: mtp } - { tp: 8, ep: 8, dp-attn: true, conc-start: 256, conc-end: 256, spec-decoding: mtp } + +# [DO NOT MERGE — experimental] MI325X (gfx942) counterpart of +# minimaxm3arf-fp8-mi355x-vllm: validates vllm-project/vllm#45639 (AITER fused +# all-reduce + Gemma-RMSNorm for MiniMax-M3) on MI325X by applying that PR's diff +# in-place to the shipped minimax-m3 image before serving (recipe +# benchmarks/single_node/fixed_seq_len/minimaxm3arf_fp8_mi325x.sh; BF16 KV on +# gfx942). Smoke test at conc 4 and 8, TP8 (the AR+RMS fusion needs TP>1). +minimaxm3arf-fp8-mi325x-vllm: + image: vllm/vllm-openai-rocm:minimax-m3 + model: MiniMaxAI/MiniMax-M3-MXFP8 + model-prefix: minimaxm3arf + runner: mi325x + precision: fp8 + framework: vllm + multinode: false + scenarios: + fixed-seq-len: + - isl: 1024 + osl: 1024 + search-space: + - { tp: 8, conc-list: [ 4, 8 ] } diff --git a/.github/profile-target.env b/.github/profile-target.env new file mode 100644 index 000000000..3942c7906 --- /dev/null +++ b/.github/profile-target.env @@ -0,0 +1,11 @@ +# Target for the label-triggered Profile workflow (.github/workflows/profile.yml). +# When a PR carries the 'profile-enabled' label, that workflow profiles this +# config-key and emits a Perfetto trace per (concurrency, tp) (artifact + relay +# link in the run summary). CONC is space-separated; the workflow runs one job +# per (conc, tp) — i.e. both TP4 and TP8 at each conc. +# +# Experiment: MiniMax-M3 MXFP8 on MI325X (gfx942) with vllm-project/vllm#45639 +# (AITER AR + Gemma-RMS fusion) applied in-place; single-node vLLM, TP8, conc 4 and 8. +CONFIG_KEY=minimaxm3arf-fp8-mi325x-vllm +CONFIG_FILE=.github/configs/amd-master.yaml +CONC=4 8 diff --git a/.github/workflows/profile.yml b/.github/workflows/profile.yml index 2b9679f55..166ea34bd 100644 --- a/.github/workflows/profile.yml +++ b/.github/workflows/profile.yml @@ -26,6 +26,12 @@ on: description: "Ref (branch/sha) to checkout" required: false type: string + # Label-triggered profiling: add the 'profile-enabled' label to a PR and this + # workflow profiles the target declared in .github/profile-target.env + # (CONFIG_KEY / CONFIG_FILE / CONC), emitting a Perfetto trace as an artifact + # plus a relay link. Gated in the get-jobs `if` below so only labelled PRs run. + pull_request: + types: [labeled, synchronize, reopened] permissions: contents: read @@ -40,26 +46,68 @@ env: jobs: get-jobs: + # Run for manual dispatch, or for a PR carrying the 'profile-enabled' label. + if: >- + ${{ github.event_name == 'workflow_dispatch' || + (github.event_name == 'pull_request' && + contains(github.event.pull_request.labels.*.name, 'profile-enabled')) }} runs-on: ubuntu-latest outputs: filtered-matrix: ${{ steps.filter.outputs.filtered }} count: ${{ steps.filter.outputs.count }} + ref: ${{ steps.preref.outputs.ref }} + moe-debug: ${{ steps.target.outputs.moe_debug }} steps: + - name: Resolve checkout ref + id: preref + run: | + if [ "${{ github.event_name }}" = "pull_request" ]; then + echo "ref=${{ github.event.pull_request.head.sha }}" >> "$GITHUB_OUTPUT" + else + echo "ref=${{ inputs.ref || github.sha }}" >> "$GITHUB_OUTPUT" + fi + - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: - ref: ${{ inputs.ref || github.sha }} + ref: ${{ steps.preref.outputs.ref }} + + - name: Resolve profile target (dispatch inputs or PR target file) + id: target + run: | + set -euo pipefail + if [ "${{ github.event_name }}" = "pull_request" ]; then + f=.github/profile-target.env + if [ ! -f "$f" ]; then + echo "::error::$f is required for label-triggered profiling" >&2 + exit 1 + fi + ck=$(grep -E '^CONFIG_KEY=' "$f" | head -1 | cut -d= -f2-) + cf=$(grep -E '^CONFIG_FILE=' "$f" | head -1 | cut -d= -f2-) + cc=$(grep -E '^CONC=' "$f" | head -1 | cut -d= -f2-) + md=$(grep -E '^MOE_DEBUG=' "$f" | head -1 | cut -d= -f2- || true) + if [ -z "$ck" ]; then echo "::error::CONFIG_KEY missing in $f" >&2; exit 1; fi + echo "config_key=${ck}" >> "$GITHUB_OUTPUT" + echo "config_file=${cf:-.github/configs/nvidia-master.yaml}" >> "$GITHUB_OUTPUT" + echo "conc=${cc:-64}" >> "$GITHUB_OUTPUT" + echo "moe_debug=${md:-false}" >> "$GITHUB_OUTPUT" + else + echo "config_key=${{ inputs.config-key }}" >> "$GITHUB_OUTPUT" + echo "config_file=${{ inputs.config-file }}" >> "$GITHUB_OUTPUT" + echo "conc=${{ inputs.conc }}" >> "$GITHUB_OUTPUT" + echo "moe_debug=${{ inputs.moe-debug }}" >> "$GITHUB_OUTPUT" + fi - id: gen name: Generate matrix via script run: | pip install pydantic - CLI_ARGS="test-config --config-files ${{ inputs.config-file }} --config-keys ${{ inputs.config-key }} --conc ${{ inputs.conc }}" + CLI_ARGS="test-config --config-files ${{ steps.target.outputs.config_file }} --config-keys ${{ steps.target.outputs.config_key }} --conc ${{ steps.target.outputs.conc }}" CONFIG_JSON=$(python3 ${GITHUB_WORKSPACE}/utils/matrix_logic/generate_sweep_configs.py $CLI_ARGS) echo "raw=$CONFIG_JSON" >> $GITHUB_OUTPUT - id: filter - name: Take first generated job + name: Select one job per concurrency shell: python run: | import json, os, sys @@ -78,7 +126,16 @@ jobs: f.write("filtered=[]\ncount=0\n") raise SystemExit(1) - filt = data[:1] + # One job per (concurrency, tp): the first config generated for each + # (conc, tp) pair — i.e. the leading 1k1k search-space row of each TP — + # ordered by (conc, tp). This profiles both TP4 and TP8 at every conc. + # A single conc/tp still yields one job (backward compatible). + by_key = {} + for job in data: + k = (job.get("conc"), job.get("tp")) + if k not in by_key: + by_key[k] = job + filt = [by_key[k] for k in sorted(by_key)] out = json.dumps(filt) print(out) @@ -115,8 +172,14 @@ jobs: CONC: ${{ matrix.config.conc }} SPEC_DECODING: ${{ matrix.config.spec-decoding }} DISAGG: ${{ matrix.config.disagg }} + # The single-node launchers resolve the recipe path as + # benchmarks/single_node/${SCENARIO_SUBDIR}__.sh and run + # under `set -u`; the sweep's benchmark-tmpl.yml sets this, so profile.yml + # must too. Profiling is fixed-seq-len (mi325x/mi300x launchers don't + # default it the way mi355x does). + SCENARIO_SUBDIR: fixed_seq_len/ MOE_DEBUG: '0' - MOE_DEBUG_LOG: ${{ (inputs.moe-debug) && '/workspace/moe_debug.tp0.log' || '' }} + MOE_DEBUG_LOG: ${{ needs.get-jobs.outputs.moe-debug == 'true' && '/workspace/moe_debug.tp0.log' || '' }} steps: - name: Resource cleanup run: | @@ -145,7 +208,7 @@ jobs: uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 - ref: ${{ inputs.ref || github.sha }} + ref: ${{ needs.get-jobs.outputs.ref }} clean: false - name: Launch + Profile (single-node sglang/vllm) @@ -261,7 +324,17 @@ jobs: git config user.email "github-actions@github.com" git add -A git commit -m "Add profile: ${GITHUB_SHA} ${{ matrix.config['exp-name'] }} tp${{ matrix.config.tp }} ep${{ matrix.config.ep || 1 }} conc${{ matrix.config.conc }}" || echo "Nothing to commit" - git push + # Parallel matrix jobs (one per conc) all push to this same repo, so a + # plain push races and is rejected non-fast-forward. Rebase onto the + # latest remote and retry with jitter until it lands. + for attempt in 1 2 3 4 5 6 7 8; do + if git push; then break; fi + echo "push rejected (attempt ${attempt}); rebasing on origin/master" + git fetch origin master --quiet + git rebase origin/master + if [ "$attempt" = 8 ]; then echo "push failed after ${attempt} attempts" >&2; exit 1; fi + sleep $(( (RANDOM % 6) + 2 )) + done STORAGE_SHA="$(git rev-parse HEAD)" popd >/dev/null diff --git a/benchmarks/single_node/fixed_seq_len/minimaxm3arf_fp8_mi325x.sh b/benchmarks/single_node/fixed_seq_len/minimaxm3arf_fp8_mi325x.sh new file mode 100644 index 000000000..d489cbe46 --- /dev/null +++ b/benchmarks/single_node/fixed_seq_len/minimaxm3arf_fp8_mi325x.sh @@ -0,0 +1,129 @@ +#!/usr/bin/env bash + +# [DO NOT MERGE — experimental] MiniMax-M3 MXFP8 MI325X (gfx942) single-node vLLM +# recipe that validates vllm-project/vllm#45639 ("[ROCm][M3] Enable AITER AR + +# Gemma-RMS fusion for MiniMax-M3") on real MI325X hardware before an image +# rebuild. It applies #45639 in-place to the shipped vllm/vllm-openai-rocm:minimax-m3 +# image, then serves with the AITER fused all-reduce + RMSNorm path enabled. +# +# Mirrors minimaxm3_fp8_mi325x.sh otherwise (--block-size 128, --language-model-only, +# TRITON_ATTN, BF16 KV — gfx942 has no calibrated FP8 attention scales). The +# #45639-specific knobs: +# VLLM_ROCM_USE_AITER=1 (AITER kernels) +# --compilation-config custom_ops=["-minimax_gemma_rms_norm"] (allow IR lowering) +# --compilation-config pass_config.fuse_allreduce_rms=true (the fusion pass) +# The fusion needs TP>1; this recipe is swept at TP8. + +source "$(dirname "$0")/../../benchmark_lib.sh" + +check_env_vars \ + MODEL \ + TP \ + EP_SIZE \ + DP_ATTENTION \ + CONC \ + ISL \ + OSL \ + MAX_MODEL_LEN \ + RANDOM_RANGE_RATIO \ + RESULT_FILENAME + +if [[ -n "$SLURM_JOB_ID" ]]; then + echo "JOB $SLURM_JOB_ID running on $SLURMD_NODENAME" +fi + +if [[ "$MODEL" != /* ]]; then hf download "$MODEL"; fi + +if [ -n "$ROCR_VISIBLE_DEVICES" ]; then + export HIP_VISIBLE_DEVICES="$ROCR_VISIBLE_DEVICES" +fi + +# ---- Apply vllm-project/vllm#45639 in-place ------------------------------- +# The shipped minimax-m3 image predates #45639 (base m3_release). Apply the +# vendored diff to the installed vllm. Idempotent: if it is already applied +# (reverse-applies cleanly) we proceed; if it neither applies cleanly nor is +# already applied, the image has drifted from the PR base — hard-fail so we never +# silently benchmark an unpatched server. +PATCH_FILE="$(cd "$(dirname "$0")/patches" && pwd)/vllm-45639-aiter-ar-gemma-rms.diff" +command -v patch >/dev/null 2>&1 || { apt-get update -q -y && apt-get install -q -y patch; } +VLLM_SP="$(python3 -c 'import os, vllm; print(os.path.dirname(os.path.dirname(vllm.__file__)))')" +if ( cd "$VLLM_SP" && patch -p1 -R --dry-run < "$PATCH_FILE" >/dev/null 2>&1 ); then + echo "[vllm#45639] already applied to $VLLM_SP/vllm" +elif ( cd "$VLLM_SP" && patch -p1 --dry-run < "$PATCH_FILE" >/dev/null 2>&1 ); then + ( cd "$VLLM_SP" && patch -p1 < "$PATCH_FILE" ) + echo "[vllm#45639] applied to $VLLM_SP/vllm" +else + echo "FATAL: vllm#45639 patch neither applies cleanly nor is already applied" >&2 + echo " ($VLLM_SP/vllm has drifted from the PR's m3_release base)" >&2 + exit 1 +fi + +SERVER_LOG=/workspace/server.log +export VLLM_ENGINE_READY_TIMEOUT_S=3600 +export VLLM_USE_BREAKABLE_CUDAGRAPH=0 +# #45639: AITER fused all-reduce + Gemma-RMSNorm. +export VLLM_ROCM_USE_AITER=1 + +if [ "${EVAL_ONLY}" = "true" ]; then + setup_eval_context +fi + +PARALLEL_ARGS=(--tensor-parallel-size "$TP") +if [ "${DP_ATTENTION}" = "true" ]; then + PARALLEL_ARGS=( + --tensor-parallel-size 1 + --data-parallel-size "$TP" + --enable-expert-parallel + ) +elif [ "$EP_SIZE" -gt 1 ]; then + PARALLEL_ARGS+=(--enable-expert-parallel) +fi + +start_gpu_monitor + +# When PROFILE=1 (profile.yml), arm vLLM's torch profiler via --profiler-config. +# This minimax-m3 image's vLLM does NOT honour the VLLM_TORCH_PROFILER_DIR env +# var, so the serve flag is what makes /start_profile emit a trace. Write to the +# dir benchmark_lib's relay scans (VLLM_TORCH_PROFILER_DIR, default /workspace/). +PROFILE_ARGS=() +if [ "${PROFILE:-}" = "1" ]; then + PROFILE_ARGS=(--profiler-config "{\"profiler\": \"torch\", \"torch_profiler_dir\": \"${VLLM_TORCH_PROFILER_DIR:-/workspace/}\"}") +fi + +set -x +vllm serve "$MODEL" --port "$PORT" \ + "${PARALLEL_ARGS[@]}" \ + "${PROFILE_ARGS[@]}" \ + --block-size 128 \ + --language-model-only \ + --max-model-len "$MAX_MODEL_LEN" \ + --attention-backend TRITON_ATTN \ + --no-enable-prefix-caching \ + --compilation-config '{"custom_ops": ["-minimax_gemma_rms_norm"], "pass_config": {"fuse_allreduce_rms": true}}' \ + --tool-call-parser minimax_m3 \ + --reasoning-parser minimax_m3 \ + --enable-auto-tool-choice > "$SERVER_LOG" 2>&1 & + +SERVER_PID=$! +wait_for_server_ready --port "$PORT" --server-log "$SERVER_LOG" --server-pid "$SERVER_PID" + +run_benchmark_serving \ + --model "$MODEL" \ + --port "$PORT" \ + --backend vllm \ + --input-len "$ISL" \ + --output-len "$OSL" \ + --random-range-ratio "$RANDOM_RANGE_RATIO" \ + --num-prompts "$((CONC * 10))" \ + --max-concurrency "$CONC" \ + --result-filename "$RESULT_FILENAME" \ + --result-dir /workspace/ \ + --trust-remote-code + +if [ "${RUN_EVAL}" = "true" ]; then + run_eval --framework lm-eval --port "$PORT" + append_lm_eval_summary +fi + +stop_gpu_monitor +set +x diff --git a/benchmarks/single_node/fixed_seq_len/patches/vllm-45639-aiter-ar-gemma-rms.diff b/benchmarks/single_node/fixed_seq_len/patches/vllm-45639-aiter-ar-gemma-rms.diff new file mode 100644 index 000000000..a946949ed --- /dev/null +++ b/benchmarks/single_node/fixed_seq_len/patches/vllm-45639-aiter-ar-gemma-rms.diff @@ -0,0 +1,203 @@ +diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +index 4de5c6cf7ae5..0fd5cb830f5e 100644 +--- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py ++++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +@@ -1439,6 +1439,51 @@ def _replacement( + return _replacement + + ++class AiterAllreduceFusedAddRMSNormWithCopyPattern(BasePattern, VllmPatternReplacement): ++ """Non-quant AR+RMS fusion for all_reduce with 2 users (copy_). ++ ++ In GemmaRMSNorm models, the post-attention all_reduce has a copy_ ++ node for cross-chunk residual state, giving it 2 users and preventing ++ the standard pattern from matching. This pattern returns ar_out as ++ an explicit output so the pattern matcher rewires external users ++ (the copy_) to the fused kernel's residual output. ++ """ ++ ++ def __init__(self, epsilon, dtype, device): ++ super().__init__(dtype, device) ++ self.epsilon = epsilon ++ self.FUSED_OP = rocm_aiter_ops.get_fused_allreduce_rmsnorm_op() ++ ++ def get_inputs(self): ++ return [self.empty(5, 16), self.empty(5, 16), self.empty(16)] ++ ++ @property ++ def pattern(self): ++ eps = self.epsilon ++ ++ def _pattern(residual, input_, weight): ++ ar_out = tensor_model_parallel_all_reduce(input_) ++ rms, res_out = vllm.ir.ops.fused_add_rms_norm(ar_out, residual, weight, eps) ++ return rms, res_out, ar_out ++ ++ return _pattern ++ ++ @property ++ def replacement(self): ++ eps = self.epsilon ++ ++ def _replacement(residual, input_, weight): ++ fused = self.FUSED_OP( ++ input_=input_, ++ residual=residual, ++ weight=weight.to(input_.dtype), ++ epsilon=eps, ++ ) ++ return fused[0], fused[1], fused[1] ++ ++ return _replacement ++ ++ + class RocmAiterAllReduceFusionPass(VllmFusionPatternMatcherPass): + def __init__(self, config: VllmConfig) -> None: + super().__init__(config, "rocm_aiter_allreduce_fusion_pass") +@@ -1503,9 +1548,13 @@ def __init__(self, config: VllmConfig) -> None: + return + + max_token_num = max_size // (hidden_dim * element_size) ++ # Cap at max_cudagraph_capture_size so fusion only fires ++ # for decode. Prefill uses quickreduce + triton rmsnorm. ++ max_cg = config.compilation_config.max_cudagraph_capture_size or 512 + self.max_token_num = min( + max_token_num, + config.scheduler_config.max_num_batched_tokens, ++ max_cg, + ) + + # Only register the AR+RMS+per-group-FP8-quant patterns when the +@@ -1524,6 +1573,15 @@ def __init__(self, config: VllmConfig) -> None: + "FP8 quant fusion." + ) + ++ # Non-quant copy-aware pattern for post-attention allreduce ++ for epsilon in [1e-5, 1e-6]: ++ self.register( ++ AiterAllreduceFusedAddRMSNormWithCopyPattern( ++ epsilon, self.model_dtype, self.device ++ ) ++ ) ++ torch._inductor.pattern_matcher._seen_patterns.clear() ++ + for epsilon in [1e-5, 1e-6]: + # Quant-fused variants must register first so the pattern matcher + # tries them before the AR+RMS-only variants. Otherwise the +diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py +index 13b0ae781314..8d0336622249 100644 +--- a/vllm/model_executor/layers/layernorm.py ++++ b/vllm/model_executor/layers/layernorm.py +@@ -159,7 +159,7 @@ def forward_native( + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" +- weight = self.weight.float() + 1.0 ++ weight = self.weight.data.to(x.dtype) + 1.0 + if residual is None: + return ir.ops.rms_norm(x, weight, self.variance_epsilon) + return ir.ops.fused_add_rms_norm(x, residual, weight, self.variance_epsilon) +diff --git a/vllm/models/minimax_m3/amd/model.py b/vllm/models/minimax_m3/amd/model.py +index b80d3b8b3b8c..eecd6d335fdd 100644 +--- a/vllm/models/minimax_m3/amd/model.py ++++ b/vllm/models/minimax_m3/amd/model.py +@@ -23,7 +23,7 @@ + from torch import nn + from transformers import PretrainedConfig + +-from vllm import _custom_ops as ops ++from vllm import _custom_ops as ops, ir + from vllm.compilation.breakable_cudagraph import eager_break_during_capture + from vllm.config import ( + CacheConfig, +@@ -32,6 +32,7 @@ + ) + from vllm.distributed import get_tensor_model_parallel_world_size + from vllm.forward_context import get_forward_context ++from vllm.model_executor.custom_op import CustomOp + from vllm.model_executor.layers.attention import Attention + from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase + from vllm.model_executor.layers.fused_allreduce_gemma_rms_norm import ( +@@ -160,17 +161,24 @@ def _build_rotary_emb(config: PretrainedConfig, head_dim: int): + ) + + +-class MiniMAXGemmaRMSNorm(nn.Module): +- """Gemma-style RMS normalization (native ROCm implementation). +- +- Normalizes in fp32 and scales by ``(1 + weight)`` — numerically equivalent +- to the FlashInfer ``gemma_rmsnorm`` / ``gemma_fused_add_rmsnorm`` kernels +- used in the NVIDIA path, which are unavailable on ROCm. When ``residual`` is +- given, the fused add + norm returns the updated ``(normed, residual)`` pair. +- +- The fp32 normalize + scale + (optional) residual-add run in a single fused +- Triton pass (``amd.ops.gemma_rmsnorm`` / ``gemma_fused_add_rmsnorm``) instead +- of a chain of elementwise PyTorch kernels. ++@CustomOp.register("minimax_gemma_rms_norm") ++class MiniMAXGemmaRMSNorm(CustomOp): ++ """Gemma-style RMS normalization for the M3 ROCm path. ++ ++ Default (custom op enabled, ``forward_hip``): an fp32 Triton pass ++ (``amd.ops.gemma_rmsnorm`` / ``gemma_fused_add_rmsnorm``) — numerically ++ equivalent to the FlashInfer ``gemma_rmsnorm`` kernels used in the NVIDIA ++ path, which are unavailable on ROCm. This is the unchanged M3 default. ++ ++ When the custom op is disabled (``--compilation-config ++ '{"custom_ops":["-minimax_gemma_rms_norm"]}'``), ``forward_native`` instead ++ emits the plain ``ir.ops.rms_norm`` / ``ir.ops.fused_add_rms_norm`` IR ops, ++ with the Gemma ``1 + weight`` offset folded into the weight. That exposes the ++ post-attention ``all_reduce -> fused_add_rms_norm`` sequence to the AITER ++ AR+RMS fusion pass (``RocmAiterAllReduceFusionPass``), letting it fuse the ++ decode-time allreduce + residual-add + rmsnorm into a single AITER kernel at ++ TP>1. Opt-in because it swaps M3's accuracy-tuned fp32 norm path for the IR ++ lowering: validate gsm8k parity before enabling by default. + """ + + def __init__( +@@ -182,11 +190,24 @@ def __init__( + self.weight = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + +- def forward( ++ def forward_native( ++ self, ++ x: torch.Tensor, ++ residual: torch.Tensor | None = None, ++ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ++ # Fusable IR path, matched by the AITER AR+RMS fusion pass. Fold the ++ # Gemma ``1 + weight`` offset into the weight in x's dtype. ++ weight = self.weight.data.to(x.dtype) + 1.0 ++ if residual is None: ++ return ir.ops.rms_norm(x, weight, self.variance_epsilon) ++ return ir.ops.fused_add_rms_norm(x, residual, weight, self.variance_epsilon) ++ ++ def forward_hip( + self, + x: torch.Tensor, + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ++ # Default ROCm path: fp32 Triton gemma kernels (unchanged M3 behavior). + if residual is None: + return gemma_rmsnorm(x, self.weight, self.variance_epsilon) + return gemma_fused_add_rmsnorm(x, residual, self.weight, self.variance_epsilon) +diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py +index aaf1fdce36bb..b8d49b6efc68 100644 +--- a/vllm/platforms/rocm.py ++++ b/vllm/platforms/rocm.py +@@ -976,9 +976,13 @@ def get_default_ir_op_priority( + using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE + default = ["native"] if using_inductor else ["vllm_c", "native"] + +- # Aiter rms norm perform best when CUDA Graph capture is enabled. +- # TODO(luka/TJ) remove env vars completely +- if ( ++ # When allreduce+rmsnorm fusion is enabled (default on ROCm TP>1), ++ # use native priority so triton can fuse rmsnorm with adjacent ops. ++ # The aiter CK rmsnorm is opaque to triton and blocks fusion. ++ fuse_ar_rms = cc.pass_config.fuse_allreduce_rms ++ if using_inductor and fuse_ar_rms is not False: ++ rms_norm = default # ["native"] ++ elif ( + cc.cudagraph_mode != CUDAGraphMode.NONE + and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_RMSNORM diff --git a/perf-changelog.yaml b/perf-changelog.yaml index bee038a7a..a0124828f 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -3842,3 +3842,9 @@ - "Recipes sourced from NVIDIA/srt-slurm branch sa-submission-q2-2026" - "Runner script updated to support dsv4 model prefix with dynamo-trt framework on GB300" pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1689 + +- config-keys: + - minimaxm3arf-fp8-mi325x-vllm + description: + - "[DO NOT MERGE — experimental] MiniMax-M3 MXFP8 MI325X (gfx942) smoke test (conc 4, 8; TP8) validating vllm-project/vllm#45639 (AITER fused all-reduce + Gemma-RMSNorm for M3): applies the PR diff in-place to the shipped minimax-m3 image before serving (BF16 KV on gfx942), then enables it via VLLM_ROCM_USE_AITER=1 + --compilation-config (custom_ops -minimax_gemma_rms_norm, pass_config.fuse_allreduce_rms). Hard-fails if the patch neither applies cleanly nor is already applied (image drifted)" + pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1772