Skip to content

Add optimised top-k kernel AIR.#2890

Open
dcampora wants to merge 6 commits intoNVIDIA:mainfrom
dcampora:feature/air-topk
Open

Add optimised top-k kernel AIR.#2890
dcampora wants to merge 6 commits intoNVIDIA:mainfrom
dcampora:feature/air-topk

Conversation

@dcampora
Copy link
Copy Markdown

@dcampora dcampora commented Apr 16, 2026

Description

Adds a custom AIR TopK implementation (header-only, vendored into
transformer_engine/common/util/) exposed as a JAX FFI custom call
via the TE JAX extension.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • transformer_engine/common/util/air_topk.cu: AIR TopK CUDA kernel
  • transformer_engine/common/util/standalone_air_topk.cuh: vendored header (AIR TopK, header-only)
  • transformer_engine/common/include/transformer_engine/air_topk.h: C API
  • transformer_engine/jax/csrc/extensions/air_topk.cpp: JAX FFI binding
  • transformer_engine/jax/cpp_extensions/air_topk.py: Python wrapper
  • transformer_engine/common/CMakeLists.txt: compile new kernel; use CCCL from CUDA toolkit; fix SM100 arch handling when all arches are special-cased

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

dcampora and others added 2 commits April 16, 2026 04:14
Adds a custom AIR TopK implementation (header-only, vendored into
transformer_engine/common/util/) exposed as a JAX FFI custom call
via the TE JAX extension.

Key changes:
- transformer_engine/common/util/air_topk.cu: AIR TopK CUDA kernel
- transformer_engine/common/util/standalone_air_topk.cuh: vendored header
- transformer_engine/common/include/transformer_engine/air_topk.h: C API
- transformer_engine/jax/csrc/extensions/air_topk.cpp: JAX FFI binding
- transformer_engine/jax/cpp_extensions/air_topk.py: Python wrapper
- CMakeLists.txt: compile new kernel; use CCCL from CUDA toolkit
- CMakeLists.txt: fix SM100 arch handling when all arches are special-cased

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: dcampora <961215+dcampora@users.noreply.github.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 16, 2026

Greptile Summary

This PR vendored the AIR radix-selection top-K algorithm as a new CUDA kernel (standalone_topk.cuh / topk.cu), exposed it via a C API (topk.h), and wired it up as a JAX FFI custom call (topk.cpp / topk.py) with correctness tests. All previous review concerns (namespace pollution for constants, unused variable, SM-count caching) have been addressed. The remaining findings are all P2 style/hygiene issues that do not block correctness.

Confidence Score: 5/5

Safe to merge; all findings are P2 style/hygiene issues that do not affect correctness or runtime behaviour.

All P0/P1 concerns from prior review rounds have been addressed. The four remaining comments are P2: two style issues in the vendored header (global-namespace helpers, dead code with magic constants), one unresolved UB TODO in a union that nvcc handles correctly in practice, and one missing Python-level guard for k > seq_len that the kernel already handles gracefully.

transformer_engine/common/util/standalone_topk.cuh (three minor P2 issues); transformer_engine/jax/cpp_extensions/topk.py (missing k <= seq_len guard)

Important Files Changed

Filename Overview
transformer_engine/common/util/standalone_topk.cuh Vendored AIR radix top-K header; core algorithm looks correct. Three P2 issues: global-namespace helpers, dead scan_warp_version with magic constants, and an unresolved UB TODO on a union.
transformer_engine/common/util/topk.cu Thin dispatch layer over standalone_topk; workspace size query always uses float which over-allocates for bfloat16 (safe). Error handling and dtype dispatch look correct.
transformer_engine/jax/cpp_extensions/topk.py Clean JAX FFI wrapper with workspace memoisation; missing k_value <= seq_len guard in abstract.
transformer_engine/jax/csrc/extensions/topk.cpp Well-structured JAX FFI handler; dtype validation, shape checks, and workspace-size plumbing are all correct.
transformer_engine/common/include/transformer_engine/topk.h New C API header; doc-comments are accurate and the extern "C" guards are properly applied.
tests/jax/test_custom_call_compute.py Adds TestTopK with 1-D and 2-D correctness checks using jax.lax.top_k as reference; cross-validates sorted values and gathered indices.
transformer_engine/common/CMakeLists.txt Adds util/topk.cu to the CUDA sources and adds CMAKE_CUDA_ARCHITECTURES OFF guard for the all-special-cased-arch edge case.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["topk(x, k_value)\nPython API"] --> B{"x.ndim == 1?"}
    B -->|yes| C["unsqueeze → (1, seq_len)"]
    B -->|no| D["(batch_size, seq_len)"]
    C --> D
    D --> E["TopKPrimitive.outer_primitive.bind()\nlengths = full(batch_size, seq_len, int32)"]
    E --> F["TopkFFI (C++)\nJAX FFI handler"]
    F --> G["nvte_topk (C API)\ntopk.cu"]
    G --> H{"len ≤ 32768?"}
    H -->|yes – one-block| I["radix_topk_one_block_kernel\n<<<batch_size, 1024>>>"]
    H -->|no – multi-block| J["calc_grid_dim → grid_dim\n(cached sm_cnt)"]
    J --> K["radix_kernel loop\n<<<grid_dim × batch, 256>>>"]
    K --> L["last_filter_kernel"]
    I --> M["out_keys (batch, k)\nout_indices (batch, k)"]
    L --> M
    M --> N{"squeezed?"}
    N -->|yes| O["squeeze → (k,)"]
    N -->|no| P["return (values, indices)"]
    O --> P
Loading

Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/jax/csrc/extensions/air_topk.cpp Outdated
Comment thread transformer_engine/common/util/standalone_air_topk.cuh Outdated
Comment thread transformer_engine/common/util/standalone_topk.cuh Outdated
…ing export, cache sm_cnt

- Move WARP_SIZE/WARP_BITS/FULL_WARP_MASK/VECTORIZED_READ_SIZE into namespace nv
- Remove unused keys_element_bytes variable in AirTopkFFI; collapse switch to dtype validation
- Add missing `from .air_topk import *` export in jax/cpp_extensions/__init__.py
- Cache sm_cnt per device with static vars to avoid repeated cudaGetDevice/cudaDeviceGetAttribute calls
- Add CMAKE_BUILD_WITH_INSTALL_RPATH=ON to build_ext.py

Signed-off-by: dcampora <961215+dcampora@users.noreply.github.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
pre-commit-ci bot and others added 3 commits April 16, 2026 05:18
Remove the `air_` prefix from all TopK-related identifiers: file names,
C API functions (nvte_air_topk -> nvte_topk), FFI handler/primitive names
(te_air_topk_ffi -> te_topk_ffi), Python symbols, and the internal
`air_topk` namespace in standalone_topk.cuh.  No functional changes.

Signed-off-by: Diego Campora <dcampora@nvidia.com>
Signed-off-by: dcampora <961215+dcampora@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant