Skip to content

Added support for AITER JIT native splitkv kernel#631

Draft
Micky774 wants to merge 3 commits into
devfrom
zain/aiter/split-kv-py
Draft

Added support for AITER JIT native splitkv kernel#631
Micky774 wants to merge 3 commits into
devfrom
zain/aiter/split-kv-py

Conversation

@Micky774

@Micky774 Micky774 commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Description

Added support for AITER JIT native splitkv kernel via diversion from CK FA backend dispatch. Note that this requires a recent AITER w/ AITER PR 3581, which entails requiring Triton >= 3.6

This change in behavior is not strictly beneficial for all configs, so it is gated as opt-in.via NVTE_FUSED_ATTN_SPLITKV.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Micky774 Micky774 added the ci-level 3 CI test level 3 label Jun 16, 2026
@Micky774 Micky774 marked this pull request as draft June 16, 2026 19:54
@github-actions

Copy link
Copy Markdown

Claude Walkthrough

Intent. Add an opt-in path that re-routes eligible ROCm FusedAttention forward calls to AITER's native split-K (Flash-Decoding) kernel. The benefit is for under-subscribed, long-KV shapes (small batch * num_heads, large seqlen_kv) where the standard CK kernel leaves CUs idle; for already-saturated shapes AITER's own heuristic declines to split, making the flag low-risk to leave on but only useful for specific workloads. Gated behind NVTE_FUSED_ATTN_SPLITKV=1 and requires an AITER build that includes PR #3581 (the num_splits arg on flash_attn_func) and Triton >= 3.6.

Key changes.

  • New env-gated module-level import + capability check in transformer_engine/pytorch/attention/dot_product_attention/backends.py:124: imports aiter.ops.mha.flash_attn_func, verifies the num_splits parameter is present via inspect.signature, and warns + disables the divert if either piece is missing.
  • New _aiter_splitkv_eligible(...) predicate at backends.py:148 that pre-filters calls the native kernel cannot serve (gfx942 only, dense bshd/sbhd, bf16, head dim 64, no bias/ALiBi/sliding-window/dropout/sink/FP8/varlen/context-parallel/KV-cache, and seqlen_kv >= seqlen_q for causal).
  • Interception inside FusedAttention.forward at backends.py:2026: when eligible, transposes sbhd -> bshd if needed, calls flash_attn_func(..., num_splits=0) (AITER picks the split count), transposes back, and reshapes ...hd -> ...(hd) to match the existing FusedAttnFunc return convention before early-returning.
  • README documentation block at README.rst:262 explaining the flag, eligibility, and when it actually helps.

Walkthrough.

  • backends.py (module scope): the divert is set up once at import time. _use_aiter_splitkv is the runtime master switch — it only flips on if NVTE_FUSED_ATTN_SPLITKV=1, aiter.ops.mha imports, and flash_attn_func exposes num_splits. The signature check is the load-bearing safety net for older AITER builds (would otherwise raise TypeError deep in the forward).
  • _aiter_splitkv_eligible: mirrors AITER's internal can_impl_fmha_native gate so the divert only triggers when AITER will actually run the native kernel. The docstring notes AITER self-gates anyway; this is a pre-filter so non-eligible calls keep the existing CK code path intact (no needless tensor reshapes or fallback round-trip).
  • FusedAttention.forward: the interception is placed after the cu_seqlens_kv_padded normalization but before use_FAv2_bwd setup and the main CK dispatch — early-returning bypasses the whole CK path including the FusedAttnFunc autograd wrapper. AITER's flash_attn_func is its own autograd Function, so backward is handled by AITER. num_splits=0 defers split-count choice to AITER's occupancy heuristic. The sbhd transpose is .contiguous() because AITER expects bshd contiguous inputs.
  • README: documents NVTE_FUSED_ATTN_SPLITKV, the eligibility constraints (mirroring the predicate), and explicitly notes this is forward-only and only affects the FusedAttention/DotProductAttention module path — training backward and the unfused path are unchanged.

Testing. No tests added. The change is opt-in via env var and the eligibility predicate restricts the path to a narrow shape regime; the default-off behavior leaves existing test coverage of the CK FusedAttention path untouched.

Notes for reviewers.

  • Backwards-compat: the divert is fully off unless NVTE_FUSED_ATTN_SPLITKV=1 and a sufficiently new AITER is installed; default-off behavior is byte-identical to before.
  • Autograd: AITER's flash_attn_func owns the backward — there is no fallback to TE's FusedAttnFunc.backward once the divert engages. Worth confirming AITER's backward semantics match the CK path for the eligible shape set (bf16, head dim 64, no dropout/bias/sink, dense layout) under training.
  • Eligibility drift: _aiter_splitkv_eligible is a hand-mirrored copy of AITER's can_impl_fmha_native gate. If AITER tightens or relaxes that gate, the predicate will diverge silently — non-eligible calls would still be caught by AITER's internal self-gate (falling back to CK inside AITER) but eligible calls newly supported by AITER would be missed here.
  • Causal short-KV: the max_seqlen_kv >= max_seqlen_q check for causal masking is the only constraint not obviously documented in the README block beyond the general eligibility list — worth confirming this matches AITER's actual support matrix.
  • Layout: qkv_format == "sbhd" triggers a transpose(0, 1).contiguous() per tensor on the way in and a transpose(0, 1) on the way out — a memory-copy cost paid per call when the eligibility check passes for sbhd inputs.

Generated by Claude. To request a code review, comment /claude review.

Comment on lines +2046 to +2057
):
q, k, v = query_layer, key_layer, value_layer
if qkv_format == "sbhd":
q, k, v = (x.transpose(0, 1).contiguous() for x in (q, k, v))
out = _aiter_splitkv_flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
num_splits=0,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intercept path drops bottom_right_diagonal — the regular path threads it into FusedAttnFunc.apply (line 2163) and it affects causal semantics whenever seqlen_q != seqlen_kv. The eligibility check allows max_seqlen_kv > max_seqlen_q for attn_mask_type == "causal", but then calls aiter with just causal=True, which is typically top-left causal. If a caller relied on bottom-right semantics (e.g. cross-attention with attn_mask_type="causal" + bottom_right_diagonal=True), the intercepted forward will produce different outputs than the unintercepted path.

Two safe options: (a) gate eligibility on max_seqlen_q == max_seqlen_kv for causal so TL/BR coincide, or (b) also check bottom_right_diagonal in (None, False) and document the assumption. Either keeps the intercept semantically equivalent to the CK path.

Comment on lines +172 to +198
if not _use_aiter_splitkv:
return False
if get_device_compute_capability() != (9, 4): # gfx942 only
return False
if q.dtype != torch.bfloat16 or v.dtype != torch.bfloat16:
return False
if q.shape[-1] != 64 or v.shape[-1] != 64:
return False
if qkv_format not in ("bshd", "sbhd"): # dense only; thd/varlen unsupported
return False
if "padding" in attn_mask_type:
return False
if attn_mask_type not in ("no_mask", "causal"):
return False
if core_attention_bias_type != "no_bias": # excludes both bias and alibi
return False
if window_size is not None and tuple(window_size) != (-1, -1): # no sliding window
return False
if dropout_p != 0.0:
return False
if fp8 or context_parallel or inference_params is not None:
return False
if softmax_offset is not None: # no learnable sink
return False
if "causal" in attn_mask_type and max_seqlen_kv < max_seqlen_q:
return False
return True

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eligibility doesn't gate on num_heads_q == num_heads_kv (GQA / MQA). The CK FusedAttention path supports unequal head counts; if aiter.ops.mha.flash_attn_func's native split-K path doesn't, GQA shapes will hit the divert and then either error or fall back inside aiter (losing the perf benefit). Worth either adding a q.shape[-2] != k.shape[-2] -> False check (and taking k into the signature for that), or documenting in the README + the docstring that GQA is intentionally let through and relies on aiter's internal self-gate.

Same comment applies to dtype/head-dim: the signature takes q, v but assumes k matches — passing and checking k would close the assumption.

):
q, k, v = query_layer, key_layer, value_layer
if qkv_format == "sbhd":
q, k, v = (x.transpose(0, 1).contiguous() for x in (q, k, v))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sbhd, this materializes a full contiguous [b, s, h, d] copy of all three tensors on every eligible call. That copy can offset (or exceed) the perceived split-K speedup on the very shapes the README highlights as the target (under-subscribed long-KV decode-style). Worth either:

  • excluding sbhd from eligibility for now (bshd is the layout where split-K most cleanly helps), or
  • calling out the extra copy in the README's "When to use it" so users don't measure regressions on sbhd and conclude the flag is broken.

Comment on lines +2032 to +2046
if _aiter_splitkv_eligible(
query_layer,
value_layer,
qkv_format,
attn_mask_type,
core_attention_bias_type,
window_size,
self.attention_dropout if self.training else 0.0,
fp8,
context_parallel,
inference_params,
softmax_offset,
max_seqlen_q,
max_seqlen_kv,
):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: the interception sits after lines 1949–2024 which build cu_seqlens_q/kv, cu_seqlens_*_padded, page_table, etc. — none of which the aiter path uses. Moving the eligibility check + divert above that block (right after qkv_format is computed at line 1958) would avoid the wasted work on every eligible call. Not critical, but on the hot path it's free to fix.

@github-actions

Copy link
Copy Markdown

Reviewed the +148/-0 diff (3 commits: ef88e16b, cb028454, b4f8fcf6) — opt-in NVTE_FUSED_ATTN_SPLITKV divert from the CK FusedAttention forward to aiter.ops.mha.flash_attn_func, plus README documentation.

Overall the gating is conservative and the import-time + per-call eligibility filters look sound. Four inline comments, none blocking, all about robustness of the eligibility filter and a small perf nit:

  • bottom_right_diagonal is dropped at the divert — TL vs BR causal can diverge when seqlen_kv > seqlen_q.
  • GQA (num_heads_q != num_heads_kv) isn't explicitly gated; relies on aiter's internal self-gate.
  • sbhd path does a full .contiguous() copy of Q/K/V, which can offset the split-K win for the very shapes the README targets.
  • The intercept sits after a block of setup work (cu_seqlens, page_table) that the aiter path doesn't use — cheap move-up.

Copyright headers: OK (both README.rst and backends.py already carry AMD ...-2026).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant