feat: intracard cp for sm90#86
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces an optimized Hopper (SM90) KDA prefill path featuring fused gate and L2-norm preprocessing, along with intra-card CP (chunk-parallel) scheduling. Key feedback includes optimizing cp_context.py to avoid a synchronous D2H copy by computing sequence mappings on the CPU, adding device validation checks in the C++ API to prevent illegal memory accesses, and using an if/else block in the fused L2-norm Triton kernel to eliminate redundant load instructions.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
pre-commit adopt cr suggestions support varlen fuse l2norm+gate cumsum & fix irregular input
|
@Hyaloid Thanks a lot for this contribution! Could you also add some performance numbers comparing this SM90 intra-card CP path against FLA with intra-card CP enabled? |
|
Hi @cherhh, I've added the comparison with the FLA intracard-CP implementation and updated the performance table above. The results show that cuLA isn't always faster than FLA's intracard-CP implementation. One likely reason is the relatively high register pressure in cuLA. The KDA forward kernel keeps q/k/v/g/beta/A_log/dt_bias in registers, and safe_gate=True increases the register usage even further. On Hopper, this doesn't usually show up as traditional register spilling because WGMMA accumulators use dedicated registers, but the high register usage can still reduce occupancy (sometimes down to one CTA per SM), making memory latency harder to hide. If nvcc -Xptxas -v reports a non-zero stack frame, that confirms actual register spill. To consistently outperform FLA across all workloads, I think we'd probably need to rethink parts of cuLA's algorithm rather than just keep tuning the current implementation. That's a much bigger piece of work. And honestly, I don't think it's realistic to expect a single library to be the best choice for every workload and every shape. |
📌 Description
The serial bottleneck
kda_prefill_hopper (cuLA's SM90 KDA prefill) launches one CTA per (seq, head) and runs a strictly
sequential chunk recurrence inside each sequence: h_t = decay(g_t) · h_{t-1} + k_t^T @ (u_t − w_t·h_{t-1}).
Within one sequence, work cannot parallelize across chunks — only across the (raw_batch × H) grid.
This becomes a bottleneck when both:
H=8 occupies only 8 CTAs on a 132‑SM H100 (~6% occupancy). The per‑SM work is so small that most of the card is idle waiting on 8 serial chains.
dominates wall time while short seqs finish in microseconds and leave SMs idle.
Approach
Mirroring FLA's intra‑card CP design (and the SM100 cuLA path in cula/ops/cp/chunk_delta_h.py),
this PR splits long sequences into CP‑chunks on the same card and produces per‑CP‑chunk initial
states so the main C++ kernel can run all CP‑chunks in parallel.
🔍 Related Issues
Similar to this issue #20 , but for SM90.
🚀 Pull Request Checklist
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
python -m pytest tests/test_intracard_cp_sm90.py -v⚡ Performance
python benchmarks/bench_intracard_cp_sm90.py