Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 2 additions & 26 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from collections.abc import Iterator
from enum import IntEnum
from types import EllipsisType, ModuleType
from typing import Any, Final, Literal, SupportsIndex, Callable
from typing import Any, Literal, SupportsIndex, Callable

import numpy as np
import numpy.typing as npt
Expand All @@ -40,35 +40,11 @@
_real_to_complex_map,
_result_type,
)
from ._devices import CPU_DEVICE, Device
from ._flags import get_array_api_strict_flags, set_array_api_strict_flags
from ._typing import PyCapsule


class Device:
_device: Final[str]
__slots__ = ("_device", "__weakref__")

def __init__(self, device: str = "CPU_DEVICE"):
if device not in ("CPU_DEVICE", "device1", "device2"):
raise ValueError(f"The device '{device}' is not a valid choice.")
self._device = device

def __repr__(self) -> str:
return f"array_api_strict.Device('{self._device}')"

def __eq__(self, other: object) -> bool:
if not isinstance(other, Device):
return False
return self._device == other._device

def __hash__(self) -> int:
return hash(("Device", self._device))


CPU_DEVICE = Device()
ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2"))


class Array:
"""
n-d array object for the array API namespace.
Expand Down
94 changes: 64 additions & 30 deletions array_api_strict/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

import numpy as np

from ._dtypes import DType, _all_dtypes, _np_dtype
from ._dtypes import DType, _all_dtypes, _np_dtype, bool as xp_bool
from ._devices import (
Device, device_supports_dtype, get_default_dtypes,
check_device as _check_device
)
from ._flags import get_array_api_strict_flags
from ._typing import NestedSequence, SupportsBufferProtocol, SupportsDLPack

Expand All @@ -14,7 +18,7 @@
from typing_extensions import TypeIs

# Circular import
from ._array_object import Array, Device
from ._array_object import Array


class Undef(Enum):
Expand All @@ -24,10 +28,15 @@ class Undef(Enum):
_undef = Undef.UNDEF


def _check_valid_dtype(dtype: DType | None) -> None:
def _check_valid_dtype(dtype: DType | None, device: Device | None = None) -> None:
# Note: Only spelling dtypes as the dtype objects is supported.
if dtype not in (None,) + _all_dtypes:
raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}")
if dtype is not None:
if dtype not in _all_dtypes:
raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}")

if device is not None:
if not device_supports_dtype(device, dtype):
raise ValueError(f"Device {device!r} does not support dtype={dtype!r}.")


def _supports_buffer_protocol(obj: object) -> TypeIs[SupportsBufferProtocol]:
Expand All @@ -38,18 +47,6 @@ def _supports_buffer_protocol(obj: object) -> TypeIs[SupportsBufferProtocol]:
return True


def _check_device(device: Device | None) -> None:
# _array_object imports in this file are inside the functions to avoid
# circular imports
from ._array_object import ALL_DEVICES, Device

if device is not None and not isinstance(device, Device):
raise ValueError(f"Unsupported device {device!r}")

if device is not None and device not in ALL_DEVICES:
raise ValueError(f"Unsupported device {device!r}")


def asarray(
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
/,
Expand All @@ -65,11 +62,12 @@ def asarray(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
_check_valid_dtype(dtype, device)
_np_dtype = None
if dtype is not None:
_np_dtype = dtype._np_dtype
_check_device(device)

if isinstance(obj, Array) and device is None:
device = obj.device

Expand Down Expand Up @@ -127,8 +125,13 @@ def arange(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
_check_valid_dtype(dtype, device)
if dtype is None:
if any(isinstance(x, float) for x in (start, stop, step)):
dtype = get_default_dtypes(device)["real floating"]
else:
dtype = get_default_dtypes(device)["integral"]

return Array._new(
np.arange(start, stop, step, dtype=_np_dtype(dtype)),
Expand All @@ -149,8 +152,10 @@ def empty(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
_check_valid_dtype(dtype, device)
if dtype is None:
dtype = get_default_dtypes(device)["real floating"]

return Array._new(np.empty(shape, dtype=_np_dtype(dtype)), device=device)

Expand All @@ -165,10 +170,12 @@ def empty_like(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
if device is None:
device = x.device
if dtype is None:
dtype = x.dtype
_check_valid_dtype(dtype, device)

return Array._new(np.empty_like(x._array, dtype=_np_dtype(dtype)), device=device)

Expand All @@ -189,8 +196,10 @@ def eye(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
_check_valid_dtype(dtype, device)
if dtype is None:
dtype = get_default_dtypes(device)["real floating"]

return Array._new(
np.eye(n_rows, M=n_cols, k=k, dtype=_np_dtype(dtype)), device=device
Expand Down Expand Up @@ -237,12 +246,22 @@ def full(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
_check_valid_dtype(dtype, device)

if not isinstance(fill_value, bool | int | float | complex):
msg = f"Expected Python scalar fill_value, got type {type(fill_value)}"
raise TypeError(msg)

if dtype is None:
if type(fill_value) == bool:
dtype = xp_bool
else:
kind = {
int: "integral", float: "real floating", complex: "complex floating"
}[type(fill_value)]
dtype = get_default_dtypes(device)[kind]

res = np.full(shape, fill_value, dtype=_np_dtype(dtype))
if DType(res.dtype) not in _all_dtypes:
# This will happen if the fill value is not something that NumPy
Expand All @@ -266,10 +285,12 @@ def full_like(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
if device is None:
device = x.device
if dtype is None:
dtype = x.dtype
_check_valid_dtype(dtype, device)

if not isinstance(fill_value, bool | int | float | complex):
msg = f"Expected Python scalar fill_value, got type {type(fill_value)}"
Expand Down Expand Up @@ -300,8 +321,13 @@ def linspace(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
_check_valid_dtype(dtype, device)
if dtype is None:
if isinstance(start, complex) or isinstance(stop, complex):
dtype = get_default_dtypes(device)["complex floating"]
else:
dtype = get_default_dtypes(device)["real floating"]

return Array._new(
np.linspace(start, stop, num, dtype=_np_dtype(dtype), endpoint=endpoint),
Expand Down Expand Up @@ -353,8 +379,10 @@ def ones(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
_check_valid_dtype(dtype, device)
if dtype is None:
dtype = get_default_dtypes(device)["real floating"]

return Array._new(np.ones(shape, dtype=_np_dtype(dtype)), device=device)

Expand All @@ -369,10 +397,12 @@ def ones_like(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
if device is None:
device = x.device
if dtype is None:
dtype = x.dtype
_check_valid_dtype(dtype, device)

return Array._new(np.ones_like(x._array, dtype=_np_dtype(dtype)), device=device)

Expand Down Expand Up @@ -418,8 +448,10 @@ def zeros(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
_check_valid_dtype(dtype, device)
if dtype is None:
dtype = get_default_dtypes(device)["real floating"]

return Array._new(np.zeros(shape, dtype=_np_dtype(dtype)), device=device)

Expand All @@ -434,9 +466,11 @@ def zeros_like(
"""
from ._array_object import Array

_check_valid_dtype(dtype)
_check_device(device)
if device is None:
device = x.device
if dtype is None:
dtype = x.dtype
_check_valid_dtype(dtype, device)

return Array._new(np.zeros_like(x._array, dtype=_np_dtype(dtype)), device=device)
101 changes: 101 additions & 0 deletions array_api_strict/_devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import Final

from ._dtypes import (
DType, float32, float64, complex64, complex128, int64,
_all_dtypes, _boolean_dtypes, _signed_integer_dtypes,
_unsigned_integer_dtypes, _integer_dtypes, _real_floating_dtypes,
_complex_floating_dtypes, _numeric_dtypes
)

_ALL_DEVICE_NAMES = ("CPU_DEVICE", "device1", "device2", "F32_device")

class Device:
_device: Final[str]
__slots__ = ("_device", "__weakref__")

def __init__(self, device: str = "CPU_DEVICE"):
if device not in _ALL_DEVICE_NAMES:
raise ValueError(f"The device '{device}' is not a valid choice.")
self._device = device

def __repr__(self) -> str:
return f"array_api_strict.Device('{self._device}')"

def __eq__(self, other: object) -> bool:
if not isinstance(other, Device):
return False
return self._device == other._device

def __hash__(self) -> int:
return hash(("Device", self._device))

def _supported_dtypes(self) -> list[DType]:
# XXX useful? Unused ATM
return list(dt for dt in _all_dtypes if device_supports_dtype(self, dt))


CPU_DEVICE = Device()
_F32_DEVICE = Device("F32_device")

ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2"), _F32_DEVICE)


def check_device(device: Device | None) -> None:
if device is not None and not isinstance(device, Device):
raise ValueError(f"Unsupported device {device!r}")

if device is not None and device not in ALL_DEVICES:
raise ValueError(f"Unsupported device {device!r}")


# Helpers for device-specific dtype support

def get_default_dtypes(device: Device | None = None) -> dict[str, Device]:
if device == _F32_DEVICE:
return {
"real floating": float32,
"complex floating": complex64,
"integral": int64,
"indexing": int64,
}
else:
return {
"real floating": float64,
"complex floating": complex128,
"integral": int64,
"indexing": int64,
}


def device_supports_dtype(device: Device | None, dtype: DType |None) -> bool:
"""True if `device` supports `dtype`, False otherwise."""
# special-case F32_device
if device == _F32_DEVICE:
return dtype not in (float64, complex128)

# All other devices support all dtypes
return True


def _map_supported(dtypes: list[DType], device: Device) -> dict[str, DType]:
return {
dt._canonic_name: dt
for dt in dtypes
if device_supports_dtype(device, dt)
}


# _info.dtypes() maps "kind" -> dict of {name: dtype}
# Note that "kinds" differ from "categories" above, per the spec.

_kind_to_dtypes = {
None: _all_dtypes,
"bool": _boolean_dtypes,
"signed integer": _signed_integer_dtypes,
"unsigned integer": _unsigned_integer_dtypes,
"integral": _integer_dtypes,
"real floating": _real_floating_dtypes,
"complex floating": _complex_floating_dtypes,
"numeric": _numeric_dtypes
}

4 changes: 3 additions & 1 deletion array_api_strict/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@

class DType:
_np_dtype: Final[np.dtype[Any]]
__slots__ = ("_np_dtype", "__weakref__")
_canonic_name: Final[Any]
__slots__ = ("_np_dtype", "_canonic_name", "__weakref__")

def __init__(self, np_dtype: npt.DTypeLike):
self._canonic_name = np_dtype
self._np_dtype = np.dtype(np_dtype)

def __repr__(self) -> str:
Expand Down
Loading
Loading