diff --git a/src/pyrecest/_backend/pytorch/_dtype.py b/src/pyrecest/_backend/pytorch/_dtype.py index fdd0fdea1..caa8f4343 100644 --- a/src/pyrecest/_backend/pytorch/_dtype.py +++ b/src/pyrecest/_backend/pytorch/_dtype.py @@ -107,6 +107,16 @@ def _dtype_as_str(dtype): return str(dtype).split(".")[-1] +def _preferred_tensor_device(*values): + tensor_devices = [value.device for value in values if _torch.is_tensor(value)] + if not tensor_devices: + return None + return next( + (device for device in tensor_devices if device.type != "cpu"), + tensor_devices[0], + ) + + def set_default_dtype(value): """Set backend default dtype. @@ -281,10 +291,15 @@ def _box_binary_scalar(target=None, box_x1=True, box_x2=True): def _decorator(func): @functools.wraps(func) def _wrapped(x1, x2, *args, **kwargs): + device = _preferred_tensor_device(x1, x2) if box_x1 and not _torch.is_tensor(x1): - x1 = _torch.tensor(x1) + x1 = _torch.tensor(x1, device=device) + elif device is not None and _torch.is_tensor(x1) and x1.device != device: + x1 = x1.to(device=device) if box_x2 and not _torch.is_tensor(x2): - x2 = _torch.tensor(x2) + x2 = _torch.tensor(x2, device=device) + elif device is not None and _torch.is_tensor(x2) and x2.device != device: + x2 = x2.to(device=device) return func(x1, x2, *args, **kwargs) diff --git a/tests/backend_support/test_pytorch_numpy_view_conversion.py b/tests/backend_support/test_pytorch_numpy_view_conversion.py index e063f500a..d56b476d8 100644 --- a/tests/backend_support/test_pytorch_numpy_view_conversion.py +++ b/tests/backend_support/test_pytorch_numpy_view_conversion.py @@ -93,3 +93,26 @@ def test_pytorch_array_copies_numpy_inputs(): ) assert result.returncode == 0, result.stderr + + +@pytest.mark.backend_portable +def test_pytorch_boxed_binary_scalar_prefers_existing_cuda_tensor_device(): + torch = pytest.importorskip("torch") + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + result = run_backend_code( + "pytorch", + """ +import pyrecest.backend as backend +import torch + +tensor_operand = torch.ones(2, device="cuda") +result = backend.arctan2([1.0, 2.0], tensor_operand) + +assert result.device.type == "cuda" +assert tuple(result.shape) == (2,) +""", + ) + + assert result.returncode == 0, result.stderr