Skip to content

Add ROCm HIP small-seq fused attention via crossattn_hip_kernel#625

Open
VeeraRajasekhar wants to merge 3 commits into
devfrom
veergopu/integrate-small-seq-mfma
Open

Add ROCm HIP small-seq fused attention via crossattn_hip_kernel#625
VeeraRajasekhar wants to merge 3 commits into
devfrom
veergopu/integrate-small-seq-mfma

Conversation

@VeeraRajasekhar

Copy link
Copy Markdown
Contributor

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Comment thread tests/jax/test_fused_attn.py Outdated
return False


@pytest.mark.skipif(not _on_gfx942(), reason="CK small-seq is implemented for gfx942 (MI300X) only")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not (is_hip_extension() and get_device_compute_capability() == 94)

#include "../util/cuda_runtime.h"
#include "../util/system.h"
#include "fused_attn_ck.h"
#include "fused_attn_smallseq.h"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be embedded at higher level, not at CK only?

}

bool is_nvte_ck_small_seq_enabled() {
#ifdef __HIP_PLATFORM_AMD__

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fused_attn_rocm/ is always built on AMD platform


bool is_nvte_ck_small_seq_enabled() {
#ifdef __HIP_PLATFORM_AMD__
const std::string& arch = cuda::sm_arch_name();

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use sm_arch instead

Comment thread tests/jax/test_fused_attn.py Outdated
seq_desc_format=SeqDescFormat.SegmentIDs,
num_segments_per_seq=2,
)
runner.test_forward()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation is added for FWD and BWD but you test FWD only

@github-actions

Copy link
Copy Markdown

Claude Walkthrough

Intent. Adds an opt-in HIP small-seq fast path for ROCm fused attention, targeting cross-attention workloads where runtime seqlen ≤ 17 on gfx942 (MI300X). When eligible, requests are routed to the MFMA kernels in the new crossattn_hip_kernel submodule; otherwise the existing CK/AITER path runs unchanged.

Key changes.

  • New submodule 3rdparty/crossattn_hip_kernel (branch prototype/lse-aux) wired into both CI workflows and the build (.gitmodules, .github/workflows/rocm-ci.yml, .github/workflows/rocm-wheels-build.yml, transformer_engine/common/CMakeLists.txt:293).
  • New eligibility/dispatch layer transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.{h,cpp} — static config checks, runtime seqlen check (kSmallSeqMaxSeqlen = 17), env gate NVTE_FUSED_ATTN_CK_SMALLSEQ, and templated launchers for fp16/bf16 × heads∈{16,32} × head_dim∈{128,256}.
  • New CK helper declaration ck_fused_attn::get_runtime_max_seqlen in transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp:146 used to probe actual max seqlen from cu_seqlens on device.
  • fused_attn_ck.cpp fwd/bwd impls reserve a small-seq workspace prefix, probe runtime seqlens, and branch to the small-seq launchers when eligible — falling through to the existing CK args path otherwise.
  • JAX workspace-size queries grow the reported workspace to include the small-seq scratch (transformer_engine/jax/csrc/extensions/attention.cpp:220, :523).
  • New pytest test_fused_attn_ck_smallseq_thd_gfx942 in tests/jax/test_fused_attn.py:1920.

Walkthrough.

  • fused_attn_smallseq.h/.cpp defines the small-seq surface. small_seq_static_config_ok rejects anything other than fp16/bf16, no dropout, no bias, equal Q/KV head dim ∈ {128, 256}, num_heads == num_gqa_groups ∈ {16, 32}, and padding-or-no-mask. is_runtime_small_seq_eligible requires both runtime max-seqlens in (0, 17]. is_nvte_ck_small_seq_enabled additionally requires gfx942 and the env var NVTE_FUSED_ATTN_CK_SMALLSEQ=1 — so this path stays off by default. The .cpp provides no-op stubs when USE_FUSED_ATTN_CK is not defined; when defined, it pulls in attn_fwd_mfma_dispatch.h / attn_bwd_mfma_16x16.h from the new submodule and instantiates FmhaKernelConfig<16384, HEAD_NUM, 17, HEAD_DIM, 256, false, CausalMaskType::DISABLE, 17> per (heads, head_dim) combo. A comment notes head_dim=512 is intentionally excluded due to LDS limits on CDNA.

  • fused_attn_ck.cpp adds a build_padded_q_to_batch_kernel (one-thread-per-batch that fills padded_q_to_batch[i] = b over each batch's padded range) and, in both fused_attn_ck_fwd_impl and fused_attn_ck_bwd_impl, reserves a workspace prefix when ck_small_seq_enabled && is_ragged. The prefix layout is [uint64 max_seqlen_q probe][uint64 max_seqlen_kv probe][int32 padded_q_to_batch]. At runtime it calls ck_fused_attn::get_runtime_max_seqlen twice, logs the chosen flow when nvte_log_ck_config is on, and only dispatches to fused_attn_smallseq_{fwd,bwd} when runtime seqlens fit. If the small-seq attempt returns false (e.g. unsupported batch size), execution falls through to the original CK args path — the existing logic is preserved unchanged, just guarded behind if(!ran_smallseq).

  • JAX bindings (attention.cpp) widen the reported workspace size for both fwd and bwd query helpers when small-seq is enabled and the static config matches. They use input_batch * q_max_seqlen as a conservative upper bound for max_tokens_q since the runtime layout isn't known yet at workspace-query time.

Testing. One new test, test_fused_attn_ck_smallseq_thd_gfx942, skipped on non-gfx942 devices. It sets NVTE_FUSED_ATTN_CK_SMALLSEQ=1, runs a THD forward with batch=2, seqlen=8, heads=16, head_dim=128, bf16, padding mask, segment IDs (2 segments/seq), and compares against the existing reference via FusedAttnRunner.test_forward(). No backward coverage and no new C++ tests.

Notes for reviewers.

  • Path is fully opt-in: gated by NVTE_FUSED_ATTN_CK_SMALLSEQ=1 and gfx942 and a matching static config and runtime seqlens ≤ 17. Default behavior is unchanged.
  • The small-seq path performs hipStreamSynchronize after each launch (in launch_{fwd,bwd}_inst), which is heavier than the async CK path — fine for correctness/prototyping but worth noting if this gets enabled in latency-sensitive flows.
  • JAX workspace size uses input_batch * q_max_seqlen as an upper bound for max_tokens_q; the C++ side later allocates from the planner against the actual max_tokens_q, so as long as the upper bound ≥ actual the carve fits.
  • The bwd path threads the padded_q_to_batch carve in the workspace layout but does not consume it (comment in source); kept for layout symmetry with the legacy CK small-seq prefix.
  • New submodule pin lives on a prototype/lse-aux branch — reviewers may want to confirm whether that's intentional or expected to move to a stable branch before merge.

Generated by Claude. To request a code review, comment /claude review.

Comment thread .gitmodules Outdated
[submodule "3rdparty/crossattn_hip_kernel"]
path = 3rdparty/crossattn_hip_kernel
url = https://github.com/ROCm/crossattn_hip_kernel.git
branch = prototype/lse-aux

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pinning to branch = prototype/lse-aux ties this PR's stability to an in-development upstream branch — any force-push, rebase, or experimental change there will affect every CI run and consumer that re-initializes the submodule, and a future merge may silently move semantics under us. Consider either pinning to a stable branch/tag once available, or at minimum documenting the prototype dependency in the PR description and tracking a follow-up to switch when upstream stabilizes.

Comment on lines +359 to +362
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/crossattn_hip_kernel/include")
target_include_directories(transformer_engine PRIVATE
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/crossattn_hip_kernel/include")
endif()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if(EXISTS ...) silently no-ops when 3rdparty/crossattn_hip_kernel isn't initialized, so a build with USE_FUSED_ATTN_CK enabled will compile fused_attn_smallseq.cpp without the required attn_fwd_mfma_dispatch.h / attn_bwd_mfma_16x16.h and fail later with a cryptic missing-header error rather than a clear "please init the submodule" message. Suggest either making the include unconditional (so cmake errors out at configure time if the submodule is absent) or pairing the existence check with a message(WARNING|FATAL_ERROR ...) so the failure mode is obvious.

return false;
}
using Config =
FmhaKernelConfig<kMaxBsInst, HEAD_NUM, 17, HEAD_DIM, 256, false, CausalMaskType::DISABLE, 17>;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 17 literals here (and again at line 201 in launch_bwd_inst) duplicate kSmallSeqMaxSeqlen from fused_attn_smallseq.h. If the cap is ever bumped in the header, these template args silently keep the old tile size and either won't compile or produce wrong results. Reference the constant directly (kSmallSeqMaxSeqlen) or add a static_assert(kSmallSeqMaxSeqlen == 17, ...) next to the template so a mismatch fails loudly.

const float sqr_dk_scale = attn_scale / std::sqrt(static_cast<float>(HEAD_DIM));
Launcher::run_attn_fwd_kernel(Q, K, V, nullptr, 0.0f, sqr_dk_scale, O, softmax_lse, cu_q, cu_qp,
cu_kv, cu_kvp, padded_q_to_batch, total_padded_q);
NVTE_CHECK_CUDA(hipStreamSynchronize(stream));

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hipStreamSynchronize inside the launcher turns the small-seq forward into a host-blocking call, defeating the async-on-stream contract that the surrounding CK/aiter dispatch follows (see ck_fused_attn::ck_attn_fwd and the other branches in fused_attn_ck.cpp that just enqueue and return). The same pattern is repeated at line 206 in launch_bwd_inst. If the underlying CK kernel actually needs a sync (e.g. it uses pinned-host scratch with non-stream semantics), please add a short comment explaining why; otherwise drop both syncs.

static_cast<const int*>(devPtrCuSeqlenPaddedQ), bs, devPtrPaddedQToBatch);
NVTE_CHECK_CUDA(hipGetLastError());
}
NVTE_CHECK_CUDA(hipStreamSynchronize(hip_stream));

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hipStreamSynchronize(hip_stream) sandwiched between build_padded_q_to_batch_kernel and fused_attn_smallseq_fwd is unnecessary: both launches are submitted to the same stream/hip_stream and are already FIFO-ordered. Removing it lets the build kernel and the small-seq kernel enqueue without a host round-trip, which matters when this path is invoked per training step. (Independently, the small-seq launcher itself also calls hipStreamSynchronize — see the comment on fused_attn_smallseq.cpp:137.)

(void)workspace_next; // Legacy layout: remainder reserved for small-seq scratch (unused here).
constexpr int k_build_padded_threads = 256;
const int bs = static_cast<int>(b);
if(bs > 0) {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When bs == 0 we skip launching build_padded_q_to_batch_kernel, but fused_attn_smallseq_fwd is still invoked with devPtrPaddedQToBatch pointing into an allocated-but-uninitialized region of the workspace. The small-seq kernel may still treat empty batches safely, but right now that's an unstated invariant. Either short-circuit when bs == 0 (skip the whole small-seq call and let the regular CK path handle it / no-op naturally) or comment that the kernel tolerates an uninitialized padded_q_to_batch when b == 0.

Comment on lines +690 to +738
bool ran_smallseq = false;
if(ck_smallseq_workspace_prefix != nullptr && is_ragged && ck_small_seq_enabled) {
void* workspace_next = ck_smallseq_workspace_prefix;
void* max_seqlen_workspace_q = workspace_next;
void* max_seqlen_workspace_kv =
static_cast<void*>(static_cast<int8_t*>(workspace_next) + sizeof(uint64_t));
hipStream_t hip_stream = reinterpret_cast<hipStream_t>(stream);
const size_t runtime_max_seqlen_q = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
b, devPtrCuSeqlensQ, devPtrCuSeqlenPaddedQ, max_seqlen_workspace_q, hip_stream));
const size_t runtime_max_seqlen_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
b, devPtrCuSeqlensKV, devPtrCuSeqlenPaddedKV, max_seqlen_workspace_kv, hip_stream));
workspace_next =
static_cast<void*>(static_cast<int8_t*>(workspace_next) + 2 * sizeof(uint64_t));
if(nvte_log_ck_config) {
std::cout << std::endl << "attn_fwd(ck small-seq): ";
std::cout << "b: " << b << ", ";
std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", ";
std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << ", ";
std::cout << "flow: "
<< (is_runtime_small_seq_eligible(runtime_max_seqlen_q, runtime_max_seqlen_kv)
? "ck-smallseq"
: "regular ck/aiter")
<< std::endl;
}
if(is_runtime_small_seq_eligible(runtime_max_seqlen_q, runtime_max_seqlen_kv)) {
const int total_padded_q = static_cast<int>(max_tokens_q);
int* devPtrPaddedQToBatch = static_cast<int*>(workspace_next);
workspace_next = static_cast<void*>(static_cast<int8_t*>(workspace_next) +
static_cast<size_t>(total_padded_q) * sizeof(int));
(void)workspace_next; // Legacy layout: remainder reserved for small-seq scratch (unused here).
constexpr int k_build_padded_threads = 256;
const int bs = static_cast<int>(b);
if(bs > 0) {
const unsigned grid_x = static_cast<unsigned>(
(static_cast<int64_t>(bs) + k_build_padded_threads - 1) / k_build_padded_threads);
dim3 grid(grid_x);
dim3 block(k_build_padded_threads);
build_padded_q_to_batch_kernel<<<grid, block, 0, stream>>>(
static_cast<const int*>(devPtrCuSeqlenPaddedQ), bs, devPtrPaddedQToBatch);
NVTE_CHECK_CUDA(hipGetLastError());
}
NVTE_CHECK_CUDA(hipStreamSynchronize(hip_stream));
ran_smallseq = fused_attn_smallseq_fwd(
b, h, d_qk, max_tokens_q, max_tokens_kv, scaling_factor, devPtrQ, devPtrK, devPtrV,
devPtrO, devPtrSoftmaxLSEWithoutPadding, devPtrCuSeqlensQ, devPtrCuSeqlenPaddedQ,
devPtrCuSeqlensKV, devPtrCuSeqlenPaddedKV, devPtrPaddedQToBatch,
static_cast<NVTEDType>(dtype), stream);
}
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things in this small-seq prelude that show up again almost verbatim in the bwd path (lines ~1119–1153) and could be tightened together:

  1. Redundant guard: ck_smallseq_workspace_prefix != nullptr && is_ragged && ck_small_seq_enabled re-checks is_ragged, which is already an && term of ck_small_seq_enabled at line 524 (and 835 in bwd). ck_smallseq_workspace_prefix != nullptr alone implies both.
  2. Dead pointer arithmetic: line 717–719 advances workspace_next past padded_q_to_batch and then immediately silences it with (void)workspace_next;. The same pattern appears at lines 1145–1147 in bwd. If the remainder of the prefix is unused on this branch, drop the increment; if it's reserved for future use, a // reserved comment without the dead arithmetic reads more honestly.

Both blocks are otherwise nearly identical (workspace carve, get_runtime_max_seqlen x2, nvte_log_ck_config print, eligibility check). Worth factoring into a small helper (e.g. probe_runtime_max_seqlens(...) returning a struct) so fwd/bwd stay in sync as small-seq evolves.

@github-actions

Copy link
Copy Markdown

Claude review — small-seq CK fused-attn integration

Reviewed the PR head (a5d3e26) against dev (20d2aae). Scope: 11 files, +717/-42 — new HIP small-seq dispatch (fused_attn_smallseq.{h,cpp}), CK fwd/bwd wiring in fused_attn_ck.cpp, JAX workspace-size plumbing, new crossattn_hip_kernel submodule, CI/wheel-build updates, and one JAX test.

Verdict: approach looks sound (env-gated, gfx942-only, falls back to regular CK on miss), but there are a handful of issues worth addressing before merge. I left 7 inline comments covering:

  • Submodule pinned to a prototype/ upstream branch — stability risk for CI/consumers.
  • if(EXISTS ...) include guard in CMake masks missing-submodule failures behind cryptic header errors.
  • Hard-coded 17 tile size in the kernel templates duplicates kSmallSeqMaxSeqlen from the header (no static_assert to keep them in sync).
  • Two hipStreamSynchronize calls (inside the smallseq launchers and between build_padded_q_to_batch_kernel and the smallseq call) defeat the async-on-stream contract used elsewhere in the file.
  • bs == 0 corner case in fwd leaves devPtrPaddedQToBatch uninitialized while still calling fused_attn_smallseq_fwd.
  • Duplicate fwd/bwd workspace-prefix carving in fused_attn_ck.cpp with dead (void)workspace_next; arithmetic — worth factoring into a helper.

I did not duplicate the points already raised in @ipanfilo's review (smallseq dispatch belonging at a higher level than CK-only, the __HIP_PLATFORM_AMD__ guard, sm_arch vs sm_arch_name, the _on_gfx942 helper, and missing BWD coverage in the JAX test).

Copyright headers: OK — all modified files carry AMD year 2026; new files have current-year AMD headers consistent with the existing fused_attn_rocm / ck_fused_attn precedent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants