diff --git a/src/pyrecest/stability.py b/src/pyrecest/stability.py index aa56e78f3..ac2112a27 100644 --- a/src/pyrecest/stability.py +++ b/src/pyrecest/stability.py @@ -10,7 +10,68 @@ patch_pytorch_allclose_device_contract as _patch_pytorch_allclose_device_contract, ) + +def _patch_pytorch_triangular_vector_helpers() -> None: + """Make PyTorch triangular-vector helpers accept array-like inputs.""" + + try: + import pyrecest._backend.pytorch as pytorch_backend # pylint: disable=import-outside-toplevel + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + import torch as torch_module # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - PyTorch backend may be unavailable + return + + originals = { + "vec_to_diag": getattr(pytorch_backend, "vec_to_diag", None), + "tril_to_vec": getattr(pytorch_backend, "tril_to_vec", None), + "triu_to_vec": getattr(pytorch_backend, "triu_to_vec", None), + } + if any(original is None for original in originals.values()): + return + + if all( + getattr(original, "_pyrecest_triangular_vector_array_like_contract", False) + for original in originals.values() + ): + if getattr(backend, "__backend_name__", None) == "pytorch": + for name, helper in originals.items(): + setattr(backend, name, helper) + return + + def vec_to_diag(vec): + values = pytorch_backend.array(vec) + return torch_module.diag_embed(values, offset=0) + + def _triangular_to_vec(x, k, indices_func): + values = pytorch_backend.array(x) + rows, cols = indices_func(values.shape[-1], k=k) + rows = rows.to(device=values.device) + cols = cols.to(device=values.device) + return values[..., rows, cols] + + def tril_to_vec(x, k=0): + return _triangular_to_vec(x, k, pytorch_backend.tril_indices) + + def triu_to_vec(x, k=0): + return _triangular_to_vec(x, k, pytorch_backend.triu_indices) + + patched = { + "vec_to_diag": vec_to_diag, + "tril_to_vec": tril_to_vec, + "triu_to_vec": triu_to_vec, + } + for name, helper in patched.items(): + original = originals[name] + helper.__name__ = getattr(original, "__name__", name) + helper.__doc__ = getattr(original, "__doc__", None) + helper._pyrecest_triangular_vector_array_like_contract = True + setattr(pytorch_backend, name, helper) + if getattr(backend, "__backend_name__", None) == "pytorch": + setattr(backend, name, helper) + + _patch_pytorch_allclose_device_contract() +_patch_pytorch_triangular_vector_helpers() P = ParamSpec("P") R = TypeVar("R")