-
Notifications
You must be signed in to change notification settings - Fork 700
Add optimised top-k kernel AIR. #2890
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
449811e
470be3c
1e6c976
1b328a2
897156e
0862bca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -79,6 +79,11 @@ if(NOT arch_120_index EQUAL -1) | |
| endif() | ||
| endif() | ||
|
|
||
| # If all architectures were special-cased and removed, disable CMake's automatic | ||
| # CUDA_ARCHITECTURES management — compilation flags are set via COMPILE_OPTIONS below. | ||
| if(NOT CMAKE_CUDA_ARCHITECTURES) | ||
| set(CMAKE_CUDA_ARCHITECTURES OFF) | ||
| endif() | ||
|
Comment on lines
+82
to
+86
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is not needed for this PR. |
||
| # cuDNN frontend API | ||
| set(CUDNN_FRONTEND_INCLUDE_DIR | ||
| "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include") | ||
|
|
@@ -151,6 +156,7 @@ list(APPEND transformer_engine_cuda_sources | |
| normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu | ||
| permutation/permutation.cu | ||
| util/padding.cu | ||
| util/topk.cu | ||
| swizzle/swizzle.cu | ||
| swizzle/swizzle_block_scaling.cu | ||
| fused_softmax/scaled_masked_softmax.cu | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| #ifndef TRANSFORMER_ENGINE_TOPK_H_ | ||
| #define TRANSFORMER_ENGINE_TOPK_H_ | ||
|
|
||
| #include "transformer_engine.h" | ||
|
|
||
| #ifdef __cplusplus | ||
| extern "C" { | ||
| #endif | ||
|
|
||
| /*! \brief Compute the top-K (key, index) pairs using the AIR radix algorithm. | ||
| * | ||
| * Operates on a batch of rows: each row of length \p seq_len is processed | ||
| * independently and the \p k largest entries are selected. | ||
| * | ||
| * \param[in] stream CUDA stream used for the operation. | ||
| * \param[in] keys_in Input keys tensor, flat storage for | ||
| * batch_size rows of seq_len elements. | ||
| * \param[in] lengths_in Per-row lengths, shape (batch_size,); int32. | ||
| * Fill with seq_len for uniform-length batches. | ||
| * \param[in,out] keys_out Output top-k keys, flat storage for | ||
| * batch_size rows of k elements. | ||
| * \param[in,out] indices_out Output top-k indices (within each row), | ||
| * flat storage for batch_size rows of k int32 elements. | ||
| * \param[in,out] workspace Workspace tensor, shape (workspace_bytes,). | ||
| * \param[in] batch_size Number of rows. | ||
| * \param[in] seq_len Number of elements per row. | ||
| * \param[in] k Number of top-K entries to select per row. | ||
| * \param[in] workspace_bytes Workspace size in bytes; must be >= | ||
| * nvte_get_topk_workspace_bytes(batch_size, seq_len, k). | ||
| * | ||
| * Supported key dtypes: float32, bfloat16. | ||
| * Index dtype: int32. | ||
| */ | ||
| void nvte_topk(cudaStream_t stream, const NVTETensor keys_in, const NVTETensor lengths_in, | ||
| NVTETensor keys_out, NVTETensor indices_out, NVTETensor workspace, int batch_size, | ||
| int seq_len, int k, size_t workspace_bytes); | ||
|
|
||
| /*! \brief Query the workspace size required by nvte_topk. | ||
| * | ||
| * \param[in] batch_size Number of rows. | ||
| * \param[in] seq_len Number of elements per row. | ||
| * \param[in] k Top-K count. | ||
| * \return Required workspace size in bytes. | ||
| */ | ||
| size_t nvte_get_topk_workspace_bytes(int batch_size, int seq_len, int k); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the other parts of TE we follow the convention of running the main function with empty workspace to get the size, rather than a specialized function, see e.g. the layernorm functions. Could we make that consistent? |
||
|
|
||
| #ifdef __cplusplus | ||
| } // extern "C" | ||
| #endif | ||
|
|
||
| #endif // TRANSFORMER_ENGINE_TOPK_H_ | ||
Uh oh!
There was an error while loading. Please reload this page.