diff --git a/src/xarray_einstats/accessors.py b/src/xarray_einstats/accessors.py index 559b17c..a5abc82 100644 --- a/src/xarray_einstats/accessors.py +++ b/src/xarray_einstats/accessors.py @@ -12,10 +12,12 @@ eigvals, eigvalsh, inv, + matmul, matrix_power, matrix_rank, matrix_transpose, norm, + pinv, qr, slogdet, solve, @@ -31,7 +33,7 @@ class LinAlgAccessor: def __init__(self, xarray_obj): self._obj = xarray_obj - def matrix_transpose(self, dims): + def matrix_transpose(self, dims=None): """Call :func:`xarray_einstats.linalg.matrix_transpose` on this DataArray.""" return matrix_transpose(self._obj, dims=dims) @@ -39,6 +41,10 @@ def matrix_power(self, n, dims=None, **kwargs): """Call :func:`xarray_einstats.linalg.matrix_power` on this DataArray.""" return matrix_power(self._obj, n, dims=dims, **kwargs) + def matmul(self, other, dims=None, **kwargs): + """Call :func:`xarray_einstats.linalg.matmul` with this DataArray as ``a/da``.""" + return matmul(self._obj, other, dims=dims, **kwargs) + def cholesky(self, dims=None, **kwargs): """Call :func:`xarray_einstats.linalg.cholesky` on this DataArray.""" return cholesky(self._obj, dims=dims, **kwargs) @@ -120,6 +126,10 @@ def inv(self, dims=None, **kwargs): """Call :func:`xarray_einstats.linalg.inv` on this DataArray.""" return inv(self._obj, dims=dims, **kwargs) + def pinv(self, dims=None, **kwargs): + """Call :func:`xarray_einstats.linalg.pinv` on this DataArray.""" + return pinv(self._obj, dims=dims, **kwargs) + @xr.register_dataarray_accessor("einops") class EinopsAccessor: diff --git a/src/xarray_einstats/accessors.pyi b/src/xarray_einstats/accessors.pyi index 7966f77..a6d419c 100644 --- a/src/xarray_einstats/accessors.pyi +++ b/src/xarray_einstats/accessors.pyi @@ -13,6 +13,7 @@ from .linalg import ( eigvals, eigvalsh, inv, + matmul, matrix_power, matrix_rank, matrix_transpose, @@ -28,6 +29,7 @@ class LinAlgAccessor: def __init__(self, xarray_obj: Incomplete) -> None: ... def matrix_transpose(self, dims: Incomplete) -> None: ... def matrix_power(self, n: Incomplete, dims: Incomplete = ..., **kwargs: Incomplete) -> None: ... + def matmul(self, other: Incomplete, dims: Incomplete = ..., **kwargs: Incomplete) -> None: ... def cholesky(self, dims: Incomplete = ..., **kwargs: Incomplete) -> None: ... def qr( self, diff --git a/src/xarray_einstats/linalg.py b/src/xarray_einstats/linalg.py index 0bead5b..d02819b 100644 --- a/src/xarray_einstats/linalg.py +++ b/src/xarray_einstats/linalg.py @@ -434,17 +434,17 @@ def matmul(da, db, dims=None, *, out_append="2", **kwargs): return matmul_aux -def matrix_transpose(da, dims): +def matrix_transpose(da, dims=None): """Transpose the underlying matrix without modifying the dimensions. - This convenience function uses :meth:`~xarray.DataArray.swap_dims` followed + This convenience function uses :meth:`~xarray.DataArray.rename` followed by :meth:`~xarray.DataArray.transpose` to get the equivalent of a matrix transposition. Parameters ---------- da : DataArray Input DataArray - dims : list of str + dims : list of str, optional Matrix dimensions Returns @@ -455,7 +455,22 @@ def matrix_transpose(da, dims): if dims is None: dims = _attempt_default_dims("matrix_transpose", da.dims) dim1, dim2 = dims - return da.swap_dims({dim1: dim2, dim2: dim1}).transpose(..., *dims) + rename_dict = {dim1: dim2, dim2: dim1} + + if ( + dim1 in da.indexes + and dim2 in da.indexes + and len(da.indexes[dim1].names) == len(da.indexes[dim2].names) + and len(da.indexes[dim1].names) > 1 + ): + for sub_dim1, sub_dim2 in zip(da.indexes[dim1].names, da.indexes[dim2].names): + rename_dict[sub_dim1] = sub_dim2 + rename_dict[sub_dim2] = sub_dim1 + + da_transposed = da.rename(rename_dict).transpose(..., *dims) + + # Purely cosmetic change to preserve order of coordinates in the output + return da_transposed.assign_coords({k: da_transposed.coords[k] for k in da.coords}) def matrix_power(da, n, dims=None, **kwargs): diff --git a/src/xarray_einstats/linalg.pyi b/src/xarray_einstats/linalg.pyi index d7ef47c..515f7d0 100644 --- a/src/xarray_einstats/linalg.pyi +++ b/src/xarray_einstats/linalg.pyi @@ -87,7 +87,7 @@ def matmul( out_append: str = ..., **kwargs: Incomplete, ) -> xarray.DataArray: ... -def matrix_transpose(da: xarray.DataArray, dims: list[str]) -> xarray.DataArray: ... +def matrix_transpose(da: xarray.DataArray, dims: list[str] | None = ...) -> xarray.DataArray: ... def matrix_power( da: xarray.DataArray, n: int, dims: Sequence[Hashable] | None = ..., **kwargs: Incomplete ) -> xarray.DataArray: ... diff --git a/tests/test_linalg.py b/tests/test_linalg.py index f688f5b..2e33c3a 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -167,6 +167,10 @@ def test_pinv_dataarray_tol(self, matrices, kind): def test_transpose(self, hermitian): assert_equal(hermitian, matrix_transpose(hermitian, dims=("dim", "dim2"))) + def test_transpose_multiindex(self, matrices): + stacked = matrices.stack(batch_experiment=("batch", "experiment"), dim_dim2=("dim", "dim2")) + matrix_transpose(stacked, dims=("batch_experiment", "dim_dim2")) + def test_matrix_power(self, matrices): out = matrix_power(matrices, 2, dims=("dim", "dim2")) assert out.shape == matrices.shape