From ef88e16b02a2a835d0f742c2ddd54b34a922b09e Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 15 Jun 2026 20:46:57 +0000 Subject: [PATCH 1/3] Wired split-kv through FlashAttention interface --- .../dot_product_attention/backends.py | 28 +++++++++++++++++++ .../attention/dot_product_attention/utils.py | 9 +++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 57b619a68..94012f88c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -91,6 +91,7 @@ _flash_attn_bwd = None _flash_attn_varlen_fwd = None _flash_attn_varlen_bwd = None +aiter_flash_attn_func_splitkv = None # ROCm: AITER native split-K forward (aiter.ops.mha) if IS_HIP_EXTENSION and os.getenv("NVTE_FLASH_ATTN_AITER", "0") == "1": try: @@ -116,6 +117,19 @@ fa_utils.version = PkgVersion("2.7.1") #masqurade as FA 2.7.1 fa_utils.set_flash_attention_version() attn_log.fa_logger.info("Using AITER Triton for FlashAttn.") + # Optionally pull in AITER's native split-K forward (aiter.ops.mha). Kept + # separate from the Triton imports above so its absence never disables the + # Triton path; engaged only when num_splits != 1 (see FlashAttention.forward). + try: + from aiter.ops.mha import flash_attn_func as aiter_flash_attn_func_splitkv + except ImportError: + attn_log.fa_logger.warning( + "AITER native split-K forward (aiter.ops.mha) is unavailable;" + " num_splits will be ignored for the AITER path." + ) + else: + fa_utils.use_aiter_splitkv = True + attn_log.fa_logger.info("AITER native split-K forward is available.") try: if fa_utils.use_aiter_triton: raise PackageNotFoundError # skip version check for aiter triton @@ -1040,6 +1054,20 @@ def forward( 1 )[:batch_size] ) + if ( + fa_utils.use_aiter_splitkv + and num_splits != 1 + and inference_params is None + and not fp8 + and func is flash_attn_func + ): + # ROCm: route the dense forward to AITER's native split-K kernel. + # aiter.ops.mha.flash_attn_func self-gates (gfx942/D64/bf16) and + # otherwise falls back to the standard CK/ASM dispatch, so this is + # safe for any dense bf16 shape. Forward-only; backward unchanged. + # num_splits: 0 = AITER heuristic, >=2 = forced split count. + func = aiter_flash_attn_func_splitkv + fa_optional_forward_kwargs["num_splits"] = num_splits output = func( query_layer, key_layer, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 87992d294..175882d7d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -144,6 +144,7 @@ class FlashAttentionUtils: (5) cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py""" v3_warning_printed = False use_aiter_triton = False #ROCm + use_aiter_splitkv = False #ROCm: AITER native split-K forward (aiter.ops.mha) available @staticmethod def set_flash_attention_version(): @@ -530,7 +531,13 @@ def get_attention_backend( # Filter: num_splits if num_splits != 1: - if use_flash_attention_2 and FlashAttentionUtils.is_installed: + # ROCm: the AITER backend masquerades as FlashAttention 2 and routes num_splits + # through to its native split-K forward (aiter.ops.mha), so keep it enabled here. + if ( + use_flash_attention_2 + and FlashAttentionUtils.is_installed + and not FlashAttentionUtils.use_aiter_splitkv + ): logger.debug("Disabling FlashAttention 2 for num_splits") use_flash_attention_2 = False if use_fused_attention: From cb028454a2d3444799e4f36d6e9cfded2771f706 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 16 Jun 2026 16:00:32 +0000 Subject: [PATCH 2/3] Updated splitkv intercept to live in CK FA path --- .../dot_product_attention/backends.py | 144 ++++++++++++++---- .../attention/dot_product_attention/utils.py | 9 +- 2 files changed, 119 insertions(+), 34 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 94012f88c..bdc44d9a8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -8,6 +8,7 @@ from contextlib import nullcontext from importlib.metadata import version as get_pkg_version from importlib.metadata import PackageNotFoundError +import inspect import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings @@ -91,7 +92,6 @@ _flash_attn_bwd = None _flash_attn_varlen_fwd = None _flash_attn_varlen_bwd = None -aiter_flash_attn_func_splitkv = None # ROCm: AITER native split-K forward (aiter.ops.mha) if IS_HIP_EXTENSION and os.getenv("NVTE_FLASH_ATTN_AITER", "0") == "1": try: @@ -117,19 +117,87 @@ fa_utils.version = PkgVersion("2.7.1") #masqurade as FA 2.7.1 fa_utils.set_flash_attention_version() attn_log.fa_logger.info("Using AITER Triton for FlashAttn.") - # Optionally pull in AITER's native split-K forward (aiter.ops.mha). Kept - # separate from the Triton imports above so its absence never disables the - # Triton path; engaged only when num_splits != 1 (see FlashAttention.forward). - try: - from aiter.ops.mha import flash_attn_func as aiter_flash_attn_func_splitkv - except ImportError: + +# ROCm: AITER native split-K (Flash-Decoding) forward, opt-in via NVTE_FUSED_ATTN_SPLITKV. +# When enabled, eligible dense bf16 head-dim-64 calls that would run on the CK +# FusedAttention backend are transparently routed to aiter.ops.mha.flash_attn_func +# (the native split-K forward) instead -- see FusedAttention.forward. +_aiter_splitkv_flash_attn_func = None +_use_aiter_splitkv = False +if IS_HIP_EXTENSION and os.getenv("NVTE_FUSED_ATTN_SPLITKV", "0") == "1": + try: + from aiter.ops.mha import flash_attn_func as _aiter_splitkv_flash_attn_func + except ImportError: + attn_log.fa_logger.warning( + "NVTE_FUSED_ATTN_SPLITKV is set but aiter.ops.mha is unavailable;" + " split-K interception disabled." + ) + else: + # The native split-K forward (AITER PR #3581) exposes a `num_splits` arg on + # flash_attn_func. Older AITER builds lack it, so verify before enabling to + # avoid a TypeError at call time. + if "num_splits" in inspect.signature(_aiter_splitkv_flash_attn_func).parameters: + _use_aiter_splitkv = True + attn_log.fa_logger.info("AITER native split-K forward enabled for FusedAttention.") + else: + _aiter_splitkv_flash_attn_func = None attn_log.fa_logger.warning( - "AITER native split-K forward (aiter.ops.mha) is unavailable;" - " num_splits will be ignored for the AITER path." + "NVTE_FUSED_ATTN_SPLITKV is set but aiter.ops.mha.flash_attn_func has no" + " num_splits arg (AITER predates PR #3581); split-K interception disabled." ) - else: - fa_utils.use_aiter_splitkv = True - attn_log.fa_logger.info("AITER native split-K forward is available.") + + +def _aiter_splitkv_eligible( + q, + v, + qkv_format, + attn_mask_type, + core_attention_bias_type, + window_size, + dropout_p, + fp8, + context_parallel, + inference_params, + softmax_offset, + max_seqlen_q, + max_seqlen_kv, +): + """Whether a FusedAttention call can be served by AITER's native split-K forward. + + Conservatively mirrors aiter.ops.mha's can_impl_fmha_native gate (gfx942, dense + bf16, head_dim 64, no bias/alibi/sliding-window/dropout/sink/fp8/varlen/context- + parallel/kvcache). aiter additionally self-gates and falls back to CK internally, + so this is a pre-filter to avoid diverting calls the native kernel cannot serve. + """ + if not _use_aiter_splitkv: + return False + if get_device_compute_capability() != (9, 4): # gfx942 only + return False + if q.dtype != torch.bfloat16 or v.dtype != torch.bfloat16: + return False + if q.shape[-1] != 64 or v.shape[-1] != 64: + return False + if qkv_format not in ("bshd", "sbhd"): # dense only; thd/varlen unsupported + return False + if "padding" in attn_mask_type: + return False + if attn_mask_type not in ("no_mask", "causal"): + return False + if core_attention_bias_type != "no_bias": # excludes both bias and alibi + return False + if window_size is not None and tuple(window_size) != (-1, -1): # no sliding window + return False + if dropout_p != 0.0: + return False + if fp8 or context_parallel or inference_params is not None: + return False + if softmax_offset is not None: # no learnable sink + return False + if "causal" in attn_mask_type and max_seqlen_kv < max_seqlen_q: + return False + return True + + try: if fa_utils.use_aiter_triton: raise PackageNotFoundError # skip version check for aiter triton @@ -1054,20 +1122,6 @@ def forward( 1 )[:batch_size] ) - if ( - fa_utils.use_aiter_splitkv - and num_splits != 1 - and inference_params is None - and not fp8 - and func is flash_attn_func - ): - # ROCm: route the dense forward to AITER's native split-K kernel. - # aiter.ops.mha.flash_attn_func self-gates (gfx942/D64/bf16) and - # otherwise falls back to the standard CK/ASM dispatch, so this is - # safe for any dense bf16 shape. Forward-only; backward unchanged. - # num_splits: 0 = AITER heuristic, >=2 = forced split count. - func = aiter_flash_attn_func_splitkv - fa_optional_forward_kwargs["num_splits"] = num_splits output = func( query_layer, key_layer, @@ -1969,6 +2023,44 @@ def forward( if (kv_format == "thd" or "padding" in attn_mask_type) and cu_seqlens_kv_padded is None: cu_seqlens_kv_padded = cu_seqlens_kv + # ROCm: opt-in interception (NVTE_FUSED_ATTN_SPLITKV=1). For eligible dense bf16 + # head-dim-64 calls, route the forward to AITER's native split-K kernel instead + # of the CK FusedAttention path. aiter.ops.mha.flash_attn_func is its own + # autograd Function (handles its own backward) and self-gates / falls back to CK + # internally, so non-eligible shapes are unaffected. num_splits=0 => AITER picks + # the split count heuristically. + if _aiter_splitkv_eligible( + query_layer, + value_layer, + qkv_format, + attn_mask_type, + core_attention_bias_type, + window_size, + self.attention_dropout if self.training else 0.0, + fp8, + context_parallel, + inference_params, + softmax_offset, + max_seqlen_q, + max_seqlen_kv, + ): + q, k, v = query_layer, key_layer, value_layer + if qkv_format == "sbhd": + q, k, v = (x.transpose(0, 1).contiguous() for x in (q, k, v)) + out = _aiter_splitkv_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + num_splits=0, + ) + if qkv_format == "sbhd": + out = out.transpose(0, 1) + # ...hd -> ...(hd), matching the FusedAttnFunc return convention below. + return out.reshape(*out.shape[:-2], -1) + use_FAv2_bwd = ( self.use_FAv2_bwd and (core_attention_bias_type == "no_bias") diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 175882d7d..87992d294 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -144,7 +144,6 @@ class FlashAttentionUtils: (5) cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py""" v3_warning_printed = False use_aiter_triton = False #ROCm - use_aiter_splitkv = False #ROCm: AITER native split-K forward (aiter.ops.mha) available @staticmethod def set_flash_attention_version(): @@ -531,13 +530,7 @@ def get_attention_backend( # Filter: num_splits if num_splits != 1: - # ROCm: the AITER backend masquerades as FlashAttention 2 and routes num_splits - # through to its native split-K forward (aiter.ops.mha), so keep it enabled here. - if ( - use_flash_attention_2 - and FlashAttentionUtils.is_installed - and not FlashAttentionUtils.use_aiter_splitkv - ): + if use_flash_attention_2 and FlashAttentionUtils.is_installed: logger.debug("Disabling FlashAttention 2 for num_splits") use_flash_attention_2 = False if use_fused_attention: From b4f8fcf6e103e669519749637d98b53f30dd71b3 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 16 Jun 2026 19:45:39 +0000 Subject: [PATCH 3/3] Added readme segment --- README.rst | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/README.rst b/README.rst index 62e7d0738..fe74e0e30 100644 --- a/README.rst +++ b/README.rst @@ -259,6 +259,34 @@ ROCm TE provides the compile-time env NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAU * 3 - standard asm, default; * 4 - rta_asm. +AITER Native Split-K Forward (gfx942 only) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +On gfx942, the CK fused attention path can optionally dispatch the *forward* pass to AITER's +hand-written native split-K (Flash-Decoding) kernel, which splits the work along the key/value +sequence dimension to keep the GPU busy when a single attention problem does not. This is +controlled by a runtime environment variable: + +* NVTE_FUSED_ATTN_SPLITKV - by default 0 (disabled). When set to 1, eligible CK FusedAttention + forward calls are routed to AITER's native split-K kernel, which picks the number of splits with + its built-in occupancy heuristic. + +When to use it: + +* The benefit comes from *under-subscribed, long-KV* shapes - small ``batch x num_heads`` with a + large ``seqlen_kv`` (e.g. long-context prefill or decode) - where the standard kernel leaves + compute units idle. Splitting the KV dimension across more workgroups fills the machine. +* For already-saturated shapes (large ``batch x num_heads``) there is little to gain; AITER's + heuristic typically declines to split, so leaving the flag on is low-risk but offers no benefit + there. +* It is forward-only and affects only the ``FusedAttention`` / ``DotProductAttention`` module path. + Training backward and the unfused path are unchanged. + +The divert engages only when a call is eligible for the native kernel: gfx942, dense ``bshd`` / +``sbhd`` layout (no ``thd``/varlen), bf16, head dim 64, no bias/ALiBi/sliding-window/dropout/ +attention-sink/FP8/context-parallel/KV-cache, and, for causal masking, ``seqlen_kv >= seqlen_q``. +Non-eligible calls fall back to the standard CK kernel unchanged. This requires an AITER build that +includes the native split-K kernel; if it is unavailable the flag is ignored with a warning. + Experimental Triton Kernels on ROCm ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Most CUDA kernels in Transformer Engine are hipified to run on ROCm. While the hipifiled CUDA kernels are functional, they are not necessarily optimal on ROCm.