Add ROCm HIP small-seq fused attention via crossattn_hip_kernel#625
Add ROCm HIP small-seq fused attention via crossattn_hip_kernel#625VeeraRajasekhar wants to merge 3 commits into
Conversation
| return False | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not _on_gfx942(), reason="CK small-seq is implemented for gfx942 (MI300X) only") |
There was a problem hiding this comment.
not (is_hip_extension() and get_device_compute_capability() == 94)
| #include "../util/cuda_runtime.h" | ||
| #include "../util/system.h" | ||
| #include "fused_attn_ck.h" | ||
| #include "fused_attn_smallseq.h" |
There was a problem hiding this comment.
Shouldn't it be embedded at higher level, not at CK only?
| } | ||
|
|
||
| bool is_nvte_ck_small_seq_enabled() { | ||
| #ifdef __HIP_PLATFORM_AMD__ |
There was a problem hiding this comment.
fused_attn_rocm/ is always built on AMD platform
|
|
||
| bool is_nvte_ck_small_seq_enabled() { | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| const std::string& arch = cuda::sm_arch_name(); |
| seq_desc_format=SeqDescFormat.SegmentIDs, | ||
| num_segments_per_seq=2, | ||
| ) | ||
| runner.test_forward() |
There was a problem hiding this comment.
The implementation is added for FWD and BWD but you test FWD only
Claude WalkthroughIntent. Adds an opt-in HIP small-seq fast path for ROCm fused attention, targeting cross-attention workloads where runtime seqlen ≤ 17 on gfx942 (MI300X). When eligible, requests are routed to the MFMA kernels in the new Key changes.
Walkthrough.
Testing. One new test, Notes for reviewers.
Generated by Claude. To request a code review, comment |
| [submodule "3rdparty/crossattn_hip_kernel"] | ||
| path = 3rdparty/crossattn_hip_kernel | ||
| url = https://github.com/ROCm/crossattn_hip_kernel.git | ||
| branch = prototype/lse-aux |
There was a problem hiding this comment.
Pinning to branch = prototype/lse-aux ties this PR's stability to an in-development upstream branch — any force-push, rebase, or experimental change there will affect every CI run and consumer that re-initializes the submodule, and a future merge may silently move semantics under us. Consider either pinning to a stable branch/tag once available, or at minimum documenting the prototype dependency in the PR description and tracking a follow-up to switch when upstream stabilizes.
| if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/crossattn_hip_kernel/include") | ||
| target_include_directories(transformer_engine PRIVATE | ||
| "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/crossattn_hip_kernel/include") | ||
| endif() |
There was a problem hiding this comment.
This if(EXISTS ...) silently no-ops when 3rdparty/crossattn_hip_kernel isn't initialized, so a build with USE_FUSED_ATTN_CK enabled will compile fused_attn_smallseq.cpp without the required attn_fwd_mfma_dispatch.h / attn_bwd_mfma_16x16.h and fail later with a cryptic missing-header error rather than a clear "please init the submodule" message. Suggest either making the include unconditional (so cmake errors out at configure time if the submodule is absent) or pairing the existence check with a message(WARNING|FATAL_ERROR ...) so the failure mode is obvious.
| return false; | ||
| } | ||
| using Config = | ||
| FmhaKernelConfig<kMaxBsInst, HEAD_NUM, 17, HEAD_DIM, 256, false, CausalMaskType::DISABLE, 17>; |
There was a problem hiding this comment.
The 17 literals here (and again at line 201 in launch_bwd_inst) duplicate kSmallSeqMaxSeqlen from fused_attn_smallseq.h. If the cap is ever bumped in the header, these template args silently keep the old tile size and either won't compile or produce wrong results. Reference the constant directly (kSmallSeqMaxSeqlen) or add a static_assert(kSmallSeqMaxSeqlen == 17, ...) next to the template so a mismatch fails loudly.
| const float sqr_dk_scale = attn_scale / std::sqrt(static_cast<float>(HEAD_DIM)); | ||
| Launcher::run_attn_fwd_kernel(Q, K, V, nullptr, 0.0f, sqr_dk_scale, O, softmax_lse, cu_q, cu_qp, | ||
| cu_kv, cu_kvp, padded_q_to_batch, total_padded_q); | ||
| NVTE_CHECK_CUDA(hipStreamSynchronize(stream)); |
There was a problem hiding this comment.
hipStreamSynchronize inside the launcher turns the small-seq forward into a host-blocking call, defeating the async-on-stream contract that the surrounding CK/aiter dispatch follows (see ck_fused_attn::ck_attn_fwd and the other branches in fused_attn_ck.cpp that just enqueue and return). The same pattern is repeated at line 206 in launch_bwd_inst. If the underlying CK kernel actually needs a sync (e.g. it uses pinned-host scratch with non-stream semantics), please add a short comment explaining why; otherwise drop both syncs.
| static_cast<const int*>(devPtrCuSeqlenPaddedQ), bs, devPtrPaddedQToBatch); | ||
| NVTE_CHECK_CUDA(hipGetLastError()); | ||
| } | ||
| NVTE_CHECK_CUDA(hipStreamSynchronize(hip_stream)); |
There was a problem hiding this comment.
This hipStreamSynchronize(hip_stream) sandwiched between build_padded_q_to_batch_kernel and fused_attn_smallseq_fwd is unnecessary: both launches are submitted to the same stream/hip_stream and are already FIFO-ordered. Removing it lets the build kernel and the small-seq kernel enqueue without a host round-trip, which matters when this path is invoked per training step. (Independently, the small-seq launcher itself also calls hipStreamSynchronize — see the comment on fused_attn_smallseq.cpp:137.)
| (void)workspace_next; // Legacy layout: remainder reserved for small-seq scratch (unused here). | ||
| constexpr int k_build_padded_threads = 256; | ||
| const int bs = static_cast<int>(b); | ||
| if(bs > 0) { |
There was a problem hiding this comment.
When bs == 0 we skip launching build_padded_q_to_batch_kernel, but fused_attn_smallseq_fwd is still invoked with devPtrPaddedQToBatch pointing into an allocated-but-uninitialized region of the workspace. The small-seq kernel may still treat empty batches safely, but right now that's an unstated invariant. Either short-circuit when bs == 0 (skip the whole small-seq call and let the regular CK path handle it / no-op naturally) or comment that the kernel tolerates an uninitialized padded_q_to_batch when b == 0.
| bool ran_smallseq = false; | ||
| if(ck_smallseq_workspace_prefix != nullptr && is_ragged && ck_small_seq_enabled) { | ||
| void* workspace_next = ck_smallseq_workspace_prefix; | ||
| void* max_seqlen_workspace_q = workspace_next; | ||
| void* max_seqlen_workspace_kv = | ||
| static_cast<void*>(static_cast<int8_t*>(workspace_next) + sizeof(uint64_t)); | ||
| hipStream_t hip_stream = reinterpret_cast<hipStream_t>(stream); | ||
| const size_t runtime_max_seqlen_q = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen( | ||
| b, devPtrCuSeqlensQ, devPtrCuSeqlenPaddedQ, max_seqlen_workspace_q, hip_stream)); | ||
| const size_t runtime_max_seqlen_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen( | ||
| b, devPtrCuSeqlensKV, devPtrCuSeqlenPaddedKV, max_seqlen_workspace_kv, hip_stream)); | ||
| workspace_next = | ||
| static_cast<void*>(static_cast<int8_t*>(workspace_next) + 2 * sizeof(uint64_t)); | ||
| if(nvte_log_ck_config) { | ||
| std::cout << std::endl << "attn_fwd(ck small-seq): "; | ||
| std::cout << "b: " << b << ", "; | ||
| std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", "; | ||
| std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << ", "; | ||
| std::cout << "flow: " | ||
| << (is_runtime_small_seq_eligible(runtime_max_seqlen_q, runtime_max_seqlen_kv) | ||
| ? "ck-smallseq" | ||
| : "regular ck/aiter") | ||
| << std::endl; | ||
| } | ||
| if(is_runtime_small_seq_eligible(runtime_max_seqlen_q, runtime_max_seqlen_kv)) { | ||
| const int total_padded_q = static_cast<int>(max_tokens_q); | ||
| int* devPtrPaddedQToBatch = static_cast<int*>(workspace_next); | ||
| workspace_next = static_cast<void*>(static_cast<int8_t*>(workspace_next) + | ||
| static_cast<size_t>(total_padded_q) * sizeof(int)); | ||
| (void)workspace_next; // Legacy layout: remainder reserved for small-seq scratch (unused here). | ||
| constexpr int k_build_padded_threads = 256; | ||
| const int bs = static_cast<int>(b); | ||
| if(bs > 0) { | ||
| const unsigned grid_x = static_cast<unsigned>( | ||
| (static_cast<int64_t>(bs) + k_build_padded_threads - 1) / k_build_padded_threads); | ||
| dim3 grid(grid_x); | ||
| dim3 block(k_build_padded_threads); | ||
| build_padded_q_to_batch_kernel<<<grid, block, 0, stream>>>( | ||
| static_cast<const int*>(devPtrCuSeqlenPaddedQ), bs, devPtrPaddedQToBatch); | ||
| NVTE_CHECK_CUDA(hipGetLastError()); | ||
| } | ||
| NVTE_CHECK_CUDA(hipStreamSynchronize(hip_stream)); | ||
| ran_smallseq = fused_attn_smallseq_fwd( | ||
| b, h, d_qk, max_tokens_q, max_tokens_kv, scaling_factor, devPtrQ, devPtrK, devPtrV, | ||
| devPtrO, devPtrSoftmaxLSEWithoutPadding, devPtrCuSeqlensQ, devPtrCuSeqlenPaddedQ, | ||
| devPtrCuSeqlensKV, devPtrCuSeqlenPaddedKV, devPtrPaddedQToBatch, | ||
| static_cast<NVTEDType>(dtype), stream); | ||
| } | ||
| } |
There was a problem hiding this comment.
Two things in this small-seq prelude that show up again almost verbatim in the bwd path (lines ~1119–1153) and could be tightened together:
- Redundant guard:
ck_smallseq_workspace_prefix != nullptr && is_ragged && ck_small_seq_enabledre-checksis_ragged, which is already an&&term ofck_small_seq_enabledat line 524 (and 835 in bwd).ck_smallseq_workspace_prefix != nullptralone implies both. - Dead pointer arithmetic: line 717–719 advances
workspace_nextpastpadded_q_to_batchand then immediately silences it with(void)workspace_next;. The same pattern appears at lines 1145–1147 in bwd. If the remainder of the prefix is unused on this branch, drop the increment; if it's reserved for future use, a// reservedcomment without the dead arithmetic reads more honestly.
Both blocks are otherwise nearly identical (workspace carve, get_runtime_max_seqlen x2, nvte_log_ck_config print, eligibility check). Worth factoring into a small helper (e.g. probe_runtime_max_seqlens(...) returning a struct) so fwd/bwd stay in sync as small-seq evolves.
Claude review — small-seq CK fused-attn integrationReviewed the PR head ( Verdict: approach looks sound (env-gated, gfx942-only, falls back to regular CK on miss), but there are a handful of issues worth addressing before merge. I left 7 inline comments covering:
I did not duplicate the points already raised in @ipanfilo's review (smallseq dispatch belonging at a higher level than CK-only, the Copyright headers: OK — all modified files carry AMD year |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: