feat(dpmodel): port DPA4/SeZM SO3 grid projection (PR-A: dpmodel core)#5547
feat(dpmodel): port DPA4/SeZM SO3 grid projection (PR-A: dpmodel core)#5547wanghan-iapcm wants to merge 13 commits into
Conversation
pt's GridMLP hard-codes bias=False on all three channel-linear projections; the net-level mlp_bias only affects the scalar gate. Drop the threaded mlp_bias from GridMLP so dpmodel matches pt exactly, avoiding a divergent untested branch and a pt->dpmodel deserialize mismatch.
…ual_scale for DPA4
…riance/fp32 coverage
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughPorts the DPA4 SO(3) Wigner-D grid-net infrastructure from PyTorch to dpmodel. Adds ChangesDPA4 SO3/S2 Grid-Net Port
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Caution Failed to replace (edit) comment. This is likely due to insufficient permissions or the comment being deleted. Error details |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5547 +/- ##
==========================================
+ Coverage 82.23% 82.28% +0.05%
==========================================
Files 894 894
Lines 101992 102392 +400
Branches 4273 4273
==========================================
+ Hits 83873 84256 +383
- Misses 16817 16833 +16
- Partials 1302 1303 +1 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
…sts (CUDA CI fix)
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (5)
source/tests/consistent/descriptor/test_dpa4.py (1)
94-102: ⚡ Quick winAdd a curated SO3
grid_mlpconsistency case.The matrix tests SO3 grid paths and
grid_mlpseparately, but the flagship path combinesffn_so3_grid/message_node_so3withgrid_mlp; because positivegrid_branchtakes precedence, add a combined case withgrid_branch=[0, 0, 0].Suggested curated case
# grid MLP point-wise op (op_type='mlp'); needs grid_branch=0 on the path, # since positive grid_branch entries take precedence over grid_mlp dpa4_case(grid_branch=[0, 0, 0], grid_mlp=True), + # flagship SO(3) grid + GridMLP combination + dpa4_case( + ffn_so3_grid=True, + message_node_so3=True, + grid_branch=[0, 0, 0], + grid_mlp=True, + ), )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/consistent/descriptor/test_dpa4.py` around lines 94 - 102, The test suite needs to add a curated case that combines the SO3 grid paths with grid_mlp to match the flagship path behavior. After the existing dpa4_case calls in the test_dpa4.py file, add a new test case that enables both ffn_so3_grid and message_node_so3 together with grid_branch=[0, 0, 0] and grid_mlp=True to ensure the combined path where SO3 Wigner-D grids work together with grid MLP is properly validated.source/tests/common/dpmodel/test_dpa4_grid_wiring.py (1)
231-245: ⚡ Quick winAdd full-descriptor parity for
node_wise_so3.
node_wise_so3=Trueonly has a construct/run smoke test, while the PT weight-copy parity matrix coversnode_wise_s2but not the SO3 cross-mode node-wise wiring. Add the analogous parity case so frame-expand/contract and m-major node-wise wiring regressions are caught at descriptor level.Suggested parity case
def test_parity_node_wise_s2(self) -> None: pt_mod, dp_mod = self._build_pair(node_wise_s2=True) self._assert_parity(pt_mod, dp_mod) + + def test_parity_node_wise_so3(self) -> None: + pt_mod, dp_mod = self._build_pair(node_wise_so3=True) + self._assert_parity(pt_mod, dp_mod)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/common/dpmodel/test_dpa4_grid_wiring.py` around lines 231 - 245, Add a new parity test method for node_wise_so3 to match the existing test_parity_node_wise_s2 method. Create a test method called test_parity_node_wise_so3 that follows the same pattern as test_parity_node_wise_s2 but passes node_wise_so3=True instead of node_wise_s2=True to the _build_pair method, then calls _assert_parity on the resulting pt_mod and dp_mod. This ensures the SO3 cross-mode node-wise wiring is covered by the full descriptor parity test matrix.source/tests/pt/model/test_dpa4_dpmodel_parity.py (1)
3244-3253: ⚡ Quick winAdd an explicit
ffn_so3_gridguard assertion forlebedev_quadrature=False.This guard test currently pins only the S2 branch; adding the SO3 branch closes the not-ported-path coverage for the new FFN SO3 route.
Suggested patch
with pytest.raises(NotImplementedError, match="lebedev_quadrature"): DPFFN( **self._ffn_kwargs(s2_activation=True, lebedev_quadrature=False), precision="float64", ) + with pytest.raises(NotImplementedError, match="lebedev_quadrature"): + DPFFN( + **self._ffn_kwargs(ffn_so3_grid=True, lebedev_quadrature=False), + precision="float64", + )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/pt/model/test_dpa4_dpmodel_parity.py` around lines 3244 - 3253, The test_ffn_guards method currently only covers the S2 branch when lebedev_quadrature is False. Add an additional pytest.raises assertion block that tests the SO3 grid branch (ffn_so3_grid enabled) with lebedev_quadrature=False set to False to ensure complete coverage of the not-implemented paths for both the S2 and SO3 routes. Use the same DPFFN class initialization pattern with the _ffn_kwargs helper method, ensuring both test cases verify that the NotImplementedError with lebedev_quadrature in the error message is properly raised.source/tests/common/dpmodel/test_dpa4_grid_mlp.py (1)
177-203: ⚡ Quick winMake
mlp_biasexplicit in S2GridNet MLP tests.These tests currently rely on constructor defaults while
_copy_s2gridnet_mlponly copies weight tensors. Settingmlp_bias=Falseexplicitly avoids brittle, default-dependent parity behavior.Suggested change
pt_net = PTS2GridNet( lmax=lmax, channels=channels, n_focus=n_focus, mode="self", op_type="mlp", + mlp_bias=False, dtype=torch.float64, layout="ndfc", coefficient_layout="packed", grid_method="lebedev", trainable=False, seed=17 + lmax, ).to("cpu") dp_net = S2GridNet( lmax=lmax, channels=channels, n_focus=n_focus, mode="self", op_type="mlp", + mlp_bias=False, precision="float64", layout="ndfc", coefficient_layout="packed", grid_method="lebedev", trainable=False, seed=17 + lmax, ) @@ dp_net = S2GridNet( lmax=lmax, channels=channels, n_focus=n_focus, mode="self", op_type="mlp", + mlp_bias=False, precision="float64", layout="ndfc", coefficient_layout="packed", grid_method="lebedev", trainable=False, seed=31 + lmax, )Also applies to: 225-237
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/common/dpmodel/test_dpa4_grid_mlp.py` around lines 177 - 203, The S2GridNet constructor calls in the test are relying on default values for the mlp_bias parameter while _copy_s2gridnet_mlp only copies weight tensors without biases, creating brittle default-dependent behavior. Add mlp_bias=False explicitly as a parameter to both S2GridNet constructor calls (the one around line 177 and the one around line 225) to make the intent clear and avoid relying on constructor defaults.source/tests/common/dpmodel/test_dpa4_frame_mixers.py (1)
184-205: ⚡ Quick winCover
_build_frame_degree_index’s untested branches.
test_degree_indexcurrently validates onlypackedwithmmax=lmax, but the new implementation also branches onm_majorand truncatedpacked(mmax<lmax). Add those cases so regressions in the new paths are caught.Suggested test expansion
-@pytest.mark.parametrize("lmax", [1, 2, 3]) # max angular momentum -def test_degree_index(lmax) -> None: +@pytest.mark.parametrize( + "lmax,mmax,coefficient_layout", + [ + (1, 1, "packed"), + (2, 2, "packed"), + (3, 1, "packed"), # truncated packed path + (3, 1, "m_major"), # m_major path + ], +) +def test_degree_index(lmax, mmax, coefficient_layout) -> None: from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( _build_frame_degree_index as pt_build_frame_degree_index, ) dp_idx = _build_frame_degree_index( - lmax=lmax, mmax=lmax, coefficient_layout="packed" + lmax=lmax, mmax=mmax, coefficient_layout=coefficient_layout ) pt_idx = ( - pt_build_frame_degree_index(lmax=lmax, mmax=lmax, coefficient_layout="packed") + pt_build_frame_degree_index( + lmax=lmax, mmax=mmax, coefficient_layout=coefficient_layout + ) .detach() .cpu() .numpy() ) - assert dp_idx.shape == ((lmax + 1) ** 2,) + assert dp_idx.shape == pt_idx.shape np.testing.assert_array_equal(dp_idx, pt_idx) - # each (l, m) row maps to degree l: row d has degree dp_idx[d] - expected = np.concatenate( - [np.full(2 * l + 1, l, dtype=np.int64) for l in range(lmax + 1)] - ) - np.testing.assert_array_equal(dp_idx, expected) + if coefficient_layout == "packed" and mmax == lmax: + expected = np.concatenate( + [np.full(2 * l + 1, l, dtype=np.int64) for l in range(lmax + 1)] + ) + np.testing.assert_array_equal(dp_idx, expected)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/common/dpmodel/test_dpa4_frame_mixers.py` around lines 184 - 205, The test_degree_index function currently only validates the packed coefficient layout with mmax equal to lmax, leaving untested branches for the m_major layout and truncated packed cases where mmax is less than lmax. Extend the test by adding parametrization for different coefficient_layout values (packed and m_major) and varying mmax values (including cases where mmax is less than lmax). For each combination, call both _build_frame_degree_index and pt_build_frame_degree_index with the respective parameters, then assert that the shapes match and that the numpy arrays are equal, ensuring regression detection across all code paths.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py`:
- Around line 1035-1038: The code reshapes scalar_gate.bias without validating
that the loaded variable has the correct shape, which could silently accept
wrong-layout biases with the same total size and corrupt parameters. Add an
explicit shape check before the reshape operation in the _load_variables method
to ensure the loaded variables["scalar_gate.bias"] matches the expected shape of
self.scalar_gate.bias.shape, and raise an appropriate error if the shapes do not
match.
In `@deepmd/dpmodel/descriptor/dpa4_nn/projection.py`:
- Line 436: The precision lookup in the SO3 serialization code uses
PRECISION_DICT[self.precision] directly without normalizing the case, while
initialization accepts mixed-case precision names through .lower(). This
inconsistency causes KeyError for inputs like "Float64". Fix this by normalizing
self.precision to lowercase when accessing PRECISION_DICT, changing
PRECISION_DICT[self.precision] to PRECISION_DICT[self.precision.lower()] to
match the initialization behavior and ensure consistent handling of mixed-case
precision names.
---
Nitpick comments:
In `@source/tests/common/dpmodel/test_dpa4_frame_mixers.py`:
- Around line 184-205: The test_degree_index function currently only validates
the packed coefficient layout with mmax equal to lmax, leaving untested branches
for the m_major layout and truncated packed cases where mmax is less than lmax.
Extend the test by adding parametrization for different coefficient_layout
values (packed and m_major) and varying mmax values (including cases where mmax
is less than lmax). For each combination, call both _build_frame_degree_index
and pt_build_frame_degree_index with the respective parameters, then assert that
the shapes match and that the numpy arrays are equal, ensuring regression
detection across all code paths.
In `@source/tests/common/dpmodel/test_dpa4_grid_mlp.py`:
- Around line 177-203: The S2GridNet constructor calls in the test are relying
on default values for the mlp_bias parameter while _copy_s2gridnet_mlp only
copies weight tensors without biases, creating brittle default-dependent
behavior. Add mlp_bias=False explicitly as a parameter to both S2GridNet
constructor calls (the one around line 177 and the one around line 225) to make
the intent clear and avoid relying on constructor defaults.
In `@source/tests/common/dpmodel/test_dpa4_grid_wiring.py`:
- Around line 231-245: Add a new parity test method for node_wise_so3 to match
the existing test_parity_node_wise_s2 method. Create a test method called
test_parity_node_wise_so3 that follows the same pattern as
test_parity_node_wise_s2 but passes node_wise_so3=True instead of
node_wise_s2=True to the _build_pair method, then calls _assert_parity on the
resulting pt_mod and dp_mod. This ensures the SO3 cross-mode node-wise wiring is
covered by the full descriptor parity test matrix.
In `@source/tests/consistent/descriptor/test_dpa4.py`:
- Around line 94-102: The test suite needs to add a curated case that combines
the SO3 grid paths with grid_mlp to match the flagship path behavior. After the
existing dpa4_case calls in the test_dpa4.py file, add a new test case that
enables both ffn_so3_grid and message_node_so3 together with grid_branch=[0, 0,
0] and grid_mlp=True to ensure the combined path where SO3 Wigner-D grids work
together with grid MLP is properly validated.
In `@source/tests/pt/model/test_dpa4_dpmodel_parity.py`:
- Around line 3244-3253: The test_ffn_guards method currently only covers the S2
branch when lebedev_quadrature is False. Add an additional pytest.raises
assertion block that tests the SO3 grid branch (ffn_so3_grid enabled) with
lebedev_quadrature=False set to False to ensure complete coverage of the
not-implemented paths for both the S2 and SO3 routes. Use the same DPFFN class
initialization pattern with the _ffn_kwargs helper method, ensuring both test
cases verify that the NotImplementedError with lebedev_quadrature in the error
message is properly raised.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: c8ddbcf2-2307-4c9c-aab7-09c1b7a46e8c
📒 Files selected for processing (14)
deepmd/dpmodel/descriptor/dpa4_nn/ffn.pydeepmd/dpmodel/descriptor/dpa4_nn/grid_net.pydeepmd/dpmodel/descriptor/dpa4_nn/projection.pydeepmd/dpmodel/descriptor/dpa4_nn/so2.pysource/tests/common/dpmodel/test_descrpt_dpa4.pysource/tests/common/dpmodel/test_dpa4_frame_mixers.pysource/tests/common/dpmodel/test_dpa4_grid_mlp.pysource/tests/common/dpmodel/test_dpa4_grid_wiring.pysource/tests/common/dpmodel/test_dpa4_gridnet_cross.pysource/tests/common/dpmodel/test_dpa4_so3_grid_utils.pysource/tests/common/dpmodel/test_dpa4_so3_gridnet.pysource/tests/common/dpmodel/test_dpa4_so3_projector.pysource/tests/consistent/descriptor/test_dpa4.pysource/tests/pt/model/test_dpa4_dpmodel_parity.py
💤 Files with no reviewable changes (1)
- source/tests/common/dpmodel/test_descrpt_dpa4.py
| if self.mlp_bias: | ||
| self.scalar_gate.bias = np.asarray( | ||
| variables["scalar_gate.bias"], dtype=prec | ||
| ).reshape(self.scalar_gate.bias.shape) |
There was a problem hiding this comment.
Validate scalar_gate.bias shape before reshape in _load_variables.
scalar_gate.bias is reshaped without an explicit shape check, so a same-size but wrong-layout bias can be accepted and silently corrupt deserialized parameters.
Proposed fix
if self.mlp_bias:
- self.scalar_gate.bias = np.asarray(
- variables["scalar_gate.bias"], dtype=prec
- ).reshape(self.scalar_gate.bias.shape)
+ bias = np.asarray(variables["scalar_gate.bias"], dtype=prec)
+ if bias.shape != self.scalar_gate.bias.shape:
+ raise ValueError(
+ f"scalar_gate.bias shape {bias.shape} does not match "
+ f"the expected shape {self.scalar_gate.bias.shape}"
+ )
+ self.scalar_gate.bias = bias🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py` around lines 1035 - 1038, The
code reshapes scalar_gate.bias without validating that the loaded variable has
the correct shape, which could silently accept wrong-layout biases with the same
total size and corrupt parameters. Add an explicit shape check before the
reshape operation in the _load_variables method to ensure the loaded
variables["scalar_gate.bias"] matches the expected shape of
self.scalar_gate.bias.shape, and raise an appropriate error if the shapes do not
match.
| "lmax": self.lmax, | ||
| "mmax": self.mmax, | ||
| "kmax": self.kmax, | ||
| "precision": np.dtype(PRECISION_DICT[self.precision]).name, |
There was a problem hiding this comment.
Normalize precision lookup in SO3 serialization.
Initialization accepts mixed-case precision names (.lower()), but serialization uses PRECISION_DICT[self.precision] directly and can raise KeyError for inputs like "Float64".
Proposed fix
- "precision": np.dtype(PRECISION_DICT[self.precision]).name,
+ "precision": np.dtype(PRECISION_DICT[self.precision.lower()]).name,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| "precision": np.dtype(PRECISION_DICT[self.precision]).name, | |
| "precision": np.dtype(PRECISION_DICT[self.precision.lower()]).name, |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/dpmodel/descriptor/dpa4_nn/projection.py` at line 436, The precision
lookup in the SO3 serialization code uses PRECISION_DICT[self.precision]
directly without normalizing the case, while initialization accepts mixed-case
precision names through .lower(). This inconsistency causes KeyError for inputs
like "Float64". Fix this by normalizing self.precision to lowercase when
accessing PRECISION_DICT, changing PRECISION_DICT[self.precision] to
PRECISION_DICT[self.precision.lower()] to match the initialization behavior and
ensure consistent handling of mixed-case precision names.
|
Superseded by #5555. This PR ported the pre-#5552 grid design; after #5517/#5552 refactored the dpmodel/pt grid ops ( |
… pt) (deepmodeling#5555) ## What Completes the DPA4/SeZM **SO3 grid projection** port to the dpmodel backend so it faithfully **mirrors master's current pt** `sezm_nn/grid_net.py`. After this, the flagship `examples/water/dpa4/input.json` (which sets `ffn_so3_grid=true`, `message_node_so3=true`, `grid_mlp`) runs on dpmodel/pt_expt. Builds **on top of** the S2-grid base that deepmodeling#5517/deepmodeling#5552 landed (`GridProduct`/`GridMLP`/`op_type='mlp'`, the `_project_frames` refactor). Master's dpmodel was the S2 (`n_frames==1`) slice with SO3/cross-mode fail-fast guarded; this PR generalizes those ops to frame-aware (`n_frames>1`) + cross-mode and adds the missing SO3 pieces — matching current pt exactly (single source of truth: dpmodel == pt). Supersedes deepmodeling#5547 (which ported the *pre*-deepmodeling#5552 design and went structurally stale). ## Changes (all mirror current pt) - **`grid_net.py`**: add `_project_frames`; generalize `GridMLP`/`GridBranch` to frame-aware (`n_frames`); generalize `BaseGridNet` (un-guard `mode='cross'`, `layout='flat'`, `residual_scale_init`, `n_frames>1`; frame-axis to/from-grid via `xp.matmul`+reshape); add `FrameContract`/`FrameExpand`/`_build_frame_degree_index`; add `SO3GridNet` (self+cross). - **`projection.py`**: add `SO3GridProjector` (Wigner-D quadrature) + `resolve_so3_grid`/`_build_so3_frame_set`. - **`ffn.py`**: un-guard `ffn_so3_grid` → `SO3GridNet(mode='self')`. - **`so2.py`**: un-guard `node_wise_{s2,so3}`/`message_node_{s2,so3}` → cross-mode grid products, applied in `call` + round-tripped in serialize. ## Validation - Component parity vs pt (weight-copied fp64): `_project_frames`, `GridMLP`/`GridBranch` (incl. S2 byte-identical regression), `BaseGridNet` cross/flat/residual, `FrameContract`/`FrameExpand`, `SO3GridProjector` matrices, `SO3GridNet` self+cross (op_type glu/mlp/branch, kmax 1&2) — all **1e-12**; rotation equivariance **1e-10**. - **fp32** grid-path parity at the computation-in-fp32 budget (actual diffs 1e-6–1e-8 ≪ 1e-4). - Full-descriptor pt→dpmodel via `DescrptDPA4.deserialize(pt.serialize())` on the example config (lmax=3, mmax=1) — **~1e-14** — proving `dp convert-backend` schema interop. - Permutation-invariance + masked-edge no-op. - Cross-backend consistency rows (pt vs dpmodel **and pt_expt**, mixed_types) for ffn_so3_grid / message_node_so3 / both / grid_mlp. - **Verified on remote GPU (Tesla T4):** 617 (grid+parity+pt_expt) + 50 (consistency) pass, no CUDA device errors. pt_expt forward works today via auto-wrap (consistency + descriptor trio green) — no explicit registration needed. ## Known limitations - pt_expt **training** Parameter-promotion for the new weight-bearing grid classes, `torch.export`/AOTI grid coverage, training e2e, argcheck `doc_only_pt_supported` removal, and freeze/DeepEval are a **follow-up PR**. - `grid_method='e3nn'` (non-Lebedev product grid) stays fail-fast (Lebedev-only, per parent design). - fp32 grid paths use a ~1e-4 budget by design; fp64 is the parity reference. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added SO(3) grid projection support and frame-aware grid networks for DPA4 descriptors, including SO(3)-based FFN and improved cross-mode grid-product wiring. * Extended grid modules to support multi-frame configurations with per-degree frame mixing, and added full SO(3) projector/network serialization. * **Bug Fixes** * Enabled previously disabled/unsupported DPA4 SO(2) convolution cross-mode SO(3)/S2 grid products. * **Documentation** * Updated DPA4 porting-layer documentation to clarify supported configuration flags. * **Tests** * Added/expanded parity, equivariance, serialization/roundtrip, and torch-namespace compatibility tests for the new SO(3) and frame-aware paths. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
What
Ports the SO3 grid projection feature family of the DPA4/SeZM descriptor to the backend-agnostic dpmodel layer (and, via auto-wrapping, pt_expt). This is PR-A of a 2-PR series; it lands the dpmodel math + descriptor wiring + parity tests. PR-B will add pt_expt training Parameter-promotion, torch.export/make_fx coverage, training e2e, and argcheck doc updates.
This is the follow-up after the merged DPA4 core series (#5515 / #5522 / #5540), which deliberately scoped to "energy model, no extensions". With this PR the flagship
examples/water/dpa4/input.jsonconfig (which setsffn_so3_grid=true,message_node_so3=true,grid_mlp) becomes runnable on dpmodel/pt_expt.Newly ported (dpmodel
descriptor/dpa4_nn/)projection.py:SO3GridProjector(init-time Wigner-D quadrature over a Lebedev×γ rotation grid),resolve_so3_grid,_build_so3_frame_set.grid_net.py:SO3GridNet(self + cross),GridMLP(op_type='mlp'),FrameExpand/FrameContractper-degree frame mixers; filled the previously-guardedBaseGridNetbranches:mode='cross',layout='flat',residual_scale_init, andn_frames>1. This also enables S2 cross-mode (node_wise_s2/message_node_s2) andGridMLPfor S2.ffn.py/so2.py: wired the grid nets into the FFN (ffn_so3_grid) andSO2Convolution(node_wise_*/message_node_*), removing the fail-fast guards. Out-of-scope flags (layer_scale, atten_*_proj, so2_attn_res, e3nn/non-Lebedev grid) remain guarded.All forward math is array-API (no
np.einsumon tensors; frame-axis to/from-grid done viaxp.matmul+reshape).Tests
SO3GridProjectormatrices match pt to 1e-12;FrameExpand/FrameContract,GridMLP,SO3GridNet(self+cross, glu/mlp/branch, kmax 1&2) to 1e-12; rotation equivariance to ~1e-14.test_dpa4_dpmodel_parity.pyconverted to real pt-parity tests at every level (S2Grid/SO2/FFN/Block); added kmax=2, mmax<lmax, and an fp32 grid case.DescrptDPA4.deserialize(pt.serialize())on the example config (lmax=3, mmax=1) — ~1e-15 — which also provesdp convert-backendschema interop.pt_expt forward inference works today via auto-wrapping (consistency + descriptor trio green).
Known limitations
_TRAINABLE_ATTRS), torch.export/AOTI coverage of the grid path, training e2e, and argcheck "only pt supported" doc removal are PR-B.grid_method='e3nn'(non-Lebedev product grid) stays fail-fast (Lebedev-only, per the parent design).Summary by CodeRabbit
New Features
Bug Fixes
Tests