Skip to content

fix(jax): restore batched neighbor statistics#5545

Open
njzjz wants to merge 1 commit into
deepmodeling:masterfrom
njzjz:fix-jax-update-sel-oom
Open

fix(jax): restore batched neighbor statistics#5545
njzjz wants to merge 1 commit into
deepmodeling:masterfrom
njzjz:fix-jax-update-sel-oom

Conversation

@njzjz

@njzjz njzjz commented Jun 16, 2026

Copy link
Copy Markdown
Member

Summary

  • restore JAX training neighbor statistics instead of silently skipping update_sel
  • route descriptor update_sel through the JAX NeighborStat implementation with AutoBatchSize
  • persist min_nbor_dist on the JAX model and add a regression test for sel=auto

Tests

  • ruff check .
  • pytest source/tests/jax/test_training.py -v
  • timeout 180 srun --gres=gpu:1 pytest source/tests/jax/test_training.py::TestJAXTraining::test_train_entrypoint_runs_one_step_from_scratch -v

Summary by CodeRabbit

Release Notes

  • New Features
    • JAX training now automatically computes neighbor statistics during training to optimize model parameters
    • Minimum neighbor distance is automatically determined and configured on the model during training
    • Optional setting available to skip neighbor statistics computation if preferred

@njzjz njzjz requested a review from wanghan-iapcm June 16, 2026 19:25
@coderabbitai

coderabbitai Bot commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: f463a3ab-9352-4696-b92a-9e49e3ecc6da

📥 Commits

Reviewing files that changed from the base of the PR and between 69409c5 and 309ff37.

📒 Files selected for processing (3)
  • deepmd/jax/entrypoints/train.py
  • deepmd/jax/utils/update_sel.py
  • source/tests/jax/test_training.py

📝 Walkthrough

Walkthrough

Adds a new deepmd/jax/utils/update_sel.py module with a JAX-specific UpdateSel class and use_jax_update_sel() context manager that temporarily patches descriptor plugin classes. The JAX training entrypoint's update_sel function is replaced with a real implementation that loads data, calls BaseModel.update_sel under the context manager, and returns both updated jdata and min_nbor_dist, which train() now propagates onto the model.

Changes

JAX Neighbor Stat Selection Update

Layer / File(s) Summary
JAX UpdateSel class and context manager
deepmd/jax/utils/update_sel.py
New module defining UpdateSel (backed by JAX NeighborStat), _get_update_sel_descriptors() for plugin discovery, and use_jax_update_sel() context manager that saves, patches, and restores _update_sel_cls on all relevant descriptor plugin classes.
Training entrypoint wiring
deepmd/jax/entrypoints/train.py
Adds BaseModel and use_jax_update_sel imports; replaces the no-op update_sel with a real implementation returning (jdata, min_nbor_dist); train() unpacks both values and conditionally assigns min_nbor_dist to model.model.min_nbor_dist.
Regression test
source/tests/jax/test_training.py
Adds test_update_sel_uses_jax_neighbor_stat that mocks get_data and UpdateSel.get_nbor_stat, invokes update_sel, and asserts updated sel, returned distance value, and mock call arguments.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'fix(jax): restore batched neighbor statistics' clearly and concisely describes the main change: restoring neighbor statistics functionality in JAX training. It matches the core objective stated in the PR objectives.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov

codecov Bot commented Jun 16, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 87.17949% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 80.96%. Comparing base (69409c5) to head (309ff37).

Files with missing lines Patch % Lines
deepmd/jax/entrypoints/train.py 69.23% 4 Missing ⚠️
deepmd/jax/utils/update_sel.py 96.15% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5545      +/-   ##
==========================================
- Coverage   82.22%   80.96%   -1.26%     
==========================================
  Files         892      893       +1     
  Lines      101548   101585      +37     
  Branches     4242     4243       +1     
==========================================
- Hits        83493    82251    -1242     
- Misses      16754    18034    +1280     
+ Partials     1301     1300       -1     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant