[Common] Support scaled & clamped swiglu, srelu for BF16 #3132
[Common] Support scaled & clamped swiglu, srelu for BF16 #3132zhongbozhu wants to merge 6 commits into
Conversation
Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds six new CUDA kernels (
Confidence Score: 4/5Safe to merge; the new kernels are mathematically consistent with the existing utility functions and the test suite covers the primary code paths for both contiguous and interleaved GLU layouts. The core kernel math, alignment dispatch, and block reduction are correct. The only items worth addressing before shipping are: FP16 is absent from the test dtype sweep even though the dispatch macro includes it, the one-block-per-row launch casts
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
API_FWD["nvte_scaled_swiglu /\nnvte_scaled_clamped_swiglu /\nnvte_scaled_srelu"]
API_BWD["nvte_scaled_dswiglu /\nnvte_scaled_clamped_dswiglu /\nnvte_scaled_dsrelu"]
API_FWD --> CHK_GATED_FWD{Gated?}
CHK_GATED_FWD -- "SwiGLU / ClampedSwiGLU" --> ALIGN_FWD[check alignment & segment layout]
CHK_GATED_FWD -- "SReLU" --> ALIGN_SRELU_FWD[check alignment]
ALIGN_FWD -- "aligned" --> KFG_VEC["scaled_gated_forward_kernel nvec>1"]
ALIGN_FWD -- "unaligned" --> KFG_SCAL["scaled_gated_forward_kernel nvec=1"]
ALIGN_SRELU_FWD -- "aligned" --> KSF_VEC["scaled_srelu_forward_kernel nvec>1"]
ALIGN_SRELU_FWD -- "unaligned" --> KSF_SCAL["scaled_srelu_forward_kernel nvec=1"]
API_BWD --> CHK_GATED_BWD{Gated?}
CHK_GATED_BWD -- "SwiGLU / ClampedSwiGLU" --> CHK_SCALE_G[grad_act_scales?]
CHK_GATED_BWD -- "SReLU" --> CHK_SCALE_S[grad_act_scales?]
CHK_SCALE_G -- "null" --> KGB_FLAT["scaled_gated_backward_kernel flat grid"]
CHK_SCALE_G -- "present" --> KGB_RED["scaled_gated_backward_with_scale_grad_kernel one block per row + warp reduction"]
CHK_SCALE_S -- "null" --> KSB_FLAT["scaled_srelu_backward_kernel flat grid"]
CHK_SCALE_S -- "present" --> KSB_RED["scaled_srelu_backward_with_scale_grad_kernel one block per row + warp reduction"]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
API_FWD["nvte_scaled_swiglu /\nnvte_scaled_clamped_swiglu /\nnvte_scaled_srelu"]
API_BWD["nvte_scaled_dswiglu /\nnvte_scaled_clamped_dswiglu /\nnvte_scaled_dsrelu"]
API_FWD --> CHK_GATED_FWD{Gated?}
CHK_GATED_FWD -- "SwiGLU / ClampedSwiGLU" --> ALIGN_FWD[check alignment & segment layout]
CHK_GATED_FWD -- "SReLU" --> ALIGN_SRELU_FWD[check alignment]
ALIGN_FWD -- "aligned" --> KFG_VEC["scaled_gated_forward_kernel nvec>1"]
ALIGN_FWD -- "unaligned" --> KFG_SCAL["scaled_gated_forward_kernel nvec=1"]
ALIGN_SRELU_FWD -- "aligned" --> KSF_VEC["scaled_srelu_forward_kernel nvec>1"]
ALIGN_SRELU_FWD -- "unaligned" --> KSF_SCAL["scaled_srelu_forward_kernel nvec=1"]
API_BWD --> CHK_GATED_BWD{Gated?}
CHK_GATED_BWD -- "SwiGLU / ClampedSwiGLU" --> CHK_SCALE_G[grad_act_scales?]
CHK_GATED_BWD -- "SReLU" --> CHK_SCALE_S[grad_act_scales?]
CHK_SCALE_G -- "null" --> KGB_FLAT["scaled_gated_backward_kernel flat grid"]
CHK_SCALE_G -- "present" --> KGB_RED["scaled_gated_backward_with_scale_grad_kernel one block per row + warp reduction"]
CHK_SCALE_S -- "null" --> KSB_FLAT["scaled_srelu_backward_kernel flat grid"]
CHK_SCALE_S -- "present" --> KSB_RED["scaled_srelu_backward_with_scale_grad_kernel one block per row + warp reduction"]
|
| } | ||
| } |
There was a problem hiding this comment.
gated_unscaled computes unscaled on line 170, but gated_grads unconditionally writes *unscaled on line 171, overwriting it. The first call is dead code — every gated_grads case sets *unscaled before returning, so the result of gated_unscaled is never observed. This should simply be removed.
|
/te-ci pytorch |
Description
Support Mega-C++ with Cublas BF16 Grouped GEMM backend: #3099
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: