Skip to content
1 change: 1 addition & 0 deletions transformer_engine/pytorch/ops/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations
import functools
import inspect
from importlib.metadata import PackageNotFoundError, version as get_pkg_version
from typing import Optional

Expand Down
10 changes: 6 additions & 4 deletions transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _cudnn_compute_wgrad(
accumulate: bool,
wgrad_kernel_fn,
single_grouped_weight: bool,
current_stream=None,
):
"""Compute wgrad using the cuDNN CuTe DSL grouped GEMM wgrad kernel.

Expand Down Expand Up @@ -88,6 +89,7 @@ def _cudnn_compute_wgrad(
wgrad_dtype=wgrad_tensor.dtype,
sf_vec_size=MXFP8_BLOCK_SCALING_SIZE,
accumulate_on_output=accumulate,
current_stream=current_stream,
)
else:
# Discrete mode: per-expert wgrad device pointers
Expand All @@ -104,6 +106,7 @@ def _cudnn_compute_wgrad(
wgrad_dtype=wgrad_output[0].dtype,
sf_vec_size=MXFP8_BLOCK_SCALING_SIZE,
accumulate_on_output=accumulate,
current_stream=current_stream,
)


Expand Down Expand Up @@ -214,6 +217,7 @@ def _compute_grad_params(
accumulate=accumulate_into_main_grad,
wgrad_kernel_fn=cudnn_wgrad_kernel_fn,
single_grouped_weight=fc_op.single_grouped_weight,
current_stream=torch.cuda.current_stream().cuda_stream,
)
else:
gemm_fn = functools.partial(
Expand Down Expand Up @@ -295,11 +299,9 @@ def grouped_gemm_quant_kernel(cls) -> Callable:
@functools.lru_cache(maxsize=None)
def grouped_gemm_wgrad_kernel(cls) -> Optional[Callable]:
"""CuTe DSL kernel for grouped GEMM wgrad on SM100+.
Returns ``None`` when the cuDNN front-end package is older than
1.23.0.
Returns ``None`` when the cuDNN front-end wgrad API is not
available or lacks the required wgrad_tensor/wgrad_ptrs params.
"""
if not _nvidia_cudnn_frontend_supports_wgrad():
return None
from cudnn import grouped_gemm_wgrad_wrapper_sm100 # pylint: disable=no-name-in-module

return grouped_gemm_wgrad_wrapper_sm100
Comment on lines 300 to 307
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 grouped_gemm_wgrad_kernel() never returns None as documented

The docstring promises None when the wgrad API is unavailable, but a bare from cudnn import grouped_gemm_wgrad_wrapper_sm100 raises ImportError instead of returning None. _compute_grad_params (called at lines 660 and 757) branches on if cudnn_wgrad_kernel_fn is not None specifically to fall back to cublas — that fallback is unreachable because the exception propagates and crashes the backward pass on any setup where grouped_gemm_dglu_kernel is present but grouped_gemm_wgrad_wrapper_sm100 is not. Note that _nvidia_cudnn_frontend_supports_wgrad is imported but never used; it was clearly intended as the version gate here.

Suggested change
def grouped_gemm_wgrad_kernel(cls) -> Optional[Callable]:
"""CuTe DSL kernel for grouped GEMM wgrad on SM100+.
Returns ``None`` when the cuDNN front-end package is older than
1.23.0.
Returns ``None`` when the cuDNN front-end wgrad API is not
available or lacks the required wgrad_tensor/wgrad_ptrs params.
"""
if not _nvidia_cudnn_frontend_supports_wgrad():
return None
from cudnn import grouped_gemm_wgrad_wrapper_sm100 # pylint: disable=no-name-in-module
return grouped_gemm_wgrad_wrapper_sm100
@classmethod
@functools.lru_cache(maxsize=None)
def grouped_gemm_wgrad_kernel(cls) -> Optional[Callable]:
"""CuTe DSL kernel for grouped GEMM wgrad on SM100+.
Returns ``None`` when the cuDNN front-end wgrad API is not
available or lacks the required wgrad_tensor/wgrad_ptrs params.
"""
if not _nvidia_cudnn_frontend_supports_wgrad():
return None
try:
from cudnn import grouped_gemm_wgrad_wrapper_sm100 # pylint: disable=no-name-in-module
except ImportError:
return None
return grouped_gemm_wgrad_wrapper_sm100

Expand Down