diff --git a/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py b/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py index c174212ff..c27c95e5d 100644 --- a/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py +++ b/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py @@ -37,6 +37,7 @@ def patch_pytorch_dtype_promotion_contract() -> None: _patch_pytorch_logical_device_contract(raw_pytorch, backend, torch) _patch_pytorch_binary_device_contract(raw_pytorch, backend, torch) _patch_pytorch_equality_device_contract(raw_pytorch, backend, torch) + _patch_pytorch_matmul_device_contract(raw_pytorch, backend, torch) _patch_pytorch_linspace_integer_dtype_contract(raw_pytorch, backend, torch) @@ -265,6 +266,32 @@ def _patch_pytorch_equality_device_contract(raw_pytorch, backend, torch) -> None setattr(backend, helper_name, wrapped_helper) +def _patch_pytorch_matmul_device_contract(raw_pytorch, backend, torch) -> None: + """Keep matmul operands on an existing non-CPU tensor device.""" + original_matmul = raw_pytorch.matmul + if getattr(original_matmul, "_pyrecest_device_contract", False): + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.matmul = original_matmul + return + + def matmul(x, y, out=None): + device = _preferred_pytorch_device(torch, x, y) + x = raw_pytorch.array(x) + y = raw_pytorch.array(y) + if device is not None: + x = x.to(device=device) + y = y.to(device=device) + x, y = raw_pytorch.convert_to_wider_dtype([x, y]) + return torch.matmul(x, y, out=out) + + matmul.__name__ = getattr(original_matmul, "__name__", "matmul") + matmul.__doc__ = getattr(original_matmul, "__doc__", None) + matmul._pyrecest_device_contract = True + raw_pytorch.matmul = matmul + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.matmul = matmul + + def _integer_torch_dtype(dtype, raw_pytorch, torch): """Return an explicit integer torch dtype, or ``None`` for non-integers.""" if dtype is None: