From 4512258674236f9d261f6d26f79a972bc624cbef Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 13 Apr 2026 02:46:59 +0000 Subject: [PATCH 1/9] integrate cudnn wgrad kernel Signed-off-by: Varun Thumbe --- .../pytorch/ops/fused/backward_grouped_mlp.py | 136 +++++++++++++++++- 1 file changed, 129 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 357e8b3695..d52d8dab0b 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -35,6 +35,86 @@ from ...triton.grouped_dbias_dscales import _compute_grouped_dbias_dscales +def _cudnn_compute_wgrad( + grouped_x: GroupedTensor, + grouped_dy: GroupedTensor, + wgrad_output, + weight_shape: tuple, + offsets: torch.Tensor, + dtype: torch.dtype, + accumulate: bool, + wgrad_kernel_fn, + single_grouped_weight: bool, +): + """Compute wgrad using the cuDNN CuTe DSL grouped GEMM wgrad kernel. + + The cuDNN wgrad kernel computes: + wgrad[e] = a[:, tok_start:tok_end] @ b[tok_start:tok_end, :] + where a = DY^T = (out_features, total_tokens) row-major and + b = X = (total_tokens, in_features) column-major. + + TE's columnwise_data is (total_tokens, features) row-major with + columnwise block scaling. We transpose data and reorder scale + super-blocks from TE's global layout to cuDNN's per-group layout. + """ + out_features, in_features = weight_shape + total_tokens = grouped_dy.logical_shape[0] + + fp8_dtype = torch.float8_e4m3fn + + # a_tensor = DY^T = (out_features, total_tokens) row-major + a_tensor = ( + grouped_dy.columnwise_data + .view(dtype=fp8_dtype) + .view(total_tokens, out_features) + .T + ) + # b_tensor = X = (total_tokens, in_features) column-major + b_tensor = ( + grouped_x.columnwise_data + .view(dtype=fp8_dtype) + .view(total_tokens, in_features) + ) + + sfa_tensor = grouped_dy.columnwise_scale_inv.view(out_features, -1).view(dtype=torch.float8_e8m0fnu) + sfb_tensor = grouped_x.columnwise_scale_inv.view(in_features, -1).view(dtype=torch.float8_e8m0fnu) + offsets_tensor = offsets.to(dtype=torch.int32) + + # Prepare wgrad output + if single_grouped_weight: + # Dense mode: single (num_groups, out_features, in_features) tensor + wgrad_tensor = wgrad_output.rowwise_data.view(offsets_tensor.shape[0], out_features, in_features) + wgrad_kernel_fn( + a_tensor=a_tensor, + b_tensor=b_tensor, + sfa_tensor=sfa_tensor, + sfb_tensor=sfb_tensor, + offsets_tensor=offsets_tensor, + output_mode="dense", + wgrad_tensor=wgrad_tensor, + acc_dtype=torch.float32, + wgrad_dtype=wgrad_tensor.dtype, + sf_vec_size=MXFP8_BLOCK_SCALING_SIZE, + accumulate_on_output=accumulate, + ) + else: + # Discrete mode: per-expert wgrad device pointers + (wgrad_ptrs,) = tex.convert_host_pointers_to_tensor([wgrad_output]) + wgrad_kernel_fn( + a_tensor=a_tensor, + b_tensor=b_tensor, + sfa_tensor=sfa_tensor, + sfb_tensor=sfb_tensor, + offsets_tensor=offsets_tensor, + output_mode="discrete", + wgrad_ptrs=wgrad_ptrs, + acc_dtype=torch.float32, + wgrad_dtype=wgrad_output[0].dtype, + sf_vec_size=MXFP8_BLOCK_SCALING_SIZE, + accumulate_on_output=accumulate, + ) + + @functools.lru_cache(maxsize=1) def _dglu_wrapper_has_generate_dbias_arg() -> bool: """True if cudnn-frontend SM100 dGLU wrapper accepts ``generate_dbias``.""" @@ -61,6 +141,8 @@ def _compute_grad_params( bias_grads, bias_grad_packed, label="", + cudnn_wgrad_kernel_fn=None, + offsets=None, ): """Compute weight gradients and build grad_params for a GroupedLinear layer. Returns the grad_params list in parameter registration order. @@ -131,11 +213,24 @@ def _compute_grad_params( if ctx.weight_requires_grad: # Launch or defer the GEMM delay_wgrad = fc_op.wgrad_store is not None and fc_op.wgrad_store.delay_wgrad_compute() - gemm_fn = functools.partial( - general_grouped_gemm_for_grouped_tensor, - layout="NT", - accumulate=accumulate_into_main_grad, - ) + + if cudnn_wgrad_kernel_fn is not None and offsets is not None: + gemm_fn = functools.partial( + _cudnn_compute_wgrad, + weight_shape=weight_shape, + offsets=offsets, + dtype=dtype, + accumulate=accumulate_into_main_grad, + wgrad_kernel_fn=cudnn_wgrad_kernel_fn, + single_grouped_weight=fc_op.single_grouped_weight, + ) + else: + gemm_fn = functools.partial( + general_grouped_gemm_for_grouped_tensor, + layout="NT", + accumulate=accumulate_into_main_grad, + ) + if delay_wgrad: fc_op.wgrad_store.put([grouped_x, grouped_dy, wgrad_output], gemm_fn) else: @@ -204,6 +299,13 @@ def grouped_gemm_quant_kernel(cls) -> Callable: return grouped_gemm_quant_wrapper_sm100 + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_wgrad_kernel(cls) -> Callable: + """CuTe DSL kernel for grouped GEMM wgrad on SM100+.""" + from cudnn import grouped_gemm_wgrad_wrapper_sm100 # pylint: disable=no-name-in-module + return grouped_gemm_wgrad_wrapper_sm100 + @classmethod @functools.lru_cache(maxsize=None) def is_supported(cls) -> bool: @@ -215,6 +317,19 @@ def is_supported(cls) -> bool: try: cls.grouped_gemm_dglu_kernel() cls.grouped_gemm_quant_kernel() + cls.grouped_gemm_wgrad_kernel() + except ImportError: + return False + return True + + @classmethod + @functools.lru_cache(maxsize=None) + def is_wgrad_supported(cls) -> bool: + """Whether the CuTe DSL wgrad kernel is available.""" + if not cls.is_supported(): + return False + try: + cls.grouped_gemm_wgrad_kernel() except ImportError: return False return True @@ -477,10 +592,12 @@ def fuser_backward( fc1_dy_row_data = fc2_dgrad_kernel_out["d_row_tensor"] fc1_dy_row_data = fc1_dy_row_data.view(out_shape[0], fc1_weight_shape[0]) - fc1_dy_row_scale = fc2_dgrad_kernel_out["sfd_row_tensor"] + # View scale in their actual swizzled shape + fc1_dy_row_scale = fc2_dgrad_kernel_out["sfd_row_tensor"].permute(5, 2, 4, 0, 1, 3).view(-1) fc1_dy_col_data = fc2_dgrad_kernel_out["d_col_tensor"] fc1_dy_col_data = fc1_dy_col_data.view(out_shape[0], fc1_weight_shape[0]) - fc1_dy_col_scale = fc2_dgrad_kernel_out["sfd_col_tensor"] + # View scale in their actual swizzled shape + fc1_dy_col_scale = fc2_dgrad_kernel_out["sfd_col_tensor"].permute(5, 2, 4, 0, 1, 3).view(-1) grad_scales = fc2_dgrad_kernel_out["dprob_tensor"].view(-1) fc2_bias_grads: Optional[list[Optional[torch.Tensor]]] = None @@ -541,6 +658,7 @@ def fuser_backward( ) # FC2 wgrad GEMM + _use_cudnn_wgrad = self.is_wgrad_supported() fc2_grad_params = _compute_grad_params( fc_op=fc2_op, ctx=fc2_ctx, @@ -553,6 +671,8 @@ def fuser_backward( bias_grads=fc2_bias_grads, bias_grad_packed=fc2_bias_grad_packed, label="FC2", + cudnn_wgrad_kernel_fn=self.grouped_gemm_wgrad_kernel() if _use_cudnn_wgrad else None, + offsets=split_points, ) # Clear FC2 input tensor if possible @@ -648,6 +768,8 @@ def fuser_backward( bias_grads=fc1_bias_grads, bias_grad_packed=fc1_bias_grad_packed, label="FC1", + cudnn_wgrad_kernel_fn=self.grouped_gemm_wgrad_kernel() if _use_cudnn_wgrad else None, + offsets=split_points, ) # Clear FC1 input tensor if possible From 18fc3afbeb5a960a29fd32b5b899eea9c1168857 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 13 Apr 2026 02:55:42 +0000 Subject: [PATCH 2/9] have only cute dsl for wgrad Signed-off-by: Varun Thumbe --- 3rdparty/cudnn-frontend | 2 +- .../pytorch/ops/fused/backward_grouped_mlp.py | 51 +++++-------------- 2 files changed, 15 insertions(+), 38 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 7b9b711c22..088e28d993 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 7b9b711c22b6823e87150213ecd8449260db8610 +Subproject commit 088e28d993d16489e213baa638949a6a13211406 diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index d52d8dab0b..81308c08bf 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -15,9 +15,6 @@ import torch import transformer_engine_torch as tex -from ...cpp_extensions import ( - general_grouped_gemm_for_grouped_tensor, -) from ...module.base import get_dummy_wgrad from ...quantization import Recipe from ...tensor.grouped_tensor import GroupedTensor @@ -141,8 +138,9 @@ def _compute_grad_params( bias_grads, bias_grad_packed, label="", - cudnn_wgrad_kernel_fn=None, - offsets=None, + *, + cudnn_wgrad_kernel_fn, + offsets, ): """Compute weight gradients and build grad_params for a GroupedLinear layer. Returns the grad_params list in parameter registration order. @@ -213,23 +211,15 @@ def _compute_grad_params( if ctx.weight_requires_grad: # Launch or defer the GEMM delay_wgrad = fc_op.wgrad_store is not None and fc_op.wgrad_store.delay_wgrad_compute() - - if cudnn_wgrad_kernel_fn is not None and offsets is not None: - gemm_fn = functools.partial( - _cudnn_compute_wgrad, - weight_shape=weight_shape, - offsets=offsets, - dtype=dtype, - accumulate=accumulate_into_main_grad, - wgrad_kernel_fn=cudnn_wgrad_kernel_fn, - single_grouped_weight=fc_op.single_grouped_weight, - ) - else: - gemm_fn = functools.partial( - general_grouped_gemm_for_grouped_tensor, - layout="NT", - accumulate=accumulate_into_main_grad, - ) + gemm_fn = functools.partial( + _cudnn_compute_wgrad, + weight_shape=weight_shape, + offsets=offsets, + dtype=dtype, + accumulate=accumulate_into_main_grad, + wgrad_kernel_fn=cudnn_wgrad_kernel_fn, + single_grouped_weight=fc_op.single_grouped_weight, + ) if delay_wgrad: fc_op.wgrad_store.put([grouped_x, grouped_dy, wgrad_output], gemm_fn) @@ -322,18 +312,6 @@ def is_supported(cls) -> bool: return False return True - @classmethod - @functools.lru_cache(maxsize=None) - def is_wgrad_supported(cls) -> bool: - """Whether the CuTe DSL wgrad kernel is available.""" - if not cls.is_supported(): - return False - try: - cls.grouped_gemm_wgrad_kernel() - except ImportError: - return False - return True - @classmethod def is_fc1_bias_supported(cls) -> bool: """Whether cudnn-frontend exposes ``generate_dbias`` on the dGLU SM100 wrapper (FC1 bias grad only).""" @@ -658,7 +636,6 @@ def fuser_backward( ) # FC2 wgrad GEMM - _use_cudnn_wgrad = self.is_wgrad_supported() fc2_grad_params = _compute_grad_params( fc_op=fc2_op, ctx=fc2_ctx, @@ -671,7 +648,7 @@ def fuser_backward( bias_grads=fc2_bias_grads, bias_grad_packed=fc2_bias_grad_packed, label="FC2", - cudnn_wgrad_kernel_fn=self.grouped_gemm_wgrad_kernel() if _use_cudnn_wgrad else None, + cudnn_wgrad_kernel_fn=self.grouped_gemm_wgrad_kernel(), offsets=split_points, ) @@ -768,7 +745,7 @@ def fuser_backward( bias_grads=fc1_bias_grads, bias_grad_packed=fc1_bias_grad_packed, label="FC1", - cudnn_wgrad_kernel_fn=self.grouped_gemm_wgrad_kernel() if _use_cudnn_wgrad else None, + cudnn_wgrad_kernel_fn=self.grouped_gemm_wgrad_kernel(), offsets=split_points, ) From 5d1a077cbe219886f04efa27b9091a02d265b068 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Apr 2026 04:08:06 +0000 Subject: [PATCH 3/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/ops/fused/backward_grouped_mlp.py | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 81308c08bf..93f5e2e3d2 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -60,27 +60,24 @@ def _cudnn_compute_wgrad( fp8_dtype = torch.float8_e4m3fn # a_tensor = DY^T = (out_features, total_tokens) row-major - a_tensor = ( - grouped_dy.columnwise_data - .view(dtype=fp8_dtype) - .view(total_tokens, out_features) - .T - ) + a_tensor = grouped_dy.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, out_features).T # b_tensor = X = (total_tokens, in_features) column-major - b_tensor = ( - grouped_x.columnwise_data - .view(dtype=fp8_dtype) - .view(total_tokens, in_features) - ) + b_tensor = grouped_x.columnwise_data.view(dtype=fp8_dtype).view(total_tokens, in_features) - sfa_tensor = grouped_dy.columnwise_scale_inv.view(out_features, -1).view(dtype=torch.float8_e8m0fnu) - sfb_tensor = grouped_x.columnwise_scale_inv.view(in_features, -1).view(dtype=torch.float8_e8m0fnu) + sfa_tensor = grouped_dy.columnwise_scale_inv.view(out_features, -1).view( + dtype=torch.float8_e8m0fnu + ) + sfb_tensor = grouped_x.columnwise_scale_inv.view(in_features, -1).view( + dtype=torch.float8_e8m0fnu + ) offsets_tensor = offsets.to(dtype=torch.int32) # Prepare wgrad output if single_grouped_weight: # Dense mode: single (num_groups, out_features, in_features) tensor - wgrad_tensor = wgrad_output.rowwise_data.view(offsets_tensor.shape[0], out_features, in_features) + wgrad_tensor = wgrad_output.rowwise_data.view( + offsets_tensor.shape[0], out_features, in_features + ) wgrad_kernel_fn( a_tensor=a_tensor, b_tensor=b_tensor, @@ -294,6 +291,7 @@ def grouped_gemm_quant_kernel(cls) -> Callable: def grouped_gemm_wgrad_kernel(cls) -> Callable: """CuTe DSL kernel for grouped GEMM wgrad on SM100+.""" from cudnn import grouped_gemm_wgrad_wrapper_sm100 # pylint: disable=no-name-in-module + return grouped_gemm_wgrad_wrapper_sm100 @classmethod From 95988478bd1834da4d9c6ea14b675024d60006be Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 13 Apr 2026 16:22:30 +0000 Subject: [PATCH 4/9] revert the change for cudnn Signed-off-by: Varun Thumbe --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 088e28d993..7b9b711c22 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 088e28d993d16489e213baa638949a6a13211406 +Subproject commit 7b9b711c22b6823e87150213ecd8449260db8610 From 6af122765b20c9dada9ddd18f20b2906c71843f4 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 13 Apr 2026 16:27:04 +0000 Subject: [PATCH 5/9] remove dtype Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 93f5e2e3d2..ebccfdf4d1 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -38,7 +38,6 @@ def _cudnn_compute_wgrad( wgrad_output, weight_shape: tuple, offsets: torch.Tensor, - dtype: torch.dtype, accumulate: bool, wgrad_kernel_fn, single_grouped_weight: bool, @@ -212,7 +211,6 @@ def _compute_grad_params( _cudnn_compute_wgrad, weight_shape=weight_shape, offsets=offsets, - dtype=dtype, accumulate=accumulate_into_main_grad, wgrad_kernel_fn=cudnn_wgrad_kernel_fn, single_grouped_weight=fc_op.single_grouped_weight, From 99f9853d2763a107f36c8ade1d57a50b0aa3f4e0 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 13 Apr 2026 16:31:39 +0000 Subject: [PATCH 6/9] fix comment: Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index ebccfdf4d1..22855b2cc8 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -48,10 +48,6 @@ def _cudnn_compute_wgrad( wgrad[e] = a[:, tok_start:tok_end] @ b[tok_start:tok_end, :] where a = DY^T = (out_features, total_tokens) row-major and b = X = (total_tokens, in_features) column-major. - - TE's columnwise_data is (total_tokens, features) row-major with - columnwise block scaling. We transpose data and reorder scale - super-blocks from TE's global layout to cuDNN's per-group layout. """ out_features, in_features = weight_shape total_tokens = grouped_dy.logical_shape[0] From b829352d502332a6aa75306a3461917552f67fb7 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 13 Apr 2026 20:06:13 +0000 Subject: [PATCH 7/9] go to cublas if needed Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/ops/_common.py | 9 +++++ .../pytorch/ops/fused/backward_grouped_mlp.py | 39 ++++++++++++------- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index ae8b48a90d..f01a3481d1 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -29,6 +29,15 @@ def _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu() -> bool: return False +@functools.lru_cache(maxsize=1) +def _nvidia_cudnn_frontend_supports_wgrad() -> bool: + """Check cuDNN FE min version for grouped GEMM wgrad kernel.""" + try: + return PkgVersion(get_pkg_version("nvidia-cudnn-frontend")) >= PkgVersion("1.23.0") + except PackageNotFoundError: + return False + + def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorStorage) -> bool: """Check if tensor is a quantized tensor""" return isinstance(tensor, QuantizedTensorStorage) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 22855b2cc8..3134405133 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -25,10 +25,13 @@ from ..fuser import register_backward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext from .._common import ( + _nvidia_cudnn_frontend_supports_wgrad, fuse_grouped_mlp_ops, maybe_dequantize, validate_grouped_mlp_dims, ) +from ...cpp_extensions import general_grouped_gemm_for_grouped_tensor +from ...module.base import _2X_ACC_WGRAD from ...triton.grouped_dbias_dscales import _compute_grouped_dbias_dscales @@ -203,14 +206,22 @@ def _compute_grad_params( if ctx.weight_requires_grad: # Launch or defer the GEMM delay_wgrad = fc_op.wgrad_store is not None and fc_op.wgrad_store.delay_wgrad_compute() - gemm_fn = functools.partial( - _cudnn_compute_wgrad, - weight_shape=weight_shape, - offsets=offsets, - accumulate=accumulate_into_main_grad, - wgrad_kernel_fn=cudnn_wgrad_kernel_fn, - single_grouped_weight=fc_op.single_grouped_weight, - ) + if cudnn_wgrad_kernel_fn is not None: + gemm_fn = functools.partial( + _cudnn_compute_wgrad, + weight_shape=weight_shape, + offsets=offsets, + accumulate=accumulate_into_main_grad, + wgrad_kernel_fn=cudnn_wgrad_kernel_fn, + single_grouped_weight=fc_op.single_grouped_weight, + ) + else: + gemm_fn = functools.partial( + general_grouped_gemm_for_grouped_tensor, + layout="NT", + accumulate=accumulate_into_main_grad, + use_split_accumulator=_2X_ACC_WGRAD, + ) if delay_wgrad: fc_op.wgrad_store.put([grouped_x, grouped_dy, wgrad_output], gemm_fn) @@ -277,15 +288,18 @@ def grouped_gemm_dglu_kernel(cls) -> Callable: def grouped_gemm_quant_kernel(cls) -> Callable: """Grouped GEMM quant kernel for block-scaled inputs.""" from cudnn import grouped_gemm_quant_wrapper_sm100 # pylint: disable=no-name-in-module - return grouped_gemm_quant_wrapper_sm100 @classmethod @functools.lru_cache(maxsize=None) - def grouped_gemm_wgrad_kernel(cls) -> Callable: - """CuTe DSL kernel for grouped GEMM wgrad on SM100+.""" + def grouped_gemm_wgrad_kernel(cls) -> Optional[Callable]: + """CuTe DSL kernel for grouped GEMM wgrad on SM100+. + Returns ``None`` when the cuDNN front-end package is older than + 1.23.0. + """ + if not _nvidia_cudnn_frontend_supports_wgrad(): + return None from cudnn import grouped_gemm_wgrad_wrapper_sm100 # pylint: disable=no-name-in-module - return grouped_gemm_wgrad_wrapper_sm100 @classmethod @@ -299,7 +313,6 @@ def is_supported(cls) -> bool: try: cls.grouped_gemm_dglu_kernel() cls.grouped_gemm_quant_kernel() - cls.grouped_gemm_wgrad_kernel() except ImportError: return False return True From a8c285db5b3983a0e2c332f1c04727cd5f2e22bf Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 13 Apr 2026 22:32:14 +0000 Subject: [PATCH 8/9] changes to unblock testing Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/ops/_common.py | 6 +----- .../pytorch/ops/fused/backward_grouped_mlp.py | 10 ++++++---- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index f01a3481d1..99a622f357 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -6,6 +6,7 @@ from __future__ import annotations import functools +import inspect from importlib.metadata import PackageNotFoundError, version as get_pkg_version from typing import Optional @@ -38,11 +39,6 @@ def _nvidia_cudnn_frontend_supports_wgrad() -> bool: return False -def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorStorage) -> bool: - """Check if tensor is a quantized tensor""" - return isinstance(tensor, QuantizedTensorStorage) - - def maybe_dequantize( tensor: torch.Tensor | QuantizedTensorStorage, dtype: torch.dtype | None = None ) -> torch.Tensor: diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index 3134405133..8c14b218ee 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -44,6 +44,7 @@ def _cudnn_compute_wgrad( accumulate: bool, wgrad_kernel_fn, single_grouped_weight: bool, + current_stream=None, ): """Compute wgrad using the cuDNN CuTe DSL grouped GEMM wgrad kernel. @@ -88,6 +89,7 @@ def _cudnn_compute_wgrad( wgrad_dtype=wgrad_tensor.dtype, sf_vec_size=MXFP8_BLOCK_SCALING_SIZE, accumulate_on_output=accumulate, + current_stream=current_stream, ) else: # Discrete mode: per-expert wgrad device pointers @@ -104,6 +106,7 @@ def _cudnn_compute_wgrad( wgrad_dtype=wgrad_output[0].dtype, sf_vec_size=MXFP8_BLOCK_SCALING_SIZE, accumulate_on_output=accumulate, + current_stream=current_stream, ) @@ -214,6 +217,7 @@ def _compute_grad_params( accumulate=accumulate_into_main_grad, wgrad_kernel_fn=cudnn_wgrad_kernel_fn, single_grouped_weight=fc_op.single_grouped_weight, + current_stream=torch.cuda.current_stream().cuda_stream, ) else: gemm_fn = functools.partial( @@ -294,11 +298,9 @@ def grouped_gemm_quant_kernel(cls) -> Callable: @functools.lru_cache(maxsize=None) def grouped_gemm_wgrad_kernel(cls) -> Optional[Callable]: """CuTe DSL kernel for grouped GEMM wgrad on SM100+. - Returns ``None`` when the cuDNN front-end package is older than - 1.23.0. + Returns ``None`` when the cuDNN front-end wgrad API is not + available or lacks the required wgrad_tensor/wgrad_ptrs params. """ - if not _nvidia_cudnn_frontend_supports_wgrad(): - return None from cudnn import grouped_gemm_wgrad_wrapper_sm100 # pylint: disable=no-name-in-module return grouped_gemm_wgrad_wrapper_sm100 From 699aad15c5d4f7111e92efd7b16a2eb118bcf467 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Apr 2026 22:34:28 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index bd1eb96699..4cbc5cb620 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -292,6 +292,7 @@ def grouped_gemm_dglu_kernel(cls) -> Callable: def grouped_gemm_quant_kernel(cls) -> Callable: """Grouped GEMM quant kernel for block-scaled inputs.""" from cudnn import grouped_gemm_quant_wrapper_sm100 # pylint: disable=no-name-in-module + return grouped_gemm_quant_wrapper_sm100 @classmethod @@ -302,6 +303,7 @@ def grouped_gemm_wgrad_kernel(cls) -> Optional[Callable]: available or lacks the required wgrad_tensor/wgrad_ptrs params. """ from cudnn import grouped_gemm_wgrad_wrapper_sm100 # pylint: disable=no-name-in-module + return grouped_gemm_wgrad_wrapper_sm100 @classmethod