Skip to content

feat: intracard cp for sm90#86

Open
Hyaloid wants to merge 2 commits into
inclusionAI:mainfrom
Hyaloid:intcd-cp
Open

feat: intracard cp for sm90#86
Hyaloid wants to merge 2 commits into
inclusionAI:mainfrom
Hyaloid:intcd-cp

Conversation

@Hyaloid

@Hyaloid Hyaloid commented Jun 4, 2026

Copy link
Copy Markdown

📌 Description

The serial bottleneck

kda_prefill_hopper (cuLA's SM90 KDA prefill) launches one CTA per (seq, head) and runs a strictly
sequential chunk recurrence inside each sequence: h_t = decay(g_t) · h_{t-1} + k_t^T @ (u_t − w_t·h_{t-1}).
Within one sequence, work cannot parallelize across chunks — only across the (raw_batch × H) grid.

This becomes a bottleneck when both:

  1. raw_batch × H is small — the baseline grid under‑utilizes the SMs. A single long sequence at
    H=8 occupies only 8 CTAs on a 132‑SM H100 (~6% occupancy). The per‑SM work is so small that most of the card is idle waiting on 8 serial chains.
  2. The shape has a long‑tail sequence (e.g. 128K+1K packed) — the long seq's serial recurrence
    dominates wall time while short seqs finish in microseconds and leave SMs idle.

Approach

Mirroring FLA's intra‑card CP design (and the SM100 cuLA path in cula/ops/cp/chunk_delta_h.py),
this PR splits long sequences into CP‑chunks on the same card and produces per‑CP‑chunk initial
states so the main C++ kernel can run all CP‑chunks in parallel.

🔍 Related Issues

Similar to this issue #20 , but for SM90.

🚀 Pull Request Checklist

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.
clang-format.............................................................Passed
ruff (legacy alias)......................................................Passed
ruff format..............................................................Passed

🧪 Tests

python -m pytest tests/test_intracard_cp_sm90.py -v

platform linux -- Python 3.12.3, pytest-9.1.0, pluggy-1.6.0 -- /opt/torch/bin/python
cachedir: .pytest_cache
configfile: pyproject.toml
plugins: anyio-4.12.1
collected 63 items                                                                                                                                               

tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens0-4-False] PASSED
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens1-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens2-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens3-8-False] PASSED
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens4-8-True] PASSED
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens5-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens6-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens7-4-False] PASSED
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens0-4-False] PASSED
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens1-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens2-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens3-8-False] PASSED
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens4-8-True] PASSED
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens5-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens6-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens7-4-False] PASSED
tests/test_intracard_cp_sm90.py::test_cp_off_matches_basic_baseline PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens0-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens1-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens2-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens3-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens4-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens5-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens6-8] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens7-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens8-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens9-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens10-8] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens11-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens12-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens13-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens0-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens1-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens2-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens3-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens4-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens5-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens6-8] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens7-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens8-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens9-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens10-8] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens11-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens12-4] PASSED
tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens13-4] PASSED
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens0-4] PASSED
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens1-4] PASSED
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens2-8] PASSED
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens3-4] PASSED
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens4-4] PASSED
tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens5-4] PASSED
tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens0-4-False] PASSED
tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens1-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens2-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens3-4-False] PASSED
tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens4-4-True] PASSED
tests/test_intracard_cp_sm90.py::test_cp_stress_repeat[single-64K-H4-h0] PASSED
tests/test_intracard_cp_sm90.py::test_cp_stress_repeat[multi-64K+4K-H4-h0] PASSED
tests/test_intracard_cp_sm90.py::test_cp_h0_none_equiv_h0_zeros PASSED
tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens0-8] PASSED
tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens1-64] PASSED
tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens2-8] PASSED
tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens3-8] PASSED

============================================================================= PASSES =============================================================================
====================================================================== slowest 15 durations ======================================================================
10.53s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens0-4-False]
9.77s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens3-8-False]
7.66s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens5-4-True]
0.20s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_stress_repeat[multi-64K+4K-H4-h0]
0.17s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_stress_repeat[single-64K-H4-h0]
0.12s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens0-4]
0.06s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens3-8]
0.04s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens7-4-False]
0.03s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens7-4-False]
0.03s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens4-4-True]
0.02s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens4-8-True]
0.02s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens4-8-True]
0.02s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens2-4-True]
0.02s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens2-4-True]
0.02s call     github/cuLA/tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens1-64]
==================================================================== short test summary info =====================================================================
PASSED tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens0-4-False]
PASSED tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens1-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens2-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens3-8-False]
PASSED tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens4-8-True]
PASSED tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens5-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens6-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_cp_matches_basic_baseline[seq_lens7-4-False]
PASSED tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens0-4-False]
PASSED tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens1-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens2-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens3-8-False]
PASSED tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens4-8-True]
PASSED tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens5-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens6-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_auto_router_matches_basic_baseline[seq_lens7-4-False]
PASSED tests/test_intracard_cp_sm90.py::test_cp_off_matches_basic_baseline
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens0-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens1-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens2-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens3-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens4-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens5-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens6-8]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens7-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens8-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens9-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens10-8]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens11-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens12-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_vs_fla[seq_lens13-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens0-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens1-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens2-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens3-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens4-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens5-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens6-8]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens7-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens8-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens9-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens10-8]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens11-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens12-4]
PASSED tests/test_intracard_cp_sm90.py::test_irregular_varlen_opt_matches_basic[seq_lens13-4]
PASSED tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens0-4]
PASSED tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens1-4]
PASSED tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens2-8]
PASSED tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens3-4]
PASSED tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens4-4]
PASSED tests/test_intracard_cp_sm90.py::test_cp_vs_fla[seq_lens5-4]
PASSED tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens0-4-False]
PASSED tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens1-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens2-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens3-4-False]
PASSED tests/test_intracard_cp_sm90.py::test_cp_final_state_per_seq[seq_lens4-4-True]
PASSED tests/test_intracard_cp_sm90.py::test_cp_stress_repeat[single-64K-H4-h0]
PASSED tests/test_intracard_cp_sm90.py::test_cp_stress_repeat[multi-64K+4K-H4-h0]
PASSED tests/test_intracard_cp_sm90.py::test_cp_h0_none_equiv_h0_zeros
PASSED tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens0-8]
PASSED tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens1-64]
PASSED tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens2-8]
PASSED tests/test_intracard_cp_sm90.py::test_cp_bypass_matches_basic[seq_lens3-8]
====================================================================== 63 passed in 35.27s =======================================================================


  • Tests have been added or updated as needed.
  • All tests are passing.

⚡ Performance

python benchmarks/bench_intracard_cp_sm90.py

==============================================================================================================
                       BENCHMARK REPORT: Intracard CP (SM90)
                       CP-on (kda_prefill_hopper_auto) vs CP-off (kda_prefill_hopper) vs FLA (chunk_kda)
                       D=128  dtype=bf16  safe_gate=True
                       Warmup=10  Iters=10
==============================================================================================================

  [H=4]
  ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  config                         T  pred  sub fla_cp fla_sub fused_pre  │         o max/mean        ht max/mean  │    FLA(ms)  CP_off(ms)   CP_on(ms)  CP_on/off  CP_on/FLA
  ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  4x256                       1024     N     0      N       0     Y  │  2.4e-04/2.4e-07  2.0e-03/3.5e-07  │     0.9339      0.2983      0.1474     2.02x      6.34x
  8x256                       2048     N     0      N       0     Y  │  2.4e-04/2.4e-07  2.0e-03/2.7e-07  │     0.9344      0.2855      0.1437     1.99x      6.50x
  16x256                      4096     N     0      N       0     Y  │  2.4e-04/2.1e-07  9.8e-04/2.5e-07  │     0.9261      0.2938      0.1514     1.94x      6.12x
  4x1K                        4096     N     0      N       0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     1.0088      0.3029      0.2818     1.08x      3.58x
  8x1K                        8192     N     0      N       0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     0.9339      0.4129      0.3125     1.32x      2.99x
  4x2K                        8192     N     0      N       0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     0.9280      0.4379      0.4387     1.00x      2.12x
  1K+512+256+128              1920     N     0      N       0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     0.9152      0.3111      0.2319     1.34x      3.95x
  2K+1K+512+256               3840     N     0      N       0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     0.9297      0.4199      0.4183     1.00x      2.22x
  1K+1+63+65+129              1282     N     0      N       0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     0.9270      0.3058      0.2307     1.33x      4.02x
  T=4K                        4096     Y    16      N       0     Y  │  2.4e-04/7.1e-07  1.5e-04/9.8e-08  │     0.9471      0.7886      0.4698     1.68x      2.02x
  T=8K                        8192     Y    32      N       0     Y  │  2.4e-04/2.0e-07  1.8e-07/6.4e-12  │     0.9569      1.5316      0.5290     2.90x      1.81x
  T=32K                      32768     Y    32      Y       4     N  │  2.4e-04/7.7e-08  0.0e+00/0.0e+00  │     2.1121      5.9842      1.3410     4.46x      1.58x
  T=64K                      65536     Y    32      Y       8     N  │  3.1e-04/4.1e-07  5.6e-06/5.0e-10  │     2.5598     11.9121      2.4417     4.88x      1.05x
  T=128K                    131072     Y    32      Y      16     N  │  2.4e-04/7.1e-09  0.0e+00/0.0e+00  │     4.1462     23.7834      4.7113     5.05x      0.88x
  8x4K                       32768     N     0      N       0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     1.2089      0.8655      0.8605     1.01x      1.40x
  4x8K                       32768     Y    32      N       0     N  │  2.4e-04/6.9e-08  0.0e+00/0.0e+00  │     1.2390      1.6120      1.4469     1.11x      0.86x
  2x16K                      32768     Y    32      N       0     N  │  2.4e-04/7.5e-08  0.0e+00/0.0e+00  │     1.2859      3.1102      1.4208     2.19x      0.91x
  16K+16K                    32768     Y    32      N       0     N  │  2.4e-04/7.5e-08  0.0e+00/0.0e+00  │     1.2724      3.1129      1.4376     2.17x      0.89x
  24K+8K                     32768     Y    32      Y       4     N  │  2.4e-04/7.4e-08  0.0e+00/0.0e+00  │     2.1055      4.5110      1.4370     3.14x      1.47x
  28K+4K                     32768     Y    32      Y       5     N  │  2.4e-04/7.6e-08  0.0e+00/0.0e+00  │     2.0744      5.2414      1.4349     3.65x      1.45x
  32K+256+256                33280     Y    34      Y       6     N  │  2.4e-04/7.6e-08  0.0e+00/0.0e+00  │     2.1354      5.9829      1.5029     3.98x      1.42x
  40K+1K+8K                  50176     Y    25      Y       7     N  │  3.7e-04/2.4e-07  0.0e+00/0.0e+00  │     2.3356      7.4879      2.1500     3.48x      1.09x
  64K+512+256+128            66432     Y    35      Y      11     N  │  3.1e-04/4.1e-07  5.6e-06/1.2e-10  │     2.5818     11.9398      2.7388     4.36x      0.94x
  128K+1K                   132096     Y    33      Y      17     N  │  2.4e-04/7.1e-09  0.0e+00/0.0e+00  │     4.1705     23.8024      5.2243     4.56x      0.80x
  128K+2x1K                 133120     Y    34      Y      18     N  │  1.2e-04/4.0e-10  0.0e+00/0.0e+00  │     4.1919     23.8588      5.2893     4.51x      0.79x
  128K+5x1K                 136192     Y    37      Y      21     N  │  1.2e-04/2.3e-10  0.0e+00/0.0e+00  │     4.3027     23.8261      5.6640     4.21x      0.76x
  128K+10x1K                141312     Y    42      Y      26     N  │  2.4e-04/2.9e-09  0.0e+00/0.0e+00  │     4.3997     23.6041      5.6571     4.17x      0.78x
  ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

  [H=8]
  ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  config                         T  pred  sub fla_cp fla_sub fused_pre  │         o max/mean        ht max/mean  │    FLA(ms)  CP_off(ms)   CP_on(ms)  CP_on/off  CP_on/FLA
  ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  4x256                       1024     N     0      N       0     Y  │  1.2e-04/1.9e-07  8.5e-04/1.6e-07  │     0.9461      0.3052      0.1458     2.09x      6.49x
  8x256                       2048     N     0      N       0     Y  │  4.9e-04/2.2e-07  1.2e-03/1.7e-07  │     0.9395      0.2966      0.1483     2.00x      6.33x
  16x256                      4096     N     0      N       0     Y  │  2.4e-04/1.8e-07  9.5e-04/1.0e-07  │     0.9337      0.2955      0.1552     1.90x      6.02x
  4x1K                        4096     N     0      N       0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     0.9443      0.2985      0.2800     1.07x      3.37x
  8x1K                        8192     N     0      N       0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     0.9998      0.3002      0.3128     0.96x      3.20x
  4x2K                        8192     N     0      N       0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     0.9795      0.4533      0.4808     0.94x      2.04x
  1K+512+256+128              1920     N     0      N       0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     0.9264      0.3098      0.2799     1.11x      3.31x
  2K+1K+512+256               3840     N     0      N       0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     0.9277      0.4353      0.4335     1.00x      2.14x
  1K+1+63+65+129              1282     N     0      N       0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     0.9242      0.3049      0.2815     1.08x      3.28x
  T=4K                        4096     Y    16      N       0     Y  │  2.4e-04/2.4e-07  1.3e-04/5.5e-08  │     0.9551      0.8067      0.4879     1.65x      1.96x
  T=8K                        8192     Y    16      N       0     Y  │  2.4e-04/1.9e-07  6.3e-05/1.7e-08  │     1.0110      1.5632      0.5976     2.62x      1.69x
  T=32K                      32768     Y    16      Y       4     N  │  4.9e-04/2.2e-07  3.4e-05/3.4e-10  │     2.5276      6.1382      1.8114     3.39x      1.40x
  T=64K                      65536     Y    16      Y       8     N  │  1.2e-04/4.2e-09  0.0e+00/0.0e+00  │     4.1150     12.3703      3.4319     3.60x      1.20x
  T=128K                    131072     Y    16      Y       8     N  │  2.4e-04/6.0e-08  1.5e-06/1.3e-11  │     7.3358     24.4276      6.7129     3.64x      1.09x
  8x4K                       32768     N     0      N       0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     1.6364      0.9413      0.9507     0.99x      1.72x
  4x8K                       32768     N     0      N       0     N  │  0.0e+00/0.0e+00  0.0e+00/0.0e+00  │     1.6504      1.6711      1.6569     1.01x      1.00x
  2x16K                      32768     Y    16      N       0     N  │  4.9e-04/1.9e-07  3.4e-05/1.7e-10  │     1.8061      3.1588      1.9250     1.64x      0.94x
  16K+16K                    32768     Y    16      N       0     N  │  4.9e-04/1.9e-07  3.4e-05/1.7e-10  │     1.8062      3.1599      1.9220     1.64x      0.94x
  24K+8K                     32768     Y    16      Y       4     N  │  4.9e-04/2.0e-07  3.4e-05/1.7e-10  │     2.5239      4.7337      1.9077     2.48x      1.32x
  28K+4K                     32768     Y    16      Y       5     N  │  4.9e-04/2.0e-07  3.4e-05/3.8e-10  │     2.5095      5.5002      1.9093     2.88x      1.31x
  32K+256+256                33280     Y    18      Y       6     N  │  4.9e-04/2.1e-07  3.4e-05/1.1e-10  │     2.5407      6.1238      2.1504     2.85x      1.18x
  40K+1K+8K                  50176     Y    13      Y       7     N  │  2.4e-04/1.9e-08  0.0e+00/0.0e+00  │     3.4652      7.7144      3.1803     2.43x      1.09x
  64K+512+256+128            66432     Y    19      Y      11     N  │  1.2e-04/4.2e-09  0.0e+00/0.0e+00  │     4.1633     12.2319      4.0994     2.98x      1.02x
  128K+1K                   132096     Y    17      Y       9     N  │  2.4e-04/5.2e-09  0.0e+00/0.0e+00  │     7.3930     24.4269      7.5871     3.22x      0.97x
  128K+2x1K                 133120     Y    18      Y      10     N  │  2.4e-04/5.4e-09  0.0e+00/0.0e+00  │     7.4272     24.4403      7.5973     3.22x      0.98x
  128K+5x1K                 136192     Y    21      Y      13     N  │  1.2e-04/4.7e-10  0.0e+00/0.0e+00  │     7.5785     24.4636      7.5939     3.22x      1.00x
  128K+10x1K                141312     Y    26      Y      18     N  │  2.4e-04/5.4e-09  0.0e+00/0.0e+00  │     7.8325     24.4907      7.7068     3.18x      1.02x
  ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

==============================================================================================================
  Summary
==============================================================================================================
  CP triggered (33 configs): geo-mean=3.00x  best=5.05x  worst=1.11x
  CP bypassed  (21 configs): mean overhead=0.812x  max=1.060x  (1.00 = no regression)
  cuLA (CP-on) vs FLA  (54 configs): geo-mean=1.70x  best=6.50x  worst=0.76x
    └─ CP-triggered subset (33 configs): geo-mean=1.13x  best=2.02x  worst=0.76x
  Accuracy (CP-on vs CP-off): o  max=4.88e-04 avg=1.99e-04   ht max=2.01e-03 avg=1.58e-04
==============================================================================================================

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimized Hopper (SM90) KDA prefill path featuring fused gate and L2-norm preprocessing, along with intra-card CP (chunk-parallel) scheduling. Key feedback includes optimizing cp_context.py to avoid a synchronous D2H copy by computing sequence mappings on the CPU, adding device validation checks in the C++ API to prevent illegal memory accesses, and using an if/else block in the fused L2-norm Triton kernel to eliminate redundant load instructions.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread cula/kda/cp_context.py Outdated
Comment thread cula/kda/cp_context.py Outdated
Comment thread cula/kda/cp_context.py Outdated
Comment thread csrc/api/kda_sm90.cu
Comment thread cula/kda/l2norm_qk_fused.py Outdated
@Hyaloid Hyaloid mentioned this pull request Jun 8, 2026
5 tasks
@icavan icavan requested review from cherhh and icavan June 12, 2026 16:17

@icavan icavan left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Hyaloid Thanks for your contribution. The main idea LGTM, could you add more test cases for varlen settings?

Comment thread benchmarks/bench_intracard_cp_sm90.py Outdated
Comment thread cula/kda/hopper_fused_fwd_opt.py
Comment thread tests/test_intracard_cp_sm90.py
pre-commit

adopt cr suggestions

support varlen fuse l2norm+gate cumsum & fix irregular input
@cherhh

cherhh commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator

@Hyaloid Thanks a lot for this contribution! Could you also add some performance numbers comparing this SM90 intra-card CP path against FLA with intra-card CP enabled?

@icavan icavan left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Hyaloid

Hyaloid commented Jun 30, 2026

Copy link
Copy Markdown
Author

Hi @cherhh, I've added the comparison with the FLA intracard-CP implementation and updated the performance table above. The results show that cuLA isn't always faster than FLA's intracard-CP implementation.

One likely reason is the relatively high register pressure in cuLA. The KDA forward kernel keeps q/k/v/g/beta/A_log/dt_bias in registers, and safe_gate=True increases the register usage even further. On Hopper, this doesn't usually show up as traditional register spilling because WGMMA accumulators use dedicated registers, but the high register usage can still reduce occupancy (sometimes down to one CTA per SM), making memory latency harder to hide. If nvcc -Xptxas -v reports a non-zero stack frame, that confirms actual register spill.

To consistently outperform FLA across all workloads, I think we'd probably need to rethink parts of cuLA's algorithm rather than just keep tuning the current implementation. That's a much bigger piece of work. And honestly, I don't think it's realistic to expect a single library to be the best choice for every workload and every shape.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants