diff --git a/src/pyrecest/backend_support/_pytorch_one_hot_scalar_contract.py b/src/pyrecest/backend_support/_pytorch_one_hot_scalar_contract.py new file mode 100644 index 000000000..a0f23a20b --- /dev/null +++ b/src/pyrecest/backend_support/_pytorch_one_hot_scalar_contract.py @@ -0,0 +1,63 @@ +"""PyTorch ``one_hot`` scalar-label compatibility hook.""" + +from __future__ import annotations + +from operator import index as _operator_index + + +def _as_one_hot_labels(torch_module, labels): + """Return integer labels without treating scalar labels as tensor sizes.""" + + if torch_module.is_tensor(labels): + if ( + labels.dtype == torch_module.bool + or labels.dtype.is_floating_point + or labels.dtype.is_complex + ): + return labels + return labels.to(dtype=torch_module.long) + + if isinstance(labels, bool): + # Preserve PyTorch's previous rejection of boolean scalar labels rather + # than silently interpreting ``True`` as class index 1. + return torch_module.LongTensor(labels) + + try: + scalar_label = _operator_index(labels) + except TypeError: + return torch_module.LongTensor(labels) + return torch_module.as_tensor(scalar_label, dtype=torch_module.long) + + +def patch_pytorch_one_hot_scalar_contract() -> None: + """Patch raw/public PyTorch ``one_hot`` to treat scalar labels as values.""" + + try: + import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + import torch as torch_module # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable + return + + original_one_hot = getattr(pytorch_backend, "one_hot", None) + if original_one_hot is None: + return + if getattr(original_one_hot, "_pyrecest_scalar_label_contract", False): + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.one_hot = original_one_hot + return + + def one_hot(labels, num_classes): + labels = _as_one_hot_labels(torch_module, labels) + result = torch_module.nn.functional.one_hot( + labels, + _operator_index(num_classes), + ) + return result.to(dtype=torch_module.uint8) + + one_hot.__name__ = getattr(original_one_hot, "__name__", "one_hot") + one_hot.__doc__ = getattr(original_one_hot, "__doc__", None) + one_hot._pyrecest_scalar_label_contract = True + pytorch_backend.one_hot = one_hot + if getattr(backend, "__backend_name__", None) == "pytorch": + backend.one_hot = one_hot diff --git a/src/pyrecest/stability.py b/src/pyrecest/stability.py index aa56e78f3..1d86e18cb 100644 --- a/src/pyrecest/stability.py +++ b/src/pyrecest/stability.py @@ -9,8 +9,12 @@ from pyrecest.backend_support._pytorch_allclose_device_contract import ( patch_pytorch_allclose_device_contract as _patch_pytorch_allclose_device_contract, ) +from pyrecest.backend_support._pytorch_one_hot_scalar_contract import ( + patch_pytorch_one_hot_scalar_contract as _patch_pytorch_one_hot_scalar_contract, +) _patch_pytorch_allclose_device_contract() +_patch_pytorch_one_hot_scalar_contract() P = ParamSpec("P") R = TypeVar("R") diff --git a/tests/backend/test_pytorch_one_hot_scalar_contract.py b/tests/backend/test_pytorch_one_hot_scalar_contract.py new file mode 100644 index 000000000..fe3564628 --- /dev/null +++ b/tests/backend/test_pytorch_one_hot_scalar_contract.py @@ -0,0 +1,34 @@ +import pytest + +import pyrecest.backend as backend + + +def _to_python(value): + value = backend.to_numpy(value) + if hasattr(value, "tolist"): + return value.tolist() + return value + + +def test_public_pytorch_one_hot_accepts_scalar_label(): + if backend.__backend_name__ != "pytorch": + pytest.skip("PyTorch-specific one_hot scalar-label contract") + + result = backend.one_hot(1, 3) + + assert result.shape == (3,) + assert str(backend.to_numpy(result).dtype) == "uint8" + assert _to_python(result) == [0, 1, 0] + + +def test_raw_pytorch_one_hot_accepts_scalar_label_after_package_import(): + if backend.__backend_name__ != "pytorch": + pytest.skip("PyTorch-specific raw-backend one_hot scalar-label contract") + + import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel + + result = pytorch_backend.one_hot(1, 3) + + assert result.shape == (3,) + assert pytorch_backend.to_numpy(result).dtype.name == "uint8" + assert pytorch_backend.to_numpy(result).tolist() == [0, 1, 0]