Skip to content

[Common] Support scaled & clamped swiglu, srelu for BF16 #3132

Open
zhongbozhu wants to merge 6 commits into
NVIDIA:mainfrom
zhongbozhu:add_support_fused_swiglu
Open

[Common] Support scaled & clamped swiglu, srelu for BF16 #3132
zhongbozhu wants to merge 6 commits into
NVIDIA:mainfrom
zhongbozhu:add_support_fused_swiglu

Conversation

@zhongbozhu

Copy link
Copy Markdown
Collaborator

Description

Support Mega-C++ with Cublas BF16 Grouped GEMM backend: #3099

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 16, 2026
@zhongbozhu zhongbozhu marked this pull request as ready for review June 16, 2026 07:32
@zhongbozhu zhongbozhu requested a review from ptrendx as a code owner June 16, 2026 07:32
@greptile-apps

greptile-apps Bot commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds six new CUDA kernels (scaled_activation.cu) that fuse a per-row activation scale into SwiGLU, ClampedSwiGLU, and SReLU forward and backward passes, along with a corresponding public C API in activation.h and a parametrized GTest suite.

  • Kernel design: Forward kernels read act/gate streams in vectorized row segments (with GLU-interleave support), multiply the activation output by act_scales[row], and store in the target dtype in a single pass. Backward kernels have two code paths — a flat element-wise grid when grad_act_scales is null, and a one-block-per-row warp-reduction path when the per-row scale gradient must be accumulated.
  • Math correctness: The SiLU, ClampedSiLU, and SReLU forward/backward formulas in the kernels match their reference implementations in util/math.h and the test reference functions; the block reduction logic is correct.
  • Minor cleanup items: A redundant gated_unscaled call in the test reference and a dead Empty variable in nvte_scaled_swiglu are left in; FP16 is absent from the test dtype sweep despite being covered by the dispatch macro; and the one-block-per-row kernel launch casts rows (a size_t) to int for the grid dimension.

Confidence Score: 4/5

Safe to merge; the new kernels are mathematically consistent with the existing utility functions and the test suite covers the primary code paths for both contiguous and interleaved GLU layouts.

The core kernel math, alignment dispatch, and block reduction are correct. The only items worth addressing before shipping are: FP16 is absent from the test dtype sweep even though the dispatch macro includes it, the one-block-per-row launch casts size_t rows to int (wrap-around for multi-billion-token batches), one dead call in the test reference, and one dead variable in the API wrapper. None of these affect correctness for typical workloads.

scaled_activation.cu (the rows cast in the reduction kernel launch) and test_scaled_activation.cu (missing FP16 coverage and dead reference call).

Important Files Changed

Filename Overview
transformer_engine/common/activation/scaled_activation.cu New 781-line CUDA file implementing 6 kernels (scaled forward/backward for SwiGLU, ClampedSwiGLU, SReLU) with vectorized loads, interleaved GLU layout support, and optional per-row scale-gradient reduction; includes dead Empty variable and a size_tint cast for the grid-launch block count.
tests/cpp/operator/test_scaled_activation.cu New 321-line test file with a parametrized GTest suite covering forward+backward for all three activations and both interleave modes; has a redundant gated_unscaled call in the reference and is missing kFloat16 from the dtype sweep.
transformer_engine/common/include/transformer_engine/activation.h Adds public C API declarations for 6 new scaled-activation functions with well-documented Doxygen comments; no issues found.
transformer_engine/common/CMakeLists.txt Registers scaled_activation.cu in both the standard and fast-math source lists; straightforward and correct.
tests/cpp/operator/CMakeLists.txt Adds test_scaled_activation.cu to the test_operator executable; no issues.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    API_FWD["nvte_scaled_swiglu /\nnvte_scaled_clamped_swiglu /\nnvte_scaled_srelu"]
    API_BWD["nvte_scaled_dswiglu /\nnvte_scaled_clamped_dswiglu /\nnvte_scaled_dsrelu"]

    API_FWD --> CHK_GATED_FWD{Gated?}
    CHK_GATED_FWD -- "SwiGLU / ClampedSwiGLU" --> ALIGN_FWD[check alignment & segment layout]
    CHK_GATED_FWD -- "SReLU" --> ALIGN_SRELU_FWD[check alignment]

    ALIGN_FWD -- "aligned" --> KFG_VEC["scaled_gated_forward_kernel nvec>1"]
    ALIGN_FWD -- "unaligned" --> KFG_SCAL["scaled_gated_forward_kernel nvec=1"]
    ALIGN_SRELU_FWD -- "aligned" --> KSF_VEC["scaled_srelu_forward_kernel nvec>1"]
    ALIGN_SRELU_FWD -- "unaligned" --> KSF_SCAL["scaled_srelu_forward_kernel nvec=1"]

    API_BWD --> CHK_GATED_BWD{Gated?}
    CHK_GATED_BWD -- "SwiGLU / ClampedSwiGLU" --> CHK_SCALE_G[grad_act_scales?]
    CHK_GATED_BWD -- "SReLU" --> CHK_SCALE_S[grad_act_scales?]

    CHK_SCALE_G -- "null" --> KGB_FLAT["scaled_gated_backward_kernel flat grid"]
    CHK_SCALE_G -- "present" --> KGB_RED["scaled_gated_backward_with_scale_grad_kernel one block per row + warp reduction"]

    CHK_SCALE_S -- "null" --> KSB_FLAT["scaled_srelu_backward_kernel flat grid"]
    CHK_SCALE_S -- "present" --> KSB_RED["scaled_srelu_backward_with_scale_grad_kernel one block per row + warp reduction"]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    API_FWD["nvte_scaled_swiglu /\nnvte_scaled_clamped_swiglu /\nnvte_scaled_srelu"]
    API_BWD["nvte_scaled_dswiglu /\nnvte_scaled_clamped_dswiglu /\nnvte_scaled_dsrelu"]

    API_FWD --> CHK_GATED_FWD{Gated?}
    CHK_GATED_FWD -- "SwiGLU / ClampedSwiGLU" --> ALIGN_FWD[check alignment & segment layout]
    CHK_GATED_FWD -- "SReLU" --> ALIGN_SRELU_FWD[check alignment]

    ALIGN_FWD -- "aligned" --> KFG_VEC["scaled_gated_forward_kernel nvec>1"]
    ALIGN_FWD -- "unaligned" --> KFG_SCAL["scaled_gated_forward_kernel nvec=1"]
    ALIGN_SRELU_FWD -- "aligned" --> KSF_VEC["scaled_srelu_forward_kernel nvec>1"]
    ALIGN_SRELU_FWD -- "unaligned" --> KSF_SCAL["scaled_srelu_forward_kernel nvec=1"]

    API_BWD --> CHK_GATED_BWD{Gated?}
    CHK_GATED_BWD -- "SwiGLU / ClampedSwiGLU" --> CHK_SCALE_G[grad_act_scales?]
    CHK_GATED_BWD -- "SReLU" --> CHK_SCALE_S[grad_act_scales?]

    CHK_SCALE_G -- "null" --> KGB_FLAT["scaled_gated_backward_kernel flat grid"]
    CHK_SCALE_G -- "present" --> KGB_RED["scaled_gated_backward_with_scale_grad_kernel one block per row + warp reduction"]

    CHK_SCALE_S -- "null" --> KSB_FLAT["scaled_srelu_backward_kernel flat grid"]
    CHK_SCALE_S -- "present" --> KSB_RED["scaled_srelu_backward_with_scale_grad_kernel one block per row + warp reduction"]
Loading

Comments Outside Diff (3)

  1. tests/cpp/operator/test_scaled_activation.cu, line 326-327 (link)

    P2 Missing FP16 dtype in test matrix

    The implementation dispatches through TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY, which covers float32, float16, and bfloat16. The test matrix only exercises kFloat32 and kBFloat16, leaving kFloat16 untested for both data and scale tensors. An off-by-one in the vector-width calculation or a narrow-type saturation edge case specific to FP16 would pass the current suite undetected.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  2. transformer_engine/common/activation/scaled_activation.cu, line 1087-1093 (link)

    P2 Dead Empty variable in nvte_scaled_swiglu

    Empty empty = {}; (void)empty; is never passed to any function in this API wrapper — the Empty type is used only inside the CUDA kernels via silu<float, float>(act_in, empty). The declaration and the suppression cast are both dead code and can be removed.

  3. transformer_engine/common/activation/scaled_activation.cu, line 994-1005 (link)

    P2 static_cast<int>(rows) may overflow for very large grid launches

    The "with scale grad" kernels are launched with <<<static_cast<int>(rows), kReductionThreads, ...>>>. rows is size_t; casting it directly to int silently wraps around for values above INT_MAX (~2.1 billion), producing a negative or near-zero block count and silently writing garbage to grad_act_scales. Using dim3(rows) (which uses the unsigned grid-dimension type) or guarding with NVTE_CHECK(rows <= INT32_MAX, ...) before the launch avoids the truncation. The same pattern appears in launch_scaled_srelu_backward.

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +170 to +171
}
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Dead gated_unscaled call

gated_unscaled computes unscaled on line 170, but gated_grads unconditionally writes *unscaled on line 171, overwriting it. The first call is dead code — every gated_grads case sets *unscaled before returning, so the result of gated_unscaled is never observed. This should simply be removed.

@zhongbozhu

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant