Skip to content

Enable NVFP4 RHT amax for grouped SReLU MLP#3133

Open
sraman-rgb wants to merge 1 commit into
NVIDIA:mainfrom
sraman-rgb:te-nvfp4-srelu-rht-hadamard
Open

Enable NVFP4 RHT amax for grouped SReLU MLP#3133
sraman-rgb wants to merge 1 commit into
NVIDIA:mainfrom
sraman-rgb:te-nvfp4-srelu-rht-hadamard

Conversation

@sraman-rgb

Copy link
Copy Markdown
Contributor

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@sraman-rgb sraman-rgb requested a review from timmoon10 as a code owner June 16, 2026 18:42
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 16, 2026
Signed-off-by: Siddhartha Raman <sraman@nvidia.com>
@sraman-rgb sraman-rgb force-pushed the te-nvfp4-srelu-rht-hadamard branch from fa32e3b to 79def34 Compare June 16, 2026 18:45
@greptile-apps

greptile-apps Bot commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR extends the NVFP4 RHT amax path — previously gated on a SwiGLU activation — to also cover GroupedMLP_CuTeGEMMUnary (SReLU). The key changes are: extracting activation_op from self.basic_ops in fuser_forward, adding an activation_is_srelu flag to gate the new code path, renaming glu_hadamard to act_hadamard throughout, and adding grouped_gemm_act_hadamard_kernel to GroupedMLP_CuTeGEMMUnary (which reuses grouped_gemm_glu_hadamard_wrapper_sm100 dispatched with act_func=\"srelu\").

  • grouped_mlp.py: Adds activation_is_srelu detection, wires SReLU into the hadamard kernel path with act_func=\"srelu\", and gives GroupedMLP_CuTeGEMMUnary its own grouped_gemm_act_hadamard_kernel that reuses the existing GLU hadamard cuDNN wrapper.
  • test_fusible_ops.py: Parameterises the existing grouped-MLP test over activation type, widens tolerances for nvfp4_rht, skips bias=True + SReLU + NVFP4 RHT (coverage limitation), and adds a dedicated test_grouped_mlp_nvfp4_rht_srelu entry point."

Confidence Score: 4/5

The change is narrowly scoped to the NVFP4 RHT amax code path. Existing SwiGLU and non-quantized flows are structurally unchanged, so regressions there are unlikely. The new SReLU hadamard path is exercised by a dedicated test, though the bias+SReLU+NVFP4 combination has no test and no production guard.

The logic is straightforward and the rename from glu_hadamard to act_hadamard is applied consistently. The one area that warrants a second look is the undocumented reuse of grouped_gemm_glu_hadamard_wrapper_sm100 for SReLU and the absence of any production-side guard for the bias+SReLU+NVFP4 combination that tests explicitly skip.

transformer_engine/pytorch/ops/fused/grouped_mlp.py — specifically GroupedMLP_CuTeGEMMUnary.grouped_gemm_act_hadamard_kernel and the fuser_forward hadamard path around the bias+SReLU case.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/grouped_mlp.py Enables NVFP4 RHT amax for grouped SReLU MLP by extracting activation_op from basic_ops, adding activation_is_srelu detection, renaming glu_hadamard to act_hadamard throughout, and adding grouped_gemm_act_hadamard_kernel to GroupedMLP_CuTeGEMMUnary (which re-uses grouped_gemm_glu_hadamard_wrapper_sm100 with act_func="srelu").
tests/pytorch/test_fusible_ops.py Extends test_grouped_mlp to accept an activation parameter (scaled_swiglu or scaled_srelu), adds wider tolerances for nvfp4_rht, skips bias+SReLU+nvfp4_rht, and adds a dedicated test_grouped_mlp_nvfp4_rht_srelu entry point.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[fuser_forward] --> B{use_nvfp4_rht_amax?}
    B -- No --> C[grouped_gemm_activation_kernel\nnorm_const_tensor path]
    B -- Yes --> D{_cudnn_act_func == swiglu\nOR activation_is_srelu?}
    D -- No --> C
    D -- Yes --> E[grouped_gemm_act_hadamard_kernel available?]
    E -- No --> C
    E -- Yes --> F{activation_is_srelu?}
    F -- Yes --> G[act_func = srelu\nGroupedMLP_CuTeGEMMUnary path]
    F -- No --> H[act_func = swiglu / geglu\nGroupedMLP_CuTeGEMMGLU path]
    G --> I[grouped_gemm_glu_hadamard_wrapper_sm100\nact_func=srelu]
    H --> I
    I --> J[_group_quantize_with_amax_for_grouped_mlp\nuses amax_tensor + post_rht_amax_tensor]
    C --> K[_group_quantize_for_grouped_mlp]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A[fuser_forward] --> B{use_nvfp4_rht_amax?}
    B -- No --> C[grouped_gemm_activation_kernel\nnorm_const_tensor path]
    B -- Yes --> D{_cudnn_act_func == swiglu\nOR activation_is_srelu?}
    D -- No --> C
    D -- Yes --> E[grouped_gemm_act_hadamard_kernel available?]
    E -- No --> C
    E -- Yes --> F{activation_is_srelu?}
    F -- Yes --> G[act_func = srelu\nGroupedMLP_CuTeGEMMUnary path]
    F -- No --> H[act_func = swiglu / geglu\nGroupedMLP_CuTeGEMMGLU path]
    G --> I[grouped_gemm_glu_hadamard_wrapper_sm100\nact_func=srelu]
    H --> I
    I --> J[_group_quantize_with_amax_for_grouped_mlp\nuses amax_tensor + post_rht_amax_tensor]
    C --> K[_group_quantize_for_grouped_mlp]
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/ops/fused/grouped_mlp.py, line 2154-2170 (link)

    P2 Both GroupedMLP_CuTeGEMMGLU and GroupedMLP_CuTeGEMMUnary now define grouped_gemm_act_hadamard_kernel with identical bodies — both import and return grouped_gemm_glu_hadamard_wrapper_sm100. The GLU class returning a kernel named glu_hadamard is self-explanatory, but the Unary/SReLU class silently reusing that same kernel (relying on act_func="srelu" to switch behaviour) is non-obvious and the body gives no hint of this. A docstring clarification would make the intent clear to future readers.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Reviews (1): Last reviewed commit: "Enable NVFP4 RHT amax for grouped SReLU ..." | Re-trigger Greptile

if with_quantization and dtype not in (torch.bfloat16, torch.float16):
pytest.skip("Quantized group GEMM is only supported with BF16/FP16")
if activation == "scaled_srelu" and quantization == "nvfp4_rht" and bias:
pytest.skip("NVFP4 RHT SReLU grouped MLP coverage is limited to no-bias")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Skipped combination has no production guard

The test skips activation="scaled_srelu" + quantization="nvfp4_rht" + bias=True with the message "coverage is limited to no-bias", but grouped_mlp.py has no corresponding runtime check. If the underlying grouped_gemm_glu_hadamard_wrapper_sm100 kernel does not support a bias tensor when called with act_func="srelu", a production caller that uses GroupedMLP_CuTeGEMMUnary with NVFP4 RHT and a bias will silently reach the hadamard kernel path and either crash or produce wrong results with no helpful error message. If this combination is genuinely unsupported, a ValueError/RuntimeError should be raised in fuser_forward when use_fc1_act_hadamard_srelu and fc1_bias_packed is not None.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant