Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions deepmd/pt_expt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@
# as it's a stateless utility class
register_dpmodel_mapping(EnvMat, lambda v: v)

# Register fake tensor implementations for custom tabulate ops.
# comm.py (border_op fake/autograd) is NOT imported here — its
# ensure_comm_registered() is called lazily from the with_comm_dict
# export path in serialization.py to avoid eager libdeepmd_op_pt.so
# loading that breaks fake-op registration order in tests.
from deepmd.pt_expt.utils import tabulate_ops # noqa: F401
# Note: tabulate_ops (fake-op registration for the compressed tabulate path)
# and comm.py (border_op fake/autograd) are intentionally NOT imported here.
# Their ensure_*_registered() helpers are called lazily from the paths that
# actually need them (compression entry / with_comm_dict export). Eager-loading
Comment thread
wanghan-iapcm marked this conversation as resolved.
# them at package import time pulls custom-op registration onto the plain pt
# (torch.jit) inference path — `deepmd.pt.infer.deep_eval` imports the vesin
# neighbor list from this package — which crashes `dp test` when the C++ op
# library is absent (the pt descriptor fallback monkeypatches a plain Python
# function onto torch.ops.deepmd, so register_fake raises "operator does not
# exist"). See tests/pt_expt/utils/test_tabulate_ops_lazy.py.

__all__ = [
"AtomExcludeMask",
Expand Down
57 changes: 33 additions & 24 deletions deepmd/pt_expt/utils/tabulate_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
compressed forward path, which uses C++ custom ops (tabulate_fusion_se_*).
Without fake implementations, torch.export cannot determine output shapes.

This module is imported at package init time (via utils/__init__.py) so
registrations happen before any descriptor code runs. If the C++ custom
op library hasn't been loaded yet at that point, `ensure_fake_registered()`
can be called again later (it is idempotent) — e.g. from the compression
entry point after the ops become available.
`ensure_fake_registered()` is called explicitly (and idempotently) by the paths
that need fake ops — e.g. the compression entry point — after the C++ custom op
library has been loaded. It is deliberately NOT called at package import time:
doing so would pull custom-op registration onto the plain pt (torch.jit)
inference path (which imports this package only for the vesin neighbor list) and
crash `dp test` when the C++ op library is absent, because the pt descriptor
fallbacks monkeypatch a plain Python function onto ``torch.ops.deepmd`` and
``register_fake`` then raises "operator does not exist".

When the C++ custom op library is loaded, the ops already have
implementations, and register_fake will raise RuntimeError. We silently
Expand All @@ -29,6 +32,20 @@
_registered: set[str] = set()


def _op_exists(name: str) -> bool:
"""Whether ``deepmd::<name>`` is a real (C++-registered) dispatcher op.

A bare ``hasattr(torch.ops.deepmd, name)`` is not sufficient: when the C++
custom-op library is absent, the pt descriptor fallbacks monkeypatch a plain
Python function onto the ``torch.ops.deepmd`` namespace (see e.g.
``deepmd/pt/model/descriptor/se_a.py``). That makes ``hasattr`` return True
while ``register_fake`` still raises "operator does not exist". Only a real
op resolves to an ``OpOverloadPacket``.
"""
op = getattr(torch.ops.deepmd, name, None)
return isinstance(op, torch._ops.OpOverloadPacket)


def _try_register_fake(op_name: str, fn: Callable[..., Any]) -> None:
"""Register a fake implementation, silently skipping if already registered."""
if op_name in _registered:
Expand All @@ -47,19 +64,15 @@ def _try_register_fake(op_name: str, fn: Callable[..., Any]) -> None:
def ensure_fake_registered() -> None:
"""Register fake implementations for all tabulate custom ops.

Only registers for ops that exist (i.e., the custom op library is loaded).
Idempotent — safe to call multiple times; already-registered ops are
skipped via the ``_registered`` set.

Called automatically at import time and should also be called from any
code path that needs fake ops after the C++ library has been loaded
(e.g. the compression entry point).
Only registers for ops that are actually loaded as real dispatcher ops
(i.e., the C++ custom op library is present). Idempotent — safe to call
multiple times; already-registered ops are skipped via the ``_registered``
set. Not called at import time: the paths that need fake ops (e.g. the
compression entry point) call this explicitly after the C++ library loads,
so that plain pt inference never triggers custom-op registration.
"""
if not hasattr(torch.ops, "deepmd"):
return

# --- tabulate_fusion_se_a ---
if hasattr(torch.ops.deepmd, "tabulate_fusion_se_a"):
if _op_exists("tabulate_fusion_se_a"):

def _fake_se_a(
table: torch.Tensor,
Expand All @@ -73,7 +86,7 @@ def _fake_se_a(
_try_register_fake("deepmd::tabulate_fusion_se_a", _fake_se_a)

# --- tabulate_fusion_se_r ---
if hasattr(torch.ops.deepmd, "tabulate_fusion_se_r"):
if _op_exists("tabulate_fusion_se_r"):

def _fake_se_r(
table: torch.Tensor,
Expand All @@ -86,7 +99,7 @@ def _fake_se_r(
_try_register_fake("deepmd::tabulate_fusion_se_r", _fake_se_r)

# --- tabulate_fusion_se_t ---
if hasattr(torch.ops.deepmd, "tabulate_fusion_se_t"):
if _op_exists("tabulate_fusion_se_t"):

def _fake_se_t(
table: torch.Tensor,
Expand All @@ -100,7 +113,7 @@ def _fake_se_t(
_try_register_fake("deepmd::tabulate_fusion_se_t", _fake_se_t)

# --- tabulate_fusion_se_t_tebd ---
if hasattr(torch.ops.deepmd, "tabulate_fusion_se_t_tebd"):
if _op_exists("tabulate_fusion_se_t_tebd"):

def _fake_se_t_tebd(
table: torch.Tensor,
Expand All @@ -116,7 +129,7 @@ def _fake_se_t_tebd(
_try_register_fake("deepmd::tabulate_fusion_se_t_tebd", _fake_se_t_tebd)

# --- tabulate_fusion_se_atten ---
if hasattr(torch.ops.deepmd, "tabulate_fusion_se_atten"):
if _op_exists("tabulate_fusion_se_atten"):

def _fake_se_atten(
table: torch.Tensor,
Expand All @@ -130,7 +143,3 @@ def _fake_se_atten(
return [table.new_empty([em.size(0), 4, last_layer_size])]

_try_register_fake("deepmd::tabulate_fusion_se_atten", _fake_se_atten)


# Best-effort at import time — ops may not be loaded yet.
ensure_fake_registered()
107 changes: 107 additions & 0 deletions source/tests/pt_expt/utils/test_tabulate_ops_lazy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Regression tests for lazy fake-op registration in ``pt_expt.utils``.

Two failure modes, both surfaced when running ``dp test`` on the plain pt
(torch.jit) backend in an environment WITHOUT the C++ custom op library
(``libdeepmd_op_pt.so``):

1. ``deepmd.pt.infer.deep_eval`` imports the vesin neighbor list from
``deepmd.pt_expt.utils``. If that package eagerly imported ``tabulate_ops``
(which registers fake custom ops at import time), plain pt inference would
drag custom-op registration onto its path.

2. When the C++ op library is absent, the pt descriptor fallbacks monkeypatch a
plain Python function onto ``torch.ops.deepmd.<op>`` (see e.g.
``deepmd/pt/model/descriptor/se_a.py``). A bare ``hasattr`` guard then
returns True even though no real dispatcher op exists, and
``register_fake`` raises ``RuntimeError: operator deepmd::... does not
exist``, crashing the import.
"""

import subprocess
import sys
import textwrap

import torch

from deepmd.pt_expt.utils import (
tabulate_ops,
)


def test_pt_deep_eval_does_not_eager_import_tabulate_ops() -> None:
"""Importing the plain pt inference entry must not pull in tabulate_ops.

Run in a fresh interpreter so ``sys.modules`` is not polluted by the test
session. Guards against re-introducing the eager
``from deepmd.pt_expt.utils import tabulate_ops`` in the package ``__init__``.
"""
code = textwrap.dedent(
"""
import sys
import deepmd.pt.infer.deep_eval # noqa: F401

leaked = [
m
for m in (
"deepmd.pt_expt.utils.tabulate_ops",
"deepmd.pt_expt.utils.comm",
)
if m in sys.modules
]
assert not leaked, f"eagerly imported custom-op modules: {leaked}"
print("OK")
"""
)
result = subprocess.run(
[sys.executable, "-c", code],
capture_output=True,
text=True,
)
assert result.returncode == 0, result.stdout + "\n" + result.stderr
assert "OK" in result.stdout


def test_ensure_fake_registered_skips_monkeypatched_fallback() -> None:
"""``ensure_fake_registered`` must skip a monkeypatched plain-function op.

Simulates the no-C++-op-library state by installing a plain Python function
on ``torch.ops.deepmd.tabulate_fusion_se_a`` (exactly what the pt descriptor
fallback does). With the old bare-``hasattr`` guard this raised
``RuntimeError: operator ... does not exist``; the fix must detect that it is
not a real ``OpOverloadPacket`` and skip it without raising.
"""
op_name = "tabulate_fusion_se_a"
qualname = "deepmd::" + op_name
ns = torch.ops.deepmd

# Snapshot any existing (possibly cached real op) attribute so we can restore.
had_attr = op_name in ns.__dict__
saved = ns.__dict__.get(op_name)
# ensure_fake_registered() may touch several op names; snapshot the whole set.
saved_registered = set(tabulate_ops._registered)

def _fallback(*args, **kwargs):
raise NotImplementedError

try:
# Install the plain-function fallback (mimics the no-op-lib descriptor hack).
setattr(ns, op_name, _fallback)
# It must NOT be recognised as a real dispatcher op.
assert not tabulate_ops._op_exists(op_name)

# Force a registration attempt for this op.
tabulate_ops._registered.discard(qualname)

# The crash repro: must complete without raising.
tabulate_ops.ensure_fake_registered()

# The monkeypatched fallback must have been skipped, not registered.
assert qualname not in tabulate_ops._registered
finally:
if had_attr:
setattr(ns, op_name, saved)
else:
ns.__dict__.pop(op_name, None)
tabulate_ops._registered.clear()
tabulate_ops._registered.update(saved_registered)
Loading