Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense, grouped_dense
from transformer_engine.jax.layernorm_dense import layernorm_dense
from transformer_engine.jax.cpp_extensions.topk import topk

GEMM_CASES = [
(256, 256, 512),
Expand Down Expand Up @@ -1955,3 +1956,83 @@ def f(x):
actual = load_array_dump("my_tensor_gpu0.bin", shape, dtype)

assert_allclose(actual, expected, dtype=dtype)


@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float32])
@pytest.mark.parametrize(
"problem_size",
[
(1, 10000, 100),
(1, 50000, 200),
(4, 16384, 256),
(8, 65536, 512),
(1, 1000000, 1000),
Comment thread
jberchtold-nvidia marked this conversation as resolved.
],
)
class TestTopK:
"""Correctness tests for the TopK JAX primitive.

Each test generates an input whose top-k entries lie in a known value range
so that correctness can be verified without a full sort, then cross-checks
against jax.lax.top_k as a reference.
"""

def test_topk_1d(self, dtype, problem_size):
"""1-D input: single row."""
_bs, n, k = problem_size

prng_key = jax.random.PRNGKey(0)
keys = jax.random.split(prng_key, 3)
topk_vals = jax.random.uniform(keys[0], shape=(k,), dtype=dtype, minval=1.5, maxval=2.5)
bottom_vals = jax.random.uniform(
keys[1], shape=(n - k,), dtype=dtype, minval=0.0, maxval=1.0
)
x = jax.random.permutation(keys[2], jnp.concatenate([topk_vals, bottom_vals]))

ref_vals, ref_idx = jax.jit(jax.lax.top_k, static_argnums=(1,))(x, k)
prim_vals, prim_idx = jax.jit(topk, static_argnums=(1,))(x, k)

# AIR TopK output is unordered; sort before comparing.
ref_vals, ref_idx = jax.lax.sort_key_val(ref_vals, ref_idx)
prim_vals, prim_idx = jax.lax.sort_key_val(prim_vals, prim_idx)

assert_allclose(prim_vals, ref_vals, dtype=dtype)

sorted_x = jax.lax.sort(x)
assert prim_vals[0] >= sorted_x[-(k + 1)]

# Values at returned indices must match reference.
assert_allclose(x[prim_idx], x[ref_idx], dtype=dtype)

def test_topk_2d(self, dtype, problem_size):
"""2-D input: each row is an independent top-k problem."""
bs, n, k = problem_size

prng_key = jax.random.PRNGKey(42)
keys = jax.random.split(prng_key, 3)
topk_vals = jax.random.uniform(keys[0], shape=(bs, k), dtype=dtype, minval=1.5, maxval=2.5)
bottom_vals = jax.random.uniform(
keys[1], shape=(bs, n - k), dtype=dtype, minval=0.0, maxval=1.0
)
x_unsorted = jnp.concatenate([topk_vals, bottom_vals], axis=1)
# Shuffle columns independently per row.
col_perm = jax.random.permutation(keys[2], n)
x = x_unsorted[:, col_perm]

ref_vals, ref_idx = jax.jit(jax.lax.top_k, static_argnums=(1,))(x, k)
prim_vals, prim_idx = jax.jit(topk, static_argnums=(1,))(x, k)

# Sort each row independently for comparison.
ref_vals, ref_idx = jax.vmap(jax.lax.sort_key_val)(ref_vals, ref_idx)
prim_vals, prim_idx = jax.vmap(jax.lax.sort_key_val)(prim_vals, prim_idx)

assert_allclose(prim_vals, ref_vals, dtype=dtype)

# For each row, the smallest selected value must be >= the (k+1)-th largest in that row.
sorted_x = jnp.sort(x, axis=1)
assert jnp.all(prim_vals[:, 0] >= sorted_x[:, -(k + 1)])

# Values at returned indices must match reference values.
prim_gathered = jnp.take_along_axis(x, prim_idx, axis=1)
ref_gathered = jnp.take_along_axis(x, ref_idx, axis=1)
assert_allclose(prim_gathered, ref_gathered, dtype=dtype)
6 changes: 6 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ if(NOT arch_120_index EQUAL -1)
endif()
endif()

# If all architectures were special-cased and removed, disable CMake's automatic
# CUDA_ARCHITECTURES management — compilation flags are set via COMPILE_OPTIONS below.
if(NOT CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES OFF)
endif()
Comment on lines +82 to +86
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not needed for this PR.

# cuDNN frontend API
set(CUDNN_FRONTEND_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include")
Expand Down Expand Up @@ -151,6 +156,7 @@ list(APPEND transformer_engine_cuda_sources
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/padding.cu
util/topk.cu
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
Expand Down
57 changes: 57 additions & 0 deletions transformer_engine/common/include/transformer_engine/topk.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#ifndef TRANSFORMER_ENGINE_TOPK_H_
#define TRANSFORMER_ENGINE_TOPK_H_

#include "transformer_engine.h"

#ifdef __cplusplus
extern "C" {
#endif

/*! \brief Compute the top-K (key, index) pairs using the AIR radix algorithm.
*
* Operates on a batch of rows: each row of length \p seq_len is processed
* independently and the \p k largest entries are selected.
*
* \param[in] stream CUDA stream used for the operation.
* \param[in] keys_in Input keys tensor, flat storage for
* batch_size rows of seq_len elements.
* \param[in] lengths_in Per-row lengths, shape (batch_size,); int32.
* Fill with seq_len for uniform-length batches.
* \param[in,out] keys_out Output top-k keys, flat storage for
* batch_size rows of k elements.
* \param[in,out] indices_out Output top-k indices (within each row),
* flat storage for batch_size rows of k int32 elements.
* \param[in,out] workspace Workspace tensor, shape (workspace_bytes,).
* \param[in] batch_size Number of rows.
* \param[in] seq_len Number of elements per row.
* \param[in] k Number of top-K entries to select per row.
* \param[in] workspace_bytes Workspace size in bytes; must be >=
* nvte_get_topk_workspace_bytes(batch_size, seq_len, k).
*
* Supported key dtypes: float32, bfloat16.
* Index dtype: int32.
*/
void nvte_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor lengths_in,
NVTETensor keys_out, NVTETensor indices_out, NVTETensor workspace, int batch_size,
int seq_len, int k, size_t workspace_bytes);

/*! \brief Query the workspace size required by nvte_topk.
*
* \param[in] batch_size Number of rows.
* \param[in] seq_len Number of elements per row.
* \param[in] k Top-K count.
* \return Required workspace size in bytes.
*/
size_t nvte_get_topk_workspace_bytes(int batch_size, int seq_len, int k);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the other parts of TE we follow the convention of running the main function with empty workspace to get the size, rather than a specialized function, see e.g. the layernorm functions. Could we make that consistent?


#ifdef __cplusplus
} // extern "C"
#endif

#endif // TRANSFORMER_ENGINE_TOPK_H_
Loading
Loading