Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions src/pyrecest/backend_support/_pytorch_matmul_device_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""PyTorch ``matmul`` device compatibility hook."""

from __future__ import annotations


def _preferred_pytorch_device(torch_module, *values):
"""Return an existing non-CPU tensor device, falling back to any tensor."""
for value in values:
if torch_module.is_tensor(value) and value.device.type != "cpu":
return value.device
for value in values:
if torch_module.is_tensor(value):
return value.device
return None


def patch_pytorch_matmul_device_contract() -> None:
"""Patch raw/public PyTorch ``matmul`` to keep operands on one device."""
try:
import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
import torch # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable
return

original_matmul = getattr(raw_pytorch, "matmul", None)
if original_matmul is None:
return
if getattr(original_matmul, "_pyrecest_device_contract", False):
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.matmul = original_matmul
return

def matmul(x, y, out=None):
device = _preferred_pytorch_device(torch, x, y, out)
x = raw_pytorch.array(x)
y = raw_pytorch.array(y)
dtype = torch.promote_types(x.dtype, y.dtype)

if device is not None:
x = x.to(device=device, dtype=dtype)
y = y.to(device=device, dtype=dtype)
else:
x = x.to(dtype=dtype)
y = y.to(dtype=dtype)

if out is not None:
return torch.matmul(x, y, out=out)
return torch.matmul(x, y)

matmul.__name__ = getattr(original_matmul, "__name__", "matmul")
matmul.__doc__ = getattr(original_matmul, "__doc__", None)
matmul._pyrecest_device_contract = True
matmul._pyrecest_numpy_contract = True
raw_pytorch.matmul = matmul
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.matmul = matmul
6 changes: 6 additions & 0 deletions src/pyrecest/stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from dataclasses import asdict, dataclass
from typing import Final, Literal, ParamSpec, TypeVar

from pyrecest.backend_support._pytorch_matmul_device_contract import (
patch_pytorch_matmul_device_contract as _patch_pytorch_matmul_device_contract,
)

_patch_pytorch_matmul_device_contract()

P = ParamSpec("P")
R = TypeVar("R")

Expand Down
67 changes: 67 additions & 0 deletions tests/backend_support/test_pytorch_matmul_device_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import importlib.util

import pytest

from tests.support.backend_runner import run_backend_code

pytestmark = pytest.mark.backend_portable


def _matmul_device_contract_code(target_module):
return f"""
import torch
import pyrecest # noqa: F401 # triggers backend-support compatibility patches
import pyrecest.backend as backend
import pyrecest._backend.pytorch as raw_pytorch


def _non_cpu_device():
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("meta")


target = {target_module}
device = _non_cpu_device()

right_matrix = torch.eye(2, device=device)
matrix_result = target.matmul(torch.eye(2), right_matrix)
assert matrix_result.device.type == device.type
assert tuple(matrix_result.shape) == (2, 2)
if device.type != "meta":
assert torch.allclose(matrix_result.cpu(), torch.eye(2))

array_like_result = target.matmul([[1.0, 2.0]], torch.ones((2, 1), device=device))
assert array_like_result.device.type == device.type
assert tuple(array_like_result.shape) == (1, 1)
if device.type != "meta":
assert torch.allclose(array_like_result.cpu(), torch.tensor([[3.0]]))

vector_result = target.matmul(torch.tensor([1.0, 2.0]), torch.ones(2, device=device))
assert vector_result.device.type == device.type
assert tuple(vector_result.shape) == ()
if device.type != "meta":
assert torch.allclose(vector_result.cpu(), torch.tensor(3.0))

print("ok")
"""


def test_raw_pytorch_matmul_prefers_existing_non_cpu_device_after_import():
if importlib.util.find_spec("torch") is None:
pytest.skip("PyTorch is not installed")

result = run_backend_code("numpy", _matmul_device_contract_code("raw_pytorch"))

assert result.returncode == 0, result.stderr
assert "ok" in result.stdout


def test_public_pytorch_matmul_prefers_existing_non_cpu_device():
if importlib.util.find_spec("torch") is None:
pytest.skip("PyTorch is not installed")

result = run_backend_code("pytorch", _matmul_device_contract_code("backend"))

assert result.returncode == 0, result.stderr
assert "ok" in result.stdout
Loading