Skip to content

[Common] Add dense router output for fused router#3129

Open
harryzhou2000 wants to merge 8 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/router_dense_route
Open

[Common] Add dense router output for fused router#3129
harryzhou2000 wants to merge 8 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/router_dense_route

Conversation

@harryzhou2000

@harryzhou2000 harryzhou2000 commented Jun 15, 2026

Copy link
Copy Markdown
Member

The fused router forward kernels currently produce boolean or bitmap sparse routing map format, which may need an additional conversion to topk_indices format (in the shape of [*leading_dims, top_k]) to be passed to the dispatcher. For example, NCCL EP accepts topk_indices as routing map. To avoid needing an extra kernel for routing map conversion, the fused router could directly write into that format.

For NCCL EP, the dense topk_indices row is consumed as an order-insensitive selected-expert set. The dense output therefore does not promise score-sorted or expert-sorted order; it preserves the selected experts produced by the chosen top-k kernel path.

Summary

  • Add an optional dense topk_indices output path to fused router top-k.
  • Avoid materializing the full routing map when callers provide a [*leading_dims, topk] index buffer.
  • Add dense-index backward using the selected expert indices directly.
  • Support int16, int32, and int64 dense index buffers.
  • Preserve existing BYTEMAP/BITMAP_U8 routing-map paths and p3R radix/naive dispatch behavior.
  • Add guard checks for dense index shape/device/dtype, grouped top-k assumptions, routing-map format, and direct score-function API usage.
  • Add int16 support for TE CUDA graph weak-ref tensors.

Testing

  • Built on B200 with:
    NVTE_BUILD_THREADS_PER_JOB=4 NVTE_CUDA_ARCHS="90;100;103a;120" NVTE_USE_CCACHE=1 pip install --no-build-isolation -e .[test] --verbose
  • Ran:
    python -m pytest -q tests/pytorch/test_fused_router.py
  • Result:
    3203 passed, 444 skipped, 3 warnings in 44.17s

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
@harryzhou2000 harryzhou2000 changed the title [Common] Add dense top-k index output for fused router [Common] Add dense router output for fused router Jun 15, 2026
@harryzhou2000 harryzhou2000 marked this pull request as ready for review June 15, 2026 12:48
@greptile-apps

greptile-apps Bot commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds an optional dense topk_indices output path to the fused router's top-k operation, avoiding the need for a separate kernel to convert sparse routing maps into the [num_tokens, topk] format required by NCCL EP dispatchers.

  • Forward path: new fused_topk_with_score_function_forward_with_indices C++ entry point writes selected expert indices directly into a caller-provided int16/int32/int64 buffer, bypassing routing-map materialization; both the Naive and Radix kernel paths are covered.
  • Backward path: new fused_topk_backward_selected_indices_kernel uses the saved dense indices instead of the full routing map to reconstruct per-expert gradients, supporting all three score functions (sigmoid, softmax, sqrtsoftplus) with and without pre-softmax.
  • Guards and shape preservation: multi-dimensional logits are correctly propagated through 2D kernel wrappers; the topk_indices buffer must match the leading dims of logits; combining dense indices with a non-default routing_map_format is rejected at both C++ and Python layers.

Confidence Score: 5/5

Safe to merge; the new dense-index path is well-isolated behind explicit guards and does not alter behavior for existing sparse-path callers.

All new kernel paths correctly mirror existing routing-map counterparts across all three score functions and both pre/post-softmax modes. Guards prevent invalid combinations, shape preservation for multi-dimensional logits is tested, and autograd plumbing is correct.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/fused_topk_with_score_function.cu Main CUDA implementation; new fused_topk_backward_selected_indices_kernel correctly handles all three score functions and both pre/post-softmax variants; bitmap+nullptr routing_map edge case falls through to BYTEMAP else-branch correctly.
transformer_engine/pytorch/csrc/extensions/router.cpp Adds dense-index forward and backward dispatch; guard checks correctly reject invalid format+indices combinations; shape-2D wrapping for multi-dimensional logits is consistent with existing code.
transformer_engine/pytorch/router.py Adds topk_indices parameter to FusedTopkScoreFunction.forward; autograd plumbing (mark_dirty, mark_non_differentiable, save_for_backward) is correct; redundant double-check idiom is a minor style point.
transformer_engine/common/fused_router/utils.h Adds check_routing_map_format helper and TE_ROUTER_DENSE_INDEX_TYPE_SWITCH_ALL macro for int16/int32/int64 dispatch.
transformer_engine/common/include/transformer_engine/fused_router.h New public C API declarations for forward/backward with dense indices are clearly documented.
tests/pytorch/test_fused_router.py Parametrizes all three test functions over topk_index_dtype in {None, int16, int32, int64}; adds test_topk_preserves_leading_dims for 3D logits; checks buffer aliasing via data_ptr().
transformer_engine/pytorch/utils.py Adds torch.int16 to the numpy typestring dict for CUDA graph weak-ref tensor support.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["fused_topk_with_score_function(logits, topk, ...)"]
    A --> B{topk_indices provided?}
    B -- No --> C["existing sparse path: fused_topk_with_score_function_forward_v2"]
    C --> D["Allocate routing_map (BYTEMAP or BITMAP)"]
    D --> E["Returns probs + routing_map"]
    B -- Yes --> F["check_dense_topk_indices (shape/dtype/device)"]
    F --> G["fused_topk_with_score_function_forward_with_indices"]
    G --> H{topk >= radix threshold?}
    H -- No --> I["fused_topk_forward_simple_kernel (Naive, IndexType, routing_map=nullptr)"]
    H -- Yes --> J["fused_topk_with_score_function_forward_kernel (Radix, ScoreFunc, IndexType, routing_map=nullptr)"]
    I --> K["Write topk_indices_output; no routing_map allocation"]
    J --> K
    K --> L["Returns probs + topk_indices (aliased)"]
    E --> M["Backward via routing_map path"]
    L --> N["Backward via fused_topk_backward_selected_indices_kernel"]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A["fused_topk_with_score_function(logits, topk, ...)"]
    A --> B{topk_indices provided?}
    B -- No --> C["existing sparse path: fused_topk_with_score_function_forward_v2"]
    C --> D["Allocate routing_map (BYTEMAP or BITMAP)"]
    D --> E["Returns probs + routing_map"]
    B -- Yes --> F["check_dense_topk_indices (shape/dtype/device)"]
    F --> G["fused_topk_with_score_function_forward_with_indices"]
    G --> H{topk >= radix threshold?}
    H -- No --> I["fused_topk_forward_simple_kernel (Naive, IndexType, routing_map=nullptr)"]
    H -- Yes --> J["fused_topk_with_score_function_forward_kernel (Radix, ScoreFunc, IndexType, routing_map=nullptr)"]
    I --> K["Write topk_indices_output; no routing_map allocation"]
    J --> K
    K --> L["Returns probs + topk_indices (aliased)"]
    E --> M["Backward via routing_map path"]
    L --> N["Backward via fused_topk_backward_selected_indices_kernel"]
Loading

Reviews (3): Last reviewed commit: "[PyTorch] Preserve router leading dimens..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/csrc/extensions/router.cpp
Comment thread transformer_engine/pytorch/csrc/extensions/router.cpp
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Comment thread transformer_engine/pytorch/router.py Outdated
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant