From 90faa0a4942070f11dc84b65c33aebff8188283e Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 16 Jun 2026 15:16:51 +0800 Subject: [PATCH 1/2] fix(pt): stop plain pt dp test from eager-loading pt_expt custom-op fakes deepmd.pt.infer.deep_eval imports the vesin neighbor list from deepmd.pt_expt.utils (added in #5491). That package __init__ eagerly imported tabulate_ops, which registers fake tensor impls for the compressed tabulate custom ops at import time. On the plain pt (torch.jit) backend without the C++ op library, the pt descriptor fallbacks monkeypatch a plain Python function onto torch.ops.deepmd., so the bare hasattr guard passes but register_fake raises "operator deepmd::tabulate_fusion_se_a does not exist", crashing `dp test`. Fix: - Drop the eager tabulate_ops import from pt_expt/utils/__init__.py. The only consumer that needs the fakes (the compression entry point) already calls ensure_fake_registered() lazily, so plain pt inference no longer triggers any custom-op registration. - Harden ensure_fake_registered(): guard each op with a real OpOverloadPacket check (_op_exists) instead of bare hasattr, so a monkeypatched plain-function fallback is skipped rather than crashing. Remove the import-time auto-call. Tests (source/tests/pt_expt/utils/test_tabulate_ops_lazy.py): - subprocess import of deepmd.pt.infer.deep_eval asserts tabulate_ops/ comm are not eagerly imported. - ensure_fake_registered() with a monkeypatched plain-function op present must skip it without raising (the exact dp test crash). --- deepmd/pt_expt/utils/__init__.py | 16 ++- deepmd/pt_expt/utils/tabulate_ops.py | 57 +++++---- .../pt_expt/utils/test_tabulate_ops_lazy.py | 108 ++++++++++++++++++ 3 files changed, 151 insertions(+), 30 deletions(-) create mode 100644 source/tests/pt_expt/utils/test_tabulate_ops_lazy.py diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py index 93170162c3..4e637e5d4f 100644 --- a/deepmd/pt_expt/utils/__init__.py +++ b/deepmd/pt_expt/utils/__init__.py @@ -22,12 +22,16 @@ # as it's a stateless utility class register_dpmodel_mapping(EnvMat, lambda v: v) -# Register fake tensor implementations for custom tabulate ops. -# comm.py (border_op fake/autograd) is NOT imported here — its -# ensure_comm_registered() is called lazily from the with_comm_dict -# export path in serialization.py to avoid eager libdeepmd_op_pt.so -# loading that breaks fake-op registration order in tests. -from deepmd.pt_expt.utils import tabulate_ops # noqa: F401 +# Note: tabulate_ops (fake-op registration for the compressed tabulate path) +# and comm.py (border_op fake/autograd) are intentionally NOT imported here. +# Their ensure_*_registered() helpers are called lazily from the paths that +# actually need them (compression entry / with_comm_dict export). Eager-loading +# them at package import time pulls custom-op registration onto the plain pt +# (torch.jit) inference path — `deepmd.pt.infer.deep_eval` imports the vesin +# neighbor list from this package — which crashes `dp test` when the C++ op +# library is absent (the pt descriptor fallback monkeypatches a plain Python +# function onto torch.ops.deepmd, so register_fake raises "operator does not +# exist"). See tests/pt_expt/utils/test_tabulate_ops_lazy.py. __all__ = [ "AtomExcludeMask", diff --git a/deepmd/pt_expt/utils/tabulate_ops.py b/deepmd/pt_expt/utils/tabulate_ops.py index d738d7ef3c..3e2f3db13b 100644 --- a/deepmd/pt_expt/utils/tabulate_ops.py +++ b/deepmd/pt_expt/utils/tabulate_ops.py @@ -5,11 +5,14 @@ compressed forward path, which uses C++ custom ops (tabulate_fusion_se_*). Without fake implementations, torch.export cannot determine output shapes. -This module is imported at package init time (via utils/__init__.py) so -registrations happen before any descriptor code runs. If the C++ custom -op library hasn't been loaded yet at that point, `ensure_fake_registered()` -can be called again later (it is idempotent) — e.g. from the compression -entry point after the ops become available. +`ensure_fake_registered()` is called explicitly (and idempotently) by the paths +that need fake ops — e.g. the compression entry point — after the C++ custom op +library has been loaded. It is deliberately NOT called at package import time: +doing so would pull custom-op registration onto the plain pt (torch.jit) +inference path (which imports this package only for the vesin neighbor list) and +crash `dp test` when the C++ op library is absent, because the pt descriptor +fallbacks monkeypatch a plain Python function onto ``torch.ops.deepmd`` and +``register_fake`` then raises "operator does not exist". When the C++ custom op library is loaded, the ops already have implementations, and register_fake will raise RuntimeError. We silently @@ -29,6 +32,20 @@ _registered: set[str] = set() +def _op_exists(name: str) -> bool: + """Whether ``deepmd::`` is a real (C++-registered) dispatcher op. + + A bare ``hasattr(torch.ops.deepmd, name)`` is not sufficient: when the C++ + custom-op library is absent, the pt descriptor fallbacks monkeypatch a plain + Python function onto the ``torch.ops.deepmd`` namespace (see e.g. + ``deepmd/pt/model/descriptor/se_a.py``). That makes ``hasattr`` return True + while ``register_fake`` still raises "operator does not exist". Only a real + op resolves to an ``OpOverloadPacket``. + """ + op = getattr(torch.ops.deepmd, name, None) + return isinstance(op, torch._ops.OpOverloadPacket) + + def _try_register_fake(op_name: str, fn: Callable[..., Any]) -> None: """Register a fake implementation, silently skipping if already registered.""" if op_name in _registered: @@ -47,19 +64,15 @@ def _try_register_fake(op_name: str, fn: Callable[..., Any]) -> None: def ensure_fake_registered() -> None: """Register fake implementations for all tabulate custom ops. - Only registers for ops that exist (i.e., the custom op library is loaded). - Idempotent — safe to call multiple times; already-registered ops are - skipped via the ``_registered`` set. - - Called automatically at import time and should also be called from any - code path that needs fake ops after the C++ library has been loaded - (e.g. the compression entry point). + Only registers for ops that are actually loaded as real dispatcher ops + (i.e., the C++ custom op library is present). Idempotent — safe to call + multiple times; already-registered ops are skipped via the ``_registered`` + set. Not called at import time: the paths that need fake ops (e.g. the + compression entry point) call this explicitly after the C++ library loads, + so that plain pt inference never triggers custom-op registration. """ - if not hasattr(torch.ops, "deepmd"): - return - # --- tabulate_fusion_se_a --- - if hasattr(torch.ops.deepmd, "tabulate_fusion_se_a"): + if _op_exists("tabulate_fusion_se_a"): def _fake_se_a( table: torch.Tensor, @@ -73,7 +86,7 @@ def _fake_se_a( _try_register_fake("deepmd::tabulate_fusion_se_a", _fake_se_a) # --- tabulate_fusion_se_r --- - if hasattr(torch.ops.deepmd, "tabulate_fusion_se_r"): + if _op_exists("tabulate_fusion_se_r"): def _fake_se_r( table: torch.Tensor, @@ -86,7 +99,7 @@ def _fake_se_r( _try_register_fake("deepmd::tabulate_fusion_se_r", _fake_se_r) # --- tabulate_fusion_se_t --- - if hasattr(torch.ops.deepmd, "tabulate_fusion_se_t"): + if _op_exists("tabulate_fusion_se_t"): def _fake_se_t( table: torch.Tensor, @@ -100,7 +113,7 @@ def _fake_se_t( _try_register_fake("deepmd::tabulate_fusion_se_t", _fake_se_t) # --- tabulate_fusion_se_t_tebd --- - if hasattr(torch.ops.deepmd, "tabulate_fusion_se_t_tebd"): + if _op_exists("tabulate_fusion_se_t_tebd"): def _fake_se_t_tebd( table: torch.Tensor, @@ -116,7 +129,7 @@ def _fake_se_t_tebd( _try_register_fake("deepmd::tabulate_fusion_se_t_tebd", _fake_se_t_tebd) # --- tabulate_fusion_se_atten --- - if hasattr(torch.ops.deepmd, "tabulate_fusion_se_atten"): + if _op_exists("tabulate_fusion_se_atten"): def _fake_se_atten( table: torch.Tensor, @@ -130,7 +143,3 @@ def _fake_se_atten( return [table.new_empty([em.size(0), 4, last_layer_size])] _try_register_fake("deepmd::tabulate_fusion_se_atten", _fake_se_atten) - - -# Best-effort at import time — ops may not be loaded yet. -ensure_fake_registered() diff --git a/source/tests/pt_expt/utils/test_tabulate_ops_lazy.py b/source/tests/pt_expt/utils/test_tabulate_ops_lazy.py new file mode 100644 index 0000000000..9d3062288d --- /dev/null +++ b/source/tests/pt_expt/utils/test_tabulate_ops_lazy.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Regression tests for lazy fake-op registration in ``pt_expt.utils``. + +Two failure modes, both surfaced when running ``dp test`` on the plain pt +(torch.jit) backend in an environment WITHOUT the C++ custom op library +(``libdeepmd_op_pt.so``): + +1. ``deepmd.pt.infer.deep_eval`` imports the vesin neighbor list from + ``deepmd.pt_expt.utils``. If that package eagerly imported ``tabulate_ops`` + (which registers fake custom ops at import time), plain pt inference would + drag custom-op registration onto its path. + +2. When the C++ op library is absent, the pt descriptor fallbacks monkeypatch a + plain Python function onto ``torch.ops.deepmd.`` (see e.g. + ``deepmd/pt/model/descriptor/se_a.py``). A bare ``hasattr`` guard then + returns True even though no real dispatcher op exists, and + ``register_fake`` raises ``RuntimeError: operator deepmd::... does not + exist``, crashing the import. +""" + +import subprocess +import sys +import textwrap + +import torch + +from deepmd.pt_expt.utils import ( + tabulate_ops, +) + + +def test_pt_deep_eval_does_not_eager_import_tabulate_ops() -> None: + """Importing the plain pt inference entry must not pull in tabulate_ops. + + Run in a fresh interpreter so ``sys.modules`` is not polluted by the test + session. Guards against re-introducing the eager + ``from deepmd.pt_expt.utils import tabulate_ops`` in the package ``__init__``. + """ + code = textwrap.dedent( + """ + import sys + import deepmd.pt.infer.deep_eval # noqa: F401 + + leaked = [ + m + for m in ( + "deepmd.pt_expt.utils.tabulate_ops", + "deepmd.pt_expt.utils.comm", + ) + if m in sys.modules + ] + assert not leaked, f"eagerly imported custom-op modules: {leaked}" + print("OK") + """ + ) + result = subprocess.run( + [sys.executable, "-c", code], + capture_output=True, + text=True, + ) + assert result.returncode == 0, result.stdout + "\n" + result.stderr + assert "OK" in result.stdout + + +def test_ensure_fake_registered_skips_monkeypatched_fallback() -> None: + """``ensure_fake_registered`` must skip a monkeypatched plain-function op. + + Simulates the no-C++-op-library state by installing a plain Python function + on ``torch.ops.deepmd.tabulate_fusion_se_a`` (exactly what the pt descriptor + fallback does). With the old bare-``hasattr`` guard this raised + ``RuntimeError: operator ... does not exist``; the fix must detect that it is + not a real ``OpOverloadPacket`` and skip it without raising. + """ + op_name = "tabulate_fusion_se_a" + qualname = "deepmd::" + op_name + ns = torch.ops.deepmd + + # Snapshot any existing (possibly cached real op) attribute so we can restore. + had_attr = op_name in ns.__dict__ + saved = ns.__dict__.get(op_name) + was_registered = qualname in tabulate_ops._registered + + def _fallback(*args, **kwargs): + raise NotImplementedError + + try: + # Install the plain-function fallback (mimics the no-op-lib descriptor hack). + setattr(ns, op_name, _fallback) + # It must NOT be recognised as a real dispatcher op. + assert not tabulate_ops._op_exists(op_name) + + # Force a registration attempt for this op. + tabulate_ops._registered.discard(qualname) + + # The crash repro: must complete without raising. + tabulate_ops.ensure_fake_registered() + + # The monkeypatched fallback must have been skipped, not registered. + assert qualname not in tabulate_ops._registered + finally: + if had_attr: + setattr(ns, op_name, saved) + else: + ns.__dict__.pop(op_name, None) + if was_registered: + tabulate_ops._registered.add(qualname) + else: + tabulate_ops._registered.discard(qualname) From 3eb005be526f199f5ef67667edc7a91cd5c77f7c Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 16 Jun 2026 15:55:07 +0800 Subject: [PATCH 2/2] test(pt_expt): snapshot full _registered set in tabulate_ops test Restore the entire tabulate_ops._registered set in the finally block rather than just the single op under test: ensure_fake_registered() may touch multiple op names, so per-op restore could leak module-global state across tests. Addresses CodeRabbit review on #5542. --- source/tests/pt_expt/utils/test_tabulate_ops_lazy.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/source/tests/pt_expt/utils/test_tabulate_ops_lazy.py b/source/tests/pt_expt/utils/test_tabulate_ops_lazy.py index 9d3062288d..841ada65c3 100644 --- a/source/tests/pt_expt/utils/test_tabulate_ops_lazy.py +++ b/source/tests/pt_expt/utils/test_tabulate_ops_lazy.py @@ -78,7 +78,8 @@ def test_ensure_fake_registered_skips_monkeypatched_fallback() -> None: # Snapshot any existing (possibly cached real op) attribute so we can restore. had_attr = op_name in ns.__dict__ saved = ns.__dict__.get(op_name) - was_registered = qualname in tabulate_ops._registered + # ensure_fake_registered() may touch several op names; snapshot the whole set. + saved_registered = set(tabulate_ops._registered) def _fallback(*args, **kwargs): raise NotImplementedError @@ -102,7 +103,5 @@ def _fallback(*args, **kwargs): setattr(ns, op_name, saved) else: ns.__dict__.pop(op_name, None) - if was_registered: - tabulate_ops._registered.add(qualname) - else: - tabulate_ops._registered.discard(qualname) + tabulate_ops._registered.clear() + tabulate_ops._registered.update(saved_registered)