feat(pt): avoid grad clip sync#5519
Conversation
📝 WalkthroughWalkthroughReplaces the stable-fallback clipping path with ChangesGradient clipping and non-finite handling
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant clip_grad_norm_
participant NonFiniteGradGuard
participant raise_nonfinite_gradient_norm
participant CheckpointWriter
Trainer->>clip_grad_norm_: call(parameters, max_norm, stable=zero_stage<2)
clip_grad_norm_->>Trainer: return total_norm (float64)
Trainer->>NonFiniteGradGuard: update(total_norm)
Note over Trainer,CheckpointWriter: training proceeds across steps
Trainer->>NonFiniteGradGuard: raise_if_nonfinite(named_parameters)
NonFiniteGradGuard->>raise_nonfinite_gradient_norm: inspect named_parameters
raise_nonfinite_gradient_norm-->>NonFiniteGradGuard: raise RuntimeError if non-finite
alt no error
Trainer->>CheckpointWriter: proceed to write checkpoint
else error
Trainer-->>CheckpointWriter: abort checkpoint write
end
🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
deepmd/pt/train/training.py (3)
1389-1422:⚠️ Potential issue | 🟠 Major | ⚡ Quick winInitialize gradient diagnostics before the optimizer split.
total_normandpre_clip_named_normsare created only in the Adam/AdaMuon/HybridMuon branch, but the TensorBoard block later readstotal_normfor every optimizer. WithLKF+ TensorBoard +gradient_max_norm > 0, this hitsUnboundLocalErrorat Line 1799.Minimal fix
if SAMPLER_RECORD: print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n" fout1.write(print_str) fout1.flush() + total_norm: torch.Tensor | None = None + pre_clip_named_norms: list[tuple[str, float]] = [] if self.opt_type in ["Adam", "AdamW", "AdaMuon", "HybridMuon"]: cur_lr = self.scheduler.get_last_lr()[0] pref_lr = cur_lr model_pred, loss, more_loss = self.wrapper( **input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key ) loss.backward() - # === Initialize gradient diagnostics variables === - total_norm: torch.Tensor | None = None - pre_clip_named_norms: list[tuple[str, float]] = [] if self.gradient_max_norm > 0.0: ...🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/train/training.py` around lines 1389 - 1422, The variables total_norm and pre_clip_named_norms must be initialized before the optimizer-specific branch so they exist for all optimizers (e.g., LKF); move or add initialization of total_norm = None and pre_clip_named_norms = [] above the if self.opt_type ... block, and only perform the per-parameter collection and clip_grad_norm_ call inside the Adam/AdamW/AdaMuon/HybridMuon branch as before; ensure nonfinite_grad_guard.update(total_norm) is only called when total_norm has been set (or leave the update call after the branch but guarded by a non-None check) so later TensorBoard code can safely read total_norm and pre_clip_named_norms for any optimizer.
1397-1422:⚠️ Potential issue | 🟠 Major | ⚡ Quick winTrack non-finite norms even when clipping is disabled.
self.nonfinite_grad_guard.update(total_norm)only runs insideif self.gradient_max_norm > 0.0. Since Line 721 defaultsgradient_max_normto0.0, the new checkpoint guard is a no-op for the default training configuration, so NaN/Inf gradients can still reach later checkpoint writes.Suggested direction
- if self.gradient_max_norm > 0.0: + total_norm = clip_grad_norm_( + self.wrapper.parameters(), + float("inf") + if self.gradient_max_norm <= 0.0 + else self.gradient_max_norm, + stable=self.zero_stage < 2, + ) + self.nonfinite_grad_guard.update(total_norm) + if self.gradient_max_norm > 0.0: # Collect per-parameter gradient norms before clipping. # NOTE: Under FSDP2 with ZeRO stage >= 2, p.grad is a sharded DTensor, # so p.grad.norm() computes the shard-local L2 norm, not the full-parameter # norm. Skip per-param collection in this case to avoid misleading values. if ( self.enable_tensorboard and self.zero_stage < 2 and ( display_step_id % self.tensorboard_freq == 0 or display_step_id == 1 ) ): pre_clip_named_norms = [ (name, p.grad.detach().norm().item()) for name, p in self.wrapper.named_parameters() if p.grad is not None ] - total_norm = clip_grad_norm_( - self.wrapper.parameters(), - self.gradient_max_norm, - stable=self.zero_stage < 2, - ) - self.nonfinite_grad_guard.update(total_norm)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/train/training.py` around lines 1397 - 1422, The non-finite gradient guard (self.nonfinite_grad_guard.update) is only called when self.gradient_max_norm > 0.0, leaving the guard inactive when gradient clipping is disabled (gradient_max_norm == 0.0); move or duplicate the call so total_norm (or a computed norm placeholder) is passed to self.nonfinite_grad_guard.update regardless of clipping being enabled. Concretely, after computing per-parameter norms (pre_clip_named_norms) and/or computing total_norm via clip_grad_norm_ in the block using self.wrapper.parameters(), ensure that self.nonfinite_grad_guard.update(total_norm) is invoked even when self.gradient_max_norm <= 0.0 (compute total_norm with torch.norm over grads or call clip_grad_norm_ with max_norm=inf/0-check), preserving the existing behavior with zero_stage and referencing the same symbols: self.gradient_max_norm, self.nonfinite_grad_guard.update, clip_grad_norm_, and self.wrapper.named_parameters().
1733-1767:⚠️ Potential issue | 🟠 Major | ⚡ Quick winCheck the non-finite guard before any validator save callback runs.
full_validator.run()andema_full_validator.run()execute beforeraise_if_nonfinite(), and both receive checkpoint-saving callbacks. A diverged step can therefore still write best/full-validation checkpoints here, even though the regular periodic save is blocked a few lines later.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/train/training.py` around lines 1733 - 1767, The non-finite gradient check must run before any validator callbacks that can write checkpoints; move the call to self.nonfinite_grad_guard.raise_if_nonfinite(self.wrapper.named_parameters) to occur before invoking self.full_validator.run(...) and self.ema_full_validator.run(...), so that full_validator.run and ema_full_validator.run (and their save_checkpoint callbacks like self.save_model_merged/self.save_model and self.save_ema_model_merged/self.save_ema_model) cannot write checkpoints when gradients are non-finite; keep the existing condition that gates the periodic save (zero_stage/rank checks) but ensure the raise_if_nonfinite call executes unconditionally (or under the same global save gating) prior to any validator.run calls.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@deepmd/pt/train/training.py`:
- Around line 1389-1422: The variables total_norm and pre_clip_named_norms must
be initialized before the optimizer-specific branch so they exist for all
optimizers (e.g., LKF); move or add initialization of total_norm = None and
pre_clip_named_norms = [] above the if self.opt_type ... block, and only perform
the per-parameter collection and clip_grad_norm_ call inside the
Adam/AdamW/AdaMuon/HybridMuon branch as before; ensure
nonfinite_grad_guard.update(total_norm) is only called when total_norm has been
set (or leave the update call after the branch but guarded by a non-None check)
so later TensorBoard code can safely read total_norm and pre_clip_named_norms
for any optimizer.
- Around line 1397-1422: The non-finite gradient guard
(self.nonfinite_grad_guard.update) is only called when self.gradient_max_norm >
0.0, leaving the guard inactive when gradient clipping is disabled
(gradient_max_norm == 0.0); move or duplicate the call so total_norm (or a
computed norm placeholder) is passed to self.nonfinite_grad_guard.update
regardless of clipping being enabled. Concretely, after computing per-parameter
norms (pre_clip_named_norms) and/or computing total_norm via clip_grad_norm_ in
the block using self.wrapper.parameters(), ensure that
self.nonfinite_grad_guard.update(total_norm) is invoked even when
self.gradient_max_norm <= 0.0 (compute total_norm with torch.norm over grads or
call clip_grad_norm_ with max_norm=inf/0-check), preserving the existing
behavior with zero_stage and referencing the same symbols:
self.gradient_max_norm, self.nonfinite_grad_guard.update, clip_grad_norm_, and
self.wrapper.named_parameters().
- Around line 1733-1767: The non-finite gradient check must run before any
validator callbacks that can write checkpoints; move the call to
self.nonfinite_grad_guard.raise_if_nonfinite(self.wrapper.named_parameters) to
occur before invoking self.full_validator.run(...) and
self.ema_full_validator.run(...), so that full_validator.run and
ema_full_validator.run (and their save_checkpoint callbacks like
self.save_model_merged/self.save_model and
self.save_ema_model_merged/self.save_ema_model) cannot write checkpoints when
gradients are non-finite; keep the existing condition that gates the periodic
save (zero_stage/rank checks) but ensure the raise_if_nonfinite call executes
unconditionally (or under the same global save gating) prior to any
validator.run calls.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: e9214d86-4472-4d81-84d6-7a2c5da60808
📒 Files selected for processing (3)
deepmd/pt/train/training.pydeepmd/pt/train/utils.pysource/tests/pt/test_train_utils.py
There was a problem hiding this comment.
Pull request overview
This PR refactors PyTorch training gradient clipping to avoid per-step host synchronization, while adding a deferred non-finite gradient detection mechanism that is checked at checkpoint boundaries to prevent saving diverged models.
Changes:
- Replaced
clip_grad_norm_with_stable_fallbackwith a newclip_grad_norm_utility that supports a stable (overflow-resistant) reduction path and a non-stable path for FSDP2/DTensor compatibility. - Added
NonFiniteGradGuardto accumulate non-finite norm flags on-device during training steps and raise only when saving checkpoints. - Updated/rewrote unit tests to cover the new clipping behavior and guard semantics.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
deepmd/pt/train/utils.py |
Introduces clip_grad_norm_, NonFiniteGradGuard, and updates non-finite gradient reporting. |
deepmd/pt/train/training.py |
Integrates the new clipping function and defers divergence checks to checkpoint-save points. |
source/tests/pt/test_train_utils.py |
Replaces legacy gradient-clip tests with coverage for the new utilities and guard behavior. |
Comments suppressed due to low confidence (1)
deepmd/pt/train/utils.py:165
raise_nonfinite_gradient_norm()currently usesparam.grad.detach().norm()to decide whether an individual gradient is non-finite. For very large but finite float32 gradients, the reduction inside.norm()can overflow toinf, causing the parameter to be reported as having non-finite gradients even when all gradient values are finite (contradicting the docstring’s “overflow in the norm reduction” case). To match the intended behavior, check finiteness of the gradient values (and only then compute a norm for reporting).
if param.grad is None:
continue
grad_norm = param.grad.detach().norm()
if not torch.isfinite(grad_norm):
bad_params.append(
f" {name}: grad_norm={grad_norm}, shape={list(param.shape)}"
)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5519 +/- ##
==========================================
- Coverage 82.19% 82.19% -0.01%
==========================================
Files 891 891
Lines 101599 101581 -18
Branches 4242 4242
==========================================
- Hits 83507 83492 -15
+ Misses 16789 16786 -3
Partials 1303 1303 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
🧹 Nitpick comments (1)
source/tests/pt/test_train_utils.py (1)
58-71: ⚡ Quick winBind the loop variable in the lambda to avoid late-binding issue.
The lambda on line 71 captures the loop variable
pwithout binding it, which Ruff B023 flags as a potential late-binding hazard. Although the lambda is invoked immediately in this test, explicitly binding the variable improves clarity and eliminates the warning.🔧 Proposed fix
guard = NonFiniteGradGuard() guard.update(total_norm) - with self.assertRaisesRegex(RuntimeError, "p"): - guard.raise_if_nonfinite(lambda: [("p", p)]) + with self.assertRaisesRegex(RuntimeError, "p"): + guard.raise_if_nonfinite(lambda p=p: [("p", p)])Note: This test directly addresses the end-to-end non-finite coverage requested in the past review comment by feeding actual NaN/Inf gradients through
clip_grad_norm_(for bothstable=Trueandstable=False) and asserting that the returnedtotal_normis non-finite and thatNonFiniteGradGuardaccumulates and raises on it.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/pt/test_train_utils.py` around lines 58 - 71, In test_nonfinite_grad_is_deferred_to_guard, the lambda passed to guard.raise_if_nonfinite captures loop variable p late; change it to bind p as a default argument (e.g. use lambda p=p: [("p", p)]) so the current parameter is captured immediately and Ruff B023 is resolved while preserving the test behavior.Source: Linters/SAST tools
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@source/tests/pt/test_train_utils.py`:
- Around line 58-71: In test_nonfinite_grad_is_deferred_to_guard, the lambda
passed to guard.raise_if_nonfinite captures loop variable p late; change it to
bind p as a default argument (e.g. use lambda p=p: [("p", p)]) so the current
parameter is captured immediately and Ruff B023 is resolved while preserving the
test behavior.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 491e40f8-94ce-42b8-b19a-9635be8040ed
📒 Files selected for processing (3)
deepmd/pt/train/training.pydeepmd/pt/train/utils.pysource/tests/pt/test_train_utils.py
🚧 Files skipped from review as they are similar to previous changes (2)
- deepmd/pt/train/training.py
- deepmd/pt/train/utils.py
njzjz
left a comment
There was a problem hiding this comment.
Requesting changes for the non-finite gradient handling issue below.
Reviewed by codex.
Summary by CodeRabbit
Bug Fixes
Chores
Tests