Skip to content

feat(pt_expt): DPA4 descriptor and fitting (training + export)#5522

Merged
wanghan-iapcm merged 12 commits into
deepmodeling:masterfrom
wanghan-iapcm:feat-ptexpt-dpa4
Jun 15, 2026
Merged

feat(pt_expt): DPA4 descriptor and fitting (training + export)#5522
wanghan-iapcm merged 12 commits into
deepmodeling:masterfrom
wanghan-iapcm:feat-ptexpt-dpa4

Conversation

@wanghan-iapcm

@wanghan-iapcm wanghan-iapcm commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator

PR-2 of the DPA4/SeZM porting series (PR-1: #5515, dpmodel core). Makes DPA4 trainable and torch.export-compatible on the pt_expt backend.

What's included

  • Descriptor wrapper (deepmd/pt_expt/descriptor/dpa4.py): thin @torch_module DescrptDPA4 registered as SeZM/sezm/DPA4/dpa4, plus explicit converter registrations for the two dpmodel classes that cannot be auto-wrapped via serialize round-trip: WignerDCalculator (deserialize raises by design — rebuilt from lmax/eps/precision) and SwiGLU (parameter-free, no serialize).
  • Trainable-weight promotion: dpmodel stores DPA4 weights as bare numpy arrays (mirroring pt SeZM parameter names), which the generic ndarray rule registers as non-trainable buffers. A localized _TRAINABLE_ATTRS table re-registers exactly the attributes that are nn.Parameter in the pt SeZM implementation, after __init__/deserialize. Two supporting infra fixes: _try_convert_list now handles Optional-module lists (SO2Convolution.non_linearities ends with None), and xp_asarray_nodetach (deepmd/dpmodel/array_api.py) replaces xp.asarray on weight attributes in dpa4 dpmodel code so torch autograd is not silently detached.
  • Fitting wrapper (deepmd/pt_expt/fitting/dpa4_ener.py): SeZMEnergyFittingNet (dpa4_ener/sezm_ener), GLUFittingNet, and a SeZMNetworkCollection wrapper mirroring the NetworkCollection ModuleDict pattern so GLU weights are optimizer-visible.
  • Model assembly (deepmd/pt_expt/model/get_model.py): get_sezm_model dispatched for model type dpa4/sezm/SeZM/DPA4, mirroring pt's config handling (descriptor/fitting type defaults, pair_exclude_types reconciliation). Fail-fast NotImplementedError guards for pt-only extensions: spin, bridging_method != none, lora, use_compile, preset_out_bias. enable_tf32 is accepted and ignored (pt_expt always runs "highest" matmul precision; documented).

Tests

  • pt_expt trio (source/tests/pt_expt/descriptor/test_dpa4.py): consistency (serde round-trip + dpmodel reference), torch.export.export, make_fx through forward + autograd.grad; fitting analogs in source/tests/pt_expt/fitting/test_dpa4_ener.py.
  • Parameter-gradient parity vs pt (source/tests/pt/model/test_dpa4_ptexpt_grad_parity.py): weight-copied pt SeZM vs pt_expt, quadratic loss, every parameter's gradient compared 1:1 via serialize-tree alignment (83/83 parameters, no dead parameters, max deviation ~3e-11 rel from fp64 summation order).
  • Model assembly tests (source/tests/pt_expt/model/test_get_model_dpa4.py): aliases, defaults, exclude-types branches, all NIE guards.
  • Training end-to-end (source/tests/pt_expt/test_training.py::test_training_loop_dpa4) on the water fixture through the full trainer.
  • Cross-backend consistency rows enabled for pt_expt (source/tests/consistent/): 20 new comparisons pass at existing harness tolerances.
  • PR-1 parity suites stay green (381 passed) including for the small dpmodel fixes this PR needed (serialize via to_numpy_array, explicit-dtype zeros, ParameterList-safe _load_variables, charge_spin kwarg for atomic-model interface compatibility).

Known limitations

  • No freeze/.pt2 export, no Python DeepEval, no C++/LAMMPS inference — that is PR-3 of the series.
  • No multitask: share_params raises NotImplementedError.
  • No model-dict deserialization of pt SeZMModel checkpoints (PR-3).
  • pt-only SeZM extensions fail fast rather than being ported: spin, bridging/ZBL, LoRA, use_compile, preset_out_bias; enable_tf32 ignored (numerically conservative).
  • _try_convert_list's ParameterList branch does not respect trainable=False for list-valued weights (pre-existing backend-wide behavior).
  • CUDA validated at tolerance level only in CI; performance vs the pt Triton implementation is unbenchmarked.

Summary by CodeRabbit

  • New Features

    • Added PyTorch-experimental DPA4/SeZM descriptor and energy-fitting components, with public package exports.
    • Added EnergyModel factory/dispatch to build DPA4/SeZM models from configs.
  • Improvements

    • Enhanced DPA4/SeZM runtime handling to keep gradient connectivity and ensure device-consistent tensor placement.
    • Updated DPA4 descriptor interface to accept optional charge_spin for compatibility.
  • Tests

    • Expanded experimental coverage with serialization/parity checks, FX tracing/export validation, trainable-parameter behavior tests, and descriptor/fitting gradient parity verification.

@wanghan-iapcm wanghan-iapcm requested review from OutisLi and njzjz June 12, 2026 16:19
@coderabbitai

coderabbitai Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

This PR adds a gradient-preserving array-conversion helper, migrates DPA4 descriptor and NN modules to use it, introduces PyTorch wrappers with trainable-buffer promotion, extends model-loading for DPA4/SeZM, and adds comprehensive tests covering parity, tracing, export, and training.

Changes

DPA4/SeZM Integration for PT Expt

Layer / File(s) Summary
Gradient-preserving array-conversion helper
deepmd/dpmodel/array_api.py
Introduces xp_asarray_nodetach to convert arrays without detaching gradients. Backend tensors are returned as-is (or with optional differentiable dtype cast), while non-backend inputs are converted via xp.asarray with device placement.
DPModel DPA4 descriptor runtime conversion migration
deepmd/dpmodel/descriptor/dpa4.py
Updates DescrptDPA4.call with charge_spin compatibility parameter; replaces xp.asarray with xp_asarray_nodetach for film-conditioning tensors and flat-index materialization; updates serialization to use to_numpy_array for mean/stddev and Wigner matrices.
DPModel NN modules parameter materialization
deepmd/dpmodel/descriptor/dpa4_nn/*.py
Replaces xp.asarray with xp_asarray_nodetach across activation, edge-cache, embedding, grid-net, indexing, norm, projection, radial, and Wigner-D modules; updates deserializers to validate index-tables via to_numpy_array and fixes parameter initialization shapes.
DPModel fitting serialization update
deepmd/dpmodel/fitting/dpa4_ener.py
Updates GLUFittingNet.serialize() to wrap layer matrices and biases with to_numpy_array for consistent serialization format.
PyTorch module conversion for mixed lists
deepmd/pt_expt/common.py
Enhances _try_convert_list to handle heterogeneous lists mixing torch.nn.Module, NativeOP, and None, converting NativeOPs while preserving None entries.
PyTorch descriptor wrapper with trainable promotion
deepmd/pt_expt/descriptor/dpa4.py, deepmd/pt_expt/descriptor/__init__.py
Adds DescrptDPA4 wrapper (registered as SeZM/sezm/DPA4/dpa4), lightweight sub-component wrappers (WignerDCalculator, SwiGLU), trainable-attribute definitions, and promotion helpers with re-promotion after deserialize; exports DescrptDPA4.
PyTorch fitting network wrapper
deepmd/pt_expt/fitting/dpa4_ener.py, deepmd/pt_expt/fitting/__init__.py
Introduces GLUFittingNet, SeZMNetworkCollection (with ModuleDict synchronization), and SeZMEnergyFittingNet; registers dpmodel mappings and BaseFitting names; exports SeZMEnergyFittingNet.
DPA4/SeZM model loading and dispatch
deepmd/pt_expt/model/get_model.py
Implements get_sezm_model to normalize configs, reject unsupported pt-only features, enforce descriptor/fitting type contracts, validate exclude_types consistency, and extends get_model dispatch for DPA4/SeZM type aliases.
Cross-backend consistency test integration
source/tests/consistent/descriptor/test_dpa4.py, source/tests/consistent/fitting/test_dpa4_ener.py
Updates existing tests to conditionally import and enable pt_expt backends, adding eval_pt_expt evaluation helpers for output comparison.
Gradient parity test suite
source/tests/pt/model/test_dpa4_ptexpt_grad_parity.py
Adds gradient parity utilities and test classes (TestDescriptorGradParity, TestFittingGradParity) to assert parameter-gradient parity between pt reference and pt_expt wrappers via serialized name-aligned arrays.
PT-expt descriptor behavior tests
source/tests/pt_expt/descriptor/test_dpa4.py
Adds test module covering descriptor factory, consistency versus dpmodel, export via torch.export.export, make_fx tracing, and trainable/freezing parameter assertions.
PT-expt fitting behavior tests
source/tests/pt_expt/fitting/test_dpa4_ener.py
Adds test module covering fitting parity, trainable-parameter validation, serialized type assertion, and make_fx gradient parity checks.
Model loading contract and validation tests
source/tests/pt_expt/model/test_get_model_dpa4.py
Adds comprehensive test suite for model normalization, type-alias dispatch, component compatibility, exclude_types reconciliation, unsupported-key rejection, and normalized defaults.
DPA4 end-to-end training smoke test
source/tests/pt_expt/test_training.py
Adds _MODEL_DPA4 configuration fixture and test_training_loop_dpa4 to execute a full training loop with DPA4 model type.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • iProzd
  • njzjz
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 54.87% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat(pt_expt): DPA4 descriptor and fitting (training + export)' clearly summarizes the main changes: adding DPA4/SeZM descriptor and fitting implementations for the pt_expt backend with training and export capabilities.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Comment thread source/tests/consistent/fitting/test_dpa4_ener.py Dismissed

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 1f26ba1359

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread deepmd/pt_expt/descriptor/dpa4.py

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/dpmodel/descriptor/dpa4.py (1)

768-790: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject non-None charge_spin instead of silently ignoring it.

The new argument is documented as unsupported, but call() currently accepts any non-None value and computes the same descriptor anyway. That can hide an upstream wiring bug and silently return the wrong behavior for callers that think charge/spin conditioning is active.

💡 Proposed fix
     def call(
         self,
         coord_ext: Array,
         atype_ext: Array,
         nlist: Array,
         mapping: Array | None = None,
         fparam: Array | None = None,
         comm_dict: dict | None = None,
         charge_spin: Array | None = None,
     ) -> tuple[Array, Any, Any, Any, Any]:
         """Compute the DPA4 descriptor.
@@
         rot_mat, g2, h2, sw
             ``None`` placeholders (pt returns empty tensors for these).
         """
+        if charge_spin is not None:
+            raise NotImplementedError(
+                "`charge_spin` is unsupported for dpmodel DescrptDPA4; "
+                "`add_chg_spin_ebd=True` is rejected at construction"
+            )
         xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/descriptor/dpa4.py` around lines 768 - 790, The call method in
dpa4.py currently ignores a non-None charge_spin; update the beginning of the
call(...) function to explicitly reject any non-None charge_spin by raising a
clear exception (e.g., ValueError) with a message like "charge_spin must be None
for DPA4 descriptor" so callers are not silently misled; place this check in the
compute DPA4 descriptor function (the call function) before any descriptor
computation or branching that would use other inputs.
🧹 Nitpick comments (1)
source/tests/pt_expt/descriptor/test_dpa4.py (1)

63-66: Downgrade dtype/device concerns in test_dpa4.py

  • torch.tensor(..., dtype=int) is accepted by PyTorch (mapped to the default integer dtype, typically torch.int64), so it’s not an API-breaking issue before the descriptor is exercised.
  • DescrptDPA4.deserialize(dd0.serialize()) uses the pt_expt conversion path (dpmodel_setattr) to materialize numpy arrays onto pt_expt.utils.env.DEVICE, matching self.device, so the round-trip doesn’t require an extra .to(self.device).

Optional: switch dtype=intdtype=torch.int64 for explicitness/consistency with the rest of the codebase.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/pt_expt/descriptor/test_dpa4.py` around lines 63 - 66, The
tensors atype_ext, nlist, and mapping are created with dtype=int which relies on
PyTorch's default integer mapping; update their creation to use an explicit
integer dtype (dtype=torch.int64) for consistency with the codebase and keep
using the existing device handling (do not add extra .to(self.device) calls)
because DescrptDPA4.deserialize(...) and dpmodel_setattr already materialize
arrays onto pt_expt.utils.env.DEVICE; modify the tensor constructions for
atype_ext, nlist, and mapping to use dtype=torch.int64 while leaving device
logic unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt_expt/model/get_model.py`:
- Around line 152-155: Replace the use of setdefault for the nested type fields
in get_sezm_model/get_model logic: when model type is the fixed DPA4/SeZM
("dpa4" or the canonical name you expect) explicitly set
data["descriptor"]["type"] = "dpa4" and data["fitting_net"]["type"] =
"dpa4_ener" (or, if you prefer strictness, check if those keys exist and raise a
ValueError if their values differ), rather than leaving preexisting values in
place; update the code that currently calls data.setdefault("descriptor", {}) /
data.setdefault("fitting_net", {}) and data["..."].setdefault("type", ...) to
either assign the canonical strings or validate-and-reject so the SeZM path
always uses the expected descriptor and fitting types.

In `@source/tests/pt_expt/descriptor/test_dpa4.py`:
- Around line 72-73: The deserialized descriptor dd1 is left on the default
device causing device-mismatch errors when inputs are on self.device; after
calling DescrptDPA4.deserialize(...) and before invoking dd1(coord_ext,
atype_ext, nlist, mapping), move/transfer dd1 to the test device (self.device)
by calling the descriptor's device-transfer method (e.g., dd1.to(self.device) or
equivalent) so the round-trip check runs on the same CUDA device as the inputs.

In `@source/tests/pt_expt/test_training.py`:
- Around line 314-321: Add a 60-second timeout to the new
test_training_loop_dpa4 by decorating the test with `@pytest.mark.timeout`(60)
(add "import pytest" if not already present) so the test is capped like other
training validations; update the function definition for test_training_loop_dpa4
and ensure the pytest timeout marker is available in the test file and test
runner.

In `@source/tests/pt/model/test_dpa4_ptexpt_grad_parity.py`:
- Around line 129-137: The constructed models are left on the default device
while test tensors use PT_DEVICE; move both models to PT_DEVICE after creation
to avoid device-mismatch failures. Specifically, after creating pt_mod
(DescrptSeZM(**kwargs).double()) call .to(PT_DEVICE) before perturbing
parameters, and after deserializing expt_mod (DescrptDPA4.deserialize(...)) also
call .to(PT_DEVICE); ensure you reference PT_DEVICE, pt_mod, expt_mod,
DescrptSeZM and DescrptDPA4.deserialize when making these changes.

---

Outside diff comments:
In `@deepmd/dpmodel/descriptor/dpa4.py`:
- Around line 768-790: The call method in dpa4.py currently ignores a non-None
charge_spin; update the beginning of the call(...) function to explicitly reject
any non-None charge_spin by raising a clear exception (e.g., ValueError) with a
message like "charge_spin must be None for DPA4 descriptor" so callers are not
silently misled; place this check in the compute DPA4 descriptor function (the
call function) before any descriptor computation or branching that would use
other inputs.

---

Nitpick comments:
In `@source/tests/pt_expt/descriptor/test_dpa4.py`:
- Around line 63-66: The tensors atype_ext, nlist, and mapping are created with
dtype=int which relies on PyTorch's default integer mapping; update their
creation to use an explicit integer dtype (dtype=torch.int64) for consistency
with the codebase and keep using the existing device handling (do not add extra
.to(self.device) calls) because DescrptDPA4.deserialize(...) and dpmodel_setattr
already materialize arrays onto pt_expt.utils.env.DEVICE; modify the tensor
constructions for atype_ext, nlist, and mapping to use dtype=torch.int64 while
leaving device logic unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: c3eea030-d041-4f20-a1b8-fa4df73e938b

📥 Commits

Reviewing files that changed from the base of the PR and between 5d94bd6 and b1920ee.

📒 Files selected for processing (27)
  • deepmd/dpmodel/array_api.py
  • deepmd/dpmodel/descriptor/dpa4.py
  • deepmd/dpmodel/descriptor/dpa4_nn/activation.py
  • deepmd/dpmodel/descriptor/dpa4_nn/edge_cache.py
  • deepmd/dpmodel/descriptor/dpa4_nn/embedding.py
  • deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py
  • deepmd/dpmodel/descriptor/dpa4_nn/indexing.py
  • deepmd/dpmodel/descriptor/dpa4_nn/norm.py
  • deepmd/dpmodel/descriptor/dpa4_nn/projection.py
  • deepmd/dpmodel/descriptor/dpa4_nn/radial.py
  • deepmd/dpmodel/descriptor/dpa4_nn/so2.py
  • deepmd/dpmodel/descriptor/dpa4_nn/so3.py
  • deepmd/dpmodel/descriptor/dpa4_nn/wignerd.py
  • deepmd/dpmodel/fitting/dpa4_ener.py
  • deepmd/pt_expt/common.py
  • deepmd/pt_expt/descriptor/__init__.py
  • deepmd/pt_expt/descriptor/dpa4.py
  • deepmd/pt_expt/fitting/__init__.py
  • deepmd/pt_expt/fitting/dpa4_ener.py
  • deepmd/pt_expt/model/get_model.py
  • source/tests/consistent/descriptor/test_dpa4.py
  • source/tests/consistent/fitting/test_dpa4_ener.py
  • source/tests/pt/model/test_dpa4_ptexpt_grad_parity.py
  • source/tests/pt_expt/descriptor/test_dpa4.py
  • source/tests/pt_expt/fitting/test_dpa4_ener.py
  • source/tests/pt_expt/model/test_get_model_dpa4.py
  • source/tests/pt_expt/test_training.py

Comment thread deepmd/pt_expt/model/get_model.py
Comment thread source/tests/pt_expt/descriptor/test_dpa4.py
Comment thread source/tests/pt_expt/test_training.py
Comment thread source/tests/pt/model/test_dpa4_ptexpt_grad_parity.py
@codecov

codecov Bot commented Jun 12, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 98.54369% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.21%. Comparing base (0de53e9) to head (56793cf).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/pt_expt/fitting/dpa4_ener.py 94.11% 2 Missing ⚠️
deepmd/pt_expt/descriptor/dpa4.py 98.21% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5522      +/-   ##
==========================================
+ Coverage   82.19%   82.21%   +0.02%     
==========================================
  Files         891      893       +2     
  Lines      101599   101746     +147     
  Branches     4241     4240       -1     
==========================================
+ Hits        83507    83654     +147     
  Misses      16789    16789              
  Partials     1303     1303              

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Jun 13, 2026
@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Jun 14, 2026
@wanghan-iapcm wanghan-iapcm enabled auto-merge June 14, 2026 03:40

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/dpmodel/descriptor/dpa4.py (1)

768-790: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Reject non-None charge_spin inputs instead of silently ignoring them.

Line 768 adds the compatibility argument, but the implementation never uses it. Since add_chg_spin_ebd=True is already rejected in __init__, accepting a non-None charge_spin here returns the unconditioned descriptor with no signal to the caller. A fail-fast guard would keep the runtime contract aligned with the rest of the unsupported charge/spin path.

💡 Proposed fix
     def call(
         self,
         coord_ext: Array,
         atype_ext: Array,
         nlist: Array,
         mapping: Array | None = None,
         fparam: Array | None = None,
         comm_dict: dict | None = None,
         charge_spin: Array | None = None,
     ) -> tuple[Array, Any, Any, Any, Any]:
         """Compute the DPA4 descriptor.
@@
+        if charge_spin is not None:
+            raise NotImplementedError(
+                "charge_spin is unsupported for DescrptDPA4 when "
+                "add_chg_spin_ebd=False"
+            )
         xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/descriptor/dpa4.py` around lines 768 - 790, The forward method
in the DPA4 descriptor class accepts charge_spin as a parameter but silently
ignores non-None values, which is inconsistent with the documented behavior that
it must be None. Add a validation check at the beginning of the forward method
that raises an error (such as ValueError) if charge_spin is not None, ensuring
fail-fast behavior that aligns with the rejection of add_chg_spin_ebd=True in
__init__ and the stated interface contract in the docstring.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@deepmd/dpmodel/descriptor/dpa4.py`:
- Around line 768-790: The forward method in the DPA4 descriptor class accepts
charge_spin as a parameter but silently ignores non-None values, which is
inconsistent with the documented behavior that it must be None. Add a validation
check at the beginning of the forward method that raises an error (such as
ValueError) if charge_spin is not None, ensuring fail-fast behavior that aligns
with the rejection of add_chg_spin_ebd=True in __init__ and the stated interface
contract in the docstring.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 6ebf72b8-87d1-4f68-943f-3aab3ef171d5

📥 Commits

Reviewing files that changed from the base of the PR and between 79bce06 and 56793cf.

📒 Files selected for processing (1)
  • deepmd/dpmodel/descriptor/dpa4.py

@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Jun 14, 2026
Merged via the queue into deepmodeling:master with commit 2cbef36 Jun 15, 2026
70 checks passed
@wanghan-iapcm wanghan-iapcm deleted the feat-ptexpt-dpa4 branch June 15, 2026 10:38
njzjz pushed a commit to njzjz/deepmd-kit that referenced this pull request Jun 17, 2026
…eepmodeling#5540)

PR-3 (final) of the DPA4/SeZM porting series — pt_expt **inference**:
freeze to `.pt2`, Python `DeepEval`, pt→pt_expt checkpoint interop, C++
single-rank, and LAMMPS single-rank. Follows PR-1 (deepmodeling#5515, dpmodel core)
and PR-2 (deepmodeling#5522, pt_expt training/export).

## What's included

- **Model freeze to `.pt2`** (`deepmd/pt_expt/model/ener_model.py`,
`deepmd/dpmodel/.../ener_model.py`,
`deepmd/dpmodel/descriptor/dpa4.py`): register `EnergyModel` under
`sezm_ener`/`dpa4_ener` model-type aliases so `BaseModel.deserialize`
resolves a standard DPA4 energy model (whose fitting type is
`sezm_ener`). Fixed a `torch.export` specialization where `int()` on
symbolic shapes baked `nf*nloc` (embedding/so2/attention).
- **NoPBC export fix** (`deepmd/dpmodel/descriptor/dpa4.py`): the
`atype_ext[:, :nloc]` slice emitted a spurious `Ne(nall, nloc)` shape
guard that crashed the compiled artifact when `nall==nloc` (no ghosts);
replaced with `xp_take_first_n` (index_select). NoPBC now matches PBC.
- **pt→pt_expt checkpoint interop** (`deepmd/pt_expt/model/model.py`):
`BaseModel.deserialize` unwraps pt's bespoke `SeZMModel` serialization
(`type:"SeZM"`, nested `sezm_atomic` atomic model with the pt-only dens
head), validates versions, rejects unsupported features
(bridging/lora/dens/active_mode) with `NotImplementedError`, and
delegates to the standard path.
- **Warn on silently-ignored flags** (`use_amp` descriptor,
`enable_tf32` model): warn-once instead of silent drop.

## Tests

- **Model freeze** `source/tests/pt_expt/model/test_dpa4_export.py`:
dual-artifact `.pt2`, metadata, AOTI load, artifact-vs-eager parity
(1e-10). *(CI-skipped — AOTI is slow; run locally.)*
- **DeepEval parity vs pt**
`source/tests/pt_expt/infer/test_dpa4_deep_eval.py`: pt `.pt` vs pt_expt
`.pt2` energy/force/global-virial/atom-energy at fp64 1e-10, **PBC and
NoPBC**; doubles as the checkpoint-interop proof. Per-atom virial
compared by sum (pt's edge-scatter from deepmodeling#5518 redistributes it; global
virial matches). *(CI-skipped — AOTI.)*
- **Interop unit tests**
`source/tests/pt_expt/model/test_dpa4_interop.py` (CI-runnable, no
AOTI): happy-path pt-checkpoint→pt_expt round-trip + every guard branch
+ version validation + `@variables` filtering.
- **Alias deserialize guard** + **use_amp/enable_tf32 warn-once** tests
(CI-runnable).
- **Fixture generator** `source/tests/infer/gen_dpa4.py` (+ wired into
`source/install/test_cc_local.sh`).
- **C++ single-rank** `source/api_cc/tests/test_deeppot_dpa4_ptexpt.cc`:
20 tests (double+float), dpa3-matched tolerances. Validated locally.
- **LAMMPS single-rank** `source/lmp/tests/test_lammps_dpa4_pt2.py`:
parity + `atom_modify map yes` + the deepmodeling#5450 no-atom-map fail-fast.
**Validated on a GPU box (7 passed).**

PR-1 parity suites stay green; the small dpmodel edits are
parity-revalidated.

## Known limitations

- **Single-rank only.** Multi-rank/MPI LAMMPS for DPA4 is deferred (no
live multi-rank cell; the with-comm artifact compiles but its runtime is
not exercised). DPA4 is a message-passing descriptor, so multi-rank
follows the existing deepmodeling#5450/deepmodeling#5430 machinery in a later PR.
- **No `.pth` (torch.jit) DPA4** — the pt backend has no `sezm_ener`
*model* registration, so `.pth` freeze of a standard DPA4 energy model
isn't available; not needed for the pt_expt inference path.
- **Per-atom virial** is not compared element-wise pt-vs-pt_expt (only
its global sum) — deepmodeling#5518 changed pt's edge-scatter distribution; both are
correct, the distribution differs.
- **AOTI tests are CI-skipped** (multi-minute compile) — the
freeze/DeepEval paths are validated locally, not in CI; the
interop/alias/warn tests give CI coverage of the non-AOTI logic.
- **fp64 only**; fp32 freeze untested. CUDA validated at LAMMPS level on
a GPU box; the AOTI parity numbers are from CPU fp64.
- **`use_amp`/`enable_tf32`** remain functionally ignored (now warned) —
by design for this series.
- pt SeZM features out of scope (guarded `NotImplementedError`): spin,
ZBL bridging, LoRA, dens/direct-force/denoising heads, SO3 grid
projection, GridMLP, SO(2) attention extensions.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

# Release Notes

* **New Features**
* Enabled DPA4 model inference via the pt_expt backend using
dual-artifact compilation.
* Registered the EnergyModel under additional aliases: `sezm_ener` and
`dpa4_ener`.

* **Improvements**
* Improved dynamic/symbolic shape handling across DPA4 components for
export/tracing stability.
* Enhanced pt SeZM/DPA4 checkpoint deserialization and normalization for
interoperability.
* Added one-time warnings when `use_amp` or `enable_tf32` settings are
ineffective.

* **Tests**
* Added C++ and Python coverage for pt2 inference, LAMMPS integration,
model export/freeze, parity, interop, and warning behavior.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants