feat(dpmodel): complete DPA4/SeZM SO3 grid projection (mirror current pt)#5555
feat(dpmodel): complete DPA4/SeZM SO3 grid projection (mirror current pt)#5555wanghan-iapcm wants to merge 18 commits into
Conversation
…) for DPA4 SO3 grid
… for DPA4 SO3 grid
for more information, see https://pre-commit.ci
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughPorts the previously unimplemented SO(3)/S2 cross-mode grid paths in the DPA4 dpmodel descriptor. Adds ChangesDPA4 dpmodel SO(3) grid port
Sequence Diagram(s)sequenceDiagram
participant Caller
participant EquivariantFFN
participant SO3GridNet
participant BaseGridNet
participant FrameExpand
participant GridOp as GridMLP/GridBranch
participant FrameContract
Caller->>EquivariantFFN: call(x_coeffs)
EquivariantFFN->>SO3GridNet: act(x_coeffs) [ffn_so3_grid=True]
SO3GridNet->>BaseGridNet: call(query, context=None)
BaseGridNet->>FrameExpand: expand packed frames
FrameExpand-->>BaseGridNet: expanded coefficients
BaseGridNet->>GridOp: to_grid / quadratic product / from_grid
GridOp-->>BaseGridNet: grid op output
BaseGridNet->>FrameContract: contract back to channel dim
FrameContract-->>BaseGridNet: contracted output
BaseGridNet-->>SO3GridNet: result + residual_scale
SO3GridNet-->>EquivariantFFN: activated features
EquivariantFFN-->>Caller: updated x_coeffs
sequenceDiagram
participant SO2Conv as SO2Convolution.call
participant NodeWise as node_wise_grid_product
participant AttnAgg as attention aggregation
participant MsgNode as message_node_grid_product
SO2Conv->>NodeWise: (x_dst_local, x_local) → residual [if enabled]
NodeWise-->>SO2Conv: add to x_local before SO(2) focus
SO2Conv->>AttnAgg: SO(2) convolution + edge aggregation → out
SO2Conv->>MsgNode: (out, node features) → residual [if enabled]
MsgNode-->>SO2Conv: add to out before final channel mixing
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
source/tests/common/dpmodel/test_dpa4_so3_gridnet.py (1)
333-338: ⚡ Quick winAssert both frame-mixer variables are absent in self mode.
test_so3_serialize_roundtriponly checks"frame_expand.weight"is absent formode="self". Add the symmetric check for"frame_contract.weight"to prevent silent schema regressions.Suggested patch
if mode == "cross": assert "frame_expand.weight" in data["`@variables`"] assert "frame_contract.weight" in data["`@variables`"] else: assert "frame_expand.weight" not in data["`@variables`"] + assert "frame_contract.weight" not in data["`@variables`"]🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/common/dpmodel/test_dpa4_so3_gridnet.py` around lines 333 - 338, In the test_so3_serialize_roundtrip function, the else block (which handles mode="self") only asserts that "frame_expand.weight" is absent from data["`@variables`"] but is missing the symmetric check for "frame_contract.weight". Add an additional assert statement in the else block to verify that "frame_contract.weight" is also not in data["`@variables`"], mirroring the structure of the cross mode block which checks both variables.source/tests/common/dpmodel/test_dpa4_so2_grid.py (1)
194-197: ⚡ Quick winAdd a non-trivial-output guard to avoid vacuous parity passes.
_assert_conv_paritycurrently only checks DP/PT closeness. Add a simple magnitude guard so parity can’t pass if both outputs collapse to (near) zero.Suggested patch
out_dp = dp_mod.call(x, dp_cache, radial) out_pt = pt_mod(_to_pt(x), pt_cache, _to_pt(radial_valid)) + assert np.max(np.abs(np.asarray(out_dp))) > 1e-8 _assert_parity(out_dp, out_pt, rtol=rtol, atol=atol)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/common/dpmodel/test_dpa4_so2_grid.py` around lines 194 - 197, The _assert_parity function currently only checks if DP and PT outputs are close to each other, which means the test can pass vacuously if both outputs collapse to near-zero values. Add a magnitude guard within the _assert_parity function that verifies at least one of the outputs has a non-trivial magnitude (e.g., check that the absolute maximum value or norm of the outputs exceeds a small threshold) before asserting their closeness, ensuring the parity test only passes when the outputs are both meaningful and close to each other.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py`:
- Around line 1532-1544: The code in the class method (around the
projector_config assignment) is accessing only the `["config"]` portion of the
projector object, bypassing validation of the nested `@class` and `@version`
schema fields. Before extracting and using projector_config, validate the full
projector object structure (the complete config["projector"] payload) to ensure
it conforms to the expected schema and version, preventing acceptance of wrong
or incompatible projector schemas when rebuilding the object.
- Around line 950-956: The operands query and context_ndfc are being passed to
self.frame_expand() without first casting them to compute_dtype, unlike
scalar_pair which is properly cast on lines 952-953. This causes dtype
mismatches when fp32 inputs flow through float64 grid nets. Cast both query and
context_ndfc to compute_dtype (using xp.astype) before passing them to the
self.frame_expand() calls in the return statement, similar to how scalar_pair is
cast to compute_dtype.
In `@deepmd/dpmodel/descriptor/dpa4_nn/so2.py`:
- Around line 1827-1841: The deserialization of `node_wise_grid_product` and
`message_node_grid_product` does not validate that the template contains only
expected keys before calling `deserialize()`. After assigning the result of
`sub_vars()` to the `@variables` key in the template, add validation to reject
any unexpected or unknown top-level keys in the template before passing it to
the `deserialize()` method for both `node_wise_grid_product` and
`message_node_grid_product`. This prevents schema drift keys from being silently
ignored during the deserialization process.
---
Nitpick comments:
In `@source/tests/common/dpmodel/test_dpa4_so2_grid.py`:
- Around line 194-197: The _assert_parity function currently only checks if DP
and PT outputs are close to each other, which means the test can pass vacuously
if both outputs collapse to near-zero values. Add a magnitude guard within the
_assert_parity function that verifies at least one of the outputs has a
non-trivial magnitude (e.g., check that the absolute maximum value or norm of
the outputs exceeds a small threshold) before asserting their closeness,
ensuring the parity test only passes when the outputs are both meaningful and
close to each other.
In `@source/tests/common/dpmodel/test_dpa4_so3_gridnet.py`:
- Around line 333-338: In the test_so3_serialize_roundtrip function, the else
block (which handles mode="self") only asserts that "frame_expand.weight" is
absent from data["`@variables`"] but is missing the symmetric check for
"frame_contract.weight". Add an additional assert statement in the else block to
verify that "frame_contract.weight" is also not in data["`@variables`"], mirroring
the structure of the cross mode block which checks both variables.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 8516e16a-0724-4162-a7a7-f1167567c8f4
📒 Files selected for processing (19)
deepmd/dpmodel/descriptor/dpa4_nn/block.pydeepmd/dpmodel/descriptor/dpa4_nn/ffn.pydeepmd/dpmodel/descriptor/dpa4_nn/grid_net.pydeepmd/dpmodel/descriptor/dpa4_nn/projection.pydeepmd/dpmodel/descriptor/dpa4_nn/so2.pysource/tests/common/dpmodel/test_descrpt_dpa4.pysource/tests/common/dpmodel/test_dpa4_basegridnet_cross.pysource/tests/common/dpmodel/test_dpa4_ffn_so3.pysource/tests/common/dpmodel/test_dpa4_frame_mixers.pysource/tests/common/dpmodel/test_dpa4_grid_descriptor.pysource/tests/common/dpmodel/test_dpa4_gridbranch_frames.pysource/tests/common/dpmodel/test_dpa4_gridmlp_frames.pysource/tests/common/dpmodel/test_dpa4_project_frames.pysource/tests/common/dpmodel/test_dpa4_so2_grid.pysource/tests/common/dpmodel/test_dpa4_so3_grid_utils.pysource/tests/common/dpmodel/test_dpa4_so3_gridnet.pysource/tests/common/dpmodel/test_dpa4_so3_projector.pysource/tests/consistent/descriptor/test_dpa4.pysource/tests/pt/model/test_dpa4_dpmodel_parity.py
…ng#5555) - BaseGridNet cross-mode: lift query/context to compute_dtype before FrameExpand so the frame expansion runs in the net's precision (mirrors pt's fp64-weight-forced FrameExpand); _FrameMixer otherwise casts weights down to the operand dtype, expanding fp32 inputs in fp32. (CodeRabbit) - SO3GridNet.deserialize: validate the nested projector @class/@Version instead of blindly reading config["projector"]["config"]. (CodeRabbit) - SO2Convolution.deserialize: reject schema-drift keys under the node_wise/message_node grid-product prefixes (loaded @variables key set must match the fresh template). (CodeRabbit) - drop unused n_valid local in test_dpa4_so2_grid.py. (CodeQL) Tests: mixed-precision cross run, deserialize rejects bad projector @class/@Version, deserialize rejects grid-product drift key.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py`:
- Around line 1541-1548: The version key for the nested projector is being
accessed with a default value of 1 using get("`@version`", 1), which allows
malformed projector payloads without a version to silently pass validation.
Remove the default value parameter and add explicit validation to require the
`@version` key to be present in projector_data before calling
check_version_compatibility(). Raise a ValueError with a descriptive message if
the `@version` key is missing, following the same pattern as the existing `@class`
validation.
In `@source/tests/common/dpmodel/test_dpa4_so3_gridnet.py`:
- Around line 472-473: The test for invalid version handling in
DPSO3GridNet.deserialize uses overly broad exception handling with
pytest.raises(Exception), which can mask unrelated failures and weakens test
specificity. Replace Exception with the concrete exception type that the
deserialize method actually raises for invalid versions (examine the method
implementation to determine if it's ValueError, RuntimeError, or another
specific exception type), and optionally add a message parameter to match
against the error message fragment for even more precise assertion.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: ca2ea555-f9dd-49f9-8000-5426bb8cb238
📒 Files selected for processing (4)
deepmd/dpmodel/descriptor/dpa4_nn/grid_net.pydeepmd/dpmodel/descriptor/dpa4_nn/so2.pysource/tests/common/dpmodel/test_dpa4_so2_grid.pysource/tests/common/dpmodel/test_dpa4_so3_gridnet.py
🚧 Files skipped from review as they are similar to previous changes (2)
- source/tests/common/dpmodel/test_dpa4_so2_grid.py
- deepmd/dpmodel/descriptor/dpa4_nn/so2.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5555 +/- ##
==========================================
+ Coverage 82.16% 82.22% +0.05%
==========================================
Files 896 896
Lines 102643 103018 +375
Branches 4340 4340
==========================================
+ Hits 84341 84708 +367
- Misses 16965 16972 +7
- Partials 1337 1338 +1 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
…eepmodeling#5555) Follow-up to CodeRabbit re-review: - SO3GridNet.deserialize now requires the nested projector @Version key (was silently defaulting a missing version to 1). - the version test asserts ValueError(match="version") instead of a blind Exception (ruff B017), and adds a missing-@Version case.
What
Completes the DPA4/SeZM SO3 grid projection port to the dpmodel backend so it faithfully mirrors master's current pt
sezm_nn/grid_net.py. After this, the flagshipexamples/water/dpa4/input.json(which setsffn_so3_grid=true,message_node_so3=true,grid_mlp) runs on dpmodel/pt_expt.Builds on top of the S2-grid base that #5517/#5552 landed (
GridProduct/GridMLP/op_type='mlp', the_project_framesrefactor). Master's dpmodel was the S2 (n_frames==1) slice with SO3/cross-mode fail-fast guarded; this PR generalizes those ops to frame-aware (n_frames>1) + cross-mode and adds the missing SO3 pieces — matching current pt exactly (single source of truth: dpmodel == pt).Supersedes #5547 (which ported the pre-#5552 design and went structurally stale).
Changes (all mirror current pt)
grid_net.py: add_project_frames; generalizeGridMLP/GridBranchto frame-aware (n_frames); generalizeBaseGridNet(un-guardmode='cross',layout='flat',residual_scale_init,n_frames>1; frame-axis to/from-grid viaxp.matmul+reshape); addFrameContract/FrameExpand/_build_frame_degree_index; addSO3GridNet(self+cross).projection.py: addSO3GridProjector(Wigner-D quadrature) +resolve_so3_grid/_build_so3_frame_set.ffn.py: un-guardffn_so3_grid→SO3GridNet(mode='self').so2.py: un-guardnode_wise_{s2,so3}/message_node_{s2,so3}→ cross-mode grid products, applied incall+ round-tripped in serialize.Validation
_project_frames,GridMLP/GridBranch(incl. S2 byte-identical regression),BaseGridNetcross/flat/residual,FrameContract/FrameExpand,SO3GridProjectormatrices,SO3GridNetself+cross (op_type glu/mlp/branch, kmax 1&2) — all 1e-12; rotation equivariance 1e-10.DescrptDPA4.deserialize(pt.serialize())on the example config (lmax=3, mmax=1) — ~1e-14 — provingdp convert-backendschema interop.pt_expt forward works today via auto-wrap (consistency + descriptor trio green) — no explicit registration needed.
Known limitations
torch.export/AOTI grid coverage, training e2e, argcheckdoc_only_pt_supportedremoval, and freeze/DeepEval are a follow-up PR.grid_method='e3nn'(non-Lebedev product grid) stays fail-fast (Lebedev-only, per parent design).Summary by CodeRabbit