diff --git a/src/pyrecest/backend_support/__init__.py b/src/pyrecest/backend_support/__init__.py index e11f3ec12..c9b82abf3 100644 --- a/src/pyrecest/backend_support/__init__.py +++ b/src/pyrecest/backend_support/__init__.py @@ -307,6 +307,79 @@ def clip(a, a_min=None, a_max=None, out=None, *, min=None, max=None): backend.clip = clip +def _pytorch_preferred_device(torch_module, *values): + """Return the existing non-CPU tensor device preferred by binary helpers.""" + for value in values: + if torch_module.is_tensor(value) and value.device.type != "cpu": + return value.device + for value in values: + if torch_module.is_tensor(value): + return value.device + return None + + +def _pytorch_tensor_on_device(value, torch_module, *, device): + """Return ``value`` as a tensor on the selected device.""" + if torch_module.is_tensor(value): + if device is not None and value.device != device: + return value.to(device=device) + return value + return torch_module.as_tensor(value, device=device) + + +def _patch_pytorch_isclose_device_contract() -> None: + """Keep PyTorch tolerance-comparison operands on one selected device.""" + try: + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - import fails before this module + return + + active_pytorch_backend = getattr(backend, "__backend_name__", None) == "pytorch" + + try: + import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel + import torch # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier + return + + helper_names = ("isclose", "allclose") + if all( + getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False) + for helper_name in helper_names + ): + if active_pytorch_backend: + for helper_name in helper_names: + setattr(backend, helper_name, getattr(raw_pytorch, helper_name)) + return + + def _comparison_operands(a, b): + device = _pytorch_preferred_device(torch, a, b) + a = _pytorch_tensor_on_device(a, torch, device=device) + b = _pytorch_tensor_on_device(b, torch, device=device) + return raw_pytorch.convert_to_wider_dtype([a, b]) + + def isclose(a, b, rtol=raw_pytorch.rtol, atol=raw_pytorch.atol): + a, b = _comparison_operands(a, b) + return torch.isclose(a, b, rtol=rtol, atol=atol) + + def allclose(a, b, atol=raw_pytorch.atol, rtol=raw_pytorch.rtol): + a, b = _comparison_operands(a, b) + return torch.allclose(a, b, atol=atol, rtol=rtol) + + isclose.__name__ = getattr(raw_pytorch.isclose, "__name__", "isclose") + isclose.__doc__ = getattr(raw_pytorch.isclose, "__doc__", None) + isclose._pyrecest_device_contract = True + allclose.__name__ = getattr(raw_pytorch.allclose, "__name__", "allclose") + allclose.__doc__ = getattr(raw_pytorch.allclose, "__doc__", None) + allclose._pyrecest_device_contract = True + + raw_pytorch.isclose = isclose + raw_pytorch.allclose = allclose + if active_pytorch_backend: + backend.isclose = isclose + backend.allclose = allclose + + def _pytorch_broadcast_dimension(dimension, numpy_module) -> int: """Return one NumPy-style broadcast dimension as a non-boolean integer.""" @@ -377,7 +450,6 @@ def _patch_jax_outer_numpy_contract() -> None: import pyrecest.backend as backend # pylint: disable=import-outside-toplevel except ModuleNotFoundError: # pragma: no cover - import fails before this module return - active_jax_backend = getattr(backend, "__backend_name__", None) == "jax" try: @@ -411,7 +483,6 @@ def _patch_jax_one_hot_backend_contract() -> None: import pyrecest.backend as backend # pylint: disable=import-outside-toplevel except ModuleNotFoundError: # pragma: no cover - import fails before this module return - active_jax_backend = getattr(backend, "__backend_name__", None) == "jax" try: @@ -449,7 +520,6 @@ def _patch_jax_take_out_contract() -> None: import pyrecest.backend as backend # pylint: disable=import-outside-toplevel except ModuleNotFoundError: # pragma: no cover - import fails before this module return - active_jax_backend = getattr(backend, "__backend_name__", None) == "jax" try: @@ -502,6 +572,7 @@ def take( _patch_pytorch_tile_numpy_contract() _patch_pytorch_copy_numpy_contract() _patch_pytorch_clip_numpy_contract() +_patch_pytorch_isclose_device_contract() _patch_pytorch_broadcast_to_numpy_contract() _patch_jax_outer_numpy_contract() _patch_jax_one_hot_backend_contract() @@ -520,47 +591,9 @@ def get_backend_support( return dict(row) -def backend_support( - api_name: str, backend: str | None = None -) -> dict[str, str] | str | None: - """Alias for :func:`get_backend_support` for concise user code.""" - return get_backend_support(api_name, backend=backend) - - -def _markdown_table_cell(value: object) -> str: - escape = chr(92) + chr(124) - return str(value).replace("\r", " ").replace("\n", "
").replace(chr(124), escape) - - -def _markdown_table_row(cells: list[str]) -> str: - separator = chr(124) - return f"{separator} " + f" {separator} ".join(cells) + f" {separator}" - - -def format_backend_support_markdown() -> str: - """Render the public backend API matrix as a Markdown table.""" - lines = [ - _markdown_table_row(["API", "NumPy", "PyTorch", "JAX", "Notes"]), - _markdown_table_row(["-----", "-------", "---------", "-----", "-------"]), - ] - for api_name, row in iter_api_backend_capabilities(): - lines.append( - _markdown_table_row( - [ - f"`{_markdown_table_cell(api_name)}`", - _markdown_table_cell(row["numpy"]), - _markdown_table_cell(row["pytorch"]), - _markdown_table_cell(row["jax"]), - _markdown_table_cell(row.get("notes", "")), - ] - ) - ) - return "\n".join(lines) +def iter_backend_support(): + """Yield backend capability entries.""" + yield from iter_api_backend_capabilities() -__all__ = [ - "BACKEND_SUPPORT_LEVELS", - "backend_support", - "format_backend_support_markdown", - "get_backend_support", -] +__all__ = ["BACKEND_SUPPORT_LEVELS", "get_backend_support", "iter_backend_support"] diff --git a/tests/test_pytorch_isclose_device_contract.py b/tests/test_pytorch_isclose_device_contract.py new file mode 100644 index 000000000..0a75bf63d --- /dev/null +++ b/tests/test_pytorch_isclose_device_contract.py @@ -0,0 +1,21 @@ +import pytest + +from tests.support.backend_runner import run_backend_code + + +def test_pytorch_isclose_places_array_like_operands_on_existing_device(): + pytest.importorskip("torch") + + code = """ +import torch +import pyrecest.backend as backend +import pyrecest._backend.pytorch as raw_pytorch + +probe = torch.ones(2, device="meta") +for helper in (backend.isclose, raw_pytorch.isclose): + result = helper(probe, [1.0, 1.0]) + assert result.device.type == "meta" + assert tuple(result.shape) == (2,) +""" + result = run_backend_code("pytorch", code) + assert result.returncode == 0, result.stderr