Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/pyrecest/backend_support/_pytorch_minmax_device_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""PyTorch ``maximum``/``minimum`` device compatibility hook."""

from __future__ import annotations


def _preferred_pytorch_device(torch_module, *values):
"""Return a non-CPU tensor device when mixed-device operands are present."""
for value in values:
if torch_module.is_tensor(value) and value.device.type != "cpu":
return value.device
for value in values:
if torch_module.is_tensor(value):
return value.device
return None


def _minmax_operands(raw_pytorch, torch_module, left, right):
"""Return operands on a common dtype and an existing preferred device."""
device = _preferred_pytorch_device(torch_module, left, right)
left = raw_pytorch.array(left)
right = raw_pytorch.array(right)
dtype = torch_module.promote_types(left.dtype, right.dtype)
if device is None:
return left.to(dtype=dtype), right.to(dtype=dtype)
return left.to(device=device, dtype=dtype), right.to(device=device, dtype=dtype)


def patch_pytorch_minmax_device_contract() -> None:
"""Patch raw/public PyTorch ``maximum`` and ``minimum`` to preserve device."""
try:
import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
import torch # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable
return

helpers = {
"maximum": torch.maximum,
"minimum": torch.minimum,
}
if all(
getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_minmax_device_contract", False)
for helper_name in helpers
):
if getattr(backend, "__backend_name__", None) == "pytorch":
for helper_name in helpers:
setattr(backend, helper_name, getattr(raw_pytorch, helper_name))
return

for helper_name, torch_helper in helpers.items():
original_helper = getattr(raw_pytorch, helper_name)

def minmax(left, right, _torch_helper=torch_helper):
left, right = _minmax_operands(raw_pytorch, torch, left, right)
return _torch_helper(left, right)

minmax.__name__ = getattr(original_helper, "__name__", helper_name)
minmax.__doc__ = getattr(original_helper, "__doc__", None)
minmax._pyrecest_minmax_device_contract = True
minmax._pyrecest_device_contract = True
setattr(raw_pytorch, helper_name, minmax)
if getattr(backend, "__backend_name__", None) == "pytorch":
setattr(backend, helper_name, minmax)
7 changes: 7 additions & 0 deletions src/pyrecest/stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
from dataclasses import asdict, dataclass
from typing import Final, Literal, ParamSpec, TypeVar

from pyrecest.backend_support._pytorch_minmax_device_contract import (
patch_pytorch_minmax_device_contract as _patch_pytorch_minmax_device_contract,
)

_pytorch_minmax_device_contract = _patch_pytorch_minmax_device_contract
_pytorch_minmax_device_contract()

P = ParamSpec("P")
R = TypeVar("R")

Expand Down
72 changes: 72 additions & 0 deletions tests/backend_support/test_pytorch_minmax_device_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Regression tests for PyTorch maximum/minimum device preservation."""

from __future__ import annotations

import importlib.util

import pytest

from tests.support.backend_runner import run_backend_code

pytestmark = pytest.mark.backend_portable


def _device_contract_code(target_module: str) -> str:
return f"""
import torch
import pyrecest # noqa: F401 # triggers backend-support compatibility patches
import pyrecest.backend as backend
import pyrecest._backend.pytorch as raw_pytorch


def _non_cpu_device():
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("meta")


target = {target_module}
device = _non_cpu_device()
right = torch.tensor([1.0, 4.0], device=device)

maximum_result = target.maximum([2.0, 3.0], right)
assert maximum_result.device.type == device.type
assert tuple(maximum_result.shape) == (2,)
if device.type != "meta":
assert torch.allclose(maximum_result.cpu(), torch.tensor([2.0, 4.0]))

minimum_result = target.minimum(torch.tensor([2.0, 3.0]), right)
assert minimum_result.device.type == device.type
assert tuple(minimum_result.shape) == (2,)
if device.type != "meta":
assert torch.allclose(minimum_result.cpu(), torch.tensor([1.0, 3.0]))

left = torch.tensor([2.0, 3.0], device=device)
minimum_arraylike_result = target.minimum(left, [1.0, 4.0])
assert minimum_arraylike_result.device.type == device.type
assert tuple(minimum_arraylike_result.shape) == (2,)
if device.type != "meta":
assert torch.allclose(minimum_arraylike_result.cpu(), torch.tensor([1.0, 3.0]))

print("ok")
"""


def test_raw_pytorch_maximum_minimum_prefer_existing_non_cpu_device_after_import():
if importlib.util.find_spec("torch") is None:
pytest.skip("PyTorch is not installed")

result = run_backend_code("numpy", _device_contract_code("raw_pytorch"))

assert result.returncode == 0, result.stderr
assert "ok" in result.stdout


def test_public_pytorch_maximum_minimum_prefer_existing_non_cpu_device():
if importlib.util.find_spec("torch") is None:
pytest.skip("PyTorch is not installed")

result = run_backend_code("pytorch", _device_contract_code("backend"))

assert result.returncode == 0, result.stderr
assert "ok" in result.stdout
Loading