Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…/TransformerEngine into users/vthumbe/wgrad_cute_dsl
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR wires up an experimental CuTe DSL grouped GEMM wgrad kernel from
Confidence Score: 4/5Not safe to merge until One P1 defect: on any system where transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py — specifically Important Files Changed
|
| def grouped_gemm_wgrad_kernel(cls) -> Optional[Callable]: | ||
| """CuTe DSL kernel for grouped GEMM wgrad on SM100+. | ||
| Returns ``None`` when the cuDNN front-end package is older than | ||
| 1.23.0. | ||
| Returns ``None`` when the cuDNN front-end wgrad API is not | ||
| available or lacks the required wgrad_tensor/wgrad_ptrs params. | ||
| """ | ||
| if not _nvidia_cudnn_frontend_supports_wgrad(): | ||
| return None | ||
| from cudnn import grouped_gemm_wgrad_wrapper_sm100 # pylint: disable=no-name-in-module | ||
|
|
||
| return grouped_gemm_wgrad_wrapper_sm100 |
There was a problem hiding this comment.
grouped_gemm_wgrad_kernel() never returns None as documented
The docstring promises None when the wgrad API is unavailable, but a bare from cudnn import grouped_gemm_wgrad_wrapper_sm100 raises ImportError instead of returning None. _compute_grad_params (called at lines 660 and 757) branches on if cudnn_wgrad_kernel_fn is not None specifically to fall back to cublas — that fallback is unreachable because the exception propagates and crashes the backward pass on any setup where grouped_gemm_dglu_kernel is present but grouped_gemm_wgrad_wrapper_sm100 is not. Note that _nvidia_cudnn_frontend_supports_wgrad is imported but never used; it was clearly intended as the version gate here.
| def grouped_gemm_wgrad_kernel(cls) -> Optional[Callable]: | |
| """CuTe DSL kernel for grouped GEMM wgrad on SM100+. | |
| Returns ``None`` when the cuDNN front-end package is older than | |
| 1.23.0. | |
| Returns ``None`` when the cuDNN front-end wgrad API is not | |
| available or lacks the required wgrad_tensor/wgrad_ptrs params. | |
| """ | |
| if not _nvidia_cudnn_frontend_supports_wgrad(): | |
| return None | |
| from cudnn import grouped_gemm_wgrad_wrapper_sm100 # pylint: disable=no-name-in-module | |
| return grouped_gemm_wgrad_wrapper_sm100 | |
| @classmethod | |
| @functools.lru_cache(maxsize=None) | |
| def grouped_gemm_wgrad_kernel(cls) -> Optional[Callable]: | |
| """CuTe DSL kernel for grouped GEMM wgrad on SM100+. | |
| Returns ``None`` when the cuDNN front-end wgrad API is not | |
| available or lacks the required wgrad_tensor/wgrad_ptrs params. | |
| """ | |
| if not _nvidia_cudnn_frontend_supports_wgrad(): | |
| return None | |
| try: | |
| from cudnn import grouped_gemm_wgrad_wrapper_sm100 # pylint: disable=no-name-in-module | |
| except ImportError: | |
| return None | |
| return grouped_gemm_wgrad_wrapper_sm100 |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: