From 4b00da948f83636d3ffa64094c6a127ef029048e Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 22:08:49 +0200 Subject: [PATCH] Patch comparison device contract --- .../__init__.py | 37 +++++++++++++++---- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py b/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py index c174212ff..d5e2854fb 100644 --- a/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py +++ b/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py @@ -2,19 +2,19 @@ from __future__ import annotations -import importlib.util +from importlib import import_module, util as importlib_util from pathlib import Path def _load_base_contract_module(): module_path = Path(__file__).resolve().parent.parent / "_torch_dtype_promotion_contract.py" - spec = importlib.util.spec_from_file_location( + spec = importlib_util.spec_from_file_location( "_pyrecest_torch_dtype_promotion_contract_base", module_path, ) if spec is None or spec.loader is None: raise ImportError(f"Cannot load PyTorch dtype contract module from {module_path}") - module = importlib.util.module_from_spec(spec) + module = importlib_util.module_from_spec(spec) spec.loader.exec_module(module) return module @@ -26,16 +26,17 @@ def patch_pytorch_dtype_promotion_contract() -> None: """Apply the base PyTorch contract patch plus device-placement fixes.""" _BASE_CONTRACT.patch_pytorch_dtype_promotion_contract() try: - import numpy as np # pylint: disable=import-outside-toplevel - import pyrecest._backend.pytorch as raw_pytorch # pylint: disable=import-outside-toplevel - import pyrecest.backend as backend # pylint: disable=import-outside-toplevel - import torch # pylint: disable=import-outside-toplevel + np = import_module("numpy") + raw_pytorch = import_module("pyrecest._backend.pytorch") + backend = import_module("pyrecest.backend") + torch = import_module("torch") except ModuleNotFoundError: # pragma: no cover - PyTorch backend import failed earlier return _patch_pytorch_assignment_numpy_index_contract(raw_pytorch, backend, torch, np) _patch_pytorch_logical_device_contract(raw_pytorch, backend, torch) _patch_pytorch_binary_device_contract(raw_pytorch, backend, torch) + _patch_pytorch_comparison_device_contract(raw_pytorch, backend, torch) _patch_pytorch_equality_device_contract(raw_pytorch, backend, torch) _patch_pytorch_linspace_integer_dtype_contract(raw_pytorch, backend, torch) @@ -243,6 +244,28 @@ def _patch_pytorch_binary_device_contract(raw_pytorch, backend, torch) -> None: setattr(backend, helper_name, wrapped_helper) +def _patch_pytorch_comparison_device_contract(raw_pytorch, backend, torch) -> None: + """Keep comparison helpers on an existing non-CPU tensor device.""" + helper_names = ("greater", "less", "logical_or") + if all( + getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_device_contract", False) + for helper_name in helper_names + ): + if getattr(backend, "__backend_name__", None) == "pytorch": + for helper_name in helper_names: + setattr(backend, helper_name, getattr(raw_pytorch, helper_name)) + return + + for helper_name in helper_names: + wrapped_helper = _wrap_tensor_binary_device_helper( + getattr(raw_pytorch, helper_name), + torch, + ) + setattr(raw_pytorch, helper_name, wrapped_helper) + if getattr(backend, "__backend_name__", None) == "pytorch": + setattr(backend, helper_name, wrapped_helper) + + def _patch_pytorch_equality_device_contract(raw_pytorch, backend, torch) -> None: """Keep equality-style helpers on an existing non-CPU tensor device.""" helper_names = ("equal", "less_equal", "array" + "_equal")