Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@

from __future__ import annotations

import importlib.util
from importlib import import_module, util as importlib_util
from pathlib import Path


def _load_base_contract_module():
module_path = Path(__file__).resolve().parent.parent / "_torch_dtype_promotion_contract.py"
spec = importlib.util.spec_from_file_location(
spec = importlib_util.spec_from_file_location(
"_pyrecest_torch_dtype_promotion_contract_base",
module_path,
)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load PyTorch dtype contract module from {module_path}")
module = importlib.util.module_from_spec(spec)
module = importlib_util.module_from_spec(spec)
spec.loader.exec_module(module)
return module

Expand All @@ -26,16 +26,17 @@ def patch_pytorch_dtype_promotion_contract() -> None:
"""Apply the base PyTorch contract patch plus device-placement fixes."""
_BASE_CONTRACT.patch_pytorch_dtype_promotion_contract()
try:
import numpy as np # pylint: disable=import-outside-toplevel
import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
import torch # pylint: disable=import-outside-toplevel
np = import_module("numpy")
raw_pytorch = import_module("pyrecest._backend.pytorch")
backend = import_module("pyrecest.backend")
torch = import_module("torch")
except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier
return

_patch_pytorch_assignment_numpy_index_contract(raw_pytorch, backend, torch, np)
_patch_pytorch_logical_device_contract(raw_pytorch, backend, torch)
_patch_pytorch_binary_device_contract(raw_pytorch, backend, torch)
_patch_pytorch_comparison_device_contract(raw_pytorch, backend, torch)
_patch_pytorch_equality_device_contract(raw_pytorch, backend, torch)
_patch_pytorch_linspace_integer_dtype_contract(raw_pytorch, backend, torch)

Expand Down Expand Up @@ -243,6 +244,28 @@ def _patch_pytorch_binary_device_contract(raw_pytorch, backend, torch) -> None:
setattr(backend, helper_name, wrapped_helper)


def _patch_pytorch_comparison_device_contract(raw_pytorch, backend, torch) -> None:
"""Keep comparison helpers on an existing non-CPU tensor device."""
helper_names = ("greater", "less", "logical_or")
if all(
getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False)
for helper_name in helper_names
):
if getattr(backend, "__backend_name__", None) == "pytorch":
for helper_name in helper_names:
setattr(backend, helper_name, getattr(raw_pytorch, helper_name))
return

for helper_name in helper_names:
wrapped_helper = _wrap_tensor_binary_device_helper(
getattr(raw_pytorch, helper_name),
torch,
)
setattr(raw_pytorch, helper_name, wrapped_helper)
if getattr(backend, "__backend_name__", None) == "pytorch":
setattr(backend, helper_name, wrapped_helper)


def _patch_pytorch_equality_device_contract(raw_pytorch, backend, torch) -> None:
"""Keep equality-style helpers on an existing non-CPU tensor device."""
helper_names = ("equal", "less_equal", "array" + "_equal")
Expand Down
Loading