Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions src/pyrecest/stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,68 @@
patch_pytorch_allclose_device_contract as _patch_pytorch_allclose_device_contract,
)


def _patch_pytorch_triangular_vector_helpers() -> None:
"""Make PyTorch triangular-vector helpers accept array-like inputs."""

try:
import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
import torch as torch_module # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable
return

originals = {
"vec_to_diag": getattr(pytorch_backend, "vec_to_diag", None),
"tril_to_vec": getattr(pytorch_backend, "tril_to_vec", None),
"triu_to_vec": getattr(pytorch_backend, "triu_to_vec", None),
}
if any(original is None for original in originals.values()):
return

if all(
getattr(original, "_pyrecest_triangular_vector_array_like_contract", False)
for original in originals.values()
):
if getattr(backend, "__backend_name__", None) == "pytorch":
for name, helper in originals.items():
setattr(backend, name, helper)
return

def vec_to_diag(vec):
values = pytorch_backend.array(vec)
return torch_module.diag_embed(values, offset=0)

def _triangular_to_vec(x, k, indices_func):
values = pytorch_backend.array(x)
rows, cols = indices_func(values.shape[-1], k=k)
rows = rows.to(device=values.device)
cols = cols.to(device=values.device)
return values[..., rows, cols]

def tril_to_vec(x, k=0):
return _triangular_to_vec(x, k, pytorch_backend.tril_indices)

def triu_to_vec(x, k=0):
return _triangular_to_vec(x, k, pytorch_backend.triu_indices)

patched = {
"vec_to_diag": vec_to_diag,
"tril_to_vec": tril_to_vec,
"triu_to_vec": triu_to_vec,
}
for name, helper in patched.items():
original = originals[name]
helper.__name__ = getattr(original, "__name__", name)
helper.__doc__ = getattr(original, "__doc__", None)
helper._pyrecest_triangular_vector_array_like_contract = True
setattr(pytorch_backend, name, helper)
if getattr(backend, "__backend_name__", None) == "pytorch":
setattr(backend, name, helper)


_patch_pytorch_allclose_device_contract()
_patch_pytorch_triangular_vector_helpers()

P = ParamSpec("P")
R = TypeVar("R")
Expand Down
Loading