From 4af3cee9adca352ed4058eb78a67659be3838d85 Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Mon, 1 Jun 2026 12:55:56 +0800 Subject: [PATCH 1/8] feat(common): add dense topk index router output Signed-off-by: Harry Zhou --- .../fused_topk_with_score_function.cu | 390 ++++++++++++++++-- .../common/fused_router/utils.h | 6 +- .../include/transformer_engine/fused_router.h | 20 + transformer_engine/pytorch/csrc/extensions.h | 4 +- .../pytorch/csrc/extensions/pybind.cpp | 4 +- .../pytorch/csrc/extensions/router.cpp | 69 +++- transformer_engine/pytorch/router.py | 40 +- 7 files changed, 462 insertions(+), 71 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 57868266f2..effaef320a 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -9,6 +9,7 @@ #include #include +#include #include #include "../common.h" @@ -26,14 +27,12 @@ namespace fused_router { // ============================================================================= template -__global__ void fused_topk_forward_simple_kernel(const DataType *logits, int num_tokens, - int num_experts, int topk, bool use_pre_softmax, - int num_groups, int group_topk, - float scaling_factor, int score_function, - const BiasType *expert_bias, DataType *probs, - uint8_t *routing_map, - CompType *intermediate_output) { + TopkFuncType TopkFunc = TopkFuncType::Naive, typename IndexType = int32_t> +__global__ void fused_topk_forward_simple_kernel( + const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, + int num_groups, int group_topk, float scaling_factor, int score_function, + const BiasType *expert_bias, DataType *probs, uint8_t *routing_map, + CompType *intermediate_output, IndexType *topk_indices_output) { constexpr bool kIsBitmap = (RoutingMapFormat == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8); int num_token_per_block = blockDim.x / kThreadsPerWarp; int warp_id = threadIdx.x / kThreadsPerWarp; @@ -77,13 +76,15 @@ __global__ void fused_topk_forward_simple_kernel(const DataType *logits, int num intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); } } - if constexpr (!kIsBitmap) { - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - routing_map[pos_offset + i] = 0; - } - } else { - for (int i = lane_id; i < bitmap_words_per_warp; i += kThreadsPerWarp) { - local_bitmap_words[i] = 0u; + if (routing_map != nullptr) { + if constexpr (!kIsBitmap) { + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + routing_map[pos_offset + i] = 0; + } + } else { + for (int i = lane_id; i < bitmap_words_per_warp; i += kThreadsPerWarp) { + local_bitmap_words[i] = 0u; + } } } // Load the logits to shmem @@ -192,16 +193,14 @@ __global__ void fused_topk_forward_simple_kernel(const DataType *logits, int num } // Write outputs - if constexpr (!kIsBitmap) { - for (int i = lane_id; i < topk; i += kThreadsPerWarp) { - routing_map[pos_offset + topk_indices[i]] = 1; - probs[pos_offset + topk_indices[i]] = scaling_factor * topk_scores[i]; - } - } else { + if (routing_map != nullptr && kIsBitmap) { for (int i = lane_id; i < topk; i += kThreadsPerWarp) { int e = topk_indices[i]; atomicOr(&local_bitmap_words[e / 32], 1u << (e % 32)); probs[pos_offset + e] = scaling_factor * topk_scores[i]; + if (topk_indices_output != nullptr) { + topk_indices_output[token_offset_cur_warp * topk + i] = static_cast(e); + } } __syncwarp(); uint8_t *bitmap_row = @@ -210,6 +209,17 @@ __global__ void fused_topk_forward_simple_kernel(const DataType *logits, int num for (int i = lane_id; i < bitmap_row_bytes; i += kThreadsPerWarp) { bitmap_row[i] = local_bitmap_bytes[i]; } + } else { + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + int e = topk_indices[i]; + if (routing_map != nullptr) { + routing_map[pos_offset + e] = 1; + } + if (topk_indices_output != nullptr) { + topk_indices_output[token_offset_cur_warp * topk + i] = static_cast(e); + } + probs[pos_offset + e] = scaling_factor * topk_scores[i]; + } } __threadfence_block(); __syncwarp(); @@ -222,11 +232,13 @@ __global__ void fused_topk_forward_simple_kernel(const DataType *logits, int num // ============================================================================= template + TopkFuncType TopkFunc = TopkFuncType::Naive, int ScoreFunc = 0, + typename IndexType = int32_t> __global__ void fused_topk_with_score_function_forward_kernel( const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, int num_groups, int group_topk, float scaling_factor, const BiasType *expert_bias, - DataType *probs, uint8_t *routing_map, CompType *intermediate_output, int num_buffers) { + DataType *probs, uint8_t *routing_map, CompType *intermediate_output, + IndexType *topk_indices_output, int num_buffers) { constexpr bool kIsBitmap = (RoutingMapFormat == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8); /*** * Section: Global Variables/Addresses init @@ -333,11 +345,13 @@ __global__ void fused_topk_with_score_function_forward_kernel( // Clear the probs/routing_map (num_experts) vec_fill_global(probs + pos_offset, static_cast(0.0f), num_experts, lane_id); - if constexpr (!kIsBitmap) { - vec_fill_global(routing_map + pos_offset, static_cast(0), num_experts, lane_id); - } else { - for (int i = lane_id; i < bitmap_words_per_warp; i += kThreadsPerWarp) { - local_bitmap_words[i] = 0u; + if (routing_map != nullptr) { + if constexpr (!kIsBitmap) { + vec_fill_global(routing_map + pos_offset, static_cast(0), num_experts, lane_id); + } else { + for (int i = lane_id; i < bitmap_words_per_warp; i += kThreadsPerWarp) { + local_bitmap_words[i] = 0u; + } } } @@ -491,16 +505,14 @@ __global__ void fused_topk_with_score_function_forward_kernel( } // Write the probs/routing_map to the output tensor - if constexpr (!kIsBitmap) { - for (int i = lane_id; i < topk; i += kThreadsPerWarp) { - routing_map[pos_offset + topk_indices[i]] = 1; - probs[pos_offset + topk_indices[i]] = scaling_factor * topk_scores[i]; - } - } else { + if (routing_map != nullptr && kIsBitmap) { for (int i = lane_id; i < topk; i += kThreadsPerWarp) { int e = topk_indices[i]; atomicOr(&local_bitmap_words[e / 32], 1u << (e % 32)); probs[pos_offset + e] = scaling_factor * topk_scores[i]; + if (topk_indices_output != nullptr) { + topk_indices_output[token_offset_cur_warp * topk + i] = static_cast(e); + } } __syncwarp(); uint8_t *bitmap_row = @@ -509,6 +521,17 @@ __global__ void fused_topk_with_score_function_forward_kernel( for (int i = lane_id; i < bitmap_row_bytes; i += kThreadsPerWarp) { bitmap_row[i] = local_bitmap_bytes[i]; } + } else { + for (int i = lane_id; i < topk; i += kThreadsPerWarp) { + int e = topk_indices[i]; + if (routing_map != nullptr) { + routing_map[pos_offset + e] = 1; + } + if (topk_indices_output != nullptr) { + topk_indices_output[token_offset_cur_warp * topk + i] = static_cast(e); + } + probs[pos_offset + e] = scaling_factor * topk_scores[i]; + } } __syncwarp(); @@ -562,7 +585,8 @@ void fused_topk_with_score_function_forward_kernel_launcher( compute_persistent_grid(kernel, kThreadsPerBlock, shared_memory_size, total_blocks); kernel<<>>( logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, - scaling_factor, expert_bias, probs, routing_map, intermediate_output, num_buffers); + scaling_factor, expert_bias, probs, routing_map, intermediate_output, + static_cast(nullptr), num_buffers); NVTE_CHECK_CUDA(cudaGetLastError()); }; @@ -579,7 +603,8 @@ void fused_topk_with_score_function_forward_kernel_launcher( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, other_shmem)); kernel<<>>( logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, - scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); + scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output, + static_cast(nullptr)); NVTE_CHECK_CUDA(cudaGetLastError()); }; @@ -606,6 +631,99 @@ void fused_topk_with_score_function_forward_kernel_launcher( } } +template +void fused_topk_with_score_function_forward_with_indices_kernel_launcher( + const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, + int num_groups, int group_topk, float scaling_factor, int score_function, + const BiasType *expert_bias, DataType *probs, IndexType *topk_indices, + CompType *intermediate_output, cudaStream_t stream) { + NVTE_CHECK(num_experts > 0, "num_experts must be positive, got ", num_experts); + NVTE_CHECK(topk > 0 && topk <= num_experts, "topk must be in [1, num_experts], got topk=", topk, + " num_experts=", num_experts); + NVTE_CHECK(static_cast(num_tokens) * num_experts <= INT_MAX, + "num_tokens * num_experts exceeds INT_MAX (kernel uses int offsets), got ", + static_cast(num_tokens) * num_experts); + NVTE_CHECK(score_function >= 0 && score_function <= 2, + "Unsupported score_function: ", score_function); + if (group_topk > 0) { + NVTE_CHECK(topk % group_topk == 0, "topk must be divisible by group_topk, got topk=", topk, + " group_topk=", group_topk); + } + if constexpr (std::is_same_v) { + NVTE_CHECK(num_experts <= INT16_MAX, "int16 topk indices require num_experts <= ", INT16_MAX, + ", got ", num_experts); + } + + size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; + size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; + size_t scores_shmem = num_experts * num_token_per_block * sizeof(CompType); + size_t scratch_shmem = + topk * num_token_per_block * sizeof(CompType) + topk * num_token_per_block * sizeof(int); + if (group_topk > 0) { + scratch_shmem += num_groups * num_token_per_block * sizeof(CompType); + scratch_shmem += num_experts * num_token_per_block * sizeof(CompType); + } + size_t other_shmem = scores_shmem + scratch_shmem; + size_t logits_single_buf = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, 1); + int num_buffers = choose_num_buffers(logits_single_buf, other_shmem); + size_t logits_raw_shmem = + RawAsyncLoader::shmem_bytes(num_experts, num_token_per_block, num_buffers); + size_t shared_memory_size = logits_raw_shmem + other_shmem; + + auto launch = [&](auto kernel) { + check_shared_memory_capacity_num_experts(shared_memory_size, num_experts); + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_memory_size)); + size_t grid_size = + compute_persistent_grid(kernel, kThreadsPerBlock, shared_memory_size, total_blocks); + kernel<<>>( + logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, + scaling_factor, expert_bias, probs, static_cast(nullptr), intermediate_output, + topk_indices, num_buffers); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + const bool use_radix = topk >= get_radix_topk_threshold() && num_experts <= kMaxExpertsRadixTopk; + if (!use_radix) { + check_shared_memory_capacity_num_experts(other_shmem, num_experts); + + auto launch_simple = [&](auto kernel) { + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, other_shmem)); + kernel<<>>( + logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, + scaling_factor, score_function, expert_bias, probs, static_cast(nullptr), + intermediate_output, topk_indices); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + launch_simple( + fused_topk_forward_simple_kernel); + } else { + switch (score_function) { + case 0: + launch(fused_topk_with_score_function_forward_kernel); + break; + case 1: + launch(fused_topk_with_score_function_forward_kernel); + break; + case 2: + launch(fused_topk_with_score_function_forward_kernel); + break; + default: + NVTE_ERROR("Unsupported score_function: " + std::to_string(score_function)); + } + } +} + // Build the expected routing_map shape for a given NVTERoutingMapFormat. // BYTEMAP -> [num_tokens, num_experts] // BITMAP_U8 -> [num_tokens, ceil(num_experts/8)] @@ -679,6 +797,41 @@ void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, #undef ROUTER_FORWARD_DISPATCH } +void fused_topk_with_score_function_forward_with_indices( + const Tensor logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, + int num_groups, int group_topk, float scaling_factor, int score_function, + const Tensor expert_bias, Tensor probs, Tensor topk_indices, Tensor intermediate_output, + cudaStream_t stream) { + // Dispatch logits dtype and output-index dtype first; expert-bias dtype is only + // dispatched when an expert-bias tensor exists, otherwise the kernel receives nullptr. +#define ROUTER_FORWARD_WITH_INDICES_DISPATCH(DataType, IndexType) \ + if (expert_bias.has_data()) { \ + TE_ROUTER_PROBS_TYPE_SWITCH_ALL( \ + expert_bias.data.dtype, BiasType, \ + fused_topk_with_score_function_forward_with_indices_kernel_launcher( \ + reinterpret_cast(logits.data.dptr), num_tokens, num_experts, topk, \ + use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, \ + reinterpret_cast(expert_bias.data.dptr), \ + reinterpret_cast(probs.data.dptr), \ + reinterpret_cast(topk_indices.data.dptr), \ + reinterpret_cast(intermediate_output.data.dptr), stream);); \ + } else { \ + fused_topk_with_score_function_forward_with_indices_kernel_launcher( \ + reinterpret_cast(logits.data.dptr), num_tokens, num_experts, topk, \ + use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, nullptr, \ + reinterpret_cast(probs.data.dptr), \ + reinterpret_cast(topk_indices.data.dptr), \ + reinterpret_cast(intermediate_output.data.dptr), stream); \ + } + TE_ROUTER_PROBS_TYPE_SWITCH_ALL( + logits.data.dtype, DataType, + TE_ROUTER_INDEX_TYPE_SWITCH_ALL(topk_indices.data.dtype, IndexType, + ROUTER_FORWARD_WITH_INDICES_DISPATCH(DataType, IndexType););); +#undef ROUTER_FORWARD_WITH_INDICES_DISPATCH +} + // Backward: grad_probs + intermediate_output + routing_map → grad_logits. // // Double-buffered cp.async loads all 3 inputs in original types. Two-pass @@ -994,6 +1147,144 @@ void fused_topk_with_score_function_backward(const Tensor &routing_map, #undef ROUTER_BACKWARD_DISPATCH } +template +__global__ void fused_topk_backward_selected_indices_kernel( + const IndexType *topk_indices, const CompType *intermediate_output, const DataType *grad_probs, + int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, + DataType *grad_logits) { + int num_token_per_block = blockDim.x / kThreadsPerWarp; + int warp_id = threadIdx.x / kThreadsPerWarp; + int lane_id = threadIdx.x % kThreadsPerWarp; + int total_round = (num_tokens + num_token_per_block - 1) / num_token_per_block; + + for (int round = blockIdx.x; round < total_round; round += gridDim.x) { + int token_idx = round * num_token_per_block + warp_id; + if (token_idx >= num_tokens) break; + + int pos = token_idx * num_experts; + const IndexType *token_topk_indices = topk_indices + token_idx * topk; + + CompType sum_act = 0.0f; + CompType sum_grad_act = 0.0f; + CompType sum_output_x_grad = 0.0f; + + if (topk > 1 || ScoreFunc == 1) { + for (int k = lane_id; k < topk; k += kThreadsPerWarp) { + int expert_idx = static_cast(token_topk_indices[k]); + CompType g = static_cast(grad_probs[pos + expert_idx]) * scaling_factor; + CompType act = intermediate_output[pos + expert_idx]; + + if constexpr (ScoreFunc == 0) { + sum_act += act; + sum_grad_act += g * act; + } else if constexpr (ScoreFunc == 2) { + CompType v = sqrtsoftplus_scalar(act); + sum_act += v; + sum_grad_act += g * v; + } else if constexpr (ScoreFunc == 1) { + sum_output_x_grad += g * act; + } + } + if constexpr (ScoreFunc == 0 || ScoreFunc == 2) { + sum_act = warp_allreduce_sum(sum_act); + sum_grad_act = warp_allreduce_sum(sum_grad_act); + } else if constexpr (ScoreFunc == 1) { + sum_output_x_grad = warp_allreduce_sum(sum_output_x_grad); + } + } + + if constexpr (ScoreFunc == 1) { + if (use_pre_softmax) { + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + CompType act = intermediate_output[pos + i]; + grad_logits[pos + i] = + static_cast(softmax_bwd_scalar(0.0f, act, sum_output_x_grad)); + } + } else { + vec_fill_global(grad_logits + pos, static_cast(0.0f), num_experts, lane_id); + } + } else { + vec_fill_global(grad_logits + pos, static_cast(0.0f), num_experts, lane_id); + } + __syncwarp(); + + for (int k = lane_id; k < topk; k += kThreadsPerWarp) { + int expert_idx = static_cast(token_topk_indices[k]); + CompType g = static_cast(grad_probs[pos + expert_idx]) * scaling_factor; + CompType act = intermediate_output[pos + expert_idx]; + + if constexpr (ScoreFunc == 0) { + if (topk > 1) { + g = normalize_bwd_scalar(g, true, sum_act, sum_grad_act); + } + g = sigmoid_bwd_scalar(g, act); + } else if constexpr (ScoreFunc == 2) { + CompType v = sqrtsoftplus_scalar(act); + if (topk > 1) { + g = normalize_bwd_scalar(g, true, sum_act, sum_grad_act); + } + g = sqrtsoftplus_bwd_scalar(g, act, v); + } else if constexpr (ScoreFunc == 1) { + g = softmax_bwd_scalar(g, act, sum_output_x_grad); + } + + grad_logits[pos + expert_idx] = static_cast(g); + } + } +} + +template +void fused_topk_with_score_function_backward_with_indices_kernel_launcher( + const IndexType *topk_indices, const CompType *intermediate_output, const DataType *grad_probs, + int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, + int score_function, DataType *grad_logits, cudaStream_t stream) { + NVTE_CHECK(num_experts > 0, "num_experts must be positive, got ", num_experts); + NVTE_CHECK(static_cast(num_tokens) * num_experts <= INT_MAX, + "num_tokens * num_experts exceeds INT_MAX (kernel uses int offsets), got ", + static_cast(num_tokens) * num_experts); + size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; + size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; + + auto launch_selected_indices = [&](auto kernel) { + size_t grid_size = compute_persistent_grid(kernel, kThreadsPerBlock, 0, total_blocks); + kernel<<>>( + topk_indices, intermediate_output, grad_probs, num_tokens, num_experts, topk, + use_pre_softmax, scaling_factor, grad_logits); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + if (score_function == 0) { + launch_selected_indices(fused_topk_backward_selected_indices_kernel); + return; + } + if (score_function == 2) { + launch_selected_indices(fused_topk_backward_selected_indices_kernel); + return; + } + if (score_function == 1) { + launch_selected_indices(fused_topk_backward_selected_indices_kernel); + return; + } + + NVTE_ERROR("Unsupported score_function: " + std::to_string(score_function)); +} + +void fused_topk_with_score_function_backward_with_indices( + const Tensor &topk_indices, const Tensor &intermediate_output, const Tensor &grad_probs, + int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, + int score_function, Tensor &grad_logits, cudaStream_t stream) { + TE_ROUTER_PROBS_TYPE_SWITCH_ALL( + grad_logits.data.dtype, DataType, + TE_ROUTER_INDEX_TYPE_SWITCH_ALL( + topk_indices.data.dtype, IndexType, + fused_topk_with_score_function_backward_with_indices_kernel_launcher( + reinterpret_cast(topk_indices.data.dptr), + reinterpret_cast(intermediate_output.data.dptr), + reinterpret_cast(grad_probs.data.dptr), num_tokens, num_experts, topk, + use_pre_softmax, scaling_factor, score_function, + reinterpret_cast(grad_logits.data.dptr), stream););); +} + } // namespace fused_router } // namespace transformer_engine @@ -1024,6 +1315,20 @@ void nvte_fused_topk_with_score_function_forward( NVTE_ROUTING_MAP_FORMAT_BYTEMAP, intermediate_output, stream); } +void nvte_fused_topk_with_score_function_forward_with_indices( + const NVTETensor logits, int num_tokens, int num_experts, int topk, int use_pre_softmax, + int num_groups, int group_topk, float scaling_factor, int score_function, + const NVTETensor expert_bias, NVTETensor probs, NVTETensor topk_indices, + NVTETensor intermediate_output, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_topk_with_score_function_forward_with_indices); + using namespace transformer_engine; + fused_router::fused_topk_with_score_function_forward_with_indices( + *convertNVTETensorCheck(logits), num_tokens, num_experts, topk, + static_cast(use_pre_softmax), num_groups, group_topk, scaling_factor, score_function, + *convertNVTETensorCheck(expert_bias), *convertNVTETensorCheck(probs), + *convertNVTETensorCheck(topk_indices), *convertNVTETensorCheck(intermediate_output), stream); +} + void nvte_fused_topk_with_score_function_backward_v2(const NVTETensor routing_map, NVTERoutingMapFormat routing_map_format, const NVTETensor intermediate_output, @@ -1051,3 +1356,16 @@ void nvte_fused_topk_with_score_function_backward(const NVTETensor routing_map, routing_map, NVTE_ROUTING_MAP_FORMAT_BYTEMAP, intermediate_output, grad_probs, num_tokens, num_experts, topk, use_pre_softmax, scaling_factor, score_function, grad_logits, stream); } + +void nvte_fused_topk_with_score_function_backward_with_indices( + const NVTETensor topk_indices, const NVTETensor intermediate_output, + const NVTETensor grad_probs, int num_tokens, int num_experts, int topk, int use_pre_softmax, + float scaling_factor, int score_function, NVTETensor grad_logits, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_topk_with_score_function_backward_with_indices); + using namespace transformer_engine; + fused_router::fused_topk_with_score_function_backward_with_indices( + *convertNVTETensorCheck(topk_indices), *convertNVTETensorCheck(intermediate_output), + *convertNVTETensorCheck(grad_probs), num_tokens, num_experts, topk, + static_cast(use_pre_softmax), scaling_factor, score_function, + *convertNVTETensorCheck(grad_logits), stream); +} diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 85881915a6..f6f569014a 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -569,6 +569,10 @@ __device__ __forceinline__ void topk_and_mask(CompType *scores, int data_size, i #define TE_ROUTER_INDEX_TYPE_SWITCH_ALL(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ + case DType::kInt16: { \ + using type = int16_t; \ + { __VA_ARGS__ } \ + } break; \ case DType::kInt32: { \ using type = int32_t; \ { __VA_ARGS__ } \ @@ -587,7 +591,7 @@ __device__ __forceinline__ void topk_and_mask(CompType *scores, int data_size, i } break; \ default: \ NVTE_ERROR("Unsupported router index dtype ", to_string(static_cast(dtype)), \ - ". Expected one of: Int32, Int64, BFloat16, " \ + ". Expected one of: Int16, Int32, Int64, BFloat16, " \ "Float32."); \ } } // namespace fused_router diff --git a/transformer_engine/common/include/transformer_engine/fused_router.h b/transformer_engine/common/include/transformer_engine/fused_router.h index 6cee10bd39..08f347c616 100644 --- a/transformer_engine/common/include/transformer_engine/fused_router.h +++ b/transformer_engine/common/include/transformer_engine/fused_router.h @@ -82,6 +82,17 @@ void nvte_fused_topk_with_score_function_forward_v2( const NVTETensor expert_bias, NVTETensor probs, NVTETensor routing_map, NVTERoutingMapFormat routing_map_format, NVTETensor intermediate_output, cudaStream_t stream); +/*! \brief Apply topk + softmax/sigmoid and output dense top-k indices. + * + * This entry point does not materialize routing_map. Instead, it writes the + * selected expert ids to topk_indices with shape [num_tokens, topk]. + */ +void nvte_fused_topk_with_score_function_forward_with_indices( + const NVTETensor logits, int num_tokens, int num_experts, int topk, int use_pre_softmax, + int num_groups, int group_topk, float scaling_factor, int score_function, + const NVTETensor expert_bias, NVTETensor probs, NVTETensor topk_indices, + NVTETensor intermediate_output, cudaStream_t stream); + /*! \brief Backward pass for fused topk + softmax/sigmoid (deprecated). * * \deprecated This function has been deprecated in favor of @@ -130,6 +141,15 @@ void nvte_fused_topk_with_score_function_backward_v2(const NVTETensor routing_ma float scaling_factor, int score_function, NVTETensor grad_logits, cudaStream_t stream); +/*! \brief Backward pass for fused topk + score function with dense top-k indices. + * + * \param[in] topk_indices Dense [num_tokens, topk] selected expert indices. + */ +void nvte_fused_topk_with_score_function_backward_with_indices( + const NVTETensor topk_indices, const NVTETensor intermediate_output, + const NVTETensor grad_probs, int num_tokens, int num_experts, int topk, int use_pre_softmax, + float scaling_factor, int score_function, NVTETensor grad_logits, cudaStream_t stream); + /*! \brief Forward pass for computing scores/routing map for auxiliary loss (deprecated). * * \deprecated This function has been deprecated in favor of diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2b4f899e1d..b4eb0253fe 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -33,13 +33,13 @@ namespace transformer_engine::pytorch { std::tuple fused_topk_with_score_function_fwd( at::Tensor logits, int topk, bool use_pre_softmax, std::optional num_groups, std::optional group_topk, std::optional scaling_factor, std::string score_function, - std::optional expert_bias, + std::optional expert_bias, std::optional topk_indices = std::nullopt, int routing_map_format = static_cast(NVTE_ROUTING_MAP_FORMAT_BYTEMAP)); void fused_topk_with_score_function_bwd( at::Tensor routing_map, at::Tensor intermediate_output, at::Tensor grad_probs, at::Tensor grad_logits, int topk, bool use_pre_softmax, std::optional scaling_factor, - std::string score_function, + std::string score_function, bool use_dense_indices = false, int routing_map_format = static_cast(NVTE_ROUTING_MAP_FORMAT_BYTEMAP)); std::tuple fused_score_for_moe_aux_loss_fwd( diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d6089b1e01..e36835dc2a 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -139,13 +139,13 @@ void init_router_bindings(pybind11::module &m) { m.def("fused_topk_with_score_function_fwd", &fused_topk_with_score_function_fwd, py::arg("logits"), py::arg("topk"), py::arg("use_pre_softmax"), py::arg("num_groups"), py::arg("group_topk"), py::arg("scaling_factor"), py::arg("score_function"), - py::arg("expert_bias"), + py::arg("expert_bias"), py::arg("topk_indices") = std::nullopt, py::arg("routing_map_format") = static_cast(NVTE_ROUTING_MAP_FORMAT_BYTEMAP), "Fused topk with score function fwd"); m.def("fused_topk_with_score_function_bwd", &fused_topk_with_score_function_bwd, py::arg("routing_map"), py::arg("intermediate_output"), py::arg("grad_probs"), py::arg("grad_logits"), py::arg("topk"), py::arg("use_pre_softmax"), - py::arg("scaling_factor"), py::arg("score_function"), + py::arg("scaling_factor"), py::arg("score_function"), py::arg("use_dense_indices") = false, py::arg("routing_map_format") = static_cast(NVTE_ROUTING_MAP_FORMAT_BYTEMAP), "Fused topk with score function bwd"); m.def("fused_score_for_moe_aux_loss_fwd", &fused_score_for_moe_aux_loss_fwd, py::arg("logits"), diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index 5dc5c7fe86..23214fd3f1 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -31,7 +31,8 @@ static at::Tensor allocate_routing_map(c10::IntArrayRef leading_dims, int64_t nu std::tuple fused_topk_with_score_function_fwd( at::Tensor logits, int topk, bool use_pre_softmax, std::optional num_groups, std::optional group_topk, std::optional scaling_factor, std::string score_function, - std::optional expert_bias, int routing_map_format) { + std::optional expert_bias, std::optional topk_indices, + int routing_map_format) { TORCH_CHECK(logits.dim() >= 1, "logits must have at least 1 dim"); TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); auto sizes = logits.sizes(); @@ -62,17 +63,22 @@ std::tuple fused_topk_with_score_function_fw at::Tensor probs = at::empty(sizes, at::dtype(logits.scalar_type()).device(at::kCUDA)); at::Tensor routing_map = - allocate_routing_map(sizes.slice(0, sizes.size() - 1), num_experts, routing_map_format); + topk_indices.has_value() + ? topk_indices.value() + : allocate_routing_map(sizes.slice(0, sizes.size() - 1), num_experts, routing_map_format); at::Tensor intermediate_output = at::empty(sizes, at::dtype(at::kFloat).device(at::kCUDA)); // 2D shape for the kernel (common-layer NVTE_CHECKs require {num_tokens, trailing_dim}). const std::vector shape_2d = {static_cast(num_tokens), static_cast(num_experts)}; - const std::vector routing_map_shape_2d = { - static_cast(num_tokens), - static_cast(routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 - ? (num_experts + 7) / 8 - : num_experts)}; + const std::vector routing_map_shape_2d = + topk_indices.has_value() + ? std::vector{static_cast(num_tokens), static_cast(topk)} + : std::vector{ + static_cast(num_tokens), + static_cast(routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 + ? (num_experts + 7) / 8 + : num_experts)}; auto logits_dtype = GetTransformerEngineDType(logits.scalar_type()); auto routing_map_dtype = GetTransformerEngineDType(routing_map.scalar_type()); @@ -87,12 +93,20 @@ std::tuple fused_topk_with_score_function_fw expert_bias_cu = makeTransformerEngineTensor(expert_bias.value()); } - nvte_fused_topk_with_score_function_forward_v2( - logits_cu.data(), static_cast(num_tokens), static_cast(num_experts), topk, - use_pre_softmax, num_groups_value, group_topk_value, scaling_factor_value, - score_function_map[score_function], expert_bias_cu.data(), probs_cu.data(), - routing_map_cu.data(), static_cast(routing_map_format), - intermediate_output_cu.data(), at::cuda::getCurrentCUDAStream()); + if (topk_indices.has_value()) { + nvte_fused_topk_with_score_function_forward_with_indices( + logits_cu.data(), static_cast(num_tokens), static_cast(num_experts), topk, + use_pre_softmax, num_groups_value, group_topk_value, scaling_factor_value, + score_function_map[score_function], expert_bias_cu.data(), probs_cu.data(), + routing_map_cu.data(), intermediate_output_cu.data(), at::cuda::getCurrentCUDAStream()); + } else { + nvte_fused_topk_with_score_function_forward_v2( + logits_cu.data(), static_cast(num_tokens), static_cast(num_experts), topk, + use_pre_softmax, num_groups_value, group_topk_value, scaling_factor_value, + score_function_map[score_function], expert_bias_cu.data(), probs_cu.data(), + routing_map_cu.data(), static_cast(routing_map_format), + intermediate_output_cu.data(), at::cuda::getCurrentCUDAStream()); + } return std::make_tuple(probs, routing_map, intermediate_output); } @@ -100,7 +114,8 @@ std::tuple fused_topk_with_score_function_fw void fused_topk_with_score_function_bwd(at::Tensor routing_map, at::Tensor intermediate_output, at::Tensor grad_probs, at::Tensor grad_logits, int topk, bool use_pre_softmax, std::optional scaling_factor, - std::string score_function, int routing_map_format) { + std::string score_function, bool use_dense_indices, + int routing_map_format) { TORCH_CHECK(grad_probs.dim() >= 1, "grad_probs must have at least 1 dim"); TORCH_CHECK(grad_probs.is_contiguous(), "grad_probs must be contiguous"); TORCH_CHECK(grad_logits.is_contiguous(), "grad_logits must be contiguous"); @@ -116,9 +131,11 @@ void fused_topk_with_score_function_bwd(at::Tensor routing_map, at::Tensor inter static_cast(num_experts)}; const std::vector routing_map_shape_2d = { static_cast(num_tokens), - static_cast(routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 - ? (num_experts + 7) / 8 - : num_experts)}; + static_cast(use_dense_indices + ? topk + : (routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8 + ? (num_experts + 7) / 8 + : num_experts))}; auto grad_dtype = GetTransformerEngineDType(grad_probs.scalar_type()); auto routing_map_dtype = GetTransformerEngineDType(routing_map.scalar_type()); @@ -129,11 +146,19 @@ void fused_topk_with_score_function_bwd(at::Tensor routing_map, at::Tensor inter auto grad_probs_cu = makeTransformerEngineTensor(grad_probs.data_ptr(), shape_2d, grad_dtype); auto grad_logits_cu = makeTransformerEngineTensor(grad_logits.data_ptr(), shape_2d, grad_dtype); - nvte_fused_topk_with_score_function_backward_v2( - routing_map_cu.data(), static_cast(routing_map_format), - intermediate_output_cu.data(), grad_probs_cu.data(), static_cast(num_tokens), - static_cast(num_experts), topk, use_pre_softmax, scaling_factor_value, - score_function_value, grad_logits_cu.data(), at::cuda::getCurrentCUDAStream()); + if (use_dense_indices) { + nvte_fused_topk_with_score_function_backward_with_indices( + routing_map_cu.data(), intermediate_output_cu.data(), grad_probs_cu.data(), + static_cast(num_tokens), static_cast(num_experts), topk, use_pre_softmax, + scaling_factor_value, score_function_value, grad_logits_cu.data(), + at::cuda::getCurrentCUDAStream()); + } else { + nvte_fused_topk_with_score_function_backward_v2( + routing_map_cu.data(), static_cast(routing_map_format), + intermediate_output_cu.data(), grad_probs_cu.data(), static_cast(num_tokens), + static_cast(num_experts), topk, use_pre_softmax, scaling_factor_value, + score_function_value, grad_logits_cu.data(), at::cuda::getCurrentCUDAStream()); + } } std::tuple fused_score_for_moe_aux_loss_fwd( diff --git a/transformer_engine/pytorch/router.py b/transformer_engine/pytorch/router.py index 667dee6872..53ffaadbdf 100644 --- a/transformer_engine/pytorch/router.py +++ b/transformer_engine/pytorch/router.py @@ -83,9 +83,14 @@ def forward( score_function: str, expert_bias: Optional[torch.Tensor], routing_map_format: int, + topk_indices: Optional[torch.Tensor], ): # pylint: disable=missing-function-docstring - probs, routing_map, intermediate_output = tex.fused_topk_with_score_function_fwd( + tensor_shape = logits.shape + logits = logits.view(-1, tensor_shape[-1]) + num_tokens = logits.size(0) + num_experts = logits.size(1) + probs, routing_output, intermediate_output = tex.fused_topk_with_score_function_fwd( logits, topk, use_pre_softmax, @@ -94,23 +99,35 @@ def forward( scaling_factor, score_function, expert_bias, + topk_indices, routing_map_format, ) - ctx.save_for_backward(routing_map, intermediate_output) + if topk_indices is not None: + routing_output = topk_indices + probs = probs.view(tensor_shape) + if topk_indices is not None: + ctx.mark_dirty(topk_indices) + ctx.save_for_backward(routing_output, intermediate_output) + ctx.num_tokens = num_tokens + ctx.num_experts = num_experts + ctx.tensor_shape = tensor_shape ctx.use_pre_softmax = use_pre_softmax ctx.topk = topk ctx.scaling_factor = scaling_factor ctx.score_function = score_function ctx.routing_map_format = routing_map_format - return probs, routing_map + ctx.logits_dtype = logits.dtype + ctx.use_dense_indices = topk_indices is not None + return probs, routing_output @staticmethod def backward(ctx, grad_probs, _): # pylint: disable=missing-function-docstring routing_map, intermediate_output = ctx.saved_tensors - if not grad_probs.is_contiguous(): - grad_probs = grad_probs.contiguous() - grad_logits = torch.empty_like(grad_probs) + grad_probs = grad_probs.contiguous().view(-1, ctx.tensor_shape[-1]) + grad_logits = torch.empty( + (ctx.num_tokens, ctx.num_experts), dtype=ctx.logits_dtype, device=grad_probs.device + ) tex.fused_topk_with_score_function_bwd( routing_map, intermediate_output, @@ -120,9 +137,11 @@ def backward(ctx, grad_probs, _): ctx.use_pre_softmax, ctx.scaling_factor, ctx.score_function, + ctx.use_dense_indices, ctx.routing_map_format, ) - return grad_logits, None, None, None, None, None, None, None, None + grad_logits = grad_logits.view(ctx.tensor_shape) + return grad_logits, None, None, None, None, None, None, None, None, None def fused_topk_with_score_function( @@ -135,6 +154,7 @@ def fused_topk_with_score_function( score_function: str, expert_bias: Optional[torch.Tensor], routing_map_format: Union[str, RoutingMapFormat, int] = RoutingMapFormat.BYTEMAP, + topk_indices: Optional[torch.Tensor] = None, ): """ Fused topk with score function router. @@ -159,6 +179,9 @@ def fused_topk_with_score_function( ``RoutingMapFormat.BITMAP_U8`` returns a uint8[T, ceil(E/8)] tensor with bit ``(e % 8)`` of byte ``(e / 8)`` set when token ``t`` routes to expert ``e`` (LSB-first / little-endian packing along the expert axis). + topk_indices : torch.Tensor, optional + Optional output buffer with shape [num_tokens, topk]. When provided, its dtype + controls the dense index output dtype and the routing map is not materialized. Returns ------- @@ -166,7 +189,7 @@ def fused_topk_with_score_function( Same shape as ``logits``. routing_map : torch.Tensor Same leading dims as ``logits``; trailing dim and dtype depend on - routing_map_format: + routing_map_format, or dense top-k indices when topk_indices is provided: - BYTEMAP: bool[*logits.shape[:-1], num_experts] - BITMAP_U8: uint8[*logits.shape[:-1], ceil(num_experts/8)] LSB-first bit-packed. @@ -184,6 +207,7 @@ def fused_topk_with_score_function( score_function, expert_bias, routing_map_format, + topk_indices, ) From ceb445b34b405e5d117daf89f915f21096c6fd8f Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Mon, 1 Jun 2026 12:55:56 +0800 Subject: [PATCH 2/8] [Common] Optimize dense fused router backward Signed-off-by: Harry Zhou --- tests/pytorch/test_fused_router.py | 36 +++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index 974ccee19c..ede44d19ce 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -166,6 +166,15 @@ def aux_loss_pytorch( return aux_loss +def topk_indices_to_routing_map(topk_indices: torch.Tensor, num_experts: int) -> torch.Tensor: + """Convert dense [num_tokens, topk] top-k indices to a bool routing map.""" + routing_map = torch.zeros( + topk_indices.size(0), num_experts, dtype=torch.bool, device=topk_indices.device + ) + routing_map.scatter_(1, topk_indices.long(), True) + return routing_map + + def run_comparison( dtype, num_tokens, @@ -177,6 +186,8 @@ def run_comparison( scaling_factor, score_function, enable_bias, + topk_output_mode="sparse", + topk_index_dtype=torch.int16, ): if topk >= num_experts: pytest.skip(f"topk ({topk}) >= num_experts ({num_experts})") @@ -235,8 +246,12 @@ def run_comparison( expert_bias=expert_bias, ) + topk_indices = None + if topk_output_mode == "dense": + topk_indices = torch.empty((num_tokens, topk), device="cuda", dtype=topk_index_dtype) + # Run the fused implementation - probs_fused, routing_map_fused = fused_topk_with_score_function( + probs_fused, routing_output_fused = fused_topk_with_score_function( logits=logits_clone, topk=topk, use_pre_softmax=use_pre_softmax, @@ -245,7 +260,14 @@ def run_comparison( scaling_factor=scaling_factor, score_function=score_function, expert_bias=expert_bias_clone, + topk_indices=topk_indices, ) + if topk_output_mode == "dense": + assert routing_output_fused.data_ptr() == topk_indices.data_ptr() + assert routing_output_fused.dtype == topk_index_dtype + routing_map_fused = topk_indices_to_routing_map(routing_output_fused, num_experts) + else: + routing_map_fused = routing_output_fused atol, rtol = _get_tolerances(dtype, num_experts) torch.testing.assert_close(probs, probs_fused, atol=atol, rtol=rtol) @@ -270,6 +292,7 @@ def run_comparison( @pytest.mark.parametrize("group_topk", [None, 4]) @pytest.mark.parametrize("scaling_factor", [None, 1.2]) @pytest.mark.parametrize("enable_bias", [True, False]) +@pytest.mark.parametrize("topk_index_dtype", [None, torch.int16, torch.int32, torch.int64]) def test_topk_sigmoid( dtype, num_tokens, @@ -278,6 +301,7 @@ def test_topk_sigmoid( group_topk, scaling_factor, enable_bias, + topk_index_dtype, ): num_groups = 8 if group_topk else None run_comparison( @@ -291,6 +315,8 @@ def test_topk_sigmoid( scaling_factor=scaling_factor, score_function="sigmoid", enable_bias=enable_bias, + topk_output_mode="dense" if topk_index_dtype is not None else "sparse", + topk_index_dtype=topk_index_dtype or torch.int16, ) @@ -301,6 +327,7 @@ def test_topk_sigmoid( @pytest.mark.parametrize("group_topk", [None, 4]) @pytest.mark.parametrize("scaling_factor", [None, 1.2]) @pytest.mark.parametrize("enable_bias", [True, False]) +@pytest.mark.parametrize("topk_index_dtype", [None, torch.int16, torch.int32, torch.int64]) def test_topk_sqrtsoftplus( dtype, num_tokens, @@ -309,6 +336,7 @@ def test_topk_sqrtsoftplus( group_topk, scaling_factor, enable_bias, + topk_index_dtype, ): num_groups = 8 if group_topk else None run_comparison( @@ -322,6 +350,8 @@ def test_topk_sqrtsoftplus( scaling_factor=scaling_factor, score_function="sqrtsoftplus", enable_bias=enable_bias, + topk_output_mode="dense" if topk_index_dtype is not None else "sparse", + topk_index_dtype=topk_index_dtype or torch.int16, ) @@ -332,6 +362,7 @@ def test_topk_sqrtsoftplus( @pytest.mark.parametrize("use_pre_softmax", [True, False]) @pytest.mark.parametrize("group_topk", [None, 4]) @pytest.mark.parametrize("scaling_factor", [None, 1.2]) +@pytest.mark.parametrize("topk_index_dtype", [None, torch.int16, torch.int32, torch.int64]) def test_topk_softmax( dtype, num_tokens, @@ -340,6 +371,7 @@ def test_topk_softmax( use_pre_softmax, group_topk, scaling_factor, + topk_index_dtype, ): num_groups = 8 if group_topk else None run_comparison( @@ -353,6 +385,8 @@ def test_topk_softmax( scaling_factor=scaling_factor, score_function="softmax", enable_bias=False, + topk_output_mode="dense" if topk_index_dtype is not None else "sparse", + topk_index_dtype=topk_index_dtype or torch.int16, ) From 0b39ede948d048dfbd9078e5bdf5f70f1a1cbc1c Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Mon, 1 Jun 2026 12:55:56 +0800 Subject: [PATCH 3/8] [Common] Align dense router fallback with p3R Signed-off-by: Harry Zhou --- .../common/fused_router/fused_topk_with_score_function.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index effaef320a..8ac60a0864 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -686,6 +686,7 @@ void fused_topk_with_score_function_forward_with_indices_kernel_launcher( const bool use_radix = topk >= get_radix_topk_threshold() && num_experts <= kMaxExpertsRadixTopk; if (!use_radix) { + // Simple path: no async loader, no persistent grid. check_shared_memory_capacity_num_experts(other_shmem, num_experts); auto launch_simple = [&](auto kernel) { From 3dd72556d50fe1c6cafd40360b06e4c353a3c3f7 Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Wed, 3 Jun 2026 18:30:07 +0800 Subject: [PATCH 4/8] [PyTorch] Support int16 weak refs for CUDA graph reuse Signed-off-by: Harry Zhou --- transformer_engine/pytorch/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index cfb21e7bff..ffbfbc1fdd 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -864,6 +864,7 @@ def torch_get_autocast_gpu_dtype() -> torch.dtype: torch.float32: " Date: Thu, 4 Jun 2026 07:31:01 -0700 Subject: [PATCH 5/8] [Common] Guard dense router topk index APIs Signed-off-by: Harry Zhou --- .../fused_topk_with_score_function.cu | 63 +++++++++++++++++-- .../common/fused_router/utils.h | 20 ++++++ transformer_engine/pytorch/csrc/extensions.h | 5 +- .../pytorch/csrc/extensions/pybind.cpp | 4 +- .../pytorch/csrc/extensions/router.cpp | 46 +++++++++++++- transformer_engine/pytorch/router.py | 3 +- 6 files changed, 130 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 8ac60a0864..7f9d370220 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -803,6 +803,31 @@ void fused_topk_with_score_function_forward_with_indices( int num_groups, int group_topk, float scaling_factor, int score_function, const Tensor expert_bias, Tensor probs, Tensor topk_indices, Tensor intermediate_output, cudaStream_t stream) { + NVTE_CHECK(num_tokens > 0 && num_experts > 0, + "num_tokens and num_experts must be positive; got num_tokens=", num_tokens, + ", num_experts=", num_experts); + NVTE_CHECK(topk > 0 && topk <= num_experts, "topk must be in [1, num_experts], got topk=", topk, + " num_experts=", num_experts); + const std::vector dense_shape{static_cast(num_tokens), + static_cast(num_experts)}; + const std::vector indices_shape{static_cast(num_tokens), + static_cast(topk)}; + NVTE_CHECK(logits.data.shape == dense_shape, "logits shape must be [num_tokens, num_experts]=[", + num_tokens, ", ", num_experts, "], got ", logits.data.shape); + NVTE_CHECK(probs.data.shape == dense_shape, "probs shape must be [num_tokens, num_experts]=[", + num_tokens, ", ", num_experts, "], got ", probs.data.shape); + NVTE_CHECK(intermediate_output.data.shape == dense_shape, + "intermediate_output shape must be [num_tokens, num_experts]=[", num_tokens, ", ", + num_experts, "], got ", intermediate_output.data.shape); + NVTE_CHECK(topk_indices.data.shape == indices_shape, + "topk_indices shape must be [num_tokens, " + "topk]=[", + num_tokens, ", ", topk, "], got ", topk_indices.data.shape); + if (expert_bias.has_data()) { + NVTE_CHECK(expert_bias.data.shape == std::vector{static_cast(num_experts)}, + "expert_bias shape must be [num_experts]=[", num_experts, "], got ", + expert_bias.data.shape); + } // Dispatch logits dtype and output-index dtype first; expert-bias dtype is only // dispatched when an expert-bias tensor exists, otherwise the kernel receives nullptr. #define ROUTER_FORWARD_WITH_INDICES_DISPATCH(DataType, IndexType) \ @@ -826,9 +851,9 @@ void fused_topk_with_score_function_forward_with_indices( reinterpret_cast(topk_indices.data.dptr), \ reinterpret_cast(intermediate_output.data.dptr), stream); \ } - TE_ROUTER_PROBS_TYPE_SWITCH_ALL( - logits.data.dtype, DataType, - TE_ROUTER_INDEX_TYPE_SWITCH_ALL(topk_indices.data.dtype, IndexType, + TE_ROUTER_PROBS_TYPE_SWITCH_ALL(logits.data.dtype, DataType, + TE_ROUTER_DENSE_INDEX_TYPE_SWITCH_ALL( + topk_indices.data.dtype, IndexType, ROUTER_FORWARD_WITH_INDICES_DISPATCH(DataType, IndexType););); #undef ROUTER_FORWARD_WITH_INDICES_DISPATCH } @@ -1063,6 +1088,8 @@ void fused_topk_with_score_function_backward_kernel_launcher( int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, int score_function, DataType *grad_logits, cudaStream_t stream) { NVTE_CHECK(num_experts > 0, "num_experts must be positive, got ", num_experts); + NVTE_CHECK(topk > 0 && topk <= num_experts, "topk must be in [1, num_experts], got topk=", topk, + " num_experts=", num_experts); NVTE_CHECK(static_cast(num_tokens) * num_experts <= INT_MAX, "num_tokens * num_experts exceeds INT_MAX (kernel uses int offsets), got ", static_cast(num_tokens) * num_experts); @@ -1240,9 +1267,15 @@ void fused_topk_with_score_function_backward_with_indices_kernel_launcher( int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, int score_function, DataType *grad_logits, cudaStream_t stream) { NVTE_CHECK(num_experts > 0, "num_experts must be positive, got ", num_experts); + NVTE_CHECK(topk > 0 && topk <= num_experts, "topk must be in [1, num_experts], got topk=", topk, + " num_experts=", num_experts); NVTE_CHECK(static_cast(num_tokens) * num_experts <= INT_MAX, "num_tokens * num_experts exceeds INT_MAX (kernel uses int offsets), got ", static_cast(num_tokens) * num_experts); + if constexpr (std::is_same_v) { + NVTE_CHECK(num_experts <= INT16_MAX, "int16 topk indices require num_experts <= ", INT16_MAX, + ", got ", num_experts); + } size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; @@ -1274,9 +1307,31 @@ void fused_topk_with_score_function_backward_with_indices( const Tensor &topk_indices, const Tensor &intermediate_output, const Tensor &grad_probs, int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, int score_function, Tensor &grad_logits, cudaStream_t stream) { + NVTE_CHECK(num_tokens > 0 && num_experts > 0, + "num_tokens and num_experts must be positive; got num_tokens=", num_tokens, + ", num_experts=", num_experts); + NVTE_CHECK(topk > 0 && topk <= num_experts, "topk must be in [1, num_experts], got topk=", topk, + " num_experts=", num_experts); + const std::vector dense_shape{static_cast(num_tokens), + static_cast(num_experts)}; + const std::vector indices_shape{static_cast(num_tokens), + static_cast(topk)}; + NVTE_CHECK(topk_indices.data.shape == indices_shape, + "topk_indices shape must be [num_tokens, " + "topk]=[", + num_tokens, ", ", topk, "], got ", topk_indices.data.shape); + NVTE_CHECK(intermediate_output.data.shape == dense_shape, + "intermediate_output shape must be [num_tokens, num_experts]=[", num_tokens, ", ", + num_experts, "], got ", intermediate_output.data.shape); + NVTE_CHECK(grad_probs.data.shape == dense_shape, + "grad_probs shape must be [num_tokens, num_experts]=[", num_tokens, ", ", num_experts, + "], got ", grad_probs.data.shape); + NVTE_CHECK(grad_logits.data.shape == dense_shape, + "grad_logits shape must be [num_tokens, num_experts]=[", num_tokens, ", ", num_experts, + "], got ", grad_logits.data.shape); TE_ROUTER_PROBS_TYPE_SWITCH_ALL( grad_logits.data.dtype, DataType, - TE_ROUTER_INDEX_TYPE_SWITCH_ALL( + TE_ROUTER_DENSE_INDEX_TYPE_SWITCH_ALL( topk_indices.data.dtype, IndexType, fused_topk_with_score_function_backward_with_indices_kernel_launcher( reinterpret_cast(topk_indices.data.dptr), diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index f6f569014a..182e71913f 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -594,6 +594,26 @@ __device__ __forceinline__ void topk_and_mask(CompType *scores, int data_size, i ". Expected one of: Int16, Int32, Int64, BFloat16, " \ "Float32."); \ } + +#define TE_ROUTER_DENSE_INDEX_TYPE_SWITCH_ALL(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kInt16: { \ + using type = int16_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt32: { \ + using type = int32_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt64: { \ + using type = int64_t; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported dense router index dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Int16, Int32, Int64."); \ + } } // namespace fused_router } // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index b4eb0253fe..13d872392d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -33,8 +33,9 @@ namespace transformer_engine::pytorch { std::tuple fused_topk_with_score_function_fwd( at::Tensor logits, int topk, bool use_pre_softmax, std::optional num_groups, std::optional group_topk, std::optional scaling_factor, std::string score_function, - std::optional expert_bias, std::optional topk_indices = std::nullopt, - int routing_map_format = static_cast(NVTE_ROUTING_MAP_FORMAT_BYTEMAP)); + std::optional expert_bias, + int routing_map_format = static_cast(NVTE_ROUTING_MAP_FORMAT_BYTEMAP), + std::optional topk_indices = std::nullopt); void fused_topk_with_score_function_bwd( at::Tensor routing_map, at::Tensor intermediate_output, at::Tensor grad_probs, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index e36835dc2a..d1890872c0 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -139,9 +139,9 @@ void init_router_bindings(pybind11::module &m) { m.def("fused_topk_with_score_function_fwd", &fused_topk_with_score_function_fwd, py::arg("logits"), py::arg("topk"), py::arg("use_pre_softmax"), py::arg("num_groups"), py::arg("group_topk"), py::arg("scaling_factor"), py::arg("score_function"), - py::arg("expert_bias"), py::arg("topk_indices") = std::nullopt, + py::arg("expert_bias"), py::arg("routing_map_format") = static_cast(NVTE_ROUTING_MAP_FORMAT_BYTEMAP), - "Fused topk with score function fwd"); + py::arg("topk_indices") = std::nullopt, "Fused topk with score function fwd"); m.def("fused_topk_with_score_function_bwd", &fused_topk_with_score_function_bwd, py::arg("routing_map"), py::arg("intermediate_output"), py::arg("grad_probs"), py::arg("grad_logits"), py::arg("topk"), py::arg("use_pre_softmax"), diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index 23214fd3f1..4cbea1ec76 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -28,11 +28,39 @@ static at::Tensor allocate_routing_map(c10::IntArrayRef leading_dims, int64_t nu return at::empty(shape, at::dtype(at::kBool).device(at::kCUDA)); } +static void check_routing_map_format(int routing_map_format) { + TORCH_CHECK(routing_map_format == NVTE_ROUTING_MAP_FORMAT_BYTEMAP || + routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8, + "routing_map_format must be BYTEMAP (0) or BITMAP_U8 (1), got ", routing_map_format); +} + +static bool is_supported_dense_index_dtype(at::ScalarType dtype) { + return dtype == at::kShort || dtype == at::kInt || dtype == at::kLong; +} + +static void check_dense_topk_indices(const at::Tensor &topk_indices, const at::Tensor &ref, + int64_t num_tokens, int topk) { + TORCH_CHECK(topk_indices.is_cuda(), "topk_indices must be a CUDA tensor"); + TORCH_CHECK(topk_indices.device() == ref.device(), "topk_indices must be on the same device as ", + "the logits/grad tensor"); + TORCH_CHECK(topk_indices.is_contiguous(), "topk_indices must be contiguous"); + TORCH_CHECK(is_supported_dense_index_dtype(topk_indices.scalar_type()), + "topk_indices dtype must be int16, int32, or int64, got ", + topk_indices.scalar_type()); + TORCH_CHECK(topk_indices.numel() == num_tokens * static_cast(topk), + "topk_indices must contain num_tokens * topk elements, got ", topk_indices.numel(), + " but expected ", num_tokens * static_cast(topk)); + TORCH_CHECK(topk_indices.dim() >= 1 && topk_indices.size(-1) == topk, + "topk_indices last dimension must be topk=", topk, ", got shape ", + topk_indices.sizes()); +} + std::tuple fused_topk_with_score_function_fwd( at::Tensor logits, int topk, bool use_pre_softmax, std::optional num_groups, std::optional group_topk, std::optional scaling_factor, std::string score_function, - std::optional expert_bias, std::optional topk_indices, - int routing_map_format) { + std::optional expert_bias, int routing_map_format, + std::optional topk_indices) { + check_routing_map_format(routing_map_format); TORCH_CHECK(logits.dim() >= 1, "logits must have at least 1 dim"); TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); auto sizes = logits.sizes(); @@ -41,6 +69,8 @@ std::tuple fused_topk_with_score_function_fw std::accumulate(sizes.begin(), sizes.end() - 1, int64_t{1}, std::multiplies()); TORCH_CHECK(num_tokens > 0 && num_experts > 0, "num_tokens and num_experts must be greater than 0"); + TORCH_CHECK(topk > 0 && topk <= num_experts, "topk must be in [1, num_experts], got topk=", topk, + " num_experts=", num_experts); // Expert bias only happens at the sigmoid case if (expert_bias.has_value()) { TORCH_CHECK(score_function == "sigmoid" || score_function == "sqrtsoftplus", @@ -55,6 +85,9 @@ std::tuple fused_topk_with_score_function_fw if (score_function == "sigmoid" || score_function == "sqrtsoftplus") { use_pre_softmax = false; // Pre-softmax only happens at the softmax case } + if (topk_indices.has_value()) { + check_dense_topk_indices(topk_indices.value(), logits, num_tokens, topk); + } // Reformat the input to make it compatible with the kernel int group_topk_value = group_topk.has_value() ? group_topk.value() : -1; @@ -116,6 +149,7 @@ void fused_topk_with_score_function_bwd(at::Tensor routing_map, at::Tensor inter bool use_pre_softmax, std::optional scaling_factor, std::string score_function, bool use_dense_indices, int routing_map_format) { + check_routing_map_format(routing_map_format); TORCH_CHECK(grad_probs.dim() >= 1, "grad_probs must have at least 1 dim"); TORCH_CHECK(grad_probs.is_contiguous(), "grad_probs must be contiguous"); TORCH_CHECK(grad_logits.is_contiguous(), "grad_logits must be contiguous"); @@ -123,6 +157,13 @@ void fused_topk_with_score_function_bwd(at::Tensor routing_map, at::Tensor inter int64_t num_experts = sizes.back(); int64_t num_tokens = std::accumulate(sizes.begin(), sizes.end() - 1, int64_t{1}, std::multiplies()); + TORCH_CHECK(num_tokens > 0 && num_experts > 0, + "num_tokens and num_experts must be greater than 0"); + TORCH_CHECK(topk > 0 && topk <= num_experts, "topk must be in [1, num_experts], got topk=", topk, + " num_experts=", num_experts); + if (use_dense_indices) { + check_dense_topk_indices(routing_map, grad_probs, num_tokens, topk); + } auto scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f; auto score_function_value = score_function_map[score_function]; @@ -163,6 +204,7 @@ void fused_topk_with_score_function_bwd(at::Tensor routing_map, at::Tensor inter std::tuple fused_score_for_moe_aux_loss_fwd( at::Tensor logits, int topk, std::string score_function, int routing_map_format) { + check_routing_map_format(routing_map_format); TORCH_CHECK(logits.dim() >= 1, "logits must have at least 1 dim"); TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); auto sizes = logits.sizes(); diff --git a/transformer_engine/pytorch/router.py b/transformer_engine/pytorch/router.py index 53ffaadbdf..004451ea7f 100644 --- a/transformer_engine/pytorch/router.py +++ b/transformer_engine/pytorch/router.py @@ -99,14 +99,15 @@ def forward( scaling_factor, score_function, expert_bias, - topk_indices, routing_map_format, + topk_indices, ) if topk_indices is not None: routing_output = topk_indices probs = probs.view(tensor_shape) if topk_indices is not None: ctx.mark_dirty(topk_indices) + ctx.mark_non_differentiable(routing_output) ctx.save_for_backward(routing_output, intermediate_output) ctx.num_tokens = num_tokens ctx.num_experts = num_experts From 8c6bd7326558119618ec50ed6863e00fe67bf700 Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Mon, 15 Jun 2026 05:27:28 -0700 Subject: [PATCH 6/8] [Common] Harden dense router API guards Signed-off-by: Harry Zhou --- .../fused_score_for_moe_aux_loss.cu | 1 + .../fused_topk_with_score_function.cu | 22 +++++++++++++++++++ .../common/fused_router/utils.h | 8 +++++++ .../pytorch/csrc/extensions/router.cpp | 18 ++++++++++----- 4 files changed, 44 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 70dc1faa10..4156a5dfc9 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -416,6 +416,7 @@ void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, Tensor &routing_map, NVTERoutingMapFormat routing_map_format, Tensor &intermediate_output, cudaStream_t stream) { + check_routing_map_format(routing_map_format); NVTE_CHECK(num_tokens > 0 && num_experts > 0, "num_tokens and num_experts must be positive; got num_tokens=", num_tokens, ", num_experts=", num_experts); diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 7f9d370220..cd8eca05e0 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -556,6 +556,16 @@ void fused_topk_with_score_function_forward_kernel_launcher( if (group_topk > 0) { NVTE_CHECK(topk % group_topk == 0, "topk must be divisible by group_topk, got topk=", topk, " group_topk=", group_topk); + NVTE_CHECK(num_groups > 0, "num_groups must be positive when group_topk > 0, got ", num_groups); + NVTE_CHECK(group_topk <= num_groups, + "group_topk must be <= num_groups, got group_topk=", group_topk, + " num_groups=", num_groups); + NVTE_CHECK(num_experts % num_groups == 0, + "num_experts must be divisible by num_groups, got num_experts=", num_experts, + " num_groups=", num_groups); + NVTE_CHECK(topk / group_topk <= num_experts / num_groups, + "per-group topk must be <= group size, got per_group_topk=", topk / group_topk, + " group_size=", num_experts / num_groups); } size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t total_blocks = (num_tokens + num_token_per_block - 1) / num_token_per_block; @@ -648,6 +658,16 @@ void fused_topk_with_score_function_forward_with_indices_kernel_launcher( if (group_topk > 0) { NVTE_CHECK(topk % group_topk == 0, "topk must be divisible by group_topk, got topk=", topk, " group_topk=", group_topk); + NVTE_CHECK(num_groups > 0, "num_groups must be positive when group_topk > 0, got ", num_groups); + NVTE_CHECK(group_topk <= num_groups, + "group_topk must be <= num_groups, got group_topk=", group_topk, + " num_groups=", num_groups); + NVTE_CHECK(num_experts % num_groups == 0, + "num_experts must be divisible by num_groups, got num_experts=", num_experts, + " num_groups=", num_groups); + NVTE_CHECK(topk / group_topk <= num_experts / num_groups, + "per-group topk must be <= group size, got per_group_topk=", topk / group_topk, + " group_size=", num_experts / num_groups); } if constexpr (std::is_same_v) { NVTE_CHECK(num_experts <= INT16_MAX, "int16 topk indices require num_experts <= ", INT16_MAX, @@ -745,6 +765,7 @@ void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, Tensor probs, Tensor routing_map, NVTERoutingMapFormat routing_map_format, Tensor intermediate_output, cudaStream_t stream) { + check_routing_map_format(routing_map_format); NVTE_CHECK(num_tokens > 0 && num_experts > 0, "num_tokens and num_experts must be positive; got num_tokens=", num_tokens, ", num_experts=", num_experts); @@ -1139,6 +1160,7 @@ void fused_topk_with_score_function_backward(const Tensor &routing_map, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, int score_function, Tensor &grad_logits, cudaStream_t stream) { + check_routing_map_format(routing_map_format); NVTE_CHECK(num_tokens > 0 && num_experts > 0, "num_tokens and num_experts must be positive; got num_tokens=", num_tokens, ", num_experts=", num_experts); diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 182e71913f..117828a704 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -12,11 +12,19 @@ #include "../util/logging.h" #include "../util/system.h" #include "../utils.cuh" +#include "transformer_engine/fused_router.h" #include "transformer_engine/transformer_engine.h" namespace transformer_engine { namespace fused_router { +inline void check_routing_map_format(NVTERoutingMapFormat routing_map_format) { + NVTE_CHECK(routing_map_format == NVTE_ROUTING_MAP_FORMAT_BYTEMAP || + routing_map_format == NVTE_ROUTING_MAP_FORMAT_BITMAP_U8, + "routing_map_format must be BYTEMAP (0) or BITMAP_U8 (1), got ", + static_cast(routing_map_format)); +} + // Topk values below this threshold use naive O(K*E) selection; // at or above it, use radix O(E) selection. Configurable via // NVTE_RADIX_TOPK_THRESHOLD (default 16: matches the upstream naive/radix switch, diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index 4cbea1ec76..d2c1205054 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -14,6 +14,14 @@ namespace transformer_engine::pytorch { static std::map score_function_map = { {"sigmoid", 0}, {"softmax", 1}, {"sqrtsoftplus", 2}}; +static int get_score_function_value(const std::string &score_function) { + auto it = score_function_map.find(score_function); + TORCH_CHECK(it != score_function_map.end(), + "score_function must be softmax, sigmoid or sqrtsoftplus for router fusion, got ", + score_function); + return it->second; +} + // Allocate a routing_map output tensor: // BYTEMAP -> bool [*leading_dims, num_experts] // BITMAP_U8 -> uint8[*leading_dims, ceil(num_experts/8)], LSB-first @@ -130,13 +138,13 @@ std::tuple fused_topk_with_score_function_fw nvte_fused_topk_with_score_function_forward_with_indices( logits_cu.data(), static_cast(num_tokens), static_cast(num_experts), topk, use_pre_softmax, num_groups_value, group_topk_value, scaling_factor_value, - score_function_map[score_function], expert_bias_cu.data(), probs_cu.data(), + get_score_function_value(score_function), expert_bias_cu.data(), probs_cu.data(), routing_map_cu.data(), intermediate_output_cu.data(), at::cuda::getCurrentCUDAStream()); } else { nvte_fused_topk_with_score_function_forward_v2( logits_cu.data(), static_cast(num_tokens), static_cast(num_experts), topk, use_pre_softmax, num_groups_value, group_topk_value, scaling_factor_value, - score_function_map[score_function], expert_bias_cu.data(), probs_cu.data(), + get_score_function_value(score_function), expert_bias_cu.data(), probs_cu.data(), routing_map_cu.data(), static_cast(routing_map_format), intermediate_output_cu.data(), at::cuda::getCurrentCUDAStream()); } @@ -166,7 +174,7 @@ void fused_topk_with_score_function_bwd(at::Tensor routing_map, at::Tensor inter } auto scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f; - auto score_function_value = score_function_map[score_function]; + auto score_function_value = get_score_function_value(score_function); const std::vector shape_2d = {static_cast(num_tokens), static_cast(num_experts)}; @@ -217,7 +225,7 @@ std::tuple fused_score_for_moe_aux_loss_fwd( TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid" || score_function == "sqrtsoftplus", "score_function must be softmax, sigmoid or sqrtsoftplus for router fusion"); - int score_function_value = score_function_map[score_function]; + int score_function_value = get_score_function_value(score_function); at::Tensor scores = at::empty(sizes, at::dtype(at::kFloat).device(at::kCUDA)); at::Tensor routing_map = @@ -261,7 +269,7 @@ void fused_score_for_moe_aux_loss_bwd(at::Tensor intermediate_output, at::Tensor int64_t num_tokens = std::accumulate(sizes.begin(), sizes.end() - 1, int64_t{1}, std::multiplies()); - int score_function_value = score_function_map[score_function]; + int score_function_value = get_score_function_value(score_function); const std::vector shape_2d = {static_cast(num_tokens), static_cast(num_experts)}; From 5ad179b4dfc5ac3f388e64479bf9b9106d0b6198 Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Mon, 15 Jun 2026 19:00:20 -0700 Subject: [PATCH 7/8] [PyTorch] Clarify dense router format guards Signed-off-by: Harry Zhou --- transformer_engine/pytorch/csrc/extensions/router.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index d2c1205054..a24a56e390 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -94,6 +94,9 @@ std::tuple fused_topk_with_score_function_fw use_pre_softmax = false; // Pre-softmax only happens at the softmax case } if (topk_indices.has_value()) { + TORCH_CHECK(routing_map_format == NVTE_ROUTING_MAP_FORMAT_BYTEMAP, + "topk_indices output cannot be combined with non-default routing_map_format; " + "dense top-k indices are returned instead of a routing map."); check_dense_topk_indices(topk_indices.value(), logits, num_tokens, topk); } @@ -157,7 +160,13 @@ void fused_topk_with_score_function_bwd(at::Tensor routing_map, at::Tensor inter bool use_pre_softmax, std::optional scaling_factor, std::string score_function, bool use_dense_indices, int routing_map_format) { - check_routing_map_format(routing_map_format); + if (use_dense_indices) { + TORCH_CHECK(routing_map_format == NVTE_ROUTING_MAP_FORMAT_BYTEMAP, + "use_dense_indices cannot be combined with non-default routing_map_format; " + "dense top-k indices are consumed instead of a routing map."); + } else { + check_routing_map_format(routing_map_format); + } TORCH_CHECK(grad_probs.dim() >= 1, "grad_probs must have at least 1 dim"); TORCH_CHECK(grad_probs.is_contiguous(), "grad_probs must be contiguous"); TORCH_CHECK(grad_logits.is_contiguous(), "grad_logits must be contiguous"); From 49c6553641e60071f0a8df625e85daae6d8b8974 Mon Sep 17 00:00:00 2001 From: Harry Zhou Date: Mon, 15 Jun 2026 20:31:07 -0700 Subject: [PATCH 8/8] [PyTorch] Preserve router leading dimensions Signed-off-by: Harry Zhou --- tests/pytorch/test_fused_router.py | 27 +++++++++++++++++++ .../pytorch/csrc/extensions/router.cpp | 15 +++++------ transformer_engine/pytorch/router.py | 17 +++--------- 3 files changed, 37 insertions(+), 22 deletions(-) diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index ede44d19ce..ab12216df8 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -390,6 +390,33 @@ def test_topk_softmax( ) +@pytest.mark.parametrize("topk_index_dtype", [None, torch.int16]) +def test_topk_preserves_leading_dims(topk_index_dtype): + num_tokens = 128 + num_experts = 32 + topk = 4 + logits = torch.randn(num_tokens, 2, num_experts, device="cuda", dtype=torch.float32) + topk_indices = None + if topk_index_dtype is not None: + topk_indices = torch.empty(num_tokens, 2, topk, device="cuda", dtype=topk_index_dtype) + + probs, routing_output = fused_topk_with_score_function( + logits=logits, + topk=topk, + use_pre_softmax=False, + num_groups=None, + group_topk=None, + scaling_factor=None, + score_function="softmax", + expert_bias=None, + topk_indices=topk_indices, + ) + + assert probs.shape == logits.shape + expected_routing_shape = topk_indices.shape if topk_indices is not None else logits.shape + assert routing_output.shape == expected_routing_shape + + @pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("num_tokens", [2048, 7168]) @pytest.mark.parametrize("num_experts", [1024, 256, 128, 32]) diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index a24a56e390..70762e3729 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -47,7 +47,7 @@ static bool is_supported_dense_index_dtype(at::ScalarType dtype) { } static void check_dense_topk_indices(const at::Tensor &topk_indices, const at::Tensor &ref, - int64_t num_tokens, int topk) { + c10::IntArrayRef leading_dims, int topk) { TORCH_CHECK(topk_indices.is_cuda(), "topk_indices must be a CUDA tensor"); TORCH_CHECK(topk_indices.device() == ref.device(), "topk_indices must be on the same device as ", "the logits/grad tensor"); @@ -55,11 +55,10 @@ static void check_dense_topk_indices(const at::Tensor &topk_indices, const at::T TORCH_CHECK(is_supported_dense_index_dtype(topk_indices.scalar_type()), "topk_indices dtype must be int16, int32, or int64, got ", topk_indices.scalar_type()); - TORCH_CHECK(topk_indices.numel() == num_tokens * static_cast(topk), - "topk_indices must contain num_tokens * topk elements, got ", topk_indices.numel(), - " but expected ", num_tokens * static_cast(topk)); - TORCH_CHECK(topk_indices.dim() >= 1 && topk_indices.size(-1) == topk, - "topk_indices last dimension must be topk=", topk, ", got shape ", + std::vector expected_shape(leading_dims.begin(), leading_dims.end()); + expected_shape.push_back(static_cast(topk)); + TORCH_CHECK(topk_indices.sizes() == expected_shape, + "topk_indices shape must be [*leading_dims, topk]=", expected_shape, ", got ", topk_indices.sizes()); } @@ -97,7 +96,7 @@ std::tuple fused_topk_with_score_function_fw TORCH_CHECK(routing_map_format == NVTE_ROUTING_MAP_FORMAT_BYTEMAP, "topk_indices output cannot be combined with non-default routing_map_format; " "dense top-k indices are returned instead of a routing map."); - check_dense_topk_indices(topk_indices.value(), logits, num_tokens, topk); + check_dense_topk_indices(topk_indices.value(), logits, sizes.slice(0, sizes.size() - 1), topk); } // Reformat the input to make it compatible with the kernel @@ -179,7 +178,7 @@ void fused_topk_with_score_function_bwd(at::Tensor routing_map, at::Tensor inter TORCH_CHECK(topk > 0 && topk <= num_experts, "topk must be in [1, num_experts], got topk=", topk, " num_experts=", num_experts); if (use_dense_indices) { - check_dense_topk_indices(routing_map, grad_probs, num_tokens, topk); + check_dense_topk_indices(routing_map, grad_probs, sizes.slice(0, sizes.size() - 1), topk); } auto scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f; diff --git a/transformer_engine/pytorch/router.py b/transformer_engine/pytorch/router.py index 004451ea7f..519d06ce23 100644 --- a/transformer_engine/pytorch/router.py +++ b/transformer_engine/pytorch/router.py @@ -86,10 +86,6 @@ def forward( topk_indices: Optional[torch.Tensor], ): # pylint: disable=missing-function-docstring - tensor_shape = logits.shape - logits = logits.view(-1, tensor_shape[-1]) - num_tokens = logits.size(0) - num_experts = logits.size(1) probs, routing_output, intermediate_output = tex.fused_topk_with_score_function_fwd( logits, topk, @@ -104,20 +100,15 @@ def forward( ) if topk_indices is not None: routing_output = topk_indices - probs = probs.view(tensor_shape) if topk_indices is not None: ctx.mark_dirty(topk_indices) ctx.mark_non_differentiable(routing_output) ctx.save_for_backward(routing_output, intermediate_output) - ctx.num_tokens = num_tokens - ctx.num_experts = num_experts - ctx.tensor_shape = tensor_shape ctx.use_pre_softmax = use_pre_softmax ctx.topk = topk ctx.scaling_factor = scaling_factor ctx.score_function = score_function ctx.routing_map_format = routing_map_format - ctx.logits_dtype = logits.dtype ctx.use_dense_indices = topk_indices is not None return probs, routing_output @@ -125,10 +116,9 @@ def forward( def backward(ctx, grad_probs, _): # pylint: disable=missing-function-docstring routing_map, intermediate_output = ctx.saved_tensors - grad_probs = grad_probs.contiguous().view(-1, ctx.tensor_shape[-1]) - grad_logits = torch.empty( - (ctx.num_tokens, ctx.num_experts), dtype=ctx.logits_dtype, device=grad_probs.device - ) + if not grad_probs.is_contiguous(): + grad_probs = grad_probs.contiguous() + grad_logits = torch.empty_like(grad_probs) tex.fused_topk_with_score_function_bwd( routing_map, intermediate_output, @@ -141,7 +131,6 @@ def backward(ctx, grad_probs, _): ctx.use_dense_indices, ctx.routing_map_format, ) - grad_logits = grad_logits.view(ctx.tensor_shape) return grad_logits, None, None, None, None, None, None, None, None, None