Conversation
Adds a custom AIR TopK implementation (header-only, vendored into transformer_engine/common/util/) exposed as a JAX FFI custom call via the TE JAX extension. Key changes: - transformer_engine/common/util/air_topk.cu: AIR TopK CUDA kernel - transformer_engine/common/util/standalone_air_topk.cuh: vendored header - transformer_engine/common/include/transformer_engine/air_topk.h: C API - transformer_engine/jax/csrc/extensions/air_topk.cpp: JAX FFI binding - transformer_engine/jax/cpp_extensions/air_topk.py: Python wrapper - CMakeLists.txt: compile new kernel; use CCCL from CUDA toolkit - CMakeLists.txt: fix SM100 arch handling when all arches are special-cased Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: dcampora <961215+dcampora@users.noreply.github.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR vendored the AIR radix-selection top-K algorithm as a new CUDA kernel ( Confidence Score: 5/5Safe to merge; all findings are P2 style/hygiene issues that do not affect correctness or runtime behaviour. All P0/P1 concerns from prior review rounds have been addressed. The four remaining comments are P2: two style issues in the vendored header (global-namespace helpers, dead code with magic constants), one unresolved UB TODO in a union that nvcc handles correctly in practice, and one missing Python-level guard for k > seq_len that the kernel already handles gracefully. transformer_engine/common/util/standalone_topk.cuh (three minor P2 issues); transformer_engine/jax/cpp_extensions/topk.py (missing k <= seq_len guard) Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["topk(x, k_value)\nPython API"] --> B{"x.ndim == 1?"}
B -->|yes| C["unsqueeze → (1, seq_len)"]
B -->|no| D["(batch_size, seq_len)"]
C --> D
D --> E["TopKPrimitive.outer_primitive.bind()\nlengths = full(batch_size, seq_len, int32)"]
E --> F["TopkFFI (C++)\nJAX FFI handler"]
F --> G["nvte_topk (C API)\ntopk.cu"]
G --> H{"len ≤ 32768?"}
H -->|yes – one-block| I["radix_topk_one_block_kernel\n<<<batch_size, 1024>>>"]
H -->|no – multi-block| J["calc_grid_dim → grid_dim\n(cached sm_cnt)"]
J --> K["radix_kernel loop\n<<<grid_dim × batch, 256>>>"]
K --> L["last_filter_kernel"]
I --> M["out_keys (batch, k)\nout_indices (batch, k)"]
L --> M
M --> N{"squeezed?"}
N -->|yes| O["squeeze → (k,)"]
N -->|no| P["return (values, indices)"]
O --> P
Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
…ing export, cache sm_cnt - Move WARP_SIZE/WARP_BITS/FULL_WARP_MASK/VECTORIZED_READ_SIZE into namespace nv - Remove unused keys_element_bytes variable in AirTopkFFI; collapse switch to dtype validation - Add missing `from .air_topk import *` export in jax/cpp_extensions/__init__.py - Cache sm_cnt per device with static vars to avoid repeated cudaGetDevice/cudaDeviceGetAttribute calls - Add CMAKE_BUILD_WITH_INSTALL_RPATH=ON to build_ext.py Signed-off-by: dcampora <961215+dcampora@users.noreply.github.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
36e8405 to
1e6c976
Compare
for more information, see https://pre-commit.ci
Remove the `air_` prefix from all TopK-related identifiers: file names, C API functions (nvte_air_topk -> nvte_topk), FFI handler/primitive names (te_air_topk_ffi -> te_topk_ffi), Python symbols, and the internal `air_topk` namespace in standalone_topk.cuh. No functional changes. Signed-off-by: Diego Campora <dcampora@nvidia.com> Signed-off-by: dcampora <961215+dcampora@users.noreply.github.com>
for more information, see https://pre-commit.ci
Description
Adds a custom AIR TopK implementation (header-only, vendored into
transformer_engine/common/util/) exposed as a JAX FFI custom callvia the TE JAX extension.
Type of change
Changes
transformer_engine/common/util/air_topk.cu: AIR TopK CUDA kerneltransformer_engine/common/util/standalone_air_topk.cuh: vendored header (AIR TopK, header-only)transformer_engine/common/include/transformer_engine/air_topk.h: C APItransformer_engine/jax/csrc/extensions/air_topk.cpp: JAX FFI bindingtransformer_engine/jax/cpp_extensions/air_topk.py: Python wrappertransformer_engine/common/CMakeLists.txt: compile new kernel; use CCCL from CUDA toolkit; fix SM100 arch handling when all arches are special-casedChecklist: