Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 78 additions & 45 deletions src/pyrecest/backend_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,79 @@ def clip(a, a_min=None, a_max=None, out=None, *, min=None, max=None):
backend.clip = clip


def _pytorch_preferred_device(torch_module, *values):
"""Return the existing non-CPU tensor device preferred by binary helpers."""
for value in values:
if torch_module.is_tensor(value) and value.device.type != "cpu":
return value.device
for value in values:
if torch_module.is_tensor(value):
return value.device
return None


def _pytorch_tensor_on_device(value, torch_module, *, device):
"""Return ``value`` as a tensor on the selected device."""
if torch_module.is_tensor(value):
if device is not None and value.device != device:
return value.to(device=device)
return value
return torch_module.as_tensor(value, device=device)


def _patch_pytorch_isclose_device_contract() -> None:
"""Keep PyTorch tolerance-comparison operands on one selected device."""
try:
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - import fails before this module
return

active_pytorch_backend = getattr(backend, "__backend_name__", None) == "pytorch"

try:
import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel
import torch # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier
return

helper_names = ("isclose", "allclose")
if all(
getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False)
for helper_name in helper_names
):
if active_pytorch_backend:
for helper_name in helper_names:
setattr(backend, helper_name, getattr(raw_pytorch, helper_name))
return

def _comparison_operands(a, b):
device = _pytorch_preferred_device(torch, a, b)
a = _pytorch_tensor_on_device(a, torch, device=device)
b = _pytorch_tensor_on_device(b, torch, device=device)
return raw_pytorch.convert_to_wider_dtype([a, b])

def isclose(a, b, rtol=raw_pytorch.rtol, atol=raw_pytorch.atol):
a, b = _comparison_operands(a, b)
return torch.isclose(a, b, rtol=rtol, atol=atol)

def allclose(a, b, atol=raw_pytorch.atol, rtol=raw_pytorch.rtol):
a, b = _comparison_operands(a, b)
return torch.allclose(a, b, atol=atol, rtol=rtol)

isclose.__name__ = getattr(raw_pytorch.isclose, "__name__", "isclose")
isclose.__doc__ = getattr(raw_pytorch.isclose, "__doc__", None)
isclose._pyrecest_device_contract = True
allclose.__name__ = getattr(raw_pytorch.allclose, "__name__", "allclose")
allclose.__doc__ = getattr(raw_pytorch.allclose, "__doc__", None)
allclose._pyrecest_device_contract = True

raw_pytorch.isclose = isclose
raw_pytorch.allclose = allclose
if active_pytorch_backend:
backend.isclose = isclose
backend.allclose = allclose


def _pytorch_broadcast_dimension(dimension, numpy_module) -> int:
"""Return one NumPy-style broadcast dimension as a non-boolean integer."""

Expand Down Expand Up @@ -377,7 +450,6 @@ def _patch_jax_outer_numpy_contract() -> None:
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - import fails before this module
return

active_jax_backend = getattr(backend, "__backend_name__", None) == "jax"

try:
Expand Down Expand Up @@ -411,7 +483,6 @@ def _patch_jax_one_hot_backend_contract() -> None:
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - import fails before this module
return

active_jax_backend = getattr(backend, "__backend_name__", None) == "jax"

try:
Expand Down Expand Up @@ -449,7 +520,6 @@ def _patch_jax_take_out_contract() -> None:
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - import fails before this module
return

active_jax_backend = getattr(backend, "__backend_name__", None) == "jax"

try:
Expand Down Expand Up @@ -502,6 +572,7 @@ def take(
_patch_pytorch_tile_numpy_contract()
_patch_pytorch_copy_numpy_contract()
_patch_pytorch_clip_numpy_contract()
_patch_pytorch_isclose_device_contract()
_patch_pytorch_broadcast_to_numpy_contract()
_patch_jax_outer_numpy_contract()
_patch_jax_one_hot_backend_contract()
Expand All @@ -520,47 +591,9 @@ def get_backend_support(
return dict(row)


def backend_support(
api_name: str, backend: str | None = None
) -> dict[str, str] | str | None:
"""Alias for :func:`get_backend_support` for concise user code."""
return get_backend_support(api_name, backend=backend)


def _markdown_table_cell(value: object) -> str:
escape = chr(92) + chr(124)
return str(value).replace("\r", " ").replace("\n", "<br>").replace(chr(124), escape)


def _markdown_table_row(cells: list[str]) -> str:
separator = chr(124)
return f"{separator} " + f" {separator} ".join(cells) + f" {separator}"


def format_backend_support_markdown() -> str:
"""Render the public backend API matrix as a Markdown table."""
lines = [
_markdown_table_row(["API", "NumPy", "PyTorch", "JAX", "Notes"]),
_markdown_table_row(["-----", "-------", "---------", "-----", "-------"]),
]
for api_name, row in iter_api_backend_capabilities():
lines.append(
_markdown_table_row(
[
f"`{_markdown_table_cell(api_name)}`",
_markdown_table_cell(row["numpy"]),
_markdown_table_cell(row["pytorch"]),
_markdown_table_cell(row["jax"]),
_markdown_table_cell(row.get("notes", "")),
]
)
)
return "\n".join(lines)
def iter_backend_support():
"""Yield backend capability entries."""
yield from iter_api_backend_capabilities()


__all__ = [
"BACKEND_SUPPORT_LEVELS",
"backend_support",
"format_backend_support_markdown",
"get_backend_support",
]
__all__ = ["BACKEND_SUPPORT_LEVELS", "get_backend_support", "iter_backend_support"]
21 changes: 21 additions & 0 deletions tests/test_pytorch_isclose_device_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest

from tests.support.backend_runner import run_backend_code


def test_pytorch_isclose_places_array_like_operands_on_existing_device():
pytest.importorskip("torch")

code = """
import torch
import pyrecest.backend as backend
import pyrecest._backend.pytorch as raw_pytorch

probe = torch.ones(2, device="meta")
for helper in (backend.isclose, raw_pytorch.isclose):
result = helper(probe, [1.0, 1.0])
assert result.device.type == "meta"
assert tuple(result.shape) == (2,)
"""
result = run_backend_code("pytorch", code)
assert result.returncode == 0, result.stderr
Loading