Added support for AITER JIT native splitkv kernel#631
Conversation
Claude WalkthroughIntent. Add an opt-in path that re-routes eligible ROCm Key changes.
Walkthrough.
Testing. No tests added. The change is opt-in via env var and the eligibility predicate restricts the path to a narrow shape regime; the default-off behavior leaves existing test coverage of the CK FusedAttention path untouched. Notes for reviewers.
Generated by Claude. To request a code review, comment |
| ): | ||
| q, k, v = query_layer, key_layer, value_layer | ||
| if qkv_format == "sbhd": | ||
| q, k, v = (x.transpose(0, 1).contiguous() for x in (q, k, v)) | ||
| out = _aiter_splitkv_flash_attn_func( | ||
| q, | ||
| k, | ||
| v, | ||
| dropout_p=0.0, | ||
| softmax_scale=self.softmax_scale, | ||
| causal="causal" in attn_mask_type, | ||
| num_splits=0, |
There was a problem hiding this comment.
The intercept path drops bottom_right_diagonal — the regular path threads it into FusedAttnFunc.apply (line 2163) and it affects causal semantics whenever seqlen_q != seqlen_kv. The eligibility check allows max_seqlen_kv > max_seqlen_q for attn_mask_type == "causal", but then calls aiter with just causal=True, which is typically top-left causal. If a caller relied on bottom-right semantics (e.g. cross-attention with attn_mask_type="causal" + bottom_right_diagonal=True), the intercepted forward will produce different outputs than the unintercepted path.
Two safe options: (a) gate eligibility on max_seqlen_q == max_seqlen_kv for causal so TL/BR coincide, or (b) also check bottom_right_diagonal in (None, False) and document the assumption. Either keeps the intercept semantically equivalent to the CK path.
| if not _use_aiter_splitkv: | ||
| return False | ||
| if get_device_compute_capability() != (9, 4): # gfx942 only | ||
| return False | ||
| if q.dtype != torch.bfloat16 or v.dtype != torch.bfloat16: | ||
| return False | ||
| if q.shape[-1] != 64 or v.shape[-1] != 64: | ||
| return False | ||
| if qkv_format not in ("bshd", "sbhd"): # dense only; thd/varlen unsupported | ||
| return False | ||
| if "padding" in attn_mask_type: | ||
| return False | ||
| if attn_mask_type not in ("no_mask", "causal"): | ||
| return False | ||
| if core_attention_bias_type != "no_bias": # excludes both bias and alibi | ||
| return False | ||
| if window_size is not None and tuple(window_size) != (-1, -1): # no sliding window | ||
| return False | ||
| if dropout_p != 0.0: | ||
| return False | ||
| if fp8 or context_parallel or inference_params is not None: | ||
| return False | ||
| if softmax_offset is not None: # no learnable sink | ||
| return False | ||
| if "causal" in attn_mask_type and max_seqlen_kv < max_seqlen_q: | ||
| return False | ||
| return True |
There was a problem hiding this comment.
Eligibility doesn't gate on num_heads_q == num_heads_kv (GQA / MQA). The CK FusedAttention path supports unequal head counts; if aiter.ops.mha.flash_attn_func's native split-K path doesn't, GQA shapes will hit the divert and then either error or fall back inside aiter (losing the perf benefit). Worth either adding a q.shape[-2] != k.shape[-2] -> False check (and taking k into the signature for that), or documenting in the README + the docstring that GQA is intentionally let through and relies on aiter's internal self-gate.
Same comment applies to dtype/head-dim: the signature takes q, v but assumes k matches — passing and checking k would close the assumption.
| ): | ||
| q, k, v = query_layer, key_layer, value_layer | ||
| if qkv_format == "sbhd": | ||
| q, k, v = (x.transpose(0, 1).contiguous() for x in (q, k, v)) |
There was a problem hiding this comment.
For sbhd, this materializes a full contiguous [b, s, h, d] copy of all three tensors on every eligible call. That copy can offset (or exceed) the perceived split-K speedup on the very shapes the README highlights as the target (under-subscribed long-KV decode-style). Worth either:
- excluding
sbhdfrom eligibility for now (bshdis the layout where split-K most cleanly helps), or - calling out the extra copy in the README's "When to use it" so users don't measure regressions on
sbhdand conclude the flag is broken.
| if _aiter_splitkv_eligible( | ||
| query_layer, | ||
| value_layer, | ||
| qkv_format, | ||
| attn_mask_type, | ||
| core_attention_bias_type, | ||
| window_size, | ||
| self.attention_dropout if self.training else 0.0, | ||
| fp8, | ||
| context_parallel, | ||
| inference_params, | ||
| softmax_offset, | ||
| max_seqlen_q, | ||
| max_seqlen_kv, | ||
| ): |
There was a problem hiding this comment.
Minor: the interception sits after lines 1949–2024 which build cu_seqlens_q/kv, cu_seqlens_*_padded, page_table, etc. — none of which the aiter path uses. Moving the eligibility check + divert above that block (right after qkv_format is computed at line 1958) would avoid the wasted work on every eligible call. Not critical, but on the hot path it's free to fix.
|
Reviewed the +148/-0 diff (3 commits: Overall the gating is conservative and the import-time + per-call eligibility filters look sound. Four inline comments, none blocking, all about robustness of the eligibility filter and a small perf nit:
Copyright headers: OK (both |
Description
Added support for AITER JIT native splitkv kernel via diversion from CK FA backend dispatch. Note that this requires a recent AITER w/ AITER PR 3581, which entails requiring Triton >= 3.6
This change in behavior is not strictly beneficial for all configs, so it is gated as opt-in.via
NVTE_FUSED_ATTN_SPLITKV.Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: