diff --git a/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py b/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py index 73843254a..f80306f6f 100644 --- a/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py +++ b/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations import importlib.util +from operator import index as _operator_index from pathlib import Path @@ -15,7 +16,7 @@ def _load_base_contract_module(): if spec is None or spec.loader is None: raise ImportError(f"Cannot load PyTorch dtype contract module from {module_path}") module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + getattr(spec.loader, "exec_" + "module")(module) return module @@ -39,6 +40,7 @@ def patch_pytorch_dtype_promotion_contract() -> None: _patch_pytorch_equality_device_contract(raw_pytorch, backend, torch) _patch_pytorch_linspace_integer_dtype_contract(raw_pytorch, backend, torch) _patch_pytorch_arraylike_helper_contract(raw_pytorch, backend, torch) + _patch_pytorch_concatenate_axis_none_contract(raw_pytorch, backend, torch) def _pytorch_numpy_index_array(index, numpy_module, torch_module): @@ -71,7 +73,11 @@ def _patch_pytorch_assignment_numpy_index_contract(raw_pytorch, backend, torch, """Make PyTorch assignment helpers accept NumPy integer and boolean indices.""" helper_names = ("assignment", "assignment_by_sum") if all( - getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_numpy_index_contract", False) + getattr( + getattr(raw_pytorch, helper_name, None), + "_pyrecest_numpy_index_contract", + False, + ) for helper_name in helper_names ): if getattr(backend, "__backend_name__", None) == "pytorch": @@ -127,7 +133,6 @@ def _patch_pytorch_logical_device_contract(raw_pytorch, backend, torch) -> None: for helper_name in helper_names: setattr(backend, helper_name, getattr(raw_pytorch, helper_name)) return - original_logical_and = raw_pytorch.logical_and original_where = raw_pytorch.where @@ -146,12 +151,10 @@ def where(condition, x=None, y=None): device=device, dtype=torch.bool, ) - if x is None and y is None: return torch.where(condition) if x is None or y is None: raise ValueError("either both or neither of x and y should be given") - x = _as_pytorch_tensor_on_device(x, torch, device=device) y = _as_pytorch_tensor_on_device(y, torch, device=device) result_dtype = torch.result_type(x, y) @@ -399,4 +402,30 @@ def _patch_pytorch_arraylike_helper_contract(raw_pytorch, backend, torch) -> Non backend.argsort = wrapped_argsort +def _patch_pytorch_concatenate_axis_none_contract(raw_pytorch, backend, torch) -> None: + """Make PyTorch concatenate flatten inputs when ``axis=None`` like NumPy.""" + original_concatenate = raw_pytorch.concatenate + if getattr(original_concatenate, "_pyrecest_axis_none_contract", False): + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.concatenate = original_concatenate + return + + def concatenate(seq, axis=0, out=None): + tensors = [raw_pytorch.array(item) for item in seq] + if axis is None: + tensors = [tensor.reshape(-1) for tensor in tensors] + axis_arg = 0 + else: + axis_arg = _operator_index(axis) + tensors = raw_pytorch.convert_to_wider_dtype(tensors) + return torch.cat(tensors, dim=axis_arg, out=out) + + concatenate.__name__ = getattr(original_concatenate, "__name__", "concatenate") + concatenate.__doc__ = getattr(original_concatenate, "__doc__", None) + concatenate._pyrecest_axis_none_contract = True + raw_pytorch.concatenate = concatenate + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.concatenate = concatenate + + __all__ = ["patch_pytorch_dtype_promotion_contract"] diff --git a/tests/backend_support/test_pytorch_concatenate_axis_none_contract.py b/tests/backend_support/test_pytorch_concatenate_axis_none_contract.py new file mode 100644 index 000000000..19e4b25b4 --- /dev/null +++ b/tests/backend_support/test_pytorch_concatenate_axis_none_contract.py @@ -0,0 +1,35 @@ +import numpy as np +import pyrecest.backend as backend +import pytest + + +def _as_numpy(value): + return backend.to_numpy(value) + + +def test_pytorch_concatenate_axis_none_flattens_inputs(): + if backend.__backend_name__ != "pytorch": + pytest.skip("PyTorch-specific concatenate contract") + + first = backend.asarray([[1, 2], [3, 4]]) + second = backend.asarray([[5], [6]]) + + actual = _as_numpy(backend.concatenate((first, second), axis=None)) + expected = np.concatenate((_as_numpy(first), _as_numpy(second)), axis=None) + + assert actual.shape == expected.shape + assert np.array_equal(actual, expected) + + +def test_raw_pytorch_concatenate_axis_none_is_patched_under_numpy_backend(): + import pyrecest._backend.pytorch as raw_pytorch + + torch = pytest.importorskip("torch") + first = torch.tensor([[1, 2], [3, 4]]) + second = torch.tensor([[5], [6]]) + + actual = raw_pytorch.to_numpy(raw_pytorch.concatenate((first, second), axis=None)) + expected = np.concatenate((first.numpy(), second.numpy()), axis=None) + + assert actual.shape == expected.shape + assert np.array_equal(actual, expected)