Skip to content

Add HipKittens based nhead=32 MLA kernel on MI35x / gfx950#3003

Open
hubertlu-tw wants to merge 2 commits intoROCm:mainfrom
hubertlu-tw:hk_h32_mla
Open

Add HipKittens based nhead=32 MLA kernel on MI35x / gfx950#3003
hubertlu-tw wants to merge 2 commits intoROCm:mainfrom
hubertlu-tw:hk_h32_mla

Conversation

@hubertlu-tw
Copy link
Copy Markdown
Contributor

@hubertlu-tw hubertlu-tw commented May 1, 2026

Motivation

This PR extends the HipKittens MLA decode path to the nhead=32 per-rank shape used by TP4 DeepSeek-R1 on MI35x / gfx950.

The existing HipKittens MLA path covers the H128 layout, but the TP4 workload reaches AITER MLA as H32 (nhead=32, decode_qlen=4, FP8 Q/KV, page_size=1). This PR adds the missing H32 kernel support and routes each H32 shape to the faster implementation: HipKittens for the measured small-context region and the existing ASM kernel elsewhere.

Modifications

  • Add H32 support to the existing HipKittens MLA decode kernel template while keeping the kernel restricted to MI35x / gfx950 dispatch.
  • Keep H32 native metadata support for gfx950, FP8 Q/KV, and max_seqlen_qo == 4.
  • Replace temporary env-driven routing with a predefined shape-aware crossover in aiter/mla.py:
    • batch <= 8: HK when average context per sequence is <= 4096.
    • batch <= 16: HK when average context per sequence is <= 2048.
    • batch <= 32: HK when average context per sequence is <= 1024.
    • batch <= 64: HK when average context per sequence is <= 256.
    • batch <= 256: HK when average context per sequence is <= 64.
    • Larger batches or contexts fall back to ASM.
  • Remove the temporary AITER_HK_MLA_FORCE_NATIVE_METADATA, AITER_HK_MLA_ENABLE_H32, and AITER_HK_MLA_ENABLE_H32_LONG code paths from tracked source.
  • Update H32 correctness, graph replay, and benchmark scripts to use the final H32 path directly.

Accuracy Tests

Environment: MI355X / gfx950, HIP_VISIBLE_DEVICES=0, AITER_ENABLE_EXPERIMENTAL=1.

python3 op_tests/op_benchmarks/hip/compare_hk_mla_h32_reference.py \
  --shapes 1x512,2x2048,4x8192,4x16384 \
  --csv /tmp/hk_h32_cleanup_correctness.csv
batch ctx total_kv HK/ASM cosine distance HK/causal-reference cosine distance
1 512 512 0.000000 0.000135
2 2048 4096 0.001556 0.000089
4 8192 32768 0.001690 0.000041
4 16384 65536 0.001757 0.000026

Dispatch routing smoke test:

batch ctx expected route observed route
8 4096 HK use_hk=True
8 8192 ASM use_hk=False
32 1024 HK use_hk=True
64 512 ASM use_hk=False

Benchmarking and Profiling

Command:

AITER_ENABLE_EXPERIMENTAL=1 HIP_VISIBLE_DEVICES=0 \
python3 op_tests/op_benchmarks/hip/bench_hk_mla_decode.py \
  --nhead 32 \
  --batch-sizes 1,3,5,16,32,64,128,256 \
  --decode-qlen 4 \
  --page-sizes 1 \
  --ctx-lens 21,64,256,512,1200,3200,5200,8192 \
  --warmup 20 --iters 100 \
  --csv /tmp/hk_h32_pr2039_style_microbench.csv

The following table follows the same layout as #2039 and reports p50 end-to-end MLA latency (decode + reduce). Positive gain means HK is faster.

MI355X / gfx950
batch ctx ASM (us) HK (us) gain (%)
1 21 61.441 31.520 48.70%
1 64 59.081 33.100 43.97%
1 256 62.521 36.361 41.84%
1 512 62.241 35.160 43.51%
1 1200 64.140 44.320 30.90%
1 3200 68.280 59.321 13.12%
1 5200 70.860 73.141 -3.22%
1 8192 75.900 93.121 -22.69%
3 21 59.441 32.880 44.68%
3 64 60.521 34.041 43.75%
3 256 66.960 40.801 39.07%
3 512 66.680 39.140 41.30%
3 1200 70.981 50.440 28.94%
3 3200 75.021 65.041 13.30%
3 5200 77.620 78.180 -0.72%
3 8192 81.001 97.321 -20.15%
5 21 59.361 33.440 43.67%
5 64 61.801 35.041 43.30%
5 256 71.561 44.360 38.01%
5 512 70.581 42.821 39.33%
5 1200 75.321 52.620 30.14%
5 3200 75.661 65.101 13.96%
5 5200 77.261 79.320 -2.67%
5 8192 82.041 101.081 -23.21%
16 21 57.941 30.480 47.39%
16 64 65.720 37.121 43.52%
16 256 78.221 49.600 36.59%
16 512 79.000 59.220 25.04%
16 1200 83.801 66.381 20.79%
16 3200 91.441 100.441 -9.84%
16 5200 97.661 122.261 -25.19%
16 8192 107.081 166.402 -55.40%
32 21 58.100 29.240 49.67%
32 64 71.041 47.960 32.49%
32 256 74.561 53.801 27.84%
32 512 79.361 62.081 21.77%
32 1200 83.441 87.621 -5.01%
32 3200 99.001 142.642 -44.08%
32 5200 110.801 196.102 -76.99%
32 8192 122.121 268.423 -119.80%
64 21 58.941 30.920 47.54%
64 64 59.360 36.460 38.58%
64 256 72.721 63.441 12.76%
64 512 78.401 78.360 0.05%
64 1200 96.601 127.562 -32.05%
64 3200 113.621 225.102 -98.12%
64 5200 134.861 337.123 -149.98%
64 8192 164.242 542.365 -230.22%
128 21 60.740 31.101 48.80%
128 64 59.561 37.540 36.97%
128 256 66.761 78.740 -17.94%
128 512 91.261 125.401 -37.41%
128 1200 112.121 200.862 -79.15%
128 3200 142.181 413.524 -190.84%
128 5200 192.302 672.766 -249.85%
128 8192 236.882 1023.450 -332.05%
256 21 63.601 33.101 47.96%
256 64 61.920 40.261 34.98%
256 256 70.940 82.261 -15.96%
256 512 79.761 136.362 -70.96%
256 1200 103.401 289.083 -179.57%
256 3200 182.082 781.467 -329.18%
256 5200 264.782 1262.792 -376.92%
256 8192 383.963 1981.220 -415.99%
image

This table is why the final dispatch uses a batch/context crossover instead of a single total_kv threshold. The code intentionally keeps ASM for the long-context rows where HK regresses.

SGLang E2E Validation

Downstream SGLang validation used DeepSeek-R1-MXFP4 TP4 EAGLE on MI355X / gfx950, with AR-fusion disabled and no explicit --chunked-prefill-size 131072. The run used --max-running-requests 8 and --max-total-tokens 32768 so SGLang's captured H32 qlen=4 graph lands inside the HK crossover region.

Server flags shared by both rows:

SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1 \
SGLANG_ENABLE_SPEC_V2=1 \
ROCM_QUICK_REDUCE_QUANTIZATION=NONE \
SGLANG_AITER_FP8_PREFILL_ATTN=1 \
SGLANG_AITER_MLA_PERSIST=1 \
AITER_MXFP4_MOE_SF=1 \
SGLANG_USE_AITER=1 \
SGLANG_INT4_WEIGHT=0 \
SGLANG_MOE_PADDING=1 \
SGLANG_SET_CPU_AFFINITY=1 \
SGLANG_ROCM_FUSED_DECODE_MLA=1 \
SGLANG_USE_ROCM700A=1 \
AITER_MLA_LOG_SHAPES=1 \
PYTHONPATH=/sgl-workspace/aiter:/sgl-workspace/sglang/python:${PYTHONPATH} \
python3 -m sglang.launch_server \
  --model-path amd/DeepSeek-R1-MXFP4 \
  --tensor-parallel-size 4 \
  --trust-remote-code \
  --host 0.0.0.0 \
  --port 8000 \
  --mem-fraction-static 0.9 \
  --attention-backend aiter \
  --speculative-algorithm EAGLE \
  --speculative-num-steps 3 \
  --speculative-eagle-topk 1 \
  --speculative-num-draft-tokens 4 \
  --max-running-requests 8 \
  --max-total-tokens 32768 \
  --context-length 4096 \
  --kv-cache-dtype fp8_e4m3 \
  --model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 8}' \
  --watchdog-timeout 10000

HK row adds:

AITER_ENABLE_EXPERIMENTAL=1

Serving benchmark:

python3 -m sglang.bench_serving \
  --host 0.0.0.0 --port 8000 \
  --model amd/DeepSeek-R1-MXFP4 \
  --dataset-name random \
  --random-input-len 1000 \
  --random-output-len 1000 \
  --random-range-ratio 1.0 \
  --num-prompts 16 \
  --max-concurrency 8
case total tok/s median TTFT ms median TPOT ms median E2E ms accept len HK H32 used
ASM baseline 1626.00 1133.48 8.29 9488.58 2.76 no
HK H32 1778.54 (+9.38%) 654.28 (-42.28%) 7.56 (-8.81%) 8034.48 (-15.32%) 2.79 (+1.09%) yes

HK usage evidence from AITER_MLA_LOG_SHAPES=1:

log use_hk=True use_hk=False module_hk_mla imports
ASM baseline 0 96 0
HK H32 4 92 4

Representative HK capture line:

[AITER_MLA_DECODE_SHAPE] q_shape=(32, 32, 576) kv_shape=(32769, 1, 1, 576) o_shape=(32, 32, 512) q_dtype=torch.float8_e4m3fn kv_dtype=torch.float8_e4m3fn nhead=32 max_seqlen_q=4 page_size=1 nhead_kv=1 total_kv=32768 persistent=True hk_supported_head_shape=True hk_avg_ctx_per_seq=4096 hk_max_ctx_per_seq=4096 use_hk=True
[aiter] import [module_hk_mla] under /sgl-workspace/aiter/aiter/jit/module_hk_mla.so

GSM8K on the same HK-enabled server:

python3 benchmark/gsm8k/bench_sglang.py \
  --num-questions 1319 \
  --parallel 16 \
  --num-shots 5 \
  --port 8000
case accuracy invalid latency s output tok/s
HK H32 GSM8K 0.944 0.000 228.780 580.957

Log bundle:

/tmp/hk_h32_sglang_e2e_20260501_2125/bench_asm_mrt8_1k_1k.log
/tmp/hk_h32_sglang_e2e_20260501_2125/bench_hk_mrt8_1k_1k.log
/tmp/hk_h32_sglang_e2e_20260501_2125/server_asm_mrt8.log
/tmp/hk_h32_sglang_e2e_20260501_2125/server_hk_mrt8.log
/tmp/hk_h32_sglang_e2e_20260501_2125/gsm8k_hk_mrt8.log

Checklist

  • Added H32 HK kernel support and Python dispatch.
  • Removed temporary H32/native-metadata env knobs from tracked source.
  • Added focused H32 correctness and benchmark scripts.
  • Ran H32 correctness comparison.
  • Ran H32 microbenchmark crossover sweep.
  • Verified dispatch selects HK and ASM on representative boundary shapes.
  • Ran SGLang 1k/1k E2E comparison and confirmed HK H32 graph capture.
  • Ran SGLang GSM8K on the HK-enabled server.
  • Run CI pre-checks in the project CI environment.

Review and Merge Process

Please review the H32 kernel changes, metadata routing, and the shape-aware crossover in aiter/mla.py. The intended behavior is conservative: enable HK only on gfx950 H32 shapes where the microbenchmark shows a win, and keep the existing ASM path everywhere else.

Enable experimental gfx950 H32 HK MLA decode routing and validation scripts for the FP8 path.
Use a measured gfx950 H32 crossover to choose between HipKittens and ASM, and remove temporary env-gated metadata paths.
@hubertlu-tw hubertlu-tw requested review from a team, HaiShaw, ruanjm and valarLip May 1, 2026 21:55
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 1, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 3003 --add-label <label>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant