Skip to content

refactor(jax): auto-convert dpmodel modules#5527

Merged
wanghan-iapcm merged 2 commits into
deepmodeling:masterfrom
njzjz:jax-reg
Jun 14, 2026
Merged

refactor(jax): auto-convert dpmodel modules#5527
wanghan-iapcm merged 2 commits into
deepmodeling:masterfrom
njzjz:jax-reg

Conversation

@njzjz

@njzjz njzjz commented Jun 13, 2026

Copy link
Copy Markdown
Member

Summary:

  • add a dpmodel-to-JAX wrapper registry and lazy registration path
  • replace many hand-written JAX setattr conversions with shared conversion logic
  • handle Flax NNX data-list attributes for DPA2/DPA3 residual parameters

Tests:

  • ruff check .
  • ruff format --check .
  • pytest source/tests/consistent/descriptor/test_dpa3.py -k jax -q
  • pytest source/tests/consistent/descriptor/test_dpa2.py -k jax -q

Summary by CodeRabbit

  • Refactor
    • Simplified JAX/Flax wrapper layers by removing scattered per-class attribute conversion logic and centralizing model-to-JAX conversions and registrations into shared utilities.
    • Streamlined descriptor, fitting, atomic-model, and network wrappers to be thinner subclasses; attribute assignment now follows unified conversion behavior.
    • No public API signatures changed; runtime behavior for assigning model components is more consistent and predictable.

@github-advanced-security github-advanced-security AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CodeQL found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

@coderabbitai

coderabbitai Bot commented Jun 13, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 7f158ef0-14fd-42b6-b0a0-5c3419666e7b

📥 Commits

Reviewing files that changed from the base of the PR and between 5b8e3c7 and e7bf953.

📒 Files selected for processing (3)
  • deepmd/jax/common.py
  • deepmd/jax/model/__init__.py
  • deepmd/jax/model/dp_zbl_model.py
✅ Files skipped from review due to trivial changes (1)
  • deepmd/jax/model/init.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • deepmd/jax/model/dp_zbl_model.py
  • deepmd/jax/common.py

📝 Walkthrough

Walkthrough

This PR centralizes JAX-side attribute conversion in deepmd/jax/common.py, removes many per-class __setattr__ overrides, simplifies JAX/Flax wrapper classes to thin subclasses or declarative forms, and registers dpmodel→JAX mapping converters and class-level conversion controls.

Changes

JAX Attribute-Setting Consolidation

Layer / File(s) Summary
Central conversion infrastructure
deepmd/jax/common.py
Establishes registry API (register_dpmodel_mapping, try_convert_module, dpmodel_setattr) and integrates conversion into flax_module-generated classes' __setattr__, handling ndarray/list conversion and dpmodel auto-wrapping.
Utility class simplification
deepmd/jax/utils/exclude_mask.py, deepmd/jax/utils/network.py, deepmd/jax/utils/type_embed.py
Removes custom __setattr__ implementations, adds declarative flags (_jax_skip_auto_convert_attrs, _jax_data_list_attrs), and registers dpmodel mappings for exclude-mask and network types.
Descriptor class simplification
deepmd/jax/descriptor/* (dpa1/dpa2/dpa3/hybrid/repflows/repformers/se_*)
Converts descriptor wrappers to thin pass subclasses, deletes per-field deserialization code, adds residual-list declarative attrs where applicable, and registers conversion mappings (e.g., se_atten_v2).
Atomic model and fitting class simplification
deepmd/jax/atomic_model/*, deepmd/jax/fitting/fitting.py
Removes __setattr__ overrides and the fitting helper; updates imports to rely on centralized conversion and side-effect registrations.
Model-level cleanup
deepmd/jax/model/dp_model.py, dp_zbl_model.py, deepmd/jax/atomic_model/base_atomic_model.py
Deletes model-level __setattr__ interception of atomic_model assignments and removes the now-unused base_atomic_model_set_attr helper and related imports.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5214: Modifies deepmd/jax/common.py's flax_module decorator implementation with metadata-preserving functools.wraps, affecting the same code path as this PR's flax_module/__setattr__ conversion integration.
  • deepmodeling/deepmd-kit#5067: Overlaps on Flax 0.12 compatibility changes and handling of nnx/nnx.data conversions that this PR centralizes and simplifies.
  • deepmodeling/deepmd-kit#5204: Introduces a parallel registry-driven conversion mechanism for another backend, showing a similar architectural pattern.

Suggested reviewers

  • iProzd
  • wanghan-iapcm
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 23.08% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'refactor(jax): auto-convert dpmodel modules' clearly and accurately describes the main change: a refactoring that introduces automatic conversion of dpmodel modules to JAX wrappers via a registry system.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
deepmd/jax/model/dp_zbl_model.py (1)

3-5: ⚡ Quick win

If this import is registration-only, make that side effect explicit.

After removing DPZBLModel.__setattr__, this file depends on the ZBL atomic-model wrapper being imported before atomic_model assignments hit the centralized conversion path. The current symbol import hides that dependency; mirroring the underscore-prefixed # noqa: F401 side-effect import style used in the other JAX wrapper files would make the contract much clearer. Based on the import patterns in the other JAX wrapper files in this PR and the model-level wiring summary.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/jax/model/dp_zbl_model.py` around lines 3 - 5, This import is
registration-only and should make that side effect explicit: change the
statement that brings in DPZBLLinearEnergyAtomicModel to an underscore-prefixed,
noop import and add a noqa comment (e.g., import DPZBLLinearEnergyAtomicModel as
_DPZBLLinearEnergyAtomicModel  # noqa: F401) so the file documents the
registration side effect required by DPZBLModel.__setattr__ removal and
suppresses unused-symbol warnings.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/jax/common.py`:
- Around line 107-134: The registration routine can be raced: modify
_ensure_registrations to use a threading.Condition (or Lock+Condition) to
serialize the first registration and block concurrent callers until the
registration pass completes; specifically, protect _REGISTRATIONS_IN_PROGRESS
and _REGISTRATIONS_READY with a module-level Condition, then in
_ensure_registrations acquire the condition and if _REGISTRATIONS_READY return,
if _REGISTRATIONS_IN_PROGRESS call cond.wait() in a loop until
_REGISTRATIONS_READY or in_progress is false, otherwise set
_REGISTRATIONS_IN_PROGRESS = True and release the condition while performing the
import_module loop, then in a finally re-acquire the condition, set
_REGISTRATIONS_IN_PROGRESS = False, set _REGISTRATIONS_READY = True only on
successful registration, and call cond.notify_all() so waiting callers resume
and see the final registry state used by try_convert_module/_DPMODEL_TO_JAX.

---

Nitpick comments:
In `@deepmd/jax/model/dp_zbl_model.py`:
- Around line 3-5: This import is registration-only and should make that side
effect explicit: change the statement that brings in
DPZBLLinearEnergyAtomicModel to an underscore-prefixed, noop import and add a
noqa comment (e.g., import DPZBLLinearEnergyAtomicModel as
_DPZBLLinearEnergyAtomicModel  # noqa: F401) so the file documents the
registration side effect required by DPZBLModel.__setattr__ removal and
suppresses unused-symbol warnings.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: b701da7a-cf1f-4f7a-9929-0bd77e43855c

📥 Commits

Reviewing files that changed from the base of the PR and between 5d94bd6 and 5b8e3c7.

📒 Files selected for processing (22)
  • deepmd/jax/atomic_model/base_atomic_model.py
  • deepmd/jax/atomic_model/dp_atomic_model.py
  • deepmd/jax/atomic_model/linear_atomic_model.py
  • deepmd/jax/atomic_model/pairtab_atomic_model.py
  • deepmd/jax/common.py
  • deepmd/jax/descriptor/dpa1.py
  • deepmd/jax/descriptor/dpa2.py
  • deepmd/jax/descriptor/dpa3.py
  • deepmd/jax/descriptor/hybrid.py
  • deepmd/jax/descriptor/repflows.py
  • deepmd/jax/descriptor/repformers.py
  • deepmd/jax/descriptor/se_atten_v2.py
  • deepmd/jax/descriptor/se_e2_a.py
  • deepmd/jax/descriptor/se_e2_r.py
  • deepmd/jax/descriptor/se_t.py
  • deepmd/jax/descriptor/se_t_tebd.py
  • deepmd/jax/fitting/fitting.py
  • deepmd/jax/model/dp_model.py
  • deepmd/jax/model/dp_zbl_model.py
  • deepmd/jax/utils/exclude_mask.py
  • deepmd/jax/utils/network.py
  • deepmd/jax/utils/type_embed.py
💤 Files with no reviewable changes (2)
  • deepmd/jax/atomic_model/base_atomic_model.py
  • deepmd/jax/model/dp_model.py

Comment thread deepmd/jax/common.py
@njzjz njzjz requested a review from wanghan-iapcm June 13, 2026 16:41
@codecov

codecov Bot commented Jun 13, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 94.78261% with 12 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.18%. Comparing base (5d94bd6) to head (e7bf953).

Files with missing lines Patch % Lines
deepmd/jax/common.py 91.11% 12 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5527      +/-   ##
==========================================
- Coverage   82.19%   82.18%   -0.01%     
==========================================
  Files         891      890       -1     
  Lines      101599   101357     -242     
  Branches     4242     4242              
==========================================
- Hits        83507    83299     -208     
+ Misses      16789    16754      -35     
- Partials     1303     1304       +1     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Jun 14, 2026
Merged via the queue into deepmodeling:master with commit c0b0319 Jun 14, 2026
70 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants