diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 613aefc178..547a26e24a 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -47,6 +47,7 @@ from transformer_engine.jax.activation import activation from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense +from transformer_engine.jax.cpp_extensions.topk import topk GEMM_CASES = [ (256, 256, 512), @@ -1955,3 +1956,83 @@ def f(x): actual = load_array_dump("my_tensor_gpu0.bin", shape, dtype) assert_allclose(actual, expected, dtype=dtype) + + +@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float32]) +@pytest.mark.parametrize( + "problem_size", + [ + (1, 10000, 100), + (1, 50000, 200), + (4, 16384, 256), + (8, 65536, 512), + (1, 1000000, 1000), + ], +) +class TestTopK: + """Correctness tests for the TopK JAX primitive. + + Each test generates an input whose top-k entries lie in a known value range + so that correctness can be verified without a full sort, then cross-checks + against jax.lax.top_k as a reference. + """ + + def test_topk_1d(self, dtype, problem_size): + """1-D input: single row.""" + _bs, n, k = problem_size + + prng_key = jax.random.PRNGKey(0) + keys = jax.random.split(prng_key, 3) + topk_vals = jax.random.uniform(keys[0], shape=(k,), dtype=dtype, minval=1.5, maxval=2.5) + bottom_vals = jax.random.uniform( + keys[1], shape=(n - k,), dtype=dtype, minval=0.0, maxval=1.0 + ) + x = jax.random.permutation(keys[2], jnp.concatenate([topk_vals, bottom_vals])) + + ref_vals, ref_idx = jax.jit(jax.lax.top_k, static_argnums=(1,))(x, k) + prim_vals, prim_idx = jax.jit(topk, static_argnums=(1,))(x, k) + + # AIR TopK output is unordered; sort before comparing. + ref_vals, ref_idx = jax.lax.sort_key_val(ref_vals, ref_idx) + prim_vals, prim_idx = jax.lax.sort_key_val(prim_vals, prim_idx) + + assert_allclose(prim_vals, ref_vals, dtype=dtype) + + sorted_x = jax.lax.sort(x) + assert prim_vals[0] >= sorted_x[-(k + 1)] + + # Values at returned indices must match reference. + assert_allclose(x[prim_idx], x[ref_idx], dtype=dtype) + + def test_topk_2d(self, dtype, problem_size): + """2-D input: each row is an independent top-k problem.""" + bs, n, k = problem_size + + prng_key = jax.random.PRNGKey(42) + keys = jax.random.split(prng_key, 3) + topk_vals = jax.random.uniform(keys[0], shape=(bs, k), dtype=dtype, minval=1.5, maxval=2.5) + bottom_vals = jax.random.uniform( + keys[1], shape=(bs, n - k), dtype=dtype, minval=0.0, maxval=1.0 + ) + x_unsorted = jnp.concatenate([topk_vals, bottom_vals], axis=1) + # Shuffle columns independently per row. + col_perm = jax.random.permutation(keys[2], n) + x = x_unsorted[:, col_perm] + + ref_vals, ref_idx = jax.jit(jax.lax.top_k, static_argnums=(1,))(x, k) + prim_vals, prim_idx = jax.jit(topk, static_argnums=(1,))(x, k) + + # Sort each row independently for comparison. + ref_vals, ref_idx = jax.vmap(jax.lax.sort_key_val)(ref_vals, ref_idx) + prim_vals, prim_idx = jax.vmap(jax.lax.sort_key_val)(prim_vals, prim_idx) + + assert_allclose(prim_vals, ref_vals, dtype=dtype) + + # For each row, the smallest selected value must be >= the (k+1)-th largest in that row. + sorted_x = jnp.sort(x, axis=1) + assert jnp.all(prim_vals[:, 0] >= sorted_x[:, -(k + 1)]) + + # Values at returned indices must match reference values. + prim_gathered = jnp.take_along_axis(x, prim_idx, axis=1) + ref_gathered = jnp.take_along_axis(x, ref_idx, axis=1) + assert_allclose(prim_gathered, ref_gathered, dtype=dtype) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index b9e2b907e0..e578dffbfe 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -79,6 +79,11 @@ if(NOT arch_120_index EQUAL -1) endif() endif() +# If all architectures were special-cased and removed, disable CMake's automatic +# CUDA_ARCHITECTURES management — compilation flags are set via COMPILE_OPTIONS below. +if(NOT CMAKE_CUDA_ARCHITECTURES) + set(CMAKE_CUDA_ARCHITECTURES OFF) +endif() # cuDNN frontend API set(CUDNN_FRONTEND_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include") @@ -151,6 +156,7 @@ list(APPEND transformer_engine_cuda_sources normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu permutation/permutation.cu util/padding.cu + util/topk.cu swizzle/swizzle.cu swizzle/swizzle_block_scaling.cu fused_softmax/scaled_masked_softmax.cu diff --git a/transformer_engine/common/include/transformer_engine/topk.h b/transformer_engine/common/include/transformer_engine/topk.h new file mode 100644 index 0000000000..1149f4218a --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/topk.h @@ -0,0 +1,57 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_TOPK_H_ +#define TRANSFORMER_ENGINE_TOPK_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Compute the top-K (key, index) pairs using the AIR radix algorithm. + * + * Operates on a batch of rows: each row of length \p seq_len is processed + * independently and the \p k largest entries are selected. + * + * \param[in] stream CUDA stream used for the operation. + * \param[in] keys_in Input keys tensor, flat storage for + * batch_size rows of seq_len elements. + * \param[in] lengths_in Per-row lengths, shape (batch_size,); int32. + * Fill with seq_len for uniform-length batches. + * \param[in,out] keys_out Output top-k keys, flat storage for + * batch_size rows of k elements. + * \param[in,out] indices_out Output top-k indices (within each row), + * flat storage for batch_size rows of k int32 elements. + * \param[in,out] workspace Workspace tensor, shape (workspace_bytes,). + * \param[in] batch_size Number of rows. + * \param[in] seq_len Number of elements per row. + * \param[in] k Number of top-K entries to select per row. + * \param[in] workspace_bytes Workspace size in bytes; must be >= + * nvte_get_topk_workspace_bytes(batch_size, seq_len, k). + * + * Supported key dtypes: float32, bfloat16. + * Index dtype: int32. + */ +void nvte_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor lengths_in, + NVTETensor keys_out, NVTETensor indices_out, NVTETensor workspace, int batch_size, + int seq_len, int k, size_t workspace_bytes); + +/*! \brief Query the workspace size required by nvte_topk. + * + * \param[in] batch_size Number of rows. + * \param[in] seq_len Number of elements per row. + * \param[in] k Top-K count. + * \return Required workspace size in bytes. + */ +size_t nvte_get_topk_workspace_bytes(int batch_size, int seq_len, int k); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_TOPK_H_ diff --git a/transformer_engine/common/util/standalone_topk.cuh b/transformer_engine/common/util/standalone_topk.cuh new file mode 100644 index 0000000000..48f7b6e876 --- /dev/null +++ b/transformer_engine/common/util/standalone_topk.cuh @@ -0,0 +1,1281 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#pragma once + +#include +#include +#include + +#include +#include +#include +namespace cg = cooperative_groups; + +// Workspace pointer-alignment helpers. +inline size_t calc_aligned_size(const std::vector &sizes) { + const size_t ALIGN_BYTES = 256; + const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); + size_t total = 0; + for (auto sz : sizes) total += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; + return total + ALIGN_BYTES - 1; +} +inline std::vector calc_aligned_pointers(const void *p, const std::vector &sizes) { + const size_t ALIGN_BYTES = 256; + const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); + char *ptr = + reinterpret_cast((reinterpret_cast(p) + ALIGN_BYTES - 1) & ALIGN_MASK); + std::vector ptrs; + ptrs.reserve(sizes.size()); + for (auto sz : sizes) { + ptrs.push_back(ptr); + ptr += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; + } + return ptrs; +} + +// Helper: convert a float literal to type T without relying on implicit +// conversions (needed when __CUDA_NO_BFLOAT16_CONVERSIONS__ is defined). +namespace nv_detail { +template +__host__ __device__ inline T float_to_T(float v) { + return static_cast(v); +} +#if defined(__CUDACC__) +template <> +__host__ __device__ inline __nv_bfloat16 float_to_T<__nv_bfloat16>(float v) { + return __float2bfloat16(v); +} +#endif +} // namespace nv_detail + +namespace nv { + +constexpr int VECTORIZED_READ_SIZE = 16; +constexpr int WARP_SIZE = 32; +constexpr int WARP_BITS = 5; +constexpr unsigned FULL_WARP_MASK = 0xffffffff; + +namespace topk { +using WideT = float4; + +#ifdef __CUDA_ARCH__ +using ::atomicAdd; +inline __device__ size_t atomicAdd(size_t *address, size_t value) { + static_assert(sizeof(size_t) == sizeof(unsigned long long int)); + return atomicAdd((unsigned long long int *)address, (unsigned long long int)value); +} +#endif + +template +__host__ __device__ constexpr int calc_num_buckets() { + return 1 << BitsPerPass; +} + +/** + * @brief Provide a ceiling division operation ie. ceil(a / b) + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr __host__ __device__ IntType ceildiv(IntType a, IntType b) { + return (a + b - 1) / b; +} + +/** + * @brief Provide an alignment function ie. ceil(a / b) * b + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr __host__ __device__ IntType alignTo(IntType a, IntType b) { + return ceildiv(a, b) * b; +} + +template +__host__ __device__ constexpr int calc_num_passes() { + return ceildiv(sizeof(T) * 8, BitsPerPass); +} + +__host__ __device__ int round(int num, int round_value) { + return ((num - 1) / round_value + 1) * round_value; +} + +/** + * Bit 0 is the least significant (rightmost); + * this implementation processes input from the most to the least significant + * bit. This way, we can skip some passes in the end at the cost of having an + * unsorted output. + * + * NB: Use pass=-1 for calc_mask(). + */ +template +__device__ constexpr int calc_start_bit(int pass) { + int start_bit = static_cast(sizeof(T) * 8) - (pass + 1) * BitsPerPass; + if (start_bit < 0) { + start_bit = 0; + } + return start_bit; +} + +template +__device__ constexpr unsigned calc_mask(int pass) { + static_assert(BitsPerPass <= 31); + int num_bits = calc_start_bit(pass - 1) - calc_start_bit(pass); + return (1 << num_bits) - 1; +} + +/** + * Use CUB to twiddle bits - so that we can correctly compare bits of + * floating-point values as well as of integers. + */ +template +__device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool select_min) { + auto bits = reinterpret_cast::UnsignedBits &>(key); + bits = cub::Traits::TwiddleIn(bits); + if (!select_min) { + bits = ~bits; + } + return bits; +} + +template +__device__ T twiddle_out(typename cub::Traits::UnsignedBits bits, bool select_min) { + if (!select_min) { + bits = ~bits; + } + bits = cub::Traits::TwiddleOut(bits); + return reinterpret_cast(bits); +} + +template +__device__ int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) { + static_assert(BitsPerPass <= sizeof(int) * 8 - 1, + "BitsPerPass is too large that the result type could not be int"); + return (twiddle_in(x, select_min) >> start_bit) & mask; +} + +template +__host__ __device__ IdxT calc_buf_len(IdxT len) { + // When writing is skipped, only read `in`(type T). + // When writing is not skipped, read `in_buf`(T) and `in_idx_buf`(IdxT), and + // write `out_buf`(T) and `out_idx_buf`(IdxT). The ratio between these cases + // determines whether to skip writing and hence the buffer size. constexpr + // float ratio = 2 + sizeof(IdxT) * 2.0 / sizeof(T); + constexpr float ratio = 128; + return len / ratio; + // return len; +} + +/** + * Map a Func over the input data, using vectorized load instructions if + * possible. + * + * NB: in future, we should move this to + * cpp/include/raft/linalg/detail/unary_op.cuh, which currently does not support + * the second lambda argument (index of an element) + * + * @tparam T element type + * @tparam IdxT indexing type + * @tparam Func void (T x, IdxT idx) + * + * @param thread_rank rank of the calling thread among all participating threads + * @param num_threads number of the threads that participate in processing + * @param in the input data + * @param len the number of elements to read + * @param f the lambda taking two arguments (T x, IdxT idx) + */ +template +__device__ void vectorized_process(size_t thread_rank, size_t num_threads, const T *in, idxT len, + Func f) { + if constexpr (sizeof(T) >= sizeof(WideT)) { + for (idxT i = thread_rank; i < len; i += num_threads) { + f(in[i], i); + } + } else { + static_assert(sizeof(WideT) % sizeof(T) == 0); + constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); + // TODO: it's UB + union { + WideT scalar; + T array[items_per_scalar]; + } wide; + + int skip_cnt = + (reinterpret_cast(in) % sizeof(WideT)) + ? ((sizeof(WideT) - reinterpret_cast(in) % sizeof(WideT)) / sizeof(T)) + : 0; + if (skip_cnt > len) { + skip_cnt = len; + } + const WideT *in_cast = reinterpret_cast(in + skip_cnt); + const idxT len_cast = (len - skip_cnt) / items_per_scalar; + + for (idxT i = thread_rank; i < len_cast; i += num_threads) { + wide.scalar = in_cast[i]; + const idxT real_i = skip_cnt + i * items_per_scalar; +#pragma unroll + for (int j = 0; j < items_per_scalar; ++j) { + f(wide.array[j], real_i + j); + } + } + + static_assert(WARP_SIZE >= items_per_scalar); + // and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt + // no need to use loop + if (thread_rank < skip_cnt) { + f(in[thread_rank], thread_rank); + } + // because len_cast = (len - skip_cnt) / items_per_scalar, + // len_cast * items_per_scalar + items_per_scalar > len - skip_cnt; + // and so + // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= + // WARP_SIZE no need to use loop + const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank; + if (remain_i < len) { + f(in[remain_i], remain_i); + } + } +} + +// sync_width should >= WARP_SIZE +template +__device__ void vectorized_process(const T *in, idxT len, Func f, int sync_width) { + const idxT stride = blockDim.x * gridDim.x; + const idxT tid = blockIdx.x * blockDim.x + threadIdx.x; + if constexpr (sizeof(T) >= sizeof(WideT)) { + for (idxT i = tid; i < len; i += stride) { + f(in[i], i, true); + } + } else { + static_assert(sizeof(WideT) % sizeof(T) == 0); + constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); + union { + WideT scalar; + T array[items_per_scalar]; + } wide; + + int skip_cnt = + (reinterpret_cast(in) % sizeof(WideT)) + ? ((sizeof(WideT) - reinterpret_cast(in) % sizeof(WideT)) / sizeof(T)) + : 0; + if (skip_cnt > len) { + skip_cnt = len; + } + const WideT *in_cast = reinterpret_cast(in + skip_cnt); + const idxT len_cast = (len - skip_cnt) / items_per_scalar; + + const idxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width; + for (idxT i = tid; i < len_cast_for_sync; i += stride) { + bool valid = i < len_cast; + if (valid) { + wide.scalar = in_cast[i]; + } + const idxT real_i = skip_cnt + i * items_per_scalar; +#pragma unroll + for (int j = 0; j < items_per_scalar; ++j) { + f(wide.array[j], real_i + j, valid); + } + } + + static_assert(WARP_SIZE >= items_per_scalar); + // need at most one warp for skipped and remained elements, + // and sync_width >= WARP_SIZE + if (tid < sync_width) { + bool valid = tid < skip_cnt; + T value = valid ? in[tid] : T(); + f(value, tid, valid); + + const idxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; + valid = remain_i < len; + value = valid ? in[remain_i] : T(); + f(value, remain_i, valid); + } + } +} + +template +struct alignas(128) Counter { + // We are processing the values in multiple passes, from most significant to + // least significant. In each pass, we keep the length of input (`len`) and + // the `k` of current pass, and update them at the end of the pass. + IdxT k; + IdxT len; + + // `previous_len` is the length of input in previous pass. Note that + // `previous_len` rather than `len` is used for the filtering step because + // filtering is indeed for previous pass (see comments before + // `radix_kernel`). + IdxT previous_len; + + // We determine the bits of the k_th value inside the mask processed by the + // pass. The already known bits are stored in `kth_value_bits`. It's used to + // discriminate a element is a result (written to `out`), a candidate for next + // pass (written to `out_buf`), or not useful (discarded). The bits that are + // not yet processed do not matter for this purpose. + typename cub::Traits::UnsignedBits kth_value_bits; + + // Record how many elements have passed filtering. It's used to determine the + // position in the `out_buf` where an element should be written. + alignas(128) IdxT filter_cnt; + + // For a row inside a batch, we may launch multiple thread blocks. This + // counter is used to determine if the current block is the last running + // block. If so, this block will execute scan() and choose_bucket(). + alignas(128) unsigned int finished_block_cnt; + + // Record how many elements have been written to the front of `out`. Elements + // less (if select_min==true) than the k-th value are written from front to + // back. + alignas(128) IdxT out_cnt; + + // Record how many elements have been written to the back of `out`. Elements + // equal to the k-th value are written from back to front. We need to keep + // count of them separately because the number of elements that <= the k-th + // value might exceed k. + alignas(128) IdxT out_back_cnt; +}; + +/** + * Fused filtering of the current pass and building histogram for the next pass + * (see steps 4 & 1 in `radix_kernel` description). + */ +template +__device__ void filter_and_histogram(const T *in_buf, const IdxT *in_idx_buf, T *out_buf, + IdxT *out_idx_buf, T *out, IdxT *out_idx, IdxT previous_len, + Counter *counter, IdxT *histogram, bool select_min, + int pass, bool early_stop) { + constexpr int num_buckets = calc_num_buckets(); + __shared__ IdxT histogram_smem[num_buckets]; + for (IdxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram_smem[i] = 0; + } + __syncthreads(); + + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); + + if (pass == 0) { + // Passed to vectorized_process, this function executes in all blocks in + // parallel, i.e. the work is split along the input (both, in batches and + // chunks of a single row). Later, the histograms are merged using + // atomicAdd. + auto f = [select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); + }; + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, in_buf, previous_len, f); + } else { + IdxT *p_filter_cnt = &counter->filter_cnt; + IdxT *p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; + const int previous_start_bit = calc_start_bit(pass - 1); + + // See the remark above on the distributed execution of `f` using + // vectorized_process. + auto f = [in_idx_buf, out_buf, out_idx_buf, out, out_idx, select_min, start_bit, mask, + previous_start_bit, kth_value_bits, p_filter_cnt, p_out_cnt, + early_stop](T value, IdxT i) { + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { + if (early_stop) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if constexpr (store_out) { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else { + if (out_buf) { + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); + } + } + // the condition `(out_buf || early_stop)` is a little tricky: + // If we skip writing to `out_buf` (when `out_buf` is nullptr), we should + // skip writing to `out` too. So we won't write the same value to `out` + // multiple times in different passes. And if we keep skipping the + // writing, values will be written in `last_filter_kernel()` at last. But + // when `early_stop` is true, we need to write to `out` since it's the + // last chance. + else if ((out_buf || early_stop) && previous_bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if constexpr (store_out) { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + }; + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, in_buf, previous_len, f); + } + if (early_stop) { + return; + } + __syncthreads(); + + // merge histograms produced by individual blocks + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + if (histogram_smem[i] != 0) { + atomicAdd(histogram + i, histogram_smem[i]); + } + } +} + +/** + * Replace histogram with its own prefix sum + * (step 2 in `radix_kernel` description) + */ +template +__device__ void scan(volatile IdxT *histogram) { + constexpr int num_buckets = calc_num_buckets(); + if constexpr (num_buckets >= BlockSize) { + static_assert(num_buckets % BlockSize == 0); + constexpr int items_per_thread = num_buckets / BlockSize; + typedef cub::BlockLoad BlockLoad; + typedef cub::BlockStore + BlockStore; + typedef cub::BlockScan BlockScan; + + __shared__ union { + typename BlockLoad::TempStorage load; + typename BlockScan::TempStorage scan; + typename BlockStore::TempStorage store; + } temp_storage; + IdxT thread_data[items_per_thread]; + + BlockLoad(temp_storage.load).Load(histogram, thread_data); + __syncthreads(); + + BlockScan(temp_storage.scan).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + BlockStore(temp_storage.store).Store(histogram, thread_data); + } else { + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + IdxT thread_data = 0; + if (threadIdx.x < num_buckets) { + thread_data = histogram[threadIdx.x]; + } + + BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + if (threadIdx.x < num_buckets) { + histogram[threadIdx.x] = thread_data; + } + } +} + +/** + * Calculate in which bucket the k-th value will fall + * (steps 3 in `radix_kernel` description) + */ +template +__device__ void choose_bucket(Counter *counter, const IdxT *histogram, const IdxT k, + const int pass) { + constexpr int num_buckets = calc_num_buckets(); + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + IdxT prev = (i == 0) ? 0 : histogram[i - 1]; + IdxT cur = histogram[i]; + + // one and only one thread will satisfy this condition, so counter is + // written by only one thread + if (prev < k && cur >= k) { + counter->k = k - prev; // how many values still are there to find + counter->len = cur - prev; // number of values in next pass + typename cub::Traits::UnsignedBits bucket = i; + int start_bit = calc_start_bit(pass); + counter->kth_value_bits |= bucket << start_bit; + } + } +} + +template +__device__ void scan_warp_version(cg::thread_block_tile const &warp, + volatile IdxT *histogram, Counter *counter, const IdxT k, + const int pass) { + constexpr int num_buckets = calc_num_buckets(); + + __shared__ IdxT warp_histogram[num_buckets >> WARP_BITS]; + for (int i = threadIdx.x; i < num_buckets; i += BlockSize) { + IdxT data = histogram[i]; + IdxT warp_sum = cg::reduce(warp, data, cg::plus()); + + if (i % WARP_SIZE == 0) { + warp_histogram[i >> WARP_BITS] = warp_sum; + } + } + __syncthreads(); + + if (threadIdx.x < WARP_SIZE) { + IdxT value = warp_histogram[threadIdx.x * 2] + warp_histogram[threadIdx.x * 2 + 1]; + IdxT prefix = value; + for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { + IdxT n = __shfl_up_sync(FULL_WARP_MASK, prefix, offset, WARP_SIZE); + if (threadIdx.x >= offset) prefix += n; + } + IdxT prefix_high = __shfl_sync(FULL_WARP_MASK, prefix, threadIdx.x - 1, WARP_SIZE); + if (threadIdx.x == 0) prefix_high = 0; + warp_histogram[threadIdx.x * 2] += prefix_high; + warp_histogram[threadIdx.x * 2 + 1] = value + prefix_high; + __syncwarp(); + + // Find the target warp bucket + IdxT target_warp = 2048; // invalid value + // bool is_one_in_warp=false; + for (int i = threadIdx.x; i < 64 && target_warp == 2048; i += WARP_SIZE) { + IdxT prev = (i == 0) ? 0 : warp_histogram[i - 1]; + IdxT cur = warp_histogram[i]; + bool is_selected = prev < k && cur >= k; + unsigned mask = __ballot_sync(FULL_WARP_MASK, is_selected); + if (__popc(mask) > 0) { + // target_warp = __ffs(mask) -1 + (i/WARP_SIZE)*WARP_SIZE; + target_warp = __ffs(mask) - 1 + ((i >> WARP_BITS) << WARP_BITS); + // is_one_in_warp= (target_warp==0? warp_histogram[0]: + // warp_histogram[target_warp]-warp_histogram[target_warp-1])==1?true:false; + } + } + + // Find the target bucket + // if(is_one_in_warp){ + // bool is_one=histogram[target_warp*WARP_SIZE+threadIdx.x]==1?1:0; + // unsigned mask = __ballot_sync(FULL_WARP_MASK, is_one); + // IdxT target_bucket=__ffs(mask)-1+target_warp*WARP_SIZE; + // IdxT prev=target_warp==0? 0: warp_histogram[target_warp-1]; + // IdxT cur=warp_histogram[target_warp]; + // if(threadIdx.x==0) { + // counter->k = k - prev; // how many values still are there + // to find counter->len = cur - prev; // number of values in next + // pass typename cub::Traits::UnsignedBits bucket = + // target_bucket; int start_bit = calc_start_bit(pass); counter->kth_value_bits |= bucket << + // start_bit; + // } + // }else{ + value = histogram[(target_warp << WARP_BITS) + threadIdx.x]; + for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { + IdxT n = __shfl_up_sync(FULL_WARP_MASK, value, offset, WARP_SIZE); + if (threadIdx.x >= offset) value += n; + } + value += (target_warp == 0 ? 0 : warp_histogram[target_warp - 1]); + + for (int i = threadIdx.x; i < WARP_SIZE; i += WARP_SIZE) { + IdxT prev = __shfl_up_sync(FULL_WARP_MASK, value, 1, WARP_SIZE); + prev = (i == 0) ? (target_warp == 0 ? 0 : warp_histogram[target_warp - 1]) : prev; + IdxT cur = value; + if (prev < k && cur >= k) { + counter->k = k - prev; // how many values still are there to find + counter->len = cur - prev; // number of values in next pass + typename cub::Traits::UnsignedBits bucket = (target_warp << WARP_BITS) + i; + int start_bit = calc_start_bit(pass); + counter->kth_value_bits |= bucket << start_bit; + } + } + // } + } +} +// For one-block version, last_filter() could be called when pass < num_passes +// - 1. So `pass` could not be constexpr +template +__device__ void last_filter(const T *in_buf, const IdxT *in_idx_buf, T *out, IdxT *out_idx, + IdxT current_len, IdxT k, Counter *counter, + const bool select_min, const int pass) { + const auto kth_value_bits = counter->kth_value_bits; + const int start_bit = calc_start_bit(pass); + + // changed in choose_bucket(); need to reload + const IdxT needed_num_of_kth = counter->k; + IdxT *p_out_cnt = &counter->out_cnt; + IdxT *p_out_back_cnt = &counter->out_back_cnt; + for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) { + const T value = in_buf[i]; + const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if (bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if constexpr (store_out) { + out[pos] = value; + } + // For one-block version, `in_idx_buf` could be nullptr at pass 0. + // For non one-block version, if writing has been skipped, `in_idx_buf` + // could be nullptr if `in_buf` is `in` + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else if (bits == kth_value_bits) { + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if (back_pos < needed_num_of_kth) { + IdxT pos = k - 1 - back_pos; + if constexpr (store_out) { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + } + } +} + +template +__global__ void last_filter_kernel(const T *in, const IdxT *in_idx, const T *in_buf, + const IdxT *in_idx_buf, T *out, IdxT *out_idx, IdxT len, IdxT k, + Counter *counters, const bool select_min) { + const size_t batch_id = blockIdx.y; // size_t to avoid multiplication overflow + + Counter *counter = counters + batch_id; + IdxT previous_len = counter->previous_len; + if (previous_len == 0) { + return; + } + const IdxT buf_len = calc_buf_len(len); + if (previous_len > buf_len || in_buf == in) { + in_buf = in + batch_id * len; + in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; + previous_len = len; + } else { + in_buf += batch_id * buf_len; + in_idx_buf += batch_id * buf_len; + } + if constexpr (store_out) { + out += batch_id * k; + } + out_idx += batch_id * k; + + constexpr int pass = calc_num_passes() - 1; + constexpr int start_bit = calc_start_bit(pass); + + const auto kth_value_bits = counter->kth_value_bits; + const IdxT needed_num_of_kth = counter->k; + IdxT *p_out_cnt = &counter->out_cnt; + IdxT *p_out_back_cnt = &counter->out_back_cnt; + + auto f = [k, select_min, kth_value_bits, needed_num_of_kth, p_out_cnt, p_out_back_cnt, in_idx_buf, + out, out_idx](T value, IdxT i) { + const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if (bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if constexpr (store_out) { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else if (bits == kth_value_bits) { + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if (back_pos < needed_num_of_kth) { + IdxT pos = k - 1 - back_pos; + if constexpr (store_out) { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + } + }; + + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, in_buf, previous_len, f); +} + +/** + * + * It is expected to call this kernel multiple times (passes), in each pass we + * process a radix, going from the most significant towards the least + * significant bits (MSD). + * + * Conceptually, each pass consists of 4 steps: + * + * 1. Calculate histogram + * First, transform bits into a digit, the value of which is in the range + * [0, 2^{BITS_PER_PASS}-1]. Then count the frequency of each digit value + * and the result is a histogram. That is, histogram[i] contains the count of + * inputs having value i. + * + * 2. Scan the histogram + * Inclusive prefix sum is computed for the histogram. After this step, + * histogram[i] contains the count of inputs having value <= i. + * + * 3. Find the bucket j of the histogram that the k-th value falls into + * + * 4. Filtering + * Input elements whose digit value +__device__ void radix_kernel_func(const T *in, const IdxT *in_idx, const T *in_buf, + const IdxT *in_idx_buf, T *out_buf, IdxT *out_idx_buf, T *out, + IdxT *out_idx, Counter *counter, IdxT *histogram, + const IdxT len, const IdxT k, const bool select_min, + const int pass) { + if (len <= k) { + if (pass == 0) { + for (int index = threadIdx.x; index < len; index += BlockSize) { + if constexpr (store_out) { + out[index] = in[index]; + } + out_idx[index] = in_idx ? in_idx[index] : index; + } + for (int index = threadIdx.x + len; index < k; index += BlockSize) { + if constexpr (store_out) { + out[index] = nv_detail::float_to_T(-1.0f); + } + out_idx[index] = -1; + } + return; + } else { + return; + } + } + + IdxT current_k; + IdxT previous_len; + IdxT current_len; + if (pass == 0) { + current_k = k; + previous_len = len; + // Need to do this so setting counter->previous_len for the next pass is + // correct. This value is meaningless for pass 0, but it's fine because pass + // 0 won't be the last pass in this implementation so pass 0 won't hit the + // "if (pass == num_passes - 1)" branch. Maybe it's better to reload + // counter->previous_len and use it rather than current_len in last_filter() + current_len = len; + } else { + current_k = counter->k; + current_len = counter->len; + previous_len = counter->previous_len; + } + if (current_len == 0) { + return; + } + + // When k=len, early_stop will be true at pass 0. It means + // filter_and_histogram() should handle correctly the case that pass=0 and + // early_stop=true. However, this special case of k=len is handled in other + // way in select_k() so such case is not possible here. + const bool early_stop = (current_len == current_k); + const IdxT buf_len = calc_buf_len(len); + constexpr int num_buckets = calc_num_buckets(); + // "previous_len > buf_len" means previous pass skips writing buffer + if (pass == 0 || pass == 1 || previous_len > buf_len) { + in_buf = in; + in_idx_buf = in_idx ? in_idx : nullptr; + previous_len = len; + } + // "current_len > buf_len" means current pass will skip writing buffer + if (pass == 0 || current_len > buf_len) { + out_buf = nullptr; + out_idx_buf = nullptr; + } + + filter_and_histogram(in_buf, in_idx_buf, out_buf, out_idx_buf, + out, out_idx, previous_len, counter, + histogram, select_min, pass, early_stop); + __threadfence(); + + bool isLastBlock = false; + if (threadIdx.x == 0) { + unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); + isLastBlock = (finished == (gridDim.x - 1)); + } + + if (__syncthreads_or(isLastBlock)) { + if (early_stop) { + if (threadIdx.x == 0) { + // `last_filter_kernel()` requires setting previous_len + counter->previous_len = 0; + counter->len = 0; + } + return; + } + + constexpr int num_passes = calc_num_passes(); + + scan(histogram); + __syncthreads(); + choose_bucket(counter, histogram, current_k, pass); + __syncthreads(); + + // reset for next pass + if (pass != num_passes - 1) { + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + } + if (threadIdx.x == 0) { + // `last_filter_kernel()` requires setting previous_len even in the last + // pass + counter->previous_len = current_len; + // not necessary for the last pass, but put it here anyway + counter->filter_cnt = 0; + } + + if constexpr (fused_last_filter) { + if (pass == num_passes - 1) { + last_filter( + out_buf ? out_buf : in_buf, out_idx_buf ? out_idx_buf : in_idx_buf, out, out_idx, + out_buf ? current_len : len, k, counter, select_min, pass); + } + } + } +} + +template +__global__ void radix_kernel(const T *in, const IdxT *in_idx, const T *in_buf, + const IdxT *in_idx_buf, T *out_buf, IdxT *out_idx_buf, T *out, + IdxT *out_idx, Counter *counters, IdxT *histograms, + const IdxT len, const IdxT k, const bool select_min, const int pass, + IdxT *lengths) { + const size_t batch_id = blockIdx.y; + auto counter = counters + batch_id; + constexpr int num_buckets = calc_num_buckets(); + auto histogram = histograms + batch_id * num_buckets; + + in += batch_id * len; + if (in_idx) { + in_idx += batch_id * len; + } + if constexpr (store_out) { + out += batch_id * k; + } + out_idx += batch_id * k; + + const IdxT buf_len = calc_buf_len(len); + in_buf += batch_id * buf_len; + in_idx_buf += batch_id * buf_len; + + out_buf += batch_id * buf_len; + out_idx_buf += batch_id * buf_len; + + IdxT actual_len = len; + if (lengths != nullptr) { + actual_len = lengths[batch_id]; + } + radix_kernel_func( + in, in_idx, in_buf, in_idx_buf, out_buf, out_idx_buf, out, out_idx, counter, histogram, + actual_len, k, select_min, pass); +} + +template +unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) { + static_assert(VECTORIZED_READ_SIZE / sizeof(T) >= 1); + + int active_blocks; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &active_blocks, radix_kernel, BlockSize, 0); + active_blocks *= sm_cnt; + + IdxT best_num_blocks = 0; + float best_tail_wave_penalty = 1.0f; + const IdxT max_num_blocks = ceildiv(len, VECTORIZED_READ_SIZE / sizeof(T) * BlockSize); + for (int num_waves = 1;; ++num_waves) { + IdxT num_blocks = std::min( + max_num_blocks, static_cast(std::max(num_waves * active_blocks / batch_size, 1))); + IdxT items_per_thread = ceildiv(len, num_blocks * BlockSize); + items_per_thread = alignTo(items_per_thread, VECTORIZED_READ_SIZE / sizeof(T)); + num_blocks = ceildiv(len, items_per_thread * BlockSize); + float actual_num_waves = static_cast(num_blocks) * batch_size / active_blocks; + float tail_wave_penalty = + (ceilf(actual_num_waves) - actual_num_waves) / ceilf(actual_num_waves); + + // 0.15 is determined experimentally. It also ensures breaking the loop + // early, e.g. when num_waves > 7, tail_wave_penalty will always <0.15 + if (tail_wave_penalty < 0.15) { + best_num_blocks = num_blocks; + break; + } else if (tail_wave_penalty < best_tail_wave_penalty) { + best_num_blocks = num_blocks; + best_tail_wave_penalty = tail_wave_penalty; + } + + if (num_blocks == max_num_blocks) { + break; + } + } + return best_num_blocks; +} + +template +__host__ __device__ void set_buf_pointers(const T *in, const IdxT *in_idx, T *buf1, IdxT *idx_buf1, + T *buf2, IdxT *idx_buf2, int pass, const T *&in_buf, + const IdxT *&in_idx_buf, T *&out_buf, + IdxT *&out_idx_buf) { + if (pass == 0) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + in_buf = in; + in_idx_buf = in_idx; + out_buf = buf1; + out_idx_buf = idx_buf1; + } else if (pass % 2 == 0) { + in_buf = buf1; + in_idx_buf = idx_buf1; + out_buf = buf2; + out_idx_buf = idx_buf2; + } else { + in_buf = buf2; + in_idx_buf = idx_buf2; + out_buf = buf1; + out_idx_buf = idx_buf1; + } +} + +// The following a few functions are for the one-block version, which uses +// single thread block for each row of a batch. +template +__device__ void filter_and_histogram_for_one_block(const T *in_buf, const IdxT *in_idx_buf, + T *out_buf, IdxT *out_idx_buf, T *out, + IdxT *out_idx, Counter *counter, + IdxT *histogram, bool select_min, int pass) { + constexpr int num_buckets = calc_num_buckets(); + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + IdxT *p_filter_cnt = &counter->filter_cnt; + if (threadIdx.x == 0) { + *p_filter_cnt = 0; + } + __syncthreads(); + + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); + const IdxT previous_len = counter->previous_len; + + if (pass == 0) { + if constexpr (is_vectorized) { + auto f = [histogram, select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + }; + vectorized_process(threadIdx.x, blockDim.x, in_buf, previous_len, f); + } else { + for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + const T value = in_buf[i]; + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + } + } + } else { + // not use vectorized_process here because it increases #registers a lot + IdxT *p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; + const int previous_start_bit = calc_start_bit(pass - 1); + + for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + const T value = in_buf[i]; + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { +#if CUDART_VERSION < 12000 + // Avoiding potential compiler bug in CUDA 11 + volatile +#endif + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + } else if (previous_bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + if constexpr (store_out) { + out[pos] = value; + } + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + } + } +} + +template +__device__ void radix_topk_one_block_func(const T *in, const IdxT *in_idx, const IdxT len, + const IdxT k, T *out, IdxT *out_idx, + const bool select_min, T *buf1, IdxT *idx_buf1, T *buf2, + IdxT *idx_buf2) { + if (len <= k) { + for (int index = threadIdx.x; index < len; index += BlockSize) { + if constexpr (store_out) { + out[index] = in[index]; + } + out_idx[index] = in_idx ? in_idx[index] : index; + } + for (int index = threadIdx.x + len; index < k; index += BlockSize) { + if constexpr (store_out) { + out[index] = nv_detail::float_to_T(-1.0f); + } + out_idx[index] = -1; + } + return; + } + + constexpr int num_buckets = calc_num_buckets(); + __shared__ Counter counter; + __shared__ IdxT histogram[num_buckets]; + + if (threadIdx.x == 0) { + counter.k = k; + counter.len = len; + counter.previous_len = len; + counter.kth_value_bits = 0; + counter.out_cnt = 0; + counter.out_back_cnt = 0; + } + __syncthreads(); + + // const size_t batch_id = blockIdx.x; // size_t to avoid multiplication + // overflow + const T *in_buf = nullptr; + const IdxT *in_idx_buf = nullptr; + T *out_buf = nullptr; + IdxT *out_idx_buf = nullptr; + + constexpr int num_passes = calc_num_passes(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + for (int pass = 0; pass < num_passes; ++pass) { + set_buf_pointers(in, in_idx, buf1, idx_buf1, buf2, idx_buf2, pass, in_buf, in_idx_buf, out_buf, + out_idx_buf); + + IdxT current_len = counter.len; + IdxT current_k = counter.k; + + filter_and_histogram_for_one_block( + in_buf, in_idx_buf, out_buf, out_idx_buf, out, out_idx, &counter, histogram, select_min, + pass); + __syncthreads(); + + scan(histogram); + __syncthreads(); + + choose_bucket(&counter, histogram, current_k, pass); + // scan_warp_version( + // warp, histogram, &counter, current_k, pass); + if (threadIdx.x == 0) { + counter.previous_len = current_len; + } + __syncthreads(); + + if (counter.len == counter.k || pass == num_passes - 1) { + last_filter(pass == 0 ? in : out_buf, + pass == 0 ? in_idx : out_idx_buf, out, out_idx, + current_len, k, &counter, select_min, pass); + break; + } + } // end for pass +} // end kernel + +template +__global__ void radix_topk_one_block_kernel(const T *in, const IdxT *in_idx, const IdxT len, + const IdxT k, T *out, IdxT *out_idx, + const bool select_min, T *buf1, IdxT *idx_buf1, T *buf2, + IdxT *idx_buf2, IdxT *lengths) { + const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow + IdxT actual_len = len; + if (lengths) { + actual_len = lengths[batch_id]; + } + + in += batch_id * len; + if (in_idx) { + in_idx += batch_id * len; + } + + out += batch_id * k; + out_idx += batch_id * k; + buf1 += batch_id * len; + idx_buf1 += batch_id * len; + buf2 += batch_id * len; + idx_buf2 += batch_id * len; + + radix_topk_one_block_func( + in, in_idx, actual_len, k, out, out_idx, select_min, buf1, idx_buf1, buf2, idx_buf2); +} // end kernel + +} // namespace topk + +/***************Runtime API****************/ + +template +void standalone_radix_topk_(void *buf, size_t &buf_size, const T *in, const IdxT *in_idx, + int batch_size, IdxT len, IdxT k, T *out, IdxT *out_idx, + bool select_min, bool fused_last_filter, unsigned grid_dim, + cudaStream_t stream, IdxT *lengths = nullptr) { + static_assert(topk::calc_num_passes() > 1); + constexpr int num_buckets = topk::calc_num_buckets(); + + topk::Counter *counters = nullptr; + IdxT *histograms = nullptr; + T *buf1 = nullptr; + IdxT *idx_buf1 = nullptr; + T *buf2 = nullptr; + IdxT *idx_buf2 = nullptr; + { + IdxT len_candidates = topk::calc_buf_len(len); + std::vector sizes = {sizeof(*counters) * batch_size, + sizeof(*histograms) * num_buckets * batch_size, + sizeof(*buf1) * len_candidates * batch_size, + sizeof(*idx_buf1) * len_candidates * batch_size, + sizeof(*buf2) * len_candidates * batch_size, + sizeof(*idx_buf2) * len_candidates * batch_size}; + size_t total_size = calc_aligned_size(sizes); + if (!buf) { + buf_size = total_size; + return; + } + + std::vector aligned_pointers = calc_aligned_pointers(buf, sizes); + counters = static_cast(aligned_pointers[0]); + histograms = static_cast(aligned_pointers[1]); + buf1 = static_cast(aligned_pointers[2]); + idx_buf1 = static_cast(aligned_pointers[3]); + buf2 = static_cast(aligned_pointers[4]); + idx_buf2 = static_cast(aligned_pointers[5]); + + cudaMemsetAsync( + buf, 0, static_cast(aligned_pointers[2]) - static_cast(aligned_pointers[0]), + stream); + } + + const T *in_buf = nullptr; + const IdxT *in_idx_buf = nullptr; + T *out_buf = nullptr; + IdxT *out_idx_buf = nullptr; + + dim3 blocks(grid_dim, batch_size); + + constexpr int num_passes = topk::calc_num_passes(); + + auto kernel = topk::radix_kernel; + + for (int pass = 0; pass < num_passes; ++pass) { + topk::set_buf_pointers(in, in_idx, buf1, idx_buf1, buf2, idx_buf2, pass, in_buf, in_idx_buf, + out_buf, out_idx_buf); + + if (fused_last_filter && pass == num_passes - 1 && out != nullptr) { + kernel = topk::radix_kernel; + } else if (fused_last_filter && pass == num_passes - 1 && out == nullptr) { + kernel = topk::radix_kernel; + } else if (out == nullptr) { + kernel = topk::radix_kernel; + } + + kernel<<>>(in, in_idx, in_buf, in_idx_buf, out_buf, out_idx_buf, + out, out_idx, counters, histograms, len, k, select_min, + pass, lengths); + } + + if (!fused_last_filter) { + if (out != nullptr) { + topk::last_filter_kernel<<>>( + in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters, select_min); + } else { + topk::last_filter_kernel<<>>( + in, in_idx, out_buf, out_idx_buf, out, out_idx, len, k, counters, select_min); + } + } +} + +template +void standalone_radix_topk_one_block_(void *buf, size_t &buf_size, const T *in, const IdxT *in_idx, + int batch_size, IdxT len, IdxT k, T *out, IdxT *out_idx, + bool select_min, cudaStream_t stream, + IdxT *lengths = nullptr) { + static_assert(topk::calc_num_passes() > 1); + + T *buf1 = nullptr; + IdxT *idx_buf1 = nullptr; + T *buf2 = nullptr; + IdxT *idx_buf2 = nullptr; + { + std::vector sizes = { + sizeof(*buf1) * len * batch_size, sizeof(*idx_buf1) * len * batch_size, + sizeof(*buf2) * len * batch_size, sizeof(*idx_buf2) * len * batch_size}; + size_t total_size = calc_aligned_size(sizes); + if (!buf) { + buf_size = total_size; + return; + } + + std::vector aligned_pointers = calc_aligned_pointers(buf, sizes); + buf1 = static_cast(aligned_pointers[0]); + idx_buf1 = static_cast(aligned_pointers[1]); + buf2 = static_cast(aligned_pointers[2]); + idx_buf2 = static_cast(aligned_pointers[3]); + } + + if (out != nullptr) { + topk::radix_topk_one_block_kernel + <<>>(in, in_idx, len, k, out, out_idx, select_min, buf1, + idx_buf1, buf2, idx_buf2, lengths); + } else { + topk::radix_topk_one_block_kernel + <<>>(in, in_idx, len, k, out, out_idx, select_min, buf1, + idx_buf1, buf2, idx_buf2, lengths); + } +} + +template +void standalone_topk(void *buf, size_t &buf_size, const T *in, int batch_size, idxT len, idxT k, + T *out, idxT *out_idx, bool greater, cudaStream_t stream = 0, + idxT *lengths = nullptr, bool is_prefill = false) { + constexpr int items_per_thread = 32; + constexpr int multi_block_dim = 256; + constexpr int single_block_dim = 1024; + constexpr bool fused_last_filter = false; + if (len <= single_block_dim * items_per_thread || is_prefill) { + standalone_radix_topk_one_block_( + buf, buf_size, in, static_cast(nullptr), batch_size, len, k, out, out_idx, !greater, + stream, lengths); + } else { + // Cache sm_cnt per device to avoid repeated host-side queries. + static int cached_dev = -1; + static int cached_sm_cnt = -1; + int sm_cnt; + { + int dev; + NVTE_CHECK_CUDA(cudaGetDevice(&dev)); + if (dev != cached_dev) { + NVTE_CHECK_CUDA( + cudaDeviceGetAttribute(&cached_sm_cnt, cudaDevAttrMultiProcessorCount, dev)); + cached_dev = dev; + } + sm_cnt = cached_sm_cnt; + } + unsigned grid_dim = topk::calc_grid_dim(batch_size, len, sm_cnt); + + if (grid_dim == 1) { + standalone_radix_topk_one_block_( + buf, buf_size, in, static_cast(nullptr), batch_size, len, k, out, out_idx, + !greater, stream, lengths); + } else { + standalone_radix_topk_( + buf, buf_size, in, static_cast(nullptr), batch_size, len, k, out, out_idx, + !greater, fused_last_filter, grid_dim, stream, lengths); + } + } +} +} // namespace nv diff --git a/transformer_engine/common/util/topk.cu b/transformer_engine/common/util/topk.cu new file mode 100644 index 0000000000..196d5d9fae --- /dev/null +++ b/transformer_engine/common/util/topk.cu @@ -0,0 +1,57 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "../common.h" +#include "standalone_topk.cuh" + +void nvte_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor lengths_in, + NVTETensor keys_out, NVTETensor indices_out, NVTETensor workspace, int batch_size, + int seq_len, int k, size_t workspace_bytes) { + NVTE_API_CALL(nvte_topk); + using namespace transformer_engine; + + const Tensor *keys_in_tensor = convertNVTETensorCheck(keys_in); + const Tensor *lengths_tensor = convertNVTETensorCheck(lengths_in); + Tensor *keys_out_tensor = convertNVTETensor(keys_out); + Tensor *indices_tensor = convertNVTETensor(indices_out); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + void *d_workspace = workspace_tensor->data.dptr; + const int *d_lengths = reinterpret_cast(lengths_tensor->data.dptr); + int *d_indices = reinterpret_cast(indices_tensor->data.dptr); + + auto dtype = keys_in_tensor->data.dtype; + +#define DISPATCH_TOPK(T, d_in_cast, d_out_cast) \ + do { \ + const T *d_in = reinterpret_cast(keys_in_tensor->data.dptr); \ + T *d_out = reinterpret_cast(keys_out_tensor->data.dptr); \ + nv::standalone_topk(d_workspace, workspace_bytes, d_in, batch_size, seq_len, k, d_out, \ + d_indices, /*greater=*/true, stream, const_cast(d_lengths), \ + /*is_prefill=*/false); \ + } while (0) + + if (dtype == DType::kBFloat16) { + DISPATCH_TOPK(__nv_bfloat16, , ); + } else if (dtype == DType::kFloat32) { + DISPATCH_TOPK(float, , ); + } else { + NVTE_ERROR("nvte_topk: unsupported key dtype (supported: float32, bfloat16)"); + } + +#undef DISPATCH_TOPK +} + +size_t nvte_get_topk_workspace_bytes(int batch_size, int seq_len, int k) { + // Call with buf=nullptr to perform a size query (no GPU work is launched). + size_t buf_size = 0; + nv::standalone_topk(nullptr, buf_size, nullptr, batch_size, seq_len, k, nullptr, + nullptr, /*greater=*/true, /*stream=*/nullptr, + /*lengths=*/nullptr, /*is_prefill=*/false); + return buf_size; +} diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index d203fcea9d..fe1f93dc7a 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -10,3 +10,4 @@ from .softmax import * from .gemm import * from .router import * +from .topk import * diff --git a/transformer_engine/jax/cpp_extensions/topk.py b/transformer_engine/jax/cpp_extensions/topk.py new file mode 100644 index 0000000000..235540e354 --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/topk.py @@ -0,0 +1,135 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""TopK custom op""" + +import functools +from typing import Tuple + +import jax +import jax.numpy as jnp +from jax import dtypes, ffi + +from .base import BasePrimitive, register_primitive + +__all__ = ["topk"] + + +@functools.lru_cache(maxsize=512) +def get_topk_workspace_bytes(batch_size: int, seq_len: int, k: int) -> int: + """Query the workspace size required for TopK. + + The result is memoised per (batch_size, seq_len, k) tuple so that repeated + JIT compilations with the same shapes incur only one host-side CUDA call. + """ + import transformer_engine_jax as _te_jax + + return int(_te_jax.get_topk_workspace_bytes(batch_size, seq_len, k)) + + +class TopKPrimitive(BasePrimitive): + """ + TopK Primitive + + Selects the top-k entries (by value) from each row of a 2-D input using the + AIR radix-selection algorithm. Returns both the top-k key values and their + column indices within each row. + """ + + name = "te_topk_ffi" + multiple_results = True + impl_static_args = (2,) # k_value + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + in_keys_aval, + in_lengths_aval, + *, + k_value, + ): + keys_dtype = dtypes.canonicalize_dtype(in_keys_aval.dtype) + assert keys_dtype in [ + jnp.float32, + jnp.bfloat16, + ], f"topk: unsupported key dtype {keys_dtype}; supported: float32, bfloat16" + assert in_keys_aval.ndim == 2, "topk: keys input must be 2D (batch_size, seq_len)" + assert dtypes.canonicalize_dtype(in_lengths_aval.dtype) == jnp.int32 + + batch_size, seq_len = in_keys_aval.shape + workspace_bytes = get_topk_workspace_bytes(batch_size, seq_len, k_value) + + out_shape = (batch_size, k_value) + out_keys_aval = jax.core.ShapedArray(shape=out_shape, dtype=keys_dtype) + out_indices_aval = jax.core.ShapedArray(shape=out_shape, dtype=jnp.int32) + workspace_aval = jax.core.ShapedArray(shape=(workspace_bytes,), dtype=jnp.uint8) + return (out_keys_aval, out_indices_aval, workspace_aval) + + @staticmethod + def outer_abstract(*args, **kwargs): + out_keys_aval, out_indices_aval, _workspace_aval = TopKPrimitive.abstract(*args, **kwargs) + return (out_keys_aval, out_indices_aval) + + @staticmethod + def lowering(ctx, in_keys, in_lengths, k_value): + keys_aval = ctx.avals_in[0] + batch_size, seq_len = keys_aval.shape + workspace_bytes = get_topk_workspace_bytes(batch_size, seq_len, k_value) + return ffi.ffi_lowering(TopKPrimitive.name)( + ctx, + in_keys, + in_lengths, + k_value=k_value, + workbuf_bytes=workspace_bytes, + ) + + @staticmethod + def impl(in_keys, in_lengths, k_value): + assert TopKPrimitive.inner_primitive is not None + out_keys, out_indices, _workspace = TopKPrimitive.inner_primitive.bind( + in_keys, + in_lengths, + k_value=k_value, + ) + return (out_keys, out_indices) + + +register_primitive(TopKPrimitive) + + +def topk( + x: jnp.ndarray, + k_value: int, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Select the top-k largest entries from each row using the AIR radix algorithm. + + Args: + x: Input array of shape ``(batch_size, seq_len)`` or ``(seq_len,)``. + Supported dtypes: ``float32``, ``bfloat16``. + k_value: Number of top entries to select per row. + + Returns: + A tuple ``(values, indices)`` where both arrays have shape + ``(batch_size, k_value)`` (or ``(k_value,)`` for 1-D input). The + outputs are *unordered*: use ``jax.lax.sort_key_val`` if a sorted result + is required. ``indices`` are the column positions within the original row. + """ + squeezed = x.ndim == 1 + if squeezed: + x = x[jnp.newaxis, :] # (1, seq_len) + + batch_size, seq_len = x.shape + lengths = jnp.full((batch_size,), seq_len, dtype=jnp.int32) + + out_keys, out_indices = TopKPrimitive.outer_primitive.bind( + x, + lengths, + k_value=k_value, + ) + + if squeezed: + out_keys = out_keys[0] # (k_value,) + out_indices = out_indices[0] # (k_value,) + + return out_keys, out_indices diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 0fe4e99239..cc6d3a7e66 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -171,6 +171,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); +// TopK +XLA_FFI_DECLARE_HANDLER_SYMBOL(TopkHandler); +int64_t GetTopkWorkspaceBytes(int batch_size, int seq_len, int k); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 28cb39b5d1..3308fa6c4a 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -100,6 +100,9 @@ pybind11::dict Registrations() { dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler); dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler); + // TopK + dict["te_topk_ffi"] = EncapsulateFFI(TopkHandler); + return dict; } @@ -117,6 +120,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_norm_bwd_workspace_sizes", &GetNormBackwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); + m.def("get_topk_workspace_bytes", &GetTopkWorkspaceBytes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator); diff --git a/transformer_engine/jax/csrc/extensions/topk.cpp b/transformer_engine/jax/csrc/extensions/topk.cpp new file mode 100644 index 0000000000..c4e80db827 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/topk.cpp @@ -0,0 +1,88 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/topk.h" + +#include "../extensions.h" +#include "xla/ffi/api/c_api.h" + +namespace transformer_engine { +namespace jax { + +// --------------------------------------------------------------------------- +// JAX FFI handler +// --------------------------------------------------------------------------- + +Error_Type TopkFFI(cudaStream_t stream, Buffer_Type keys_in_buf, Buffer_Type lengths_buf, + Result_Type keys_out_buf, Result_Type indices_out_buf, Result_Type workspace_buf, + int64_t k_value, int64_t workbuf_bytes) { + auto keys_in_dtype = convert_ffi_datatype_to_te_dtype(keys_in_buf.element_type()); + auto keys_out_dtype = convert_ffi_datatype_to_te_dtype(keys_out_buf->element_type()); + auto idx_out_dtype = convert_ffi_datatype_to_te_dtype(indices_out_buf->element_type()); + NVTE_CHECK(keys_in_dtype == keys_out_dtype, "TopkFFI: input and output key dtypes must match"); + NVTE_CHECK(idx_out_dtype == DType::kInt32, "TopkFFI: index output must be int32"); + + auto keys_in_shape = keys_in_buf.dimensions(); + NVTE_CHECK(keys_in_shape.size() == 2, "TopkFFI: keys input must be 2D (batch_size, seq_len)"); + + int batch_size = static_cast(keys_in_shape[0]); + int seq_len = static_cast(keys_in_shape[1]); + int k = static_cast(k_value); + + // Validate key dtype (float32 and bfloat16 only). + switch (keys_in_dtype) { + case DType::kFloat32: + case DType::kBFloat16: + break; + default: + NVTE_ERROR("TopkFFI: unsupported key dtype (float32 and bfloat16 only)"); + } + + // Build flat TensorWrappers over the full (batch_size * seq_len) / (batch_size * k) buffers. + auto flat_in_shape = + std::vector{static_cast(batch_size) * static_cast(seq_len)}; + auto flat_out_shape = + std::vector{static_cast(batch_size) * static_cast(k)}; + auto len_shape = std::vector{static_cast(batch_size)}; + auto ws_shape = std::vector{static_cast(workbuf_bytes)}; + + auto keys_in_tensor = TensorWrapper(keys_in_buf.untyped_data(), flat_in_shape, keys_in_dtype); + auto lengths_tensor = TensorWrapper(lengths_buf.untyped_data(), len_shape, DType::kInt32); + auto keys_out_tensor = + TensorWrapper(keys_out_buf->untyped_data(), flat_out_shape, keys_out_dtype); + auto idx_out_tensor = + TensorWrapper(indices_out_buf->untyped_data(), flat_out_shape, DType::kInt32); + auto workspace_tensor = TensorWrapper(workspace_buf->untyped_data(), ws_shape, DType::kByte); + + nvte_topk(stream, keys_in_tensor.data(), lengths_tensor.data(), keys_out_tensor.data(), + idx_out_tensor.data(), workspace_tensor.data(), batch_size, seq_len, k, + static_cast(workbuf_bytes)); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(TopkHandler, TopkFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // keys_in + .Arg() // lengths + .Ret() // keys_out + .Ret() // indices_out + .Ret() // workspace + .Attr("k_value") + .Attr("workbuf_bytes"), + FFI_CudaGraph_Traits); + +// --------------------------------------------------------------------------- +// Workspace-size query exposed to Python +// --------------------------------------------------------------------------- + +int64_t GetTopkWorkspaceBytes(int batch_size, int seq_len, int k) { + return static_cast(nvte_get_topk_workspace_bytes(batch_size, seq_len, k)); +} + +} // namespace jax +} // namespace transformer_engine