Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions src/pyrecest/backend_support/_pytorch_dot_outer_device_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""PyTorch ``dot``/``outer`` device compatibility hook."""

from __future__ import annotations


def _preferred_pytorch_device(torch_module, *values):
"""Return a non-CPU tensor device when mixed-device operands are present."""
for value in values:
if torch_module.is_tensor(value) and value.device.type != "cpu":
return value.device
for value in values:
if torch_module.is_tensor(value):
return value.device
return None


def _promoted_pair(raw_pytorch, torch_module, left, right):
"""Return PyTorch operands on a common dtype and preferred existing device."""
device = _preferred_pytorch_device(torch_module, left, right)
left = raw_pytorch.array(left)
right = raw_pytorch.array(right)
dtype = torch_module.promote_types(left.dtype, right.dtype)
if device is None:
return left.to(dtype=dtype), right.to(dtype=dtype)
return left.to(device=device, dtype=dtype), right.to(device=device, dtype=dtype)


def patch_pytorch_dot_outer_device_contract() -> None:
"""Patch raw/public PyTorch ``dot`` and ``outer`` to preserve non-CPU operands."""
try:
import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
import torch # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable
return

original_dot = getattr(raw_pytorch, "dot", None)
original_outer = getattr(raw_pytorch, "outer", None)
if original_dot is None or original_outer is None:
return
if getattr(original_dot, "_pyrecest_dot_outer_device_contract", False) and getattr(
original_outer,
"_pyrecest_dot_outer_device_contract",
False,
):
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.dot = original_dot
backend.outer = original_outer
return

def dot(a, b):
a, b = _promoted_pair(raw_pytorch, torch, a, b)
if a.ndim == 0 or b.ndim == 0:
return torch.multiply(a, b)
if a.ndim == 1 and b.ndim == 1:
return torch.dot(a, b)
if b.ndim == 1:
return torch.einsum("...i,i->...", a, b)
if a.ndim == 1:
return torch.einsum("i,...i->...", a, b)
return torch.einsum("...i,...i->...", a, b)

def outer(a, b):
a, b = _promoted_pair(raw_pytorch, torch, a, b)
if a.ndim == 0 or b.ndim == 0:
return torch.multiply(a, b)
return a[..., :, None] * b[..., None, :]

for helper_name, helper, original_helper in (
("dot", dot, original_dot),
("outer", outer, original_outer),
):
helper.__name__ = getattr(original_helper, "__name__", helper_name)
helper.__doc__ = getattr(original_helper, "__doc__", None)
helper._pyrecest_dot_outer_device_contract = True
helper._pyrecest_device_contract = True
helper._pyrecest_numpy_contract = True
setattr(raw_pytorch, helper_name, helper)
if getattr(backend, "__backend_name__", None) == "pytorch":
setattr(backend, helper_name, helper)
6 changes: 6 additions & 0 deletions src/pyrecest/stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from dataclasses import asdict, dataclass
from typing import Final, Literal, ParamSpec, TypeVar

from pyrecest.backend_support._pytorch_dot_outer_device_contract import (
patch_pytorch_dot_outer_device_contract as _patch_pytorch_dot_outer_device_contract,
)

_patch_pytorch_dot_outer_device_contract()

P = ParamSpec("P")
R = TypeVar("R")

Expand Down
70 changes: 69 additions & 1 deletion tests/backend_support/test_pytorch_dot_outer_device_contract.py
Original file line number Diff line number Diff line change
@@ -1 +1,69 @@
placeholder
import importlib.util

import pytest

from tests.support.backend_runner import run_backend_code

pytestmark = pytest.mark.backend_portable


def _device_contract_code(target_module):
return f"""
import torch
import pyrecest # noqa: F401 # triggers backend-support compatibility patches
import pyrecest.backend as backend
import pyrecest._backend.pytorch as raw_pytorch


def _non_cpu_device():
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("meta")


target = {target_module}
device = _non_cpu_device()
right_vector = torch.ones(2, device=device)

dot_result = target.dot(torch.tensor([1.0, 2.0]), right_vector)
assert dot_result.device.type == device.type
assert tuple(dot_result.shape) == ()
if device.type != "meta":
assert torch.allclose(dot_result.cpu(), torch.tensor(3.0))

outer_result = target.outer(torch.tensor([1.0, 2.0]), right_vector)
assert outer_result.device.type == device.type
assert tuple(outer_result.shape) == (2, 2)
if device.type != "meta":
expected = torch.tensor([[1.0, 1.0], [2.0, 2.0]])
assert torch.allclose(outer_result.cpu(), expected)

dot_arraylike_result = target.dot([1.0, 2.0], right_vector)
assert dot_arraylike_result.device.type == device.type
assert tuple(dot_arraylike_result.shape) == ()
if device.type != "meta":
assert torch.allclose(dot_arraylike_result.cpu(), torch.tensor(3.0))

print("ok")
"""



def test_raw_pytorch_dot_outer_prefer_existing_non_cpu_device_after_import():
if importlib.util.find_spec("torch") is None:
pytest.skip("PyTorch is not installed")

result = run_backend_code("numpy", _device_contract_code("raw_pytorch"))

assert result.returncode == 0, result.stderr
assert "ok" in result.stdout


def test_public_pytorch_dot_outer_prefer_existing_non_cpu_device():
if importlib.util.find_spec("torch") is None:
pytest.skip("PyTorch is not installed")

result = run_backend_code("pytorch", _device_contract_code("backend"))

assert result.returncode == 0, result.stderr
assert "ok" in result.stdout
Loading