Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP: Type annotations overhaul, part 1 #257

Merged
merged 14 commits into from
Mar 22, 2025
248 changes: 140 additions & 108 deletions array_api_compat/common/_aliases.py

Large diffs are not rendered by default.

87 changes: 45 additions & 42 deletions array_api_compat/common/_fft.py
Original file line number Diff line number Diff line change
@@ -1,149 +1,148 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Union, Optional, Literal
from collections.abc import Sequence
from typing import Union, Optional, Literal

if TYPE_CHECKING:
from ._typing import Device, ndarray, DType
from collections.abc import Sequence
from ._typing import Device, Array, DType, Namespace

# Note: NumPy fft functions improperly upcast float32 and complex64 to
# complex128, which is why we require wrapping them all here.

def fft(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This literal is repeated quite often, so it's probably a good idea to extract it as something like

_Norm: TypeAlias = Literal["backward", "ortho", "forward"]

) -> ndarray:
) -> Array:
res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def ifft(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def fftn(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
Comment on lines 44 to 45
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
s: Sequence[int] = None,
axes: Sequence[int] = None,
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,

FYI, a Sequence[int] is very broad, and will even accept things like bytes. Maybe a tuple[int, ...] is more appropriate here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def ifftn(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
Comment on lines 58 to 59
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
s: Sequence[int] = None,
axes: Sequence[int] = None,
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,

norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res

def rfft(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.float32:
return res.astype(xp.complex64)
return res

def irfft(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.complex64:
return res.astype(xp.float32)
return res

def rfftn(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
Comment on lines 100 to 101
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
s: Sequence[int] = None,
axes: Sequence[int] = None,
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,

norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.float32:
return res.astype(xp.complex64)
return res

def irfftn(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
Comment on lines 114 to 115
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
s: Sequence[int] = None,
axes: Sequence[int] = None,
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,

norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.complex64:
return res.astype(xp.float32)
return res

def hfft(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.float32)
return res

def ihfft(
x: ndarray,
x: Array,
/,
xp,
xp: Namespace,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
) -> Array:
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
Expand All @@ -152,12 +151,12 @@ def ihfft(
def fftfreq(
n: int,
/,
xp,
xp: Namespace,
*,
d: float = 1.0,
dtype: Optional[DType] = None,
device: Optional[Device] = None
) -> ndarray:
device: Optional[Device] = None,
) -> Array:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
res = xp.fft.fftfreq(n, d=d)
Expand All @@ -168,23 +167,27 @@ def fftfreq(
def rfftfreq(
n: int,
/,
xp,
xp: Namespace,
*,
d: float = 1.0,
dtype: Optional[DType] = None,
device: Optional[Device] = None
) -> ndarray:
device: Optional[Device] = None,
) -> Array:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
res = xp.fft.rfftfreq(n, d=d)
if dtype is not None:
return res.astype(dtype)
return res

def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
def fftshift(
x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None

) -> Array:
return xp.fft.fftshift(x, axes=axes)

def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
def ifftshift(
x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None

) -> Array:
return xp.fft.ifftshift(x, axes=axes)

__all__ = [
Expand Down
32 changes: 17 additions & 15 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,14 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Optional, Union, Any
from ._typing import Array, Device, Namespace

import sys
import math
import inspect
import warnings
from typing import Optional, Union, Any

from ._typing import Array, Device, Namespace


def _is_jax_zero_gradient_array(x: object) -> bool:
"""Return True if `x` is a zero-gradient array.
Expand Down Expand Up @@ -268,7 +266,7 @@
return __name__.removesuffix('.common._helpers')


def is_numpy_namespace(xp) -> bool:
def is_numpy_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a NumPy namespace.

Expand All @@ -289,7 +287,7 @@
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}


def is_cupy_namespace(xp) -> bool:
def is_cupy_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a CuPy namespace.

Expand All @@ -310,7 +308,7 @@
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}


def is_torch_namespace(xp) -> bool:
def is_torch_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a PyTorch namespace.

Expand All @@ -331,7 +329,7 @@
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}


def is_ndonnx_namespace(xp) -> bool:
def is_ndonnx_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is an NDONNX namespace.

Expand All @@ -350,7 +348,7 @@
return xp.__name__ == 'ndonnx'


def is_dask_namespace(xp) -> bool:
def is_dask_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a Dask namespace.

Expand All @@ -371,7 +369,7 @@
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}


def is_jax_namespace(xp) -> bool:
def is_jax_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a JAX namespace.

Expand All @@ -393,7 +391,7 @@
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}


def is_pydata_sparse_namespace(xp) -> bool:
def is_pydata_sparse_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a pydata/sparse namespace.

Expand All @@ -412,7 +410,7 @@
return xp.__name__ == 'sparse'


def is_array_api_strict_namespace(xp) -> bool:
def is_array_api_strict_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is an array-api-strict namespace.

Expand All @@ -439,7 +437,11 @@
raise ValueError("Only the 2024.12 version of the array API specification is currently supported")


def array_namespace(*xs, api_version=None, use_compat=None) -> Namespace:
def array_namespace(
*xs: Union[Array, bool, int, float, complex, None],

Check failure on line 441 in array_api_compat/common/_helpers.py

View workflow job for this annotation

GitHub Actions / check-ruff

array_api_compat/common/_helpers.py:441:10: SyntaxError: Cannot use star annotation on Python 3.9 (syntax was added in Python 3.11)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's sufficient to write

Suggested change
*xs: Union[Array, bool, int, float, complex, None],
*xs: Array | complex | None,

Keep in mind that in case of type-checker errors, often the entire signature is printed. So it's usually a good idea to avoid long annotations (in function signatures, but also in general).

api_version: Optional[str] = None,
use_compat: Optional[bool] = None,
) -> Namespace:
"""
Get the array API compatible namespace for the arrays `xs`.

Expand Down
Loading
Loading