diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index c154ad0b..d5342658 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -288,8 +288,14 @@ def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]: is_dask_array is_jax_array """ + try: + # TODO: drop this check after np.matrix is gone + if _issubclass_fast(type(x), "numpy", "matrix"): + return False + except Exception: + pass return ( - hasattr(x, '__array_namespace__') + hasattr(x, '__array_namespace__') or _is_array_api_cls(cast(Hashable, type(x))) ) diff --git a/tests/test_numpy.py b/tests/test_numpy.py new file mode 100644 index 00000000..a139d428 --- /dev/null +++ b/tests/test_numpy.py @@ -0,0 +1,19 @@ +"""Test "unspecified" behavior which we cannot easily test in the Array API test suite. +""" +import warnings +import pytest + +try: + import numpy as np +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="numpy not found") + +from array_api_compat import is_array_api_obj + +def test_matrix_is_not_array_api_obj(): + assert is_array_api_obj(np.asarray(3)) + assert is_array_api_obj(np.float64(3)) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", PendingDeprecationWarning) + assert not is_array_api_obj(np.matrix(3))