From edfcc34e72adf3dca5933c381ce6a584485550c3 Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 13:08:46 +0200 Subject: [PATCH 1/8] Handle transpose via rename not swap_dims --- src/xarray_einstats/linalg.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/xarray_einstats/linalg.py b/src/xarray_einstats/linalg.py index 0bead5b..2f48087 100644 --- a/src/xarray_einstats/linalg.py +++ b/src/xarray_einstats/linalg.py @@ -455,7 +455,16 @@ def matrix_transpose(da, dims): if dims is None: dims = _attempt_default_dims("matrix_transpose", da.dims) dim1, dim2 = dims - return da.swap_dims({dim1: dim2, dim2: dim1}).transpose(..., *dims) + rename_dict = {dim1: dim2, dim2: dim1} + + for sub_dim1, sub_dim2 in zip(da.indexes[dim1].names, da.indexes[dim2].names): + rename_dict[sub_dim1] = sub_dim2 + rename_dict[sub_dim2] = sub_dim1 + + daT = da.rename(rename_dict).transpose(..., *dims) + + # Purely cosmetic change to preserve order of coordinates in the output + return daT.assign_coords({k: daT.coords[k] for k in da.coords}) def matrix_power(da, n, dims=None, **kwargs): From 1c8d188a9b2bae2bd06e2d5e7d34dd3ce3e21fd7 Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 13:09:22 +0200 Subject: [PATCH 2/8] Make matrix_transpose dims optional as documented --- src/xarray_einstats/accessors.py | 2 +- src/xarray_einstats/linalg.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/xarray_einstats/accessors.py b/src/xarray_einstats/accessors.py index 559b17c..746ad37 100644 --- a/src/xarray_einstats/accessors.py +++ b/src/xarray_einstats/accessors.py @@ -31,7 +31,7 @@ class LinAlgAccessor: def __init__(self, xarray_obj): self._obj = xarray_obj - def matrix_transpose(self, dims): + def matrix_transpose(self, dims=None): """Call :func:`xarray_einstats.linalg.matrix_transpose` on this DataArray.""" return matrix_transpose(self._obj, dims=dims) diff --git a/src/xarray_einstats/linalg.py b/src/xarray_einstats/linalg.py index 2f48087..f5e0880 100644 --- a/src/xarray_einstats/linalg.py +++ b/src/xarray_einstats/linalg.py @@ -434,7 +434,7 @@ def matmul(da, db, dims=None, *, out_append="2", **kwargs): return matmul_aux -def matrix_transpose(da, dims): +def matrix_transpose(da, dims=None): """Transpose the underlying matrix without modifying the dimensions. This convenience function uses :meth:`~xarray.DataArray.swap_dims` followed @@ -444,7 +444,7 @@ def matrix_transpose(da, dims): ---------- da : DataArray Input DataArray - dims : list of str + dims : list of str, optional Matrix dimensions Returns From 03d48953deba9f961ad226473aa319a165e6c5aa Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 13:09:32 +0200 Subject: [PATCH 3/8] Add missing pinv to accessor --- src/xarray_einstats/accessors.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/xarray_einstats/accessors.py b/src/xarray_einstats/accessors.py index 746ad37..af6639d 100644 --- a/src/xarray_einstats/accessors.py +++ b/src/xarray_einstats/accessors.py @@ -12,6 +12,7 @@ eigvals, eigvalsh, inv, + pinv, matrix_power, matrix_rank, matrix_transpose, @@ -120,6 +121,10 @@ def inv(self, dims=None, **kwargs): """Call :func:`xarray_einstats.linalg.inv` on this DataArray.""" return inv(self._obj, dims=dims, **kwargs) + def pinv(self, dims=None, **kwargs): + """Call :func:`xarray_einstats.linalg.pinv` on this DataArray.""" + return pinv(self._obj, dims=dims, **kwargs) + @xr.register_dataarray_accessor("einops") class EinopsAccessor: From ed13bf4b0d9f6f937cc9d725a4a9158a58e1a535 Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 13:09:43 +0200 Subject: [PATCH 4/8] Add a multiindex test to show there are no errors --- tests/test_linalg.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_linalg.py b/tests/test_linalg.py index f688f5b..0c40f19 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -166,6 +166,10 @@ def test_pinv_dataarray_tol(self, matrices, kind): def test_transpose(self, hermitian): assert_equal(hermitian, matrix_transpose(hermitian, dims=("dim", "dim2"))) + + def test_transpose_multiindex(self, matrices): + stacked = matrices.stack(batch_experiment=("batch", "experiment"), dim_dim2=("dim", "dim2")) + matrix_transpose(stacked, dims=("batch_experiment", "dim_dim2")) def test_matrix_power(self, matrices): out = matrix_power(matrices, 2, dims=("dim", "dim2")) From 428a1115d90705da3c577f809ef523a663da8f45 Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 13:10:46 +0200 Subject: [PATCH 5/8] Update docstring --- src/xarray_einstats/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/xarray_einstats/linalg.py b/src/xarray_einstats/linalg.py index f5e0880..a2b4b83 100644 --- a/src/xarray_einstats/linalg.py +++ b/src/xarray_einstats/linalg.py @@ -437,7 +437,7 @@ def matmul(da, db, dims=None, *, out_append="2", **kwargs): def matrix_transpose(da, dims=None): """Transpose the underlying matrix without modifying the dimensions. - This convenience function uses :meth:`~xarray.DataArray.swap_dims` followed + This convenience function uses :meth:`~xarray.DataArray.rename` followed by :meth:`~xarray.DataArray.transpose` to get the equivalent of a matrix transposition. Parameters From a3af36e79e4ba30f2b8a602c923da073040d2c35 Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 13:11:59 +0200 Subject: [PATCH 6/8] Apply tox fixes --- src/xarray_einstats/accessors.py | 2 +- src/xarray_einstats/linalg.py | 4 ++-- src/xarray_einstats/linalg.pyi | 2 +- tests/test_linalg.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/xarray_einstats/accessors.py b/src/xarray_einstats/accessors.py index af6639d..7578070 100644 --- a/src/xarray_einstats/accessors.py +++ b/src/xarray_einstats/accessors.py @@ -12,11 +12,11 @@ eigvals, eigvalsh, inv, - pinv, matrix_power, matrix_rank, matrix_transpose, norm, + pinv, qr, slogdet, solve, diff --git a/src/xarray_einstats/linalg.py b/src/xarray_einstats/linalg.py index a2b4b83..a838553 100644 --- a/src/xarray_einstats/linalg.py +++ b/src/xarray_einstats/linalg.py @@ -461,10 +461,10 @@ def matrix_transpose(da, dims=None): rename_dict[sub_dim1] = sub_dim2 rename_dict[sub_dim2] = sub_dim1 - daT = da.rename(rename_dict).transpose(..., *dims) + da_transposed = da.rename(rename_dict).transpose(..., *dims) # Purely cosmetic change to preserve order of coordinates in the output - return daT.assign_coords({k: daT.coords[k] for k in da.coords}) + return da_transposed.assign_coords({k: da_transposed.coords[k] for k in da.coords}) def matrix_power(da, n, dims=None, **kwargs): diff --git a/src/xarray_einstats/linalg.pyi b/src/xarray_einstats/linalg.pyi index d7ef47c..515f7d0 100644 --- a/src/xarray_einstats/linalg.pyi +++ b/src/xarray_einstats/linalg.pyi @@ -87,7 +87,7 @@ def matmul( out_append: str = ..., **kwargs: Incomplete, ) -> xarray.DataArray: ... -def matrix_transpose(da: xarray.DataArray, dims: list[str]) -> xarray.DataArray: ... +def matrix_transpose(da: xarray.DataArray, dims: list[str] | None = ...) -> xarray.DataArray: ... def matrix_power( da: xarray.DataArray, n: int, dims: Sequence[Hashable] | None = ..., **kwargs: Incomplete ) -> xarray.DataArray: ... diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 0c40f19..2e33c3a 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -166,7 +166,7 @@ def test_pinv_dataarray_tol(self, matrices, kind): def test_transpose(self, hermitian): assert_equal(hermitian, matrix_transpose(hermitian, dims=("dim", "dim2"))) - + def test_transpose_multiindex(self, matrices): stacked = matrices.stack(batch_experiment=("batch", "experiment"), dim_dim2=("dim", "dim2")) matrix_transpose(stacked, dims=("batch_experiment", "dim_dim2")) From e2fd7e8b77e9793fd182499a91c70fd999e17372 Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 13:23:42 +0200 Subject: [PATCH 7/8] Fix 1D case ;) --- src/xarray_einstats/linalg.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/xarray_einstats/linalg.py b/src/xarray_einstats/linalg.py index a838553..d02819b 100644 --- a/src/xarray_einstats/linalg.py +++ b/src/xarray_einstats/linalg.py @@ -457,9 +457,15 @@ def matrix_transpose(da, dims=None): dim1, dim2 = dims rename_dict = {dim1: dim2, dim2: dim1} - for sub_dim1, sub_dim2 in zip(da.indexes[dim1].names, da.indexes[dim2].names): - rename_dict[sub_dim1] = sub_dim2 - rename_dict[sub_dim2] = sub_dim1 + if ( + dim1 in da.indexes + and dim2 in da.indexes + and len(da.indexes[dim1].names) == len(da.indexes[dim2].names) + and len(da.indexes[dim1].names) > 1 + ): + for sub_dim1, sub_dim2 in zip(da.indexes[dim1].names, da.indexes[dim2].names): + rename_dict[sub_dim1] = sub_dim2 + rename_dict[sub_dim2] = sub_dim1 da_transposed = da.rename(rename_dict).transpose(..., *dims) From f95726f77613f4713adf3a9e7632161e0581626b Mon Sep 17 00:00:00 2001 From: krokosik Date: Wed, 15 Apr 2026 14:08:16 +0200 Subject: [PATCH 8/8] Add missing matmul to accessor --- src/xarray_einstats/accessors.py | 5 +++++ src/xarray_einstats/accessors.pyi | 2 ++ 2 files changed, 7 insertions(+) diff --git a/src/xarray_einstats/accessors.py b/src/xarray_einstats/accessors.py index 7578070..a5abc82 100644 --- a/src/xarray_einstats/accessors.py +++ b/src/xarray_einstats/accessors.py @@ -12,6 +12,7 @@ eigvals, eigvalsh, inv, + matmul, matrix_power, matrix_rank, matrix_transpose, @@ -40,6 +41,10 @@ def matrix_power(self, n, dims=None, **kwargs): """Call :func:`xarray_einstats.linalg.matrix_power` on this DataArray.""" return matrix_power(self._obj, n, dims=dims, **kwargs) + def matmul(self, other, dims=None, **kwargs): + """Call :func:`xarray_einstats.linalg.matmul` with this DataArray as ``a/da``.""" + return matmul(self._obj, other, dims=dims, **kwargs) + def cholesky(self, dims=None, **kwargs): """Call :func:`xarray_einstats.linalg.cholesky` on this DataArray.""" return cholesky(self._obj, dims=dims, **kwargs) diff --git a/src/xarray_einstats/accessors.pyi b/src/xarray_einstats/accessors.pyi index 7966f77..a6d419c 100644 --- a/src/xarray_einstats/accessors.pyi +++ b/src/xarray_einstats/accessors.pyi @@ -13,6 +13,7 @@ from .linalg import ( eigvals, eigvalsh, inv, + matmul, matrix_power, matrix_rank, matrix_transpose, @@ -28,6 +29,7 @@ class LinAlgAccessor: def __init__(self, xarray_obj: Incomplete) -> None: ... def matrix_transpose(self, dims: Incomplete) -> None: ... def matrix_power(self, n: Incomplete, dims: Incomplete = ..., **kwargs: Incomplete) -> None: ... + def matmul(self, other: Incomplete, dims: Incomplete = ..., **kwargs: Incomplete) -> None: ... def cholesky(self, dims: Incomplete = ..., **kwargs: Incomplete) -> None: ... def qr( self,