Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions src/pyrecest/backend_support/_pytorch_pad_mode_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""PyTorch ``pad`` NumPy mode-name compatibility hook."""

from __future__ import annotations

_PAD_MODE_ALIASES = {
"edge": "replicate",
"wrap": "circular",
}


def _torch_pad_mode(mode):
"""Return the PyTorch padding mode corresponding to a NumPy mode name."""

return _PAD_MODE_ALIASES.get(mode, mode)


def patch_pytorch_pad_mode_contract() -> None:
"""Patch raw/public PyTorch ``pad`` to accept NumPy mode names."""

try:
import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable
return

original_pad = getattr(pytorch_backend, "pad", None)
if original_pad is None:
return
if getattr(original_pad, "_pyrecest_numpy_pad_mode_contract", False):
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.pad = original_pad
return

def pad(a, pad_width, mode="constant", constant_values=0.0):
return original_pad(
a,
pad_width,
mode=_torch_pad_mode(mode),
constant_values=constant_values,
)

pad.__name__ = getattr(original_pad, "__name__", "pad")
pad.__doc__ = getattr(original_pad, "__doc__", None)
pad._pyrecest_numpy_pad_mode_contract = True
pytorch_backend.pad = pad
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.pad = pad
4 changes: 4 additions & 0 deletions src/pyrecest/stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
from pyrecest.backend_support._pytorch_allclose_device_contract import (
patch_pytorch_allclose_device_contract as _patch_pytorch_allclose_device_contract,
)
from pyrecest.backend_support._pytorch_pad_mode_contract import (
patch_pytorch_pad_mode_contract as _patch_pytorch_pad_mode_contract,
)

_patch_pytorch_allclose_device_contract()
_patch_pytorch_pad_mode_contract()

P = ParamSpec("P")
R = TypeVar("R")
Expand Down
40 changes: 40 additions & 0 deletions tests/backend/test_pytorch_pad_mode_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from tests.support.backend_runner import run_backend_code


torch = pytest.importorskip("torch")

import pyrecest.backend_tools # noqa: E402,F401
import pyrecest._backend.pytorch as pytorch_backend # noqa: E402


def _to_list(value):
return value.detach().cpu().tolist()


def test_raw_pytorch_pad_accepts_numpy_edge_and_wrap_mode_names():
values = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float64)

edge = pytorch_backend.pad(values, ((0, 0), (2, 1)), mode="edge")
wrap = pytorch_backend.pad(values, ((0, 0), (1, 2)), mode="wrap")

assert _to_list(edge) == [[1.0, 1.0, 1.0, 2.0, 3.0, 3.0]]
assert _to_list(wrap) == [[3.0, 1.0, 2.0, 3.0, 1.0, 2.0]]


def test_public_pytorch_pad_accepts_numpy_edge_and_wrap_mode_names():
result = run_backend_code(
"pytorch",
r'''
import pyrecest.backend as backend

values = backend.asarray([[1.0, 2.0, 3.0]])
edge = backend.pad(values, ((0, 0), (2, 1)), mode="edge")
wrap = backend.pad(values, ((0, 0), (1, 2)), mode="wrap")

assert backend.to_numpy(edge).tolist() == [[1.0, 1.0, 1.0, 2.0, 3.0, 3.0]]
assert backend.to_numpy(wrap).tolist() == [[3.0, 1.0, 2.0, 3.0, 1.0, 2.0]]
''',
)

assert result.returncode == 0, result.stderr
Loading