Skip to content

Commit e5d5f7e

Browse files
committed
mypy passing
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent 389e042 commit e5d5f7e

File tree

5 files changed

+87
-54
lines changed

5 files changed

+87
-54
lines changed

.pre-commit-config.yaml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
repos:
2+
- repo: https://github.com/pre-commit/mirrors-mypy
3+
rev: "v1.0.0"
4+
hooks:
5+
- id: mypy
6+
additional_dependencies: [typing_extensions>=4.4.0]
7+
args:
8+
- --ignore-missing-imports
9+
- --config=pyproject.toml
10+
files: ".*(_draft.*)$"
11+
exclude: |
12+
(?x)^(
13+
.*creation_functions.py|
14+
.*data_type_functions.py|
15+
.*elementwise_functions.py|
16+
.*fft.py|
17+
.*indexing_functions.py|
18+
.*linalg.py|
19+
.*linear_algebra_functions.py|
20+
.*manipulation_functions.py|
21+
.*searching_functions.py|
22+
.*set_functions.py|
23+
.*sorting_functions.py|
24+
.*statistical_functions.py|
25+
.*utility_functions.py|
26+
)$

pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,12 @@ doc = [
2929
[build-system]
3030
requires = ["setuptools"]
3131
build-backend = "setuptools.build_meta"
32+
33+
34+
[tool.mypy]
35+
python_version = "3.9"
36+
mypy_path = "$MYPY_CONFIG_FILE_DIR/src/array_api_stubs/_draft/"
37+
files = [
38+
"src/array_api_stubs/_draft/**/*.py"
39+
]
40+
follow_imports = "silent"

src/array_api_stubs/_draft/_types.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class finfo_object:
4646
max: float
4747
min: float
4848
smallest_normal: float
49-
dtype: dtype
49+
dtype: DType
5050

5151
@dataclass
5252
class iinfo_object:
@@ -55,7 +55,7 @@ class iinfo_object:
5555
bits: int
5656
max: int
5757
min: int
58-
dtype: dtype
58+
dtype: DType
5959

6060
_T_co = TypeVar("_T_co", covariant=True)
6161

@@ -68,7 +68,6 @@ def __len__(self, /) -> int:
6868
...
6969

7070

71-
7271
__all__ = [
7372
"Any",
7473
"List",

src/array_api_stubs/_draft/array_object.py

Lines changed: 48 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@
22

33
from typing import TYPE_CHECKING, Protocol, TypeVar
44

5-
from ._types import (
6-
dtype as Dtype,
7-
device as Device,
8-
Any,
9-
PyCapsule,
10-
Enum,
11-
ellipsis,
12-
)
5+
if TYPE_CHECKING:
6+
from ._types import (
7+
device as Device,
8+
Any,
9+
PyCapsule,
10+
Enum,
11+
ellipsis,
12+
)
13+
from .data_types import DType
1314

14-
Self = TypeVar("Self", bound="Array")
15-
# NOTE: when working with py3.11+ this can be ``typing.Self``.
15+
array = TypeVar("array", bound="Array")
16+
# NOTE: when working with py3.11+ this can be ``typing.array``.
1617

1718

1819
class Array(Protocol):
@@ -21,7 +22,7 @@ def __init__(self) -> None:
2122
...
2223

2324
@property
24-
def dtype(self) -> Dtype:
25+
def dtype(self) -> DType:
2526
"""
2627
Data type of the array elements.
2728
@@ -33,7 +34,7 @@ def dtype(self) -> Dtype:
3334
...
3435

3536
@property
36-
def device(self) -> Device:
37+
def device(self) -> Device: # type: ignore[type-var]
3738
"""
3839
Hardware device the array data resides on.
3940
@@ -45,7 +46,7 @@ def device(self) -> Device:
4546
...
4647

4748
@property
48-
def mT(self: Self) -> Self:
49+
def mT(self: array) -> array:
4950
"""
5051
Transpose of a matrix (or a stack of matrices).
5152
@@ -109,7 +110,7 @@ def size(self) -> int | None:
109110
...
110111

111112
@property
112-
def T(self: Self) -> Self:
113+
def T(self: array) -> array:
113114
"""
114115
Transpose of the array.
115116
@@ -126,7 +127,7 @@ def T(self: Self) -> Self:
126127
"""
127128
...
128129

129-
def __abs__(self: Self, /) -> Self:
130+
def __abs__(self: array, /) -> array:
130131
"""
131132
Calculates the absolute value for each element of an array instance.
132133
@@ -156,7 +157,7 @@ def __abs__(self: Self, /) -> Self:
156157
"""
157158
...
158159

159-
def __add__(self: Self, other: int | float | Self, /) -> Self:
160+
def __add__(self: array, other: int | float | array, /) -> array:
160161
"""
161162
Calculates the sum for each element of an array instance with the respective element of the array ``other``.
162163
@@ -183,7 +184,7 @@ def __add__(self: Self, other: int | float | Self, /) -> Self:
183184
"""
184185
...
185186

186-
def __and__(self: Self, other: int | bool | Self, /) -> Self:
187+
def __and__(self: array, other: int | bool | array, /) -> array:
187188
"""
188189
Evaluates ``self_i & other_i`` for each element of an array instance with the respective element of the array ``other``.
189190
@@ -294,7 +295,7 @@ def __complex__(self, /) -> complex:
294295

295296
def __dlpack__(
296297
self, /, *, stream: int | Any | None = None
297-
) -> PyCapsule:
298+
) -> PyCapsule: # type: ignore[type-var]
298299
"""
299300
Exports the array for consumption by :func:`~array_api.from_dlpack` as a DLPack capsule.
300301
@@ -394,7 +395,7 @@ def __dlpack_device__(self, /) -> tuple[Enum, int]:
394395
# Note that __eq__ returns an array while `object.__eq__` returns a bool.
395396
# Hence Mypy will complain that this violates the Liskov substitution
396397
# principle - ignore that.
397-
def __eq__(self: Self, other: int | float | bool | Self, /) -> Self: # xtype: ignore
398+
def __eq__(self: array, other: int | float | bool | array, /) -> array: # type: ignore[override]
398399
r"""
399400
Computes the truth value of ``self_i == other_i`` for each element of an array instance with the respective element of the array ``other``.
400401
@@ -448,7 +449,7 @@ def __float__(self, /) -> float:
448449
"""
449450
...
450451

451-
def __floordiv__(self: Self, other: int | float | Self, /) -> Self:
452+
def __floordiv__(self: array, other: int | float | array, /) -> array:
452453
"""
453454
Evaluates ``self_i // other_i`` for each element of an array instance with the respective element of the array ``other``.
454455
@@ -473,7 +474,7 @@ def __floordiv__(self: Self, other: int | float | Self, /) -> Self:
473474
"""
474475
...
475476

476-
def __ge__(self: Self, other: int | float | Self, /) -> Self:
477+
def __ge__(self: array, other: int | float | array, /) -> array:
477478
"""
478479
Computes the truth value of ``self_i >= other_i`` for each element of an array instance with the respective element of the array ``other``.
479480
@@ -499,10 +500,10 @@ def __ge__(self: Self, other: int | float | Self, /) -> Self:
499500
...
500501

501502
def __getitem__(
502-
self: Self,
503-
key: int | slice | ellipsis | tuple[int | slice | ellipsis, ...] | Self,
503+
self: array,
504+
key: int | slice | ellipsis | tuple[int | slice | ellipsis, ...] | array,
504505
/,
505-
) -> Self:
506+
) -> array:
506507
"""
507508
Returns ``self[key]``.
508509
@@ -520,7 +521,7 @@ def __getitem__(
520521
"""
521522
...
522523

523-
def __gt__(self: Self, other: int | float | Self, /) -> Self:
524+
def __gt__(self: array, other: int | float | array, /) -> array:
524525
"""
525526
Computes the truth value of ``self_i > other_i`` for each element of an array instance with the respective element of the array ``other``.
526527
@@ -605,7 +606,7 @@ def __int__(self, /) -> int:
605606
"""
606607
...
607608

608-
def __invert__(self: Self, /) -> Self:
609+
def __invert__(self: array, /) -> array:
609610
"""
610611
Evaluates ``~self_i`` for each element of an array instance.
611612
@@ -625,7 +626,7 @@ def __invert__(self: Self, /) -> Self:
625626
"""
626627
...
627628

628-
def __le__(self: Self, other: int | float | Self, /) -> Self:
629+
def __le__(self: array, other: int | float | array, /) -> array:
629630
"""
630631
Computes the truth value of ``self_i <= other_i`` for each element of an array instance with the respective element of the array ``other``.
631632
@@ -650,7 +651,7 @@ def __le__(self: Self, other: int | float | Self, /) -> Self:
650651
"""
651652
...
652653

653-
def __lshift__(self: Self, other: int | Self, /) -> Self:
654+
def __lshift__(self: array, other: int | array, /) -> array:
654655
"""
655656
Evaluates ``self_i << other_i`` for each element of an array instance with the respective element of the array ``other``.
656657
@@ -672,7 +673,7 @@ def __lshift__(self: Self, other: int | Self, /) -> Self:
672673
"""
673674
...
674675

675-
def __lt__(self: Self, other: int | float | Self, /) -> Self:
676+
def __lt__(self: array, other: int | float | array, /) -> array:
676677
"""
677678
Computes the truth value of ``self_i < other_i`` for each element of an array instance with the respective element of the array ``other``.
678679
@@ -697,7 +698,7 @@ def __lt__(self: Self, other: int | float | Self, /) -> Self:
697698
"""
698699
...
699700

700-
def __matmul__(self: Self, other: Self, /) -> Self:
701+
def __matmul__(self: array, other: array, /) -> array:
701702
"""
702703
Computes the matrix product.
703704
@@ -746,7 +747,7 @@ def __matmul__(self: Self, other: Self, /) -> Self:
746747
"""
747748
...
748749

749-
def __mod__(self: Self, other: int | float | Self, /) -> Self:
750+
def __mod__(self: array, other: int | float | array, /) -> array:
750751
"""
751752
Evaluates ``self_i % other_i`` for each element of an array instance with the respective element of the array ``other``.
752753
@@ -771,7 +772,7 @@ def __mod__(self: Self, other: int | float | Self, /) -> Self:
771772
"""
772773
...
773774

774-
def __mul__(self: Self, other: int | float | Self, /) -> Self:
775+
def __mul__(self: array, other: int | float | array, /) -> array:
775776
r"""
776777
Calculates the product for each element of an array instance with the respective element of the array ``other``.
777778
@@ -802,7 +803,7 @@ def __mul__(self: Self, other: int | float | Self, /) -> Self:
802803
...
803804

804805
# See note above __eq__ method for explanation of the `type: ignore`
805-
def __ne__(self: Self, other: int | float | bool | Self, /) -> Self: # type: ignore
806+
def __ne__(self: array, other: int | float | bool | array, /) -> array: # type: ignore[override]
806807
"""
807808
Computes the truth value of ``self_i != other_i`` for each element of an array instance with the respective element of the array ``other``.
808809
@@ -830,7 +831,7 @@ def __ne__(self: Self, other: int | float | bool | Self, /) -> Self: # type: ig
830831
"""
831832
...
832833

833-
def __neg__(self: Self, /) -> Self:
834+
def __neg__(self: array, /) -> array:
834835
"""
835836
Evaluates ``-self_i`` for each element of an array instance.
836837
@@ -861,7 +862,7 @@ def __neg__(self: Self, /) -> Self:
861862
"""
862863
...
863864

864-
def __or__(self: Self, other: int | bool | Self, /) -> Self:
865+
def __or__(self: array, other: int | bool | array, /) -> array:
865866
"""
866867
Evaluates ``self_i | other_i`` for each element of an array instance with the respective element of the array ``other``.
867868
@@ -883,7 +884,7 @@ def __or__(self: Self, other: int | bool | Self, /) -> Self:
883884
"""
884885
...
885886

886-
def __pos__(self: Self, /) -> Self:
887+
def __pos__(self: array, /) -> array:
887888
"""
888889
Evaluates ``+self_i`` for each element of an array instance.
889890
@@ -908,7 +909,7 @@ def __pos__(self: Self, /) -> Self:
908909
"""
909910
...
910911

911-
def __pow__(self: Self, other: int | float | Self, /) -> Self:
912+
def __pow__(self: array, other: int | float | array, /) -> array:
912913
r"""
913914
Calculates an implementation-dependent approximation of exponentiation by raising each element (the base) of an array instance to the power of ``other_i`` (the exponent), where ``other_i`` is the corresponding element of the array ``other``.
914915
@@ -940,7 +941,7 @@ def __pow__(self: Self, other: int | float | Self, /) -> Self:
940941
"""
941942
...
942943

943-
def __rshift__(self: Self, other: int | Self, /) -> Self:
944+
def __rshift__(self: array, other: int | array, /) -> array:
944945
"""
945946
Evaluates ``self_i >> other_i`` for each element of an array instance with the respective element of the array ``other``.
946947
@@ -963,9 +964,9 @@ def __rshift__(self: Self, other: int | Self, /) -> Self:
963964
...
964965

965966
def __setitem__(
966-
self: Self,
967-
key: int | slice | ellipsis | tuple[int | slice | ellipsis, ...] | Self,
968-
value: int | float | bool | Self,
967+
self: array,
968+
key: int | slice | ellipsis | tuple[int | slice | ellipsis, ...] | array,
969+
value: int | float | bool | array,
969970
/,
970971
) -> None:
971972
"""
@@ -991,7 +992,7 @@ def __setitem__(
991992
"""
992993
...
993994

994-
def __sub__(self: Self, other: int | float | Self, /) -> Self:
995+
def __sub__(self: array, other: int | float | array, /) -> array:
995996
"""
996997
Calculates the difference for each element of an array instance with the respective element of the array ``other``.
997998
@@ -1020,7 +1021,7 @@ def __sub__(self: Self, other: int | float | Self, /) -> Self:
10201021
"""
10211022
...
10221023

1023-
def __truediv__(self: Self, other: int | float | Self, /) -> Self:
1024+
def __truediv__(self: array, other: int | float | array, /) -> array:
10241025
r"""
10251026
Evaluates ``self_i / other_i`` for each element of an array instance with the respective element of the array ``other``.
10261027
@@ -1052,7 +1053,7 @@ def __truediv__(self: Self, other: int | float | Self, /) -> Self:
10521053
"""
10531054
...
10541055

1055-
def __xor__(self: Self, other: int | bool | Self, /) -> Self:
1056+
def __xor__(self: array, other: int | bool | array, /) -> array:
10561057
"""
10571058
Evaluates ``self_i ^ other_i`` for each element of an array instance with the respective element of the array ``other``.
10581059
@@ -1075,8 +1076,8 @@ def __xor__(self: Self, other: int | bool | Self, /) -> Self:
10751076
...
10761077

10771078
def to_device(
1078-
self: Self, device: Device, /, *, stream: int | Any | None = None
1079-
) -> Self:
1079+
self: array, device: Device, /, *, stream: int | Any | None = None
1080+
) -> array:
10801081
"""
10811082
Copy the array from the device on which it currently resides to the specified ``device``.
10821083
@@ -1101,6 +1102,4 @@ def to_device(
11011102
...
11021103

11031104

1104-
array = Array
1105-
1106-
__all__ = ["array"]
1105+
__all__ = ["Array"]

0 commit comments

Comments
 (0)