From 1a5f808c1e431f4350d000a9321bae2f0dddbb8e Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 13:28:18 +0200 Subject: [PATCH 1/6] Add a context manager for default dims --- src/xarray_einstats/__init__.py | 12 ++++++++++++ src/xarray_einstats/linalg.py | 2 ++ 2 files changed, 14 insertions(+) diff --git a/src/xarray_einstats/__init__.py b/src/xarray_einstats/__init__.py index 1b6da71..08f096e 100644 --- a/src/xarray_einstats/__init__.py +++ b/src/xarray_einstats/__init__.py @@ -1,6 +1,7 @@ """Stats, linear algebra and einops for xarray.""" from __future__ import annotations +from contextlib import contextmanager import numpy as np import xarray as xr @@ -188,3 +189,14 @@ def ones_ref(*args, dims, dtype=None): empty_ref, zeros_ref """ return _create_ref(*args, dims=dims, np_creator=np.ones, dtype=dtype) + + +@contextmanager +def default_linalg_dims(func: callable): + original_get_default_dims = linalg.get_default_dims + + linalg.get_default_dims = func + try: + yield + finally: + linalg.get_default_dims = original_get_default_dims diff --git a/src/xarray_einstats/linalg.py b/src/xarray_einstats/linalg.py index 0bead5b..a29823d 100644 --- a/src/xarray_einstats/linalg.py +++ b/src/xarray_einstats/linalg.py @@ -12,6 +12,8 @@ """ +from contextlib import contextmanager + import numpy as np import xarray as xr From d0e0fdf76d1e42e363ce74e017cc2d1a97055aef Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 13:34:21 +0200 Subject: [PATCH 2/6] Add a docstring and fix lint issues --- src/xarray_einstats/__init__.py | 26 +++++++++++++++++++++++++- src/xarray_einstats/linalg.py | 2 -- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/xarray_einstats/__init__.py b/src/xarray_einstats/__init__.py index 08f096e..3f953e5 100644 --- a/src/xarray_einstats/__init__.py +++ b/src/xarray_einstats/__init__.py @@ -192,9 +192,33 @@ def ones_ref(*args, dims, dtype=None): @contextmanager -def default_linalg_dims(func: callable): +def default_linalg_dims(func_or_dims: callable | list): + """Context manager to temporarily set the default dimensions for linalg functions. + + Safer alternative to monkey patching the `get_default_dims` function in `linalg` module, + as it ensures that the original function is restored even if an error occurs within the context. + + Parameters + ---------- + func_or_dims : callable or list + If a callable is provided, it should take the same arguments as `get_default_dims` + and return the default dimensions based on those arguments. + If a list is provided, it will be used as the default dimensions + regardless of the input arguments. + + Yields + ------ + None + """ + from xarray_einstats import linalg + original_get_default_dims = linalg.get_default_dims + def func(*args): + if isinstance(func_or_dims, list): + return func_or_dims + return func_or_dims(*args) + linalg.get_default_dims = func try: yield diff --git a/src/xarray_einstats/linalg.py b/src/xarray_einstats/linalg.py index a29823d..0bead5b 100644 --- a/src/xarray_einstats/linalg.py +++ b/src/xarray_einstats/linalg.py @@ -12,8 +12,6 @@ """ -from contextlib import contextmanager - import numpy as np import xarray as xr From 3d338bb152a70940f79bb9e45f025328f789a691 Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 13:48:20 +0200 Subject: [PATCH 3/6] Add default linalg dims to package exports --- src/xarray_einstats/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/xarray_einstats/__init__.py b/src/xarray_einstats/__init__.py index 3f953e5..955767f 100644 --- a/src/xarray_einstats/__init__.py +++ b/src/xarray_einstats/__init__.py @@ -10,6 +10,7 @@ from .accessors import LinAlgAccessor, EinopsAccessor __all__ = [ + "default_linalg_dims", "einsum", "einsum_path", "matmul", From 39b4bf96d13beea72e65d3ddce1453e1537cdbc9 Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 13:54:10 +0200 Subject: [PATCH 4/6] add type hints --- src/xarray_einstats/__init__.pyi | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/xarray_einstats/__init__.pyi b/src/xarray_einstats/__init__.pyi index 870c3c1..55192d9 100644 --- a/src/xarray_einstats/__init__.pyi +++ b/src/xarray_einstats/__init__.pyi @@ -13,6 +13,7 @@ from .accessors import EinopsAccessor, LinAlgAccessor from .linalg import einsum, einsum_path, matmul __all__ = [ + "default_linalg_dims", "einsum", "einsum_path", "matmul", @@ -52,3 +53,4 @@ def ones_ref( dims: Sequence[Hashable], dtype: np.typing.DTypeLike | None = ..., ) -> xarray.DataArray: ... +def default_linalg_dims(func_or_dims: callable | list[Unknown]) -> Generator[None, Any, None]: ... From 96fd80109f6db9e8bebeb4b53fd28c291a01f1f0 Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 14:47:07 +0200 Subject: [PATCH 5/6] Handle all iterables --- src/xarray_einstats/__init__.py | 9 +++++---- src/xarray_einstats/__init__.pyi | 5 ++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/xarray_einstats/__init__.py b/src/xarray_einstats/__init__.py index 955767f..09b18dc 100644 --- a/src/xarray_einstats/__init__.py +++ b/src/xarray_einstats/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations from contextlib import contextmanager +from collections.abc import Iterable import numpy as np import xarray as xr @@ -193,7 +194,7 @@ def ones_ref(*args, dims, dtype=None): @contextmanager -def default_linalg_dims(func_or_dims: callable | list): +def default_linalg_dims(func_or_dims): """Context manager to temporarily set the default dimensions for linalg functions. Safer alternative to monkey patching the `get_default_dims` function in `linalg` module, @@ -201,10 +202,10 @@ def default_linalg_dims(func_or_dims: callable | list): Parameters ---------- - func_or_dims : callable or list + func_or_dims : callable or iterable If a callable is provided, it should take the same arguments as `get_default_dims` and return the default dimensions based on those arguments. - If a list is provided, it will be used as the default dimensions + If an iterable is provided, it will be used as the default dimensions regardless of the input arguments. Yields @@ -216,7 +217,7 @@ def default_linalg_dims(func_or_dims: callable | list): original_get_default_dims = linalg.get_default_dims def func(*args): - if isinstance(func_or_dims, list): + if isinstance(func_or_dims, Iterable): return func_or_dims return func_or_dims(*args) diff --git a/src/xarray_einstats/__init__.pyi b/src/xarray_einstats/__init__.pyi index 55192d9..f6e6646 100644 --- a/src/xarray_einstats/__init__.pyi +++ b/src/xarray_einstats/__init__.pyi @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Hashable, Iterable, Sequence +from typing import Any, Callable, Generator import numpy as np import xarray @@ -53,4 +54,6 @@ def ones_ref( dims: Sequence[Hashable], dtype: np.typing.DTypeLike | None = ..., ) -> xarray.DataArray: ... -def default_linalg_dims(func_or_dims: callable | list[Unknown]) -> Generator[None, Any, None]: ... +def default_linalg_dims( + func_or_dims: Callable | Iterable, +) -> Generator[None, Any, None]: ... From 21a607c02defec522173bee090415b7ded5bb9dd Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 14:54:20 +0200 Subject: [PATCH 6/6] Add missing contextmanager wrapper --- src/xarray_einstats/__init__.pyi | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/xarray_einstats/__init__.pyi b/src/xarray_einstats/__init__.pyi index f6e6646..0645295 100644 --- a/src/xarray_einstats/__init__.pyi +++ b/src/xarray_einstats/__init__.pyi @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Hashable, Iterable, Sequence +from contextlib import contextmanager from typing import Any, Callable, Generator import numpy as np @@ -54,6 +55,7 @@ def ones_ref( dims: Sequence[Hashable], dtype: np.typing.DTypeLike | None = ..., ) -> xarray.DataArray: ... +@contextmanager def default_linalg_dims( func_or_dims: Callable | Iterable, ) -> Generator[None, Any, None]: ...