[JAX] Fix grouped quant checkpointing#2889
[JAX] Fix grouped quant checkpointing#2889jberchtold-nvidia wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci |
Greptile SummaryThis PR fixes a checkpointing gap in Confidence Score: 5/5Safe to merge — targeted two-line fix with no new logic paths or API surface changes. The change exactly mirrors the existing colwise checkpoint pattern, is guarded correctly for non-ScaledTensor paths, and the explanatory comment clearly documents the remat DCE intent. No regressions introduced; all remaining prior review concerns (shared checkpoint_name) were addressed in the previous thread. No files require special attention. Important Files Changed
Sequence DiagramsequenceDiagram
participant Fwd as Forward Pass
participant Remat as JAX Remat
participant Bwd as Backward Pass
Fwd->>Fwd: grouped_quantize(x) → casted_x
Fwd->>Fwd: grouped_gemm_x = casted_x.get_tensor(LHS)
Note over Fwd: NEW: checkpoint_name(grouped_gemm_x)
Fwd->>Fwd: ctx_x = casted_x.get_tensor(LHS_TRANS)
Fwd->>Fwd: grouped_quantize(kernel) → casted_kernel
Fwd->>Fwd: grouped_gemm_kernel = casted_kernel.get_tensor(RHS)
Note over Fwd: NEW: checkpoint_name(grouped_gemm_kernel)
Fwd->>Fwd: ctx_kernel = casted_kernel.get_tensor(RHS_TRANS)
Fwd->>Fwd: output = grouped_gemm(grouped_gemm_x, grouped_gemm_kernel)
Fwd-->>Remat: save ctx (group_sizes, ckpt(ctx_x), ckpt(ctx_kernel), …)
Note over Remat: All 4 quantize outputs now checkpointed → DCE quantize kernel
Remat-->>Bwd: ctx residuals (no re-quantize needed)
Bwd->>Bwd: dgrad = grouped_gemm(dgrad_grad, ctx_kernel)
Bwd->>Bwd: wgrad = grouped_gemm(ctx_x, wgrad_grad)
Reviews (2): Last reviewed commit: "Merge branch 'main' into jberchtold/grou..." | Re-trigger Greptile |
|
/te-ci |
da44cd2 to
1d2328a
Compare
|
/te-ci |
Description
Fixes issues with grouped quant checkpointing where not all values were checkpointed properly
Type of change
Changes
Checklist: