-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TYP: Type annotations overhaul, part 1 #257
Changes from all commits
b5582a7
8d888d5
4e4e84e
082c052
ac2cb73
84e28e7
8eaf862
40de1c9
a213ab5
0adac27
8fe4205
646fc61
2713c7c
1dc4955
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -1,149 +1,148 @@ | ||||||||||
from __future__ import annotations | ||||||||||
|
||||||||||
from typing import TYPE_CHECKING, Union, Optional, Literal | ||||||||||
from collections.abc import Sequence | ||||||||||
from typing import Union, Optional, Literal | ||||||||||
|
||||||||||
if TYPE_CHECKING: | ||||||||||
from ._typing import Device, ndarray, DType | ||||||||||
from collections.abc import Sequence | ||||||||||
from ._typing import Device, Array, DType, Namespace | ||||||||||
|
||||||||||
# Note: NumPy fft functions improperly upcast float32 and complex64 to | ||||||||||
# complex128, which is why we require wrapping them all here. | ||||||||||
|
||||||||||
def fft( | ||||||||||
x: ndarray, | ||||||||||
x: Array, | ||||||||||
/, | ||||||||||
xp, | ||||||||||
xp: Namespace, | ||||||||||
*, | ||||||||||
n: Optional[int] = None, | ||||||||||
axis: int = -1, | ||||||||||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||||||||||
) -> ndarray: | ||||||||||
) -> Array: | ||||||||||
res = xp.fft.fft(x, n=n, axis=axis, norm=norm) | ||||||||||
if x.dtype in [xp.float32, xp.complex64]: | ||||||||||
return res.astype(xp.complex64) | ||||||||||
return res | ||||||||||
|
||||||||||
def ifft( | ||||||||||
x: ndarray, | ||||||||||
x: Array, | ||||||||||
/, | ||||||||||
xp, | ||||||||||
xp: Namespace, | ||||||||||
*, | ||||||||||
n: Optional[int] = None, | ||||||||||
axis: int = -1, | ||||||||||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||||||||||
) -> ndarray: | ||||||||||
) -> Array: | ||||||||||
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) | ||||||||||
if x.dtype in [xp.float32, xp.complex64]: | ||||||||||
return res.astype(xp.complex64) | ||||||||||
return res | ||||||||||
|
||||||||||
def fftn( | ||||||||||
x: ndarray, | ||||||||||
x: Array, | ||||||||||
/, | ||||||||||
xp, | ||||||||||
xp: Namespace, | ||||||||||
*, | ||||||||||
s: Sequence[int] = None, | ||||||||||
axes: Sequence[int] = None, | ||||||||||
Comment on lines
44
to
45
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
FYI, a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't get to decide to restrict it to a tuple |
||||||||||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||||||||||
) -> ndarray: | ||||||||||
) -> Array: | ||||||||||
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) | ||||||||||
if x.dtype in [xp.float32, xp.complex64]: | ||||||||||
return res.astype(xp.complex64) | ||||||||||
return res | ||||||||||
|
||||||||||
def ifftn( | ||||||||||
x: ndarray, | ||||||||||
x: Array, | ||||||||||
/, | ||||||||||
xp, | ||||||||||
xp: Namespace, | ||||||||||
*, | ||||||||||
s: Sequence[int] = None, | ||||||||||
axes: Sequence[int] = None, | ||||||||||
Comment on lines
58
to
59
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||||||||||
) -> ndarray: | ||||||||||
) -> Array: | ||||||||||
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) | ||||||||||
if x.dtype in [xp.float32, xp.complex64]: | ||||||||||
return res.astype(xp.complex64) | ||||||||||
return res | ||||||||||
|
||||||||||
def rfft( | ||||||||||
x: ndarray, | ||||||||||
x: Array, | ||||||||||
/, | ||||||||||
xp, | ||||||||||
xp: Namespace, | ||||||||||
*, | ||||||||||
n: Optional[int] = None, | ||||||||||
axis: int = -1, | ||||||||||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||||||||||
) -> ndarray: | ||||||||||
) -> Array: | ||||||||||
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) | ||||||||||
if x.dtype == xp.float32: | ||||||||||
return res.astype(xp.complex64) | ||||||||||
return res | ||||||||||
|
||||||||||
def irfft( | ||||||||||
x: ndarray, | ||||||||||
x: Array, | ||||||||||
/, | ||||||||||
xp, | ||||||||||
xp: Namespace, | ||||||||||
*, | ||||||||||
n: Optional[int] = None, | ||||||||||
axis: int = -1, | ||||||||||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||||||||||
) -> ndarray: | ||||||||||
) -> Array: | ||||||||||
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) | ||||||||||
if x.dtype == xp.complex64: | ||||||||||
return res.astype(xp.float32) | ||||||||||
return res | ||||||||||
|
||||||||||
def rfftn( | ||||||||||
x: ndarray, | ||||||||||
x: Array, | ||||||||||
/, | ||||||||||
xp, | ||||||||||
xp: Namespace, | ||||||||||
*, | ||||||||||
s: Sequence[int] = None, | ||||||||||
axes: Sequence[int] = None, | ||||||||||
Comment on lines
100
to
101
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||||||||||
) -> ndarray: | ||||||||||
) -> Array: | ||||||||||
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) | ||||||||||
if x.dtype == xp.float32: | ||||||||||
return res.astype(xp.complex64) | ||||||||||
return res | ||||||||||
|
||||||||||
def irfftn( | ||||||||||
x: ndarray, | ||||||||||
x: Array, | ||||||||||
/, | ||||||||||
xp, | ||||||||||
xp: Namespace, | ||||||||||
*, | ||||||||||
s: Sequence[int] = None, | ||||||||||
axes: Sequence[int] = None, | ||||||||||
Comment on lines
114
to
115
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||||||||||
) -> ndarray: | ||||||||||
) -> Array: | ||||||||||
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) | ||||||||||
if x.dtype == xp.complex64: | ||||||||||
return res.astype(xp.float32) | ||||||||||
return res | ||||||||||
|
||||||||||
def hfft( | ||||||||||
x: ndarray, | ||||||||||
x: Array, | ||||||||||
/, | ||||||||||
xp, | ||||||||||
xp: Namespace, | ||||||||||
*, | ||||||||||
n: Optional[int] = None, | ||||||||||
axis: int = -1, | ||||||||||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||||||||||
) -> ndarray: | ||||||||||
) -> Array: | ||||||||||
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) | ||||||||||
if x.dtype in [xp.float32, xp.complex64]: | ||||||||||
return res.astype(xp.float32) | ||||||||||
return res | ||||||||||
|
||||||||||
def ihfft( | ||||||||||
x: ndarray, | ||||||||||
x: Array, | ||||||||||
/, | ||||||||||
xp, | ||||||||||
xp: Namespace, | ||||||||||
*, | ||||||||||
n: Optional[int] = None, | ||||||||||
axis: int = -1, | ||||||||||
norm: Literal["backward", "ortho", "forward"] = "backward", | ||||||||||
) -> ndarray: | ||||||||||
) -> Array: | ||||||||||
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) | ||||||||||
if x.dtype in [xp.float32, xp.complex64]: | ||||||||||
return res.astype(xp.complex64) | ||||||||||
|
@@ -152,12 +151,12 @@ def ihfft( | |||||||||
def fftfreq( | ||||||||||
n: int, | ||||||||||
/, | ||||||||||
xp, | ||||||||||
xp: Namespace, | ||||||||||
*, | ||||||||||
d: float = 1.0, | ||||||||||
dtype: Optional[DType] = None, | ||||||||||
device: Optional[Device] = None | ||||||||||
) -> ndarray: | ||||||||||
device: Optional[Device] = None, | ||||||||||
) -> Array: | ||||||||||
if device not in ["cpu", None]: | ||||||||||
raise ValueError(f"Unsupported device {device!r}") | ||||||||||
res = xp.fft.fftfreq(n, d=d) | ||||||||||
|
@@ -168,23 +167,27 @@ def fftfreq( | |||||||||
def rfftfreq( | ||||||||||
n: int, | ||||||||||
/, | ||||||||||
xp, | ||||||||||
xp: Namespace, | ||||||||||
*, | ||||||||||
d: float = 1.0, | ||||||||||
dtype: Optional[DType] = None, | ||||||||||
device: Optional[Device] = None | ||||||||||
) -> ndarray: | ||||||||||
device: Optional[Device] = None, | ||||||||||
) -> Array: | ||||||||||
if device not in ["cpu", None]: | ||||||||||
raise ValueError(f"Unsupported device {device!r}") | ||||||||||
res = xp.fft.rfftfreq(n, d=d) | ||||||||||
if dtype is not None: | ||||||||||
return res.astype(dtype) | ||||||||||
return res | ||||||||||
|
||||||||||
def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: | ||||||||||
def fftshift( | ||||||||||
x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
) -> Array: | ||||||||||
return xp.fft.fftshift(x, axes=axes) | ||||||||||
|
||||||||||
def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: | ||||||||||
def ifftshift( | ||||||||||
x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
) -> Array: | ||||||||||
return xp.fft.ifftshift(x, axes=axes) | ||||||||||
|
||||||||||
__all__ = [ | ||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -7,16 +7,14 @@ | |||||
""" | ||||||
from __future__ import annotations | ||||||
|
||||||
from typing import TYPE_CHECKING | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
from typing import Optional, Union, Any | ||||||
from ._typing import Array, Device, Namespace | ||||||
|
||||||
import sys | ||||||
import math | ||||||
import inspect | ||||||
import warnings | ||||||
from typing import Optional, Union, Any | ||||||
|
||||||
from ._typing import Array, Device, Namespace | ||||||
|
||||||
|
||||||
def _is_jax_zero_gradient_array(x: object) -> bool: | ||||||
"""Return True if `x` is a zero-gradient array. | ||||||
|
@@ -268,7 +266,7 @@ | |||||
return __name__.removesuffix('.common._helpers') | ||||||
|
||||||
|
||||||
def is_numpy_namespace(xp) -> bool: | ||||||
def is_numpy_namespace(xp: Namespace) -> bool: | ||||||
""" | ||||||
Returns True if `xp` is a NumPy namespace. | ||||||
|
||||||
|
@@ -289,7 +287,7 @@ | |||||
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'} | ||||||
|
||||||
|
||||||
def is_cupy_namespace(xp) -> bool: | ||||||
def is_cupy_namespace(xp: Namespace) -> bool: | ||||||
""" | ||||||
Returns True if `xp` is a CuPy namespace. | ||||||
|
||||||
|
@@ -310,7 +308,7 @@ | |||||
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'} | ||||||
|
||||||
|
||||||
def is_torch_namespace(xp) -> bool: | ||||||
def is_torch_namespace(xp: Namespace) -> bool: | ||||||
""" | ||||||
Returns True if `xp` is a PyTorch namespace. | ||||||
|
||||||
|
@@ -331,7 +329,7 @@ | |||||
return xp.__name__ in {'torch', _compat_module_name() + '.torch'} | ||||||
|
||||||
|
||||||
def is_ndonnx_namespace(xp) -> bool: | ||||||
def is_ndonnx_namespace(xp: Namespace) -> bool: | ||||||
""" | ||||||
Returns True if `xp` is an NDONNX namespace. | ||||||
|
||||||
|
@@ -350,7 +348,7 @@ | |||||
return xp.__name__ == 'ndonnx' | ||||||
|
||||||
|
||||||
def is_dask_namespace(xp) -> bool: | ||||||
def is_dask_namespace(xp: Namespace) -> bool: | ||||||
""" | ||||||
Returns True if `xp` is a Dask namespace. | ||||||
|
||||||
|
@@ -371,7 +369,7 @@ | |||||
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'} | ||||||
|
||||||
|
||||||
def is_jax_namespace(xp) -> bool: | ||||||
def is_jax_namespace(xp: Namespace) -> bool: | ||||||
""" | ||||||
Returns True if `xp` is a JAX namespace. | ||||||
|
||||||
|
@@ -393,7 +391,7 @@ | |||||
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'} | ||||||
|
||||||
|
||||||
def is_pydata_sparse_namespace(xp) -> bool: | ||||||
def is_pydata_sparse_namespace(xp: Namespace) -> bool: | ||||||
""" | ||||||
Returns True if `xp` is a pydata/sparse namespace. | ||||||
|
||||||
|
@@ -412,7 +410,7 @@ | |||||
return xp.__name__ == 'sparse' | ||||||
|
||||||
|
||||||
def is_array_api_strict_namespace(xp) -> bool: | ||||||
def is_array_api_strict_namespace(xp: Namespace) -> bool: | ||||||
""" | ||||||
Returns True if `xp` is an array-api-strict namespace. | ||||||
|
||||||
|
@@ -439,7 +437,11 @@ | |||||
raise ValueError("Only the 2024.12 version of the array API specification is currently supported") | ||||||
|
||||||
|
||||||
def array_namespace(*xs, api_version=None, use_compat=None) -> Namespace: | ||||||
def array_namespace( | ||||||
*xs: Union[Array, bool, int, float, complex, None], | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's sufficient to write
Suggested change
Keep in mind that in case of type-checker errors, often the entire signature is printed. So it's usually a good idea to avoid long annotations (in function signatures, but also in general). |
||||||
api_version: Optional[str] = None, | ||||||
use_compat: Optional[bool] = None, | ||||||
) -> Namespace: | ||||||
""" | ||||||
Get the array API compatible namespace for the arrays `xs`. | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This literal is repeated quite often, so it's probably a good idea to extract it as something like