-
Notifications
You must be signed in to change notification settings - Fork 195
[DO NOT MERGE] [Klaud Cold] experimental: MiniMax-M3 MI325X conc 4/8 — apply vllm#45639 (AITER AR + Gemma-RMS fusion) #1772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
397f637
65be443
eb422fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| #!/usr/bin/env bash | ||
|
|
||
| # [DO NOT MERGE — experimental] MiniMax-M3 MXFP8 MI325X (gfx942) single-node vLLM | ||
| # recipe that validates vllm-project/vllm#45639 ("[ROCm][M3] Enable AITER AR + | ||
| # Gemma-RMS fusion for MiniMax-M3") on real MI325X hardware before an image | ||
| # rebuild. It applies #45639 in-place to the shipped vllm/vllm-openai-rocm:minimax-m3 | ||
| # image, then serves with the AITER fused all-reduce + RMSNorm path enabled. | ||
| # | ||
| # Mirrors minimaxm3_fp8_mi325x.sh otherwise (--block-size 128, --language-model-only, | ||
| # TRITON_ATTN, BF16 KV — gfx942 has no calibrated FP8 attention scales). The | ||
| # #45639-specific knobs: | ||
| # VLLM_ROCM_USE_AITER=1 (AITER kernels) | ||
| # --compilation-config custom_ops=["-minimax_gemma_rms_norm"] (allow IR lowering) | ||
| # --compilation-config pass_config.fuse_allreduce_rms=true (the fusion pass) | ||
| # The fusion needs TP>1; this recipe is swept at TP8. | ||
|
|
||
| source "$(dirname "$0")/../../benchmark_lib.sh" | ||
|
|
||
| check_env_vars \ | ||
| MODEL \ | ||
| TP \ | ||
| EP_SIZE \ | ||
| DP_ATTENTION \ | ||
| CONC \ | ||
| ISL \ | ||
| OSL \ | ||
| MAX_MODEL_LEN \ | ||
| RANDOM_RANGE_RATIO \ | ||
| RESULT_FILENAME | ||
|
|
||
| if [[ -n "$SLURM_JOB_ID" ]]; then | ||
| echo "JOB $SLURM_JOB_ID running on $SLURMD_NODENAME" | ||
| fi | ||
|
|
||
| if [[ "$MODEL" != /* ]]; then hf download "$MODEL"; fi | ||
|
|
||
| if [ -n "$ROCR_VISIBLE_DEVICES" ]; then | ||
| export HIP_VISIBLE_DEVICES="$ROCR_VISIBLE_DEVICES" | ||
| fi | ||
|
|
||
| # ---- Apply vllm-project/vllm#45639 in-place ------------------------------- | ||
| # The shipped minimax-m3 image predates #45639 (base m3_release). Apply the | ||
| # vendored diff to the installed vllm. Idempotent: if it is already applied | ||
| # (reverse-applies cleanly) we proceed; if it neither applies cleanly nor is | ||
| # already applied, the image has drifted from the PR base — hard-fail so we never | ||
| # silently benchmark an unpatched server. | ||
| PATCH_FILE="$(cd "$(dirname "$0")/patches" && pwd)/vllm-45639-aiter-ar-gemma-rms.diff" | ||
| command -v patch >/dev/null 2>&1 || { apt-get update -q -y && apt-get install -q -y patch; } | ||
| VLLM_SP="$(python3 -c 'import os, vllm; print(os.path.dirname(os.path.dirname(vllm.__file__)))')" | ||
| if ( cd "$VLLM_SP" && patch -p1 -R --dry-run < "$PATCH_FILE" >/dev/null 2>&1 ); then | ||
| echo "[vllm#45639] already applied to $VLLM_SP/vllm" | ||
| elif ( cd "$VLLM_SP" && patch -p1 --dry-run < "$PATCH_FILE" >/dev/null 2>&1 ); then | ||
| ( cd "$VLLM_SP" && patch -p1 < "$PATCH_FILE" ) | ||
| echo "[vllm#45639] applied to $VLLM_SP/vllm" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Patch apply errors ignoredHigh Severity After a successful Reviewed by Cursor Bugbot for commit 65be443. Configure here. |
||
| else | ||
| echo "FATAL: vllm#45639 patch neither applies cleanly nor is already applied" >&2 | ||
| echo " ($VLLM_SP/vllm has drifted from the PR's m3_release base)" >&2 | ||
| exit 1 | ||
| fi | ||
|
|
||
| SERVER_LOG=/workspace/server.log | ||
| export VLLM_ENGINE_READY_TIMEOUT_S=3600 | ||
| export VLLM_USE_BREAKABLE_CUDAGRAPH=0 | ||
| # #45639: AITER fused all-reduce + Gemma-RMSNorm. | ||
| export VLLM_ROCM_USE_AITER=1 | ||
| # DEBUG so the server log carries the fusion-pass match/replace counts | ||
| # ("RocmAiterAllReduceFusionPass Replaced N patterns", "fusion pass matches: {}") | ||
| # in addition to the (default-level) registration bail warnings. | ||
| export VLLM_LOGGING_LEVEL=DEBUG | ||
|
|
||
| if [ "${EVAL_ONLY}" = "true" ]; then | ||
| setup_eval_context | ||
| fi | ||
|
|
||
| PARALLEL_ARGS=(--tensor-parallel-size "$TP") | ||
| if [ "${DP_ATTENTION}" = "true" ]; then | ||
| PARALLEL_ARGS=( | ||
| --tensor-parallel-size 1 | ||
| --data-parallel-size "$TP" | ||
| --enable-expert-parallel | ||
| ) | ||
| elif [ "$EP_SIZE" -gt 1 ]; then | ||
| PARALLEL_ARGS+=(--enable-expert-parallel) | ||
| fi | ||
|
|
||
| start_gpu_monitor | ||
|
|
||
| # When PROFILE=1 (profile.yml), arm vLLM's torch profiler via --profiler-config. | ||
| # This minimax-m3 image's vLLM does NOT honour the VLLM_TORCH_PROFILER_DIR env | ||
| # var, so the serve flag is what makes /start_profile emit a trace. Write to the | ||
| # dir benchmark_lib's relay scans (VLLM_TORCH_PROFILER_DIR, default /workspace/). | ||
| PROFILE_ARGS=() | ||
| if [ "${PROFILE:-}" = "1" ]; then | ||
| PROFILE_ARGS=(--profiler-config "{\"profiler\": \"torch\", \"torch_profiler_dir\": \"${VLLM_TORCH_PROFILER_DIR:-/workspace/}\"}") | ||
| fi | ||
|
|
||
| set -x | ||
| vllm serve "$MODEL" --port "$PORT" \ | ||
| "${PARALLEL_ARGS[@]}" \ | ||
| "${PROFILE_ARGS[@]}" \ | ||
| --block-size 128 \ | ||
| --language-model-only \ | ||
| --max-model-len "$MAX_MODEL_LEN" \ | ||
| --attention-backend TRITON_ATTN \ | ||
| --no-enable-prefix-caching \ | ||
| --compilation-config '{"custom_ops": ["-minimax_gemma_rms_norm"], "pass_config": {"fuse_allreduce_rms": true}}' \ | ||
| --tool-call-parser minimax_m3 \ | ||
| --reasoning-parser minimax_m3 \ | ||
| --enable-auto-tool-choice > "$SERVER_LOG" 2>&1 & | ||
|
|
||
| SERVER_PID=$! | ||
| wait_for_server_ready --port "$PORT" --server-log "$SERVER_LOG" --server-pid "$SERVER_PID" | ||
|
|
||
| # ---- #45639 AITER AR + Gemma-RMS fusion diagnostics (definitive) ---------- | ||
| # Engine init (incl. torch.compile fusion passes) has finished by now, so the | ||
| # fusion-pass logging is in the server log. Two questions, answered from the log: | ||
| # 1) Did the pass REGISTER? Any of these warning_once strings => it registered | ||
| # ZERO patterns (match count is 0 by construction): | ||
| # "AllReduce fusion pass is disabled", "AITER allreduce fusion must be | ||
| # initialized", "AITER allreduce-rmsnorm fusion disabled: aiter<0.1.12" | ||
| # (the M3/6144 one), "Custom Allreduce is required". | ||
| # 2) Did it MATCH+REPLACE? "RocmAiterAllReduceFusionPass Replaced N patterns" | ||
| # (N>0 => matched & replaced; N==0 => matched nothing) and the per-pass | ||
| # "fusion pass matches: {...}" table. | ||
| set +x | ||
| echo "================ #45639 fusion-pass verdict ================" | ||
| echo "--- [1] registration bail warnings (presence => registered 0 patterns) ---" | ||
| grep -nE "AllReduce fusion pass is disabled|AITER allreduce fusion must be initialized|AITER allreduce-rmsnorm fusion disabled|Custom Allreduce is required" "$SERVER_LOG" \ | ||
| || echo " (none — no registration bail)" | ||
| echo "--- [2] match / replace counts ---" | ||
| grep -nE "RocmAiterAllReduceFusionPass Replaced [0-9]+ patterns|fusion pass matches:" "$SERVER_LOG" \ | ||
| || echo " (no 'Replaced N patterns' / 'fusion pass matches' line found)" | ||
| echo "===========================================================" | ||
| set -x | ||
|
|
||
| run_benchmark_serving \ | ||
| --model "$MODEL" \ | ||
| --port "$PORT" \ | ||
| --backend vllm \ | ||
| --input-len "$ISL" \ | ||
| --output-len "$OSL" \ | ||
| --random-range-ratio "$RANDOM_RANGE_RATIO" \ | ||
| --num-prompts "$((CONC * 10))" \ | ||
| --max-concurrency "$CONC" \ | ||
| --result-filename "$RESULT_FILENAME" \ | ||
| --result-dir /workspace/ \ | ||
| --trust-remote-code | ||
|
|
||
| if [ "${RUN_EVAL}" = "true" ]; then | ||
| run_eval --framework lm-eval --port "$PORT" | ||
| append_lm_eval_summary | ||
| fi | ||
|
|
||
| stop_gpu_monitor | ||
| set +x | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,203 @@ | ||
| diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py | ||
| index 4de5c6cf7ae5..0fd5cb830f5e 100644 | ||
| --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py | ||
| +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py | ||
| @@ -1439,6 +1439,51 @@ def _replacement( | ||
| return _replacement | ||
|
|
||
|
|
||
| +class AiterAllreduceFusedAddRMSNormWithCopyPattern(BasePattern, VllmPatternReplacement): | ||
| + """Non-quant AR+RMS fusion for all_reduce with 2 users (copy_). | ||
| + | ||
| + In GemmaRMSNorm models, the post-attention all_reduce has a copy_ | ||
| + node for cross-chunk residual state, giving it 2 users and preventing | ||
| + the standard pattern from matching. This pattern returns ar_out as | ||
| + an explicit output so the pattern matcher rewires external users | ||
| + (the copy_) to the fused kernel's residual output. | ||
| + """ | ||
| + | ||
| + def __init__(self, epsilon, dtype, device): | ||
| + super().__init__(dtype, device) | ||
| + self.epsilon = epsilon | ||
| + self.FUSED_OP = rocm_aiter_ops.get_fused_allreduce_rmsnorm_op() | ||
| + | ||
| + def get_inputs(self): | ||
| + return [self.empty(5, 16), self.empty(5, 16), self.empty(16)] | ||
| + | ||
| + @property | ||
| + def pattern(self): | ||
| + eps = self.epsilon | ||
| + | ||
| + def _pattern(residual, input_, weight): | ||
| + ar_out = tensor_model_parallel_all_reduce(input_) | ||
| + rms, res_out = vllm.ir.ops.fused_add_rms_norm(ar_out, residual, weight, eps) | ||
| + return rms, res_out, ar_out | ||
| + | ||
| + return _pattern | ||
| + | ||
| + @property | ||
| + def replacement(self): | ||
| + eps = self.epsilon | ||
| + | ||
| + def _replacement(residual, input_, weight): | ||
| + fused = self.FUSED_OP( | ||
| + input_=input_, | ||
| + residual=residual, | ||
| + weight=weight.to(input_.dtype), | ||
| + epsilon=eps, | ||
| + ) | ||
| + return fused[0], fused[1], fused[1] | ||
| + | ||
| + return _replacement | ||
| + | ||
| + | ||
| class RocmAiterAllReduceFusionPass(VllmFusionPatternMatcherPass): | ||
| def __init__(self, config: VllmConfig) -> None: | ||
| super().__init__(config, "rocm_aiter_allreduce_fusion_pass") | ||
| @@ -1503,9 +1548,13 @@ def __init__(self, config: VllmConfig) -> None: | ||
| return | ||
|
|
||
| max_token_num = max_size // (hidden_dim * element_size) | ||
| + # Cap at max_cudagraph_capture_size so fusion only fires | ||
| + # for decode. Prefill uses quickreduce + triton rmsnorm. | ||
| + max_cg = config.compilation_config.max_cudagraph_capture_size or 512 | ||
| self.max_token_num = min( | ||
| max_token_num, | ||
| config.scheduler_config.max_num_batched_tokens, | ||
| + max_cg, | ||
| ) | ||
|
|
||
| # Only register the AR+RMS+per-group-FP8-quant patterns when the | ||
| @@ -1524,6 +1573,15 @@ def __init__(self, config: VllmConfig) -> None: | ||
| "FP8 quant fusion." | ||
| ) | ||
|
|
||
| + # Non-quant copy-aware pattern for post-attention allreduce | ||
| + for epsilon in [1e-5, 1e-6]: | ||
| + self.register( | ||
| + AiterAllreduceFusedAddRMSNormWithCopyPattern( | ||
| + epsilon, self.model_dtype, self.device | ||
| + ) | ||
| + ) | ||
| + torch._inductor.pattern_matcher._seen_patterns.clear() | ||
| + | ||
| for epsilon in [1e-5, 1e-6]: | ||
| # Quant-fused variants must register first so the pattern matcher | ||
| # tries them before the AR+RMS-only variants. Otherwise the | ||
| diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py | ||
| index 13b0ae781314..8d0336622249 100644 | ||
| --- a/vllm/model_executor/layers/layernorm.py | ||
| +++ b/vllm/model_executor/layers/layernorm.py | ||
| @@ -159,7 +159,7 @@ def forward_native( | ||
| residual: torch.Tensor | None = None, | ||
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | ||
| """PyTorch-native implementation equivalent to forward().""" | ||
| - weight = self.weight.float() + 1.0 | ||
| + weight = self.weight.data.to(x.dtype) + 1.0 | ||
| if residual is None: | ||
| return ir.ops.rms_norm(x, weight, self.variance_epsilon) | ||
| return ir.ops.fused_add_rms_norm(x, residual, weight, self.variance_epsilon) | ||
| diff --git a/vllm/models/minimax_m3/amd/model.py b/vllm/models/minimax_m3/amd/model.py | ||
| index b80d3b8b3b8c..eecd6d335fdd 100644 | ||
| --- a/vllm/models/minimax_m3/amd/model.py | ||
| +++ b/vllm/models/minimax_m3/amd/model.py | ||
| @@ -23,7 +23,7 @@ | ||
| from torch import nn | ||
| from transformers import PretrainedConfig | ||
|
|
||
| -from vllm import _custom_ops as ops | ||
| +from vllm import _custom_ops as ops, ir | ||
| from vllm.compilation.breakable_cudagraph import eager_break_during_capture | ||
| from vllm.config import ( | ||
| CacheConfig, | ||
| @@ -32,6 +32,7 @@ | ||
| ) | ||
| from vllm.distributed import get_tensor_model_parallel_world_size | ||
| from vllm.forward_context import get_forward_context | ||
| +from vllm.model_executor.custom_op import CustomOp | ||
| from vllm.model_executor.layers.attention import Attention | ||
| from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase | ||
| from vllm.model_executor.layers.fused_allreduce_gemma_rms_norm import ( | ||
| @@ -160,17 +161,24 @@ def _build_rotary_emb(config: PretrainedConfig, head_dim: int): | ||
| ) | ||
|
|
||
|
|
||
| -class MiniMAXGemmaRMSNorm(nn.Module): | ||
| - """Gemma-style RMS normalization (native ROCm implementation). | ||
| - | ||
| - Normalizes in fp32 and scales by ``(1 + weight)`` — numerically equivalent | ||
| - to the FlashInfer ``gemma_rmsnorm`` / ``gemma_fused_add_rmsnorm`` kernels | ||
| - used in the NVIDIA path, which are unavailable on ROCm. When ``residual`` is | ||
| - given, the fused add + norm returns the updated ``(normed, residual)`` pair. | ||
| - | ||
| - The fp32 normalize + scale + (optional) residual-add run in a single fused | ||
| - Triton pass (``amd.ops.gemma_rmsnorm`` / ``gemma_fused_add_rmsnorm``) instead | ||
| - of a chain of elementwise PyTorch kernels. | ||
| +@CustomOp.register("minimax_gemma_rms_norm") | ||
| +class MiniMAXGemmaRMSNorm(CustomOp): | ||
| + """Gemma-style RMS normalization for the M3 ROCm path. | ||
| + | ||
| + Default (custom op enabled, ``forward_hip``): an fp32 Triton pass | ||
| + (``amd.ops.gemma_rmsnorm`` / ``gemma_fused_add_rmsnorm``) — numerically | ||
| + equivalent to the FlashInfer ``gemma_rmsnorm`` kernels used in the NVIDIA | ||
| + path, which are unavailable on ROCm. This is the unchanged M3 default. | ||
| + | ||
| + When the custom op is disabled (``--compilation-config | ||
| + '{"custom_ops":["-minimax_gemma_rms_norm"]}'``), ``forward_native`` instead | ||
| + emits the plain ``ir.ops.rms_norm`` / ``ir.ops.fused_add_rms_norm`` IR ops, | ||
| + with the Gemma ``1 + weight`` offset folded into the weight. That exposes the | ||
| + post-attention ``all_reduce -> fused_add_rms_norm`` sequence to the AITER | ||
| + AR+RMS fusion pass (``RocmAiterAllReduceFusionPass``), letting it fuse the | ||
| + decode-time allreduce + residual-add + rmsnorm into a single AITER kernel at | ||
| + TP>1. Opt-in because it swaps M3's accuracy-tuned fp32 norm path for the IR | ||
| + lowering: validate gsm8k parity before enabling by default. | ||
| """ | ||
|
|
||
| def __init__( | ||
| @@ -182,11 +190,24 @@ def __init__( | ||
| self.weight = nn.Parameter(torch.zeros(hidden_size)) | ||
| self.variance_epsilon = eps | ||
|
|
||
| - def forward( | ||
| + def forward_native( | ||
| + self, | ||
| + x: torch.Tensor, | ||
| + residual: torch.Tensor | None = None, | ||
| + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | ||
| + # Fusable IR path, matched by the AITER AR+RMS fusion pass. Fold the | ||
| + # Gemma ``1 + weight`` offset into the weight in x's dtype. | ||
| + weight = self.weight.data.to(x.dtype) + 1.0 | ||
| + if residual is None: | ||
| + return ir.ops.rms_norm(x, weight, self.variance_epsilon) | ||
| + return ir.ops.fused_add_rms_norm(x, residual, weight, self.variance_epsilon) | ||
| + | ||
| + def forward_hip( | ||
| self, | ||
| x: torch.Tensor, | ||
| residual: torch.Tensor | None = None, | ||
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | ||
| + # Default ROCm path: fp32 Triton gemma kernels (unchanged M3 behavior). | ||
| if residual is None: | ||
| return gemma_rmsnorm(x, self.weight, self.variance_epsilon) | ||
| return gemma_fused_add_rmsnorm(x, residual, self.weight, self.variance_epsilon) | ||
| diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py | ||
| index aaf1fdce36bb..b8d49b6efc68 100644 | ||
| --- a/vllm/platforms/rocm.py | ||
| +++ b/vllm/platforms/rocm.py | ||
| @@ -976,9 +976,13 @@ def get_default_ir_op_priority( | ||
| using_inductor = cc.backend == "inductor" and cc.mode != CompilationMode.NONE | ||
| default = ["native"] if using_inductor else ["vllm_c", "native"] | ||
|
|
||
| - # Aiter rms norm perform best when CUDA Graph capture is enabled. | ||
| - # TODO(luka/TJ) remove env vars completely | ||
| - if ( | ||
| + # When allreduce+rmsnorm fusion is enabled (default on ROCm TP>1), | ||
| + # use native priority so triton can fuse rmsnorm with adjacent ops. | ||
| + # The aiter CK rmsnorm is opaque to triton and blocks fusion. | ||
| + fuse_ar_rms = cc.pass_config.fuse_allreduce_rms | ||
| + if using_inductor and fuse_ar_rms is not False: | ||
| + rms_norm = default # ["native"] | ||
| + elif ( | ||
| cc.cudagraph_mode != CUDAGraphMode.NONE | ||
| and envs.VLLM_ROCM_USE_AITER | ||
| and envs.VLLM_ROCM_USE_AITER_RMSNORM |


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
conc-list breaks full-sweep
Medium Severity
The new single-node
fixed-seq-lenrows use onlyconc-list, butgenerate_full_sweepreadsconc-startandconc-endfor single-node entries and will raiseKeyErrorwhen it hits this config during an unfiltered amd-master full sweep.Reviewed by Cursor Bugbot for commit eb422fe. Configure here.