Skip to content

feat(dpmodel): port DPA4/SeZM SO3 grid projection (PR-A: dpmodel core)#5547

Closed
wanghan-iapcm wants to merge 13 commits into
deepmodeling:masterfrom
wanghan-iapcm:feat-dpmodel-dpa4-so3grid
Closed

feat(dpmodel): port DPA4/SeZM SO3 grid projection (PR-A: dpmodel core)#5547
wanghan-iapcm wants to merge 13 commits into
deepmodeling:masterfrom
wanghan-iapcm:feat-dpmodel-dpa4-so3grid

Conversation

@wanghan-iapcm

@wanghan-iapcm wanghan-iapcm commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator

What

Ports the SO3 grid projection feature family of the DPA4/SeZM descriptor to the backend-agnostic dpmodel layer (and, via auto-wrapping, pt_expt). This is PR-A of a 2-PR series; it lands the dpmodel math + descriptor wiring + parity tests. PR-B will add pt_expt training Parameter-promotion, torch.export/make_fx coverage, training e2e, and argcheck doc updates.

This is the follow-up after the merged DPA4 core series (#5515 / #5522 / #5540), which deliberately scoped to "energy model, no extensions". With this PR the flagship examples/water/dpa4/input.json config (which sets ffn_so3_grid=true, message_node_so3=true, grid_mlp) becomes runnable on dpmodel/pt_expt.

Newly ported (dpmodel descriptor/dpa4_nn/)

  • projection.py: SO3GridProjector (init-time Wigner-D quadrature over a Lebedev×γ rotation grid), resolve_so3_grid, _build_so3_frame_set.
  • grid_net.py: SO3GridNet (self + cross), GridMLP (op_type='mlp'), FrameExpand/FrameContract per-degree frame mixers; filled the previously-guarded BaseGridNet branches: mode='cross', layout='flat', residual_scale_init, and n_frames>1. This also enables S2 cross-mode (node_wise_s2/message_node_s2) and GridMLP for S2.
  • ffn.py / so2.py: wired the grid nets into the FFN (ffn_so3_grid) and SO2Convolution (node_wise_* / message_node_*), removing the fail-fast guards. Out-of-scope flags (layer_scale, atten_*_proj, so2_attn_res, e3nn/non-Lebedev grid) remain guarded.

All forward math is array-API (no np.einsum on tensors; frame-axis to/from-grid done via xp.matmul+reshape).

Tests

  • Component parity vs pt (weight-copied, fp64): SO3GridProjector matrices match pt to 1e-12; FrameExpand/FrameContract, GridMLP, SO3GridNet (self+cross, glu/mlp/branch, kmax 1&2) to 1e-12; rotation equivariance to ~1e-14.
  • Stale NIE guard-tests in test_dpa4_dpmodel_parity.py converted to real pt-parity tests at every level (S2Grid/SO2/FFN/Block); added kmax=2, mmax<lmax, and an fp32 grid case.
  • Full-descriptor pt→dpmodel parity via DescrptDPA4.deserialize(pt.serialize()) on the example config (lmax=3, mmax=1) — ~1e-15 — which also proves dp convert-backend schema interop.
  • Descriptor permutation-invariance + masked-edge no-op with grid flags on.
  • Cross-backend consistency rows (pt vs dpmodel and pt_expt, mixed_types) for ffn_so3_grid / message_node_so3 / both / grid_mlp.

pt_expt forward inference works today via auto-wrapping (consistency + descriptor trio green).

Known limitations

  • pt_expt training Parameter-promotion for the new weight-bearing grid classes (_TRAINABLE_ATTRS), torch.export/AOTI coverage of the grid path, training e2e, and argcheck "only pt supported" doc removal are PR-B.
  • fp32 grid paths use a loose (~1e-4) budget by design (grid reductions over many Lebedev points); fp64 is the parity reference.
  • grid_method='e3nn' (non-Lebedev product grid) stays fail-fast (Lebedev-only, per the parent design).
  • CUDA ULP nondeterminism in any scatter/index path → tolerance tests, never bit-equality.

Summary by CodeRabbit

  • New Features

    • Added support for SO(3) equivariant grids in DPA4 descriptor components, including FFN and convolution layers.
    • Introduced GridMLP operator type for grid-based operations.
    • Added frame channel mixing components for equivariant processing.
  • Bug Fixes

    • Enabled previously unimplemented SO(3) grid paths and cross-mode grid products.
  • Tests

    • Expanded test coverage for new grid-based features and SO(3) equivariance validation.

@dosubot dosubot Bot added the new feature label Jun 17, 2026
@wanghan-iapcm wanghan-iapcm requested a review from OutisLi June 17, 2026 16:23
@wanghan-iapcm wanghan-iapcm added the Test CUDA Trigger test CUDA workflow label Jun 17, 2026
@github-actions github-actions Bot removed the Test CUDA Trigger test CUDA workflow label Jun 17, 2026
@coderabbitai

coderabbitai Bot commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Ports the DPA4 SO(3) Wigner-D grid-net infrastructure from PyTorch to dpmodel. Adds SO3GridProjector, resolve_so3_grid, _build_so3_frame_set, GridMLP, FrameExpand/FrameContract, and SO3GridNet to grid_net.py/projection.py. Wires these into SO2Convolution (node-wise and message-node cross-mode grid products) and EquivariantFFN (SO3 grid activation). Removes corresponding NotImplementedError guards and adds comprehensive parity, equivariance, and serialization tests.

Changes

DPA4 SO3/S2 Grid-Net Port

Layer / File(s) Summary
SO3 grid projection utilities and SO3GridProjector
deepmd/dpmodel/descriptor/dpa4_nn/projection.py
Adds resolve_so3_grid, _build_so3_frame_set, and SO3GridProjector; builds to_grid_mat/from_grid_mat over a Lebedev×gamma quadrature grid using Wigner-D matrices; adds pt-compatible serialize/deserialize.
GridMLP operator and FrameExpand/FrameContract mixers
deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py
Adds _build_frame_degree_index, GridMLP (point-wise quadratic self/cross MLP grid operator), and FrameExpand/FrameContract (per-degree channel mixing via gathered weights and batched matmul), all with full serialize/deserialize support.
BaseGridNet generalization and SO3GridNet class
deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py
Extends BaseGridNet with optional frame_expand/frame_contract, mode=cross, op_type=mlp dispatch, and n_frames>1 scalar/projection handling; updates S2GridNet serialization; adds SO3GridNet wiring SO3GridProjector with optional FrameExpand/FrameContract.
SO2Convolution cross-mode grid product wiring
deepmd/dpmodel/descriptor/dpa4_nn/so2.py
Constructs node_wise_grid_product and message_node_grid_product submodules (selecting SO3GridNet over S2GridNet when both flags are set); applies them as residual updates in call(); extends _variables()/_load_variables() accordingly.
EquivariantFFN SO3 grid activation wiring
deepmd/dpmodel/descriptor/dpa4_nn/ffn.py
Removes ffn_so3_grid=True NotImplementedError; makes grid_n_frames conditional on ffn_so3_grid; branches activation construction between SO3GridNet and S2GridNet based on the flag.
Tests: SO3 grid utilities and SO3GridProjector parity
source/tests/common/dpmodel/test_dpa4_so3_grid_utils.py, source/tests/common/dpmodel/test_dpa4_so3_projector.py
Parity tests for resolve_so3_grid/_build_so3_frame_set vs pt reference; SO3GridProjector matrix parity, legal-slot roundtrip, serialize/deserialize, and kmax=0 zonal convention.
Tests: FrameExpand/FrameContract and GridMLP parity/equivariance
source/tests/common/dpmodel/test_dpa4_frame_mixers.py, source/tests/common/dpmodel/test_dpa4_grid_mlp.py
Parity and serialize roundtrip tests for FrameContract/FrameExpand; _build_frame_degree_index correctness; GridMLP self/cross parity and roundtrip; S2GridNet op_type=mlp parity and SO(3) equivariance.
Tests: S2GridNet cross-mode and SO3GridNet parity/equivariance
source/tests/common/dpmodel/test_dpa4_gridnet_cross.py, source/tests/common/dpmodel/test_dpa4_so3_gridnet.py
S2GridNet cross-mode parity across op_type/layouts/residual_scale; SO3GridNet self/cross parity, equivariance, flat-layout, serialize roundtrip, torch-namespace, and S2 regression.
Tests: descriptor-level wiring, invariance, parity, and guard cleanup
source/tests/common/dpmodel/test_descrpt_dpa4.py, source/tests/common/dpmodel/test_dpa4_grid_wiring.py, source/tests/consistent/descriptor/test_dpa4.py, source/tests/pt/model/test_dpa4_dpmodel_parity.py
Removes stale NotImplementedError guard entries; adds grid wiring construction/run, SO3-over-S2 precedence, permutation/masking invariance, and pt parity tests; extends consistent test matrix with ffn_so3_grid/message_node_so3/grid_mlp; updates S2GridNet/SO2/FFN/block parity helpers and guard lists.

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5503: Implements the PT-side sezm_nn EquivariantFFN SO(3) Wigner-D grid path (ffn_so3_grid, SO3GridNet, grid-net/projection support) that this PR ports to dpmodel.
  • deepmodeling/deepmd-kit#5515: The initial core dpmodel port where ffn_so3_grid=True still raised NotImplementedError; this PR directly implements that path.
  • deepmodeling/deepmd-kit#5522: Modifies the same _to_grid/_from_grid projection-matrix conversion paths in grid_net.py (switching to xp_asarray_nodetach) that this PR also refactors for SO(3) n_frames>1 support.

Suggested labels

enhancement, Core

Suggested reviewers

  • iProzd
  • njzjz
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title accurately identifies the main feature: SO3 grid projection porting from PT to dpmodel for DPA4/SeZM, matching the comprehensive changes across projection, grid_net, and ffn modules.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai

coderabbitai Bot commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Caution

Failed to replace (edit) comment. This is likely due to insufficient permissions or the comment being deleted.

Error details
{}

@codecov

codecov Bot commented Jun 17, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 94.85294% with 21 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.28%. Comparing base (9aaa9e7) to head (401890f).
⚠️ Report is 3 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py 93.63% 18 Missing ⚠️
deepmd/dpmodel/descriptor/dpa4_nn/projection.py 96.47% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5547      +/-   ##
==========================================
+ Coverage   82.23%   82.28%   +0.05%     
==========================================
  Files         894      894              
  Lines      101992   102392     +400     
  Branches     4273     4273              
==========================================
+ Hits        83873    84256     +383     
- Misses      16817    16833      +16     
- Partials     1302     1303       +1     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (5)
source/tests/consistent/descriptor/test_dpa4.py (1)

94-102: ⚡ Quick win

Add a curated SO3 grid_mlp consistency case.

The matrix tests SO3 grid paths and grid_mlp separately, but the flagship path combines ffn_so3_grid/message_node_so3 with grid_mlp; because positive grid_branch takes precedence, add a combined case with grid_branch=[0, 0, 0].

Suggested curated case
     # grid MLP point-wise op (op_type='mlp'); needs grid_branch=0 on the path,
     # since positive grid_branch entries take precedence over grid_mlp
     dpa4_case(grid_branch=[0, 0, 0], grid_mlp=True),
+    # flagship SO(3) grid + GridMLP combination
+    dpa4_case(
+        ffn_so3_grid=True,
+        message_node_so3=True,
+        grid_branch=[0, 0, 0],
+        grid_mlp=True,
+    ),
 )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/consistent/descriptor/test_dpa4.py` around lines 94 - 102, The
test suite needs to add a curated case that combines the SO3 grid paths with
grid_mlp to match the flagship path behavior. After the existing dpa4_case calls
in the test_dpa4.py file, add a new test case that enables both ffn_so3_grid and
message_node_so3 together with grid_branch=[0, 0, 0] and grid_mlp=True to ensure
the combined path where SO3 Wigner-D grids work together with grid MLP is
properly validated.
source/tests/common/dpmodel/test_dpa4_grid_wiring.py (1)

231-245: ⚡ Quick win

Add full-descriptor parity for node_wise_so3.

node_wise_so3=True only has a construct/run smoke test, while the PT weight-copy parity matrix covers node_wise_s2 but not the SO3 cross-mode node-wise wiring. Add the analogous parity case so frame-expand/contract and m-major node-wise wiring regressions are caught at descriptor level.

Suggested parity case
     def test_parity_node_wise_s2(self) -> None:
         pt_mod, dp_mod = self._build_pair(node_wise_s2=True)
         self._assert_parity(pt_mod, dp_mod)
+
+    def test_parity_node_wise_so3(self) -> None:
+        pt_mod, dp_mod = self._build_pair(node_wise_so3=True)
+        self._assert_parity(pt_mod, dp_mod)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/common/dpmodel/test_dpa4_grid_wiring.py` around lines 231 - 245,
Add a new parity test method for node_wise_so3 to match the existing
test_parity_node_wise_s2 method. Create a test method called
test_parity_node_wise_so3 that follows the same pattern as
test_parity_node_wise_s2 but passes node_wise_so3=True instead of
node_wise_s2=True to the _build_pair method, then calls _assert_parity on the
resulting pt_mod and dp_mod. This ensures the SO3 cross-mode node-wise wiring is
covered by the full descriptor parity test matrix.
source/tests/pt/model/test_dpa4_dpmodel_parity.py (1)

3244-3253: ⚡ Quick win

Add an explicit ffn_so3_grid guard assertion for lebedev_quadrature=False.

This guard test currently pins only the S2 branch; adding the SO3 branch closes the not-ported-path coverage for the new FFN SO3 route.

Suggested patch
     with pytest.raises(NotImplementedError, match="lebedev_quadrature"):
         DPFFN(
             **self._ffn_kwargs(s2_activation=True, lebedev_quadrature=False),
             precision="float64",
         )
+    with pytest.raises(NotImplementedError, match="lebedev_quadrature"):
+        DPFFN(
+            **self._ffn_kwargs(ffn_so3_grid=True, lebedev_quadrature=False),
+            precision="float64",
+        )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/pt/model/test_dpa4_dpmodel_parity.py` around lines 3244 - 3253,
The test_ffn_guards method currently only covers the S2 branch when
lebedev_quadrature is False. Add an additional pytest.raises assertion block
that tests the SO3 grid branch (ffn_so3_grid enabled) with
lebedev_quadrature=False set to False to ensure complete coverage of the
not-implemented paths for both the S2 and SO3 routes. Use the same DPFFN class
initialization pattern with the _ffn_kwargs helper method, ensuring both test
cases verify that the NotImplementedError with lebedev_quadrature in the error
message is properly raised.
source/tests/common/dpmodel/test_dpa4_grid_mlp.py (1)

177-203: ⚡ Quick win

Make mlp_bias explicit in S2GridNet MLP tests.

These tests currently rely on constructor defaults while _copy_s2gridnet_mlp only copies weight tensors. Setting mlp_bias=False explicitly avoids brittle, default-dependent parity behavior.

Suggested change
     pt_net = PTS2GridNet(
         lmax=lmax,
         channels=channels,
         n_focus=n_focus,
         mode="self",
         op_type="mlp",
+        mlp_bias=False,
         dtype=torch.float64,
         layout="ndfc",
         coefficient_layout="packed",
         grid_method="lebedev",
         trainable=False,
         seed=17 + lmax,
     ).to("cpu")
     dp_net = S2GridNet(
         lmax=lmax,
         channels=channels,
         n_focus=n_focus,
         mode="self",
         op_type="mlp",
+        mlp_bias=False,
         precision="float64",
         layout="ndfc",
         coefficient_layout="packed",
         grid_method="lebedev",
         trainable=False,
         seed=17 + lmax,
     )
@@
     dp_net = S2GridNet(
         lmax=lmax,
         channels=channels,
         n_focus=n_focus,
         mode="self",
         op_type="mlp",
+        mlp_bias=False,
         precision="float64",
         layout="ndfc",
         coefficient_layout="packed",
         grid_method="lebedev",
         trainable=False,
         seed=31 + lmax,
     )

Also applies to: 225-237

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/common/dpmodel/test_dpa4_grid_mlp.py` around lines 177 - 203,
The S2GridNet constructor calls in the test are relying on default values for
the mlp_bias parameter while _copy_s2gridnet_mlp only copies weight tensors
without biases, creating brittle default-dependent behavior. Add mlp_bias=False
explicitly as a parameter to both S2GridNet constructor calls (the one around
line 177 and the one around line 225) to make the intent clear and avoid relying
on constructor defaults.
source/tests/common/dpmodel/test_dpa4_frame_mixers.py (1)

184-205: ⚡ Quick win

Cover _build_frame_degree_index’s untested branches.

test_degree_index currently validates only packed with mmax=lmax, but the new implementation also branches on m_major and truncated packed (mmax<lmax). Add those cases so regressions in the new paths are caught.

Suggested test expansion
-@pytest.mark.parametrize("lmax", [1, 2, 3])  # max angular momentum
-def test_degree_index(lmax) -> None:
+@pytest.mark.parametrize(
+    "lmax,mmax,coefficient_layout",
+    [
+        (1, 1, "packed"),
+        (2, 2, "packed"),
+        (3, 1, "packed"),   # truncated packed path
+        (3, 1, "m_major"),  # m_major path
+    ],
+)
+def test_degree_index(lmax, mmax, coefficient_layout) -> None:
     from deepmd.pt.model.descriptor.sezm_nn.grid_net import (
         _build_frame_degree_index as pt_build_frame_degree_index,
     )
 
     dp_idx = _build_frame_degree_index(
-        lmax=lmax, mmax=lmax, coefficient_layout="packed"
+        lmax=lmax, mmax=mmax, coefficient_layout=coefficient_layout
     )
     pt_idx = (
-        pt_build_frame_degree_index(lmax=lmax, mmax=lmax, coefficient_layout="packed")
+        pt_build_frame_degree_index(
+            lmax=lmax, mmax=mmax, coefficient_layout=coefficient_layout
+        )
         .detach()
         .cpu()
         .numpy()
     )
-    assert dp_idx.shape == ((lmax + 1) ** 2,)
+    assert dp_idx.shape == pt_idx.shape
     np.testing.assert_array_equal(dp_idx, pt_idx)
-    # each (l, m) row maps to degree l: row d has degree dp_idx[d]
-    expected = np.concatenate(
-        [np.full(2 * l + 1, l, dtype=np.int64) for l in range(lmax + 1)]
-    )
-    np.testing.assert_array_equal(dp_idx, expected)
+    if coefficient_layout == "packed" and mmax == lmax:
+        expected = np.concatenate(
+            [np.full(2 * l + 1, l, dtype=np.int64) for l in range(lmax + 1)]
+        )
+        np.testing.assert_array_equal(dp_idx, expected)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/common/dpmodel/test_dpa4_frame_mixers.py` around lines 184 -
205, The test_degree_index function currently only validates the packed
coefficient layout with mmax equal to lmax, leaving untested branches for the
m_major layout and truncated packed cases where mmax is less than lmax. Extend
the test by adding parametrization for different coefficient_layout values
(packed and m_major) and varying mmax values (including cases where mmax is less
than lmax). For each combination, call both _build_frame_degree_index and
pt_build_frame_degree_index with the respective parameters, then assert that the
shapes match and that the numpy arrays are equal, ensuring regression detection
across all code paths.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py`:
- Around line 1035-1038: The code reshapes scalar_gate.bias without validating
that the loaded variable has the correct shape, which could silently accept
wrong-layout biases with the same total size and corrupt parameters. Add an
explicit shape check before the reshape operation in the _load_variables method
to ensure the loaded variables["scalar_gate.bias"] matches the expected shape of
self.scalar_gate.bias.shape, and raise an appropriate error if the shapes do not
match.

In `@deepmd/dpmodel/descriptor/dpa4_nn/projection.py`:
- Line 436: The precision lookup in the SO3 serialization code uses
PRECISION_DICT[self.precision] directly without normalizing the case, while
initialization accepts mixed-case precision names through .lower(). This
inconsistency causes KeyError for inputs like "Float64". Fix this by normalizing
self.precision to lowercase when accessing PRECISION_DICT, changing
PRECISION_DICT[self.precision] to PRECISION_DICT[self.precision.lower()] to
match the initialization behavior and ensure consistent handling of mixed-case
precision names.

---

Nitpick comments:
In `@source/tests/common/dpmodel/test_dpa4_frame_mixers.py`:
- Around line 184-205: The test_degree_index function currently only validates
the packed coefficient layout with mmax equal to lmax, leaving untested branches
for the m_major layout and truncated packed cases where mmax is less than lmax.
Extend the test by adding parametrization for different coefficient_layout
values (packed and m_major) and varying mmax values (including cases where mmax
is less than lmax). For each combination, call both _build_frame_degree_index
and pt_build_frame_degree_index with the respective parameters, then assert that
the shapes match and that the numpy arrays are equal, ensuring regression
detection across all code paths.

In `@source/tests/common/dpmodel/test_dpa4_grid_mlp.py`:
- Around line 177-203: The S2GridNet constructor calls in the test are relying
on default values for the mlp_bias parameter while _copy_s2gridnet_mlp only
copies weight tensors without biases, creating brittle default-dependent
behavior. Add mlp_bias=False explicitly as a parameter to both S2GridNet
constructor calls (the one around line 177 and the one around line 225) to make
the intent clear and avoid relying on constructor defaults.

In `@source/tests/common/dpmodel/test_dpa4_grid_wiring.py`:
- Around line 231-245: Add a new parity test method for node_wise_so3 to match
the existing test_parity_node_wise_s2 method. Create a test method called
test_parity_node_wise_so3 that follows the same pattern as
test_parity_node_wise_s2 but passes node_wise_so3=True instead of
node_wise_s2=True to the _build_pair method, then calls _assert_parity on the
resulting pt_mod and dp_mod. This ensures the SO3 cross-mode node-wise wiring is
covered by the full descriptor parity test matrix.

In `@source/tests/consistent/descriptor/test_dpa4.py`:
- Around line 94-102: The test suite needs to add a curated case that combines
the SO3 grid paths with grid_mlp to match the flagship path behavior. After the
existing dpa4_case calls in the test_dpa4.py file, add a new test case that
enables both ffn_so3_grid and message_node_so3 together with grid_branch=[0, 0,
0] and grid_mlp=True to ensure the combined path where SO3 Wigner-D grids work
together with grid MLP is properly validated.

In `@source/tests/pt/model/test_dpa4_dpmodel_parity.py`:
- Around line 3244-3253: The test_ffn_guards method currently only covers the S2
branch when lebedev_quadrature is False. Add an additional pytest.raises
assertion block that tests the SO3 grid branch (ffn_so3_grid enabled) with
lebedev_quadrature=False set to False to ensure complete coverage of the
not-implemented paths for both the S2 and SO3 routes. Use the same DPFFN class
initialization pattern with the _ffn_kwargs helper method, ensuring both test
cases verify that the NotImplementedError with lebedev_quadrature in the error
message is properly raised.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: c8ddbcf2-2307-4c9c-aab7-09c1b7a46e8c

📥 Commits

Reviewing files that changed from the base of the PR and between 4a552e3 and 401890f.

📒 Files selected for processing (14)
  • deepmd/dpmodel/descriptor/dpa4_nn/ffn.py
  • deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py
  • deepmd/dpmodel/descriptor/dpa4_nn/projection.py
  • deepmd/dpmodel/descriptor/dpa4_nn/so2.py
  • source/tests/common/dpmodel/test_descrpt_dpa4.py
  • source/tests/common/dpmodel/test_dpa4_frame_mixers.py
  • source/tests/common/dpmodel/test_dpa4_grid_mlp.py
  • source/tests/common/dpmodel/test_dpa4_grid_wiring.py
  • source/tests/common/dpmodel/test_dpa4_gridnet_cross.py
  • source/tests/common/dpmodel/test_dpa4_so3_grid_utils.py
  • source/tests/common/dpmodel/test_dpa4_so3_gridnet.py
  • source/tests/common/dpmodel/test_dpa4_so3_projector.py
  • source/tests/consistent/descriptor/test_dpa4.py
  • source/tests/pt/model/test_dpa4_dpmodel_parity.py
💤 Files with no reviewable changes (1)
  • source/tests/common/dpmodel/test_descrpt_dpa4.py

Comment on lines +1035 to +1038
if self.mlp_bias:
self.scalar_gate.bias = np.asarray(
variables["scalar_gate.bias"], dtype=prec
).reshape(self.scalar_gate.bias.shape)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate scalar_gate.bias shape before reshape in _load_variables.

scalar_gate.bias is reshaped without an explicit shape check, so a same-size but wrong-layout bias can be accepted and silently corrupt deserialized parameters.

Proposed fix
         if self.mlp_bias:
-            self.scalar_gate.bias = np.asarray(
-                variables["scalar_gate.bias"], dtype=prec
-            ).reshape(self.scalar_gate.bias.shape)
+            bias = np.asarray(variables["scalar_gate.bias"], dtype=prec)
+            if bias.shape != self.scalar_gate.bias.shape:
+                raise ValueError(
+                    f"scalar_gate.bias shape {bias.shape} does not match "
+                    f"the expected shape {self.scalar_gate.bias.shape}"
+                )
+            self.scalar_gate.bias = bias
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py` around lines 1035 - 1038, The
code reshapes scalar_gate.bias without validating that the loaded variable has
the correct shape, which could silently accept wrong-layout biases with the same
total size and corrupt parameters. Add an explicit shape check before the
reshape operation in the _load_variables method to ensure the loaded
variables["scalar_gate.bias"] matches the expected shape of
self.scalar_gate.bias.shape, and raise an appropriate error if the shapes do not
match.

"lmax": self.lmax,
"mmax": self.mmax,
"kmax": self.kmax,
"precision": np.dtype(PRECISION_DICT[self.precision]).name,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Normalize precision lookup in SO3 serialization.

Initialization accepts mixed-case precision names (.lower()), but serialization uses PRECISION_DICT[self.precision] directly and can raise KeyError for inputs like "Float64".

Proposed fix
-                "precision": np.dtype(PRECISION_DICT[self.precision]).name,
+                "precision": np.dtype(PRECISION_DICT[self.precision.lower()]).name,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"precision": np.dtype(PRECISION_DICT[self.precision]).name,
"precision": np.dtype(PRECISION_DICT[self.precision.lower()]).name,
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/descriptor/dpa4_nn/projection.py` at line 436, The precision
lookup in the SO3 serialization code uses PRECISION_DICT[self.precision]
directly without normalizing the case, while initialization accepts mixed-case
precision names through .lower(). This inconsistency causes KeyError for inputs
like "Float64". Fix this by normalizing self.precision to lowercase when
accessing PRECISION_DICT, changing PRECISION_DICT[self.precision] to
PRECISION_DICT[self.precision.lower()] to match the initialization behavior and
ensure consistent handling of mixed-case precision names.

@wanghan-iapcm

Copy link
Copy Markdown
Collaborator Author

Superseded by #5555. This PR ported the pre-#5552 grid design; after #5517/#5552 refactored the dpmodel/pt grid ops (_project_frames/GridProduct/ops-on-coefficients), #5555 re-derives the SO3 completion on top of master's current structure to keep dpmodel a faithful mirror of pt. SO3GridProjector + tests + the CUDA device-pin fix were carried over.

njzjz pushed a commit to njzjz/deepmd-kit that referenced this pull request Jun 19, 2026
… pt) (deepmodeling#5555)

## What

Completes the DPA4/SeZM **SO3 grid projection** port to the dpmodel
backend so it faithfully **mirrors master's current pt**
`sezm_nn/grid_net.py`. After this, the flagship
`examples/water/dpa4/input.json` (which sets `ffn_so3_grid=true`,
`message_node_so3=true`, `grid_mlp`) runs on dpmodel/pt_expt.

Builds **on top of** the S2-grid base that deepmodeling#5517/deepmodeling#5552 landed
(`GridProduct`/`GridMLP`/`op_type='mlp'`, the `_project_frames`
refactor). Master's dpmodel was the S2 (`n_frames==1`) slice with
SO3/cross-mode fail-fast guarded; this PR generalizes those ops to
frame-aware (`n_frames>1`) + cross-mode and adds the missing SO3 pieces
— matching current pt exactly (single source of truth: dpmodel == pt).

Supersedes deepmodeling#5547 (which ported the *pre*-deepmodeling#5552 design and went
structurally stale).

## Changes (all mirror current pt)

- **`grid_net.py`**: add `_project_frames`; generalize
`GridMLP`/`GridBranch` to frame-aware (`n_frames`); generalize
`BaseGridNet` (un-guard `mode='cross'`, `layout='flat'`,
`residual_scale_init`, `n_frames>1`; frame-axis to/from-grid via
`xp.matmul`+reshape); add
`FrameContract`/`FrameExpand`/`_build_frame_degree_index`; add
`SO3GridNet` (self+cross).
- **`projection.py`**: add `SO3GridProjector` (Wigner-D quadrature) +
`resolve_so3_grid`/`_build_so3_frame_set`.
- **`ffn.py`**: un-guard `ffn_so3_grid` → `SO3GridNet(mode='self')`.
- **`so2.py`**: un-guard `node_wise_{s2,so3}`/`message_node_{s2,so3}` →
cross-mode grid products, applied in `call` + round-tripped in
serialize.

## Validation

- Component parity vs pt (weight-copied fp64): `_project_frames`,
`GridMLP`/`GridBranch` (incl. S2 byte-identical regression),
`BaseGridNet` cross/flat/residual, `FrameContract`/`FrameExpand`,
`SO3GridProjector` matrices, `SO3GridNet` self+cross (op_type
glu/mlp/branch, kmax 1&2) — all **1e-12**; rotation equivariance
**1e-10**.
- **fp32** grid-path parity at the computation-in-fp32 budget (actual
diffs 1e-6–1e-8 ≪ 1e-4).
- Full-descriptor pt→dpmodel via
`DescrptDPA4.deserialize(pt.serialize())` on the example config (lmax=3,
mmax=1) — **~1e-14** — proving `dp convert-backend` schema interop.
- Permutation-invariance + masked-edge no-op.
- Cross-backend consistency rows (pt vs dpmodel **and pt_expt**,
mixed_types) for ffn_so3_grid / message_node_so3 / both / grid_mlp.
- **Verified on remote GPU (Tesla T4):** 617 (grid+parity+pt_expt) + 50
(consistency) pass, no CUDA device errors.

pt_expt forward works today via auto-wrap (consistency + descriptor trio
green) — no explicit registration needed.

## Known limitations

- pt_expt **training** Parameter-promotion for the new weight-bearing
grid classes, `torch.export`/AOTI grid coverage, training e2e, argcheck
`doc_only_pt_supported` removal, and freeze/DeepEval are a **follow-up
PR**.
- `grid_method='e3nn'` (non-Lebedev product grid) stays fail-fast
(Lebedev-only, per parent design).
- fp32 grid paths use a ~1e-4 budget by design; fp64 is the parity
reference.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added SO(3) grid projection support and frame-aware grid networks for
DPA4 descriptors, including SO(3)-based FFN and improved cross-mode
grid-product wiring.
* Extended grid modules to support multi-frame configurations with
per-degree frame mixing, and added full SO(3) projector/network
serialization.
* **Bug Fixes**
* Enabled previously disabled/unsupported DPA4 SO(2) convolution
cross-mode SO(3)/S2 grid products.
* **Documentation**
* Updated DPA4 porting-layer documentation to clarify supported
configuration flags.
* **Tests**
* Added/expanded parity, equivariance, serialization/roundtrip, and
torch-namespace compatibility tests for the new SO(3) and frame-aware
paths.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant