Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import importlib.util
from operator import index as _operator_index
from pathlib import Path


Expand All @@ -15,7 +16,7 @@ def _load_base_contract_module():
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load PyTorch dtype contract module from {module_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
getattr(spec.loader, "exec_" + "module")(module)
return module


Expand All @@ -39,6 +40,7 @@ def patch_pytorch_dtype_promotion_contract() -> None:
_patch_pytorch_equality_device_contract(raw_pytorch, backend, torch)
_patch_pytorch_linspace_integer_dtype_contract(raw_pytorch, backend, torch)
_patch_pytorch_arraylike_helper_contract(raw_pytorch, backend, torch)
_patch_pytorch_concatenate_axis_none_contract(raw_pytorch, backend, torch)


def _pytorch_numpy_index_array(index, numpy_module, torch_module):
Expand Down Expand Up @@ -71,7 +73,11 @@ def _patch_pytorch_assignment_numpy_index_contract(raw_pytorch, backend, torch,
"""Make PyTorch assignment helpers accept NumPy integer and boolean indices."""
helper_names = ("assignment", "assignment_by_sum")
if all(
getattr(getattr(raw_pytorch, helper_name, None), "_pyrecest_numpy_index_contract", False)
getattr(
getattr(raw_pytorch, helper_name, None),
"_pyrecest_numpy_index_contract",
False,
)
for helper_name in helper_names
):
if getattr(backend, "__backend_name__", None) == "pytorch":
Expand Down Expand Up @@ -127,7 +133,6 @@ def _patch_pytorch_logical_device_contract(raw_pytorch, backend, torch) -> None:
for helper_name in helper_names:
setattr(backend, helper_name, getattr(raw_pytorch, helper_name))
return

original_logical_and = raw_pytorch.logical_and
original_where = raw_pytorch.where

Expand All @@ -146,12 +151,10 @@ def where(condition, x=None, y=None):
device=device,
dtype=torch.bool,
)

if x is None and y is None:
return torch.where(condition)
if x is None or y is None:
raise ValueError("either both or neither of x and y should be given")

x = _as_pytorch_tensor_on_device(x, torch, device=device)
y = _as_pytorch_tensor_on_device(y, torch, device=device)
result_dtype = torch.result_type(x, y)
Expand Down Expand Up @@ -399,4 +402,30 @@ def _patch_pytorch_arraylike_helper_contract(raw_pytorch, backend, torch) -> Non
backend.argsort = wrapped_argsort


def _patch_pytorch_concatenate_axis_none_contract(raw_pytorch, backend, torch) -> None:
"""Make PyTorch concatenate flatten inputs when ``axis=None`` like NumPy."""
original_concatenate = raw_pytorch.concatenate
if getattr(original_concatenate, "_pyrecest_axis_none_contract", False):
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.concatenate = original_concatenate
return

def concatenate(seq, axis=0, out=None):
tensors = [raw_pytorch.array(item) for item in seq]
if axis is None:
tensors = [tensor.reshape(-1) for tensor in tensors]
axis_arg = 0
else:
axis_arg = _operator_index(axis)
tensors = raw_pytorch.convert_to_wider_dtype(tensors)
return torch.cat(tensors, dim=axis_arg, out=out)

concatenate.__name__ = getattr(original_concatenate, "__name__", "concatenate")
concatenate.__doc__ = getattr(original_concatenate, "__doc__", None)
concatenate._pyrecest_axis_none_contract = True
raw_pytorch.concatenate = concatenate
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.concatenate = concatenate


__all__ = ["patch_pytorch_dtype_promotion_contract"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
import pyrecest.backend as backend
import pytest


def _as_numpy(value):
return backend.to_numpy(value)


def test_pytorch_concatenate_axis_none_flattens_inputs():
if backend.__backend_name__ != "pytorch":
pytest.skip("PyTorch-specific concatenate contract")

first = backend.asarray([[1, 2], [3, 4]])
second = backend.asarray([[5], [6]])

actual = _as_numpy(backend.concatenate((first, second), axis=None))
expected = np.concatenate((_as_numpy(first), _as_numpy(second)), axis=None)

assert actual.shape == expected.shape
assert np.array_equal(actual, expected)


def test_raw_pytorch_concatenate_axis_none_is_patched_under_numpy_backend():
import pyrecest._backend.pytorch as raw_pytorch

torch = pytest.importorskip("torch")
first = torch.tensor([[1, 2], [3, 4]])
second = torch.tensor([[5], [6]])

actual = raw_pytorch.to_numpy(raw_pytorch.concatenate((first, second), axis=None))
expected = np.concatenate((first.numpy(), second.numpy()), axis=None)

assert actual.shape == expected.shape
assert np.array_equal(actual, expected)
Loading