feat(pt_expt): DPA4 descriptor and fitting (training + export)#5522
Conversation
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughThis PR adds a gradient-preserving array-conversion helper, migrates DPA4 descriptor and NN modules to use it, introduces PyTorch wrappers with trainable-buffer promotion, extends model-loading for DPA4/SeZM, and adds comprehensive tests covering parity, tracing, export, and training. ChangesDPA4/SeZM Integration for PT Expt
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 1f26ba1359
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/dpmodel/descriptor/dpa4.py (1)
768-790:⚠️ Potential issue | 🟠 Major | ⚡ Quick winReject non-
Nonecharge_spininstead of silently ignoring it.The new argument is documented as unsupported, but
call()currently accepts any non-Nonevalue and computes the same descriptor anyway. That can hide an upstream wiring bug and silently return the wrong behavior for callers that think charge/spin conditioning is active.💡 Proposed fix
def call( self, coord_ext: Array, atype_ext: Array, nlist: Array, mapping: Array | None = None, fparam: Array | None = None, comm_dict: dict | None = None, charge_spin: Array | None = None, ) -> tuple[Array, Any, Any, Any, Any]: """Compute the DPA4 descriptor. @@ rot_mat, g2, h2, sw ``None`` placeholders (pt returns empty tensors for these). """ + if charge_spin is not None: + raise NotImplementedError( + "`charge_spin` is unsupported for dpmodel DescrptDPA4; " + "`add_chg_spin_ebd=True` is rejected at construction" + ) xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/dpmodel/descriptor/dpa4.py` around lines 768 - 790, The call method in dpa4.py currently ignores a non-None charge_spin; update the beginning of the call(...) function to explicitly reject any non-None charge_spin by raising a clear exception (e.g., ValueError) with a message like "charge_spin must be None for DPA4 descriptor" so callers are not silently misled; place this check in the compute DPA4 descriptor function (the call function) before any descriptor computation or branching that would use other inputs.
🧹 Nitpick comments (1)
source/tests/pt_expt/descriptor/test_dpa4.py (1)
63-66: Downgrade dtype/device concerns intest_dpa4.py
torch.tensor(..., dtype=int)is accepted by PyTorch (mapped to the default integer dtype, typicallytorch.int64), so it’s not an API-breaking issue before the descriptor is exercised.DescrptDPA4.deserialize(dd0.serialize())uses the pt_expt conversion path (dpmodel_setattr) to materialize numpy arrays ontopt_expt.utils.env.DEVICE, matchingself.device, so the round-trip doesn’t require an extra.to(self.device).Optional: switch
dtype=int→dtype=torch.int64for explicitness/consistency with the rest of the codebase.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/pt_expt/descriptor/test_dpa4.py` around lines 63 - 66, The tensors atype_ext, nlist, and mapping are created with dtype=int which relies on PyTorch's default integer mapping; update their creation to use an explicit integer dtype (dtype=torch.int64) for consistency with the codebase and keep using the existing device handling (do not add extra .to(self.device) calls) because DescrptDPA4.deserialize(...) and dpmodel_setattr already materialize arrays onto pt_expt.utils.env.DEVICE; modify the tensor constructions for atype_ext, nlist, and mapping to use dtype=torch.int64 while leaving device logic unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt_expt/model/get_model.py`:
- Around line 152-155: Replace the use of setdefault for the nested type fields
in get_sezm_model/get_model logic: when model type is the fixed DPA4/SeZM
("dpa4" or the canonical name you expect) explicitly set
data["descriptor"]["type"] = "dpa4" and data["fitting_net"]["type"] =
"dpa4_ener" (or, if you prefer strictness, check if those keys exist and raise a
ValueError if their values differ), rather than leaving preexisting values in
place; update the code that currently calls data.setdefault("descriptor", {}) /
data.setdefault("fitting_net", {}) and data["..."].setdefault("type", ...) to
either assign the canonical strings or validate-and-reject so the SeZM path
always uses the expected descriptor and fitting types.
In `@source/tests/pt_expt/descriptor/test_dpa4.py`:
- Around line 72-73: The deserialized descriptor dd1 is left on the default
device causing device-mismatch errors when inputs are on self.device; after
calling DescrptDPA4.deserialize(...) and before invoking dd1(coord_ext,
atype_ext, nlist, mapping), move/transfer dd1 to the test device (self.device)
by calling the descriptor's device-transfer method (e.g., dd1.to(self.device) or
equivalent) so the round-trip check runs on the same CUDA device as the inputs.
In `@source/tests/pt_expt/test_training.py`:
- Around line 314-321: Add a 60-second timeout to the new
test_training_loop_dpa4 by decorating the test with `@pytest.mark.timeout`(60)
(add "import pytest" if not already present) so the test is capped like other
training validations; update the function definition for test_training_loop_dpa4
and ensure the pytest timeout marker is available in the test file and test
runner.
In `@source/tests/pt/model/test_dpa4_ptexpt_grad_parity.py`:
- Around line 129-137: The constructed models are left on the default device
while test tensors use PT_DEVICE; move both models to PT_DEVICE after creation
to avoid device-mismatch failures. Specifically, after creating pt_mod
(DescrptSeZM(**kwargs).double()) call .to(PT_DEVICE) before perturbing
parameters, and after deserializing expt_mod (DescrptDPA4.deserialize(...)) also
call .to(PT_DEVICE); ensure you reference PT_DEVICE, pt_mod, expt_mod,
DescrptSeZM and DescrptDPA4.deserialize when making these changes.
---
Outside diff comments:
In `@deepmd/dpmodel/descriptor/dpa4.py`:
- Around line 768-790: The call method in dpa4.py currently ignores a non-None
charge_spin; update the beginning of the call(...) function to explicitly reject
any non-None charge_spin by raising a clear exception (e.g., ValueError) with a
message like "charge_spin must be None for DPA4 descriptor" so callers are not
silently misled; place this check in the compute DPA4 descriptor function (the
call function) before any descriptor computation or branching that would use
other inputs.
---
Nitpick comments:
In `@source/tests/pt_expt/descriptor/test_dpa4.py`:
- Around line 63-66: The tensors atype_ext, nlist, and mapping are created with
dtype=int which relies on PyTorch's default integer mapping; update their
creation to use an explicit integer dtype (dtype=torch.int64) for consistency
with the codebase and keep using the existing device handling (do not add extra
.to(self.device) calls) because DescrptDPA4.deserialize(...) and dpmodel_setattr
already materialize arrays onto pt_expt.utils.env.DEVICE; modify the tensor
constructions for atype_ext, nlist, and mapping to use dtype=torch.int64 while
leaving device logic unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: c3eea030-d041-4f20-a1b8-fa4df73e938b
📒 Files selected for processing (27)
deepmd/dpmodel/array_api.pydeepmd/dpmodel/descriptor/dpa4.pydeepmd/dpmodel/descriptor/dpa4_nn/activation.pydeepmd/dpmodel/descriptor/dpa4_nn/edge_cache.pydeepmd/dpmodel/descriptor/dpa4_nn/embedding.pydeepmd/dpmodel/descriptor/dpa4_nn/grid_net.pydeepmd/dpmodel/descriptor/dpa4_nn/indexing.pydeepmd/dpmodel/descriptor/dpa4_nn/norm.pydeepmd/dpmodel/descriptor/dpa4_nn/projection.pydeepmd/dpmodel/descriptor/dpa4_nn/radial.pydeepmd/dpmodel/descriptor/dpa4_nn/so2.pydeepmd/dpmodel/descriptor/dpa4_nn/so3.pydeepmd/dpmodel/descriptor/dpa4_nn/wignerd.pydeepmd/dpmodel/fitting/dpa4_ener.pydeepmd/pt_expt/common.pydeepmd/pt_expt/descriptor/__init__.pydeepmd/pt_expt/descriptor/dpa4.pydeepmd/pt_expt/fitting/__init__.pydeepmd/pt_expt/fitting/dpa4_ener.pydeepmd/pt_expt/model/get_model.pysource/tests/consistent/descriptor/test_dpa4.pysource/tests/consistent/fitting/test_dpa4_ener.pysource/tests/pt/model/test_dpa4_ptexpt_grad_parity.pysource/tests/pt_expt/descriptor/test_dpa4.pysource/tests/pt_expt/fitting/test_dpa4_ener.pysource/tests/pt_expt/model/test_get_model_dpa4.pysource/tests/pt_expt/test_training.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5522 +/- ##
==========================================
+ Coverage 82.19% 82.21% +0.02%
==========================================
Files 891 893 +2
Lines 101599 101746 +147
Branches 4241 4240 -1
==========================================
+ Hits 83507 83654 +147
Misses 16789 16789
Partials 1303 1303 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
…eze, component type validation)
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/dpmodel/descriptor/dpa4.py (1)
768-790:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winReject non-
Nonecharge_spininputs instead of silently ignoring them.Line 768 adds the compatibility argument, but the implementation never uses it. Since
add_chg_spin_ebd=Trueis already rejected in__init__, accepting a non-Nonecharge_spinhere returns the unconditioned descriptor with no signal to the caller. A fail-fast guard would keep the runtime contract aligned with the rest of the unsupported charge/spin path.💡 Proposed fix
def call( self, coord_ext: Array, atype_ext: Array, nlist: Array, mapping: Array | None = None, fparam: Array | None = None, comm_dict: dict | None = None, charge_spin: Array | None = None, ) -> tuple[Array, Any, Any, Any, Any]: """Compute the DPA4 descriptor. @@ + if charge_spin is not None: + raise NotImplementedError( + "charge_spin is unsupported for DescrptDPA4 when " + "add_chg_spin_ebd=False" + ) xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/dpmodel/descriptor/dpa4.py` around lines 768 - 790, The forward method in the DPA4 descriptor class accepts charge_spin as a parameter but silently ignores non-None values, which is inconsistent with the documented behavior that it must be None. Add a validation check at the beginning of the forward method that raises an error (such as ValueError) if charge_spin is not None, ensuring fail-fast behavior that aligns with the rejection of add_chg_spin_ebd=True in __init__ and the stated interface contract in the docstring.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@deepmd/dpmodel/descriptor/dpa4.py`:
- Around line 768-790: The forward method in the DPA4 descriptor class accepts
charge_spin as a parameter but silently ignores non-None values, which is
inconsistent with the documented behavior that it must be None. Add a validation
check at the beginning of the forward method that raises an error (such as
ValueError) if charge_spin is not None, ensuring fail-fast behavior that aligns
with the rejection of add_chg_spin_ebd=True in __init__ and the stated interface
contract in the docstring.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 6ebf72b8-87d1-4f68-943f-3aab3ef171d5
📒 Files selected for processing (1)
deepmd/dpmodel/descriptor/dpa4.py
…eepmodeling#5540) PR-3 (final) of the DPA4/SeZM porting series — pt_expt **inference**: freeze to `.pt2`, Python `DeepEval`, pt→pt_expt checkpoint interop, C++ single-rank, and LAMMPS single-rank. Follows PR-1 (deepmodeling#5515, dpmodel core) and PR-2 (deepmodeling#5522, pt_expt training/export). ## What's included - **Model freeze to `.pt2`** (`deepmd/pt_expt/model/ener_model.py`, `deepmd/dpmodel/.../ener_model.py`, `deepmd/dpmodel/descriptor/dpa4.py`): register `EnergyModel` under `sezm_ener`/`dpa4_ener` model-type aliases so `BaseModel.deserialize` resolves a standard DPA4 energy model (whose fitting type is `sezm_ener`). Fixed a `torch.export` specialization where `int()` on symbolic shapes baked `nf*nloc` (embedding/so2/attention). - **NoPBC export fix** (`deepmd/dpmodel/descriptor/dpa4.py`): the `atype_ext[:, :nloc]` slice emitted a spurious `Ne(nall, nloc)` shape guard that crashed the compiled artifact when `nall==nloc` (no ghosts); replaced with `xp_take_first_n` (index_select). NoPBC now matches PBC. - **pt→pt_expt checkpoint interop** (`deepmd/pt_expt/model/model.py`): `BaseModel.deserialize` unwraps pt's bespoke `SeZMModel` serialization (`type:"SeZM"`, nested `sezm_atomic` atomic model with the pt-only dens head), validates versions, rejects unsupported features (bridging/lora/dens/active_mode) with `NotImplementedError`, and delegates to the standard path. - **Warn on silently-ignored flags** (`use_amp` descriptor, `enable_tf32` model): warn-once instead of silent drop. ## Tests - **Model freeze** `source/tests/pt_expt/model/test_dpa4_export.py`: dual-artifact `.pt2`, metadata, AOTI load, artifact-vs-eager parity (1e-10). *(CI-skipped — AOTI is slow; run locally.)* - **DeepEval parity vs pt** `source/tests/pt_expt/infer/test_dpa4_deep_eval.py`: pt `.pt` vs pt_expt `.pt2` energy/force/global-virial/atom-energy at fp64 1e-10, **PBC and NoPBC**; doubles as the checkpoint-interop proof. Per-atom virial compared by sum (pt's edge-scatter from deepmodeling#5518 redistributes it; global virial matches). *(CI-skipped — AOTI.)* - **Interop unit tests** `source/tests/pt_expt/model/test_dpa4_interop.py` (CI-runnable, no AOTI): happy-path pt-checkpoint→pt_expt round-trip + every guard branch + version validation + `@variables` filtering. - **Alias deserialize guard** + **use_amp/enable_tf32 warn-once** tests (CI-runnable). - **Fixture generator** `source/tests/infer/gen_dpa4.py` (+ wired into `source/install/test_cc_local.sh`). - **C++ single-rank** `source/api_cc/tests/test_deeppot_dpa4_ptexpt.cc`: 20 tests (double+float), dpa3-matched tolerances. Validated locally. - **LAMMPS single-rank** `source/lmp/tests/test_lammps_dpa4_pt2.py`: parity + `atom_modify map yes` + the deepmodeling#5450 no-atom-map fail-fast. **Validated on a GPU box (7 passed).** PR-1 parity suites stay green; the small dpmodel edits are parity-revalidated. ## Known limitations - **Single-rank only.** Multi-rank/MPI LAMMPS for DPA4 is deferred (no live multi-rank cell; the with-comm artifact compiles but its runtime is not exercised). DPA4 is a message-passing descriptor, so multi-rank follows the existing deepmodeling#5450/deepmodeling#5430 machinery in a later PR. - **No `.pth` (torch.jit) DPA4** — the pt backend has no `sezm_ener` *model* registration, so `.pth` freeze of a standard DPA4 energy model isn't available; not needed for the pt_expt inference path. - **Per-atom virial** is not compared element-wise pt-vs-pt_expt (only its global sum) — deepmodeling#5518 changed pt's edge-scatter distribution; both are correct, the distribution differs. - **AOTI tests are CI-skipped** (multi-minute compile) — the freeze/DeepEval paths are validated locally, not in CI; the interop/alias/warn tests give CI coverage of the non-AOTI logic. - **fp64 only**; fp32 freeze untested. CUDA validated at LAMMPS level on a GPU box; the AOTI parity numbers are from CPU fp64. - **`use_amp`/`enable_tf32`** remain functionally ignored (now warned) — by design for this series. - pt SeZM features out of scope (guarded `NotImplementedError`): spin, ZBL bridging, LoRA, dens/direct-force/denoising heads, SO3 grid projection, GridMLP, SO(2) attention extensions. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit # Release Notes * **New Features** * Enabled DPA4 model inference via the pt_expt backend using dual-artifact compilation. * Registered the EnergyModel under additional aliases: `sezm_ener` and `dpa4_ener`. * **Improvements** * Improved dynamic/symbolic shape handling across DPA4 components for export/tracing stability. * Enhanced pt SeZM/DPA4 checkpoint deserialization and normalization for interoperability. * Added one-time warnings when `use_amp` or `enable_tf32` settings are ineffective. * **Tests** * Added C++ and Python coverage for pt2 inference, LAMMPS integration, model export/freeze, parity, interop, and warning behavior. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
PR-2 of the DPA4/SeZM porting series (PR-1: #5515, dpmodel core). Makes DPA4 trainable and torch.export-compatible on the pt_expt backend.
What's included
deepmd/pt_expt/descriptor/dpa4.py): thin@torch_moduleDescrptDPA4registered asSeZM/sezm/DPA4/dpa4, plus explicit converter registrations for the two dpmodel classes that cannot be auto-wrapped via serialize round-trip:WignerDCalculator(deserialize raises by design — rebuilt fromlmax/eps/precision) andSwiGLU(parameter-free, no serialize)._TRAINABLE_ATTRStable re-registers exactly the attributes that arenn.Parameterin the pt SeZM implementation, after__init__/deserialize. Two supporting infra fixes:_try_convert_listnow handles Optional-module lists (SO2Convolution.non_linearitiesends withNone), andxp_asarray_nodetach(deepmd/dpmodel/array_api.py) replacesxp.asarrayon weight attributes in dpa4 dpmodel code so torch autograd is not silently detached.deepmd/pt_expt/fitting/dpa4_ener.py):SeZMEnergyFittingNet(dpa4_ener/sezm_ener),GLUFittingNet, and aSeZMNetworkCollectionwrapper mirroring theNetworkCollectionModuleDict pattern so GLU weights are optimizer-visible.deepmd/pt_expt/model/get_model.py):get_sezm_modeldispatched for model typedpa4/sezm/SeZM/DPA4, mirroring pt's config handling (descriptor/fitting type defaults,pair_exclude_typesreconciliation). Fail-fastNotImplementedErrorguards for pt-only extensions: spin,bridging_method != none,lora,use_compile,preset_out_bias.enable_tf32is accepted and ignored (pt_expt always runs "highest" matmul precision; documented).Tests
source/tests/pt_expt/descriptor/test_dpa4.py): consistency (serde round-trip + dpmodel reference),torch.export.export,make_fxthrough forward +autograd.grad; fitting analogs insource/tests/pt_expt/fitting/test_dpa4_ener.py.source/tests/pt/model/test_dpa4_ptexpt_grad_parity.py): weight-copied pt SeZM vs pt_expt, quadratic loss, every parameter's gradient compared 1:1 via serialize-tree alignment (83/83 parameters, no dead parameters, max deviation ~3e-11 rel from fp64 summation order).source/tests/pt_expt/model/test_get_model_dpa4.py): aliases, defaults, exclude-types branches, all NIE guards.source/tests/pt_expt/test_training.py::test_training_loop_dpa4) on the water fixture through the full trainer.source/tests/consistent/): 20 new comparisons pass at existing harness tolerances.to_numpy_array, explicit-dtype zeros, ParameterList-safe_load_variables,charge_spinkwarg for atomic-model interface compatibility).Known limitations
.pt2export, no PythonDeepEval, no C++/LAMMPS inference — that is PR-3 of the series.share_paramsraisesNotImplementedError.SeZMModelcheckpoints (PR-3).use_compile,preset_out_bias;enable_tf32ignored (numerically conservative)._try_convert_list's ParameterList branch does not respecttrainable=Falsefor list-valued weights (pre-existing backend-wide behavior).Summary by CodeRabbit
New Features
Improvements
charge_spinfor compatibility.Tests