Skip to content

[JAX] Fix grouped quant checkpointing#2889

Open
jberchtold-nvidia wants to merge 4 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/grouped-quant-checkpointing
Open

[JAX] Fix grouped quant checkpointing#2889
jberchtold-nvidia wants to merge 4 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/grouped-quant-checkpointing

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Apr 16, 2026

Description

Fixes issues with grouped quant checkpointing where not all values were checkpointed properly

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Ensure all saved values are checkpointed properly in grouped_dense VJP

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 16, 2026

Greptile Summary

This PR fixes a checkpointing gap in _grouped_dense_fwd_rule: the colwise residuals (ctx_x, ctx_kernel) were already being saved via checkpoint_name, but the rowwise tensors (grouped_gemm_x, grouped_gemm_kernel) were not. Without those markers JAX's remat trace could not DCE the te_grouped_quantize_ffi custom call, so the quantize kernel would re-execute unnecessarily during backward rematerialisation. The fix mirrors the existing pattern used for ctx_x/ctx_kernel and adds symmetrical ScaledTensor guards for both inputs.

Confidence Score: 5/5

Safe to merge — targeted two-line fix with no new logic paths or API surface changes.

The change exactly mirrors the existing colwise checkpoint pattern, is guarded correctly for non-ScaledTensor paths, and the explanatory comment clearly documents the remat DCE intent. No regressions introduced; all remaining prior review concerns (shared checkpoint_name) were addressed in the previous thread.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/jax/dense.py Adds checkpoint_name markers for rowwise tensors grouped_gemm_x and grouped_gemm_kernel to allow DCE of the quantize kernel during backward-pass rematerialisation; consistent with the existing colwise checkpoint pattern.

Sequence Diagram

sequenceDiagram
    participant Fwd as Forward Pass
    participant Remat as JAX Remat
    participant Bwd as Backward Pass

    Fwd->>Fwd: grouped_quantize(x) → casted_x
    Fwd->>Fwd: grouped_gemm_x = casted_x.get_tensor(LHS)
    Note over Fwd: NEW: checkpoint_name(grouped_gemm_x)
    Fwd->>Fwd: ctx_x = casted_x.get_tensor(LHS_TRANS)
    Fwd->>Fwd: grouped_quantize(kernel) → casted_kernel
    Fwd->>Fwd: grouped_gemm_kernel = casted_kernel.get_tensor(RHS)
    Note over Fwd: NEW: checkpoint_name(grouped_gemm_kernel)
    Fwd->>Fwd: ctx_kernel = casted_kernel.get_tensor(RHS_TRANS)
    Fwd->>Fwd: output = grouped_gemm(grouped_gemm_x, grouped_gemm_kernel)
    Fwd-->>Remat: save ctx (group_sizes, ckpt(ctx_x), ckpt(ctx_kernel), …)
    Note over Remat: All 4 quantize outputs now checkpointed → DCE quantize kernel
    Remat-->>Bwd: ctx residuals (no re-quantize needed)
    Bwd->>Bwd: dgrad = grouped_gemm(dgrad_grad, ctx_kernel)
    Bwd->>Bwd: wgrad = grouped_gemm(ctx_x, wgrad_grad)
Loading

Reviews (2): Last reviewed commit: "Merge branch 'main' into jberchtold/grou..." | Re-trigger Greptile

Comment thread transformer_engine/jax/dense.py
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/grouped-quant-checkpointing branch from da44cd2 to 1d2328a Compare April 16, 2026 16:13
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant