Skip to content
Merged
19 changes: 18 additions & 1 deletion deepmd/dpmodel/descriptor/dpa4.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
annotations,
)

import logging
import math
from typing import (
TYPE_CHECKING,
Expand All @@ -36,11 +37,17 @@
import array_api_compat
import numpy as np

log = logging.getLogger(__name__)

# Warn at most once per process for backend-ignored switches (keyed by name).
_WARNED_ONCE: set[str] = set()

from deepmd.dpmodel import (
NativeOP,
)
from deepmd.dpmodel.array_api import (
xp_asarray_nodetach,
xp_take_first_n,
)
from deepmd.dpmodel.common import (
PRECISION_DICT,
Expand Down Expand Up @@ -323,6 +330,12 @@ def __init__(
# pt-runtime-only switch (CUDA bfloat16 autocast during training);
# accepted for config compatibility and ignored by dpmodel.
self.use_amp = bool(use_amp)
if self.use_amp and "use_amp" not in _WARNED_ONCE:
log.warning(
"`use_amp` has no effect on the dpmodel/pt_expt backend "
"(it is a pt-runtime CUDA autocast switch); ignoring it."
)
_WARNED_ONCE.add("use_amp")
self.trainable = bool(trainable)
self.seed = seed
self.random_gamma = bool(random_gamma)
Expand Down Expand Up @@ -811,7 +824,11 @@ def call(
pair_keep_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) != 0

# === Step 2. Type embedding (l=0) ===
atype_loc = atype_ext[:, :nloc]
# Use ``xp_take_first_n`` (torch.index_select) rather than a plain
# ``[:, :nloc]`` slice: the slice makes torch.export emit a spurious
# ``Ne(nall, nloc)`` contiguity guard that breaks the ``nall == nloc``
# (NoPBC, no ghost atoms) case in the compiled .pt2 artifact.
atype_loc = xp_take_first_n(atype_ext, 1, nloc)
type_ebed = xp.reshape(
self.type_embedding(atype_loc), (n_nodes, self.channels)
) # (N, C)
Expand Down
9 changes: 7 additions & 2 deletions deepmd/dpmodel/descriptor/dpa4_nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,17 @@ def segment_envelope_gated_softmax(
n_edge, n_focus, n_head = logits.shape
n_channel = n_focus * n_head
eps_f = float(eps)
if n_nodes <= 0 or n_edge % int(n_nodes) != 0:
# Keep ``n_nodes`` symbolic (no ``int()``): it is the product ``nf*nloc``,
# and casting to a Python int specializes it to the trace-time sample
# shape, which breaks torch.export with a dynamic ``nloc`` dim. The
# ``Mod`` check below stays statically known (``E == n_nodes*nnei``) and
# the ``(n_nodes, nnei, ...)`` reshapes recover the layout symbolically.
if n_nodes <= 0 or n_edge % n_nodes != 0:
raise ValueError(
"padded-edge layout requires E to be a multiple of n_nodes; "
f"got E={n_edge}, n_nodes={n_nodes}"
)
nnei = n_edge // int(n_nodes)
nnei = n_edge // n_nodes
device = array_api_compat.device(logits)

# === Step 1. Flatten (F, H) and build the effective per-edge weight ===
Expand Down
21 changes: 15 additions & 6 deletions deepmd/dpmodel/descriptor/dpa4_nn/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,14 @@ def call(
return xp.zeros(
(n_nodes, self.ebed_dim, self.channels), dtype=dtype, device=device
)
n_edge = int(edge_cache.dst.shape[0])
nnei = _edge_layout(n_edge, int(n_nodes))
# Keep ``n_edge``/``n_nodes`` symbolic (no ``int()``): they are the
# products ``nf*nloc*nnei`` / ``nf*nloc``. Casting to a Python int
# specializes them to the trace-time sample shape (e.g. nf*nloc==14),
# which breaks torch.export with a dynamic ``nloc`` dim. ``_edge_layout``
# returns a symbolic ``nnei`` and the masked-sum reshapes below use
# ``-1`` for the node axis to recover it symbolically.
n_edge = edge_cache.dst.shape[0]
nnei = _edge_layout(n_edge, n_nodes)

# === Step 2. Gather all m=0 columns (l >= 1) in one shot ===
# pt embedding.py:235-241 pairs one packed non-scalar row with the
Expand Down Expand Up @@ -345,7 +351,7 @@ def call(
non_scalar_out = xp.sum(
xp.reshape(
non_scalar_message,
(n_nodes, nnei, self.ebed_dim - 1, self.channels),
(-1, nnei, self.ebed_dim - 1, self.channels),
),
axis=1,
) # (N, D-1, C)
Expand Down Expand Up @@ -592,8 +598,11 @@ def call(
edge_vec = edge_cache.edge_vec # (E, 3)
edge_rbf = edge_cache.edge_rbf # (E, n_radial)
edge_env = edge_cache.edge_env # (E, 1)
n_edge = int(dst.shape[0])
nnei = _edge_layout(n_edge, int(n_nodes))
# Keep ``n_edge``/``n_nodes`` symbolic (no ``int()``); see the matching
# comment in ``GeometricInitialEmbedding.call`` for why casting to a
# Python int breaks torch.export with a dynamic ``nloc`` dim.
n_edge = dst.shape[0]
nnei = _edge_layout(n_edge, n_nodes)

# === Step 1. Construct r_tilde = [s, s*r_hat] ===
# s = edge_env * (1/r), r_hat = edge_vec / r (pt embedding.py:489-495)
Expand Down Expand Up @@ -641,7 +650,7 @@ def call(
xp.reshape(edge_mask, (n_edge, 1)), outer_flat.dtype
)
env_agg = xp.sum(
xp.reshape(outer_flat, (n_nodes, nnei, 4 * self.embed_dim)),
xp.reshape(outer_flat, (-1, nnei, 4 * self.embed_dim)),
axis=1,
) # (N, 4*embed_dim)
env_agg = xp.reshape(env_agg, (n_nodes, 4, self.embed_dim))
Expand Down
8 changes: 7 additions & 1 deletion deepmd/dpmodel/descriptor/dpa4_nn/so2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,13 @@ def call(
device = array_api_compat.device(x)
src, dst = edge_cache.src, edge_cache.dst
n_node = x.shape[0]
n_edge = int(src.shape[0])
# Keep ``n_edge``/``n_node`` symbolic (no ``int()``): they are the
# products ``nf*nloc*nnei`` / ``nf*nloc``. Casting to a Python int
# specializes them to the trace-time sample shape (breaking
# torch.export with a dynamic ``nloc`` dim); the ``Mod`` check stays
# statically known and the ``(n_node, nnei, ...)`` reshape below
# recovers the layout symbolically.
n_edge = src.shape[0]
if n_node <= 0 or n_edge % n_node != 0:
raise ValueError(
"padded-edge layout requires E to be a multiple of N; "
Expand Down
2 changes: 2 additions & 0 deletions deepmd/dpmodel/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@


@BaseModel.register("ener")
@BaseModel.register("sezm_ener")
@BaseModel.register("dpa4_ener")
class EnergyModel(DPModelCommon, DPEnergyModel_):
r"""Energy model that predicts total energy and derived quantities.

Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt_expt/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@


@BaseModel.register("ener")
@BaseModel.register("sezm_ener")
@BaseModel.register("dpa4_ener")
class EnergyModel(DPModelCommon, DPEnergyModel_):
def __init__(
self,
Expand Down
12 changes: 12 additions & 0 deletions deepmd/pt_expt/model/get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import copy
import logging
from typing import (
Any,
)
Expand Down Expand Up @@ -44,6 +45,11 @@
Spin,
)

log = logging.getLogger(__name__)

# Warn at most once per process for backend-ignored switches (keyed by name).
_WARNED_ONCE: set[str] = set()


def _get_standard_model_components(
data: dict[str, Any],
Expand Down Expand Up @@ -128,6 +134,12 @@ def get_sezm_model(data: dict) -> EnergyModel:
("highest") matmul precision, which is numerically conservative.
"""
data = copy.deepcopy(data)
if bool(data.get("enable_tf32", True)) and "enable_tf32" not in _WARNED_ONCE:
log.warning(
"`enable_tf32` has no effect on the pt_expt backend, which "
"always runs at full ('highest') matmul precision; ignoring it."
)
_WARNED_ONCE.add("enable_tf32")
if "spin" in data:
raise NotImplementedError(
"Spin DPA4/SeZM models are not supported in the pt_expt backend."
Expand Down
99 changes: 98 additions & 1 deletion deepmd/pt_expt/model/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.model.base_model import (
make_base_model,
)
from deepmd.utils.version import (
check_version_compatibility,
)


class BaseModel(make_base_model()):
Expand All @@ -16,4 +23,94 @@ class BaseModel(make_base_model()):
Backend-independent BaseModel class.
"""

pass
# The pt backend's ``SeZMModel`` (model_type "SeZM", aliases dpa4/sezm)
# serialises with a *model-level wrapper*: ``{type: "SeZM",
# atomic_model: <sezm_atomic dict>, bridging_method, bridging_r_*, lora}``,
# and its atomic model uses ``type: "sezm_atomic"`` carrying pt-only
# extras (``dens_fitting``/``active_mode`` plus a ``dens_force_rmsd``
# @variable). pt_expt builds the equivalent DPA4 model via the generic
# ``make_model`` path, whose ``serialize()`` emits the standard atomic
# dict directly (``type: "standard"``). To load a pt-trained checkpoint
# into pt_expt (the serialization-compat / checkpoint-interop
# requirement), recognise the wrapper, reject the pt-only features pt_expt
# does not implement (when they are non-default), strip the rest, and
# delegate to the standard path. The nested descriptor/fitting dicts are
# already backend-agnostic dpmodel serializations and pass through intact.
_SEZM_MODEL_TYPES = frozenset({"sezm", "dpa4"})
_SEZM_ATOMIC_TYPES = frozenset({"sezm_atomic"})

@classmethod
def deserialize(cls, data: dict[str, Any]) -> "BaseModel":
model_type = str(data.get("type", "standard"))
if model_type.lower() in cls._SEZM_MODEL_TYPES:
return cls.deserialize(cls._unwrap_pt_sezm_model(data))
if model_type.lower() in cls._SEZM_ATOMIC_TYPES:
return cls.deserialize(cls._normalize_pt_sezm_atomic(data))
return super().deserialize(data)

@staticmethod
def _unwrap_pt_sezm_model(data: dict[str, Any]) -> dict[str, Any]:
"""Unwrap pt's ``SeZMModel`` serialization to the inner atomic dict."""
data = data.copy()
# The pt SeZM model wrapper serialises with ``@version`` 1. Validate
# before discarding it so a future incompatible wrapper schema is not
# silently mis-deserialized (the wrapper only carries the guarded
# bridging/lora extras below, so the accepted range is narrow).
check_version_compatibility(int(data.get("@version", 1)), 1, 1)
bridging_method = str(data.get("bridging_method", "none")).lower()
if bridging_method not in ("none", ""):
raise NotImplementedError(
"Deserializing a pt SeZM/DPA4 checkpoint with "
f"`bridging_method`={data.get('bridging_method')!r} is not "
"supported in pt_expt."
)
if data.get("lora") is not None:
raise NotImplementedError(
"Deserializing a pt SeZM/DPA4 checkpoint with `lora` is "
"not supported in pt_expt."
)
atomic_model = data.get("atomic_model")
if atomic_model is None:
raise ValueError(
"SeZM/DPA4 model data is missing the 'atomic_model' entry."
)
return atomic_model

@staticmethod
def _normalize_pt_sezm_atomic(data: dict[str, Any]) -> dict[str, Any]:
"""Convert a pt ``sezm_atomic`` dict to a standard atomic dict.

Strips the pt-only ``dens`` head state (``dens_fitting`` /
``active_mode`` / the ``dens_force_rmsd`` @variable) and rewrites the
``type``/``@version`` so the generic dpmodel atomic-model deserialize
accepts it. A non-energy active mode or a populated dens head is
rejected because pt_expt only implements the energy path.
"""
data = data.copy()
# pt emits ``@version`` 3 for ``sezm_atomic``; the standard dpmodel
# atomic-model deserialize requires exactly 2. The only schema delta
# between the two is the stripped ``dens`` state below, so coercion is
# safe for the known-compatible range {2, 3}. Validate the incoming
# version BEFORE coercing so a future incompatible pt schema (e.g.
# ``@version`` 4) is rejected loudly instead of mis-deserialized.
check_version_compatibility(int(data.get("@version", 2)), 3, 2)
if data.pop("dens_fitting", None) is not None:
raise NotImplementedError(
"Deserializing a pt SeZM/DPA4 checkpoint with a `dens` "
"fitting head is not supported in pt_expt."
)
active_mode = data.pop("active_mode", None)
if active_mode not in (None, "ener"):
raise NotImplementedError(
f"Deserializing a pt SeZM/DPA4 checkpoint in active_mode "
f"{active_mode!r} is not supported in pt_expt (energy only)."
)
variables = data.get("@variables")
if isinstance(variables, dict):
data["@variables"] = {
k: v for k, v in variables.items() if k in ("out_bias", "out_std")
}
# The standard dpmodel atomic-model deserialize checks @version == 2.
data["@version"] = 2
data["type"] = "standard"
return data
Loading
Loading