diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 35262d3a..0d123b99 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -4,15 +4,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Sequence, Tuple, Union - from ._typing import ndarray, Device, Dtype - -from typing import NamedTuple import inspect +from typing import NamedTuple, Optional, Sequence, Tuple, Union from ._helpers import array_namespace, _check_device, device, is_cupy_namespace +from ._typing import Array, Device, DType, Namespace # These functions are modified from the NumPy versions. @@ -24,29 +20,34 @@ def arange( stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, - xp, - dtype: Optional[Dtype] = None, + xp: Namespace, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: _check_device(xp, device) return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) def empty( shape: Union[int, Tuple[int, ...]], - xp, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: _check_device(xp, device) return xp.empty(shape, dtype=dtype, **kwargs) def empty_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: + x: Array, + /, + xp: Namespace, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + **kwargs, +) -> Array: _check_device(xp, device) return xp.empty_like(x, dtype=dtype, **kwargs) @@ -55,37 +56,37 @@ def eye( n_cols: Optional[int] = None, /, *, - xp, + xp: Namespace, k: int = 0, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) def full( shape: Union[int, Tuple[int, ...]], - fill_value: Union[int, float], - xp, + fill_value: complex, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.full(shape, fill_value, dtype=dtype, **kwargs) def full_like( - x: ndarray, + x: Array, /, - fill_value: Union[int, float], + fill_value: complex, *, - xp, - dtype: Optional[Dtype] = None, + xp: Namespace, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.full_like(x, fill_value, dtype=dtype, **kwargs) @@ -95,48 +96,58 @@ def linspace( /, num: int, *, - xp, - dtype: Optional[Dtype] = None, + xp: Namespace, + dtype: Optional[DType] = None, device: Optional[Device] = None, endpoint: bool = True, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) def ones( shape: Union[int, Tuple[int, ...]], - xp, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.ones(shape, dtype=dtype, **kwargs) def ones_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + x: Array, + /, + xp: Namespace, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.ones_like(x, dtype=dtype, **kwargs) def zeros( shape: Union[int, Tuple[int, ...]], - xp, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.zeros(shape, dtype=dtype, **kwargs) def zeros_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + x: Array, + /, + xp: Namespace, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.zeros_like(x, dtype=dtype, **kwargs) @@ -150,23 +161,23 @@ def zeros_like( # Note that these named tuples aren't actually part of the standard namespace, # but I don't see any issue with exporting the names here regardless. class UniqueAllResult(NamedTuple): - values: ndarray - indices: ndarray - inverse_indices: ndarray - counts: ndarray + values: Array + indices: Array + inverse_indices: Array + counts: Array class UniqueCountsResult(NamedTuple): - values: ndarray - counts: ndarray + values: Array + counts: Array class UniqueInverseResult(NamedTuple): - values: ndarray - inverse_indices: ndarray + values: Array + inverse_indices: Array -def _unique_kwargs(xp): +def _unique_kwargs(xp: Namespace) -> dict[str, bool]: # Older versions of NumPy and CuPy do not have equal_nan. Rather than # trying to parse version numbers, just check if equal_nan is in the # signature. @@ -175,7 +186,7 @@ def _unique_kwargs(xp): return {'equal_nan': False} return {} -def unique_all(x: ndarray, /, xp) -> UniqueAllResult: +def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult: kwargs = _unique_kwargs(xp) values, indices, inverse_indices, counts = xp.unique( x, @@ -195,7 +206,7 @@ def unique_all(x: ndarray, /, xp) -> UniqueAllResult: ) -def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: +def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult: kwargs = _unique_kwargs(xp) res = xp.unique( x, @@ -208,7 +219,7 @@ def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: return UniqueCountsResult(*res) -def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: +def unique_inverse(x: Array, /, xp: Namespace) -> UniqueInverseResult: kwargs = _unique_kwargs(xp) values, inverse_indices = xp.unique( x, @@ -223,7 +234,7 @@ def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: return UniqueInverseResult(values, inverse_indices) -def unique_values(x: ndarray, /, xp) -> ndarray: +def unique_values(x: Array, /, xp: Namespace) -> Array: kwargs = _unique_kwargs(xp) return xp.unique( x, @@ -236,42 +247,42 @@ def unique_values(x: ndarray, /, xp) -> ndarray: # These functions have different keyword argument names def std( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + correction: Union[int, float] = 0.0, # correction instead of ddof keepdims: bool = False, **kwargs, -) -> ndarray: +) -> Array: return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) def var( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + correction: Union[int, float] = 0.0, # correction instead of ddof keepdims: bool = False, **kwargs, -) -> ndarray: +) -> Array: return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) # cumulative_sum is renamed from cumsum, and adds the include_initial keyword # argument def cumulative_sum( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[int] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, include_initial: bool = False, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: wrapped_xp = array_namespace(x) # TODO: The standard is not clear about what should happen when x.ndim == 0. @@ -294,15 +305,15 @@ def cumulative_sum( def cumulative_prod( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[int] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, include_initial: bool = False, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: wrapped_xp = array_namespace(x) if axis is None: @@ -325,17 +336,18 @@ def cumulative_prod( # The min and max argument names in clip are different and not optional in numpy, and type # promotion behavior is different. def clip( - x: ndarray, + x: Array, /, - min: Optional[Union[int, float, ndarray]] = None, - max: Optional[Union[int, float, ndarray]] = None, + min: Optional[Union[int, float, Array]] = None, + max: Optional[Union[int, float, Array]] = None, *, - xp, + xp: Namespace, # TODO: np.clip has other ufunc kwargs - out: Optional[ndarray] = None, -) -> ndarray: + out: Optional[Array] = None, +) -> Array: def _isscalar(a): return isinstance(a, (int, float, type(None))) + min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -390,15 +402,19 @@ def _isscalar(a): return out[()] # Unlike transpose(), the axes argument to permute_dims() is required. -def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray: +def permute_dims(x: Array, /, axes: Tuple[int, ...], xp: Namespace) -> Array: return xp.transpose(x, axes) # np.reshape calls the keyword argument 'newshape' instead of 'shape' -def reshape(x: ndarray, - /, - shape: Tuple[int, ...], - xp, copy: Optional[bool] = None, - **kwargs) -> ndarray: +def reshape( + x: Array, + /, + shape: Tuple[int, ...], + xp: Namespace, + *, + copy: Optional[bool] = None, + **kwargs, +) -> Array: if copy is True: x = x.copy() elif copy is False: @@ -410,9 +426,15 @@ def reshape(x: ndarray, # The descending keyword is new in sort and argsort, and 'kind' replaced with # 'stable' def argsort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, + x: Array, + /, + xp: Namespace, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, **kwargs, -) -> ndarray: +) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. @@ -435,9 +457,15 @@ def argsort( return res def sort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, + x: Array, + /, + xp: Namespace, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, **kwargs, -) -> ndarray: +) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. @@ -449,50 +477,51 @@ def sort( return res # nonzero should error for zero-dimensional arrays -def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]: +def nonzero(x: Array, /, xp: Namespace, **kwargs) -> Tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return xp.nonzero(x, **kwargs) # ceil, floor, and trunc return integers for integer inputs -def ceil(x: ndarray, /, xp, **kwargs) -> ndarray: +def ceil(x: Array, /, xp: Namespace, **kwargs) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.ceil(x, **kwargs) -def floor(x: ndarray, /, xp, **kwargs) -> ndarray: +def floor(x: Array, /, xp: Namespace, **kwargs) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.floor(x, **kwargs) -def trunc(x: ndarray, /, xp, **kwargs) -> ndarray: +def trunc(x: Array, /, xp: Namespace, **kwargs) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.trunc(x, **kwargs) # linear algebra functions -def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: +def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: return xp.matmul(x1, x2, **kwargs) # Unlike transpose, matrix_transpose only transposes the last two axes. -def matrix_transpose(x: ndarray, /, xp) -> ndarray: +def matrix_transpose(x: Array, /, xp: Namespace) -> Array: if x.ndim < 2: raise ValueError("x must be at least 2-dimensional for matrix_transpose") return xp.swapaxes(x, -1, -2) -def tensordot(x1: ndarray, - x2: ndarray, - /, - xp, - *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, -) -> ndarray: +def tensordot( + x1: Array, + x2: Array, + /, + xp: Namespace, + *, + axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, + **kwargs, +) -> Array: return xp.tensordot(x1, x2, axes=axes, **kwargs) -def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: +def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array: if x1.shape[axis] != x2.shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") @@ -511,8 +540,11 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: # isdtype is a new function in the 2022.12 array API specification. def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp, - *, _tuple=True, # Disallow nested tuples + dtype: DType, + kind: Union[DType, str, Tuple[Union[DType, str], ...]], + xp: Namespace, + *, + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -551,14 +583,14 @@ def isdtype( return dtype == kind # unstack is a new function in the 2023.12 array API standard -def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]: +def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> Tuple[Array, ...]: if x.ndim == 0: raise ValueError("Input array must be at least 1-d.") return tuple(xp.moveaxis(x, axis, 0)) # numpy 1.26 does not use the standard definition for sign on complex numbers -def sign(x: ndarray, /, xp, **kwargs) -> ndarray: +def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: if isdtype(x.dtype, 'complex floating', xp=xp): out = (x/xp.abs(x, **kwargs))[...] # sign(0) = 0 but the above formula would give nan diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index e5caebef..bd2a4e1a 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -1,149 +1,148 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Union, Optional, Literal +from collections.abc import Sequence +from typing import Union, Optional, Literal -if TYPE_CHECKING: - from ._typing import Device, ndarray, DType - from collections.abc import Sequence +from ._typing import Device, Array, DType, Namespace # Note: NumPy fft functions improperly upcast float32 and complex64 to # complex128, which is why we require wrapping them all here. def fft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.fft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def ifft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def fftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def ifftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def rfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.float32: return res.astype(xp.complex64) return res def irfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.complex64: return res.astype(xp.float32) return res def rfftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.float32: return res.astype(xp.complex64) return res def irfftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.complex64: return res.astype(xp.float32) return res def hfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.float32) return res def ihfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) @@ -152,12 +151,12 @@ def ihfft( def fftfreq( n: int, /, - xp, + xp: Namespace, *, d: float = 1.0, dtype: Optional[DType] = None, - device: Optional[Device] = None -) -> ndarray: + device: Optional[Device] = None, +) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") res = xp.fft.fftfreq(n, d=d) @@ -168,12 +167,12 @@ def fftfreq( def rfftfreq( n: int, /, - xp, + xp: Namespace, *, d: float = 1.0, dtype: Optional[DType] = None, - device: Optional[Device] = None -) -> ndarray: + device: Optional[Device] = None, +) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") res = xp.fft.rfftfreq(n, d=d) @@ -181,10 +180,14 @@ def rfftfreq( return res.astype(dtype) return res -def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: +def fftshift( + x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None +) -> Array: return xp.fft.fftshift(x, axes=axes) -def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: +def ifftshift( + x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None +) -> Array: return xp.fft.ifftshift(x, axes=axes) __all__ = [ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 791edb81..6d95069d 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -7,16 +7,14 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Optional, Union, Any - from ._typing import Array, Device, Namespace - import sys import math import inspect import warnings +from typing import Optional, Union, Any + +from ._typing import Array, Device, Namespace + def _is_jax_zero_gradient_array(x: object) -> bool: """Return True if `x` is a zero-gradient array. @@ -268,7 +266,7 @@ def _compat_module_name() -> str: return __name__.removesuffix('.common._helpers') -def is_numpy_namespace(xp) -> bool: +def is_numpy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a NumPy namespace. @@ -289,7 +287,7 @@ def is_numpy_namespace(xp) -> bool: return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'} -def is_cupy_namespace(xp) -> bool: +def is_cupy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a CuPy namespace. @@ -310,7 +308,7 @@ def is_cupy_namespace(xp) -> bool: return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'} -def is_torch_namespace(xp) -> bool: +def is_torch_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a PyTorch namespace. @@ -331,7 +329,7 @@ def is_torch_namespace(xp) -> bool: return xp.__name__ in {'torch', _compat_module_name() + '.torch'} -def is_ndonnx_namespace(xp) -> bool: +def is_ndonnx_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is an NDONNX namespace. @@ -350,7 +348,7 @@ def is_ndonnx_namespace(xp) -> bool: return xp.__name__ == 'ndonnx' -def is_dask_namespace(xp) -> bool: +def is_dask_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a Dask namespace. @@ -371,7 +369,7 @@ def is_dask_namespace(xp) -> bool: return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'} -def is_jax_namespace(xp) -> bool: +def is_jax_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a JAX namespace. @@ -393,7 +391,7 @@ def is_jax_namespace(xp) -> bool: return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'} -def is_pydata_sparse_namespace(xp) -> bool: +def is_pydata_sparse_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a pydata/sparse namespace. @@ -412,7 +410,7 @@ def is_pydata_sparse_namespace(xp) -> bool: return xp.__name__ == 'sparse' -def is_array_api_strict_namespace(xp) -> bool: +def is_array_api_strict_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is an array-api-strict namespace. @@ -439,7 +437,11 @@ def _check_api_version(api_version: str) -> None: raise ValueError("Only the 2024.12 version of the array API specification is currently supported") -def array_namespace(*xs, api_version=None, use_compat=None) -> Namespace: +def array_namespace( + *xs: Union[Array, bool, int, float, complex, None], + api_version: Optional[str] = None, + use_compat: Optional[bool] = None, +) -> Namespace: """ Get the array API compatible namespace for the arrays `xs`. diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index bfa1f1b9..c77ee3b8 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -1,11 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, NamedTuple -if TYPE_CHECKING: - from typing import Literal, Optional, Tuple, Union - from ._typing import ndarray - import math +from typing import Literal, NamedTuple, Optional, Tuple, Union import numpy as np if np.__version__[0] == "2": @@ -15,50 +11,53 @@ from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype from .._internal import get_xp +from ._typing import Array, Namespace # These are in the main NumPy namespace but not in numpy.linalg -def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray: +def cross(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1, **kwargs) -> Array: return xp.cross(x1, x2, axis=axis, **kwargs) -def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: +def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: return xp.outer(x1, x2, **kwargs) class EighResult(NamedTuple): - eigenvalues: ndarray - eigenvectors: ndarray + eigenvalues: Array + eigenvectors: Array class QRResult(NamedTuple): - Q: ndarray - R: ndarray + Q: Array + R: Array class SlogdetResult(NamedTuple): - sign: ndarray - logabsdet: ndarray + sign: Array + logabsdet: Array class SVDResult(NamedTuple): - U: ndarray - S: ndarray - Vh: ndarray + U: Array + S: Array + Vh: Array # These functions are the same as their NumPy counterparts except they return # a namedtuple. -def eigh(x: ndarray, /, xp, **kwargs) -> EighResult: +def eigh(x: Array, /, xp: Namespace, **kwargs) -> EighResult: return EighResult(*xp.linalg.eigh(x, **kwargs)) -def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced', +def qr(x: Array, /, xp: Namespace, *, mode: Literal['reduced', 'complete'] = 'reduced', **kwargs) -> QRResult: return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) -def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult: +def slogdet(x: Array, /, xp: Namespace, **kwargs) -> SlogdetResult: return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) -def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult: +def svd( + x: Array, /, xp: Namespace, *, full_matrices: bool = True, **kwargs +) -> SVDResult: return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) # These functions have additional keyword arguments # The upper keyword argument is new from NumPy -def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: +def cholesky(x: Array, /, xp: Namespace, *, upper: bool = False, **kwargs) -> Array: L = xp.linalg.cholesky(x, **kwargs) if upper: U = get_xp(xp)(matrix_transpose)(L) @@ -69,12 +68,12 @@ def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. # Note that it has a different semantic meaning from tol and rcond. -def matrix_rank(x: ndarray, +def matrix_rank(x: Array, /, - xp, + xp: Namespace, *, - rtol: Optional[Union[float, ndarray]] = None, - **kwargs) -> ndarray: + rtol: Optional[Union[float, Array]] = None, + **kwargs) -> Array: # this is different from xp.linalg.matrix_rank, which supports 1 # dimensional arrays. if x.ndim < 2: @@ -88,7 +87,9 @@ def matrix_rank(x: ndarray, tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis] return xp.count_nonzero(S > tol, axis=-1) -def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray: +def pinv( + x: Array, /, xp: Namespace, *, rtol: Optional[Union[float, Array]] = None, **kwargs +) -> Array: # this is different from xp.linalg.pinv, which does not multiply the # default tolerance by max(M, N). if rtol is None: @@ -97,15 +98,30 @@ def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **k # These functions are new in the array API spec -def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray: +def matrix_norm( + x: Array, + /, + xp: Namespace, + *, + keepdims: bool = False, + ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro', +) -> Array: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) # svdvals is not in NumPy (but it is in SciPy). It is equivalent to # xp.linalg.svd(compute_uv=False). -def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]: +def svdvals(x: Array, /, xp: Namespace) -> Union[Array, Tuple[Array, ...]]: return xp.linalg.svd(x, compute_uv=False) -def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray: +def vector_norm( + x: Array, + /, + xp: Namespace, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + ord: Optional[Union[int, float]] = 2, +) -> Array: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make # it so the input is 1-D (for axis=None), or reshape so that norm is done @@ -143,11 +159,15 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] # xp.diagonal and xp.trace operate on the first two axes whereas these # operates on the last two -def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: +def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs) -> Array: return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) -def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray: - return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) +def trace( + x: Array, /, xp: Namespace, *, offset: int = 0, dtype=None, **kwargs +) -> Array: + return xp.asarray( + xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs) + ) __all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet', diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index d8acdef7..4c3b356b 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -1,26 +1,24 @@ from __future__ import annotations +from types import ModuleType as Namespace +from typing import Any, TypeVar, Protocol __all__ = [ + "Array", + "DType", + "Device", + "Namespace", "NestedSequence", "SupportsBufferProtocol", ] -from types import ModuleType -from typing import ( - Any, - TypeVar, - Protocol, -) - _T_co = TypeVar("_T_co", covariant=True) class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... -SupportsBufferProtocol = Any +SupportsBufferProtocol = Any Array = Any Device = Any DType = Any -Namespace = ModuleType diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 30d9fe48..ebc7ccd9 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -1,16 +1,14 @@ from __future__ import annotations +from typing import Optional + import cupy as cp from ..common import _aliases, _helpers +from ..common._typing import NestedSequence, SupportsBufferProtocol from .._internal import get_xp - from ._info import __array_namespace_info__ - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Union - from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol +from ._typing import Array, Device, DType bool = cp.bool_ @@ -66,23 +64,19 @@ _copy_default = object() + # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Union[ - ndarray, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: ( + Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + ), /, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, copy: Optional[bool] = _copy_default, **kwargs, -) -> ndarray: +) -> Array: """ Array API compatibility wrapper for asarray(). @@ -112,13 +106,13 @@ def asarray( def astype( - x: ndarray, - dtype: Dtype, + x: Array, + dtype: DType, /, *, copy: bool = True, device: Optional[Device] = None, -) -> ndarray: +) -> Array: if device is None: return x.astype(dtype=dtype, copy=copy) out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device) @@ -127,10 +121,10 @@ def astype( # cupy.count_nonzero does not have keepdims def count_nonzero( - x: ndarray, + x: Array, axis=None, keepdims=False -) -> ndarray: +) -> Array: result = cp.count_nonzero(x, axis) if keepdims: if axis is None: diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index f3d9aab6..66af5d19 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -1,46 +1,31 @@ from __future__ import annotations -__all__ = [ - "ndarray", - "Device", - "Dtype", -] +__all__ = ["Array", "DType", "Device"] +_all_ignore = ["cp"] -import sys -from typing import ( - Union, - TYPE_CHECKING, -) - -from cupy import ( - ndarray, - dtype, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, -) +from typing import TYPE_CHECKING +import cupy as cp +from cupy import ndarray as Array from cupy.cuda.device import Device -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] +if TYPE_CHECKING: + # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[] + DType = cp.dtype[ + cp.intp + | cp.int8 + | cp.int16 + | cp.int32 + | cp.int64 + | cp.uint8 + | cp.uint16 + | cp.uint32 + | cp.uint64 + | cp.float32 + | cp.float64 + | cp.complex64 + | cp.complex128 + | cp.bool_ + ] else: - Dtype = dtype + DType = cp.dtype diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 80d66281..e737cebd 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,16 +1,10 @@ from __future__ import annotations -from typing import Callable - -from ...common import _aliases, array_namespace - -from ..._internal import get_xp - -from ._info import __array_namespace_info__ +from typing import Callable, Optional, Union import numpy as np from numpy import ( - # Dtypes + # dtypes iinfo, finfo, bool_ as bool, @@ -29,22 +23,19 @@ can_cast, result_type, ) - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Optional, Union - - from ...common._typing import ( - Device, - Dtype, - Array, - NestedSequence, - SupportsBufferProtocol, - ) - import dask.array as da +from ...common import _aliases, array_namespace +from ...common._typing import ( + Array, + Device, + DType, + NestedSequence, + SupportsBufferProtocol, +) +from ..._internal import get_xp +from ._info import __array_namespace_info__ + isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) @@ -52,7 +43,7 @@ # da.astype doesn't respect copy=True def astype( x: Array, - dtype: Dtype, + dtype: DType, /, *, copy: bool = True, @@ -84,7 +75,7 @@ def arange( stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, ) -> Array: @@ -144,17 +135,12 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Union[ - Array, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: ( + Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + ), /, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, copy: Optional[Union[bool, np._CopyMode]] = None, **kwargs, @@ -360,4 +346,4 @@ def count_nonzero( 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'count_nonzero', 'result_type'] -_all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"] +_all_ignore = ["array_namespace", "get_xp", "da", "np"] diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py index aebd86f7..3f40dffe 100644 --- a/array_api_compat/dask/array/fft.py +++ b/array_api_compat/dask/array/fft.py @@ -4,9 +4,10 @@ # from dask.array.fft import __all__ as linalg_all _n = {} exec('from dask.array.fft import *', _n) -del _n['__builtins__'] +for k in ("__builtins__", "Sequence", "annotations", "warnings"): + _n.pop(k, None) fft_all = list(_n) -del _n +del _n, k from ...common import _fft from ..._internal import get_xp @@ -16,9 +17,5 @@ fftfreq = get_xp(da)(_fft.fftfreq) rfftfreq = get_xp(da)(_fft.rfftfreq) -__all__ = [elem for elem in fft_all if elem != "annotations"] + ["fftfreq", "rfftfreq"] - -del get_xp -del da -del fft_all -del _fft +__all__ = fft_all + ["fftfreq", "rfftfreq"] +_all_ignore = ["da", "fft_all", "get_xp", "warnings"] diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 49c26d8b..bd53f0df 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -1,33 +1,28 @@ from __future__ import annotations -from ...common import _linalg -from ..._internal import get_xp +from typing import Literal +import dask.array as da # Exports from dask.array.linalg import * # noqa: F403 from dask.array import outer - # These functions are in both the main and linalg namespaces from dask.array import matmul, tensordot -from ._aliases import matrix_transpose, vecdot -import dask.array as da - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from ...common._typing import Array - from typing import Literal +from ..._internal import get_xp +from ...common import _linalg +from ...common._typing import Array +from ._aliases import matrix_transpose, vecdot # dask.array.linalg doesn't have __all__. If it is added, replace this with # # from dask.array.linalg import __all__ as linalg_all _n = {} exec('from dask.array.linalg import *', _n) -del _n['__builtins__'] -if 'annotations' in _n: - del _n['annotations'] +for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'): + _n.pop(k, None) linalg_all = list(_n) -del _n +del _n, k EighResult = _linalg.EighResult QRResult = _linalg.QRResult @@ -70,4 +65,4 @@ def svdvals(x: Array) -> Array: "cholesky", "matrix_rank", "matrix_norm", "svdvals", "vector_norm", "diagonal"] -_all_ignore = ['get_xp', 'da', 'linalg_all'] +_all_ignore = ['get_xp', 'da', 'linalg_all', 'warnings'] diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index a47f7121..6536d9a8 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -1,17 +1,15 @@ from __future__ import annotations -from ..common import _aliases +from typing import Optional, Union from .._internal import get_xp - +from ..common import _aliases +from ..common._typing import NestedSequence, SupportsBufferProtocol from ._info import __array_namespace_info__ - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Union - from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol +from ._typing import Array, Device, DType import numpy as np + bool = np.bool_ # Basic renames @@ -64,6 +62,7 @@ tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) + def _supports_buffer_protocol(obj): try: memoryview(obj) @@ -71,26 +70,22 @@ def _supports_buffer_protocol(obj): return False return True + # asarray also adds the copy keyword, which is not present in numpy 1.0. # asarray() is different enough between numpy, cupy, and dask, the logic # complicated enough that it's easier to define it separately for each module # rather than trying to combine everything into one function in common/ def asarray( - obj: Union[ - ndarray, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: ( + Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + ), /, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, copy: "Optional[Union[bool, np._CopyMode]]" = None, **kwargs, -) -> ndarray: +) -> Array: """ Array API compatibility wrapper for asarray(). @@ -117,23 +112,19 @@ def asarray( def astype( - x: ndarray, - dtype: Dtype, + x: Array, + dtype: DType, /, *, copy: bool = True, device: Optional[Device] = None, -) -> ndarray: +) -> Array: return x.astype(dtype=dtype, copy=copy) # count_nonzero returns a python int for axis=None and keepdims=False # https://github.com/numpy/numpy/issues/17562 -def count_nonzero( - x : ndarray, - axis=None, - keepdims=False -) -> ndarray: +def count_nonzero(x: Array, axis=None, keepdims=False) -> Array: result = np.count_nonzero(x, axis=axis, keepdims=keepdims) if axis is None and not keepdims: return np.asarray(result) diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index c5ebb5ab..6a18a3b2 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -1,46 +1,31 @@ from __future__ import annotations -__all__ = [ - "ndarray", - "Device", - "Dtype", -] +__all__ = ["Array", "DType", "Device"] +_all_ignore = ["np"] -import sys -from typing import ( - Literal, - Union, - TYPE_CHECKING, -) +from typing import Literal, TYPE_CHECKING -from numpy import ( - ndarray, - dtype, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, -) +import numpy as np +from numpy import ndarray as Array Device = Literal["cpu"] -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] +if TYPE_CHECKING: + # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[] + DType = np.dtype[ + np.intp + | np.int8 + | np.int16 + | np.int32 + | np.int64 + | np.uint8 + | np.uint16 + | np.uint32 + | np.uint64 + | np.float32 + | np.float64 + | np.complex64 + | np.complex128 + | np.bool + ] else: - Dtype = dtype + DType = np.dtype diff --git a/array_api_compat/py.typed b/array_api_compat/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 4b727f1c..87d32d85 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -2,21 +2,14 @@ from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any - -from ..common import _aliases -from .._internal import get_xp - -from ._info import __array_namespace_info__ +from typing import List, Optional, Sequence, Tuple, Union import torch -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import List, Optional, Sequence, Tuple, Union - from ..common._typing import Device - from torch import dtype as Dtype - - array = torch.Tensor +from .._internal import get_xp +from ..common import _aliases +from ._info import __array_namespace_info__ +from ._typing import Array, Device, DType _int_dtypes = { torch.uint8, @@ -123,7 +116,7 @@ def _fix_promotion(x1, x2, only_scalar=True): _py_scalars = (bool, int, float, complex) -def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype: +def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType: num = len(arrays_and_dtypes) if num == 0: @@ -170,7 +163,7 @@ def _result_type(x, y): return torch.result_type(x, y) -def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: +def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: if not isinstance(from_, torch.dtype): from_ = from_.dtype return torch.can_cast(from_, to) @@ -216,13 +209,13 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: # of 'axis'. # torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745 -def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: +def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) return torch.amax(x, axis, keepdims=keepdims) -def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: +def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -235,7 +228,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 -def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array: +def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> Array: return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values def _normalize_axes(axis, ndim): @@ -280,13 +273,13 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): out = torch.unsqueeze(out, a) return out -def prod(x: array, +def prod(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim @@ -316,13 +309,13 @@ def prod(x: array, return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) -def sum(x: array, +def sum(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim @@ -347,12 +340,12 @@ def sum(x: array, return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) -def any(x: array, +def any(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim if axis == (): @@ -372,12 +365,12 @@ def any(x: array, # torch.any doesn't return bool for uint8 return torch.any(x, axis, keepdims=keepdims).to(torch.bool) -def all(x: array, +def all(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim if axis == (): @@ -397,12 +390,12 @@ def all(x: array, # torch.all doesn't return bool for uint8 return torch.all(x, axis, keepdims=keepdims).to(torch.bool) -def mean(x: array, +def mean(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -414,13 +407,13 @@ def mean(x: array, return res return torch.mean(x, axis, keepdims=keepdims, **kwargs) -def std(x: array, +def std(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -445,13 +438,13 @@ def std(x: array, return res return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs) -def var(x: array, +def var(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -474,11 +467,11 @@ def var(x: array, # torch.concat doesn't support dim=None # https://github.com/pytorch/pytorch/issues/70925 -def concat(arrays: Union[Tuple[array, ...], List[array]], +def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0, - **kwargs) -> array: + **kwargs) -> Array: if axis is None: arrays = tuple(ar.flatten() for ar in arrays) axis = 0 @@ -487,7 +480,7 @@ def concat(arrays: Union[Tuple[array, ...], List[array]], # torch.squeeze only accepts int dim and doesn't require it # https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was # added at https://github.com/pytorch/pytorch/pull/89017. -def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: +def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: if isinstance(axis, int): axis = (axis,) for a in axis: @@ -501,27 +494,27 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: return x # torch.broadcast_to uses size instead of shape -def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array: +def broadcast_to(x: Array, /, shape: Tuple[int, ...], **kwargs) -> Array: return torch.broadcast_to(x, shape, **kwargs) # torch.permute uses dims instead of axes -def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: +def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: return torch.permute(x, axes) # The axis parameter doesn't work for flip() and roll() # https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't # accept axis=None -def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: +def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: if axis is None: axis = tuple(range(x.ndim)) # torch.flip doesn't accept dim as an int but the method does # https://github.com/pytorch/pytorch/issues/18095 return x.flip(axis, **kwargs) -def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: +def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: return torch.roll(x, shift, axis, **kwargs) -def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: +def nonzero(x: Array, /, **kwargs) -> Tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return torch.nonzero(x, as_tuple=True, **kwargs) @@ -529,25 +522,25 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: # torch uses `dim` instead of `axis` def diff( - x: array, + x: Array, /, *, axis: int = -1, n: int = 1, - prepend: Optional[array] = None, - append: Optional[array] = None, -) -> array: + prepend: Optional[Array] = None, + append: Optional[Array] = None, +) -> Array: return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append) # torch uses `dim` instead of `axis`, does not have keepdims def count_nonzero( - x: array, + x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, -) -> array: +) -> Array: result = torch.count_nonzero(x, dim=axis) if keepdims: if axis is not None: @@ -557,17 +550,17 @@ def count_nonzero( return result - -def where(condition: array, x1: array, x2: array, /) -> array: +def where(condition: Array, x1: Array, x2: Array, /) -> Array: x1, x2 = _fix_promotion(x1, x2) return torch.where(condition, x1, x2) # torch.reshape doesn't have the copy keyword -def reshape(x: array, +def reshape(x: Array, /, shape: Tuple[int, ...], + *, copy: Optional[bool] = None, - **kwargs) -> array: + **kwargs) -> Array: if copy is not None: raise NotImplementedError("torch.reshape doesn't yet support the copy keyword") return torch.reshape(x, shape, **kwargs) @@ -581,9 +574,9 @@ def arange(start: Union[int, float], stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: if stop is None: start, stop = 0, start if step > 0 and stop <= start or step < 0 and stop >= start: @@ -602,9 +595,9 @@ def eye(n_rows: int, /, *, k: int = 0, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: if n_cols is None: n_cols = n_rows z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs) @@ -618,10 +611,10 @@ def linspace(start: Union[int, float], /, num: int, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, endpoint: bool = True, - **kwargs) -> array: + **kwargs) -> Array: if not endpoint: return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1] return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs) @@ -629,11 +622,11 @@ def linspace(start: Union[int, float], # torch.full does not accept an int size # https://github.com/pytorch/pytorch/issues/70906 def full(shape: Union[int, Tuple[int, ...]], - fill_value: Union[bool, int, float, complex], + fill_value: complex, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: if isinstance(shape, int): shape = (shape,) @@ -642,52 +635,52 @@ def full(shape: Union[int, Tuple[int, ...]], # ones, zeros, and empty do not accept shape as a keyword argument def ones(shape: Union[int, Tuple[int, ...]], *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: return torch.ones(shape, dtype=dtype, device=device, **kwargs) def zeros(shape: Union[int, Tuple[int, ...]], *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: return torch.zeros(shape, dtype=dtype, device=device, **kwargs) def empty(shape: Union[int, Tuple[int, ...]], *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: return torch.empty(shape, dtype=dtype, device=device, **kwargs) # tril and triu do not call the keyword argument k -def tril(x: array, /, *, k: int = 0) -> array: +def tril(x: Array, /, *, k: int = 0) -> Array: return torch.tril(x, k) -def triu(x: array, /, *, k: int = 0) -> array: +def triu(x: Array, /, *, k: int = 0) -> Array: return torch.triu(x, k) # Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742 -def expand_dims(x: array, /, *, axis: int = 0) -> array: +def expand_dims(x: Array, /, *, axis: int = 0) -> Array: return torch.unsqueeze(x, axis) def astype( - x: array, - dtype: Dtype, + x: Array, + dtype: DType, /, *, copy: bool = True, device: Optional[Device] = None, -) -> array: +) -> Array: if device is not None: return x.to(device, dtype=dtype, copy=copy) return x.to(dtype=dtype, copy=copy) -def broadcast_arrays(*arrays: array) -> List[array]: +def broadcast_arrays(*arrays: Array) -> List[Array]: shape = torch.broadcast_shapes(*[a.shape for a in arrays]) return [torch.broadcast_to(a, shape) for a in arrays] @@ -697,7 +690,7 @@ def broadcast_arrays(*arrays: array) -> List[array]: UniqueInverseResult) # https://github.com/pytorch/pytorch/issues/70920 -def unique_all(x: array) -> UniqueAllResult: +def unique_all(x: Array) -> UniqueAllResult: # torch.unique doesn't support returning indices. # https://github.com/pytorch/pytorch/issues/36748. The workaround # suggested in that issue doesn't actually function correctly (it relies @@ -710,7 +703,7 @@ def unique_all(x: array) -> UniqueAllResult: # counts[torch.isnan(values)] = 1 # return UniqueAllResult(values, indices, inverse_indices, counts) -def unique_counts(x: array) -> UniqueCountsResult: +def unique_counts(x: Array) -> UniqueCountsResult: values, counts = torch.unique(x, return_counts=True) # torch.unique incorrectly gives a 0 count for nan values. @@ -718,14 +711,14 @@ def unique_counts(x: array) -> UniqueCountsResult: counts[torch.isnan(values)] = 1 return UniqueCountsResult(values, counts) -def unique_inverse(x: array) -> UniqueInverseResult: +def unique_inverse(x: Array) -> UniqueInverseResult: values, inverse = torch.unique(x, return_inverse=True) return UniqueInverseResult(values, inverse) -def unique_values(x: array) -> array: +def unique_values(x: Array) -> Array: return torch.unique(x) -def matmul(x1: array, x2: array, /, **kwargs) -> array: +def matmul(x1: Array, x2: Array, /, **kwargs) -> Array: # torch.matmul doesn't type promote (but differently from _fix_promotion) x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.matmul(x1, x2, **kwargs) @@ -733,12 +726,19 @@ def matmul(x1: array, x2: array, /, **kwargs) -> array: matrix_transpose = get_xp(torch)(_aliases.matrix_transpose) _vecdot = get_xp(torch)(_aliases.vecdot) -def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return _vecdot(x1, x2, axis=axis) # torch.tensordot uses dims instead of axes -def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> array: +def tensordot( + x1: Array, + x2: Array, + /, + *, + axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, + **kwargs, +) -> Array: # Note: torch.tensordot fails with integer dtypes when there is only 1 # element in the axis (https://github.com/pytorch/pytorch/issues/84530). x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -746,7 +746,7 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], + dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]], *, _tuple=True, # Disallow nested tuples ) -> bool: """ @@ -781,7 +781,7 @@ def isdtype( else: return dtype == kind -def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array: +def take(x: Array, indices: Array, /, *, axis: Optional[int] = None, **kwargs) -> Array: if axis is None: if x.ndim != 1: raise ValueError("axis must be specified when ndim > 1") @@ -789,11 +789,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - return torch.index_select(x, axis, indices, **kwargs) -def take_along_axis(x: array, indices: array, /, *, axis: int = -1) -> array: +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: return torch.take_along_dim(x, indices, dim=axis) -def sign(x: array, /) -> array: +def sign(x: Array, /) -> Array: # torch sign() does not support complex numbers and does not propagate # nans. See https://github.com/data-apis/array-api-compat/issues/136 if x.dtype.is_complex: diff --git a/array_api_compat/torch/_typing.py b/array_api_compat/torch/_typing.py new file mode 100644 index 00000000..29ad3fa7 --- /dev/null +++ b/array_api_compat/torch/_typing.py @@ -0,0 +1,4 @@ +__all__ = ["Array", "DType", "Device"] + +from torch import dtype as DType, Tensor as Array +from ..common._typing import Device diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 3c9117ee..50e6a0d0 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -1,76 +1,75 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import torch - array = torch.Tensor - from typing import Union, Sequence, Literal +from typing import Union, Sequence, Literal -from torch.fft import * # noqa: F403 +import torch import torch.fft +from torch.fft import * # noqa: F403 + +from ._typing import Array # Several torch fft functions do not map axes to dim def fftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs) def ifftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs) def rfftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs) def irfftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs) def fftshift( - x: array, + x: Array, /, *, axes: Union[int, Sequence[int]] = None, **kwargs, -) -> array: +) -> Array: return torch.fft.fftshift(x, dim=axes, **kwargs) def ifftshift( - x: array, + x: Array, /, *, axes: Union[int, Sequence[int]] = None, **kwargs, -) -> array: +) -> Array: return torch.fft.ifftshift(x, dim=axes, **kwargs) diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index e26198b9..7b59a670 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,14 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import torch - array = torch.Tensor - from torch import dtype as Dtype - from typing import Optional, Union, Tuple, Literal - inf = float('inf') - -from ._aliases import _fix_promotion, sum +import torch +from typing import Optional, Union, Tuple from torch.linalg import * # noqa: F403 @@ -19,15 +12,17 @@ # outer is implemented in torch but aren't in the linalg namespace from torch import outer +from ._aliases import _fix_promotion, sum # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot +from ._typing import Array, DType # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 # torch.cross also does not support broadcasting when it would add new # dimensions https://github.com/pytorch/pytorch/issues/39656 -def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: +def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)): raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}") @@ -36,7 +31,7 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: x1, x2 = torch.broadcast_tensors(x1, x2) return torch_linalg.cross(x1, x2, dim=axis) -def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array: from ._aliases import isdtype x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -58,7 +53,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: return res[..., 0, 0] return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) -def solve(x1: array, x2: array, /, **kwargs) -> array: +def solve(x1: Array, x2: Array, /, **kwargs) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve # whenever @@ -79,19 +74,20 @@ def solve(x1: array, x2: array, /, **kwargs) -> array: return torch.linalg.solve(x1, x2, **kwargs) # torch.trace doesn't support the offset argument and doesn't support stacking -def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: +def trace(x: Array, /, *, offset: int = 0, dtype: Optional[DType] = None) -> Array: # Use our wrapped sum to make sure it does upcasting correctly return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) def vector_norm( - x: array, + x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - ord: Union[int, float, Literal[inf, -inf]] = 2, + # float stands for inf | -inf, which are not valid for Literal + ord: Union[int, float, float] = 2, **kwargs, -) -> array: +) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None if axis == (): out = kwargs.get('out') diff --git a/setup.py b/setup.py index 3d2b68a2..2368ccc4 100644 --- a/setup.py +++ b/setup.py @@ -33,5 +33,8 @@ "Programming Language :: Python :: 3.13", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", - ] + ], + package_data={ + "array_api_compat": ["py.typed"], + }, ) diff --git a/tests/test_all.py b/tests/test_all.py index d2e9b768..598fab62 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -15,6 +15,16 @@ from ._helpers import import_, wrapped_libraries import pytest +import typing + +TYPING_NAMES = frozenset(( + "Array", + "Device", + "DType", + "Namespace", + "NestedSequence", + "SupportsBufferProtocol", +)) @pytest.mark.skip(reason="TODO: starts failing after adding test_torch.py in gh-277") @pytest.mark.parametrize("library", ["common"] + wrapped_libraries) @@ -38,8 +48,11 @@ def test_all(library): dir_names = [n for n in dir(module) if not n.startswith('_')] if '__array_namespace_info__' in dir(module): dir_names.append('__array_namespace_info__') - ignore_all_names = getattr(module, '_all_ignore', []) - ignore_all_names += ['annotations', 'TYPE_CHECKING'] + ignore_all_names = set(getattr(module, '_all_ignore', ())) + ignore_all_names |= set(dir(typing)) + ignore_all_names |= {"annotations"} + if not module.__name__.endswith("._typing"): + ignore_all_names |= TYPING_NAMES dir_names = set(dir_names) - set(ignore_all_names) all_names = module.__all__