diff --git a/src/pyrecest/backend_support/_pytorch_pad_mode_contract.py b/src/pyrecest/backend_support/_pytorch_pad_mode_contract.py new file mode 100644 index 000000000..8d5a8ff9f --- /dev/null +++ b/src/pyrecest/backend_support/_pytorch_pad_mode_contract.py @@ -0,0 +1,47 @@ +"""PyTorch ``pad`` NumPy mode-name compatibility hook.""" + +from __future__ import annotations + +_PAD_MODE_ALIASES = { + "edge": "replicate", + "wrap": "circular", +} + + +def _torch_pad_mode(mode): + """Return the PyTorch padding mode corresponding to a NumPy mode name.""" + + return _PAD_MODE_ALIASES.get(mode, mode) + + +def patch_pytorch_pad_mode_contract() -> None: + """Patch raw/public PyTorch ``pad`` to accept NumPy mode names.""" + + try: + import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable + return + + original_pad = getattr(pytorch_backend, "pad", None) + if original_pad is None: + return + if getattr(original_pad, "_pyrecest_numpy_pad_mode_contract", False): + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.pad = original_pad + return + + def pad(a, pad_width, mode="constant", constant_values=0.0): + return original_pad( + a, + pad_width, + mode=_torch_pad_mode(mode), + constant_values=constant_values, + ) + + pad.__name__ = getattr(original_pad, "__name__", "pad") + pad.__doc__ = getattr(original_pad, "__doc__", None) + pad._pyrecest_numpy_pad_mode_contract = True + pytorch_backend.pad = pad + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.pad = pad diff --git a/src/pyrecest/stability.py b/src/pyrecest/stability.py index aa56e78f3..25e45074b 100644 --- a/src/pyrecest/stability.py +++ b/src/pyrecest/stability.py @@ -9,8 +9,12 @@ from pyrecest.backend_support._pytorch_allclose_device_contract import ( patch_pytorch_allclose_device_contract as _patch_pytorch_allclose_device_contract, ) +from pyrecest.backend_support._pytorch_pad_mode_contract import ( + patch_pytorch_pad_mode_contract as _patch_pytorch_pad_mode_contract, +) _patch_pytorch_allclose_device_contract() +_patch_pytorch_pad_mode_contract() P = ParamSpec("P") R = TypeVar("R") diff --git a/tests/backend/test_pytorch_pad_mode_contract.py b/tests/backend/test_pytorch_pad_mode_contract.py new file mode 100644 index 000000000..a1ecc6463 --- /dev/null +++ b/tests/backend/test_pytorch_pad_mode_contract.py @@ -0,0 +1,40 @@ +import pytest +from tests.support.backend_runner import run_backend_code + + +torch = pytest.importorskip("torch") + +import pyrecest.backend_tools # noqa: E402,F401 +import pyrecest._backend.pytorch as pytorch_backend # noqa: E402 + + +def _to_list(value): + return value.detach().cpu().tolist() + + +def test_raw_pytorch_pad_accepts_numpy_edge_and_wrap_mode_names(): + values = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float64) + + edge = pytorch_backend.pad(values, ((0, 0), (2, 1)), mode="edge") + wrap = pytorch_backend.pad(values, ((0, 0), (1, 2)), mode="wrap") + + assert _to_list(edge) == [[1.0, 1.0, 1.0, 2.0, 3.0, 3.0]] + assert _to_list(wrap) == [[3.0, 1.0, 2.0, 3.0, 1.0, 2.0]] + + +def test_public_pytorch_pad_accepts_numpy_edge_and_wrap_mode_names(): + result = run_backend_code( + "pytorch", + r''' +import pyrecest.backend as backend + +values = backend.asarray([[1.0, 2.0, 3.0]]) +edge = backend.pad(values, ((0, 0), (2, 1)), mode="edge") +wrap = backend.pad(values, ((0, 0), (1, 2)), mode="wrap") + +assert backend.to_numpy(edge).tolist() == [[1.0, 1.0, 1.0, 2.0, 3.0, 3.0]] +assert backend.to_numpy(wrap).tolist() == [[3.0, 1.0, 2.0, 3.0, 1.0, 2.0]] +''', + ) + + assert result.returncode == 0, result.stderr