Skip to content

[DONOT MERGE] Wgrad cute dsl v2#2872

Draft
vthumbe1503 wants to merge 11 commits intoNVIDIA:mainfrom
vthumbe1503:wgrad_cute_dsl_v2
Draft

[DONOT MERGE] Wgrad cute dsl v2#2872
vthumbe1503 wants to merge 11 commits intoNVIDIA:mainfrom
vthumbe1503:wgrad_cute_dsl_v2

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vthumbe1503 and others added 9 commits April 13, 2026 02:46
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…/TransformerEngine into users/vthumbe/wgrad_cute_dsl
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 changed the title Wgrad cute dsl v2 [DONOT MERGE] Wgrad cute dsl v2 Apr 13, 2026
vthumbe1503 and others added 2 commits April 13, 2026 15:33
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 13, 2026

Greptile Summary

This PR wires up an experimental CuTe DSL grouped GEMM wgrad kernel from cudnn's SM100 front-end into the BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8 fused backward pass, replacing the cublas fallback path when the kernel is available. A new _compute_grad_params helper centralises weight-gradient allocation and dispatch across both the cuDNN and cublas code paths.

  • P1: grouped_gemm_wgrad_kernel() raises ImportError instead of returning None when grouped_gemm_wgrad_wrapper_sm100 is absent — the None-check fallback inside _compute_grad_params is therefore unreachable, and the backward pass will crash on any setup where the dGLU/quant kernels exist but the wgrad API does not. _nvidia_cudnn_frontend_supports_wgrad (added in _common.py and imported into the backward file) was clearly intended as the version gate but was never wired in.

Confidence Score: 4/5

Not safe to merge until grouped_gemm_wgrad_kernel() is fixed to return None on ImportError instead of propagating the exception.

One P1 defect: on any system where grouped_gemm_dglu_kernel is available but grouped_gemm_wgrad_wrapper_sm100 is not, the backward pass crashes at runtime rather than silently falling back to cublas. The fix is a small try/except plus wiring in _nvidia_cudnn_frontend_supports_wgrad.

transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py — specifically grouped_gemm_wgrad_kernel() (lines 300–307) and its call-sites at lines 660 and 757.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Adds CuTe DSL grouped GEMM wgrad kernel path and helper _compute_grad_params; grouped_gemm_wgrad_kernel() raises ImportError instead of returning None as documented, breaking the cublas fallback in _compute_grad_params.
transformer_engine/pytorch/ops/_common.py Adds _nvidia_cudnn_frontend_supports_wgrad() version check (mirrors _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu); no issues found in this file.

Comments Outside Diff (1)

  1. transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py, line 26-31 (link)

    P2 Unused import _nvidia_cudnn_frontend_supports_wgrad

    _nvidia_cudnn_frontend_supports_wgrad is imported here but never referenced anywhere in this file. It was clearly intended to be the version guard inside grouped_gemm_wgrad_kernel() (see the P1 finding above); once that fix is applied the import becomes necessary and the issue resolves itself. If the fix is not applied, the import should be removed to keep the namespace clean.

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines 300 to 307
def grouped_gemm_wgrad_kernel(cls) -> Optional[Callable]:
"""CuTe DSL kernel for grouped GEMM wgrad on SM100+.
Returns ``None`` when the cuDNN front-end package is older than
1.23.0.
Returns ``None`` when the cuDNN front-end wgrad API is not
available or lacks the required wgrad_tensor/wgrad_ptrs params.
"""
if not _nvidia_cudnn_frontend_supports_wgrad():
return None
from cudnn import grouped_gemm_wgrad_wrapper_sm100 # pylint: disable=no-name-in-module

return grouped_gemm_wgrad_wrapper_sm100
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 grouped_gemm_wgrad_kernel() never returns None as documented

The docstring promises None when the wgrad API is unavailable, but a bare from cudnn import grouped_gemm_wgrad_wrapper_sm100 raises ImportError instead of returning None. _compute_grad_params (called at lines 660 and 757) branches on if cudnn_wgrad_kernel_fn is not None specifically to fall back to cublas — that fallback is unreachable because the exception propagates and crashes the backward pass on any setup where grouped_gemm_dglu_kernel is present but grouped_gemm_wgrad_wrapper_sm100 is not. Note that _nvidia_cudnn_frontend_supports_wgrad is imported but never used; it was clearly intended as the version gate here.

Suggested change
def grouped_gemm_wgrad_kernel(cls) -> Optional[Callable]:
"""CuTe DSL kernel for grouped GEMM wgrad on SM100+.
Returns ``None`` when the cuDNN front-end package is older than
1.23.0.
Returns ``None`` when the cuDNN front-end wgrad API is not
available or lacks the required wgrad_tensor/wgrad_ptrs params.
"""
if not _nvidia_cudnn_frontend_supports_wgrad():
return None
from cudnn import grouped_gemm_wgrad_wrapper_sm100 # pylint: disable=no-name-in-module
return grouped_gemm_wgrad_wrapper_sm100
@classmethod
@functools.lru_cache(maxsize=None)
def grouped_gemm_wgrad_kernel(cls) -> Optional[Callable]:
"""CuTe DSL kernel for grouped GEMM wgrad on SM100+.
Returns ``None`` when the cuDNN front-end wgrad API is not
available or lacks the required wgrad_tensor/wgrad_ptrs params.
"""
if not _nvidia_cudnn_frontend_supports_wgrad():
return None
try:
from cudnn import grouped_gemm_wgrad_wrapper_sm100 # pylint: disable=no-name-in-module
except ImportError:
return None
return grouped_gemm_wgrad_wrapper_sm100

@vthumbe1503 vthumbe1503 marked this pull request as draft April 13, 2026 22:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant