diff --git a/README.rst b/README.rst index 62e7d0738..fe74e0e30 100644 --- a/README.rst +++ b/README.rst @@ -259,6 +259,34 @@ ROCm TE provides the compile-time env NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAU * 3 - standard asm, default; * 4 - rta_asm. +AITER Native Split-K Forward (gfx942 only) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +On gfx942, the CK fused attention path can optionally dispatch the *forward* pass to AITER's +hand-written native split-K (Flash-Decoding) kernel, which splits the work along the key/value +sequence dimension to keep the GPU busy when a single attention problem does not. This is +controlled by a runtime environment variable: + +* NVTE_FUSED_ATTN_SPLITKV - by default 0 (disabled). When set to 1, eligible CK FusedAttention + forward calls are routed to AITER's native split-K kernel, which picks the number of splits with + its built-in occupancy heuristic. + +When to use it: + +* The benefit comes from *under-subscribed, long-KV* shapes - small ``batch x num_heads`` with a + large ``seqlen_kv`` (e.g. long-context prefill or decode) - where the standard kernel leaves + compute units idle. Splitting the KV dimension across more workgroups fills the machine. +* For already-saturated shapes (large ``batch x num_heads``) there is little to gain; AITER's + heuristic typically declines to split, so leaving the flag on is low-risk but offers no benefit + there. +* It is forward-only and affects only the ``FusedAttention`` / ``DotProductAttention`` module path. + Training backward and the unfused path are unchanged. + +The divert engages only when a call is eligible for the native kernel: gfx942, dense ``bshd`` / +``sbhd`` layout (no ``thd``/varlen), bf16, head dim 64, no bias/ALiBi/sliding-window/dropout/ +attention-sink/FP8/context-parallel/KV-cache, and, for causal masking, ``seqlen_kv >= seqlen_q``. +Non-eligible calls fall back to the standard CK kernel unchanged. This requires an AITER build that +includes the native split-K kernel; if it is unavailable the flag is ignored with a warning. + Experimental Triton Kernels on ROCm ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Most CUDA kernels in Transformer Engine are hipified to run on ROCm. While the hipifiled CUDA kernels are functional, they are not necessarily optimal on ROCm. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 57b619a68..bdc44d9a8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -8,6 +8,7 @@ from contextlib import nullcontext from importlib.metadata import version as get_pkg_version from importlib.metadata import PackageNotFoundError +import inspect import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings @@ -116,6 +117,87 @@ fa_utils.version = PkgVersion("2.7.1") #masqurade as FA 2.7.1 fa_utils.set_flash_attention_version() attn_log.fa_logger.info("Using AITER Triton for FlashAttn.") + +# ROCm: AITER native split-K (Flash-Decoding) forward, opt-in via NVTE_FUSED_ATTN_SPLITKV. +# When enabled, eligible dense bf16 head-dim-64 calls that would run on the CK +# FusedAttention backend are transparently routed to aiter.ops.mha.flash_attn_func +# (the native split-K forward) instead -- see FusedAttention.forward. +_aiter_splitkv_flash_attn_func = None +_use_aiter_splitkv = False +if IS_HIP_EXTENSION and os.getenv("NVTE_FUSED_ATTN_SPLITKV", "0") == "1": + try: + from aiter.ops.mha import flash_attn_func as _aiter_splitkv_flash_attn_func + except ImportError: + attn_log.fa_logger.warning( + "NVTE_FUSED_ATTN_SPLITKV is set but aiter.ops.mha is unavailable;" + " split-K interception disabled." + ) + else: + # The native split-K forward (AITER PR #3581) exposes a `num_splits` arg on + # flash_attn_func. Older AITER builds lack it, so verify before enabling to + # avoid a TypeError at call time. + if "num_splits" in inspect.signature(_aiter_splitkv_flash_attn_func).parameters: + _use_aiter_splitkv = True + attn_log.fa_logger.info("AITER native split-K forward enabled for FusedAttention.") + else: + _aiter_splitkv_flash_attn_func = None + attn_log.fa_logger.warning( + "NVTE_FUSED_ATTN_SPLITKV is set but aiter.ops.mha.flash_attn_func has no" + " num_splits arg (AITER predates PR #3581); split-K interception disabled." + ) + + +def _aiter_splitkv_eligible( + q, + v, + qkv_format, + attn_mask_type, + core_attention_bias_type, + window_size, + dropout_p, + fp8, + context_parallel, + inference_params, + softmax_offset, + max_seqlen_q, + max_seqlen_kv, +): + """Whether a FusedAttention call can be served by AITER's native split-K forward. + + Conservatively mirrors aiter.ops.mha's can_impl_fmha_native gate (gfx942, dense + bf16, head_dim 64, no bias/alibi/sliding-window/dropout/sink/fp8/varlen/context- + parallel/kvcache). aiter additionally self-gates and falls back to CK internally, + so this is a pre-filter to avoid diverting calls the native kernel cannot serve. + """ + if not _use_aiter_splitkv: + return False + if get_device_compute_capability() != (9, 4): # gfx942 only + return False + if q.dtype != torch.bfloat16 or v.dtype != torch.bfloat16: + return False + if q.shape[-1] != 64 or v.shape[-1] != 64: + return False + if qkv_format not in ("bshd", "sbhd"): # dense only; thd/varlen unsupported + return False + if "padding" in attn_mask_type: + return False + if attn_mask_type not in ("no_mask", "causal"): + return False + if core_attention_bias_type != "no_bias": # excludes both bias and alibi + return False + if window_size is not None and tuple(window_size) != (-1, -1): # no sliding window + return False + if dropout_p != 0.0: + return False + if fp8 or context_parallel or inference_params is not None: + return False + if softmax_offset is not None: # no learnable sink + return False + if "causal" in attn_mask_type and max_seqlen_kv < max_seqlen_q: + return False + return True + + try: if fa_utils.use_aiter_triton: raise PackageNotFoundError # skip version check for aiter triton @@ -1941,6 +2023,44 @@ def forward( if (kv_format == "thd" or "padding" in attn_mask_type) and cu_seqlens_kv_padded is None: cu_seqlens_kv_padded = cu_seqlens_kv + # ROCm: opt-in interception (NVTE_FUSED_ATTN_SPLITKV=1). For eligible dense bf16 + # head-dim-64 calls, route the forward to AITER's native split-K kernel instead + # of the CK FusedAttention path. aiter.ops.mha.flash_attn_func is its own + # autograd Function (handles its own backward) and self-gates / falls back to CK + # internally, so non-eligible shapes are unaffected. num_splits=0 => AITER picks + # the split count heuristically. + if _aiter_splitkv_eligible( + query_layer, + value_layer, + qkv_format, + attn_mask_type, + core_attention_bias_type, + window_size, + self.attention_dropout if self.training else 0.0, + fp8, + context_parallel, + inference_params, + softmax_offset, + max_seqlen_q, + max_seqlen_kv, + ): + q, k, v = query_layer, key_layer, value_layer + if qkv_format == "sbhd": + q, k, v = (x.transpose(0, 1).contiguous() for x in (q, k, v)) + out = _aiter_splitkv_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + num_splits=0, + ) + if qkv_format == "sbhd": + out = out.transpose(0, 1) + # ...hd -> ...(hd), matching the FusedAttnFunc return convention below. + return out.reshape(*out.shape[:-2], -1) + use_FAv2_bwd = ( self.use_FAv2_bwd and (core_attention_bias_type == "no_bias")