From 30634179650fdd52e4d00b09c5b9758468473124 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 14 Apr 2026 16:17:46 -0700 Subject: [PATCH 1/3] Fix grouped quant checkpointing Signed-off-by: Jeremy Berchtold --- .../jax/cpp_extensions/quantization.py | 6 +++++- transformer_engine/jax/dense.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 7138cfcf40..3d56d57ada 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1182,7 +1182,11 @@ def lowering( # V2: CUDA-graph safe; scale is passed but ignored by the C++ handler. # Requires total_first_dim % 128 == 0 (checked above) and all individual # group sizes % 128 == 0 (dynamic constraint, enforced by the kernel). - return ffi.ffi_lowering(GroupedQuantizePrimitive.name_v2)( + # has_side_effect=False: V2 has no observable side effects beyond its output + # buffers (no D2H copy, no global-state mutation). This lets XLA DCE the + # call in the backward-scan remat block when both rowwise and colwise outputs + # have been replaced by JAX checkpointed residuals. + return ffi.ffi_lowering(GroupedQuantizePrimitive.name_v2, has_side_effect=False)( ctx, x, scale, diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index dbd7bbb1ff..4de13ebf6e 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -429,10 +429,22 @@ def _grouped_dense_fwd_rule( # rowwise_casted_x.original_shape == (M, K) # colwise_casted_kernel.original_shape == (G, N, K) grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS) + # Checkpoint the rowwise inputs so that te_grouped_quantize_ffi can be DCE'd in the + # backward-scan remat block. Without this, JAX would re-run the quantize kernel to + # obtain grouped_gemm_x / grouped_gemm_kernel for the forward-GEMM recomputation even + # though the colwise residuals (ctx_x / ctx_kernel) are already saved. With both + # orientations checkpointed, all outputs of the custom-call become dead in the remat + # trace and XLA can eliminate it (provided has_side_effect=False on the lowering). + grouped_gemm_x = grouped_gemm_x.checkpoint(quantizer_set.x) if isinstance( + grouped_gemm_x, ScaledTensor + ) else grouped_gemm_x ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS) ctx_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS) grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS) + grouped_gemm_kernel = grouped_gemm_kernel.checkpoint(quantizer_set.kernel) if isinstance( + grouped_gemm_kernel, ScaledTensor + ) else grouped_gemm_kernel output = tex.grouped_gemm( grouped_gemm_x, grouped_gemm_kernel, From 1d2328a891c7104bef95851bd347d8b4cc273ee3 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 16 Apr 2026 09:13:18 -0700 Subject: [PATCH 2/3] Cleanup Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/quantization.py | 6 +----- transformer_engine/jax/dense.py | 3 +-- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 3d56d57ada..7138cfcf40 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1182,11 +1182,7 @@ def lowering( # V2: CUDA-graph safe; scale is passed but ignored by the C++ handler. # Requires total_first_dim % 128 == 0 (checked above) and all individual # group sizes % 128 == 0 (dynamic constraint, enforced by the kernel). - # has_side_effect=False: V2 has no observable side effects beyond its output - # buffers (no D2H copy, no global-state mutation). This lets XLA DCE the - # call in the backward-scan remat block when both rowwise and colwise outputs - # have been replaced by JAX checkpointed residuals. - return ffi.ffi_lowering(GroupedQuantizePrimitive.name_v2, has_side_effect=False)( + return ffi.ffi_lowering(GroupedQuantizePrimitive.name_v2)( ctx, x, scale, diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 4de13ebf6e..86506381d0 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -433,8 +433,7 @@ def _grouped_dense_fwd_rule( # backward-scan remat block. Without this, JAX would re-run the quantize kernel to # obtain grouped_gemm_x / grouped_gemm_kernel for the forward-GEMM recomputation even # though the colwise residuals (ctx_x / ctx_kernel) are already saved. With both - # orientations checkpointed, all outputs of the custom-call become dead in the remat - # trace and XLA can eliminate it (provided has_side_effect=False on the lowering). + # orientations checkpointed, all outputs of the custom-call become dead in the remat trace. grouped_gemm_x = grouped_gemm_x.checkpoint(quantizer_set.x) if isinstance( grouped_gemm_x, ScaledTensor ) else grouped_gemm_x From 0de178e4ddb38a6a87b6f3def0a38570a378ee68 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Apr 2026 16:14:23 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/dense.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 86506381d0..f8c30ffccb 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -434,16 +434,20 @@ def _grouped_dense_fwd_rule( # obtain grouped_gemm_x / grouped_gemm_kernel for the forward-GEMM recomputation even # though the colwise residuals (ctx_x / ctx_kernel) are already saved. With both # orientations checkpointed, all outputs of the custom-call become dead in the remat trace. - grouped_gemm_x = grouped_gemm_x.checkpoint(quantizer_set.x) if isinstance( - grouped_gemm_x, ScaledTensor - ) else grouped_gemm_x + grouped_gemm_x = ( + grouped_gemm_x.checkpoint(quantizer_set.x) + if isinstance(grouped_gemm_x, ScaledTensor) + else grouped_gemm_x + ) ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS) ctx_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS) grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS) - grouped_gemm_kernel = grouped_gemm_kernel.checkpoint(quantizer_set.kernel) if isinstance( - grouped_gemm_kernel, ScaledTensor - ) else grouped_gemm_kernel + grouped_gemm_kernel = ( + grouped_gemm_kernel.checkpoint(quantizer_set.kernel) + if isinstance(grouped_gemm_kernel, ScaledTensor) + else grouped_gemm_kernel + ) output = tex.grouped_gemm( grouped_gemm_x, grouped_gemm_kernel,