From 1c0e647791ee558e84f594979ea6acc0e0529849 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 22:31:07 +0200 Subject: [PATCH 1/3] Add PyTorch min/max device contract hook --- .../_pytorch_minmax_device_contract.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 src/pyrecest/backend_support/_pytorch_minmax_device_contract.py diff --git a/src/pyrecest/backend_support/_pytorch_minmax_device_contract.py b/src/pyrecest/backend_support/_pytorch_minmax_device_contract.py new file mode 100644 index 000000000..daffbf942 --- /dev/null +++ b/src/pyrecest/backend_support/_pytorch_minmax_device_contract.py @@ -0,0 +1,63 @@ +"""PyTorch ``maximum``/``minimum`` device compatibility hook.""" + +from __future__ import annotations + + +def _preferred_pytorch_device(torch_module, *values): + """Return a non-CPU tensor device when mixed-device operands are present.""" + for value in values: + if torch_module.is_tensor(value) and value.device.type != "cpu": + return value.device + for value in values: + if torch_module.is_tensor(value): + return value.device + return None + + +def _minmax_operands(raw_pytorch, torch_module, left, right): + """Return operands on a common dtype and an existing preferred device.""" + device = _preferred_pytorch_device(torch_module, left, right) + left = raw_pytorch.array(left) + right = raw_pytorch.array(right) + dtype = torch_module.promote_types(left.dtype, right.dtype) + if device is None: + return left.to(dtype=dtype), right.to(dtype=dtype) + return left.to(device=device, dtype=dtype), right.to(device=device, dtype=dtype) + + +def patch_pytorch_minmax_device_contract() -> None: + """Patch raw/public PyTorch ``maximum`` and ``minimum`` to preserve device.""" + try: + import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + import torch # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable + return + + helpers = { + "maximum": torch.maximum, + "minimum": torch.minimum, + } + if all( + getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_minmax_device_contract", False) + for helper_name in helpers + ): + if getattr(backend, "__backend_name__", None) == "pytorch": + for helper_name in helpers: + setattr(backend, helper_name, getattr(raw_pytorch, helper_name)) + return + + for helper_name, torch_helper in helpers.items(): + original_helper = getattr(raw_pytorch, helper_name) + + def minmax(left, right, _torch_helper=torch_helper): + left, right = _minmax_operands(raw_pytorch, torch, left, right) + return _torch_helper(left, right) + + minmax.__name__ = getattr(original_helper, "__name__", helper_name) + minmax.__doc__ = getattr(original_helper, "__doc__", None) + minmax._pyrecest_minmax_device_contract = True + minmax._pyrecest_device_contract = True + setattr(raw_pytorch, helper_name, minmax) + if getattr(backend, "__backend_name__", None) == "pytorch": + setattr(backend, helper_name, minmax) From b0be381d30cf98701ebc9e9d9327ec6346aa77c5 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 22:32:02 +0200 Subject: [PATCH 2/3] Update stability helpers --- src/pyrecest/stability.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/pyrecest/stability.py b/src/pyrecest/stability.py index 1d362a318..6ecab0479 100644 --- a/src/pyrecest/stability.py +++ b/src/pyrecest/stability.py @@ -6,6 +6,13 @@ from dataclasses import asdict, dataclass from typing import Final, Literal, ParamSpec, TypeVar +from pyrecest.backend_support._pytorch_minmax_device_contract import ( + patch_pytorch_minmax_device_contract as _patch_pytorch_minmax_device_contract, +) + +_pytorch_minmax_device_contract = _patch_pytorch_minmax_device_contract +_pytorch_minmax_device_contract() + P = ParamSpec("P") R = TypeVar("R") From 91283ff89a03674157e69760ddb83228b4d863dc Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 22:32:15 +0200 Subject: [PATCH 3/3] Test PyTorch min/max device preservation --- .../test_pytorch_minmax_device_contract.py | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 tests/backend_support/test_pytorch_minmax_device_contract.py diff --git a/tests/backend_support/test_pytorch_minmax_device_contract.py b/tests/backend_support/test_pytorch_minmax_device_contract.py new file mode 100644 index 000000000..b1dc3097a --- /dev/null +++ b/tests/backend_support/test_pytorch_minmax_device_contract.py @@ -0,0 +1,72 @@ +"""Regression tests for PyTorch maximum/minimum device preservation.""" + +from __future__ import annotations + +import importlib.util + +import pytest + +from tests.support.backend_runner import run_backend_code + +pytestmark = pytest.mark.backend_portable + + +def _device_contract_code(target_module: str) -> str: + return f""" +import torch +import pyrecest # noqa: F401 # triggers backend-support compatibility patches +import pyrecest.backend as backend +import pyrecest._backend.pytorch as raw_pytorch + + +def _non_cpu_device(): + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("meta") + + +target = {target_module} +device = _non_cpu_device() +right = torch.tensor([1.0, 4.0], device=device) + +maximum_result = target.maximum([2.0, 3.0], right) +assert maximum_result.device.type == device.type +assert tuple(maximum_result.shape) == (2,) +if device.type != "meta": + assert torch.allclose(maximum_result.cpu(), torch.tensor([2.0, 4.0])) + +minimum_result = target.minimum(torch.tensor([2.0, 3.0]), right) +assert minimum_result.device.type == device.type +assert tuple(minimum_result.shape) == (2,) +if device.type != "meta": + assert torch.allclose(minimum_result.cpu(), torch.tensor([1.0, 3.0])) + +left = torch.tensor([2.0, 3.0], device=device) +minimum_arraylike_result = target.minimum(left, [1.0, 4.0]) +assert minimum_arraylike_result.device.type == device.type +assert tuple(minimum_arraylike_result.shape) == (2,) +if device.type != "meta": + assert torch.allclose(minimum_arraylike_result.cpu(), torch.tensor([1.0, 3.0])) + +print("ok") +""" + + +def test_raw_pytorch_maximum_minimum_prefer_existing_non_cpu_device_after_import(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + result = run_backend_code("numpy", _device_contract_code("raw_pytorch")) + + assert result.returncode == 0, result.stderr + assert "ok" in result.stdout + + +def test_public_pytorch_maximum_minimum_prefer_existing_non_cpu_device(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + result = run_backend_code("pytorch", _device_contract_code("backend")) + + assert result.returncode == 0, result.stderr + assert "ok" in result.stdout