perf(vllm): fuse MiniMax M3 BF16 EP experts on MI300X#1782
Conversation
|
Thanks for the contribution! For vLLM & SGLang, please ensure that your recipes is similar to the official vLLM recipes and/or the SGLang cookbook If it is not, please create a PR first before we can merge your single node PR into the master branch. Let's ensure that the documentation is first class such that the entire ML community can benefit from your hard work! Thank you PR authors are responsible for ensuring that after merging, all GitHub Action jobs fully pass. A lot of the time, failures are just flakes and simply re-running the failed jobs will fix it. If re-running failed jobs is attempted, PR authors are responsible for ensuring it passes. See GitHub's docs on re-running failed jobs: https://docs.github.com/en/actions/how-tos/manage-workflow-runs/re-run-workflows-and-jobs#re-running-failed-jobs-in-a-workflow As a rule of thumb, generally, PR authors should request a review & get a PR approval from the respective companies' CODEOWNERS before requesting a review from core maintainers. If additional help is needed, PR authors can reach out to core maintainers over Slack. |
Summary
long-context shape on MI300X
experts instead of 128 global experts
expert instead of the existing 64-row tile
activation kernel and the 2x-intermediate GEMM1 output
This PR is stacked on #1753 and contains only the incremental EP8 optimization.
It does not include the profiling branch, AITER allreduce/RMSNorm work,
temporary benchmark configuration, or
perf-changelog.yamlchanges.Profile basis
The six-point MI300X profile found expert GEMM1+GEMM2 at 30.31 ms for 1k/c256
and 28.10 ms for 8k/c256. After collective fusion, expert GEMMs remained the
largest classified 8k/c256 phase at 28.79 ms across 114 calls.
At c256, MiniMax M3 has about 216 active tokens and top-k 4, or 864 routed rows
globally. EP8 owns 16 of 128 experts per rank, leaving about 108 local rows,
roughly 6.75 rows per local expert. The existing BF16 config uses a 64-row M
tile, so it can execute about 1,024 padded rows per rank for roughly 108 useful
rows. Global alignment also creates blocks for remote experts that do no useful
GEMM work.
Profile report:
https://github.com/SemiAnalysisAI/InferenceX/blob/profiling/experimental/minimax_m3_mi300x_profile.md
First-principles changes
is based on 16 local experts, while the device counter remains authoritative.
BLOCK_SIZE_M=16, matching the observed route densityand reducing padded expert-row computation by up to 4x versus the 64-row
tile.
applies split SwiGLU-OAI before storing. This halves its BF16 output traffic
and removes a separate activation launch.
It avoids direct atomic accumulation, which the profile identified as a poor
fit for the c256 top-k-4 shape.
The path is gated to the exact gfx94x MiniMax M3 EP8 BF16 shape. gfx95x and
other models/configurations are unchanged.
Validation
Static and local validation:
python -m pytest utils/matrix_logic/ -q: 156 passedbash -n benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.shcompileall, andgit diff --checkexpert-map reduction correctness tests
MI300X serving validation is pending infrastructure recovery. The exact six-job
matrix (c1/c16/c256 for 1k1k and 8k1k) was dispatched four times, but every
attempt failed before GPU allocation because the Slurm controller was
unreachable:
https://github.com/SemiAnalysisAI/InferenceX/actions/runs/27569397626