Add HipKittens based nhead=32 MLA kernel on MI35x / gfx950#3003
Open
hubertlu-tw wants to merge 2 commits intoROCm:mainfrom
Open
Add HipKittens based nhead=32 MLA kernel on MI35x / gfx950#3003hubertlu-tw wants to merge 2 commits intoROCm:mainfrom
gfx950#3003hubertlu-tw wants to merge 2 commits intoROCm:mainfrom
Conversation
Enable experimental gfx950 H32 HK MLA decode routing and validation scripts for the FP8 path.
Use a measured gfx950 H32 crossover to choose between HipKittens and ASM, and remove temporary env-gated metadata paths.
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
This PR extends the HipKittens MLA decode path to the
nhead=32per-rank shape used by TP4 DeepSeek-R1 on MI35x /gfx950.The existing HipKittens MLA path covers the H128 layout, but the TP4 workload reaches AITER MLA as H32 (
nhead=32,decode_qlen=4, FP8 Q/KV,page_size=1). This PR adds the missing H32 kernel support and routes each H32 shape to the faster implementation: HipKittens for the measured small-context region and the existing ASM kernel elsewhere.Modifications
gfx950dispatch.gfx950, FP8 Q/KV, andmax_seqlen_qo == 4.aiter/mla.py:batch <= 8: HK when average context per sequence is<= 4096.batch <= 16: HK when average context per sequence is<= 2048.batch <= 32: HK when average context per sequence is<= 1024.batch <= 64: HK when average context per sequence is<= 256.batch <= 256: HK when average context per sequence is<= 64.AITER_HK_MLA_FORCE_NATIVE_METADATA,AITER_HK_MLA_ENABLE_H32, andAITER_HK_MLA_ENABLE_H32_LONGcode paths from tracked source.Accuracy Tests
Environment: MI355X /
gfx950,HIP_VISIBLE_DEVICES=0,AITER_ENABLE_EXPERIMENTAL=1.Dispatch routing smoke test:
use_hk=Trueuse_hk=Falseuse_hk=Trueuse_hk=FalseBenchmarking and Profiling
Command:
The following table follows the same layout as #2039 and reports p50 end-to-end MLA latency (
decode + reduce). Positive gain means HK is faster.This table is why the final dispatch uses a batch/context crossover instead of a single
total_kvthreshold. The code intentionally keeps ASM for the long-context rows where HK regresses.SGLang E2E Validation
Downstream SGLang validation used DeepSeek-R1-MXFP4 TP4 EAGLE on MI355X /
gfx950, with AR-fusion disabled and no explicit--chunked-prefill-size 131072. The run used--max-running-requests 8and--max-total-tokens 32768so SGLang's captured H32 qlen=4 graph lands inside the HK crossover region.Server flags shared by both rows:
HK row adds:
Serving benchmark:
HK usage evidence from
AITER_MLA_LOG_SHAPES=1:use_hk=Trueuse_hk=Falsemodule_hk_mlaimportsRepresentative HK capture line:
GSM8K on the same HK-enabled server:
Log bundle:
Checklist
Review and Merge Process
Please review the H32 kernel changes, metadata routing, and the shape-aware crossover in
aiter/mla.py. The intended behavior is conservative: enable HK only ongfx950H32 shapes where the microbenchmark shows a win, and keep the existing ASM path everywhere else.