From d0a638b2f054290c2e413901a1c33b21135838cb Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 22:31:31 +0200 Subject: [PATCH 1/3] Add PyTorch matmul device hook --- .../_pytorch_matmul_device_contract.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 src/pyrecest/backend_support/_pytorch_matmul_device_contract.py diff --git a/src/pyrecest/backend_support/_pytorch_matmul_device_contract.py b/src/pyrecest/backend_support/_pytorch_matmul_device_contract.py new file mode 100644 index 000000000..e8d3d2895 --- /dev/null +++ b/src/pyrecest/backend_support/_pytorch_matmul_device_contract.py @@ -0,0 +1,57 @@ +"""PyTorch ``matmul`` device compatibility hook.""" + +from __future__ import annotations + + +def _preferred_pytorch_device(torch_module, *values): + """Return an existing non-CPU tensor device, falling back to any tensor.""" + for value in values: + if torch_module.is_tensor(value) and value.device.type != "cpu": + return value.device + for value in values: + if torch_module.is_tensor(value): + return value.device + return None + + +def patch_pytorch_matmul_device_contract() -> None: + """Patch raw/public PyTorch ``matmul`` to keep operands on one device.""" + try: + import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + import torch # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable + return + + original_matmul = getattr(raw_pytorch, "matmul", None) + if original_matmul is None: + return + if getattr(original_matmul, "_pyrecest_device_contract", False): + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.matmul = original_matmul + return + + def matmul(x, y, out=None): + device = _preferred_pytorch_device(torch, x, y, out) + x = raw_pytorch.array(x) + y = raw_pytorch.array(y) + dtype = torch.promote_types(x.dtype, y.dtype) + + if device is not None: + x = x.to(device=device, dtype=dtype) + y = y.to(device=device, dtype=dtype) + else: + x = x.to(dtype=dtype) + y = y.to(dtype=dtype) + + if out is not None: + return torch.matmul(x, y, out=out) + return torch.matmul(x, y) + + matmul.__name__ = getattr(original_matmul, "__name__", "matmul") + matmul.__doc__ = getattr(original_matmul, "__doc__", None) + matmul._pyrecest_device_contract = True + matmul._pyrecest_numpy_contract = True + raw_pytorch.matmul = matmul + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.matmul = matmul From 03a8ceecc613eba35e6aa604e7f997f89b6be0ec Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 22:31:54 +0200 Subject: [PATCH 2/3] Load PyTorch matmul device hook --- src/pyrecest/stability.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/pyrecest/stability.py b/src/pyrecest/stability.py index 1d362a318..57d9ad695 100644 --- a/src/pyrecest/stability.py +++ b/src/pyrecest/stability.py @@ -6,6 +6,12 @@ from dataclasses import asdict, dataclass from typing import Final, Literal, ParamSpec, TypeVar +from pyrecest.backend_support._pytorch_matmul_device_contract import ( + patch_pytorch_matmul_device_contract as _patch_pytorch_matmul_device_contract, +) + +_patch_pytorch_matmul_device_contract() + P = ParamSpec("P") R = TypeVar("R") From 71d8d3b325b9b9bf8bbfc968fdc6448a473fe4b7 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 22:32:07 +0200 Subject: [PATCH 3/3] Test PyTorch matmul device contract --- .../test_pytorch_matmul_device_contract.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 tests/backend_support/test_pytorch_matmul_device_contract.py diff --git a/tests/backend_support/test_pytorch_matmul_device_contract.py b/tests/backend_support/test_pytorch_matmul_device_contract.py new file mode 100644 index 000000000..0232750c6 --- /dev/null +++ b/tests/backend_support/test_pytorch_matmul_device_contract.py @@ -0,0 +1,67 @@ +import importlib.util + +import pytest + +from tests.support.backend_runner import run_backend_code + +pytestmark = pytest.mark.backend_portable + + +def _matmul_device_contract_code(target_module): + return f""" +import torch +import pyrecest # noqa: F401 # triggers backend-support compatibility patches +import pyrecest.backend as backend +import pyrecest._backend.pytorch as raw_pytorch + + +def _non_cpu_device(): + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("meta") + + +target = {target_module} +device = _non_cpu_device() + +right_matrix = torch.eye(2, device=device) +matrix_result = target.matmul(torch.eye(2), right_matrix) +assert matrix_result.device.type == device.type +assert tuple(matrix_result.shape) == (2, 2) +if device.type != "meta": + assert torch.allclose(matrix_result.cpu(), torch.eye(2)) + +array_like_result = target.matmul([[1.0, 2.0]], torch.ones((2, 1), device=device)) +assert array_like_result.device.type == device.type +assert tuple(array_like_result.shape) == (1, 1) +if device.type != "meta": + assert torch.allclose(array_like_result.cpu(), torch.tensor([[3.0]])) + +vector_result = target.matmul(torch.tensor([1.0, 2.0]), torch.ones(2, device=device)) +assert vector_result.device.type == device.type +assert tuple(vector_result.shape) == () +if device.type != "meta": + assert torch.allclose(vector_result.cpu(), torch.tensor(3.0)) + +print("ok") +""" + + +def test_raw_pytorch_matmul_prefers_existing_non_cpu_device_after_import(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + result = run_backend_code("numpy", _matmul_device_contract_code("raw_pytorch")) + + assert result.returncode == 0, result.stderr + assert "ok" in result.stdout + + +def test_public_pytorch_matmul_prefers_existing_non_cpu_device(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + result = run_backend_code("pytorch", _matmul_device_contract_code("backend")) + + assert result.returncode == 0, result.stderr + assert "ok" in result.stdout