From a415c1e8c5ab12040450c4f8ba38aada408346ec Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 23 Apr 2026 13:08:53 +0200 Subject: [PATCH 01/12] BUG: _info.dtypes(kind=tuple) does not drop the device argument --- array_api_strict/_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index 12beed0..cf3e21a 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -125,7 +125,7 @@ def dtypes( if isinstance(kind, tuple): res: DataTypes = {} for k in kind: - res.update(self.dtypes(kind=k)) + res.update(self.dtypes(kind=k, device=device)) return res raise ValueError(f"unsupported kind: {kind!r}") From ef2bd5ea21000342ba6284117aaf2ceb617e2197 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 23 Apr 2026 12:52:58 +0200 Subject: [PATCH 02/12] MAINT: refactor _info.dtypes(), make check_dtype device-aware --- array_api_strict/_array_object.py | 28 +----- array_api_strict/_creation_functions.py | 53 ++++++------ array_api_strict/_devices.py | 85 +++++++++++++++++++ array_api_strict/_dtypes.py | 4 +- array_api_strict/_info.py | 84 +++--------------- array_api_strict/tests/test_device_support.py | 31 +++++-- 6 files changed, 152 insertions(+), 133 deletions(-) create mode 100644 array_api_strict/_devices.py diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 629af98..dcb20e9 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -20,7 +20,7 @@ from collections.abc import Iterator from enum import IntEnum from types import EllipsisType, ModuleType -from typing import Any, Final, Literal, SupportsIndex, Callable +from typing import Any, Literal, SupportsIndex, Callable import numpy as np import numpy.typing as npt @@ -40,35 +40,11 @@ _real_to_complex_map, _result_type, ) +from ._devices import CPU_DEVICE, ALL_DEVICES, Device from ._flags import get_array_api_strict_flags, set_array_api_strict_flags from ._typing import PyCapsule -class Device: - _device: Final[str] - __slots__ = ("_device", "__weakref__") - - def __init__(self, device: str = "CPU_DEVICE"): - if device not in ("CPU_DEVICE", "device1", "device2"): - raise ValueError(f"The device '{device}' is not a valid choice.") - self._device = device - - def __repr__(self) -> str: - return f"array_api_strict.Device('{self._device}')" - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Device): - return False - return self._device == other._device - - def __hash__(self) -> int: - return hash(("Device", self._device)) - - -CPU_DEVICE = Device() -ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2")) - - class Array: """ n-d array object for the array API namespace. diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 8d3dc60..4afa411 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -6,6 +6,7 @@ import numpy as np from ._dtypes import DType, _all_dtypes, _np_dtype +from ._devices import CPU_DEVICE, Device, device_supports_dtype, check_device as _check_device from ._flags import get_array_api_strict_flags from ._typing import NestedSequence, SupportsBufferProtocol, SupportsDLPack @@ -14,7 +15,7 @@ from typing_extensions import TypeIs # Circular import - from ._array_object import Array, Device + from ._array_object import Array class Undef(Enum): @@ -24,10 +25,15 @@ class Undef(Enum): _undef = Undef.UNDEF -def _check_valid_dtype(dtype: DType | None) -> None: +def _check_valid_dtype(dtype: DType | None, device: Device | None = None) -> None: # Note: Only spelling dtypes as the dtype objects is supported. - if dtype not in (None,) + _all_dtypes: - raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}") + if dtype is not None: + if dtype not in _all_dtypes: + raise ValueError(f"dtype must be one of the supported dtypes, got {dtype!r}") + + if device is not None: + if not device_supports_dtype(device, dtype): + raise ValueError(f"Device {device!r} does not support dtype={dtype!r}.") def _supports_buffer_protocol(obj: object) -> TypeIs[SupportsBufferProtocol]: @@ -38,18 +44,6 @@ def _supports_buffer_protocol(obj: object) -> TypeIs[SupportsBufferProtocol]: return True -def _check_device(device: Device | None) -> None: - # _array_object imports in this file are inside the functions to avoid - # circular imports - from ._array_object import ALL_DEVICES, Device - - if device is not None and not isinstance(device, Device): - raise ValueError(f"Unsupported device {device!r}") - - if device is not None and device not in ALL_DEVICES: - raise ValueError(f"Unsupported device {device!r}") - - def asarray( obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, @@ -65,11 +59,12 @@ def asarray( """ from ._array_object import Array - _check_valid_dtype(dtype) + _check_device(device) + _check_valid_dtype(dtype, device) _np_dtype = None if dtype is not None: _np_dtype = dtype._np_dtype - _check_device(device) + if isinstance(obj, Array) and device is None: device = obj.device @@ -127,8 +122,8 @@ def arange( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) + _check_valid_dtype(dtype, device) return Array._new( np.arange(start, stop, step, dtype=_np_dtype(dtype)), @@ -149,8 +144,8 @@ def empty( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) + _check_valid_dtype(dtype, device) return Array._new(np.empty(shape, dtype=_np_dtype(dtype)), device=device) @@ -165,10 +160,10 @@ def empty_like( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) if device is None: device = x.device + _check_valid_dtype(dtype, device) return Array._new(np.empty_like(x._array, dtype=_np_dtype(dtype)), device=device) @@ -189,8 +184,8 @@ def eye( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) + _check_valid_dtype(dtype, device) return Array._new( np.eye(n_rows, M=n_cols, k=k, dtype=_np_dtype(dtype)), device=device @@ -237,8 +232,8 @@ def full( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) + _check_valid_dtype(dtype, device) if not isinstance(fill_value, bool | int | float | complex): msg = f"Expected Python scalar fill_value, got type {type(fill_value)}" @@ -266,10 +261,10 @@ def full_like( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) if device is None: device = x.device + _check_valid_dtype(dtype, device) if not isinstance(fill_value, bool | int | float | complex): msg = f"Expected Python scalar fill_value, got type {type(fill_value)}" @@ -300,8 +295,8 @@ def linspace( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) + _check_valid_dtype(dtype, device) return Array._new( np.linspace(start, stop, num, dtype=_np_dtype(dtype), endpoint=endpoint), @@ -353,8 +348,8 @@ def ones( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) + _check_valid_dtype(dtype, device) return Array._new(np.ones(shape, dtype=_np_dtype(dtype)), device=device) @@ -369,10 +364,10 @@ def ones_like( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) if device is None: device = x.device + _check_valid_dtype(dtype, device) return Array._new(np.ones_like(x._array, dtype=_np_dtype(dtype)), device=device) @@ -418,8 +413,8 @@ def zeros( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) + _check_valid_dtype(dtype, device) return Array._new(np.zeros(shape, dtype=_np_dtype(dtype)), device=device) @@ -434,9 +429,9 @@ def zeros_like( """ from ._array_object import Array - _check_valid_dtype(dtype) _check_device(device) if device is None: device = x.device + _check_valid_dtype(dtype, device) return Array._new(np.zeros_like(x._array, dtype=_np_dtype(dtype)), device=device) diff --git a/array_api_strict/_devices.py b/array_api_strict/_devices.py new file mode 100644 index 0000000..eb03d92 --- /dev/null +++ b/array_api_strict/_devices.py @@ -0,0 +1,85 @@ +from typing import Final + +from ._dtypes import DType, float64, complex128 +from ._dtypes import ( + _all_dtypes, _boolean_dtypes, _signed_integer_dtypes, + _unsigned_integer_dtypes, _integer_dtypes, _real_floating_dtypes, + _complex_floating_dtypes, _numeric_dtypes +) + +_ALL_DEVICE_NAMES = ("CPU_DEVICE", "device1", "device2", "F32_device") + +class Device: + _device: Final[str] + __slots__ = ("_device", "__weakref__") + + def __init__(self, device: str = "CPU_DEVICE"): + if device not in _ALL_DEVICE_NAMES: + raise ValueError(f"The device '{device}' is not a valid choice.") + self._device = device + + def __repr__(self) -> str: + return f"array_api_strict.Device('{self._device}')" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Device): + return False + return self._device == other._device + + def __hash__(self) -> int: + return hash(("Device", self._device)) + + def _supported_dtypes(self) -> list[DType]: + # XXX useful? Unused ATM + return list(dt for dt in _all_dtypes if device_supports_dtype(self, dt)) + + +CPU_DEVICE = Device() +_F32_DEVICE = Device("F32_device") + +ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2"), _F32_DEVICE) + + +def check_device(device: Device | None) -> None: + if device is not None and not isinstance(device, Device): + raise ValueError(f"Unsupported device {device!r}") + + if device is not None and device not in ALL_DEVICES: + raise ValueError(f"Unsupported device {device!r}") + + +# Helpers for device-specific dtype support + + +def device_supports_dtype(device: Device | None, dtype: DType |None) -> bool: + """True if `device` supports `dtype`, False otherwise.""" + # special-case F32_device + if device == _F32_DEVICE: + return dtype not in (float64, complex128) + + # All other devices support all dtypes + return True + + +def _map_supported(dtypes: list[DType], device: Device) -> dict[str, DType]: + return { + dt._canonic_name: dt + for dt in dtypes + if device_supports_dtype(device, dt) + } + + +# _info.dtypes() maps "kind" -> dict of {name: dtype} +# Note that "kinds" differ from "categories" above, per the spec. + +_kind_to_dtypes = { + None: _all_dtypes, + "bool": _boolean_dtypes, + "signed integer": _signed_integer_dtypes, + "unsigned integer": _unsigned_integer_dtypes, + "integral": _integer_dtypes, + "real floating": _real_floating_dtypes, + "complex floating": _complex_floating_dtypes, + "numeric": _numeric_dtypes +} + diff --git a/array_api_strict/_dtypes.py b/array_api_strict/_dtypes.py index 564db5a..2d5ec33 100644 --- a/array_api_strict/_dtypes.py +++ b/array_api_strict/_dtypes.py @@ -11,9 +11,11 @@ class DType: _np_dtype: Final[np.dtype[Any]] - __slots__ = ("_np_dtype", "__weakref__") + _canonic_name: Final[Any] + __slots__ = ("_np_dtype", "_canonic_name", "__weakref__") def __init__(self, np_dtype: npt.DTypeLike): + self._canonic_name = np_dtype self._np_dtype = np.dtype(np_dtype) def __repr__(self) -> str: diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index cf3e21a..fa25ad7 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -1,6 +1,7 @@ import numpy as np from . import _dtypes as dt +from . import _devices from ._array_object import ALL_DEVICES, CPU_DEVICE, Device from ._flags import get_array_api_strict_flags, requires_api_version from ._typing import Capabilities, DataTypes, DefaultDataTypes @@ -54,78 +55,21 @@ def dtypes( device: Device | None = None, kind: str | tuple[str, ...] | None = None, ) -> DataTypes: - if kind is None: - return { - "bool": dt.bool, - "int8": dt.int8, - "int16": dt.int16, - "int32": dt.int32, - "int64": dt.int64, - "uint8": dt.uint8, - "uint16": dt.uint16, - "uint32": dt.uint32, - "uint64": dt.uint64, - "float32": dt.float32, - "float64": dt.float64, - "complex64": dt.complex64, - "complex128": dt.complex128, - } - if kind == "bool": - return {"bool": dt.bool} - if kind == "signed integer": - return { - "int8": dt.int8, - "int16": dt.int16, - "int32": dt.int32, - "int64": dt.int64, - } - if kind == "unsigned integer": - return { - "uint8": dt.uint8, - "uint16": dt.uint16, - "uint32": dt.uint32, - "uint64": dt.uint64, - } - if kind == "integral": - return { - "int8": dt.int8, - "int16": dt.int16, - "int32": dt.int32, - "int64": dt.int64, - "uint8": dt.uint8, - "uint16": dt.uint16, - "uint32": dt.uint32, - "uint64": dt.uint64, - } - if kind == "real floating": - return { - "float32": dt.float32, - "float64": dt.float64, - } - if kind == "complex floating": - return { - "complex64": dt.complex64, - "complex128": dt.complex128, - } - if kind == "numeric": - return { - "int8": dt.int8, - "int16": dt.int16, - "int32": dt.int32, - "int64": dt.int64, - "uint8": dt.uint8, - "uint16": dt.uint16, - "uint32": dt.uint32, - "uint64": dt.uint64, - "float32": dt.float32, - "float64": dt.float64, - "complex64": dt.complex64, - "complex128": dt.complex128, - } - if isinstance(kind, tuple): + if device is None: + device = CPU_DEVICE + if isinstance(kind, type(None) | str): + + try: + dtypes = _devices._kind_to_dtypes[kind] + except KeyError: + raise ValueError(f"unsupported kind: {kind!r}") + res = _devices._map_supported(dtypes, device) + return res + + elif isinstance(kind, tuple): res: DataTypes = {} for k in kind: - res.update(self.dtypes(kind=k, device=device)) + res.update(self.dtypes(kind=kind, device=device)) return res raise ValueError(f"unsupported kind: {kind!r}") diff --git a/array_api_strict/tests/test_device_support.py b/array_api_strict/tests/test_device_support.py index 0f3d6b5..8726266 100644 --- a/array_api_strict/tests/test_device_support.py +++ b/array_api_strict/tests/test_device_support.py @@ -1,6 +1,6 @@ import pytest -import array_api_strict +import array_api_strict as xp @pytest.mark.parametrize( @@ -18,11 +18,11 @@ ), ) def test_fft_device_support_complex(func_name): - func = getattr(array_api_strict.fft, func_name) - x = array_api_strict.asarray( + func = getattr(xp.fft, func_name) + x = xp.asarray( [1, 2.0], - dtype=array_api_strict.complex64, - device=array_api_strict.Device("device1"), + dtype=xp.complex64, + device=xp.Device("device1"), ) y = func(x) @@ -31,8 +31,25 @@ def test_fft_device_support_complex(func_name): @pytest.mark.parametrize("func_name", ("rfft", "rfftn", "ihfft")) def test_fft_device_support_real(func_name): - func = getattr(array_api_strict.fft, func_name) - x = array_api_strict.asarray([1, 2.0], device=array_api_strict.Device("device1")) + func = getattr(xp.fft, func_name) + x = xp.asarray([1, 2.0], device=xp.Device("device1")) y = func(x) assert x.device == y.device + + +class TestF32Device: + @pytest.mark.parametrize("dtype_str", ["float64", "complex128"]) + def test_f64_raises(self, dtype_str): + f32_device = xp.Device("F32_device") + dtype = getattr(xp, dtype_str) + with pytest.raises(ValueError): + xp.arange(3, device=f32_device, dtype=dtype) + + def test_info_no_f64(self): + f32_device = xp.Device("F32_device") + + info = xp.__array_namespace_info__() + all_dtypes = info.dtypes(device=f32_device) + assert "float64" not in all_dtypes + assert "complex128" not in all_dtypes From 8be2e9187f801f666f4b91e5513bf165a3202375 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 24 Apr 2026 11:10:03 +0200 Subject: [PATCH 03/12] ENH: info.default_dtypes: make device-aware --- array_api_strict/_devices.py | 18 +++++++++++++++++- array_api_strict/_info.py | 10 ++-------- array_api_strict/tests/test_device_support.py | 13 +++++++++++++ 3 files changed, 32 insertions(+), 9 deletions(-) diff --git a/array_api_strict/_devices.py b/array_api_strict/_devices.py index eb03d92..cd76eae 100644 --- a/array_api_strict/_devices.py +++ b/array_api_strict/_devices.py @@ -1,7 +1,7 @@ from typing import Final -from ._dtypes import DType, float64, complex128 from ._dtypes import ( + DType, float32, float64, complex64, complex128, int64, _all_dtypes, _boolean_dtypes, _signed_integer_dtypes, _unsigned_integer_dtypes, _integer_dtypes, _real_floating_dtypes, _complex_floating_dtypes, _numeric_dtypes @@ -50,6 +50,22 @@ def check_device(device: Device | None) -> None: # Helpers for device-specific dtype support +def get_default_dtypes(device: Device | None = None) -> dict[str, Device]: + if device == _F32_DEVICE: + return { + "real floating": float32, + "complex floating": complex64, + "integral": int64, + "indexing": int64, + } + else: + return { + "real floating": float64, + "complex floating": complex128, + "integral": int64, + "indexing": int64, + } + def device_supports_dtype(device: Device | None, dtype: DType |None) -> bool: """True if `device` supports `dtype`, False otherwise.""" diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index fa25ad7..d0d5ff6 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -1,8 +1,7 @@ import numpy as np -from . import _dtypes as dt from . import _devices -from ._array_object import ALL_DEVICES, CPU_DEVICE, Device +from ._devices import ALL_DEVICES, CPU_DEVICE, Device from ._flags import get_array_api_strict_flags, requires_api_version from ._typing import Capabilities, DataTypes, DefaultDataTypes @@ -41,12 +40,7 @@ def default_dtypes( *, device: Device | None = None, ) -> DefaultDataTypes: - return { - "real floating": dt.float64, - "complex floating": dt.complex128, - "integral": dt.int64, - "indexing": dt.int64, - } + return _devices.get_default_dtypes(device) @requires_api_version('2023.12') def dtypes( diff --git a/array_api_strict/tests/test_device_support.py b/array_api_strict/tests/test_device_support.py index 8726266..242151a 100644 --- a/array_api_strict/tests/test_device_support.py +++ b/array_api_strict/tests/test_device_support.py @@ -53,3 +53,16 @@ def test_info_no_f64(self): all_dtypes = info.dtypes(device=f32_device) assert "float64" not in all_dtypes assert "complex128" not in all_dtypes + + def test_info_default_dtypes(self): + f32_device = xp.Device("F32_device") + info = xp.__array_namespace_info__() + defaults = info.default_dtypes(device=f32_device) + assert defaults["real floating"] == xp.float32 + assert defaults["complex floating"] == xp.complex64 + + cpu_device = xp.Device() + info = xp.__array_namespace_info__() + defaults = info.default_dtypes(device=cpu_device) + assert defaults["real floating"] == xp.float64 + assert defaults["complex floating"] == xp.complex128 From 782689e3e9331d0115f31ffa1c89776c8695998d Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 24 Apr 2026 11:26:22 +0200 Subject: [PATCH 04/12] TST: use per-device dtypes in test_elementwise_functions --- array_api_strict/tests/test_elementwise_functions.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 050f2bc..5cc65d8 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -21,6 +21,7 @@ int64, uint64, ) +from .._info import __array_namespace_info__ from .test_array_object import _check_op_array_scalar, BIG_INT import array_api_strict @@ -144,6 +145,10 @@ def _array_vals(dtypes): yield asarray(1., dtype=dtype, device=device) dtypes = _dtype_categories[types] + + supported_dtypes = __array_namespace_info__().dtypes(device=device) + dtypes = [dt for dt in dtypes if dt in supported_dtypes] + func = getattr(_elementwise_functions, func_name) for x in _array_vals(dtypes): From d48f41122e0a6754263fdb1e4a26e811a733f410 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 24 Apr 2026 12:27:17 +0200 Subject: [PATCH 05/12] ENH: device-dependent default dtypes in creation_functions --- array_api_strict/_creation_functions.py | 13 ++++- .../tests/test_creation_functions.py | 56 ++++++++++++++++++- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 4afa411..196738b 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -6,7 +6,10 @@ import numpy as np from ._dtypes import DType, _all_dtypes, _np_dtype -from ._devices import CPU_DEVICE, Device, device_supports_dtype, check_device as _check_device +from ._devices import ( + CPU_DEVICE, Device, device_supports_dtype, get_default_dtypes, + check_device as _check_device +) from ._flags import get_array_api_strict_flags from ._typing import NestedSequence, SupportsBufferProtocol, SupportsDLPack @@ -146,6 +149,8 @@ def empty( _check_device(device) _check_valid_dtype(dtype, device) + if dtype is None: + dtype = get_default_dtypes(device)["real floating"] return Array._new(np.empty(shape, dtype=_np_dtype(dtype)), device=device) @@ -234,6 +239,8 @@ def full( _check_device(device) _check_valid_dtype(dtype, device) + if dtype is None: + dtype = get_default_dtypes(device)["real floating"] if not isinstance(fill_value, bool | int | float | complex): msg = f"Expected Python scalar fill_value, got type {type(fill_value)}" @@ -350,6 +357,8 @@ def ones( _check_device(device) _check_valid_dtype(dtype, device) + if dtype is None: + dtype = get_default_dtypes(device)["real floating"] return Array._new(np.ones(shape, dtype=_np_dtype(dtype)), device=device) @@ -415,6 +424,8 @@ def zeros( _check_device(device) _check_valid_dtype(dtype, device) + if dtype is None: + dtype = get_default_dtypes(device)["real floating"] return Array._new(np.zeros(shape, dtype=_np_dtype(dtype)), device=device) diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index 6736826..67e0cd8 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -23,7 +23,9 @@ zeros_like, ) from .._dtypes import float32, float64 -from .._array_object import Array, CPU_DEVICE, Device +from .._array_object import Array +from .._devices import CPU_DEVICE, ALL_DEVICES, Device +from .._info import __array_namespace_info__ from .._flags import set_array_api_strict_flags def test_asarray_errors(): @@ -212,6 +214,7 @@ def test_zeros_like_errors(): assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype=int)) assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype="i")) + def test_meshgrid_dtype_errors(): # Doesn't raise meshgrid() @@ -221,6 +224,57 @@ def test_meshgrid_dtype_errors(): assert_raises(ValueError, lambda: meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float64))) + +def _full(a, *args, **kwds): + return full(a, fill_value=42, *args, **kwds) + + +def _full_like(a, *args, **kwds): + return full_like(a, fill_value=42, *args, **kwds) + + +class TestDefaultDType: + + info = __array_namespace_info__() + + @pytest.mark.parametrize("device", ALL_DEVICES) + @pytest.mark.parametrize("func", [empty, zeros, ones, _full]) + def test_ones_etc(self, func, device): + a = func(1, device=device) + assert a.dtype == self.info.default_dtypes(device=device)["real floating"] + + @pytest.mark.parametrize("func", [empty_like, zeros_like, ones_like, _full_like]) + def test_ones_like_etc_correct(self, func): + # float32 is preserved + a = ones(2, dtype=float32) + device = Device('F32_device') + b = func(a, device=device) + assert b.dtype == self.info.default_dtypes(device=device)["real floating"] + + @pytest.mark.parametrize("func", [empty_like, zeros_like, ones_like, _full_like]) + def test_ones_like_etc_incorrect(self, func): + a = ones(2) + assert a.dtype == float64 + assert a.device == Device() + + # XXX: a.dtype not supported by the device: ValueError or TypeError? + + # >>> a = torch.ones(3, dtype=torch.float64, device='cpu') + # >>> torch.ones_like(a, device='mps') + # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework + # doesn't support float64. + with pytest.raises(TypeError): + func(a, device=Device('F32_device')) +# TODO: +# def asarray( +# def arange( +# def eye( +# def linspace( +# def meshgrid(*arrays: Array, indexing: Literal["xy", "ij"] = "xy") -> tuple[Array, ...]: +# def tril(x: Array, /, *, k: int = 0) -> Array: +# def triu(x: Array, /, *, k: int = 0) -> Array: + + @pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) def from_dlpack_2023_12(api_version): if api_version != '2022.12': From a35e4085c25357d41863075815fe62b84c7d3921 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 24 Apr 2026 21:59:18 +0200 Subject: [PATCH 06/12] BUG: fix full() output dtype (depends on fill_value type) --- array_api_strict/_creation_functions.py | 14 +++++++++++--- array_api_strict/tests/test_creation_functions.py | 4 ++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 196738b..89fed1c 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -5,7 +5,7 @@ import numpy as np -from ._dtypes import DType, _all_dtypes, _np_dtype +from ._dtypes import DType, _all_dtypes, _np_dtype, bool as xp_bool from ._devices import ( CPU_DEVICE, Device, device_supports_dtype, get_default_dtypes, check_device as _check_device @@ -239,12 +239,20 @@ def full( _check_device(device) _check_valid_dtype(dtype, device) - if dtype is None: - dtype = get_default_dtypes(device)["real floating"] if not isinstance(fill_value, bool | int | float | complex): msg = f"Expected Python scalar fill_value, got type {type(fill_value)}" raise TypeError(msg) + + if dtype is None: + if type(fill_value) == bool: + dtype = xp_bool + else: + kind = { + int: "integral", float: "real floating", complex: "complex floating" + }[type(fill_value)] + dtype = get_default_dtypes(device)[kind] + res = np.full(shape, fill_value, dtype=_np_dtype(dtype)) if DType(res.dtype) not in _all_dtypes: # This will happen if the fill value is not something that NumPy diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index 67e0cd8..166d51b 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -226,11 +226,11 @@ def test_meshgrid_dtype_errors(): def _full(a, *args, **kwds): - return full(a, fill_value=42, *args, **kwds) + return full(a, fill_value=42.0, *args, **kwds) def _full_like(a, *args, **kwds): - return full_like(a, fill_value=42, *args, **kwds) + return full_like(a, fill_value=42.0, *args, **kwds) class TestDefaultDType: From 0fdb4f8c0173ce02f4f96fe262d9fe103f4fc180 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 24 Apr 2026 22:01:52 +0200 Subject: [PATCH 07/12] BUG: fix info.dtypes(kind=tuple) --- array_api_strict/_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index d0d5ff6..cbee036 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -63,7 +63,7 @@ def dtypes( elif isinstance(kind, tuple): res: DataTypes = {} for k in kind: - res.update(self.dtypes(kind=kind, device=device)) + res.update(self.dtypes(kind=k, device=device)) return res raise ValueError(f"unsupported kind: {kind!r}") From 6ba0f5b9ce2a599db5de35b83ed606019e17db42 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 24 Apr 2026 22:23:44 +0200 Subject: [PATCH 08/12] MAINT: appease ruff --- array_api_strict/_array_object.py | 2 +- array_api_strict/_creation_functions.py | 2 +- array_api_strict/_fft.py | 3 ++- array_api_strict/tests/test_elementwise_functions.py | 2 +- array_api_strict/tests/test_searching_functions.py | 3 ++- 5 files changed, 7 insertions(+), 5 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index dcb20e9..5c933ef 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -40,7 +40,7 @@ _real_to_complex_map, _result_type, ) -from ._devices import CPU_DEVICE, ALL_DEVICES, Device +from ._devices import CPU_DEVICE, Device from ._flags import get_array_api_strict_flags, set_array_api_strict_flags from ._typing import PyCapsule diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 89fed1c..885d005 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -7,7 +7,7 @@ from ._dtypes import DType, _all_dtypes, _np_dtype, bool as xp_bool from ._devices import ( - CPU_DEVICE, Device, device_supports_dtype, get_default_dtypes, + Device, device_supports_dtype, get_default_dtypes, check_device as _check_device ) from ._flags import get_array_api_strict_flags diff --git a/array_api_strict/_fft.py b/array_api_strict/_fft.py index c2c617e..9f0dfcf 100644 --- a/array_api_strict/_fft.py +++ b/array_api_strict/_fft.py @@ -3,7 +3,8 @@ import numpy as np -from ._array_object import ALL_DEVICES, Array, Device +from ._array_object import Array +from ._devices import ALL_DEVICES, Device from ._data_type_functions import astype from ._dtypes import ( DType, diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 5cc65d8..7fb6e33 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -6,7 +6,7 @@ from .. import asarray, _elementwise_functions -from .._array_object import ALL_DEVICES, CPU_DEVICE, Device +from .._devices import ALL_DEVICES, CPU_DEVICE, Device from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift from .._dtypes import ( _dtype_categories, diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py index abe1949..18775ed 100644 --- a/array_api_strict/tests/test_searching_functions.py +++ b/array_api_strict/tests/test_searching_functions.py @@ -3,10 +3,11 @@ import array_api_strict as xp from array_api_strict import ArrayAPIStrictFlags -from .._array_object import ALL_DEVICES, CPU_DEVICE, Device +from .._devices import ALL_DEVICES, CPU_DEVICE, Device from .._dtypes import _all_dtypes + def test_where_with_scalars(): x = xp.asarray([1, 2, 3, 1]) From 6d72df1c69b1382011997cd77f4eaee5f91a680b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 26 Apr 2026 20:34:24 +0200 Subject: [PATCH 09/12] BUG: make _like functions raise if array arg dtype is incompatible with the device --- array_api_strict/_creation_functions.py | 8 ++++++++ array_api_strict/tests/test_creation_functions.py | 11 ++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 885d005..5170fd7 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -168,6 +168,8 @@ def empty_like( _check_device(device) if device is None: device = x.device + if dtype is None: + dtype = x.dtype _check_valid_dtype(dtype, device) return Array._new(np.empty_like(x._array, dtype=_np_dtype(dtype)), device=device) @@ -279,6 +281,8 @@ def full_like( _check_device(device) if device is None: device = x.device + if dtype is None: + dtype = x.dtype _check_valid_dtype(dtype, device) if not isinstance(fill_value, bool | int | float | complex): @@ -384,6 +388,8 @@ def ones_like( _check_device(device) if device is None: device = x.device + if dtype is None: + dtype = x.dtype _check_valid_dtype(dtype, device) return Array._new(np.ones_like(x._array, dtype=_np_dtype(dtype)), device=device) @@ -451,6 +457,8 @@ def zeros_like( _check_device(device) if device is None: device = x.device + if dtype is None: + dtype = x.dtype _check_valid_dtype(dtype, device) return Array._new(np.zeros_like(x._array, dtype=_np_dtype(dtype)), device=device) diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index 166d51b..4226018 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -263,8 +263,17 @@ def test_ones_like_etc_incorrect(self, func): # >>> torch.ones_like(a, device='mps') # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework # doesn't support float64. - with pytest.raises(TypeError): + + # incompatible dtype inferred from `a.dtype` + with pytest.raises((TypeError, ValueError)): func(a, device=Device('F32_device')) + + # `a.dtype` is compatible but the explicit dtype= argument is incompatible + a = ones(2, dtype=float32) + with pytest.raises((TypeError, ValueError)): + func(a, device=Device('F32_device'), dtype=float64) + + # TODO: # def asarray( # def arange( From 917312e89cce2b31c718a691bb569a5b95d2905c Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 26 Apr 2026 20:38:42 +0200 Subject: [PATCH 10/12] ENH: eye default dtype --- array_api_strict/_creation_functions.py | 2 ++ array_api_strict/tests/test_creation_functions.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 5170fd7..5408499 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -193,6 +193,8 @@ def eye( _check_device(device) _check_valid_dtype(dtype, device) + if dtype is None: + dtype = get_default_dtypes(device)["real floating"] return Array._new( np.eye(n_rows, M=n_cols, k=k, dtype=_np_dtype(dtype)), device=device diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index 4226018..89ec655 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -273,15 +273,20 @@ def test_ones_like_etc_incorrect(self, func): with pytest.raises((TypeError, ValueError)): func(a, device=Device('F32_device'), dtype=float64) + def test_eye(self): + device = Device('F32_device') + a = eye(3, device=device) + assert a.dtype == self.info.default_dtypes(device=device)["real floating"] + + with pytest.raises((TypeError, ValueError)): + eye(3, device=device, dtype=float64) + # TODO: # def asarray( # def arange( -# def eye( # def linspace( -# def meshgrid(*arrays: Array, indexing: Literal["xy", "ij"] = "xy") -> tuple[Array, ...]: -# def tril(x: Array, /, *, k: int = 0) -> Array: -# def triu(x: Array, /, *, k: int = 0) -> Array: + @pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) From 147050539b85326034df850fdc95e0dcc9507151 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 26 Apr 2026 20:49:46 +0200 Subject: [PATCH 11/12] ENH: linspace default dtype --- array_api_strict/_creation_functions.py | 5 +++++ .../tests/test_creation_functions.py | 16 ++++++++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 5408499..ddb744e 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -318,6 +318,11 @@ def linspace( _check_device(device) _check_valid_dtype(dtype, device) + if dtype is None: + if isinstance(start, complex) or isinstance(stop, complex): + dtype = get_default_dtypes(device)["complex floating"] + else: + dtype = get_default_dtypes(device)["real floating"] return Array._new( np.linspace(start, stop, num, dtype=_np_dtype(dtype), endpoint=endpoint), diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index 89ec655..975ceb6 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -282,11 +282,23 @@ def test_eye(self): eye(3, device=device, dtype=float64) + def test_linspace(self): + device = Device('F32_device') + + a = linspace(1, 10, 11, device=device) + assert a.dtype == self.info.default_dtypes(device=device)["real floating"] + + a = linspace(1+0j, 10, 11, device=device) + assert a.dtype == self.info.default_dtypes(device=device)["complex floating"] + + with pytest.raises((TypeError, ValueError)): + linspace(1, 10, 11, device=device, dtype=float64) + + + # TODO: # def asarray( # def arange( -# def linspace( - @pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) From 5e72e78c3b0f567757c77ecc0af8df2cc1131a75 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 26 Apr 2026 20:55:42 +0200 Subject: [PATCH 12/12] ENH: arange default dtype --- array_api_strict/_creation_functions.py | 5 +++++ array_api_strict/tests/test_creation_functions.py | 15 +++++++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index ddb744e..b6e1d67 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -127,6 +127,11 @@ def arange( _check_device(device) _check_valid_dtype(dtype, device) + if dtype is None: + if any(isinstance(x, float) for x in (start, stop, step)): + dtype = get_default_dtypes(device)["real floating"] + else: + dtype = get_default_dtypes(device)["integral"] return Array._new( np.arange(start, stop, step, dtype=_np_dtype(dtype)), diff --git a/array_api_strict/tests/test_creation_functions.py b/array_api_strict/tests/test_creation_functions.py index 975ceb6..b8b71be 100644 --- a/array_api_strict/tests/test_creation_functions.py +++ b/array_api_strict/tests/test_creation_functions.py @@ -281,7 +281,6 @@ def test_eye(self): with pytest.raises((TypeError, ValueError)): eye(3, device=device, dtype=float64) - def test_linspace(self): device = Device('F32_device') @@ -294,11 +293,23 @@ def test_linspace(self): with pytest.raises((TypeError, ValueError)): linspace(1, 10, 11, device=device, dtype=float64) + def test_arange(self): + device = Device('F32_device') + + a = arange(0, 10, 1, device=device) + assert a.dtype == self.info.default_dtypes(device=device)["integral"] + + a = arange(0.0, 10, 1, device=device) + assert a.dtype == self.info.default_dtypes(device=device)["real floating"] + + with pytest.raises((TypeError, ValueError)): + arange(0, 10, 1, device=device, dtype=float64) + with pytest.raises((TypeError, ValueError)): + arange(0.0, 10, 1, device=device, dtype=float64) # TODO: # def asarray( -# def arange( @pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])