From 449811ee9fc38fa54a49816d809cc6896f4295e4 Mon Sep 17 00:00:00 2001 From: dcampora <961215+dcampora@users.noreply.github.com> Date: Thu, 16 Apr 2026 03:29:42 +0000 Subject: [PATCH 1/6] Add AIR TopK support to TE JAX extension Adds a custom AIR TopK implementation (header-only, vendored into transformer_engine/common/util/) exposed as a JAX FFI custom call via the TE JAX extension. Key changes: - transformer_engine/common/util/air_topk.cu: AIR TopK CUDA kernel - transformer_engine/common/util/standalone_air_topk.cuh: vendored header - transformer_engine/common/include/transformer_engine/air_topk.h: C API - transformer_engine/jax/csrc/extensions/air_topk.cpp: JAX FFI binding - transformer_engine/jax/cpp_extensions/air_topk.py: Python wrapper - CMakeLists.txt: compile new kernel; use CCCL from CUDA toolkit - CMakeLists.txt: fix SM100 arch handling when all arches are special-cased Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: dcampora <961215+dcampora@users.noreply.github.com> --- tests/jax/test_custom_call_compute.py | 77 + transformer_engine/common/CMakeLists.txt | 6 + .../include/transformer_engine/air_topk.h | 57 + transformer_engine/common/util/air_topk.cu | 57 + .../common/util/standalone_air_topk.cuh | 1274 +++++++++++++++++ .../jax/cpp_extensions/air_topk.py | 136 ++ transformer_engine/jax/csrc/extensions.h | 4 + .../jax/csrc/extensions/air_topk.cpp | 92 ++ .../jax/csrc/extensions/pybind.cpp | 4 + 9 files changed, 1707 insertions(+) create mode 100644 transformer_engine/common/include/transformer_engine/air_topk.h create mode 100644 transformer_engine/common/util/air_topk.cu create mode 100644 transformer_engine/common/util/standalone_air_topk.cuh create mode 100644 transformer_engine/jax/cpp_extensions/air_topk.py create mode 100644 transformer_engine/jax/csrc/extensions/air_topk.cpp diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 613aefc178..0d35058a1a 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -47,6 +47,7 @@ from transformer_engine.jax.activation import activation from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense +from transformer_engine.jax.cpp_extensions.air_topk import air_topk GEMM_CASES = [ (256, 256, 512), @@ -1955,3 +1956,79 @@ def f(x): actual = load_array_dump("my_tensor_gpu0.bin", shape, dtype) assert_allclose(actual, expected, dtype=dtype) + + +@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float32]) +@pytest.mark.parametrize( + "problem_size", + [ + (1, 10000, 100), + (1, 50000, 200), + (4, 16384, 256), + (8, 65536, 512), + (1, 1000000, 1000), + ], +) +class TestAirTopK: + """Correctness tests for the AIR TopK JAX primitive. + + Each test generates an input whose top-k entries lie in a known value range + so that correctness can be verified without a full sort, then cross-checks + against jax.lax.top_k as a reference. + """ + + def test_air_topk_1d(self, dtype, problem_size): + """1-D input: single row.""" + _bs, n, k = problem_size + + prng_key = jax.random.PRNGKey(0) + keys = jax.random.split(prng_key, 3) + topk_vals = jax.random.uniform(keys[0], shape=(k,), dtype=dtype, minval=1.5, maxval=2.5) + bottom_vals = jax.random.uniform(keys[1], shape=(n - k,), dtype=dtype, minval=0.0, maxval=1.0) + x = jax.random.permutation(keys[2], jnp.concatenate([topk_vals, bottom_vals])) + + ref_vals, ref_idx = jax.jit(jax.lax.top_k, static_argnums=(1,))(x, k) + prim_vals, prim_idx = jax.jit(air_topk, static_argnums=(1,))(x, k) + + # AIR TopK output is unordered; sort before comparing. + ref_vals, ref_idx = jax.lax.sort_key_val(ref_vals, ref_idx) + prim_vals, prim_idx = jax.lax.sort_key_val(prim_vals, prim_idx) + + assert_allclose(prim_vals, ref_vals, dtype=dtype) + + sorted_x = jax.lax.sort(x) + assert prim_vals[0] >= sorted_x[-(k + 1)] + + # Values at returned indices must match reference. + assert_allclose(x[prim_idx], x[ref_idx], dtype=dtype) + + def test_air_topk_2d(self, dtype, problem_size): + """2-D input: each row is an independent top-k problem.""" + bs, n, k = problem_size + + prng_key = jax.random.PRNGKey(42) + keys = jax.random.split(prng_key, 3) + topk_vals = jax.random.uniform(keys[0], shape=(bs, k), dtype=dtype, minval=1.5, maxval=2.5) + bottom_vals = jax.random.uniform(keys[1], shape=(bs, n - k), dtype=dtype, minval=0.0, maxval=1.0) + x_unsorted = jnp.concatenate([topk_vals, bottom_vals], axis=1) + # Shuffle columns independently per row. + col_perm = jax.random.permutation(keys[2], n) + x = x_unsorted[:, col_perm] + + ref_vals, ref_idx = jax.jit(jax.lax.top_k, static_argnums=(1,))(x, k) + prim_vals, prim_idx = jax.jit(air_topk, static_argnums=(1,))(x, k) + + # Sort each row independently for comparison. + ref_vals, ref_idx = jax.vmap(jax.lax.sort_key_val)(ref_vals, ref_idx) + prim_vals, prim_idx = jax.vmap(jax.lax.sort_key_val)(prim_vals, prim_idx) + + assert_allclose(prim_vals, ref_vals, dtype=dtype) + + # For each row, the smallest selected value must be >= the (k+1)-th largest in that row. + sorted_x = jnp.sort(x, axis=1) + assert jnp.all(prim_vals[:, 0] >= sorted_x[:, -(k + 1)]) + + # Values at returned indices must match reference values. + prim_gathered = jnp.take_along_axis(x, prim_idx, axis=1) + ref_gathered = jnp.take_along_axis(x, ref_idx, axis=1) + assert_allclose(prim_gathered, ref_gathered, dtype=dtype) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index b9e2b907e0..761c0eae8a 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -79,6 +79,11 @@ if(NOT arch_120_index EQUAL -1) endif() endif() +# If all architectures were special-cased and removed, disable CMake's automatic +# CUDA_ARCHITECTURES management — compilation flags are set via COMPILE_OPTIONS below. +if(NOT CMAKE_CUDA_ARCHITECTURES) + set(CMAKE_CUDA_ARCHITECTURES OFF) +endif() # cuDNN frontend API set(CUDNN_FRONTEND_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include") @@ -151,6 +156,7 @@ list(APPEND transformer_engine_cuda_sources normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu permutation/permutation.cu util/padding.cu + util/air_topk.cu swizzle/swizzle.cu swizzle/swizzle_block_scaling.cu fused_softmax/scaled_masked_softmax.cu diff --git a/transformer_engine/common/include/transformer_engine/air_topk.h b/transformer_engine/common/include/transformer_engine/air_topk.h new file mode 100644 index 0000000000..8ff48f1854 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/air_topk.h @@ -0,0 +1,57 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_AIR_TOPK_H_ +#define TRANSFORMER_ENGINE_AIR_TOPK_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Compute the top-K (key, index) pairs using the AIR radix algorithm. + * + * Operates on a batch of rows: each row of length \p seq_len is processed + * independently and the \p k largest entries are selected. + * + * \param[in] stream CUDA stream used for the operation. + * \param[in] keys_in Input keys tensor, flat storage for + * batch_size rows of seq_len elements. + * \param[in] lengths_in Per-row lengths, shape (batch_size,); int32. + * Fill with seq_len for uniform-length batches. + * \param[in,out] keys_out Output top-k keys, flat storage for + * batch_size rows of k elements. + * \param[in,out] indices_out Output top-k indices (within each row), + * flat storage for batch_size rows of k int32 elements. + * \param[in,out] workspace Workspace tensor, shape (workspace_bytes,). + * \param[in] batch_size Number of rows. + * \param[in] seq_len Number of elements per row. + * \param[in] k Number of top-K entries to select per row. + * \param[in] workspace_bytes Workspace size in bytes; must be >= + * nvte_get_air_topk_workspace_bytes(batch_size, seq_len, k). + * + * Supported key dtypes: float32, bfloat16. + * Index dtype: int32. + */ +void nvte_air_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor lengths_in, + NVTETensor keys_out, NVTETensor indices_out, NVTETensor workspace, + int batch_size, int seq_len, int k, size_t workspace_bytes); + +/*! \brief Query the workspace size required by nvte_air_topk. + * + * \param[in] batch_size Number of rows. + * \param[in] seq_len Number of elements per row. + * \param[in] k Top-K count. + * \return Required workspace size in bytes. + */ +size_t nvte_get_air_topk_workspace_bytes(int batch_size, int seq_len, int k); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_AIR_TOPK_H_ diff --git a/transformer_engine/common/util/air_topk.cu b/transformer_engine/common/util/air_topk.cu new file mode 100644 index 0000000000..aaae7af795 --- /dev/null +++ b/transformer_engine/common/util/air_topk.cu @@ -0,0 +1,57 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "../common.h" +#include "standalone_air_topk.cuh" + +void nvte_air_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor lengths_in, + NVTETensor keys_out, NVTETensor indices_out, NVTETensor workspace, + int batch_size, int seq_len, int k, size_t workspace_bytes) { + NVTE_API_CALL(nvte_air_topk); + using namespace transformer_engine; + + const Tensor *keys_in_tensor = convertNVTETensorCheck(keys_in); + const Tensor *lengths_tensor = convertNVTETensorCheck(lengths_in); + Tensor *keys_out_tensor = convertNVTETensor(keys_out); + Tensor *indices_tensor = convertNVTETensor(indices_out); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + void *d_workspace = workspace_tensor->data.dptr; + const int *d_lengths = reinterpret_cast(lengths_tensor->data.dptr); + int *d_indices = reinterpret_cast(indices_tensor->data.dptr); + + auto dtype = keys_in_tensor->data.dtype; + +#define DISPATCH_AIR_TOPK(T, d_in_cast, d_out_cast) \ + do { \ + const T *d_in = reinterpret_cast(keys_in_tensor->data.dptr); \ + T *d_out = reinterpret_cast(keys_out_tensor->data.dptr); \ + nv::standalone_air_topk(d_workspace, workspace_bytes, d_in, batch_size, seq_len, k, \ + d_out, d_indices, /*greater=*/true, stream, \ + const_cast(d_lengths), /*is_prefill=*/false); \ + } while (0) + + if (dtype == DType::kBFloat16) { + DISPATCH_AIR_TOPK(__nv_bfloat16, , ); + } else if (dtype == DType::kFloat32) { + DISPATCH_AIR_TOPK(float, , ); + } else { + NVTE_ERROR("nvte_air_topk: unsupported key dtype (supported: float32, bfloat16)"); + } + +#undef DISPATCH_AIR_TOPK +} + +size_t nvte_get_air_topk_workspace_bytes(int batch_size, int seq_len, int k) { + // Call with buf=nullptr to perform a size query (no GPU work is launched). + size_t buf_size = 0; + nv::standalone_air_topk(nullptr, buf_size, nullptr, batch_size, seq_len, k, nullptr, + nullptr, /*greater=*/true, /*stream=*/nullptr, + /*lengths=*/nullptr, /*is_prefill=*/false); + return buf_size; +} diff --git a/transformer_engine/common/util/standalone_air_topk.cuh b/transformer_engine/common/util/standalone_air_topk.cuh new file mode 100644 index 0000000000..0a78342273 --- /dev/null +++ b/transformer_engine/common/util/standalone_air_topk.cuh @@ -0,0 +1,1274 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#pragma once + +constexpr int VECTORIZED_READ_SIZE = 16; +constexpr int WARP_SIZE = 32; +constexpr int WARP_BITS = 5; +constexpr unsigned FULL_WARP_MASK = 0xffffffff; + +#include +#include +#include + +#include +#include +#include +namespace cg = cooperative_groups; + +// Workspace pointer-alignment helpers. +inline size_t calc_aligned_size(const std::vector &sizes) { + const size_t ALIGN_BYTES = 256; + const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); + size_t total = 0; + for (auto sz : sizes) total += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; + return total + ALIGN_BYTES - 1; +} +inline std::vector calc_aligned_pointers(const void *p, const std::vector &sizes) { + const size_t ALIGN_BYTES = 256; + const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); + char *ptr = + reinterpret_cast((reinterpret_cast(p) + ALIGN_BYTES - 1) & ALIGN_MASK); + std::vector ptrs; + ptrs.reserve(sizes.size()); + for (auto sz : sizes) { + ptrs.push_back(ptr); + ptr += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; + } + return ptrs; +} + +// Helper: convert a float literal to type T without relying on implicit +// conversions (needed when __CUDA_NO_BFLOAT16_CONVERSIONS__ is defined). +namespace nv_detail { +template +__host__ __device__ inline T float_to_T(float v) { + return static_cast(v); +} +#if defined(__CUDACC__) +template <> +__host__ __device__ inline __nv_bfloat16 float_to_T<__nv_bfloat16>(float v) { + return __float2bfloat16(v); +} +#endif +} // namespace nv_detail + +namespace nv { + +namespace air_topk { +using WideT = float4; + +#ifdef __CUDA_ARCH__ +using ::atomicAdd; +inline __device__ size_t atomicAdd(size_t *address, size_t value) { + static_assert(sizeof(size_t) == sizeof(unsigned long long int)); + return atomicAdd((unsigned long long int *)address, (unsigned long long int)value); +} +#endif + +template +__host__ __device__ constexpr int calc_num_buckets() { + return 1 << BitsPerPass; +} + +/** + * @brief Provide a ceiling division operation ie. ceil(a / b) + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr __host__ __device__ IntType ceildiv(IntType a, IntType b) { + return (a + b - 1) / b; +} + +/** + * @brief Provide an alignment function ie. ceil(a / b) * b + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr __host__ __device__ IntType alignTo(IntType a, IntType b) { + return ceildiv(a, b) * b; +} + +template +__host__ __device__ constexpr int calc_num_passes() { + return ceildiv(sizeof(T) * 8, BitsPerPass); +} + +__host__ __device__ int round(int num, int round_value) { + return ((num - 1) / round_value + 1) * round_value; +} + +/** + * Bit 0 is the least significant (rightmost); + * this implementation processes input from the most to the least significant + * bit. This way, we can skip some passes in the end at the cost of having an + * unsorted output. + * + * NB: Use pass=-1 for calc_mask(). + */ +template +__device__ constexpr int calc_start_bit(int pass) { + int start_bit = static_cast(sizeof(T) * 8) - (pass + 1) * BitsPerPass; + if (start_bit < 0) { + start_bit = 0; + } + return start_bit; +} + +template +__device__ constexpr unsigned calc_mask(int pass) { + static_assert(BitsPerPass <= 31); + int num_bits = calc_start_bit(pass - 1) - calc_start_bit(pass); + return (1 << num_bits) - 1; +} + +/** + * Use CUB to twiddle bits - so that we can correctly compare bits of + * floating-point values as well as of integers. + */ +template +__device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool select_min) { + auto bits = reinterpret_cast::UnsignedBits &>(key); + bits = cub::Traits::TwiddleIn(bits); + if (!select_min) { + bits = ~bits; + } + return bits; +} + +template +__device__ T twiddle_out(typename cub::Traits::UnsignedBits bits, bool select_min) { + if (!select_min) { + bits = ~bits; + } + bits = cub::Traits::TwiddleOut(bits); + return reinterpret_cast(bits); +} + +template +__device__ int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) { + static_assert(BitsPerPass <= sizeof(int) * 8 - 1, + "BitsPerPass is too large that the result type could not be int"); + return (twiddle_in(x, select_min) >> start_bit) & mask; +} + +template +__host__ __device__ IdxT calc_buf_len(IdxT len) { + // When writing is skipped, only read `in`(type T). + // When writing is not skipped, read `in_buf`(T) and `in_idx_buf`(IdxT), and + // write `out_buf`(T) and `out_idx_buf`(IdxT). The ratio between these cases + // determines whether to skip writing and hence the buffer size. constexpr + // float ratio = 2 + sizeof(IdxT) * 2.0 / sizeof(T); + constexpr float ratio = 128; + return len / ratio; + // return len; +} + +/** + * Map a Func over the input data, using vectorized load instructions if + * possible. + * + * NB: in future, we should move this to + * cpp/include/raft/linalg/detail/unary_op.cuh, which currently does not support + * the second lambda argument (index of an element) + * + * @tparam T element type + * @tparam IdxT indexing type + * @tparam Func void (T x, IdxT idx) + * + * @param thread_rank rank of the calling thread among all participating threads + * @param num_threads number of the threads that participate in processing + * @param in the input data + * @param len the number of elements to read + * @param f the lambda taking two arguments (T x, IdxT idx) + */ +template +__device__ void vectorized_process(size_t thread_rank, size_t num_threads, const T *in, idxT len, + Func f) { + if constexpr (sizeof(T) >= sizeof(WideT)) { + for (idxT i = thread_rank; i < len; i += num_threads) { + f(in[i], i); + } + } else { + static_assert(sizeof(WideT) % sizeof(T) == 0); + constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); + // TODO: it's UB + union { + WideT scalar; + T array[items_per_scalar]; + } wide; + + int skip_cnt = + (reinterpret_cast(in) % sizeof(WideT)) + ? ((sizeof(WideT) - reinterpret_cast(in) % sizeof(WideT)) / sizeof(T)) + : 0; + if (skip_cnt > len) { + skip_cnt = len; + } + const WideT *in_cast = reinterpret_cast(in + skip_cnt); + const idxT len_cast = (len - skip_cnt) / items_per_scalar; + + for (idxT i = thread_rank; i < len_cast; i += num_threads) { + wide.scalar = in_cast[i]; + const idxT real_i = skip_cnt + i * items_per_scalar; +#pragma unroll + for (int j = 0; j < items_per_scalar; ++j) { + f(wide.array[j], real_i + j); + } + } + + static_assert(WARP_SIZE >= items_per_scalar); + // and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt + // no need to use loop + if (thread_rank < skip_cnt) { + f(in[thread_rank], thread_rank); + } + // because len_cast = (len - skip_cnt) / items_per_scalar, + // len_cast * items_per_scalar + items_per_scalar > len - skip_cnt; + // and so + // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= + // WARP_SIZE no need to use loop + const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank; + if (remain_i < len) { + f(in[remain_i], remain_i); + } + } +} + +// sync_width should >= WARP_SIZE +template +__device__ void vectorized_process(const T *in, idxT len, Func f, int sync_width) { + const idxT stride = blockDim.x * gridDim.x; + const idxT tid = blockIdx.x * blockDim.x + threadIdx.x; + if constexpr (sizeof(T) >= sizeof(WideT)) { + for (idxT i = tid; i < len; i += stride) { + f(in[i], i, true); + } + } else { + static_assert(sizeof(WideT) % sizeof(T) == 0); + constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); + union { + WideT scalar; + T array[items_per_scalar]; + } wide; + + int skip_cnt = + (reinterpret_cast(in) % sizeof(WideT)) + ? ((sizeof(WideT) - reinterpret_cast(in) % sizeof(WideT)) / sizeof(T)) + : 0; + if (skip_cnt > len) { + skip_cnt = len; + } + const WideT *in_cast = reinterpret_cast(in + skip_cnt); + const idxT len_cast = (len - skip_cnt) / items_per_scalar; + + const idxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width; + for (idxT i = tid; i < len_cast_for_sync; i += stride) { + bool valid = i < len_cast; + if (valid) { + wide.scalar = in_cast[i]; + } + const idxT real_i = skip_cnt + i * items_per_scalar; +#pragma unroll + for (int j = 0; j < items_per_scalar; ++j) { + f(wide.array[j], real_i + j, valid); + } + } + + static_assert(WARP_SIZE >= items_per_scalar); + // need at most one warp for skipped and remained elements, + // and sync_width >= WARP_SIZE + if (tid < sync_width) { + bool valid = tid < skip_cnt; + T value = valid ? in[tid] : T(); + f(value, tid, valid); + + const idxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; + valid = remain_i < len; + value = valid ? in[remain_i] : T(); + f(value, remain_i, valid); + } + } +} + +template +struct alignas(128) Counter { + // We are processing the values in multiple passes, from most significant to + // least significant. In each pass, we keep the length of input (`len`) and + // the `k` of current pass, and update them at the end of the pass. + IdxT k; + IdxT len; + + // `previous_len` is the length of input in previous pass. Note that + // `previous_len` rather than `len` is used for the filtering step because + // filtering is indeed for previous pass (see comments before + // `radix_kernel`). + IdxT previous_len; + + // We determine the bits of the k_th value inside the mask processed by the + // pass. The already known bits are stored in `kth_value_bits`. It's used to + // discriminate a element is a result (written to `out`), a candidate for next + // pass (written to `out_buf`), or not useful (discarded). The bits that are + // not yet processed do not matter for this purpose. + typename cub::Traits::UnsignedBits kth_value_bits; + + // Record how many elements have passed filtering. It's used to determine the + // position in the `out_buf` where an element should be written. + alignas(128) IdxT filter_cnt; + + // For a row inside a batch, we may launch multiple thread blocks. This + // counter is used to determine if the current block is the last running + // block. If so, this block will execute scan() and choose_bucket(). + alignas(128) unsigned int finished_block_cnt; + + // Record how many elements have been written to the front of `out`. Elements + // less (if select_min==true) than the k-th value are written from front to + // back. + alignas(128) IdxT out_cnt; + + // Record how many elements have been written to the back of `out`. Elements + // equal to the k-th value are written from back to front. We need to keep + // count of them separately because the number of elements that <= the k-th + // value might exceed k. + alignas(128) IdxT out_back_cnt; +}; + +/** + * Fused filtering of the current pass and building histogram for the next pass + * (see steps 4 & 1 in `radix_kernel` description). + */ +template +__device__ void filter_and_histogram(const T *in_buf, const IdxT *in_idx_buf, T *out_buf, + IdxT *out_idx_buf, T *out, IdxT *out_idx, IdxT previous_len, + Counter *counter, IdxT *histogram, bool select_min, + int pass, bool early_stop) { + constexpr int num_buckets = calc_num_buckets(); + __shared__ IdxT histogram_smem[num_buckets]; + for (IdxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram_smem[i] = 0; + } + __syncthreads(); + + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); + + if (pass == 0) { + // Passed to vectorized_process, this function executes in all blocks in + // parallel, i.e. the work is split along the input (both, in batches and + // chunks of a single row). Later, the histograms are merged using + // atomicAdd. + auto f = [select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); + }; + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, in_buf, previous_len, f); + } else { + IdxT *p_filter_cnt = &counter->filter_cnt; + IdxT *p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; + const int previous_start_bit = calc_start_bit(pass - 1); + + // See the remark above on the distributed execution of `f` using + // vectorized_process. + auto f = [in_idx_buf, out_buf, out_idx_buf, out, out_idx, select_min, start_bit, mask, + previous_start_bit, kth_value_bits, p_filter_cnt, p_out_cnt, + early_stop](T value, IdxT i) { + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { + if (early_stop) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if constexpr (store_out) { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else { + if (out_buf) { + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); + } + } + // the condition `(out_buf || early_stop)` is a little tricky: + // If we skip writing to `out_buf` (when `out_buf` is nullptr), we should + // skip writing to `out` too. So we won't write the same value to `out` + // multiple times in different passes. And if we keep skipping the + // writing, values will be written in `last_filter_kernel()` at last. But + // when `early_stop` is true, we need to write to `out` since it's the + // last chance. + else if ((out_buf || early_stop) && previous_bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if constexpr (store_out) { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + }; + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, in_buf, previous_len, f); + } + if (early_stop) { + return; + } + __syncthreads(); + + // merge histograms produced by individual blocks + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + if (histogram_smem[i] != 0) { + atomicAdd(histogram + i, histogram_smem[i]); + } + } +} + +/** + * Replace histogram with its own prefix sum + * (step 2 in `radix_kernel` description) + */ +template +__device__ void scan(volatile IdxT *histogram) { + constexpr int num_buckets = calc_num_buckets(); + if constexpr (num_buckets >= BlockSize) { + static_assert(num_buckets % BlockSize == 0); + constexpr int items_per_thread = num_buckets / BlockSize; + typedef cub::BlockLoad BlockLoad; + typedef cub::BlockStore + BlockStore; + typedef cub::BlockScan BlockScan; + + __shared__ union { + typename BlockLoad::TempStorage load; + typename BlockScan::TempStorage scan; + typename BlockStore::TempStorage store; + } temp_storage; + IdxT thread_data[items_per_thread]; + + BlockLoad(temp_storage.load).Load(histogram, thread_data); + __syncthreads(); + + BlockScan(temp_storage.scan).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + BlockStore(temp_storage.store).Store(histogram, thread_data); + } else { + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + IdxT thread_data = 0; + if (threadIdx.x < num_buckets) { + thread_data = histogram[threadIdx.x]; + } + + BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + if (threadIdx.x < num_buckets) { + histogram[threadIdx.x] = thread_data; + } + } +} + +/** + * Calculate in which bucket the k-th value will fall + * (steps 3 in `radix_kernel` description) + */ +template +__device__ void choose_bucket(Counter *counter, const IdxT *histogram, const IdxT k, + const int pass) { + constexpr int num_buckets = calc_num_buckets(); + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + IdxT prev = (i == 0) ? 0 : histogram[i - 1]; + IdxT cur = histogram[i]; + + // one and only one thread will satisfy this condition, so counter is + // written by only one thread + if (prev < k && cur >= k) { + counter->k = k - prev; // how many values still are there to find + counter->len = cur - prev; // number of values in next pass + typename cub::Traits::UnsignedBits bucket = i; + int start_bit = calc_start_bit(pass); + counter->kth_value_bits |= bucket << start_bit; + } + } +} + +template +__device__ void scan_warp_version(cg::thread_block_tile const &warp, + volatile IdxT *histogram, Counter *counter, const IdxT k, + const int pass) { + constexpr int num_buckets = calc_num_buckets(); + + __shared__ IdxT warp_histogram[num_buckets >> WARP_BITS]; + for (int i = threadIdx.x; i < num_buckets; i += BlockSize) { + IdxT data = histogram[i]; + IdxT warp_sum = cg::reduce(warp, data, cg::plus()); + + if (i % WARP_SIZE == 0) { + warp_histogram[i >> WARP_BITS] = warp_sum; + } + } + __syncthreads(); + + if (threadIdx.x < WARP_SIZE) { + IdxT value = warp_histogram[threadIdx.x * 2] + warp_histogram[threadIdx.x * 2 + 1]; + IdxT prefix = value; + for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { + IdxT n = __shfl_up_sync(FULL_WARP_MASK, prefix, offset, WARP_SIZE); + if (threadIdx.x >= offset) prefix += n; + } + IdxT prefix_high = __shfl_sync(FULL_WARP_MASK, prefix, threadIdx.x - 1, WARP_SIZE); + if (threadIdx.x == 0) prefix_high = 0; + warp_histogram[threadIdx.x * 2] += prefix_high; + warp_histogram[threadIdx.x * 2 + 1] = value + prefix_high; + __syncwarp(); + + // Find the target warp bucket + IdxT target_warp = 2048; // invalid value + // bool is_one_in_warp=false; + for (int i = threadIdx.x; i < 64 && target_warp == 2048; i += WARP_SIZE) { + IdxT prev = (i == 0) ? 0 : warp_histogram[i - 1]; + IdxT cur = warp_histogram[i]; + bool is_selected = prev < k && cur >= k; + unsigned mask = __ballot_sync(FULL_WARP_MASK, is_selected); + if (__popc(mask) > 0) { + // target_warp = __ffs(mask) -1 + (i/WARP_SIZE)*WARP_SIZE; + target_warp = __ffs(mask) - 1 + ((i >> WARP_BITS) << WARP_BITS); + // is_one_in_warp= (target_warp==0? warp_histogram[0]: + // warp_histogram[target_warp]-warp_histogram[target_warp-1])==1?true:false; + } + } + + // Find the target bucket + // if(is_one_in_warp){ + // bool is_one=histogram[target_warp*WARP_SIZE+threadIdx.x]==1?1:0; + // unsigned mask = __ballot_sync(FULL_WARP_MASK, is_one); + // IdxT target_bucket=__ffs(mask)-1+target_warp*WARP_SIZE; + // IdxT prev=target_warp==0? 0: warp_histogram[target_warp-1]; + // IdxT cur=warp_histogram[target_warp]; + // if(threadIdx.x==0) { + // counter->k = k - prev; // how many values still are there + // to find counter->len = cur - prev; // number of values in next + // pass typename cub::Traits::UnsignedBits bucket = + // target_bucket; int start_bit = calc_start_bit(pass); counter->kth_value_bits |= bucket << + // start_bit; + // } + // }else{ + value = histogram[(target_warp << WARP_BITS) + threadIdx.x]; + for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { + IdxT n = __shfl_up_sync(FULL_WARP_MASK, value, offset, WARP_SIZE); + if (threadIdx.x >= offset) value += n; + } + value += (target_warp == 0 ? 0 : warp_histogram[target_warp - 1]); + + for (int i = threadIdx.x; i < WARP_SIZE; i += WARP_SIZE) { + IdxT prev = __shfl_up_sync(FULL_WARP_MASK, value, 1, WARP_SIZE); + prev = (i == 0) ? (target_warp == 0 ? 0 : warp_histogram[target_warp - 1]) : prev; + IdxT cur = value; + if (prev < k && cur >= k) { + counter->k = k - prev; // how many values still are there to find + counter->len = cur - prev; // number of values in next pass + typename cub::Traits::UnsignedBits bucket = (target_warp << WARP_BITS) + i; + int start_bit = calc_start_bit(pass); + counter->kth_value_bits |= bucket << start_bit; + } + } + // } + } +} +// For one-block version, last_filter() could be called when pass < num_passes +// - 1. So `pass` could not be constexpr +template +__device__ void last_filter(const T *in_buf, const IdxT *in_idx_buf, T *out, IdxT *out_idx, + IdxT current_len, IdxT k, Counter *counter, + const bool select_min, const int pass) { + const auto kth_value_bits = counter->kth_value_bits; + const int start_bit = calc_start_bit(pass); + + // changed in choose_bucket(); need to reload + const IdxT needed_num_of_kth = counter->k; + IdxT *p_out_cnt = &counter->out_cnt; + IdxT *p_out_back_cnt = &counter->out_back_cnt; + for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) { + const T value = in_buf[i]; + const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if (bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if constexpr (store_out) { + out[pos] = value; + } + // For one-block version, `in_idx_buf` could be nullptr at pass 0. + // For non one-block version, if writing has been skipped, `in_idx_buf` + // could be nullptr if `in_buf` is `in` + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else if (bits == kth_value_bits) { + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if (back_pos < needed_num_of_kth) { + IdxT pos = k - 1 - back_pos; + if constexpr (store_out) { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + } + } +} + +template +__global__ void last_filter_kernel(const T *in, const IdxT *in_idx, const T *in_buf, + const IdxT *in_idx_buf, T *out, IdxT *out_idx, IdxT len, IdxT k, + Counter *counters, const bool select_min) { + const size_t batch_id = blockIdx.y; // size_t to avoid multiplication overflow + + Counter *counter = counters + batch_id; + IdxT previous_len = counter->previous_len; + if (previous_len == 0) { + return; + } + const IdxT buf_len = calc_buf_len(len); + if (previous_len > buf_len || in_buf == in) { + in_buf = in + batch_id * len; + in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; + previous_len = len; + } else { + in_buf += batch_id * buf_len; + in_idx_buf += batch_id * buf_len; + } + if constexpr (store_out) { + out += batch_id * k; + } + out_idx += batch_id * k; + + constexpr int pass = calc_num_passes() - 1; + constexpr int start_bit = calc_start_bit(pass); + + const auto kth_value_bits = counter->kth_value_bits; + const IdxT needed_num_of_kth = counter->k; + IdxT *p_out_cnt = &counter->out_cnt; + IdxT *p_out_back_cnt = &counter->out_back_cnt; + + auto f = [k, select_min, kth_value_bits, needed_num_of_kth, p_out_cnt, p_out_back_cnt, in_idx_buf, + out, out_idx](T value, IdxT i) { + const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if (bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if constexpr (store_out) { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else if (bits == kth_value_bits) { + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if (back_pos < needed_num_of_kth) { + IdxT pos = k - 1 - back_pos; + if constexpr (store_out) { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + } + }; + + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, in_buf, previous_len, f); +} + +/** + * + * It is expected to call this kernel multiple times (passes), in each pass we + * process a radix, going from the most significant towards the least + * significant bits (MSD). + * + * Conceptually, each pass consists of 4 steps: + * + * 1. Calculate histogram + * First, transform bits into a digit, the value of which is in the range + * [0, 2^{BITS_PER_PASS}-1]. Then count the frequency of each digit value + * and the result is a histogram. That is, histogram[i] contains the count of + * inputs having value i. + * + * 2. Scan the histogram + * Inclusive prefix sum is computed for the histogram. After this step, + * histogram[i] contains the count of inputs having value <= i. + * + * 3. Find the bucket j of the histogram that the k-th value falls into + * + * 4. Filtering + * Input elements whose digit value +__device__ void radix_kernel_func(const T *in, const IdxT *in_idx, const T *in_buf, + const IdxT *in_idx_buf, T *out_buf, IdxT *out_idx_buf, T *out, + IdxT *out_idx, Counter *counter, IdxT *histogram, + const IdxT len, const IdxT k, const bool select_min, + const int pass) { + if (len <= k) { + if (pass == 0) { + for (int index = threadIdx.x; index < len; index += BlockSize) { + if constexpr (store_out) { + out[index] = in[index]; + } + out_idx[index] = in_idx ? in_idx[index] : index; + } + for (int index = threadIdx.x + len; index < k; index += BlockSize) { + if constexpr (store_out) { + out[index] = nv_detail::float_to_T(-1.0f); + } + out_idx[index] = -1; + } + return; + } else { + return; + } + } + + IdxT current_k; + IdxT previous_len; + IdxT current_len; + if (pass == 0) { + current_k = k; + previous_len = len; + // Need to do this so setting counter->previous_len for the next pass is + // correct. This value is meaningless for pass 0, but it's fine because pass + // 0 won't be the last pass in this implementation so pass 0 won't hit the + // "if (pass == num_passes - 1)" branch. Maybe it's better to reload + // counter->previous_len and use it rather than current_len in last_filter() + current_len = len; + } else { + current_k = counter->k; + current_len = counter->len; + previous_len = counter->previous_len; + } + if (current_len == 0) { + return; + } + + // When k=len, early_stop will be true at pass 0. It means + // filter_and_histogram() should handle correctly the case that pass=0 and + // early_stop=true. However, this special case of k=len is handled in other + // way in select_k() so such case is not possible here. + const bool early_stop = (current_len == current_k); + const IdxT buf_len = calc_buf_len(len); + constexpr int num_buckets = calc_num_buckets(); + // "previous_len > buf_len" means previous pass skips writing buffer + if (pass == 0 || pass == 1 || previous_len > buf_len) { + in_buf = in; + in_idx_buf = in_idx ? in_idx : nullptr; + previous_len = len; + } + // "current_len > buf_len" means current pass will skip writing buffer + if (pass == 0 || current_len > buf_len) { + out_buf = nullptr; + out_idx_buf = nullptr; + } + + filter_and_histogram(in_buf, in_idx_buf, out_buf, out_idx_buf, + out, out_idx, previous_len, counter, + histogram, select_min, pass, early_stop); + __threadfence(); + + bool isLastBlock = false; + if (threadIdx.x == 0) { + unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); + isLastBlock = (finished == (gridDim.x - 1)); + } + + if (__syncthreads_or(isLastBlock)) { + if (early_stop) { + if (threadIdx.x == 0) { + // `last_filter_kernel()` requires setting previous_len + counter->previous_len = 0; + counter->len = 0; + } + return; + } + + constexpr int num_passes = calc_num_passes(); + + scan(histogram); + __syncthreads(); + choose_bucket(counter, histogram, current_k, pass); + __syncthreads(); + + // reset for next pass + if (pass != num_passes - 1) { + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + } + if (threadIdx.x == 0) { + // `last_filter_kernel()` requires setting previous_len even in the last + // pass + counter->previous_len = current_len; + // not necessary for the last pass, but put it here anyway + counter->filter_cnt = 0; + } + + if constexpr (fused_last_filter) { + if (pass == num_passes - 1) { + last_filter( + out_buf ? out_buf : in_buf, out_idx_buf ? out_idx_buf : in_idx_buf, out, out_idx, + out_buf ? current_len : len, k, counter, select_min, pass); + } + } + } +} + +template +__global__ void radix_kernel(const T *in, const IdxT *in_idx, const T *in_buf, + const IdxT *in_idx_buf, T *out_buf, IdxT *out_idx_buf, T *out, + IdxT *out_idx, Counter *counters, IdxT *histograms, + const IdxT len, const IdxT k, const bool select_min, const int pass, + IdxT *lengths) { + const size_t batch_id = blockIdx.y; + auto counter = counters + batch_id; + constexpr int num_buckets = calc_num_buckets(); + auto histogram = histograms + batch_id * num_buckets; + + in += batch_id * len; + if (in_idx) { + in_idx += batch_id * len; + } + if constexpr (store_out) { + out += batch_id * k; + } + out_idx += batch_id * k; + + const IdxT buf_len = calc_buf_len(len); + in_buf += batch_id * buf_len; + in_idx_buf += batch_id * buf_len; + + out_buf += batch_id * buf_len; + out_idx_buf += batch_id * buf_len; + + IdxT actual_len = len; + if (lengths != nullptr) { + actual_len = lengths[batch_id]; + } + radix_kernel_func( + in, in_idx, in_buf, in_idx_buf, out_buf, out_idx_buf, out, out_idx, counter, histogram, + actual_len, k, select_min, pass); +} + +template +unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) { + static_assert(VECTORIZED_READ_SIZE / sizeof(T) >= 1); + + int active_blocks; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &active_blocks, radix_kernel, BlockSize, 0); + active_blocks *= sm_cnt; + + IdxT best_num_blocks = 0; + float best_tail_wave_penalty = 1.0f; + const IdxT max_num_blocks = ceildiv(len, VECTORIZED_READ_SIZE / sizeof(T) * BlockSize); + for (int num_waves = 1;; ++num_waves) { + IdxT num_blocks = std::min( + max_num_blocks, static_cast(std::max(num_waves * active_blocks / batch_size, 1))); + IdxT items_per_thread = ceildiv(len, num_blocks * BlockSize); + items_per_thread = alignTo(items_per_thread, VECTORIZED_READ_SIZE / sizeof(T)); + num_blocks = ceildiv(len, items_per_thread * BlockSize); + float actual_num_waves = static_cast(num_blocks) * batch_size / active_blocks; + float tail_wave_penalty = + (ceilf(actual_num_waves) - actual_num_waves) / ceilf(actual_num_waves); + + // 0.15 is determined experimentally. It also ensures breaking the loop + // early, e.g. when num_waves > 7, tail_wave_penalty will always <0.15 + if (tail_wave_penalty < 0.15) { + best_num_blocks = num_blocks; + break; + } else if (tail_wave_penalty < best_tail_wave_penalty) { + best_num_blocks = num_blocks; + best_tail_wave_penalty = tail_wave_penalty; + } + + if (num_blocks == max_num_blocks) { + break; + } + } + return best_num_blocks; +} + +template +__host__ __device__ void set_buf_pointers(const T *in, const IdxT *in_idx, T *buf1, IdxT *idx_buf1, + T *buf2, IdxT *idx_buf2, int pass, const T *&in_buf, + const IdxT *&in_idx_buf, T *&out_buf, + IdxT *&out_idx_buf) { + if (pass == 0) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + in_buf = in; + in_idx_buf = in_idx; + out_buf = buf1; + out_idx_buf = idx_buf1; + } else if (pass % 2 == 0) { + in_buf = buf1; + in_idx_buf = idx_buf1; + out_buf = buf2; + out_idx_buf = idx_buf2; + } else { + in_buf = buf2; + in_idx_buf = idx_buf2; + out_buf = buf1; + out_idx_buf = idx_buf1; + } +} + +// The following a few functions are for the one-block version, which uses +// single thread block for each row of a batch. +template +__device__ void filter_and_histogram_for_one_block(const T *in_buf, const IdxT *in_idx_buf, + T *out_buf, IdxT *out_idx_buf, T *out, + IdxT *out_idx, Counter *counter, + IdxT *histogram, bool select_min, int pass) { + constexpr int num_buckets = calc_num_buckets(); + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + IdxT *p_filter_cnt = &counter->filter_cnt; + if (threadIdx.x == 0) { + *p_filter_cnt = 0; + } + __syncthreads(); + + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); + const IdxT previous_len = counter->previous_len; + + if (pass == 0) { + if constexpr (is_vectorized) { + auto f = [histogram, select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + }; + vectorized_process(threadIdx.x, blockDim.x, in_buf, previous_len, f); + } else { + for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + const T value = in_buf[i]; + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + } + } + } else { + // not use vectorized_process here because it increases #registers a lot + IdxT *p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; + const int previous_start_bit = calc_start_bit(pass - 1); + + for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + const T value = in_buf[i]; + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { +#if CUDART_VERSION < 12000 + // Avoiding potential compiler bug in CUDA 11 + volatile +#endif + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + } else if (previous_bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if constexpr (store_out) { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + } + } +} + +template +__device__ void radix_topk_one_block_func(const T *in, const IdxT *in_idx, const IdxT len, + const IdxT k, T *out, IdxT *out_idx, + const bool select_min, T *buf1, IdxT *idx_buf1, T *buf2, + IdxT *idx_buf2) { + if (len <= k) { + for (int index = threadIdx.x; index < len; index += BlockSize) { + if constexpr (store_out) { + out[index] = in[index]; + } + out_idx[index] = in_idx ? in_idx[index] : index; + } + for (int index = threadIdx.x + len; index < k; index += BlockSize) { + if constexpr (store_out) { + out[index] = nv_detail::float_to_T(-1.0f); + } + out_idx[index] = -1; + } + return; + } + + constexpr int num_buckets = calc_num_buckets(); + __shared__ Counter counter; + __shared__ IdxT histogram[num_buckets]; + + if (threadIdx.x == 0) { + counter.k = k; + counter.len = len; + counter.previous_len = len; + counter.kth_value_bits = 0; + counter.out_cnt = 0; + counter.out_back_cnt = 0; + } + __syncthreads(); + + // const size_t batch_id = blockIdx.x; // size_t to avoid multiplication + // overflow + const T *in_buf = nullptr; + const IdxT *in_idx_buf = nullptr; + T *out_buf = nullptr; + IdxT *out_idx_buf = nullptr; + + constexpr int num_passes = calc_num_passes(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + for (int pass = 0; pass < num_passes; ++pass) { + set_buf_pointers(in, in_idx, buf1, idx_buf1, buf2, idx_buf2, pass, in_buf, in_idx_buf, out_buf, + out_idx_buf); + + IdxT current_len = counter.len; + IdxT current_k = counter.k; + + filter_and_histogram_for_one_block( + in_buf, in_idx_buf, out_buf, out_idx_buf, out, out_idx, &counter, histogram, select_min, + pass); + __syncthreads(); + + scan(histogram); + __syncthreads(); + + choose_bucket(&counter, histogram, current_k, pass); + // scan_warp_version( + // warp, histogram, &counter, current_k, pass); + if (threadIdx.x == 0) { + counter.previous_len = current_len; + } + __syncthreads(); + + if (counter.len == counter.k || pass == num_passes - 1) { + last_filter(pass == 0 ? in : out_buf, + pass == 0 ? in_idx : out_idx_buf, out, out_idx, + current_len, k, &counter, select_min, pass); + break; + } + } // end for pass +} // end kernel + +template +__global__ void radix_topk_one_block_kernel(const T *in, const IdxT *in_idx, const IdxT len, + const IdxT k, T *out, IdxT *out_idx, + const bool select_min, T *buf1, IdxT *idx_buf1, T *buf2, + IdxT *idx_buf2, IdxT *lengths) { + const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow + IdxT actual_len = len; + if (lengths) { + actual_len = lengths[batch_id]; + } + + in += batch_id * len; + if (in_idx) { + in_idx += batch_id * len; + } + + out += batch_id * k; + out_idx += batch_id * k; + buf1 += batch_id * len; + idx_buf1 += batch_id * len; + buf2 += batch_id * len; + idx_buf2 += batch_id * len; + + radix_topk_one_block_func( + in, in_idx, actual_len, k, out, out_idx, select_min, buf1, idx_buf1, buf2, idx_buf2); +} // end kernel + +} // namespace air_topk + +/***************Runtime API****************/ + +template +void standalone_radix_topk_(void *buf, size_t &buf_size, const T *in, const IdxT *in_idx, + int batch_size, IdxT len, IdxT k, T *out, IdxT *out_idx, + bool select_min, bool fused_last_filter, unsigned grid_dim, + cudaStream_t stream, IdxT *lengths = nullptr) { + static_assert(air_topk::calc_num_passes() > 1); + constexpr int num_buckets = air_topk::calc_num_buckets(); + + air_topk::Counter *counters = nullptr; + IdxT *histograms = nullptr; + T *buf1 = nullptr; + IdxT *idx_buf1 = nullptr; + T *buf2 = nullptr; + IdxT *idx_buf2 = nullptr; + { + IdxT len_candidates = air_topk::calc_buf_len(len); + std::vector sizes = {sizeof(*counters) * batch_size, + sizeof(*histograms) * num_buckets * batch_size, + sizeof(*buf1) * len_candidates * batch_size, + sizeof(*idx_buf1) * len_candidates * batch_size, + sizeof(*buf2) * len_candidates * batch_size, + sizeof(*idx_buf2) * len_candidates * batch_size}; + size_t total_size = calc_aligned_size(sizes); + if (!buf) { + buf_size = total_size; + return; + } + + std::vector aligned_pointers = calc_aligned_pointers(buf, sizes); + counters = static_cast(aligned_pointers[0]); + histograms = static_cast(aligned_pointers[1]); + buf1 = static_cast(aligned_pointers[2]); + idx_buf1 = static_cast(aligned_pointers[3]); + buf2 = static_cast(aligned_pointers[4]); + idx_buf2 = static_cast(aligned_pointers[5]); + + cudaMemsetAsync( + buf, 0, static_cast(aligned_pointers[2]) - static_cast(aligned_pointers[0]), + stream); + } + + const T *in_buf = nullptr; + const IdxT *in_idx_buf = nullptr; + T *out_buf = nullptr; + IdxT *out_idx_buf = nullptr; + + dim3 blocks(grid_dim, batch_size); + + constexpr int num_passes = air_topk::calc_num_passes(); + + auto kernel = air_topk::radix_kernel; + + for (int pass = 0; pass < num_passes; ++pass) { + air_topk::set_buf_pointers(in, in_idx, buf1, idx_buf1, buf2, idx_buf2, pass, in_buf, in_idx_buf, + out_buf, out_idx_buf); + + if (fused_last_filter && pass == num_passes - 1 && out != nullptr) { + kernel = air_topk::radix_kernel; + } else if (fused_last_filter && pass == num_passes - 1 && out == nullptr) { + kernel = air_topk::radix_kernel; + } else if (out == nullptr) { + kernel = air_topk::radix_kernel; + } + + kernel<<>>(in, in_idx, in_buf, in_idx_buf, out_buf, out_idx_buf, + out, out_idx, counters, histograms, len, k, select_min, + pass, lengths); + } + + if (!fused_last_filter) { + if (out != nullptr) { + air_topk::last_filter_kernel<<>>( + in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters, select_min); + } else { + air_topk::last_filter_kernel<<>>( + in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters, select_min); + } + } +} + +template +void standalone_radix_topk_one_block_(void *buf, size_t &buf_size, const T *in, const IdxT *in_idx, + int batch_size, IdxT len, IdxT k, T *out, IdxT *out_idx, + bool select_min, cudaStream_t stream, + IdxT *lengths = nullptr) { + static_assert(air_topk::calc_num_passes() > 1); + + T *buf1 = nullptr; + IdxT *idx_buf1 = nullptr; + T *buf2 = nullptr; + IdxT *idx_buf2 = nullptr; + { + std::vector sizes = { + sizeof(*buf1) * len * batch_size, sizeof(*idx_buf1) * len * batch_size, + sizeof(*buf2) * len * batch_size, sizeof(*idx_buf2) * len * batch_size}; + size_t total_size = calc_aligned_size(sizes); + if (!buf) { + buf_size = total_size; + return; + } + + std::vector aligned_pointers = calc_aligned_pointers(buf, sizes); + buf1 = static_cast(aligned_pointers[0]); + idx_buf1 = static_cast(aligned_pointers[1]); + buf2 = static_cast(aligned_pointers[2]); + idx_buf2 = static_cast(aligned_pointers[3]); + } + + if (out != nullptr) { + air_topk::radix_topk_one_block_kernel + <<>>(in, in_idx, len, k, out, out_idx, select_min, buf1, + idx_buf1, buf2, idx_buf2, lengths); + } else { + air_topk::radix_topk_one_block_kernel + <<>>(in, in_idx, len, k, out, out_idx, select_min, buf1, + idx_buf1, buf2, idx_buf2, lengths); + } +} + +template +void standalone_air_topk(void *buf, size_t &buf_size, const T *in, int batch_size, idxT len, idxT k, + T *out, idxT *out_idx, bool greater, cudaStream_t stream = 0, + idxT *lengths = nullptr, bool is_prefill = false) { + constexpr int items_per_thread = 32; + constexpr int multi_block_dim = 256; + constexpr int single_block_dim = 1024; + constexpr bool fused_last_filter = false; + if (len <= single_block_dim * items_per_thread || is_prefill) { + standalone_radix_topk_one_block_( + buf, buf_size, in, static_cast(nullptr), batch_size, len, k, out, out_idx, !greater, + stream, lengths); + } else { + int sm_cnt; + { + int dev; + NVTE_CHECK_CUDA(cudaGetDevice(&dev)); + NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev)); + } + unsigned grid_dim = + air_topk::calc_grid_dim(batch_size, len, sm_cnt); + + if (grid_dim == 1) { + standalone_radix_topk_one_block_( + buf, buf_size, in, static_cast(nullptr), batch_size, len, k, out, out_idx, + !greater, stream, lengths); + } else { + standalone_radix_topk_( + buf, buf_size, in, static_cast(nullptr), batch_size, len, k, out, out_idx, + !greater, fused_last_filter, grid_dim, stream, lengths); + } + } +} +} // namespace nv diff --git a/transformer_engine/jax/cpp_extensions/air_topk.py b/transformer_engine/jax/cpp_extensions/air_topk.py new file mode 100644 index 0000000000..bdc084ee8d --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/air_topk.py @@ -0,0 +1,136 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""AIR TopK custom op""" + +import functools +from typing import Tuple + +import jax +import jax.numpy as jnp +from jax import dtypes, ffi + +from .base import BasePrimitive, register_primitive + +__all__ = ["air_topk"] + + +@functools.lru_cache(maxsize=512) +def get_air_topk_workspace_bytes(batch_size: int, seq_len: int, k: int) -> int: + """Query the workspace size required for AIR TopK. + + The result is memoised per (batch_size, seq_len, k) tuple so that repeated + JIT compilations with the same shapes incur only one host-side CUDA call. + """ + import transformer_engine_jax as _te_jax + + return int(_te_jax.get_air_topk_workspace_bytes(batch_size, seq_len, k)) + + +class AirTopKPrimitive(BasePrimitive): + """ + AIR TopK Primitive + + Selects the top-k entries (by value) from each row of a 2-D input using the + AIR radix-selection algorithm. Returns both the top-k key values and their + column indices within each row. + """ + + name = "te_air_topk_ffi" + multiple_results = True + impl_static_args = (2,) # k_value + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + in_keys_aval, + in_lengths_aval, + *, + k_value, + ): + keys_dtype = dtypes.canonicalize_dtype(in_keys_aval.dtype) + assert keys_dtype in [jnp.float32, jnp.bfloat16], ( + f"air_topk: unsupported key dtype {keys_dtype}; supported: float32, bfloat16" + ) + assert in_keys_aval.ndim == 2, "air_topk: keys input must be 2D (batch_size, seq_len)" + assert dtypes.canonicalize_dtype(in_lengths_aval.dtype) == jnp.int32 + + batch_size, seq_len = in_keys_aval.shape + workspace_bytes = get_air_topk_workspace_bytes(batch_size, seq_len, k_value) + + out_shape = (batch_size, k_value) + out_keys_aval = jax.core.ShapedArray(shape=out_shape, dtype=keys_dtype) + out_indices_aval = jax.core.ShapedArray(shape=out_shape, dtype=jnp.int32) + workspace_aval = jax.core.ShapedArray(shape=(workspace_bytes,), dtype=jnp.uint8) + return (out_keys_aval, out_indices_aval, workspace_aval) + + @staticmethod + def outer_abstract(*args, **kwargs): + out_keys_aval, out_indices_aval, _workspace_aval = AirTopKPrimitive.abstract( + *args, **kwargs + ) + return (out_keys_aval, out_indices_aval) + + @staticmethod + def lowering(ctx, in_keys, in_lengths, k_value): + keys_aval = ctx.avals_in[0] + batch_size, seq_len = keys_aval.shape + workspace_bytes = get_air_topk_workspace_bytes(batch_size, seq_len, k_value) + return ffi.ffi_lowering(AirTopKPrimitive.name)( + ctx, + in_keys, + in_lengths, + k_value=k_value, + workbuf_bytes=workspace_bytes, + ) + + @staticmethod + def impl(in_keys, in_lengths, k_value): + assert AirTopKPrimitive.inner_primitive is not None + out_keys, out_indices, _workspace = AirTopKPrimitive.inner_primitive.bind( + in_keys, + in_lengths, + k_value=k_value, + ) + return (out_keys, out_indices) + + +register_primitive(AirTopKPrimitive) + + +def air_topk( + x: jnp.ndarray, + k_value: int, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Select the top-k largest entries from each row using the AIR radix algorithm. + + Args: + x: Input array of shape ``(batch_size, seq_len)`` or ``(seq_len,)``. + Supported dtypes: ``float32``, ``bfloat16``. + k_value: Number of top entries to select per row. + + Returns: + A tuple ``(values, indices)`` where both arrays have shape + ``(batch_size, k_value)`` (or ``(k_value,)`` for 1-D input). The + outputs are *unordered*: use ``jax.lax.sort_key_val`` if a sorted result + is required. ``indices`` are the column positions within the original row. + """ + squeezed = x.ndim == 1 + if squeezed: + x = x[jnp.newaxis, :] # (1, seq_len) + + batch_size, seq_len = x.shape + lengths = jnp.full((batch_size,), seq_len, dtype=jnp.int32) + + out_keys, out_indices = AirTopKPrimitive.outer_primitive.bind( + x, + lengths, + k_value=k_value, + ) + + if squeezed: + out_keys = out_keys[0] # (k_value,) + out_indices = out_indices[0] # (k_value,) + + return out_keys, out_indices diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 0fe4e99239..a7960b6882 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -171,6 +171,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); +// AIR TopK +XLA_FFI_DECLARE_HANDLER_SYMBOL(AirTopkHandler); +int64_t GetAirTopkWorkspaceBytes(int batch_size, int seq_len, int k); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/air_topk.cpp b/transformer_engine/jax/csrc/extensions/air_topk.cpp new file mode 100644 index 0000000000..10c3ae7748 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/air_topk.cpp @@ -0,0 +1,92 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/air_topk.h" + +#include "../extensions.h" +#include "xla/ffi/api/c_api.h" + +namespace transformer_engine { +namespace jax { + +// --------------------------------------------------------------------------- +// JAX FFI handler +// --------------------------------------------------------------------------- + +Error_Type AirTopkFFI(cudaStream_t stream, Buffer_Type keys_in_buf, Buffer_Type lengths_buf, + Result_Type keys_out_buf, Result_Type indices_out_buf, + Result_Type workspace_buf, int64_t k_value, int64_t workbuf_bytes) { + auto keys_in_dtype = convert_ffi_datatype_to_te_dtype(keys_in_buf.element_type()); + auto keys_out_dtype = convert_ffi_datatype_to_te_dtype(keys_out_buf->element_type()); + auto idx_out_dtype = convert_ffi_datatype_to_te_dtype(indices_out_buf->element_type()); + NVTE_CHECK(keys_in_dtype == keys_out_dtype, "AirTopkFFI: input and output key dtypes must match"); + NVTE_CHECK(idx_out_dtype == DType::kInt32, "AirTopkFFI: index output must be int32"); + + auto keys_in_shape = keys_in_buf.dimensions(); + NVTE_CHECK(keys_in_shape.size() == 2, "AirTopkFFI: keys input must be 2D (batch_size, seq_len)"); + + int batch_size = static_cast(keys_in_shape[0]); + int seq_len = static_cast(keys_in_shape[1]); + int k = static_cast(k_value); + + // Element byte widths for computing flat buffer sizes. + size_t keys_element_bytes; + switch (keys_in_dtype) { + case DType::kFloat32: + keys_element_bytes = 4; + break; + case DType::kBFloat16: + keys_element_bytes = 2; + break; + default: + NVTE_ERROR("AirTopkFFI: unsupported key dtype (float32 and bfloat16 only)"); + } + + // Build flat TensorWrappers over the full (batch_size * seq_len) / (batch_size * k) buffers. + auto flat_in_shape = + std::vector{static_cast(batch_size) * static_cast(seq_len)}; + auto flat_out_shape = + std::vector{static_cast(batch_size) * static_cast(k)}; + auto len_shape = std::vector{static_cast(batch_size)}; + auto ws_shape = std::vector{static_cast(workbuf_bytes)}; + + auto keys_in_tensor = TensorWrapper(keys_in_buf.untyped_data(), flat_in_shape, keys_in_dtype); + auto lengths_tensor = TensorWrapper(lengths_buf.untyped_data(), len_shape, DType::kInt32); + auto keys_out_tensor = + TensorWrapper(keys_out_buf->untyped_data(), flat_out_shape, keys_out_dtype); + auto idx_out_tensor = + TensorWrapper(indices_out_buf->untyped_data(), flat_out_shape, DType::kInt32); + auto workspace_tensor = TensorWrapper(workspace_buf->untyped_data(), ws_shape, DType::kByte); + + nvte_air_topk(stream, keys_in_tensor.data(), lengths_tensor.data(), keys_out_tensor.data(), + idx_out_tensor.data(), workspace_tensor.data(), batch_size, seq_len, k, + static_cast(workbuf_bytes)); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(AirTopkHandler, AirTopkFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // keys_in + .Arg() // lengths + .Ret() // keys_out + .Ret() // indices_out + .Ret() // workspace + .Attr("k_value") + .Attr("workbuf_bytes"), + FFI_CudaGraph_Traits); + +// --------------------------------------------------------------------------- +// Workspace-size query exposed to Python +// --------------------------------------------------------------------------- + +int64_t GetAirTopkWorkspaceBytes(int batch_size, int seq_len, int k) { + return static_cast(nvte_get_air_topk_workspace_bytes(batch_size, seq_len, k)); +} + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 28cb39b5d1..f1be141dc1 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -100,6 +100,9 @@ pybind11::dict Registrations() { dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler); dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler); + // AIR TopK + dict["te_air_topk_ffi"] = EncapsulateFFI(AirTopkHandler); + return dict; } @@ -117,6 +120,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_norm_bwd_workspace_sizes", &GetNormBackwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); + m.def("get_air_topk_workspace_bytes", &GetAirTopkWorkspaceBytes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator); From 470be3c9228ccd30ea23ea2ee05b13c22e736ec7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Apr 2026 04:20:05 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 22 +++++++++++-------- .../jax/cpp_extensions/air_topk.py | 15 +++++++------ 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 0d35058a1a..860623f3e3 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1983,15 +1983,17 @@ def test_air_topk_1d(self, dtype, problem_size): prng_key = jax.random.PRNGKey(0) keys = jax.random.split(prng_key, 3) - topk_vals = jax.random.uniform(keys[0], shape=(k,), dtype=dtype, minval=1.5, maxval=2.5) - bottom_vals = jax.random.uniform(keys[1], shape=(n - k,), dtype=dtype, minval=0.0, maxval=1.0) + topk_vals = jax.random.uniform(keys[0], shape=(k,), dtype=dtype, minval=1.5, maxval=2.5) + bottom_vals = jax.random.uniform( + keys[1], shape=(n - k,), dtype=dtype, minval=0.0, maxval=1.0 + ) x = jax.random.permutation(keys[2], jnp.concatenate([topk_vals, bottom_vals])) ref_vals, ref_idx = jax.jit(jax.lax.top_k, static_argnums=(1,))(x, k) prim_vals, prim_idx = jax.jit(air_topk, static_argnums=(1,))(x, k) # AIR TopK output is unordered; sort before comparing. - ref_vals, ref_idx = jax.lax.sort_key_val(ref_vals, ref_idx) + ref_vals, ref_idx = jax.lax.sort_key_val(ref_vals, ref_idx) prim_vals, prim_idx = jax.lax.sort_key_val(prim_vals, prim_idx) assert_allclose(prim_vals, ref_vals, dtype=dtype) @@ -2008,18 +2010,20 @@ def test_air_topk_2d(self, dtype, problem_size): prng_key = jax.random.PRNGKey(42) keys = jax.random.split(prng_key, 3) - topk_vals = jax.random.uniform(keys[0], shape=(bs, k), dtype=dtype, minval=1.5, maxval=2.5) - bottom_vals = jax.random.uniform(keys[1], shape=(bs, n - k), dtype=dtype, minval=0.0, maxval=1.0) - x_unsorted = jnp.concatenate([topk_vals, bottom_vals], axis=1) + topk_vals = jax.random.uniform(keys[0], shape=(bs, k), dtype=dtype, minval=1.5, maxval=2.5) + bottom_vals = jax.random.uniform( + keys[1], shape=(bs, n - k), dtype=dtype, minval=0.0, maxval=1.0 + ) + x_unsorted = jnp.concatenate([topk_vals, bottom_vals], axis=1) # Shuffle columns independently per row. col_perm = jax.random.permutation(keys[2], n) x = x_unsorted[:, col_perm] - ref_vals, ref_idx = jax.jit(jax.lax.top_k, static_argnums=(1,))(x, k) + ref_vals, ref_idx = jax.jit(jax.lax.top_k, static_argnums=(1,))(x, k) prim_vals, prim_idx = jax.jit(air_topk, static_argnums=(1,))(x, k) # Sort each row independently for comparison. - ref_vals, ref_idx = jax.vmap(jax.lax.sort_key_val)(ref_vals, ref_idx) + ref_vals, ref_idx = jax.vmap(jax.lax.sort_key_val)(ref_vals, ref_idx) prim_vals, prim_idx = jax.vmap(jax.lax.sort_key_val)(prim_vals, prim_idx) assert_allclose(prim_vals, ref_vals, dtype=dtype) @@ -2030,5 +2034,5 @@ def test_air_topk_2d(self, dtype, problem_size): # Values at returned indices must match reference values. prim_gathered = jnp.take_along_axis(x, prim_idx, axis=1) - ref_gathered = jnp.take_along_axis(x, ref_idx, axis=1) + ref_gathered = jnp.take_along_axis(x, ref_idx, axis=1) assert_allclose(prim_gathered, ref_gathered, dtype=dtype) diff --git a/transformer_engine/jax/cpp_extensions/air_topk.py b/transformer_engine/jax/cpp_extensions/air_topk.py index bdc084ee8d..f9189faaa1 100644 --- a/transformer_engine/jax/cpp_extensions/air_topk.py +++ b/transformer_engine/jax/cpp_extensions/air_topk.py @@ -50,9 +50,10 @@ def abstract( k_value, ): keys_dtype = dtypes.canonicalize_dtype(in_keys_aval.dtype) - assert keys_dtype in [jnp.float32, jnp.bfloat16], ( - f"air_topk: unsupported key dtype {keys_dtype}; supported: float32, bfloat16" - ) + assert keys_dtype in [ + jnp.float32, + jnp.bfloat16, + ], f"air_topk: unsupported key dtype {keys_dtype}; supported: float32, bfloat16" assert in_keys_aval.ndim == 2, "air_topk: keys input must be 2D (batch_size, seq_len)" assert dtypes.canonicalize_dtype(in_lengths_aval.dtype) == jnp.int32 @@ -60,9 +61,9 @@ def abstract( workspace_bytes = get_air_topk_workspace_bytes(batch_size, seq_len, k_value) out_shape = (batch_size, k_value) - out_keys_aval = jax.core.ShapedArray(shape=out_shape, dtype=keys_dtype) - out_indices_aval = jax.core.ShapedArray(shape=out_shape, dtype=jnp.int32) - workspace_aval = jax.core.ShapedArray(shape=(workspace_bytes,), dtype=jnp.uint8) + out_keys_aval = jax.core.ShapedArray(shape=out_shape, dtype=keys_dtype) + out_indices_aval = jax.core.ShapedArray(shape=out_shape, dtype=jnp.int32) + workspace_aval = jax.core.ShapedArray(shape=(workspace_bytes,), dtype=jnp.uint8) return (out_keys_aval, out_indices_aval, workspace_aval) @staticmethod @@ -130,7 +131,7 @@ def air_topk( ) if squeezed: - out_keys = out_keys[0] # (k_value,) + out_keys = out_keys[0] # (k_value,) out_indices = out_indices[0] # (k_value,) return out_keys, out_indices From 1e6c976db681a36eecc12fd9e36d6bf2b3ff1be0 Mon Sep 17 00:00:00 2001 From: dcampora <961215+dcampora@users.noreply.github.com> Date: Thu, 16 Apr 2026 05:15:39 +0000 Subject: [PATCH 3/6] Address PR review comments: fix namespace pollution, unused var, missing export, cache sm_cnt - Move WARP_SIZE/WARP_BITS/FULL_WARP_MASK/VECTORIZED_READ_SIZE into namespace nv - Remove unused keys_element_bytes variable in AirTopkFFI; collapse switch to dtype validation - Add missing `from .air_topk import *` export in jax/cpp_extensions/__init__.py - Cache sm_cnt per device with static vars to avoid repeated cudaGetDevice/cudaDeviceGetAttribute calls - Add CMAKE_BUILD_WITH_INSTALL_RPATH=ON to build_ext.py Signed-off-by: dcampora <961215+dcampora@users.noreply.github.com> Co-Authored-By: Claude Sonnet 4.6 --- build_tools/build_ext.py | 1 + .../common/util/standalone_air_topk.cuh | 19 +++++++++++++------ .../jax/cpp_extensions/__init__.py | 1 + .../jax/csrc/extensions/air_topk.cpp | 6 +----- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index cbb8838b00..2fb7562a61 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -60,6 +60,7 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None: f"-DPython_SITEARCH={sysconfig.get_path('platlib')}", f"-DCMAKE_BUILD_TYPE={build_type}", f"-DCMAKE_INSTALL_PREFIX={install_dir}", + "-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON", ] if bool(int(os.getenv("NVTE_USE_CCACHE", "0"))): ccache_bin = os.getenv("NVTE_CCACHE_BIN", "ccache") diff --git a/transformer_engine/common/util/standalone_air_topk.cuh b/transformer_engine/common/util/standalone_air_topk.cuh index 0a78342273..7ec02df8f4 100644 --- a/transformer_engine/common/util/standalone_air_topk.cuh +++ b/transformer_engine/common/util/standalone_air_topk.cuh @@ -6,11 +6,6 @@ #pragma once -constexpr int VECTORIZED_READ_SIZE = 16; -constexpr int WARP_SIZE = 32; -constexpr int WARP_BITS = 5; -constexpr unsigned FULL_WARP_MASK = 0xffffffff; - #include #include #include @@ -59,6 +54,11 @@ __host__ __device__ inline __nv_bfloat16 float_to_T<__nv_bfloat16>(float v) { namespace nv { +constexpr int VECTORIZED_READ_SIZE = 16; +constexpr int WARP_SIZE = 32; +constexpr int WARP_BITS = 5; +constexpr unsigned FULL_WARP_MASK = 0xffffffff; + namespace air_topk { using WideT = float4; @@ -1251,11 +1251,18 @@ void standalone_air_topk(void *buf, size_t &buf_size, const T *in, int batch_siz buf, buf_size, in, static_cast(nullptr), batch_size, len, k, out, out_idx, !greater, stream, lengths); } else { + // Cache sm_cnt per device to avoid repeated host-side queries. + static int cached_dev = -1; + static int cached_sm_cnt = -1; int sm_cnt; { int dev; NVTE_CHECK_CUDA(cudaGetDevice(&dev)); - NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev)); + if (dev != cached_dev) { + NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&cached_sm_cnt, cudaDevAttrMultiProcessorCount, dev)); + cached_dev = dev; + } + sm_cnt = cached_sm_cnt; } unsigned grid_dim = air_topk::calc_grid_dim(batch_size, len, sm_cnt); diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index d203fcea9d..bbea60e345 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -10,3 +10,4 @@ from .softmax import * from .gemm import * from .router import * +from .air_topk import * diff --git a/transformer_engine/jax/csrc/extensions/air_topk.cpp b/transformer_engine/jax/csrc/extensions/air_topk.cpp index 10c3ae7748..1fef1434c1 100644 --- a/transformer_engine/jax/csrc/extensions/air_topk.cpp +++ b/transformer_engine/jax/csrc/extensions/air_topk.cpp @@ -32,14 +32,10 @@ Error_Type AirTopkFFI(cudaStream_t stream, Buffer_Type keys_in_buf, Buffer_Type int seq_len = static_cast(keys_in_shape[1]); int k = static_cast(k_value); - // Element byte widths for computing flat buffer sizes. - size_t keys_element_bytes; + // Validate key dtype (float32 and bfloat16 only). switch (keys_in_dtype) { case DType::kFloat32: - keys_element_bytes = 4; - break; case DType::kBFloat16: - keys_element_bytes = 2; break; default: NVTE_ERROR("AirTopkFFI: unsupported key dtype (float32 and bfloat16 only)"); From 1b328a21b59babc4e0093e7c732ed8b40882a8ca Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Apr 2026 05:18:14 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/util/standalone_air_topk.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/util/standalone_air_topk.cuh b/transformer_engine/common/util/standalone_air_topk.cuh index 7ec02df8f4..6ef8dbafe1 100644 --- a/transformer_engine/common/util/standalone_air_topk.cuh +++ b/transformer_engine/common/util/standalone_air_topk.cuh @@ -1259,7 +1259,8 @@ void standalone_air_topk(void *buf, size_t &buf_size, const T *in, int batch_siz int dev; NVTE_CHECK_CUDA(cudaGetDevice(&dev)); if (dev != cached_dev) { - NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&cached_sm_cnt, cudaDevAttrMultiProcessorCount, dev)); + NVTE_CHECK_CUDA( + cudaDeviceGetAttribute(&cached_sm_cnt, cudaDevAttrMultiProcessorCount, dev)); cached_dev = dev; } sm_cnt = cached_sm_cnt; From 897156e91350fe36dc51431cb307e41e9ac3dc26 Mon Sep 17 00:00:00 2001 From: dcampora <961215+dcampora@users.noreply.github.com> Date: Thu, 16 Apr 2026 05:33:54 +0000 Subject: [PATCH 5/6] Rename air_topk -> topk throughout JAX extension Remove the `air_` prefix from all TopK-related identifiers: file names, C API functions (nvte_air_topk -> nvte_topk), FFI handler/primitive names (te_air_topk_ffi -> te_topk_ffi), Python symbols, and the internal `air_topk` namespace in standalone_topk.cuh. No functional changes. Signed-off-by: Diego Campora Signed-off-by: dcampora <961215+dcampora@users.noreply.github.com> --- build_tools/build_ext.py | 1 - tests/jax/test_custom_call_compute.py | 14 ++--- transformer_engine/common/CMakeLists.txt | 2 +- .../transformer_engine/{air_topk.h => topk.h} | 18 +++--- transformer_engine/common/util/air_topk.cu | 57 ------------------- ...alone_air_topk.cuh => standalone_topk.cuh} | 38 ++++++------- transformer_engine/common/util/topk.cu | 57 +++++++++++++++++++ .../jax/cpp_extensions/__init__.py | 2 +- .../cpp_extensions/{air_topk.py => topk.py} | 38 ++++++------- transformer_engine/jax/csrc/extensions.h | 6 +- .../jax/csrc/extensions/pybind.cpp | 6 +- .../extensions/{air_topk.cpp => topk.cpp} | 28 ++++----- 12 files changed, 133 insertions(+), 134 deletions(-) rename transformer_engine/common/include/transformer_engine/{air_topk.h => topk.h} (75%) delete mode 100644 transformer_engine/common/util/air_topk.cu rename transformer_engine/common/util/{standalone_air_topk.cuh => standalone_topk.cuh} (97%) create mode 100644 transformer_engine/common/util/topk.cu rename transformer_engine/jax/cpp_extensions/{air_topk.py => topk.py} (75%) rename transformer_engine/jax/csrc/extensions/{air_topk.cpp => topk.cpp} (73%) diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 2fb7562a61..cbb8838b00 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -60,7 +60,6 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None: f"-DPython_SITEARCH={sysconfig.get_path('platlib')}", f"-DCMAKE_BUILD_TYPE={build_type}", f"-DCMAKE_INSTALL_PREFIX={install_dir}", - "-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON", ] if bool(int(os.getenv("NVTE_USE_CCACHE", "0"))): ccache_bin = os.getenv("NVTE_CCACHE_BIN", "ccache") diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 860623f3e3..547a26e24a 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -47,7 +47,7 @@ from transformer_engine.jax.activation import activation from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense -from transformer_engine.jax.cpp_extensions.air_topk import air_topk +from transformer_engine.jax.cpp_extensions.topk import topk GEMM_CASES = [ (256, 256, 512), @@ -1969,15 +1969,15 @@ def f(x): (1, 1000000, 1000), ], ) -class TestAirTopK: - """Correctness tests for the AIR TopK JAX primitive. +class TestTopK: + """Correctness tests for the TopK JAX primitive. Each test generates an input whose top-k entries lie in a known value range so that correctness can be verified without a full sort, then cross-checks against jax.lax.top_k as a reference. """ - def test_air_topk_1d(self, dtype, problem_size): + def test_topk_1d(self, dtype, problem_size): """1-D input: single row.""" _bs, n, k = problem_size @@ -1990,7 +1990,7 @@ def test_air_topk_1d(self, dtype, problem_size): x = jax.random.permutation(keys[2], jnp.concatenate([topk_vals, bottom_vals])) ref_vals, ref_idx = jax.jit(jax.lax.top_k, static_argnums=(1,))(x, k) - prim_vals, prim_idx = jax.jit(air_topk, static_argnums=(1,))(x, k) + prim_vals, prim_idx = jax.jit(topk, static_argnums=(1,))(x, k) # AIR TopK output is unordered; sort before comparing. ref_vals, ref_idx = jax.lax.sort_key_val(ref_vals, ref_idx) @@ -2004,7 +2004,7 @@ def test_air_topk_1d(self, dtype, problem_size): # Values at returned indices must match reference. assert_allclose(x[prim_idx], x[ref_idx], dtype=dtype) - def test_air_topk_2d(self, dtype, problem_size): + def test_topk_2d(self, dtype, problem_size): """2-D input: each row is an independent top-k problem.""" bs, n, k = problem_size @@ -2020,7 +2020,7 @@ def test_air_topk_2d(self, dtype, problem_size): x = x_unsorted[:, col_perm] ref_vals, ref_idx = jax.jit(jax.lax.top_k, static_argnums=(1,))(x, k) - prim_vals, prim_idx = jax.jit(air_topk, static_argnums=(1,))(x, k) + prim_vals, prim_idx = jax.jit(topk, static_argnums=(1,))(x, k) # Sort each row independently for comparison. ref_vals, ref_idx = jax.vmap(jax.lax.sort_key_val)(ref_vals, ref_idx) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 761c0eae8a..e578dffbfe 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -156,7 +156,7 @@ list(APPEND transformer_engine_cuda_sources normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu permutation/permutation.cu util/padding.cu - util/air_topk.cu + util/topk.cu swizzle/swizzle.cu swizzle/swizzle_block_scaling.cu fused_softmax/scaled_masked_softmax.cu diff --git a/transformer_engine/common/include/transformer_engine/air_topk.h b/transformer_engine/common/include/transformer_engine/topk.h similarity index 75% rename from transformer_engine/common/include/transformer_engine/air_topk.h rename to transformer_engine/common/include/transformer_engine/topk.h index 8ff48f1854..6dfca850ec 100644 --- a/transformer_engine/common/include/transformer_engine/air_topk.h +++ b/transformer_engine/common/include/transformer_engine/topk.h @@ -4,8 +4,8 @@ * See LICENSE for license information. ************************************************************************/ -#ifndef TRANSFORMER_ENGINE_AIR_TOPK_H_ -#define TRANSFORMER_ENGINE_AIR_TOPK_H_ +#ifndef TRANSFORMER_ENGINE_TOPK_H_ +#define TRANSFORMER_ENGINE_TOPK_H_ #include "transformer_engine.h" @@ -32,26 +32,26 @@ extern "C" { * \param[in] seq_len Number of elements per row. * \param[in] k Number of top-K entries to select per row. * \param[in] workspace_bytes Workspace size in bytes; must be >= - * nvte_get_air_topk_workspace_bytes(batch_size, seq_len, k). + * nvte_get_topk_workspace_bytes(batch_size, seq_len, k). * * Supported key dtypes: float32, bfloat16. * Index dtype: int32. */ -void nvte_air_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor lengths_in, - NVTETensor keys_out, NVTETensor indices_out, NVTETensor workspace, - int batch_size, int seq_len, int k, size_t workspace_bytes); +void nvte_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor lengths_in, + NVTETensor keys_out, NVTETensor indices_out, NVTETensor workspace, + int batch_size, int seq_len, int k, size_t workspace_bytes); -/*! \brief Query the workspace size required by nvte_air_topk. +/*! \brief Query the workspace size required by nvte_topk. * * \param[in] batch_size Number of rows. * \param[in] seq_len Number of elements per row. * \param[in] k Top-K count. * \return Required workspace size in bytes. */ -size_t nvte_get_air_topk_workspace_bytes(int batch_size, int seq_len, int k); +size_t nvte_get_topk_workspace_bytes(int batch_size, int seq_len, int k); #ifdef __cplusplus } // extern "C" #endif -#endif // TRANSFORMER_ENGINE_AIR_TOPK_H_ +#endif // TRANSFORMER_ENGINE_TOPK_H_ diff --git a/transformer_engine/common/util/air_topk.cu b/transformer_engine/common/util/air_topk.cu deleted file mode 100644 index aaae7af795..0000000000 --- a/transformer_engine/common/util/air_topk.cu +++ /dev/null @@ -1,57 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include - -#include "../common.h" -#include "standalone_air_topk.cuh" - -void nvte_air_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor lengths_in, - NVTETensor keys_out, NVTETensor indices_out, NVTETensor workspace, - int batch_size, int seq_len, int k, size_t workspace_bytes) { - NVTE_API_CALL(nvte_air_topk); - using namespace transformer_engine; - - const Tensor *keys_in_tensor = convertNVTETensorCheck(keys_in); - const Tensor *lengths_tensor = convertNVTETensorCheck(lengths_in); - Tensor *keys_out_tensor = convertNVTETensor(keys_out); - Tensor *indices_tensor = convertNVTETensor(indices_out); - Tensor *workspace_tensor = convertNVTETensor(workspace); - - void *d_workspace = workspace_tensor->data.dptr; - const int *d_lengths = reinterpret_cast(lengths_tensor->data.dptr); - int *d_indices = reinterpret_cast(indices_tensor->data.dptr); - - auto dtype = keys_in_tensor->data.dtype; - -#define DISPATCH_AIR_TOPK(T, d_in_cast, d_out_cast) \ - do { \ - const T *d_in = reinterpret_cast(keys_in_tensor->data.dptr); \ - T *d_out = reinterpret_cast(keys_out_tensor->data.dptr); \ - nv::standalone_air_topk(d_workspace, workspace_bytes, d_in, batch_size, seq_len, k, \ - d_out, d_indices, /*greater=*/true, stream, \ - const_cast(d_lengths), /*is_prefill=*/false); \ - } while (0) - - if (dtype == DType::kBFloat16) { - DISPATCH_AIR_TOPK(__nv_bfloat16, , ); - } else if (dtype == DType::kFloat32) { - DISPATCH_AIR_TOPK(float, , ); - } else { - NVTE_ERROR("nvte_air_topk: unsupported key dtype (supported: float32, bfloat16)"); - } - -#undef DISPATCH_AIR_TOPK -} - -size_t nvte_get_air_topk_workspace_bytes(int batch_size, int seq_len, int k) { - // Call with buf=nullptr to perform a size query (no GPU work is launched). - size_t buf_size = 0; - nv::standalone_air_topk(nullptr, buf_size, nullptr, batch_size, seq_len, k, nullptr, - nullptr, /*greater=*/true, /*stream=*/nullptr, - /*lengths=*/nullptr, /*is_prefill=*/false); - return buf_size; -} diff --git a/transformer_engine/common/util/standalone_air_topk.cuh b/transformer_engine/common/util/standalone_topk.cuh similarity index 97% rename from transformer_engine/common/util/standalone_air_topk.cuh rename to transformer_engine/common/util/standalone_topk.cuh index 6ef8dbafe1..1f1b5a07bf 100644 --- a/transformer_engine/common/util/standalone_air_topk.cuh +++ b/transformer_engine/common/util/standalone_topk.cuh @@ -59,7 +59,7 @@ constexpr int WARP_SIZE = 32; constexpr int WARP_BITS = 5; constexpr unsigned FULL_WARP_MASK = 0xffffffff; -namespace air_topk { +namespace topk { using WideT = float4; #ifdef __CUDA_ARCH__ @@ -1115,7 +1115,7 @@ __global__ void radix_topk_one_block_kernel(const T *in, const IdxT *in_idx, con in, in_idx, actual_len, k, out, out_idx, select_min, buf1, idx_buf1, buf2, idx_buf2); } // end kernel -} // namespace air_topk +} // namespace topk /***************Runtime API****************/ @@ -1124,17 +1124,17 @@ void standalone_radix_topk_(void *buf, size_t &buf_size, const T *in, const IdxT int batch_size, IdxT len, IdxT k, T *out, IdxT *out_idx, bool select_min, bool fused_last_filter, unsigned grid_dim, cudaStream_t stream, IdxT *lengths = nullptr) { - static_assert(air_topk::calc_num_passes() > 1); - constexpr int num_buckets = air_topk::calc_num_buckets(); + static_assert(topk::calc_num_passes() > 1); + constexpr int num_buckets = topk::calc_num_buckets(); - air_topk::Counter *counters = nullptr; + topk::Counter *counters = nullptr; IdxT *histograms = nullptr; T *buf1 = nullptr; IdxT *idx_buf1 = nullptr; T *buf2 = nullptr; IdxT *idx_buf2 = nullptr; { - IdxT len_candidates = air_topk::calc_buf_len(len); + IdxT len_candidates = topk::calc_buf_len(len); std::vector sizes = {sizeof(*counters) * batch_size, sizeof(*histograms) * num_buckets * batch_size, sizeof(*buf1) * len_candidates * batch_size, @@ -1167,20 +1167,20 @@ void standalone_radix_topk_(void *buf, size_t &buf_size, const T *in, const IdxT dim3 blocks(grid_dim, batch_size); - constexpr int num_passes = air_topk::calc_num_passes(); + constexpr int num_passes = topk::calc_num_passes(); - auto kernel = air_topk::radix_kernel; + auto kernel = topk::radix_kernel; for (int pass = 0; pass < num_passes; ++pass) { - air_topk::set_buf_pointers(in, in_idx, buf1, idx_buf1, buf2, idx_buf2, pass, in_buf, in_idx_buf, + topk::set_buf_pointers(in, in_idx, buf1, idx_buf1, buf2, idx_buf2, pass, in_buf, in_idx_buf, out_buf, out_idx_buf); if (fused_last_filter && pass == num_passes - 1 && out != nullptr) { - kernel = air_topk::radix_kernel; + kernel = topk::radix_kernel; } else if (fused_last_filter && pass == num_passes - 1 && out == nullptr) { - kernel = air_topk::radix_kernel; + kernel = topk::radix_kernel; } else if (out == nullptr) { - kernel = air_topk::radix_kernel; + kernel = topk::radix_kernel; } kernel<<>>(in, in_idx, in_buf, in_idx_buf, out_buf, out_idx_buf, @@ -1190,10 +1190,10 @@ void standalone_radix_topk_(void *buf, size_t &buf_size, const T *in, const IdxT if (!fused_last_filter) { if (out != nullptr) { - air_topk::last_filter_kernel<<>>( + topk::last_filter_kernel<<>>( in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters, select_min); } else { - air_topk::last_filter_kernel<<>>( + topk::last_filter_kernel<<>>( in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters, select_min); } } @@ -1204,7 +1204,7 @@ void standalone_radix_topk_one_block_(void *buf, size_t &buf_size, const T *in, int batch_size, IdxT len, IdxT k, T *out, IdxT *out_idx, bool select_min, cudaStream_t stream, IdxT *lengths = nullptr) { - static_assert(air_topk::calc_num_passes() > 1); + static_assert(topk::calc_num_passes() > 1); T *buf1 = nullptr; IdxT *idx_buf1 = nullptr; @@ -1228,18 +1228,18 @@ void standalone_radix_topk_one_block_(void *buf, size_t &buf_size, const T *in, } if (out != nullptr) { - air_topk::radix_topk_one_block_kernel + topk::radix_topk_one_block_kernel <<>>(in, in_idx, len, k, out, out_idx, select_min, buf1, idx_buf1, buf2, idx_buf2, lengths); } else { - air_topk::radix_topk_one_block_kernel + topk::radix_topk_one_block_kernel <<>>(in, in_idx, len, k, out, out_idx, select_min, buf1, idx_buf1, buf2, idx_buf2, lengths); } } template -void standalone_air_topk(void *buf, size_t &buf_size, const T *in, int batch_size, idxT len, idxT k, +void standalone_topk(void *buf, size_t &buf_size, const T *in, int batch_size, idxT len, idxT k, T *out, idxT *out_idx, bool greater, cudaStream_t stream = 0, idxT *lengths = nullptr, bool is_prefill = false) { constexpr int items_per_thread = 32; @@ -1266,7 +1266,7 @@ void standalone_air_topk(void *buf, size_t &buf_size, const T *in, int batch_siz sm_cnt = cached_sm_cnt; } unsigned grid_dim = - air_topk::calc_grid_dim(batch_size, len, sm_cnt); + topk::calc_grid_dim(batch_size, len, sm_cnt); if (grid_dim == 1) { standalone_radix_topk_one_block_( diff --git a/transformer_engine/common/util/topk.cu b/transformer_engine/common/util/topk.cu new file mode 100644 index 0000000000..2591917ee7 --- /dev/null +++ b/transformer_engine/common/util/topk.cu @@ -0,0 +1,57 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "../common.h" +#include "standalone_topk.cuh" + +void nvte_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor lengths_in, + NVTETensor keys_out, NVTETensor indices_out, NVTETensor workspace, + int batch_size, int seq_len, int k, size_t workspace_bytes) { + NVTE_API_CALL(nvte_topk); + using namespace transformer_engine; + + const Tensor *keys_in_tensor = convertNVTETensorCheck(keys_in); + const Tensor *lengths_tensor = convertNVTETensorCheck(lengths_in); + Tensor *keys_out_tensor = convertNVTETensor(keys_out); + Tensor *indices_tensor = convertNVTETensor(indices_out); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + void *d_workspace = workspace_tensor->data.dptr; + const int *d_lengths = reinterpret_cast(lengths_tensor->data.dptr); + int *d_indices = reinterpret_cast(indices_tensor->data.dptr); + + auto dtype = keys_in_tensor->data.dtype; + +#define DISPATCH_TOPK(T, d_in_cast, d_out_cast) \ + do { \ + const T *d_in = reinterpret_cast(keys_in_tensor->data.dptr); \ + T *d_out = reinterpret_cast(keys_out_tensor->data.dptr); \ + nv::standalone_topk(d_workspace, workspace_bytes, d_in, batch_size, seq_len, k, \ + d_out, d_indices, /*greater=*/true, stream, \ + const_cast(d_lengths), /*is_prefill=*/false); \ + } while (0) + + if (dtype == DType::kBFloat16) { + DISPATCH_TOPK(__nv_bfloat16, , ); + } else if (dtype == DType::kFloat32) { + DISPATCH_TOPK(float, , ); + } else { + NVTE_ERROR("nvte_topk: unsupported key dtype (supported: float32, bfloat16)"); + } + +#undef DISPATCH_TOPK +} + +size_t nvte_get_topk_workspace_bytes(int batch_size, int seq_len, int k) { + // Call with buf=nullptr to perform a size query (no GPU work is launched). + size_t buf_size = 0; + nv::standalone_topk(nullptr, buf_size, nullptr, batch_size, seq_len, k, nullptr, + nullptr, /*greater=*/true, /*stream=*/nullptr, + /*lengths=*/nullptr, /*is_prefill=*/false); + return buf_size; +} diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index bbea60e345..fe1f93dc7a 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -10,4 +10,4 @@ from .softmax import * from .gemm import * from .router import * -from .air_topk import * +from .topk import * diff --git a/transformer_engine/jax/cpp_extensions/air_topk.py b/transformer_engine/jax/cpp_extensions/topk.py similarity index 75% rename from transformer_engine/jax/cpp_extensions/air_topk.py rename to transformer_engine/jax/cpp_extensions/topk.py index f9189faaa1..120997f10d 100644 --- a/transformer_engine/jax/cpp_extensions/air_topk.py +++ b/transformer_engine/jax/cpp_extensions/topk.py @@ -1,7 +1,7 @@ # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -"""AIR TopK custom op""" +"""TopK custom op""" import functools from typing import Tuple @@ -12,31 +12,31 @@ from .base import BasePrimitive, register_primitive -__all__ = ["air_topk"] +__all__ = ["topk"] @functools.lru_cache(maxsize=512) -def get_air_topk_workspace_bytes(batch_size: int, seq_len: int, k: int) -> int: - """Query the workspace size required for AIR TopK. +def get_topk_workspace_bytes(batch_size: int, seq_len: int, k: int) -> int: + """Query the workspace size required for TopK. The result is memoised per (batch_size, seq_len, k) tuple so that repeated JIT compilations with the same shapes incur only one host-side CUDA call. """ import transformer_engine_jax as _te_jax - return int(_te_jax.get_air_topk_workspace_bytes(batch_size, seq_len, k)) + return int(_te_jax.get_topk_workspace_bytes(batch_size, seq_len, k)) -class AirTopKPrimitive(BasePrimitive): +class TopKPrimitive(BasePrimitive): """ - AIR TopK Primitive + TopK Primitive Selects the top-k entries (by value) from each row of a 2-D input using the AIR radix-selection algorithm. Returns both the top-k key values and their column indices within each row. """ - name = "te_air_topk_ffi" + name = "te_topk_ffi" multiple_results = True impl_static_args = (2,) # k_value inner_primitive = None @@ -53,12 +53,12 @@ def abstract( assert keys_dtype in [ jnp.float32, jnp.bfloat16, - ], f"air_topk: unsupported key dtype {keys_dtype}; supported: float32, bfloat16" - assert in_keys_aval.ndim == 2, "air_topk: keys input must be 2D (batch_size, seq_len)" + ], f"topk: unsupported key dtype {keys_dtype}; supported: float32, bfloat16" + assert in_keys_aval.ndim == 2, "topk: keys input must be 2D (batch_size, seq_len)" assert dtypes.canonicalize_dtype(in_lengths_aval.dtype) == jnp.int32 batch_size, seq_len = in_keys_aval.shape - workspace_bytes = get_air_topk_workspace_bytes(batch_size, seq_len, k_value) + workspace_bytes = get_topk_workspace_bytes(batch_size, seq_len, k_value) out_shape = (batch_size, k_value) out_keys_aval = jax.core.ShapedArray(shape=out_shape, dtype=keys_dtype) @@ -68,7 +68,7 @@ def abstract( @staticmethod def outer_abstract(*args, **kwargs): - out_keys_aval, out_indices_aval, _workspace_aval = AirTopKPrimitive.abstract( + out_keys_aval, out_indices_aval, _workspace_aval = TopKPrimitive.abstract( *args, **kwargs ) return (out_keys_aval, out_indices_aval) @@ -77,8 +77,8 @@ def outer_abstract(*args, **kwargs): def lowering(ctx, in_keys, in_lengths, k_value): keys_aval = ctx.avals_in[0] batch_size, seq_len = keys_aval.shape - workspace_bytes = get_air_topk_workspace_bytes(batch_size, seq_len, k_value) - return ffi.ffi_lowering(AirTopKPrimitive.name)( + workspace_bytes = get_topk_workspace_bytes(batch_size, seq_len, k_value) + return ffi.ffi_lowering(TopKPrimitive.name)( ctx, in_keys, in_lengths, @@ -88,8 +88,8 @@ def lowering(ctx, in_keys, in_lengths, k_value): @staticmethod def impl(in_keys, in_lengths, k_value): - assert AirTopKPrimitive.inner_primitive is not None - out_keys, out_indices, _workspace = AirTopKPrimitive.inner_primitive.bind( + assert TopKPrimitive.inner_primitive is not None + out_keys, out_indices, _workspace = TopKPrimitive.inner_primitive.bind( in_keys, in_lengths, k_value=k_value, @@ -97,10 +97,10 @@ def impl(in_keys, in_lengths, k_value): return (out_keys, out_indices) -register_primitive(AirTopKPrimitive) +register_primitive(TopKPrimitive) -def air_topk( +def topk( x: jnp.ndarray, k_value: int, ) -> Tuple[jnp.ndarray, jnp.ndarray]: @@ -124,7 +124,7 @@ def air_topk( batch_size, seq_len = x.shape lengths = jnp.full((batch_size,), seq_len, dtype=jnp.int32) - out_keys, out_indices = AirTopKPrimitive.outer_primitive.bind( + out_keys, out_indices = TopKPrimitive.outer_primitive.bind( x, lengths, k_value=k_value, diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index a7960b6882..cc6d3a7e66 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -171,9 +171,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); -// AIR TopK -XLA_FFI_DECLARE_HANDLER_SYMBOL(AirTopkHandler); -int64_t GetAirTopkWorkspaceBytes(int batch_size, int seq_len, int k); +// TopK +XLA_FFI_DECLARE_HANDLER_SYMBOL(TopkHandler); +int64_t GetTopkWorkspaceBytes(int batch_size, int seq_len, int k); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index f1be141dc1..3308fa6c4a 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -100,8 +100,8 @@ pybind11::dict Registrations() { dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler); dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler); - // AIR TopK - dict["te_air_topk_ffi"] = EncapsulateFFI(AirTopkHandler); + // TopK + dict["te_topk_ffi"] = EncapsulateFFI(TopkHandler); return dict; } @@ -120,7 +120,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_norm_bwd_workspace_sizes", &GetNormBackwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); - m.def("get_air_topk_workspace_bytes", &GetAirTopkWorkspaceBytes); + m.def("get_topk_workspace_bytes", &GetTopkWorkspaceBytes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator); diff --git a/transformer_engine/jax/csrc/extensions/air_topk.cpp b/transformer_engine/jax/csrc/extensions/topk.cpp similarity index 73% rename from transformer_engine/jax/csrc/extensions/air_topk.cpp rename to transformer_engine/jax/csrc/extensions/topk.cpp index 1fef1434c1..90d544441b 100644 --- a/transformer_engine/jax/csrc/extensions/air_topk.cpp +++ b/transformer_engine/jax/csrc/extensions/topk.cpp @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -#include "transformer_engine/air_topk.h" +#include "transformer_engine/topk.h" #include "../extensions.h" #include "xla/ffi/api/c_api.h" @@ -16,17 +16,17 @@ namespace jax { // JAX FFI handler // --------------------------------------------------------------------------- -Error_Type AirTopkFFI(cudaStream_t stream, Buffer_Type keys_in_buf, Buffer_Type lengths_buf, - Result_Type keys_out_buf, Result_Type indices_out_buf, - Result_Type workspace_buf, int64_t k_value, int64_t workbuf_bytes) { +Error_Type TopkFFI(cudaStream_t stream, Buffer_Type keys_in_buf, Buffer_Type lengths_buf, + Result_Type keys_out_buf, Result_Type indices_out_buf, + Result_Type workspace_buf, int64_t k_value, int64_t workbuf_bytes) { auto keys_in_dtype = convert_ffi_datatype_to_te_dtype(keys_in_buf.element_type()); auto keys_out_dtype = convert_ffi_datatype_to_te_dtype(keys_out_buf->element_type()); auto idx_out_dtype = convert_ffi_datatype_to_te_dtype(indices_out_buf->element_type()); - NVTE_CHECK(keys_in_dtype == keys_out_dtype, "AirTopkFFI: input and output key dtypes must match"); - NVTE_CHECK(idx_out_dtype == DType::kInt32, "AirTopkFFI: index output must be int32"); + NVTE_CHECK(keys_in_dtype == keys_out_dtype, "TopkFFI: input and output key dtypes must match"); + NVTE_CHECK(idx_out_dtype == DType::kInt32, "TopkFFI: index output must be int32"); auto keys_in_shape = keys_in_buf.dimensions(); - NVTE_CHECK(keys_in_shape.size() == 2, "AirTopkFFI: keys input must be 2D (batch_size, seq_len)"); + NVTE_CHECK(keys_in_shape.size() == 2, "TopkFFI: keys input must be 2D (batch_size, seq_len)"); int batch_size = static_cast(keys_in_shape[0]); int seq_len = static_cast(keys_in_shape[1]); @@ -38,7 +38,7 @@ Error_Type AirTopkFFI(cudaStream_t stream, Buffer_Type keys_in_buf, Buffer_Type case DType::kBFloat16: break; default: - NVTE_ERROR("AirTopkFFI: unsupported key dtype (float32 and bfloat16 only)"); + NVTE_ERROR("TopkFFI: unsupported key dtype (float32 and bfloat16 only)"); } // Build flat TensorWrappers over the full (batch_size * seq_len) / (batch_size * k) buffers. @@ -57,14 +57,14 @@ Error_Type AirTopkFFI(cudaStream_t stream, Buffer_Type keys_in_buf, Buffer_Type TensorWrapper(indices_out_buf->untyped_data(), flat_out_shape, DType::kInt32); auto workspace_tensor = TensorWrapper(workspace_buf->untyped_data(), ws_shape, DType::kByte); - nvte_air_topk(stream, keys_in_tensor.data(), lengths_tensor.data(), keys_out_tensor.data(), - idx_out_tensor.data(), workspace_tensor.data(), batch_size, seq_len, k, - static_cast(workbuf_bytes)); + nvte_topk(stream, keys_in_tensor.data(), lengths_tensor.data(), keys_out_tensor.data(), + idx_out_tensor.data(), workspace_tensor.data(), batch_size, seq_len, k, + static_cast(workbuf_bytes)); return ffi_with_cuda_error_check(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL(AirTopkHandler, AirTopkFFI, +XLA_FFI_DEFINE_HANDLER_SYMBOL(TopkHandler, TopkFFI, FFI::Bind() .Ctx() // stream .Arg() // keys_in @@ -80,8 +80,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(AirTopkHandler, AirTopkFFI, // Workspace-size query exposed to Python // --------------------------------------------------------------------------- -int64_t GetAirTopkWorkspaceBytes(int batch_size, int seq_len, int k) { - return static_cast(nvte_get_air_topk_workspace_bytes(batch_size, seq_len, k)); +int64_t GetTopkWorkspaceBytes(int batch_size, int seq_len, int k) { + return static_cast(nvte_get_topk_workspace_bytes(batch_size, seq_len, k)); } } // namespace jax From 0862bca024093796a53680b78f7cf25a0d5243e0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Apr 2026 05:35:26 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/include/transformer_engine/topk.h | 4 ++-- .../common/util/standalone_topk.cuh | 9 ++++----- transformer_engine/common/util/topk.cu | 16 ++++++++-------- transformer_engine/jax/cpp_extensions/topk.py | 4 +--- transformer_engine/jax/csrc/extensions/topk.cpp | 4 ++-- 5 files changed, 17 insertions(+), 20 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/topk.h b/transformer_engine/common/include/transformer_engine/topk.h index 6dfca850ec..1149f4218a 100644 --- a/transformer_engine/common/include/transformer_engine/topk.h +++ b/transformer_engine/common/include/transformer_engine/topk.h @@ -38,8 +38,8 @@ extern "C" { * Index dtype: int32. */ void nvte_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor lengths_in, - NVTETensor keys_out, NVTETensor indices_out, NVTETensor workspace, - int batch_size, int seq_len, int k, size_t workspace_bytes); + NVTETensor keys_out, NVTETensor indices_out, NVTETensor workspace, int batch_size, + int seq_len, int k, size_t workspace_bytes); /*! \brief Query the workspace size required by nvte_topk. * diff --git a/transformer_engine/common/util/standalone_topk.cuh b/transformer_engine/common/util/standalone_topk.cuh index 1f1b5a07bf..48f7b6e876 100644 --- a/transformer_engine/common/util/standalone_topk.cuh +++ b/transformer_engine/common/util/standalone_topk.cuh @@ -1173,7 +1173,7 @@ void standalone_radix_topk_(void *buf, size_t &buf_size, const T *in, const IdxT for (int pass = 0; pass < num_passes; ++pass) { topk::set_buf_pointers(in, in_idx, buf1, idx_buf1, buf2, idx_buf2, pass, in_buf, in_idx_buf, - out_buf, out_idx_buf); + out_buf, out_idx_buf); if (fused_last_filter && pass == num_passes - 1 && out != nullptr) { kernel = topk::radix_kernel; @@ -1240,8 +1240,8 @@ void standalone_radix_topk_one_block_(void *buf, size_t &buf_size, const T *in, template void standalone_topk(void *buf, size_t &buf_size, const T *in, int batch_size, idxT len, idxT k, - T *out, idxT *out_idx, bool greater, cudaStream_t stream = 0, - idxT *lengths = nullptr, bool is_prefill = false) { + T *out, idxT *out_idx, bool greater, cudaStream_t stream = 0, + idxT *lengths = nullptr, bool is_prefill = false) { constexpr int items_per_thread = 32; constexpr int multi_block_dim = 256; constexpr int single_block_dim = 1024; @@ -1265,8 +1265,7 @@ void standalone_topk(void *buf, size_t &buf_size, const T *in, int batch_size, i } sm_cnt = cached_sm_cnt; } - unsigned grid_dim = - topk::calc_grid_dim(batch_size, len, sm_cnt); + unsigned grid_dim = topk::calc_grid_dim(batch_size, len, sm_cnt); if (grid_dim == 1) { standalone_radix_topk_one_block_( diff --git a/transformer_engine/common/util/topk.cu b/transformer_engine/common/util/topk.cu index 2591917ee7..196d5d9fae 100644 --- a/transformer_engine/common/util/topk.cu +++ b/transformer_engine/common/util/topk.cu @@ -10,8 +10,8 @@ #include "standalone_topk.cuh" void nvte_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor lengths_in, - NVTETensor keys_out, NVTETensor indices_out, NVTETensor workspace, - int batch_size, int seq_len, int k, size_t workspace_bytes) { + NVTETensor keys_out, NVTETensor indices_out, NVTETensor workspace, int batch_size, + int seq_len, int k, size_t workspace_bytes) { NVTE_API_CALL(nvte_topk); using namespace transformer_engine; @@ -27,13 +27,13 @@ void nvte_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor l auto dtype = keys_in_tensor->data.dtype; -#define DISPATCH_TOPK(T, d_in_cast, d_out_cast) \ +#define DISPATCH_TOPK(T, d_in_cast, d_out_cast) \ do { \ - const T *d_in = reinterpret_cast(keys_in_tensor->data.dptr); \ - T *d_out = reinterpret_cast(keys_out_tensor->data.dptr); \ - nv::standalone_topk(d_workspace, workspace_bytes, d_in, batch_size, seq_len, k, \ - d_out, d_indices, /*greater=*/true, stream, \ - const_cast(d_lengths), /*is_prefill=*/false); \ + const T *d_in = reinterpret_cast(keys_in_tensor->data.dptr); \ + T *d_out = reinterpret_cast(keys_out_tensor->data.dptr); \ + nv::standalone_topk(d_workspace, workspace_bytes, d_in, batch_size, seq_len, k, d_out, \ + d_indices, /*greater=*/true, stream, const_cast(d_lengths), \ + /*is_prefill=*/false); \ } while (0) if (dtype == DType::kBFloat16) { diff --git a/transformer_engine/jax/cpp_extensions/topk.py b/transformer_engine/jax/cpp_extensions/topk.py index 120997f10d..235540e354 100644 --- a/transformer_engine/jax/cpp_extensions/topk.py +++ b/transformer_engine/jax/cpp_extensions/topk.py @@ -68,9 +68,7 @@ def abstract( @staticmethod def outer_abstract(*args, **kwargs): - out_keys_aval, out_indices_aval, _workspace_aval = TopKPrimitive.abstract( - *args, **kwargs - ) + out_keys_aval, out_indices_aval, _workspace_aval = TopKPrimitive.abstract(*args, **kwargs) return (out_keys_aval, out_indices_aval) @staticmethod diff --git a/transformer_engine/jax/csrc/extensions/topk.cpp b/transformer_engine/jax/csrc/extensions/topk.cpp index 90d544441b..c4e80db827 100644 --- a/transformer_engine/jax/csrc/extensions/topk.cpp +++ b/transformer_engine/jax/csrc/extensions/topk.cpp @@ -17,8 +17,8 @@ namespace jax { // --------------------------------------------------------------------------- Error_Type TopkFFI(cudaStream_t stream, Buffer_Type keys_in_buf, Buffer_Type lengths_buf, - Result_Type keys_out_buf, Result_Type indices_out_buf, - Result_Type workspace_buf, int64_t k_value, int64_t workbuf_bytes) { + Result_Type keys_out_buf, Result_Type indices_out_buf, Result_Type workspace_buf, + int64_t k_value, int64_t workbuf_bytes) { auto keys_in_dtype = convert_ffi_datatype_to_te_dtype(keys_in_buf.element_type()); auto keys_out_dtype = convert_ffi_datatype_to_te_dtype(keys_out_buf->element_type()); auto idx_out_dtype = convert_ffi_datatype_to_te_dtype(indices_out_buf->element_type());