Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
cov
create_diagonal
default_dtype
diag_indices
expand_dims
isclose
isin
Expand All @@ -25,5 +26,7 @@
partition
setdiff1d
sinc
tril_indices
triu_indices
union1d
```
6 changes: 6 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
atleast_nd,
cov,
create_diagonal,
diag_indices,
expand_dims,
isclose,
isin,
Expand All @@ -15,6 +16,8 @@
searchsorted,
setdiff1d,
sinc,
tril_indices,
triu_indices,
union1d,
)
from ._lib._at import at
Expand All @@ -40,6 +43,7 @@
"cov",
"create_diagonal",
"default_dtype",
"diag_indices",
"expand_dims",
"isclose",
"isin",
Expand All @@ -53,5 +57,7 @@
"searchsorted",
"setdiff1d",
"sinc",
"tril_indices",
"triu_indices",
"union1d",
]
196 changes: 195 additions & 1 deletion src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,22 @@
)
from ._lib._utils._compat import device as get_device
from ._lib._utils._helpers import asarrays, eager_shape
from ._lib._utils._typing import Array, DType
from ._lib._utils._typing import Array, Device, DType

__all__ = [
"atleast_nd",
"cov",
"create_diagonal",
"diag_indices",
"expand_dims",
"isclose",
"nan_to_num",
"one_hot",
"pad",
"searchsorted",
"sinc",
"tril_indices",
"triu_indices",
]


Expand Down Expand Up @@ -238,6 +241,55 @@ def create_diagonal(
return _funcs.create_diagonal(x, offset=offset, xp=xp)


def diag_indices(
n: int, /, *, ndim: int = 2, device: Device | None = None, xp: ModuleType
) -> tuple[Array, ...]:
"""
Return the indices to access the main diagonal of an array.

Equivalent to ``numpy.diag_indices``.

Parameters
----------
n : int
The size of each dimension of the (hyper-)cube ``(n, n, ..., n)``
that the returned indices index into.
ndim : int, optional
The number of dimensions. Default: ``2``.
device : Device, optional
The device on which to place the returned arrays. Default: current device.
xp : array_namespace
The standard-compatible namespace to create the indices in.

Returns
-------
tuple of array
``ndim`` 1-D integer arrays of length ``n`` that together index
the main diagonal of an array of shape ``(n,) * ndim``.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> rows, cols = xpx.diag_indices(3, xp=xp)
>>> rows
Array([0, 1, 2], dtype=array_api_strict.int64)
>>> cols
Array([0, 1, 2], dtype=array_api_strict.int64)
"""
if n < 0:
msg = f"`n` must be non-negative, got {n}"
raise ValueError(msg)
if ndim < 1:
msg = f"`ndim` must be >= 1, got {ndim}"
raise ValueError(msg)
if device is None and (
is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp)
):
return xp.diag_indices(n, ndim=ndim)
return _funcs.diag_indices(n, ndim=ndim, device=device, xp=xp)


def expand_dims(
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
) -> Array:
Expand Down Expand Up @@ -1150,3 +1202,145 @@ def union1d(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
return xp.union1d(a, b)

return _funcs.union1d(a, b, xp=xp)


def tril_indices(
n: int,
/,
*,
offset: int = 0,
m: int | None = None,
device: Device | None = None,
xp: ModuleType,
) -> tuple[Array, Array]:
"""
Return the indices of the lower triangle of an ``(n, m)`` array.

Equivalent to ``numpy.tril_indices`` with parameter ``k`` renamed to
``offset`` to match ``xp.linalg.diagonal``'s naming.

Parameters
----------
n : int
The row dimension of the array.
offset : int, optional
Diagonal offset; ``0`` (default) is the main diagonal. Corresponds
to ``k`` in ``numpy.tril_indices``.
m : int, optional
The column dimension. If ``None`` (default), assumed equal to `n`.
device : Device, optional
The device on which to place the returned arrays. Default: current device.
xp : array_namespace
The standard-compatible namespace to create the indices in.

Returns
-------
tuple of array
Row and column indices ``(rows, cols)`` of the lower triangle of
the ``(n, m)`` matrix, shifted by `offset`.

Notes
-----
The generic fallback uses ``xp.nonzero``, so namespaces without
``nonzero`` are not supported on that path.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> rows, cols = xpx.tril_indices(3, xp=xp)
>>> rows
Array([0, 1, 1, 2, 2, 2], dtype=array_api_strict.int64)
>>> cols
Array([0, 0, 1, 0, 1, 2], dtype=array_api_strict.int64)
"""
if n < 0:
msg = f"`n` must be non-negative, got {n}"
raise ValueError(msg)
if m is not None and m < 0:
msg = f"`m` must be non-negative, got {m}"
raise ValueError(msg)
if device is None and (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_jax_namespace(xp)
or is_dask_namespace(xp)
):
return xp.tril_indices(n, k=offset, m=m)
if is_torch_namespace(xp):
# `torch.tril_indices` returns a 2xN tensor, not a tuple, and
# takes (row, col) rather than (n, *, m=None).
cols = n if m is None else m
idx = xp.tril_indices(n, cols, offset=offset, device=device)
return (idx[0], idx[1])
return _funcs.tril_indices(n, offset=offset, m=m, device=device, xp=xp)


def triu_indices(
n: int,
/,
*,
offset: int = 0,
m: int | None = None,
device: Device | None = None,
xp: ModuleType,
) -> tuple[Array, Array]:
"""
Return the indices of the upper triangle of an ``(n, m)`` array.

Equivalent to ``numpy.triu_indices`` with parameter ``k`` renamed to
``offset`` to match ``xp.linalg.diagonal``'s naming.

Parameters
----------
n : int
The row dimension of the array.
offset : int, optional
Diagonal offset; ``0`` (default) is the main diagonal. Corresponds
to ``k`` in ``numpy.triu_indices``.
m : int, optional
The column dimension. If ``None`` (default), assumed equal to `n`.
device : Device, optional
The device on which to place the returned arrays. Default: current device.
xp : array_namespace
The standard-compatible namespace to create the indices in.

Returns
-------
tuple of array
Row and column indices ``(rows, cols)`` of the upper triangle of
the ``(n, m)`` matrix, shifted by `offset`.

Notes
-----
The generic fallback uses ``xp.nonzero``, so namespaces without
``nonzero`` are not supported on that path.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> rows, cols = xpx.triu_indices(3, xp=xp)
>>> rows
Array([0, 0, 0, 1, 1, 2], dtype=array_api_strict.int64)
>>> cols
Array([0, 1, 2, 1, 2, 2], dtype=array_api_strict.int64)
"""
if n < 0:
msg = f"`n` must be non-negative, got {n}"
raise ValueError(msg)
if m is not None and m < 0:
msg = f"`m` must be non-negative, got {m}"
raise ValueError(msg)
if device is None and (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_jax_namespace(xp)
or is_dask_namespace(xp)
):
return xp.triu_indices(n, k=offset, m=m)
if is_torch_namespace(xp):
cols = n if m is None else m
idx = xp.triu_indices(n, cols, offset=offset, device=device)
return (idx[0], idx[1])
return _funcs.triu_indices(n, offset=offset, m=m, device=device, xp=xp)
56 changes: 56 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
"broadcast_shapes",
"cov",
"create_diagonal",
"diag_indices",
"expand_dims",
"kron",
"nunique",
"pad",
"searchsorted",
"setdiff1d",
"sinc",
"tril_indices",
"triu_indices",
]


Expand Down Expand Up @@ -346,6 +349,59 @@ def create_diagonal(
return xp.reshape(diag, (*batch_dims, n, n))


def diag_indices(
n: int, /, *, ndim: int = 2, device: Device | None = None, xp: ModuleType
) -> tuple[Array, ...]: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
idx = xp.arange(n, device=device)
return (idx,) * ndim


def _tri_indices(
n: int,
*,
offset: int,
m: int | None,
upper: bool,
device: Device | None,
xp: ModuleType,
) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01
"""Shared implementation for `tril_indices` and `triu_indices`."""
cols = n if m is None else m
rows = xp.arange(n, device=device)[:, None]
cols_a = xp.arange(cols, device=device)[None, :]
delta = cols_a - rows
mask = delta >= offset if upper else delta <= offset
r, c = xp.nonzero(mask)
return (r, c)


def tril_indices(
n: int,
/,
*,
offset: int = 0,
m: int | None = None,
device: Device | None = None,
xp: ModuleType,
) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
return _tri_indices(n, offset=offset, m=m, upper=False, device=device, xp=xp)


def triu_indices(
n: int,
/,
*,
offset: int = 0,
m: int | None = None,
device: Device | None = None,
xp: ModuleType,
) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
return _tri_indices(n, offset=offset, m=m, upper=True, device=device, xp=xp)


def default_dtype(
xp: ModuleType,
kind: Literal[
Expand Down
7 changes: 5 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,12 @@ def as_readonly(o: T) -> T: # numpydoc ignore=PR01,RT01
# Cannot interpret as a data type
return o

# This works with namedtuples too
if isinstance(o, tuple | list):
return type(o)(*(as_readonly(i) for i in o)) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType]
# namedtuple wants positional args; plain tuple/list wants an iterable.
items = (as_readonly(i) for i in o)
if hasattr(o, "_fields"):
return type(o)(*items) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType]
return type(o)(items) # type: ignore[return-value]

return o

Expand Down
Loading