Skip to content

perf(dpa4): opt so3grid#5517

Open
OutisLi wants to merge 2 commits into
deepmodeling:masterfrom
OutisLi:pr/so3grid
Open

perf(dpa4): opt so3grid#5517
OutisLi wants to merge 2 commits into
deepmodeling:masterfrom
OutisLi:pr/so3grid

Conversation

@OutisLi

@OutisLi OutisLi commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator

Summary by CodeRabbit

Release Notes

  • Refactor

    • Reworked S2 grid-net processing to apply channel projections in coefficient space, keeping grid projection confined to the selected grid operation.
    • Updated grid-net APIs to accept to_grid/from_grid callables and generalized grid-op handling for quadratic/product, polynomial MLP, and branch routing.
    • Broadened grid-net construction and serialization support for mlp and branch operations.
  • Tests

    • Extended parity and serialize/deserialization roundtrip coverage for grid-net variants, including mlp, with updated forward wiring.
  • Documentation

    • Clarified guarded/unsupported grid flag behavior and option precedence.

Copilot AI review requested due to automatic review settings June 12, 2026 05:00
@dosubot dosubot Bot added the enhancement label Jun 12, 2026

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes the SO(3) grid-net quadratic operations in the SeZM NN descriptor by moving channel-only linear projections from grid resolution back to coefficient resolution (where possible), reducing work proportional to the grid size while preserving equivariant behavior.

Changes:

  • Introduced a shared _project_frames() helper to apply per-frame ChannelLinear projections directly on packed coefficient tensors.
  • Refactored GridMLP and GridBranch to operate on coefficient operands and use injected to_grid/from_grid projectors only for the unavoidable point-wise grid product step.
  • Replaced the implicit GLU “identity” op path with an explicit GridProduct module and removed _apply_grid_op, unifying the grid-op call interface.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@coderabbitai

coderabbitai Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: c4141cc4-d35e-4adf-a0d4-1919b55697b2

📥 Commits

Reviewing files that changed from the base of the PR and between 9ffd80f and b5471ee.

📒 Files selected for processing (5)
  • deepmd/dpmodel/descriptor/dpa4_nn/block.py
  • deepmd/dpmodel/descriptor/dpa4_nn/ffn.py
  • deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py
  • deepmd/pt/model/descriptor/sezm_nn/grid_net.py
  • source/tests/pt/model/test_dpa4_dpmodel_parity.py
✅ Files skipped from review due to trivial changes (1)
  • deepmd/dpmodel/descriptor/dpa4_nn/block.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • deepmd/dpmodel/descriptor/dpa4_nn/ffn.py
  • source/tests/pt/model/test_dpa4_dpmodel_parity.py
  • deepmd/pt/model/descriptor/sezm_nn/grid_net.py
  • deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py

📝 Walkthrough

Walkthrough

This PR refactors grid-net computation in both PyTorch and DPModel implementations to operate at coefficient resolution. A new _project_frames helper applies per-frame channel transformations. GridProduct, GridMLP, and GridBranch now accept injected to_grid/from_grid projection callables. GridMLP and GridBranch store n_frames and manage projections internally. BaseGridNet wiring updated to construct grid operations with frame counts and call them with projection callbacks. Grid helpers generalized to infer channel widths from runtime tensor shapes. Documentation updated to reflect mlp operation now being ported. Tests expanded to cover all three grid operation types with identity projector injection.

Changes

PyTorch Grid-Net Coefficient-Space Refactoring

Layer / File(s) Summary
Type imports and setup
deepmd/pt/model/descriptor/sezm_nn/grid_net.py
TYPE_CHECKING guard added for Callable import from collections.abc supporting new grid operation call signatures with projection callbacks.
Frame projection and GridProduct refactoring
deepmd/pt/model/descriptor/sezm_nn/grid_net.py
_project_frames(...) helper applies ChannelLinear transformations per Wigner-D frame within coefficient tensors. GridProduct.forward updated to accept injected to_grid/from_grid parameters and perform pointwise grid product.
GridMLP coefficient-space computation
deepmd/pt/model/descriptor/sezm_nn/grid_net.py
Constructor accepts and stores n_frames. forward refactored to project operands at coefficient resolution via _project_frames, compute quadratic interaction through grid projections, and project output channels back.
GridBranch coefficient-space computation
deepmd/pt/model/descriptor/sezm_nn/grid_net.py
Constructor accepts and stores n_frames. forward refactored to project operands at coefficient resolution, compute quadratic branch values on grid, apply softmax routing, and project routed result back to coefficients.
BaseGridNet grid-op wiring and orchestration
deepmd/pt/model/descriptor/sezm_nn/grid_net.py
Grid-op construction for MLP and branch types updated to pass n_frames. forward changed to call self.grid_op(left, right, scalar_pair, to_grid=..., from_grid=...) producing coefficient-space output. _apply_grid_op method removed.
Grid helper generalizations
deepmd/pt/model/descriptor/sezm_nn/grid_net.py
_to_grid and _from_grid refactored to infer per-channel width from runtime tensor shapes using -1 reshape rather than assuming fixed self.channels/self.expanded_channels.

DPModel Grid-Net Port and Grid Operation Classes

Layer / File(s) Summary
Type imports and setup
deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py
TYPE_CHECKING guard and Callable import added for new grid operation call signatures.
Module documentation updates
deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py
Module docstring updated to reflect newly ported grid operation classes (GridProduct, GridMLP) and BaseGridNet op_type values including mlp.
GridProduct and GridMLP implementations
deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py
New GridProduct class performs pointwise quadratic product via injected projectors. New GridMLP class with mode-dependent tensor handling and explicit serialize()/deserialize() for checkpoint support.
GridBranch coefficient-space refactoring
deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py
call method refactored to operate on coefficient-space operands with injected projectors, apply branch-wise softmax routing, and return coefficient-space outputs.
BaseGridNet grid-op routing and helpers
deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py
Removes hard NotImplementedError for op_type='mlp'. Implements three-way grid-op selection: GridMLP for mlp, GridBranch for branch, GridProduct fallback. Refactors _to_grid/_from_grid to infer channel widths at runtime.
S2GridNet serialization updates
deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py
serialize and deserialize methods broadened to handle op_type in {"mlp","branch"} for grid-op parameter storage/loading, extending prior branch-only logic.

Documentation and Guard Clarifications

Layer / File(s) Summary
Documentation and guard updates
deepmd/dpmodel/descriptor/dpa4_nn/block.py, deepmd/dpmodel/descriptor/dpa4_nn/ffn.py
Removes references to unported grid_mlp/mlp from docstrings, clarifies grid_branch precedence over grid_mlp, updates guards to reflect ffn_so3_grid as the guarded flag.

Test Parity and Coverage Updates

Layer / File(s) Summary
Grid-net parity and parametrization
source/tests/pt/model/test_dpa4_dpmodel_parity.py
Generalizes grid-op weight-copy logic to handle all op_type values. Expands parametrization to include op_type="mlp" alongside glu and branch.
GridBranch and GridMLP forward parity
source/tests/pt/model/test_dpa4_dpmodel_parity.py
Updates test_grid_branch with n_frames=1 and identity projector injection. Adds comprehensive GridMLP forward-parity and serialize roundtrip test with explicit weight transfer and projector injection.
Guard and NotImplementedError updates
source/tests/pt/model/test_dpa4_dpmodel_parity.py
Removes expectation that op_type="mlp" raises NotImplementedError. Updates FFN guard test to expect NotImplementedError for ffn_so3_grid=True.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5503: Both PRs refactor the DPA4/SeZM grid-net grid nonlinearity by adding/rewiring GridMLP/GridBranch/GridProduct to execute via coefficient↔grid projection callbacks (to_grid/from_grid), aligning grid-op computation across PT and DPModel implementations.
  • deepmodeling/deepmd-kit#5515: Both PRs modify the DPA4/SeZM grid-net stack so grid projections are handled via injected to_grid/from_grid with coefficient-space computation, addressing grid MLP and product operations via the same projection callback pattern.

Suggested reviewers

  • wanghan-iapcm
  • iProzd
  • njzjz
🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 37.78% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'perf(dpa4): opt so3grid' is vague and uses non-descriptive abbreviations that don't clearly convey the actual scope of changes, which include refactoring grid operations to coefficient space, adding GridMLP/GridProduct support, and updating APIs across multiple grid operation classes. Consider a more descriptive title that better reflects the actual refactoring scope, such as 'refactor(dpa4): move grid ops to coefficient space and add GridMLP/GridProduct support' or similar.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@OutisLi OutisLi requested a review from wanghan-iapcm June 12, 2026 05:26
@codecov

codecov Bot commented Jun 12, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 96.90722% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.19%. Comparing base (5d94bd6) to head (9ffd80f).

Files with missing lines Patch % Lines
deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py 95.65% 3 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##           master    #5517   +/-   ##
=======================================
  Coverage   82.19%   82.19%           
=======================================
  Files         891      891           
  Lines      101599   101647   +48     
  Branches     4242     4240    -2     
=======================================
+ Hits        83507    83552   +45     
- Misses      16789    16792    +3     
  Partials     1303     1303           

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Jun 12, 2026
@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Jun 12, 2026
@OutisLi OutisLi added this pull request to the merge queue Jun 13, 2026
@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Jun 13, 2026
@wanghan-iapcm

Copy link
Copy Markdown
Collaborator

This PR passes its own CI but fails in the merge queue. The merge_group CUDA run errors with:

TypeError: GridBranch.__init__() missing 1 required keyword-only argument: 'n_frames'
FAILED source/tests/pt/model/test_dpa4_dpmodel_parity.py::TestS2GridParity::test_grid_branch[1]
FAILED source/tests/pt/model/test_dpa4_dpmodel_parity.py::TestS2GridParity::test_grid_branch[2]

(run: https://github.com/deepmodeling/deepmd-kit/actions/runs/27460273106)

Root cause — stale base + a semantic merge conflict. This PR makes n_frames a required keyword-only arg of GridBranch.__init__ in grid_net.py, but its base predates #5515. #5515 added test_grid_branch, which constructs GridBranch without n_frames:

pt_mod = PTGridBranch(
channels=self.channels,
n_branches=n_branches,
dtype=torch.float64,
trainable=True,
seed=9,
)
rng = np.random.default_rng(2084)
with torch.no_grad():
for p in pt_mod.parameters():
p += to_pt(0.1 * rng.normal(size=tuple(p.shape)))
state = pt_state_to_numpy(pt_mod)
assert set(state) == {
"left_proj.weight",
"right_proj.weight",
"router.weight",
"out_proj.weight",
}
dp_mod = DPGridBranch(
channels=self.channels,
n_branches=n_branches,
precision="float64",
seed=9,
)

Each side is fine alone, but the merge of this PR with current master (which already contains #5515) constructs the new required-n_frames signature with the old call → TypeError. The PR's own pull_request CI ran on the pre-#5515 base, so it never saw the conflicting test; only the merge queue runs the up-to-date merge.

Fix: rebase onto current master and update the now-stale call sites to pass n_frames= (the two constructions above, plus any others), or give n_frames a default to keep the signature backward-compatible. Then re-queue.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py`:
- Around line 692-695: The `op_type` parameter docstring in the GridNet class
(around lines 692-695) incorrectly states that `"mlp"` is not ported, but this
contradicts the actual implementation which supports all three operation types
including `"mlp"`. Update the docstring for the `op_type` parameter to remove
the note claiming `"mlp"` is not ported, and ensure it accurately reflects that
`"mlp"` is a supported option alongside `"glu"` and `"branch"`, consistent with
the module docstring, BaseGridNet implementation, and serialize/deserialize
logic.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 4227543e-5cfe-4584-b0ec-2e8d377dd576

📥 Commits

Reviewing files that changed from the base of the PR and between 6a4df8d and 9ffd80f.

📒 Files selected for processing (5)
  • deepmd/dpmodel/descriptor/dpa4_nn/block.py
  • deepmd/dpmodel/descriptor/dpa4_nn/ffn.py
  • deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py
  • deepmd/pt/model/descriptor/sezm_nn/grid_net.py
  • source/tests/pt/model/test_dpa4_dpmodel_parity.py
💤 Files with no reviewable changes (1)
  • source/tests/pt/model/test_dpa4_dpmodel_parity.py
✅ Files skipped from review due to trivial changes (2)
  • deepmd/dpmodel/descriptor/dpa4_nn/block.py
  • deepmd/dpmodel/descriptor/dpa4_nn/ffn.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/model/descriptor/sezm_nn/grid_net.py

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Inline review comments failed to post. This is likely due to GitHub's internal server error or limits when posting large numbers of comments. If you are seeing this consistently it is likely a permissions issue. Please check "Moderation" -> "Code review limits" under your organization settings.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py`:
- Around line 692-695: The `op_type` parameter docstring in the GridNet class
(around lines 692-695) incorrectly states that `"mlp"` is not ported, but this
contradicts the actual implementation which supports all three operation types
including `"mlp"`. Update the docstring for the `op_type` parameter to remove
the note claiming `"mlp"` is not ported, and ensure it accurately reflects that
`"mlp"` is a supported option alongside `"glu"` and `"branch"`, consistent with
the module docstring, BaseGridNet implementation, and serialize/deserialize
logic.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 4227543e-5cfe-4584-b0ec-2e8d377dd576

📥 Commits

Reviewing files that changed from the base of the PR and between 6a4df8d and 9ffd80f.

📒 Files selected for processing (5)
  • deepmd/dpmodel/descriptor/dpa4_nn/block.py
  • deepmd/dpmodel/descriptor/dpa4_nn/ffn.py
  • deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py
  • deepmd/pt/model/descriptor/sezm_nn/grid_net.py
  • source/tests/pt/model/test_dpa4_dpmodel_parity.py
💤 Files with no reviewable changes (1)
  • source/tests/pt/model/test_dpa4_dpmodel_parity.py
✅ Files skipped from review due to trivial changes (2)
  • deepmd/dpmodel/descriptor/dpa4_nn/block.py
  • deepmd/dpmodel/descriptor/dpa4_nn/ffn.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/model/descriptor/sezm_nn/grid_net.py
🛑 Comments failed to post (1)
deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py (1)

692-695: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Stale docstring: "mlp" is now ported.

The docstring claims "mlp" is not ported, but this contradicts the module docstring (lines 15-17), the BaseGridNet implementation (lines 539-546), and the serialize/deserialize logic (lines 777, 844) which all support op_type="mlp".

📝 Proposed fix
     op_type : str
-        Point-wise grid operation; ``"glu"`` or ``"branch"`` (``"mlp"`` is
-        not ported).
+        Point-wise grid operation: ``"glu"``, ``"mlp"``, or ``"branch"``.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py` around lines 692 - 695, The
`op_type` parameter docstring in the GridNet class (around lines 692-695)
incorrectly states that `"mlp"` is not ported, but this contradicts the actual
implementation which supports all three operation types including `"mlp"`.
Update the docstring for the `op_type` parameter to remove the note claiming
`"mlp"` is not ported, and ensure it accurately reflects that `"mlp"` is a
supported option alongside `"glu"` and `"branch"`, consistent with the module
docstring, BaseGridNet implementation, and serialize/deserialize logic.

@OutisLi OutisLi enabled auto-merge June 14, 2026 03:41
@OutisLi OutisLi added this pull request to the merge queue Jun 14, 2026
@njzjz njzjz removed this pull request from the merge queue due to a manual request Jun 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants