Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/pyrecest/backend_support/_pytorch_one_hot_scalar_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""PyTorch ``one_hot`` scalar-label compatibility hook."""

from __future__ import annotations

from operator import index as _operator_index


def _as_one_hot_labels(torch_module, labels):
"""Return integer labels without treating scalar labels as tensor sizes."""

if torch_module.is_tensor(labels):
if (
labels.dtype == torch_module.bool
or labels.dtype.is_floating_point
or labels.dtype.is_complex
):
return labels
return labels.to(dtype=torch_module.long)

if isinstance(labels, bool):
# Preserve PyTorch's previous rejection of boolean scalar labels rather
# than silently interpreting ``True`` as class index 1.
return torch_module.LongTensor(labels)

try:
scalar_label = _operator_index(labels)
except TypeError:
return torch_module.LongTensor(labels)
return torch_module.as_tensor(scalar_label, dtype=torch_module.long)


def patch_pytorch_one_hot_scalar_contract() -> None:
"""Patch raw/public PyTorch ``one_hot`` to treat scalar labels as values."""

try:
import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
import torch as torch_module # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable
return

original_one_hot = getattr(pytorch_backend, "one_hot", None)
if original_one_hot is None:
return
if getattr(original_one_hot, "_pyrecest_scalar_label_contract", False):
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.one_hot = original_one_hot
return

def one_hot(labels, num_classes):
labels = _as_one_hot_labels(torch_module, labels)
result = torch_module.nn.functional.one_hot(
labels,
_operator_index(num_classes),
)
return result.to(dtype=torch_module.uint8)

one_hot.__name__ = getattr(original_one_hot, "__name__", "one_hot")
one_hot.__doc__ = getattr(original_one_hot, "__doc__", None)
one_hot._pyrecest_scalar_label_contract = True
pytorch_backend.one_hot = one_hot
if getattr(backend, "__backend_name__", None) == "pytorch":
backend.one_hot = one_hot
4 changes: 4 additions & 0 deletions src/pyrecest/stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
from pyrecest.backend_support._pytorch_allclose_device_contract import (
patch_pytorch_allclose_device_contract as _patch_pytorch_allclose_device_contract,
)
from pyrecest.backend_support._pytorch_one_hot_scalar_contract import (
patch_pytorch_one_hot_scalar_contract as _patch_pytorch_one_hot_scalar_contract,
)

_patch_pytorch_allclose_device_contract()
_patch_pytorch_one_hot_scalar_contract()

P = ParamSpec("P")
R = TypeVar("R")
Expand Down
34 changes: 34 additions & 0 deletions tests/backend/test_pytorch_one_hot_scalar_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest

import pyrecest.backend as backend


def _to_python(value):
value = backend.to_numpy(value)
if hasattr(value, "tolist"):
return value.tolist()
return value


def test_public_pytorch_one_hot_accepts_scalar_label():
if backend.__backend_name__ != "pytorch":
pytest.skip("PyTorch-specific one_hot scalar-label contract")

result = backend.one_hot(1, 3)

assert result.shape == (3,)
assert str(backend.to_numpy(result).dtype) == "uint8"
assert _to_python(result) == [0, 1, 0]


def test_raw_pytorch_one_hot_accepts_scalar_label_after_package_import():
if backend.__backend_name__ != "pytorch":
pytest.skip("PyTorch-specific raw-backend one_hot scalar-label contract")

import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel

result = pytorch_backend.one_hot(1, 3)

assert result.shape == (3,)
assert pytorch_backend.to_numpy(result).dtype.name == "uint8"
assert pytorch_backend.to_numpy(result).tolist() == [0, 1, 0]
Loading