diff --git a/docs/api-reference.md b/docs/api-reference.md index 771967af..a94f95a9 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -14,6 +14,7 @@ cov create_diagonal default_dtype + diag_indices expand_dims isclose isin @@ -25,5 +26,7 @@ partition setdiff1d sinc + tril_indices + triu_indices union1d ``` diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 2fcdcd8e..44b2dc47 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -5,6 +5,7 @@ atleast_nd, cov, create_diagonal, + diag_indices, expand_dims, isclose, isin, @@ -15,6 +16,8 @@ searchsorted, setdiff1d, sinc, + tril_indices, + triu_indices, union1d, ) from ._lib._at import at @@ -40,6 +43,7 @@ "cov", "create_diagonal", "default_dtype", + "diag_indices", "expand_dims", "isclose", "isin", @@ -53,5 +57,7 @@ "searchsorted", "setdiff1d", "sinc", + "tril_indices", + "triu_indices", "union1d", ] diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 46639559..98d01eda 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -16,12 +16,13 @@ ) from ._lib._utils._compat import device as get_device from ._lib._utils._helpers import asarrays, eager_shape -from ._lib._utils._typing import Array, DType +from ._lib._utils._typing import Array, Device, DType __all__ = [ "atleast_nd", "cov", "create_diagonal", + "diag_indices", "expand_dims", "isclose", "nan_to_num", @@ -29,6 +30,8 @@ "pad", "searchsorted", "sinc", + "tril_indices", + "triu_indices", ] @@ -238,6 +241,55 @@ def create_diagonal( return _funcs.create_diagonal(x, offset=offset, xp=xp) +def diag_indices( + n: int, /, *, ndim: int = 2, device: Device | None = None, xp: ModuleType +) -> tuple[Array, ...]: + """ + Return the indices to access the main diagonal of an array. + + Equivalent to ``numpy.diag_indices``. + + Parameters + ---------- + n : int + The size of each dimension of the (hyper-)cube ``(n, n, ..., n)`` + that the returned indices index into. + ndim : int, optional + The number of dimensions. Default: ``2``. + device : Device, optional + The device on which to place the returned arrays. Default: current device. + xp : array_namespace + The standard-compatible namespace to create the indices in. + + Returns + ------- + tuple of array + ``ndim`` 1-D integer arrays of length ``n`` that together index + the main diagonal of an array of shape ``(n,) * ndim``. + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> rows, cols = xpx.diag_indices(3, xp=xp) + >>> rows + Array([0, 1, 2], dtype=array_api_strict.int64) + >>> cols + Array([0, 1, 2], dtype=array_api_strict.int64) + """ + if n < 0: + msg = f"`n` must be non-negative, got {n}" + raise ValueError(msg) + if ndim < 1: + msg = f"`ndim` must be >= 1, got {ndim}" + raise ValueError(msg) + if device is None and ( + is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp) + ): + return xp.diag_indices(n, ndim=ndim) + return _funcs.diag_indices(n, ndim=ndim, device=device, xp=xp) + + def expand_dims( a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None ) -> Array: @@ -1150,3 +1202,145 @@ def union1d(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array: return xp.union1d(a, b) return _funcs.union1d(a, b, xp=xp) + + +def tril_indices( + n: int, + /, + *, + offset: int = 0, + m: int | None = None, + device: Device | None = None, + xp: ModuleType, +) -> tuple[Array, Array]: + """ + Return the indices of the lower triangle of an ``(n, m)`` array. + + Equivalent to ``numpy.tril_indices`` with parameter ``k`` renamed to + ``offset`` to match ``xp.linalg.diagonal``'s naming. + + Parameters + ---------- + n : int + The row dimension of the array. + offset : int, optional + Diagonal offset; ``0`` (default) is the main diagonal. Corresponds + to ``k`` in ``numpy.tril_indices``. + m : int, optional + The column dimension. If ``None`` (default), assumed equal to `n`. + device : Device, optional + The device on which to place the returned arrays. Default: current device. + xp : array_namespace + The standard-compatible namespace to create the indices in. + + Returns + ------- + tuple of array + Row and column indices ``(rows, cols)`` of the lower triangle of + the ``(n, m)`` matrix, shifted by `offset`. + + Notes + ----- + The generic fallback uses ``xp.nonzero``, so namespaces without + ``nonzero`` are not supported on that path. + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> rows, cols = xpx.tril_indices(3, xp=xp) + >>> rows + Array([0, 1, 1, 2, 2, 2], dtype=array_api_strict.int64) + >>> cols + Array([0, 0, 1, 0, 1, 2], dtype=array_api_strict.int64) + """ + if n < 0: + msg = f"`n` must be non-negative, got {n}" + raise ValueError(msg) + if m is not None and m < 0: + msg = f"`m` must be non-negative, got {m}" + raise ValueError(msg) + if device is None and ( + is_numpy_namespace(xp) + or is_cupy_namespace(xp) + or is_jax_namespace(xp) + or is_dask_namespace(xp) + ): + return xp.tril_indices(n, k=offset, m=m) + if is_torch_namespace(xp): + # `torch.tril_indices` returns a 2xN tensor, not a tuple, and + # takes (row, col) rather than (n, *, m=None). + cols = n if m is None else m + idx = xp.tril_indices(n, cols, offset=offset, device=device) + return (idx[0], idx[1]) + return _funcs.tril_indices(n, offset=offset, m=m, device=device, xp=xp) + + +def triu_indices( + n: int, + /, + *, + offset: int = 0, + m: int | None = None, + device: Device | None = None, + xp: ModuleType, +) -> tuple[Array, Array]: + """ + Return the indices of the upper triangle of an ``(n, m)`` array. + + Equivalent to ``numpy.triu_indices`` with parameter ``k`` renamed to + ``offset`` to match ``xp.linalg.diagonal``'s naming. + + Parameters + ---------- + n : int + The row dimension of the array. + offset : int, optional + Diagonal offset; ``0`` (default) is the main diagonal. Corresponds + to ``k`` in ``numpy.triu_indices``. + m : int, optional + The column dimension. If ``None`` (default), assumed equal to `n`. + device : Device, optional + The device on which to place the returned arrays. Default: current device. + xp : array_namespace + The standard-compatible namespace to create the indices in. + + Returns + ------- + tuple of array + Row and column indices ``(rows, cols)`` of the upper triangle of + the ``(n, m)`` matrix, shifted by `offset`. + + Notes + ----- + The generic fallback uses ``xp.nonzero``, so namespaces without + ``nonzero`` are not supported on that path. + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> rows, cols = xpx.triu_indices(3, xp=xp) + >>> rows + Array([0, 0, 0, 1, 1, 2], dtype=array_api_strict.int64) + >>> cols + Array([0, 1, 2, 1, 2, 2], dtype=array_api_strict.int64) + """ + if n < 0: + msg = f"`n` must be non-negative, got {n}" + raise ValueError(msg) + if m is not None and m < 0: + msg = f"`m` must be non-negative, got {m}" + raise ValueError(msg) + if device is None and ( + is_numpy_namespace(xp) + or is_cupy_namespace(xp) + or is_jax_namespace(xp) + or is_dask_namespace(xp) + ): + return xp.triu_indices(n, k=offset, m=m) + if is_torch_namespace(xp): + cols = n if m is None else m + idx = xp.triu_indices(n, cols, offset=offset, device=device) + return (idx[0], idx[1]) + return _funcs.triu_indices(n, offset=offset, m=m, device=device, xp=xp) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 97904ddb..ba7120c0 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -28,6 +28,7 @@ "broadcast_shapes", "cov", "create_diagonal", + "diag_indices", "expand_dims", "kron", "nunique", @@ -35,6 +36,8 @@ "searchsorted", "setdiff1d", "sinc", + "tril_indices", + "triu_indices", ] @@ -346,6 +349,59 @@ def create_diagonal( return xp.reshape(diag, (*batch_dims, n, n)) +def diag_indices( + n: int, /, *, ndim: int = 2, device: Device | None = None, xp: ModuleType +) -> tuple[Array, ...]: # numpydoc ignore=PR01,RT01 + """See docstring in array_api_extra._delegation.""" + idx = xp.arange(n, device=device) + return (idx,) * ndim + + +def _tri_indices( + n: int, + *, + offset: int, + m: int | None, + upper: bool, + device: Device | None, + xp: ModuleType, +) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01 + """Shared implementation for `tril_indices` and `triu_indices`.""" + cols = n if m is None else m + rows = xp.arange(n, device=device)[:, None] + cols_a = xp.arange(cols, device=device)[None, :] + delta = cols_a - rows + mask = delta >= offset if upper else delta <= offset + r, c = xp.nonzero(mask) + return (r, c) + + +def tril_indices( + n: int, + /, + *, + offset: int = 0, + m: int | None = None, + device: Device | None = None, + xp: ModuleType, +) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01 + """See docstring in array_api_extra._delegation.""" + return _tri_indices(n, offset=offset, m=m, upper=False, device=device, xp=xp) + + +def triu_indices( + n: int, + /, + *, + offset: int = 0, + m: int | None = None, + device: Device | None = None, + xp: ModuleType, +) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01 + """See docstring in array_api_extra._delegation.""" + return _tri_indices(n, offset=offset, m=m, upper=True, device=device, xp=xp) + + def default_dtype( xp: ModuleType, kind: Literal[ diff --git a/tests/conftest.py b/tests/conftest.py index df703b97..6c735a20 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -96,9 +96,12 @@ def as_readonly(o: T) -> T: # numpydoc ignore=PR01,RT01 # Cannot interpret as a data type return o - # This works with namedtuples too if isinstance(o, tuple | list): - return type(o)(*(as_readonly(i) for i in o)) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType] + # namedtuple wants positional args; plain tuple/list wants an iterable. + items = (as_readonly(i) for i in o) + if hasattr(o, "_fields"): + return type(o)(*items) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType] + return type(o)(items) # type: ignore[return-value] return o diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 6a11e059..068e179c 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -21,6 +21,7 @@ cov, create_diagonal, default_dtype, + diag_indices, expand_dims, isclose, isin, @@ -32,6 +33,8 @@ partition, setdiff1d, sinc, + tril_indices, + triu_indices, union1d, ) from array_api_extra import ( @@ -56,6 +59,7 @@ lazy_xp_function(cov) lazy_xp_function(create_diagonal) lazy_xp_function(default_dtype) +lazy_xp_function(diag_indices) lazy_xp_function(expand_dims) lazy_xp_function(isclose) lazy_xp_function(isin) @@ -68,6 +72,8 @@ # FIXME calls in1d which calls xp.unique_values without size lazy_xp_function(setdiff1d, jax_jit=False) lazy_xp_function(sinc) +lazy_xp_function(tril_indices) +lazy_xp_function(triu_indices) lazy_xp_function(union1d, jax_jit=False) lazy_xp_function(xpx_searchsorted) lazy_xp_function(_funcs_searchsorted) @@ -803,6 +809,133 @@ def test_torch(self, torch: ModuleType): assert default_dtype(xp, "complex floating") == xp.complex64 +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False) +class TestDiagIndices: + def test_basic(self, xp: ModuleType): + rows, cols = diag_indices(5, xp=xp) + ref_rows, ref_cols = np.diag_indices(5) + xp_assert_equal(rows, xp.asarray(ref_rows)) + xp_assert_equal(cols, xp.asarray(ref_cols)) + + @pytest.mark.parametrize("n", [2, 4, 7]) + @pytest.mark.parametrize("ndim", [1, 2, 3, 4]) + def test_ndim(self, xp: ModuleType, n: int, ndim: int): + idx = diag_indices(n, ndim=ndim, xp=xp) + assert len(idx) == ndim + ref = np.diag_indices(n, ndim=ndim) + for got, expected in zip(idx, ref, strict=True): + xp_assert_equal(got, xp.asarray(expected)) + + def test_empty(self, xp: ModuleType): + rows, cols = diag_indices(0, xp=xp) + assert rows.shape == (0,) + assert cols.shape == (0,) + + def test_validation(self, xp: ModuleType): + with pytest.raises(ValueError, match="`n` must be non-negative"): + _ = diag_indices(-1, xp=xp) + with pytest.raises(ValueError, match="`ndim` must be >= 1"): + _ = diag_indices(3, ndim=0, xp=xp) + + def test_device(self, xp: ModuleType, device: Device): + default_device = get_device(xp.empty(0)) + rows, cols = diag_indices(3, device=None, xp=xp) + assert get_device(rows) == default_device + assert get_device(cols) == default_device + rows, cols = diag_indices(3, device=device, xp=xp) + assert get_device(rows) == device + assert get_device(cols) == device + + +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange/nonzero", strict=False) +@pytest.mark.xfail_xp_backend( + Backend.ARRAY_API_STRICTEST, + reason="generic path uses nonzero (data-dependent)", + strict=False, +) +@pytest.mark.parametrize( + ("xpx_fn", "np_fn"), + [(tril_indices, np.tril_indices), (triu_indices, np.triu_indices)], + ids=["tril", "triu"], +) +class TestTriIndices: + def test_basic( + self, + xp: ModuleType, + xpx_fn: Callable[..., tuple[Array, Array]], + np_fn: Callable[..., tuple[Array, Array]], + ): + rows, cols = xpx_fn(4, xp=xp) + ref_rows, ref_cols = np_fn(4) + xp_assert_equal(rows, xp.asarray(ref_rows)) + xp_assert_equal(cols, xp.asarray(ref_cols)) + + @pytest.mark.parametrize("offset", [-2, -1, 0, 1, 2]) + def test_offset( + self, + xp: ModuleType, + xpx_fn: Callable[..., tuple[Array, Array]], + np_fn: Callable[..., tuple[Array, Array]], + offset: int, + ): + rows, cols = xpx_fn(5, offset=offset, xp=xp) + ref_rows, ref_cols = np_fn(5, k=offset) + xp_assert_equal(rows, xp.asarray(ref_rows)) + xp_assert_equal(cols, xp.asarray(ref_cols)) + + def test_rectangular( + self, + xp: ModuleType, + xpx_fn: Callable[..., tuple[Array, Array]], + np_fn: Callable[..., tuple[Array, Array]], + ): + rows, cols = xpx_fn(3, m=5, xp=xp) + ref_rows, ref_cols = np_fn(3, m=5) + xp_assert_equal(rows, xp.asarray(ref_rows)) + xp_assert_equal(cols, xp.asarray(ref_cols)) + + @pytest.mark.xfail_xp_backend( + Backend.DASK, reason="dask: no 2D fancy indexing", strict=False + ) + def test_use_to_read( + self, + xp: ModuleType, + xpx_fn: Callable[..., tuple[Array, Array]], + np_fn: Callable[..., tuple[Array, Array]], + ): + rng = np.random.default_rng(0) + a = rng.integers(0, 100, (4, 4)) + a_xp = xp.asarray(a) + rows, cols = xpx_fn(4, xp=xp) + xp_assert_equal(a_xp[rows, cols], xp.asarray(a[np_fn(4)])) + + def test_validation( + self, + xp: ModuleType, + xpx_fn: Callable[..., tuple[Array, Array]], + np_fn: Callable[..., tuple[Array, Array]], # noqa: ARG002 # pytest param + ): + with pytest.raises(ValueError, match="`n` must be non-negative"): + _ = xpx_fn(-1, xp=xp) + with pytest.raises(ValueError, match="`m` must be non-negative"): + _ = xpx_fn(3, m=-1, xp=xp) + + def test_device( + self, + xp: ModuleType, + device: Device, + xpx_fn: Callable[..., tuple[Array, Array]], + np_fn: Callable[..., tuple[Array, Array]], # noqa: ARG002 # pytest param + ): + default_device = get_device(xp.empty(0)) + rows, cols = xpx_fn(4, device=None, xp=xp) + assert get_device(rows) == default_device + assert get_device(cols) == default_device + rows, cols = xpx_fn(4, device=device, xp=xp) + assert get_device(rows) == device + assert get_device(cols) == device + + class TestExpandDims: def test_single_axis(self, xp: ModuleType): """Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims"""