From c4ad9db3e55e40a75082a1a9b39997076a66d0e2 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 21:43:25 +0200 Subject: [PATCH 1/2] Normalize PyTorch linspace dtype aliases --- .../__init__.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py b/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py index c174212ff..d6be4f7fd 100644 --- a/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py +++ b/src/pyrecest/backend_support/_torch_dtype_promotion_contract/__init__.py @@ -265,13 +265,19 @@ def _patch_pytorch_equality_device_contract(raw_pytorch, backend, torch) -> None setattr(backend, helper_name, wrapped_helper) -def _integer_torch_dtype(dtype, raw_pytorch, torch): - """Return an explicit integer torch dtype, or ``None`` for non-integers.""" +def _linspace_torch_dtype(dtype, raw_pytorch): + """Return a torch dtype for supported NumPy-style linspace dtype aliases.""" if dtype is None: return None try: - torch_dtype = raw_pytorch.as_dtype(dtype) + return raw_pytorch.as_dtype(dtype) except (KeyError, TypeError, ValueError): + return dtype + + +def _integer_torch_dtype(torch_dtype, torch): + """Return an explicit integer torch dtype, or ``None`` for non-integers.""" + if torch_dtype is None: return None integer_dtypes = { torch.uint8, @@ -284,7 +290,7 @@ def _integer_torch_dtype(dtype, raw_pytorch, torch): def _patch_pytorch_linspace_integer_dtype_contract(raw_pytorch, backend, torch) -> None: - """Make PyTorch linspace match NumPy flooring for explicit integer dtypes.""" + """Make PyTorch linspace match NumPy dtype-alias and integer flooring semantics.""" original_linspace = raw_pytorch.linspace if getattr(original_linspace, "_pyrecest_integer_dtype_contract", False): if getattr(backend, "__backend_name__", None) == "pytorch": @@ -292,14 +298,15 @@ def _patch_pytorch_linspace_integer_dtype_contract(raw_pytorch, backend, torch) return def linspace(start, stop, num=50, endpoint=True, dtype=None): - integer_dtype = _integer_torch_dtype(dtype, raw_pytorch, torch) + torch_dtype = _linspace_torch_dtype(dtype, raw_pytorch) + integer_dtype = _integer_torch_dtype(torch_dtype, torch) if integer_dtype is None: return original_linspace( start, stop, num=num, endpoint=endpoint, - dtype=dtype, + dtype=torch_dtype, ) values = original_linspace(start, stop, num=num, endpoint=endpoint, dtype=None) return torch.floor(values).to(dtype=integer_dtype) @@ -312,4 +319,4 @@ def linspace(start, stop, num=50, endpoint=True, dtype=None): backend.linspace = linspace -__all__ = ["patch_pytorch_dtype_promotion_contract"] +__all__ = ["patch_pytorch_dtype_promotion_contract"] \ No newline at end of file From 22eafee4d3e2fdfde1b70095fed66cd47efc68e7 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 21:43:38 +0200 Subject: [PATCH 2/2] Add PyTorch linspace dtype regression tests --- .../test_pytorch_linspace_dtype_contract.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 tests/backend_support/test_pytorch_linspace_dtype_contract.py diff --git a/tests/backend_support/test_pytorch_linspace_dtype_contract.py b/tests/backend_support/test_pytorch_linspace_dtype_contract.py new file mode 100644 index 000000000..3c2cd0747 --- /dev/null +++ b/tests/backend_support/test_pytorch_linspace_dtype_contract.py @@ -0,0 +1,60 @@ +"""Regression tests for PyTorch linspace dtype normalization.""" + +from __future__ import annotations + +import importlib.util + +import pytest +from tests.support.backend_runner import run_backend_code + + +@pytest.mark.backend_portable +def test_public_pytorch_linspace_accepts_numpy_dtype_aliases(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + result = run_backend_code( + "pytorch", + """ +import numpy as np +import pyrecest.backend as backend + +from_numpy_type = backend.linspace(0, 1, num=3, dtype=np.float32) +from_numpy_dtype = backend.linspace(0, 1, num=3, dtype=np.dtype("float64")) +from_torch_string = backend.linspace(0, 1, num=3, dtype="torch.float64") +from_integer_alias = backend.linspace(-1.5, 1.5, num=4, dtype=np.int64) + +assert from_numpy_type.dtype == backend.float32 +assert from_numpy_dtype.dtype == backend.float64 +assert from_torch_string.dtype == backend.float64 +assert from_integer_alias.dtype == backend.int64 +assert backend.to_numpy(from_numpy_type).tolist() == [0.0, 0.5, 1.0] +assert backend.to_numpy(from_numpy_dtype).tolist() == [0.0, 0.5, 1.0] +assert backend.to_numpy(from_torch_string).tolist() == [0.0, 0.5, 1.0] +assert backend.to_numpy(from_integer_alias).tolist() == [-2, -1, 0, 1] +""", + ) + + assert result.returncode == 0, result.stderr + + +@pytest.mark.backend_portable +def test_raw_pytorch_linspace_is_patched_under_non_pytorch_backend(): + if importlib.util.find_spec("torch") is None: + pytest.skip("PyTorch is not installed") + + result = run_backend_code( + "numpy", + """ +import numpy as np +import pyrecest # noqa: F401 - triggers import-time backend compatibility patches +import pyrecest._backend.pytorch as raw_pytorch + +values = raw_pytorch.linspace(0, 1, num=3, dtype=np.float32) + +assert values.dtype == raw_pytorch.float32 +assert values.tolist() == [0.0, 0.5, 1.0] +""", + ) + + assert result.returncode == 0, result.stderr