From 906c59ae5622f866dbc029a270aa155effc78cc9 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 22:07:13 +0200 Subject: [PATCH 1/2] Add PyTorch dot outer device regression --- .../test_pytorch_dot_outer_device_contract.py | 70 ++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/tests/backend_support/test_pytorch_dot_outer_device_contract.py b/tests/backend_support/test_pytorch_dot_outer_device_contract.py index 48cdce852..c793f7a8a 100644 --- a/tests/backend_support/test_pytorch_dot_outer_device_contract.py +++ b/tests/backend_support/test_pytorch_dot_outer_device_contract.py @@ -1 +1,69 @@ -placeholder +import importlib.util + +import pytest + +from tests.support.backend_runner import run_backend_code + +pytestmark = pytest.mark.backend_portable + + +def _device_contract_code(target_module): + return f""" +import torch +import pyrecest # noqa: F401 # triggers backend-support compatibility patches +import pyrecest.backend as backend +import pyrecest._backend.pytorch as raw_pytorch + + +def _non_cpu_device(): + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("meta") + + +target = {target_module} +device = _non_cpu_device() +right_vector = torch.ones(2, device=device) + +dot_result = target.dot(torch.tensor([1.0, 2.0]), right_vector) +assert dot_result.device.type == device.type +assert tuple(dot_result.shape) == () +if device.type != "meta": + assert torch.allclose(dot_result.cpu(), torch.tensor(3.0)) + +outer_result = target.outer(torch.tensor([1.0, 2.0]), right_vector) +assert outer_result.device.type == device.type +assert tuple(outer_result.shape) == (2, 2) +if device.type != "meta": + expected = torch.tensor([[1.0, 1.0], [2.0, 2.0]]) + assert torch.allclose(outer_result.cpu(), expected) + +dot_arraylike_result = target.dot([1.0, 2.0], right_vector) +assert dot_arraylike_result.device.type == device.type +assert tuple(dot_arraylike_result.shape) == () +if device.type != "meta": + assert torch.allclose(dot_arraylike_result.cpu(), torch.tensor(3.0)) + +print("ok") +""" + + + +def test_raw_pytorch_dot_outer_prefer_existing_non_cpu_device_after_import(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + result = run_backend_code("numpy", _device_contract_code("raw_pytorch")) + + assert result.returncode == 0, result.stderr + assert "ok" in result.stdout + + +def test_public_pytorch_dot_outer_prefer_existing_non_cpu_device(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + result = run_backend_code("pytorch", _device_contract_code("backend")) + + assert result.returncode == 0, result.stderr + assert "ok" in result.stdout From f01d57c3abe96dffea8d54e2bd8b206ce7a46f69 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 22:10:07 +0200 Subject: [PATCH 2/2] PyTorch dot outer device hook --- .../_pytorch_dot_outer_device_contract.py | 80 +++++++++++++++++++ src/pyrecest/stability.py | 6 ++ 2 files changed, 86 insertions(+) create mode 100644 src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py diff --git a/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py b/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py new file mode 100644 index 000000000..f9a6d1da3 --- /dev/null +++ b/src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py @@ -0,0 +1,80 @@ +"""PyTorch ``dot``/``outer`` device compatibility hook.""" + +from __future__ import annotations + + +def _preferred_pytorch_device(torch_module, *values): + """Return a non-CPU tensor device when mixed-device operands are present.""" + for value in values: + if torch_module.is_tensor(value) and value.device.type != "cpu": + return value.device + for value in values: + if torch_module.is_tensor(value): + return value.device + return None + + +def _promoted_pair(raw_pytorch, torch_module, left, right): + """Return PyTorch operands on a common dtype and preferred existing device.""" + device = _preferred_pytorch_device(torch_module, left, right) + left = raw_pytorch.array(left) + right = raw_pytorch.array(right) + dtype = torch_module.promote_types(left.dtype, right.dtype) + if device is None: + return left.to(dtype=dtype), right.to(dtype=dtype) + return left.to(device=device, dtype=dtype), right.to(device=device, dtype=dtype) + + +def patch_pytorch_dot_outer_device_contract() -> None: + """Patch raw/public PyTorch ``dot`` and ``outer`` to preserve non-CPU operands.""" + try: + import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + import torch # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable + return + + original_dot = getattr(raw_pytorch, "dot", None) + original_outer = getattr(raw_pytorch, "outer", None) + if original_dot is None or original_outer is None: + return + if getattr(original_dot, "_pyrecest_dot_outer_device_contract", False) and getattr( + original_outer, + "_pyrecest_dot_outer_device_contract", + False, + ): + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.dot = original_dot + backend.outer = original_outer + return + + def dot(a, b): + a, b = _promoted_pair(raw_pytorch, torch, a, b) + if a.ndim == 0 or b.ndim == 0: + return torch.multiply(a, b) + if a.ndim == 1 and b.ndim == 1: + return torch.dot(a, b) + if b.ndim == 1: + return torch.einsum("...i,i->...", a, b) + if a.ndim == 1: + return torch.einsum("i,...i->...", a, b) + return torch.einsum("...i,...i->...", a, b) + + def outer(a, b): + a, b = _promoted_pair(raw_pytorch, torch, a, b) + if a.ndim == 0 or b.ndim == 0: + return torch.multiply(a, b) + return a[..., :, None] * b[..., None, :] + + for helper_name, helper, original_helper in ( + ("dot", dot, original_dot), + ("outer", outer, original_outer), + ): + helper.__name__ = getattr(original_helper, "__name__", helper_name) + helper.__doc__ = getattr(original_helper, "__doc__", None) + helper._pyrecest_dot_outer_device_contract = True + helper._pyrecest_device_contract = True + helper._pyrecest_numpy_contract = True + setattr(raw_pytorch, helper_name, helper) + if getattr(backend, "__backend_name__", None) == "pytorch": + setattr(backend, helper_name, helper) diff --git a/src/pyrecest/stability.py b/src/pyrecest/stability.py index 1d362a318..51e20ce90 100644 --- a/src/pyrecest/stability.py +++ b/src/pyrecest/stability.py @@ -6,6 +6,12 @@ from dataclasses import asdict, dataclass from typing import Final, Literal, ParamSpec, TypeVar +from pyrecest.backend_support._pytorch_dot_outer_device_contract import ( + patch_pytorch_dot_outer_device_contract as _patch_pytorch_dot_outer_device_contract, +) + +_patch_pytorch_dot_outer_device_contract() + P = ParamSpec("P") R = TypeVar("R")