Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/pyrecest/_backend/pytorch/_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@ def _dtype_as_str(dtype):
return str(dtype).split(".")[-1]


def _preferred_tensor_device(*values):
tensor_devices = [value.device for value in values if _torch.is_tensor(value)]
if not tensor_devices:
return None
return next(
(device for device in tensor_devices if device.type != "cpu"),
tensor_devices[0],
)


def set_default_dtype(value):
"""Set backend default dtype.

Expand Down Expand Up @@ -281,10 +291,15 @@ def _box_binary_scalar(target=None, box_x1=True, box_x2=True):
def _decorator(func):
@functools.wraps(func)
def _wrapped(x1, x2, *args, **kwargs):
device = _preferred_tensor_device(x1, x2)
if box_x1 and not _torch.is_tensor(x1):
x1 = _torch.tensor(x1)
x1 = _torch.tensor(x1, device=device)
elif device is not None and _torch.is_tensor(x1) and x1.device != device:
x1 = x1.to(device=device)
if box_x2 and not _torch.is_tensor(x2):
x2 = _torch.tensor(x2)
x2 = _torch.tensor(x2, device=device)
elif device is not None and _torch.is_tensor(x2) and x2.device != device:
x2 = x2.to(device=device)

return func(x1, x2, *args, **kwargs)

Expand Down
23 changes: 23 additions & 0 deletions tests/backend_support/test_pytorch_numpy_view_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,26 @@ def test_pytorch_array_copies_numpy_inputs():
)

assert result.returncode == 0, result.stderr


@pytest.mark.backend_portable
def test_pytorch_boxed_binary_scalar_prefers_existing_cuda_tensor_device():
torch = pytest.importorskip("torch")
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")

result = run_backend_code(
"pytorch",
"""
import pyrecest.backend as backend
import torch

tensor_operand = torch.ones(2, device="cuda")
result = backend.arctan2([1.0, 2.0], tensor_operand)

assert result.device.type == "cuda"
assert tuple(result.shape) == (2,)
""",
)

assert result.returncode == 0, result.stderr
Loading