From 4c6b538a3e073fe28d95c073df05ac7fd1421948 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Thu, 11 Jun 2026 19:02:12 +0800 Subject: [PATCH 1/3] perf(dpa4): opt so3grid --- .../pt/model/descriptor/sezm_nn/grid_net.py | 213 +++++++++++++----- 1 file changed, 156 insertions(+), 57 deletions(-) diff --git a/deepmd/pt/model/descriptor/sezm_nn/grid_net.py b/deepmd/pt/model/descriptor/sezm_nn/grid_net.py index 867dd47782..b0226203fe 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/grid_net.py +++ b/deepmd/pt/model/descriptor/sezm_nn/grid_net.py @@ -20,6 +20,7 @@ ) from typing import ( + TYPE_CHECKING, Literal, ) @@ -54,6 +55,11 @@ FocusLinear, ) +if TYPE_CHECKING: + from collections.abc import ( + Callable, + ) + GridNetLayout = Literal["ndfc", "nfdc", "flat"] GridNetMode = Literal["self", "cross"] GridNetOp = Literal["glu", "mlp", "branch"] @@ -78,6 +84,70 @@ def _build_frame_degree_index( raise ValueError("`coefficient_layout` must be either 'packed' or 'm_major'") +def _project_frames( + coeff: torch.Tensor, proj: ChannelLinear, n_frames: int +) -> torch.Tensor: + """ + Apply a channel-only linear map to each Wigner-D frame independently. + + Parameters + ---------- + coeff : torch.Tensor + Frame-packed coefficients with shape ``(N, D, F, n_frames * C_in)``. + proj : ChannelLinear + Linear map acting on the per-frame channel axis (``C_in -> C_out``). + n_frames : int + Number of Wigner-D frames packed along the trailing axis. + + Returns + ------- + torch.Tensor + Projected coefficients with shape ``(N, D, F, n_frames * C_out)``. + + Notes + ----- + ``to_grid`` and ``from_grid`` are frame-wise linear and commute with any + channel map, so applying the map at coefficient resolution here is identical + to applying it on the grid field while touching ``n_frames``-fold fewer rows + than the ``G``-point grid. + """ + n_batch, coeff_dim, n_focus, _ = coeff.shape + projected = proj(coeff.reshape(n_batch, coeff_dim, n_focus, n_frames, -1)) + return projected.reshape(n_batch, coeff_dim, n_focus, -1) + + +class GridProduct(nn.Module): + """Parameter-free quadratic grid product ``u(g) * v(g)``.""" + + def forward( + self, + left: torch.Tensor, + right: torch.Tensor, + scalar_pair: torch.Tensor, + *, + to_grid: Callable[[torch.Tensor], torch.Tensor], + from_grid: Callable[[torch.Tensor], torch.Tensor], + ) -> torch.Tensor: + """ + Combine two coefficient operands by a point-wise grid product. + + Parameters + ---------- + left, right : torch.Tensor + Coefficient operands with shape ``(N, D, F, n_frames * C)``. + scalar_pair : torch.Tensor + Invariant routing signal; unused on this path. + to_grid, from_grid : Callable + Coefficient/grid projectors supplied by the owning grid net. + + Returns + ------- + torch.Tensor + Coefficient result with shape ``(N, D, F, n_frames * C)``. + """ + return from_grid(to_grid(left) * to_grid(right)) + + class GridMLP(nn.Module): """Polynomial point-wise MLP applied independently at every grid point.""" @@ -86,6 +156,7 @@ def __init__( *, channels: int, mode: GridNetMode, + n_frames: int, dtype: torch.dtype, trainable: bool, seed: int | list[int] | None = None, @@ -95,6 +166,7 @@ def __init__( self.mode = str(mode).lower() if self.mode not in {"self", "cross"}: raise ValueError("`mode` must be either 'self' or 'cross'") + self.n_frames = int(n_frames) self.input_channels = ( 2 * self.channels if self.mode == "self" else self.channels ) @@ -125,24 +197,51 @@ def __init__( ) def forward( - self, query_grid: torch.Tensor, context_grid: torch.Tensor + self, + left: torch.Tensor, + right: torch.Tensor, + scalar_pair: torch.Tensor, + *, + to_grid: Callable[[torch.Tensor], torch.Tensor], + from_grid: Callable[[torch.Tensor], torch.Tensor], ) -> torch.Tensor: """ - Apply the point-wise polynomial MLP to ``(N, G, F, C)`` grid fields. + Apply the polynomial point-wise MLP on coefficient operands. + + In self mode, both projections see the per-frame concatenation of the + two operands and can form self and cross quadratic channel terms. In + cross mode the query and context roles stay separate: + ``(W_q query) * (W_c context)``. - In self mode, both projections see ``concat(query_grid, context_grid)`` - and can form self and cross quadratic channel terms. In cross mode, - the query and context roles stay separate: - ``(W_q query_grid) * (W_c context_grid)``. + Parameters + ---------- + left, right : torch.Tensor + Coefficient operands with shape ``(N, D, F, n_frames * C)``. + scalar_pair : torch.Tensor + Invariant routing signal; unused on this path. + to_grid, from_grid : Callable + Coefficient/grid projectors supplied by the owning grid net. + + Returns + ------- + torch.Tensor + Coefficient result with shape ``(N, D, F, n_frames * C)``. """ + # === Step 1. Channel projections at coefficient resolution === if self.mode == "self": - grid = torch.cat([query_grid, context_grid], dim=-1) - left = self.left_proj(grid) - right = self.right_proj(grid) + shape = (*left.shape[:-1], self.n_frames, -1) + fused = torch.cat( + [left.reshape(shape), right.reshape(shape)], dim=-1 + ).reshape(*left.shape[:-1], -1) # per-frame concat -> (N, D, F, K*2C) + left = _project_frames(fused, self.left_proj, self.n_frames) + right = _project_frames(fused, self.right_proj, self.n_frames) else: - left = self.left_proj(query_grid) - right = self.right_proj(context_grid) - return self.out_proj(left * right) + left = _project_frames(left, self.left_proj, self.n_frames) + right = _project_frames(right, self.right_proj, self.n_frames) + + # === Step 2. Quadratic product on the grid, projected back === + coeff = from_grid(to_grid(left) * to_grid(right)) + return _project_frames(coeff, self.out_proj, self.n_frames) class GridBranch(nn.Module): @@ -159,6 +258,7 @@ def __init__( *, channels: int, n_branches: int, + n_frames: int, dtype: torch.dtype, trainable: bool, seed: int | list[int] | None = None, @@ -168,6 +268,7 @@ def __init__( self.n_branches = int(n_branches) if self.n_branches < 1: raise ValueError("`n_branches` must be positive") + self.n_frames = int(n_frames) self.left_proj = ChannelLinear( in_channels=self.channels, out_channels=self.n_branches * self.channels, @@ -203,35 +304,43 @@ def __init__( def forward( self, - query_grid: torch.Tensor, - context_grid: torch.Tensor, + left: torch.Tensor, + right: torch.Tensor, scalar_pair: torch.Tensor, + *, + to_grid: Callable[[torch.Tensor], torch.Tensor], + from_grid: Callable[[torch.Tensor], torch.Tensor], ) -> torch.Tensor: """ - Apply scalar-routed grid branch mixing. + Apply scalar-routed grid branch mixing on coefficient operands. Parameters ---------- - query_grid - First grid source with shape ``(N, G, F, C)``. - context_grid - Second grid source with shape ``(N, G, F, C)``. - scalar_pair + left, right : torch.Tensor + Coefficient operands with shape ``(N, D, F, n_frames * C)``. + scalar_pair : torch.Tensor Invariant router source with shape ``(N, F, 2*C)``. + to_grid, from_grid : Callable + Coefficient/grid projectors supplied by the owning grid net. + + Returns + ------- + torch.Tensor + Coefficient result with shape ``(N, D, F, n_frames * C)``. """ - n_batch, n_grid, n_focus, _ = query_grid.shape - left = self.left_proj(query_grid) - right = self.right_proj(context_grid) - value = (left * right).reshape( - n_batch, - n_grid, - n_focus, - self.n_branches, - self.channels, - ) # (N, G, F, N_branches, C) + # === Step 1. Branch channel projections at coefficient resolution === + left = _project_frames(left, self.left_proj, self.n_frames) + right = _project_frames(right, self.right_proj, self.n_frames) + + # === Step 2. Quadratic branches on the grid, routed by scalars === + value = to_grid(left) * to_grid(right) # (N, G, F, N_branches * C) + n_batch, n_grid, n_focus, _ = value.shape + value = value.reshape(n_batch, n_grid, n_focus, self.n_branches, self.channels) router = torch.softmax(self.router(scalar_pair), dim=-1) # (N, F, N_branches) out = torch.einsum("ngfhc,nfh->ngfc", value, router) # (N, G, F, C) - return self.out_proj(out) + + # === Step 3. Project back to coefficients and mix output channels === + return _project_frames(from_grid(out), self.out_proj, self.n_frames) class FrameContract(nn.Module): @@ -407,9 +516,10 @@ def __init__( init_std=0.01, ) if self.op_type == "mlp": - self.grid_op = GridMLP( + self.grid_op: nn.Module = GridMLP( channels=self.channels, mode=self.mode, + n_frames=self.n_frames, dtype=self.dtype, trainable=trainable, seed=child_seed(seed, 1), @@ -418,12 +528,13 @@ def __init__( self.grid_op = GridBranch( channels=self.channels, n_branches=grid_branches, + n_frames=self.n_frames, dtype=self.dtype, trainable=trainable, seed=child_seed(seed, 1), ) else: - self.grid_op = nn.Identity() + self.grid_op = GridProduct() if residual_scale_init is None: self.residual_scale = None @@ -448,8 +559,13 @@ def forward( input_dtype = query.dtype query_ndfc, shape_info = self._to_ndfc(query) left, right, scalar_pair = self._prepare_pair(query_ndfc, context) - grid_out = self._apply_grid_op(left, right, scalar_pair) - coeff_out = self._from_grid(grid_out) + coeff_out = self.grid_op( + left.to(dtype=self.dtype), + right.to(dtype=self.dtype), + scalar_pair, + to_grid=self._to_grid, + from_grid=self._from_grid, + ) coeff_out = self._apply_scalar_path(coeff_out, scalar_pair) coeff_out = self._contract_frames(coeff_out) coeff_out = self._apply_residual_scale(coeff_out) @@ -499,20 +615,6 @@ def _prepare_cross_pair( scalar_pair, ) - def _apply_grid_op( - self, - left: torch.Tensor, - right: torch.Tensor, - scalar_pair: torch.Tensor, - ) -> torch.Tensor: - left_grid = self._to_grid(left.to(dtype=self.dtype)) - right_grid = self._to_grid(right.to(dtype=self.dtype)) - if self.op_type == "glu": - return left_grid * right_grid - if self.op_type == "mlp": - return self.grid_op(left_grid, right_grid) - return self.grid_op(left_grid, right_grid, scalar_pair) - def _contract_frames(self, coeff: torch.Tensor) -> torch.Tensor: if self.frame_contract is None: return coeff @@ -576,14 +678,10 @@ def _extract_scalar(self, coeff: torch.Tensor) -> torch.Tensor: return coeff_view[:, 0, :, self.frame_zero_index, :] def _to_grid(self, coeff: torch.Tensor) -> torch.Tensor: + # The per-frame channel width is inferred so the projector also serves + # widened operands (e.g. a branch hidden width ``n_branches * C``). n_batch, coeff_dim, n_focus, _ = coeff.shape - coeff_view = coeff.reshape( - n_batch, - coeff_dim, - n_focus, - self.n_frames, - self.channels, - ) + coeff_view = coeff.reshape(n_batch, coeff_dim, n_focus, self.n_frames, -1) to_grid = self.projector.to_grid_mat.reshape( self.projector.grid_size, coeff_dim, @@ -592,6 +690,7 @@ def _to_grid(self, coeff: torch.Tensor) -> torch.Tensor: return torch.einsum("gdk,ndfkc->ngfc", to_grid, coeff_view) def _from_grid(self, grid: torch.Tensor) -> torch.Tensor: + # Channel width is inferred to match the (possibly widened) grid field. n_batch, _, n_focus, _ = grid.shape coeff_dim = self.projector.coeff_dim // self.n_frames from_grid = self.projector.from_grid_mat.reshape( @@ -600,7 +699,7 @@ def _from_grid(self, grid: torch.Tensor) -> torch.Tensor: self.projector.grid_size, ) coeff = torch.einsum("dkg,ngfc->ndfkc", from_grid, grid) - return coeff.reshape(n_batch, coeff_dim, n_focus, self.expanded_channels) + return coeff.reshape(n_batch, coeff_dim, n_focus, -1) def _to_ndfc(self, value: torch.Tensor) -> tuple[torch.Tensor, tuple[int, ...]]: if self.layout == "ndfc": From b5471eeea9200f73917b24c438a38259bfb44413 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sun, 14 Jun 2026 11:31:47 +0800 Subject: [PATCH 2/3] fix --- deepmd/dpmodel/descriptor/dpa4_nn/block.py | 3 +- deepmd/dpmodel/descriptor/dpa4_nn/ffn.py | 11 +- deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py | 296 +++++++++++++++--- .../pt/model/test_dpa4_dpmodel_parity.py | 126 ++++++-- 4 files changed, 352 insertions(+), 84 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/block.py b/deepmd/dpmodel/descriptor/dpa4_nn/block.py index ce529c3cd2..b17e2c0c4c 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/block.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/block.py @@ -21,8 +21,7 @@ not duplicated here): ``so2_attn_res``, ``so2_s2_activation``, ``node_wise_s2/so3``, ``message_node_s2/so3``, ``atten_f_mix``, ``atten_v_proj``, ``atten_o_proj`` (raised by ``SO2Convolution``) and -``ffn_so3_grid``, ``ffn_grid_mlp`` with the grid path active (raised by -``EquivariantFFN`` / ``S2GridNet``). +``ffn_so3_grid`` with the grid path active (raised by ``EquivariantFFN``). The pt eval-time activation-checkpoint / nvtx instrumentation (``DP_ACT_INFER``, ``DP_COMPILE_INFER``, ``nvtx_range``) is pt-runtime-only diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/ffn.py b/deepmd/dpmodel/descriptor/dpa4_nn/ffn.py index a54cf7c4eb..673e4ee836 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/ffn.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/ffn.py @@ -11,10 +11,6 @@ - ``ffn_so3_grid=True`` — the pt path instantiates ``SO3GridNet`` (pt ffn.py:209), which is not ported to dpmodel. -- ``grid_mlp=True`` together with the grid path active selects - ``op_type='mlp'`` for ``S2GridNet`` (pt ffn.py:206); the delegate - ``S2GridNet`` constructor raises for that op type (``GridMLP`` is not - ported), so no duplicate guard is added here. """ from __future__ import ( @@ -85,9 +81,9 @@ class EquivariantFFN(NativeOP): kmax Maximum Wigner-D frame order (|k|) used by the SO3 Wigner-D FFN grid. grid_mlp - If True, select the polynomial grid MLP operation when the - block-internal FFN grid path is enabled. Not ported: the delegate - ``S2GridNet`` raises ``NotImplementedError`` for ``op_type='mlp'``. + If True, select the polynomial grid MLP operation (``op_type='mlp'``) + when the block-internal FFN grid path is enabled. ``grid_branch`` takes + precedence when positive. grid_branch Number of scalar-routed polynomial product branches used when the block-internal FFN grid path is enabled. ``0`` disables this branch @@ -203,7 +199,6 @@ def __init__( if self.use_grid_branch else ("mlp" if self.use_grid_mlp else "glu") ) - # op_type='mlp' raises NotImplementedError inside S2GridNet self.act: NativeOP = S2GridNet( lmax=self.lmax, channels=self.hidden_channels, diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py index 70db705bb9..ff11b5ad6d 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py @@ -12,8 +12,9 @@ * ``mode='self'``: one input ``(N, D, F, 2*C)`` or ``(N, F, D, 2*C)``. * grid values: ``(N, G, F, C)`` after S2 projection. -Ported names: ``BaseGridNet`` (``mode='self'``; ``op_type`` 'glu'/'branch'), -``S2GridNet``, ``GridBranch``. +Ported names: ``BaseGridNet`` (``mode='self'``; ``op_type`` +'glu'/'mlp'/'branch'), ``S2GridNet``, ``GridProduct``, ``GridMLP``, +``GridBranch``. Skipped names, with consumer evidence from the pt sources: @@ -23,9 +24,6 @@ - ``FrameContract``, ``FrameExpand``, ``_build_frame_degree_index``: only constructed by ``SO3GridNet`` (``mode='cross'``); the S2 projector always has ``n_frames == 1``, so the frame machinery is unreachable here. -- ``GridMLP``: only selected via ``op_type='mlp'`` (``grid_mlp=True`` paths); - the core config has ``grid_mlp=[False, False, False]``. ``BaseGridNet`` - raises ``NotImplementedError`` for ``op_type='mlp'``. Guarded (routable from the shared ``S2GridNet`` entry point but only used by the disabled ``node_wise_s2``/``message_node_s2`` grid products in @@ -47,6 +45,7 @@ ) from typing import ( + TYPE_CHECKING, Any, ) @@ -85,6 +84,11 @@ FocusLinear, ) +if TYPE_CHECKING: + from collections.abc import ( + Callable, + ) + def _softmax_last_axis(x: Any) -> Any: """Numerically stable softmax on the last axis (matches torch.softmax).""" @@ -93,6 +97,194 @@ def _softmax_last_axis(x: Any) -> Any: return e_x / xp.sum(e_x, axis=-1, keepdims=True) +class GridProduct(NativeOP): + """Parameter-free quadratic grid product ``u(g) * v(g)``.""" + + def call( + self, + left: Any, + right: Any, + scalar_pair: Any, + *, + to_grid: Callable[[Any], Any], + from_grid: Callable[[Any], Any], + ) -> Any: + """ + Combine two coefficient operands by a point-wise grid product. + + Parameters + ---------- + left, right + Coefficient operands with shape ``(N, D, F, C)``. + scalar_pair + Invariant routing signal; unused on this path. + to_grid, from_grid + Coefficient/grid projectors supplied by the owning grid net. + """ + return from_grid(to_grid(left) * to_grid(right)) + + +class GridMLP(NativeOP): + """ + Polynomial point-wise MLP applied independently at every grid point. + + Specialized to the S2 ``n_frames == 1`` case, so the per-frame packing of + the pt ``GridMLP`` collapses to a plain channel concatenation in self mode. + + Parameters + ---------- + channels : int + Number of channels per grid point. + mode : str + Pairing mode, either ``"self"`` or ``"cross"``. + precision : str + Parameter precision. + trainable : bool + Whether parameters are trainable. + seed : int | list[int] | None + Random seed for weight initialization. + """ + + def __init__( + self, + *, + channels: int, + mode: str, + precision: str = DEFAULT_PRECISION, + trainable: bool = True, + seed: int | list[int] | None = None, + ) -> None: + self.channels = int(channels) + self.mode = str(mode).lower() + if self.mode not in {"self", "cross"}: + raise ValueError("`mode` must be either 'self' or 'cross'") + self.precision = precision + self.trainable = bool(trainable) + self.input_channels = ( + 2 * self.channels if self.mode == "self" else self.channels + ) + self.hidden_channels = 2 * self.channels + self.left_proj = ChannelLinear( + in_channels=self.input_channels, + out_channels=self.hidden_channels, + precision=precision, + bias=False, + trainable=trainable, + seed=child_seed(seed, 0), + ) + self.right_proj = ChannelLinear( + in_channels=self.input_channels, + out_channels=self.hidden_channels, + precision=precision, + bias=False, + trainable=trainable, + seed=child_seed(seed, 1), + ) + self.out_proj = ChannelLinear( + in_channels=self.hidden_channels, + out_channels=self.channels, + precision=precision, + bias=False, + trainable=trainable, + seed=child_seed(seed, 2), + ) + + def call( + self, + left: Any, + right: Any, + scalar_pair: Any, + *, + to_grid: Callable[[Any], Any], + from_grid: Callable[[Any], Any], + ) -> Any: + """ + Apply the polynomial point-wise MLP on coefficient operands. + + In self mode both projections see the concatenation of the two operands + and can form self and cross quadratic channel terms. In cross mode the + query and context roles stay separate: ``(W_q query) * (W_c context)``. + + Parameters + ---------- + left, right + Coefficient operands with shape ``(N, D, F, C)``. + scalar_pair + Invariant routing signal; unused on this path. + to_grid, from_grid + Coefficient/grid projectors supplied by the owning grid net. + """ + if self.mode == "self": + xp = array_api_compat.array_namespace(left) + fused = xp.concat([left, right], axis=-1) # (N, D, F, 2C) + left = self.left_proj(fused) + right = self.right_proj(fused) + else: + left = self.left_proj(left) + right = self.right_proj(right) + coeff = from_grid(to_grid(left) * to_grid(right)) # (N, D, F, 2C) + return self.out_proj(coeff) # (N, D, F, C) + + def serialize(self) -> dict[str, Any]: + """Serialize the GridMLP to a dict. + + The pt ``GridMLP`` has no ``serialize()``; the ``@variables`` keys here + match the pt ``state_dict`` key names. + """ + return { + "@class": "GridMLP", + "@version": 1, + "config": { + "channels": self.channels, + "mode": self.mode, + "precision": np.dtype(PRECISION_DICT[self.precision]).name, + "trainable": self.trainable, + "seed": None, + }, + "@variables": { + "left_proj.weight": to_numpy_array(self.left_proj.weight), + "right_proj.weight": to_numpy_array(self.right_proj.weight), + "out_proj.weight": to_numpy_array(self.out_proj.weight), + }, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> GridMLP: + """Deserialize a GridMLP from a dict.""" + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "GridMLP": + raise ValueError(f"Invalid class for GridMLP: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + obj = cls( + channels=int(config["channels"]), + mode=str(config["mode"]), + precision=str(config["precision"]), + trainable=bool(config["trainable"]), + seed=config.get("seed"), + ) + obj._load_variables(variables) + return obj + + def _load_variables(self, variables: dict[str, Any]) -> None: + prec = PRECISION_DICT[self.precision.lower()] + for name, proj in ( + ("left_proj", self.left_proj), + ("right_proj", self.right_proj), + ("out_proj", self.out_proj), + ): + weight = np.asarray(variables[f"{name}.weight"], dtype=prec) + if weight.shape != proj.weight.shape: + raise ValueError( + f"{name}.weight shape {weight.shape} does not match " + f"the expected shape {proj.weight.shape}" + ) + proj.weight = weight + + class GridBranch(NativeOP): """ Scalar-routed polynomial mixer over grid product branches. @@ -165,34 +357,43 @@ def __init__( def call( self, - query_grid: Any, - context_grid: Any, + left: Any, + right: Any, scalar_pair: Any, + *, + to_grid: Callable[[Any], Any], + from_grid: Callable[[Any], Any], ) -> Any: """ - Apply scalar-routed grid branch mixing. + Apply scalar-routed grid branch mixing on coefficient operands. + + The channel maps are applied at coefficient resolution and the grid + transform is deferred to the injected ``to_grid``/``from_grid`` + callables, matching the pt ``GridBranch`` specialized to the S2 + ``n_frames == 1`` case (so no per-frame packing is needed). Parameters ---------- - query_grid - First grid source with shape ``(N, G, F, C)``. - context_grid - Second grid source with shape ``(N, G, F, C)``. + left, right + Coefficient operands with shape ``(N, D, F, C)``. scalar_pair Invariant router source with shape ``(N, F, 2*C)``. + to_grid, from_grid + Coefficient/grid projectors supplied by the owning grid net. """ - xp = array_api_compat.array_namespace(query_grid) - n_batch, n_grid, n_focus, _ = query_grid.shape - left = self.left_proj(query_grid) - right = self.right_proj(context_grid) + xp = array_api_compat.array_namespace(left) + left = self.left_proj(left) # (N, D, F, N_branches * C) + right = self.right_proj(right) # (N, D, F, N_branches * C) + value = to_grid(left) * to_grid(right) # (N, G, F, N_branches * C) + n_batch, n_grid, n_focus, _ = value.shape value = xp.reshape( - left * right, + value, (n_batch, n_grid, n_focus, self.n_branches, self.channels), ) # (N, G, F, N_branches, C) router = _softmax_last_axis(self.router(scalar_pair)) # (N, F, N_branches) # einsum "ngfhc,nfh->ngfc" as a broadcast sum over the branch axis out = xp.sum(value * router[:, None, :, :, None], axis=3) # (N, G, F, C) - return self.out_proj(out) + return self.out_proj(from_grid(out)) # (N, D, F, C) def serialize(self) -> dict[str, Any]: """Serialize the GridBranch to a dict. @@ -305,10 +506,6 @@ def __init__( self.op_type = str(op_type).lower() if self.op_type not in {"glu", "mlp", "branch"}: raise ValueError("`op_type` must be one of 'glu', 'mlp', or 'branch'") - if self.op_type == "mlp": - raise NotImplementedError( - "op_type='mlp' (grid_mlp=True paths) is not ported to dpmodel" - ) self.precision = precision self.layout = str(layout).lower() if self.layout not in {"ndfc", "nfdc", "flat"}: @@ -340,8 +537,16 @@ def __init__( seed=child_seed(seed, 0), init_std=0.01, ) - if self.op_type == "branch": - self.grid_op: GridBranch | None = GridBranch( + if self.op_type == "mlp": + self.grid_op: NativeOP = GridMLP( + channels=self.channels, + mode=self.mode, + precision=self.precision, + trainable=trainable, + seed=child_seed(seed, 1), + ) + elif self.op_type == "branch": + self.grid_op = GridBranch( channels=self.channels, n_branches=grid_branches, precision=self.precision, @@ -349,8 +554,7 @@ def __init__( seed=child_seed(seed, 1), ) else: - # pt uses nn.Identity() here (parameter-free, no state-dict keys) - self.grid_op = None + self.grid_op = GridProduct() def call(self, query: Any, context: Any = None) -> Any: """Apply the configured grid net and restore the input layout.""" @@ -360,8 +564,7 @@ def call(self, query: Any, context: Any = None) -> Any: query_ndfc = self._to_ndfc(query) left, right = self._split_self_query(query_ndfc) scalar_pair = self._make_scalar_pair(left, right, compute_dtype) - grid_out = self._apply_grid_op(left, right, scalar_pair, compute_dtype) - coeff_out = self._from_grid(grid_out) + coeff_out = self._apply_grid_op(left, right, scalar_pair, compute_dtype) coeff_out = self._apply_scalar_path(coeff_out, scalar_pair) if coeff_out.dtype != input_dtype: coeff_out = xp.astype(coeff_out, input_dtype) @@ -379,11 +582,13 @@ def _apply_grid_op( left = xp.astype(left, compute_dtype) if right.dtype != compute_dtype: right = xp.astype(right, compute_dtype) - left_grid = self._to_grid(left) - right_grid = self._to_grid(right) - if self.op_type == "glu": - return left_grid * right_grid - return self.grid_op(left_grid, right_grid, scalar_pair) + return self.grid_op( + left, + right, + scalar_pair, + to_grid=self._to_grid, + from_grid=self._from_grid, + ) def _apply_scalar_path(self, coeff: Any, scalar_pair: Any) -> Any: xp = array_api_compat.array_namespace(coeff) @@ -421,33 +626,34 @@ def _extract_scalar(self, coeff: Any) -> Any: return coeff[:, 0, :, :] def _to_grid(self, coeff: Any) -> Any: - # einsum "gd,ndfc->ngfc" (n_frames == 1) as a broadcast batched matmul + # einsum "gd,ndfc->ngfc" (n_frames == 1) as a broadcast batched matmul. + # The per-point channel width is inferred so the projector also serves + # widened operands (e.g. a branch hidden width ``n_branches * C``). xp = array_api_compat.array_namespace(coeff) - n_batch, coeff_dim, n_focus, _ = coeff.shape + n_batch, coeff_dim, n_focus, n_channels = coeff.shape to_grid_mat = xp_asarray_nodetach( xp, self.projector.to_grid_mat[...], device=array_api_compat.device(coeff) ) if to_grid_mat.dtype != coeff.dtype: to_grid_mat = xp.astype(to_grid_mat, coeff.dtype) - flat = xp.reshape(coeff, (n_batch, coeff_dim, n_focus * self.channels)) + flat = xp.reshape(coeff, (n_batch, coeff_dim, n_focus * n_channels)) out = xp.matmul(to_grid_mat[None, ...], flat) # (N, G, F*C) - return xp.reshape( - out, (n_batch, self.projector.grid_size, n_focus, self.channels) - ) + return xp.reshape(out, (n_batch, self.projector.grid_size, n_focus, n_channels)) def _from_grid(self, grid: Any) -> Any: - # einsum "dg,ngfc->ndfc" (n_frames == 1) as a broadcast batched matmul + # einsum "dg,ngfc->ndfc" (n_frames == 1) as a broadcast batched matmul. + # The channel width is inferred to match the (possibly widened) grid. xp = array_api_compat.array_namespace(grid) - n_batch, n_grid, n_focus, _ = grid.shape + n_batch, n_grid, n_focus, n_channels = grid.shape coeff_dim = self.projector.coeff_dim from_grid_mat = xp_asarray_nodetach( xp, self.projector.from_grid_mat[...], device=array_api_compat.device(grid) ) if from_grid_mat.dtype != grid.dtype: from_grid_mat = xp.astype(from_grid_mat, grid.dtype) - flat = xp.reshape(grid, (n_batch, n_grid, n_focus * self.channels)) + flat = xp.reshape(grid, (n_batch, n_grid, n_focus * n_channels)) out = xp.matmul(from_grid_mat[None, ...], flat) # (N, D, F*C) - return xp.reshape(out, (n_batch, coeff_dim, n_focus, self.expanded_channels)) + return xp.reshape(out, (n_batch, coeff_dim, n_focus, n_channels)) def _to_ndfc(self, value: Any) -> Any: if self.layout == "ndfc": @@ -569,7 +775,7 @@ def serialize(self) -> dict[str, Any]: variables = {"scalar_gate.weight": to_numpy_array(self.scalar_gate.weight)} if self.mlp_bias: variables["scalar_gate.bias"] = to_numpy_array(self.scalar_gate.bias) - if self.op_type == "branch": + if self.op_type in {"mlp", "branch"}: grid_op_data = self.grid_op.serialize()["@variables"] for key, value in grid_op_data.items(): variables[f"grid_op.{key}"] = value @@ -636,7 +842,7 @@ def deserialize(cls, data: dict[str, Any]) -> S2GridNet: obj.scalar_gate.bias = np.asarray( variables["scalar_gate.bias"], dtype=prec ).reshape(obj.scalar_gate.bias.shape) - if obj.op_type == "branch": + if obj.op_type in {"mlp", "branch"}: obj.grid_op._load_variables( { key[len("grid_op.") :]: value diff --git a/source/tests/pt/model/test_dpa4_dpmodel_parity.py b/source/tests/pt/model/test_dpa4_dpmodel_parity.py index 08d78e7320..90a986f77f 100644 --- a/source/tests/pt/model/test_dpa4_dpmodel_parity.py +++ b/source/tests/pt/model/test_dpa4_dpmodel_parity.py @@ -1483,20 +1483,17 @@ def _build_grid_nets( expected_keys = {"scalar_gate.weight"} if mlp_bias: expected_keys.add("scalar_gate.bias") - if op_type == "branch": - expected_keys |= { - "grid_op.left_proj.weight", - "grid_op.right_proj.weight", - "grid_op.router.weight", - "grid_op.out_proj.weight", - } + grid_op_params = { + "mlp": ("left_proj", "right_proj", "out_proj"), + "branch": ("left_proj", "right_proj", "router", "out_proj"), + }.get(op_type, ()) + expected_keys |= {f"grid_op.{name}.weight" for name in grid_op_params} assert set(state) == expected_keys dp_net.scalar_gate.weight = state["scalar_gate.weight"] if mlp_bias: dp_net.scalar_gate.bias = state["scalar_gate.bias"] - if op_type == "branch": - for name in ("left_proj", "right_proj", "router", "out_proj"): - getattr(dp_net.grid_op, name).weight = state[f"grid_op.{name}.weight"] + for name in grid_op_params: + getattr(dp_net.grid_op, name).weight = state[f"grid_op.{name}.weight"] return pt_net, dp_net # ------------------------------------------------- (a) projector constants @@ -1571,7 +1568,7 @@ def test_resolve_s2_grid_resolution(self, method, lmax, mmax) -> None: # ------------------------------------------ (b) S2GridNet forward parity @pytest.mark.parametrize("lmax", [2, 3]) # max degree - @pytest.mark.parametrize("op_type", ["glu", "branch"]) # grid operation + @pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid operation def test_s2_grid_net(self, lmax, op_type) -> None: # ffn-style core usage: mode="self", layout="ndfc", packed, n_focus=1 pt_net, dp_net = self._build_grid_nets( @@ -1629,6 +1626,7 @@ def test_grid_branch(self, n_branches) -> None: pt_mod = PTGridBranch( channels=self.channels, n_branches=n_branches, + n_frames=1, dtype=torch.float64, trainable=True, seed=9, @@ -1652,16 +1650,26 @@ def test_grid_branch(self, n_branches) -> None: ) for name in ("left_proj", "right_proj", "router", "out_proj"): getattr(dp_mod, name).weight = state[f"{name}.weight"] - n_batch, n_grid, n_focus = 5, 26, 2 - query = rng.normal(size=(n_batch, n_grid, n_focus, self.channels)) - context = rng.normal(size=(n_batch, n_grid, n_focus, self.channels)) + n_batch, n_coeff, n_focus = 5, 26, 2 + left = rng.normal(size=(n_batch, n_coeff, n_focus, self.channels)) + right = rng.normal(size=(n_batch, n_coeff, n_focus, self.channels)) scalar = rng.normal(size=(n_batch, n_focus, 2 * self.channels)) + + # Both backends take coefficient operands and defer the grid transform + # to injected to_grid/from_grid callables (pt's so3grid layout). The + # unit injects identity projectors; real projector behavior is covered + # by the S2GridNet parity tests above. + def identity(t): + return t + assert_parity( - dp_mod.call(query, context, scalar), + dp_mod.call(left, right, scalar, to_grid=identity, from_grid=identity), pt_mod( - to_pt(query), - to_pt(context), + to_pt(left), + to_pt(right), to_pt(scalar), + to_grid=identity, + from_grid=identity, ), ) # serialize roundtrip is exact; @variables keys match the pt state dict @@ -1669,12 +1677,81 @@ def test_grid_branch(self, n_branches) -> None: assert set(ser["@variables"]) == set(state) dp_mod2 = DPGridBranch.deserialize(ser) np.testing.assert_array_equal( - np.asarray(dp_mod.call(query, context, scalar)), - np.asarray(dp_mod2.call(query, context, scalar)), + np.asarray( + dp_mod.call(left, right, scalar, to_grid=identity, from_grid=identity) + ), + np.asarray( + dp_mod2.call(left, right, scalar, to_grid=identity, from_grid=identity) + ), + ) + + # --------------------------------------------- (c') GridMLP forward parity + @pytest.mark.parametrize("mode", ["self", "cross"]) # pairing mode + def test_grid_mlp(self, mode) -> None: + from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import GridMLP as DPGridMLP + from deepmd.pt.model.descriptor.sezm_nn.grid_net import GridMLP as PTGridMLP + + pt_mod = PTGridMLP( + channels=self.channels, + mode=mode, + n_frames=1, + dtype=torch.float64, + trainable=True, + seed=9, + ) + rng = np.random.default_rng(2087) + with torch.no_grad(): + for p in pt_mod.parameters(): + p += to_pt(0.1 * rng.normal(size=tuple(p.shape))) + state = pt_state_to_numpy(pt_mod) + assert set(state) == { + "left_proj.weight", + "right_proj.weight", + "out_proj.weight", + } + dp_mod = DPGridMLP( + channels=self.channels, + mode=mode, + precision="float64", + seed=9, + ) + for name in ("left_proj", "right_proj", "out_proj"): + getattr(dp_mod, name).weight = state[f"{name}.weight"] + n_batch, n_coeff, n_focus = 5, 26, 2 + left = rng.normal(size=(n_batch, n_coeff, n_focus, self.channels)) + right = rng.normal(size=(n_batch, n_coeff, n_focus, self.channels)) + # GridMLP ignores scalar_pair; both backends take coefficient operands + # and defer the grid transform to injected to_grid/from_grid callables. + scalar = rng.normal(size=(n_batch, n_focus, 2 * self.channels)) + + def identity(t): + return t + + assert_parity( + dp_mod.call(left, right, scalar, to_grid=identity, from_grid=identity), + pt_mod( + to_pt(left), + to_pt(right), + to_pt(scalar), + to_grid=identity, + from_grid=identity, + ), + ) + # serialize roundtrip is exact; @variables keys match the pt state dict + ser = dp_mod.serialize() + assert set(ser["@variables"]) == set(state) + dp_mod2 = DPGridMLP.deserialize(ser) + np.testing.assert_array_equal( + np.asarray( + dp_mod.call(left, right, scalar, to_grid=identity, from_grid=identity) + ), + np.asarray( + dp_mod2.call(left, right, scalar, to_grid=identity, from_grid=identity) + ), ) # ------------------------------------------------------ (e) serialization - @pytest.mark.parametrize("op_type", ["glu", "branch"]) # grid operation + @pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid operation @pytest.mark.parametrize("mlp_bias", [False, True]) # scalar gate bias def test_s2_grid_net_serialize_roundtrip(self, op_type, mlp_bias) -> None: from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import S2GridNet as DPS2GridNet @@ -1732,9 +1809,6 @@ def test_not_ported_guards(self) -> None: # "e3nn" default, which dp rejects): default construction works net = DPS2GridNet(**{k: v for k, v in common.items() if k != "grid_method"}) assert net.grid_method == "lebedev" - with pytest.raises(NotImplementedError, match="grid_mlp"): - # GridMLP (grid_mlp=True) is not ported - DPS2GridNet(**{**common, "op_type": "mlp"}) with pytest.raises(NotImplementedError, match="node_wise_s2"): # cross mode backs node_wise_s2/message_node_s2 only DPS2GridNet(**{**common, "mode": "cross"}) @@ -3114,12 +3188,6 @@ def test_ffn_guards(self) -> None: with pytest.raises(NotImplementedError, match="ffn_so3_grid"): DPFFN(**self._ffn_kwargs(ffn_so3_grid=True), precision="float64") - # grid_mlp guard is delegated to S2GridNet's op_type='mlp' NIE - with pytest.raises(NotImplementedError, match="mlp"): - DPFFN( - **self._ffn_kwargs(s2_activation=True, grid_mlp=True), - precision="float64", - ) def test_ffn_errors(self) -> None: from deepmd.dpmodel.descriptor.dpa4_nn.ffn import EquivariantFFN as DPFFN From f011367d4a444d9552bc7171bda29c13fcb8b38d Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 18 Jun 2026 11:57:07 +0800 Subject: [PATCH 3/3] fix(dpa4): make GridProduct pt_expt-wrappable The parameter-free `GridProduct` NativeOP (added for the so3grid optimization) has no `serialize`/`deserialize` and is not registered via `register_dpmodel_mapping`. The pt_expt backend auto-wraps every dpmodel NativeOP sub-component through `_auto_wrap_native_op`, which requires the op to be serializable (or registered) to build its dynamic torch wrapper; otherwise it raises: TypeError: Cannot auto-wrap GridProduct: it must implement serialize()/deserialize() or be explicitly registered via register_dpmodel_mapping(). This broke every `Test Python` shard that loads a DPA4 pt_expt model (e.g. test_get_model_dpa4.py). Add trivial `serialize`/`deserialize` (no state, mirroring the GridMLP @class/@version convention) so the op auto-wraps cleanly. --- deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py index ff11b5ad6d..dc7d2ab1a8 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py @@ -123,6 +123,23 @@ def call( """ return from_grid(to_grid(left) * to_grid(right)) + def serialize(self) -> dict[str, Any]: + """Serialize the parameter-free grid product to a dict.""" + return { + "@class": "GridProduct", + "@version": 1, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> GridProduct: + """Deserialize a GridProduct from a dict.""" + data = data.copy() + data_cls = data.pop("@class", "GridProduct") + if data_cls != "GridProduct": + raise ValueError(f"Invalid class for GridProduct: {data_cls}") + check_version_compatibility(int(data.pop("@version", 1)), 1, 1) + return cls() + class GridMLP(NativeOP): """