From cd5b353068f322dc2e8731b188f08f1236fcfb5b Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 23:22:31 +0200 Subject: [PATCH 1/3] Fix PyTorch one_hot scalar labels --- .../_pytorch_one_hot_scalar_contract.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 src/pyrecest/backend_support/_pytorch_one_hot_scalar_contract.py diff --git a/src/pyrecest/backend_support/_pytorch_one_hot_scalar_contract.py b/src/pyrecest/backend_support/_pytorch_one_hot_scalar_contract.py new file mode 100644 index 000000000..a0f23a20b --- /dev/null +++ b/src/pyrecest/backend_support/_pytorch_one_hot_scalar_contract.py @@ -0,0 +1,63 @@ +"""PyTorch ``one_hot`` scalar-label compatibility hook.""" + +from __future__ import annotations + +from operator import index as _operator_index + + +def _as_one_hot_labels(torch_module, labels): + """Return integer labels without treating scalar labels as tensor sizes.""" + + if torch_module.is_tensor(labels): + if ( + labels.dtype == torch_module.bool + or labels.dtype.is_floating_point + or labels.dtype.is_complex + ): + return labels + return labels.to(dtype=torch_module.long) + + if isinstance(labels, bool): + # Preserve PyTorch's previous rejection of boolean scalar labels rather + # than silently interpreting ``True`` as class index 1. + return torch_module.LongTensor(labels) + + try: + scalar_label = _operator_index(labels) + except TypeError: + return torch_module.LongTensor(labels) + return torch_module.as_tensor(scalar_label, dtype=torch_module.long) + + +def patch_pytorch_one_hot_scalar_contract() -> None: + """Patch raw/public PyTorch ``one_hot`` to treat scalar labels as values.""" + + try: + import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + import torch as torch_module # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable + return + + original_one_hot = getattr(pytorch_backend, "one_hot", None) + if original_one_hot is None: + return + if getattr(original_one_hot, "_pyrecest_scalar_label_contract", False): + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.one_hot = original_one_hot + return + + def one_hot(labels, num_classes): + labels = _as_one_hot_labels(torch_module, labels) + result = torch_module.nn.functional.one_hot( + labels, + _operator_index(num_classes), + ) + return result.to(dtype=torch_module.uint8) + + one_hot.__name__ = getattr(original_one_hot, "__name__", "one_hot") + one_hot.__doc__ = getattr(original_one_hot, "__doc__", None) + one_hot._pyrecest_scalar_label_contract = True + pytorch_backend.one_hot = one_hot + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.one_hot = one_hot From 48b09cc206b5396736db3271e2b71df2d4b3efd4 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 23:22:49 +0200 Subject: [PATCH 2/3] Load PyTorch one_hot scalar label hook --- src/pyrecest/stability.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pyrecest/stability.py b/src/pyrecest/stability.py index aa56e78f3..1d86e18cb 100644 --- a/src/pyrecest/stability.py +++ b/src/pyrecest/stability.py @@ -9,8 +9,12 @@ from pyrecest.backend_support._pytorch_allclose_device_contract import ( patch_pytorch_allclose_device_contract as _patch_pytorch_allclose_device_contract, ) +from pyrecest.backend_support._pytorch_one_hot_scalar_contract import ( + patch_pytorch_one_hot_scalar_contract as _patch_pytorch_one_hot_scalar_contract, +) _patch_pytorch_allclose_device_contract() +_patch_pytorch_one_hot_scalar_contract() P = ParamSpec("P") R = TypeVar("R") From c6112468aa82e95254b09d9322bb244101c814a1 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 23:23:39 +0200 Subject: [PATCH 3/3] Add PyTorch one_hot scalar label regression --- .../test_pytorch_one_hot_scalar_contract.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 tests/backend/test_pytorch_one_hot_scalar_contract.py diff --git a/tests/backend/test_pytorch_one_hot_scalar_contract.py b/tests/backend/test_pytorch_one_hot_scalar_contract.py new file mode 100644 index 000000000..fe3564628 --- /dev/null +++ b/tests/backend/test_pytorch_one_hot_scalar_contract.py @@ -0,0 +1,34 @@ +import pytest + +import pyrecest.backend as backend + + +def _to_python(value): + value = backend.to_numpy(value) + if hasattr(value, "tolist"): + return value.tolist() + return value + + +def test_public_pytorch_one_hot_accepts_scalar_label(): + if backend.__backend_name__ != "pytorch": + pytest.skip("PyTorch-specific one_hot scalar-label contract") + + result = backend.one_hot(1, 3) + + assert result.shape == (3,) + assert str(backend.to_numpy(result).dtype) == "uint8" + assert _to_python(result) == [0, 1, 0] + + +def test_raw_pytorch_one_hot_accepts_scalar_label_after_package_import(): + if backend.__backend_name__ != "pytorch": + pytest.skip("PyTorch-specific raw-backend one_hot scalar-label contract") + + import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel + + result = pytorch_backend.one_hot(1, 3) + + assert result.shape == (3,) + assert pytorch_backend.to_numpy(result).dtype.name == "uint8" + assert pytorch_backend.to_numpy(result).tolist() == [0, 1, 0]