diff --git a/src/xarray_einstats/__init__.py b/src/xarray_einstats/__init__.py index 1b6da71..09b18dc 100644 --- a/src/xarray_einstats/__init__.py +++ b/src/xarray_einstats/__init__.py @@ -1,6 +1,8 @@ """Stats, linear algebra and einops for xarray.""" from __future__ import annotations +from contextlib import contextmanager +from collections.abc import Iterable import numpy as np import xarray as xr @@ -9,6 +11,7 @@ from .accessors import LinAlgAccessor, EinopsAccessor __all__ = [ + "default_linalg_dims", "einsum", "einsum_path", "matmul", @@ -188,3 +191,38 @@ def ones_ref(*args, dims, dtype=None): empty_ref, zeros_ref """ return _create_ref(*args, dims=dims, np_creator=np.ones, dtype=dtype) + + +@contextmanager +def default_linalg_dims(func_or_dims): + """Context manager to temporarily set the default dimensions for linalg functions. + + Safer alternative to monkey patching the `get_default_dims` function in `linalg` module, + as it ensures that the original function is restored even if an error occurs within the context. + + Parameters + ---------- + func_or_dims : callable or iterable + If a callable is provided, it should take the same arguments as `get_default_dims` + and return the default dimensions based on those arguments. + If an iterable is provided, it will be used as the default dimensions + regardless of the input arguments. + + Yields + ------ + None + """ + from xarray_einstats import linalg + + original_get_default_dims = linalg.get_default_dims + + def func(*args): + if isinstance(func_or_dims, Iterable): + return func_or_dims + return func_or_dims(*args) + + linalg.get_default_dims = func + try: + yield + finally: + linalg.get_default_dims = original_get_default_dims diff --git a/src/xarray_einstats/__init__.pyi b/src/xarray_einstats/__init__.pyi index 870c3c1..0645295 100644 --- a/src/xarray_einstats/__init__.pyi +++ b/src/xarray_einstats/__init__.pyi @@ -3,6 +3,8 @@ from __future__ import annotations from collections.abc import Hashable, Iterable, Sequence +from contextlib import contextmanager +from typing import Any, Callable, Generator import numpy as np import xarray @@ -13,6 +15,7 @@ from .accessors import EinopsAccessor, LinAlgAccessor from .linalg import einsum, einsum_path, matmul __all__ = [ + "default_linalg_dims", "einsum", "einsum_path", "matmul", @@ -52,3 +55,7 @@ def ones_ref( dims: Sequence[Hashable], dtype: np.typing.DTypeLike | None = ..., ) -> xarray.DataArray: ... +@contextmanager +def default_linalg_dims( + func_or_dims: Callable | Iterable, +) -> Generator[None, Any, None]: ...