Enable NVFP4 RHT amax for grouped SReLU MLP#3133
Conversation
Signed-off-by: Siddhartha Raman <sraman@nvidia.com>
fa32e3b to
79def34
Compare
Greptile SummaryThis PR extends the NVFP4 RHT amax path — previously gated on a SwiGLU activation — to also cover
Confidence Score: 4/5The change is narrowly scoped to the NVFP4 RHT amax code path. Existing SwiGLU and non-quantized flows are structurally unchanged, so regressions there are unlikely. The new SReLU hadamard path is exercised by a dedicated test, though the bias+SReLU+NVFP4 combination has no test and no production guard. The logic is straightforward and the rename from glu_hadamard to act_hadamard is applied consistently. The one area that warrants a second look is the undocumented reuse of grouped_gemm_glu_hadamard_wrapper_sm100 for SReLU and the absence of any production-side guard for the bias+SReLU+NVFP4 combination that tests explicitly skip. transformer_engine/pytorch/ops/fused/grouped_mlp.py — specifically GroupedMLP_CuTeGEMMUnary.grouped_gemm_act_hadamard_kernel and the fuser_forward hadamard path around the bias+SReLU case. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[fuser_forward] --> B{use_nvfp4_rht_amax?}
B -- No --> C[grouped_gemm_activation_kernel\nnorm_const_tensor path]
B -- Yes --> D{_cudnn_act_func == swiglu\nOR activation_is_srelu?}
D -- No --> C
D -- Yes --> E[grouped_gemm_act_hadamard_kernel available?]
E -- No --> C
E -- Yes --> F{activation_is_srelu?}
F -- Yes --> G[act_func = srelu\nGroupedMLP_CuTeGEMMUnary path]
F -- No --> H[act_func = swiglu / geglu\nGroupedMLP_CuTeGEMMGLU path]
G --> I[grouped_gemm_glu_hadamard_wrapper_sm100\nact_func=srelu]
H --> I
I --> J[_group_quantize_with_amax_for_grouped_mlp\nuses amax_tensor + post_rht_amax_tensor]
C --> K[_group_quantize_for_grouped_mlp]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A[fuser_forward] --> B{use_nvfp4_rht_amax?}
B -- No --> C[grouped_gemm_activation_kernel\nnorm_const_tensor path]
B -- Yes --> D{_cudnn_act_func == swiglu\nOR activation_is_srelu?}
D -- No --> C
D -- Yes --> E[grouped_gemm_act_hadamard_kernel available?]
E -- No --> C
E -- Yes --> F{activation_is_srelu?}
F -- Yes --> G[act_func = srelu\nGroupedMLP_CuTeGEMMUnary path]
F -- No --> H[act_func = swiglu / geglu\nGroupedMLP_CuTeGEMMGLU path]
G --> I[grouped_gemm_glu_hadamard_wrapper_sm100\nact_func=srelu]
H --> I
I --> J[_group_quantize_with_amax_for_grouped_mlp\nuses amax_tensor + post_rht_amax_tensor]
C --> K[_group_quantize_for_grouped_mlp]
|
| if with_quantization and dtype not in (torch.bfloat16, torch.float16): | ||
| pytest.skip("Quantized group GEMM is only supported with BF16/FP16") | ||
| if activation == "scaled_srelu" and quantization == "nvfp4_rht" and bias: | ||
| pytest.skip("NVFP4 RHT SReLU grouped MLP coverage is limited to no-bias") |
There was a problem hiding this comment.
Skipped combination has no production guard
The test skips activation="scaled_srelu" + quantization="nvfp4_rht" + bias=True with the message "coverage is limited to no-bias", but grouped_mlp.py has no corresponding runtime check. If the underlying grouped_gemm_glu_hadamard_wrapper_sm100 kernel does not support a bias tensor when called with act_func="srelu", a production caller that uses GroupedMLP_CuTeGEMMUnary with NVFP4 RHT and a bias will silently reach the hadamard kernel path and either crash or produce wrong results with no helpful error message. If this combination is genuinely unsupported, a ValueError/RuntimeError should be raised in fuser_forward when use_fc1_act_hadamard_srelu and fc1_bias_packed is not None.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: