Skip to content

Commit a5a1d8b

Browse files
authored
TYP: Type annotations overhaul, part 1 (#257)
* ENH: Type annotations overhaul * Re-add py.typed * code review * lint * asarray * fill_value * result_type * lint * Arrays don't need to support buffer protocol * bool is a subclass of int * reshape: copy kwarg is keyword-only * tensordot formatting * Reinstate explicit bool | complex
1 parent 0cebd55 commit a5a1d8b

19 files changed

+511
-508
lines changed

array_api_compat/common/_aliases.py

+140-108
Large diffs are not rendered by default.

array_api_compat/common/_fft.py

+45-42
Original file line numberDiff line numberDiff line change
@@ -1,149 +1,148 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Union, Optional, Literal
3+
from collections.abc import Sequence
4+
from typing import Union, Optional, Literal
45

5-
if TYPE_CHECKING:
6-
from ._typing import Device, ndarray, DType
7-
from collections.abc import Sequence
6+
from ._typing import Device, Array, DType, Namespace
87

98
# Note: NumPy fft functions improperly upcast float32 and complex64 to
109
# complex128, which is why we require wrapping them all here.
1110

1211
def fft(
13-
x: ndarray,
12+
x: Array,
1413
/,
15-
xp,
14+
xp: Namespace,
1615
*,
1716
n: Optional[int] = None,
1817
axis: int = -1,
1918
norm: Literal["backward", "ortho", "forward"] = "backward",
20-
) -> ndarray:
19+
) -> Array:
2120
res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
2221
if x.dtype in [xp.float32, xp.complex64]:
2322
return res.astype(xp.complex64)
2423
return res
2524

2625
def ifft(
27-
x: ndarray,
26+
x: Array,
2827
/,
29-
xp,
28+
xp: Namespace,
3029
*,
3130
n: Optional[int] = None,
3231
axis: int = -1,
3332
norm: Literal["backward", "ortho", "forward"] = "backward",
34-
) -> ndarray:
33+
) -> Array:
3534
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
3635
if x.dtype in [xp.float32, xp.complex64]:
3736
return res.astype(xp.complex64)
3837
return res
3938

4039
def fftn(
41-
x: ndarray,
40+
x: Array,
4241
/,
43-
xp,
42+
xp: Namespace,
4443
*,
4544
s: Sequence[int] = None,
4645
axes: Sequence[int] = None,
4746
norm: Literal["backward", "ortho", "forward"] = "backward",
48-
) -> ndarray:
47+
) -> Array:
4948
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
5049
if x.dtype in [xp.float32, xp.complex64]:
5150
return res.astype(xp.complex64)
5251
return res
5352

5453
def ifftn(
55-
x: ndarray,
54+
x: Array,
5655
/,
57-
xp,
56+
xp: Namespace,
5857
*,
5958
s: Sequence[int] = None,
6059
axes: Sequence[int] = None,
6160
norm: Literal["backward", "ortho", "forward"] = "backward",
62-
) -> ndarray:
61+
) -> Array:
6362
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
6463
if x.dtype in [xp.float32, xp.complex64]:
6564
return res.astype(xp.complex64)
6665
return res
6766

6867
def rfft(
69-
x: ndarray,
68+
x: Array,
7069
/,
71-
xp,
70+
xp: Namespace,
7271
*,
7372
n: Optional[int] = None,
7473
axis: int = -1,
7574
norm: Literal["backward", "ortho", "forward"] = "backward",
76-
) -> ndarray:
75+
) -> Array:
7776
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
7877
if x.dtype == xp.float32:
7978
return res.astype(xp.complex64)
8079
return res
8180

8281
def irfft(
83-
x: ndarray,
82+
x: Array,
8483
/,
85-
xp,
84+
xp: Namespace,
8685
*,
8786
n: Optional[int] = None,
8887
axis: int = -1,
8988
norm: Literal["backward", "ortho", "forward"] = "backward",
90-
) -> ndarray:
89+
) -> Array:
9190
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
9291
if x.dtype == xp.complex64:
9392
return res.astype(xp.float32)
9493
return res
9594

9695
def rfftn(
97-
x: ndarray,
96+
x: Array,
9897
/,
99-
xp,
98+
xp: Namespace,
10099
*,
101100
s: Sequence[int] = None,
102101
axes: Sequence[int] = None,
103102
norm: Literal["backward", "ortho", "forward"] = "backward",
104-
) -> ndarray:
103+
) -> Array:
105104
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
106105
if x.dtype == xp.float32:
107106
return res.astype(xp.complex64)
108107
return res
109108

110109
def irfftn(
111-
x: ndarray,
110+
x: Array,
112111
/,
113-
xp,
112+
xp: Namespace,
114113
*,
115114
s: Sequence[int] = None,
116115
axes: Sequence[int] = None,
117116
norm: Literal["backward", "ortho", "forward"] = "backward",
118-
) -> ndarray:
117+
) -> Array:
119118
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
120119
if x.dtype == xp.complex64:
121120
return res.astype(xp.float32)
122121
return res
123122

124123
def hfft(
125-
x: ndarray,
124+
x: Array,
126125
/,
127-
xp,
126+
xp: Namespace,
128127
*,
129128
n: Optional[int] = None,
130129
axis: int = -1,
131130
norm: Literal["backward", "ortho", "forward"] = "backward",
132-
) -> ndarray:
131+
) -> Array:
133132
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
134133
if x.dtype in [xp.float32, xp.complex64]:
135134
return res.astype(xp.float32)
136135
return res
137136

138137
def ihfft(
139-
x: ndarray,
138+
x: Array,
140139
/,
141-
xp,
140+
xp: Namespace,
142141
*,
143142
n: Optional[int] = None,
144143
axis: int = -1,
145144
norm: Literal["backward", "ortho", "forward"] = "backward",
146-
) -> ndarray:
145+
) -> Array:
147146
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
148147
if x.dtype in [xp.float32, xp.complex64]:
149148
return res.astype(xp.complex64)
@@ -152,12 +151,12 @@ def ihfft(
152151
def fftfreq(
153152
n: int,
154153
/,
155-
xp,
154+
xp: Namespace,
156155
*,
157156
d: float = 1.0,
158157
dtype: Optional[DType] = None,
159-
device: Optional[Device] = None
160-
) -> ndarray:
158+
device: Optional[Device] = None,
159+
) -> Array:
161160
if device not in ["cpu", None]:
162161
raise ValueError(f"Unsupported device {device!r}")
163162
res = xp.fft.fftfreq(n, d=d)
@@ -168,23 +167,27 @@ def fftfreq(
168167
def rfftfreq(
169168
n: int,
170169
/,
171-
xp,
170+
xp: Namespace,
172171
*,
173172
d: float = 1.0,
174173
dtype: Optional[DType] = None,
175-
device: Optional[Device] = None
176-
) -> ndarray:
174+
device: Optional[Device] = None,
175+
) -> Array:
177176
if device not in ["cpu", None]:
178177
raise ValueError(f"Unsupported device {device!r}")
179178
res = xp.fft.rfftfreq(n, d=d)
180179
if dtype is not None:
181180
return res.astype(dtype)
182181
return res
183182

184-
def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
183+
def fftshift(
184+
x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
185+
) -> Array:
185186
return xp.fft.fftshift(x, axes=axes)
186187

187-
def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
188+
def ifftshift(
189+
x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
190+
) -> Array:
188191
return xp.fft.ifftshift(x, axes=axes)
189192

190193
__all__ = [

array_api_compat/common/_helpers.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,14 @@
77
"""
88
from __future__ import annotations
99

10-
from typing import TYPE_CHECKING
11-
12-
if TYPE_CHECKING:
13-
from typing import Optional, Union, Any
14-
from ._typing import Array, Device, Namespace
15-
1610
import sys
1711
import math
1812
import inspect
1913
import warnings
14+
from typing import Optional, Union, Any
15+
16+
from ._typing import Array, Device, Namespace
17+
2018

2119
def _is_jax_zero_gradient_array(x: object) -> bool:
2220
"""Return True if `x` is a zero-gradient array.
@@ -268,7 +266,7 @@ def _compat_module_name() -> str:
268266
return __name__.removesuffix('.common._helpers')
269267

270268

271-
def is_numpy_namespace(xp) -> bool:
269+
def is_numpy_namespace(xp: Namespace) -> bool:
272270
"""
273271
Returns True if `xp` is a NumPy namespace.
274272
@@ -289,7 +287,7 @@ def is_numpy_namespace(xp) -> bool:
289287
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
290288

291289

292-
def is_cupy_namespace(xp) -> bool:
290+
def is_cupy_namespace(xp: Namespace) -> bool:
293291
"""
294292
Returns True if `xp` is a CuPy namespace.
295293
@@ -310,7 +308,7 @@ def is_cupy_namespace(xp) -> bool:
310308
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
311309

312310

313-
def is_torch_namespace(xp) -> bool:
311+
def is_torch_namespace(xp: Namespace) -> bool:
314312
"""
315313
Returns True if `xp` is a PyTorch namespace.
316314
@@ -331,7 +329,7 @@ def is_torch_namespace(xp) -> bool:
331329
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
332330

333331

334-
def is_ndonnx_namespace(xp) -> bool:
332+
def is_ndonnx_namespace(xp: Namespace) -> bool:
335333
"""
336334
Returns True if `xp` is an NDONNX namespace.
337335
@@ -350,7 +348,7 @@ def is_ndonnx_namespace(xp) -> bool:
350348
return xp.__name__ == 'ndonnx'
351349

352350

353-
def is_dask_namespace(xp) -> bool:
351+
def is_dask_namespace(xp: Namespace) -> bool:
354352
"""
355353
Returns True if `xp` is a Dask namespace.
356354
@@ -371,7 +369,7 @@ def is_dask_namespace(xp) -> bool:
371369
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
372370

373371

374-
def is_jax_namespace(xp) -> bool:
372+
def is_jax_namespace(xp: Namespace) -> bool:
375373
"""
376374
Returns True if `xp` is a JAX namespace.
377375
@@ -393,7 +391,7 @@ def is_jax_namespace(xp) -> bool:
393391
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
394392

395393

396-
def is_pydata_sparse_namespace(xp) -> bool:
394+
def is_pydata_sparse_namespace(xp: Namespace) -> bool:
397395
"""
398396
Returns True if `xp` is a pydata/sparse namespace.
399397
@@ -412,7 +410,7 @@ def is_pydata_sparse_namespace(xp) -> bool:
412410
return xp.__name__ == 'sparse'
413411

414412

415-
def is_array_api_strict_namespace(xp) -> bool:
413+
def is_array_api_strict_namespace(xp: Namespace) -> bool:
416414
"""
417415
Returns True if `xp` is an array-api-strict namespace.
418416
@@ -439,7 +437,11 @@ def _check_api_version(api_version: str) -> None:
439437
raise ValueError("Only the 2024.12 version of the array API specification is currently supported")
440438

441439

442-
def array_namespace(*xs, api_version=None, use_compat=None) -> Namespace:
440+
def array_namespace(
441+
*xs: Union[Array, bool, int, float, complex, None],
442+
api_version: Optional[str] = None,
443+
use_compat: Optional[bool] = None,
444+
) -> Namespace:
443445
"""
444446
Get the array API compatible namespace for the arrays `xs`.
445447

0 commit comments

Comments
 (0)