diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py
index 35262d3a..0d123b99 100644
--- a/array_api_compat/common/_aliases.py
+++ b/array_api_compat/common/_aliases.py
@@ -4,15 +4,11 @@
 
 from __future__ import annotations
 
-from typing import TYPE_CHECKING
-if TYPE_CHECKING:
-    from typing import Optional, Sequence, Tuple, Union
-    from ._typing import ndarray, Device, Dtype
-
-from typing import NamedTuple
 import inspect
+from typing import NamedTuple, Optional, Sequence, Tuple, Union
 
 from ._helpers import array_namespace, _check_device, device, is_cupy_namespace
+from ._typing import Array, Device, DType, Namespace
 
 # These functions are modified from the NumPy versions.
 
@@ -24,29 +20,34 @@ def arange(
     stop: Optional[Union[int, float]] = None,
     step: Union[int, float] = 1,
     *,
-    xp,
-    dtype: Optional[Dtype] = None,
+    xp: Namespace,
+    dtype: Optional[DType] = None,
     device: Optional[Device] = None,
-    **kwargs
-) -> ndarray:
+    **kwargs,
+) -> Array:
     _check_device(xp, device)
     return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
 
 def empty(
     shape: Union[int, Tuple[int, ...]],
-    xp,
+    xp: Namespace,
     *,
-    dtype: Optional[Dtype] = None,
+    dtype: Optional[DType] = None,
     device: Optional[Device] = None,
-    **kwargs
-) -> ndarray:
+    **kwargs,
+) -> Array:
     _check_device(xp, device)
     return xp.empty(shape, dtype=dtype, **kwargs)
 
 def empty_like(
-    x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
-    **kwargs
-) -> ndarray:
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    dtype: Optional[DType] = None, 
+    device: Optional[Device] = None,
+    **kwargs,
+) -> Array:
     _check_device(xp, device)
     return xp.empty_like(x, dtype=dtype, **kwargs)
 
@@ -55,37 +56,37 @@ def eye(
     n_cols: Optional[int] = None,
     /,
     *,
-    xp,
+    xp: Namespace,
     k: int = 0,
-    dtype: Optional[Dtype] = None,
+    dtype: Optional[DType] = None,
     device: Optional[Device] = None,
     **kwargs,
-) -> ndarray:
+) -> Array:
     _check_device(xp, device)
     return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
 
 def full(
     shape: Union[int, Tuple[int, ...]],
-    fill_value: Union[int, float],
-    xp,
+    fill_value: complex,
+    xp: Namespace,
     *,
-    dtype: Optional[Dtype] = None,
+    dtype: Optional[DType] = None,
     device: Optional[Device] = None,
     **kwargs,
-) -> ndarray:
+) -> Array:
     _check_device(xp, device)
     return xp.full(shape, fill_value, dtype=dtype, **kwargs)
 
 def full_like(
-    x: ndarray,
+    x: Array,
     /,
-    fill_value: Union[int, float],
+    fill_value: complex,
     *,
-    xp,
-    dtype: Optional[Dtype] = None,
+    xp: Namespace,
+    dtype: Optional[DType] = None,
     device: Optional[Device] = None,
     **kwargs,
-) -> ndarray:
+) -> Array:
     _check_device(xp, device)
     return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
 
@@ -95,48 +96,58 @@ def linspace(
     /,
     num: int,
     *,
-    xp,
-    dtype: Optional[Dtype] = None,
+    xp: Namespace,
+    dtype: Optional[DType] = None,
     device: Optional[Device] = None,
     endpoint: bool = True,
     **kwargs,
-) -> ndarray:
+) -> Array:
     _check_device(xp, device)
     return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
 
 def ones(
     shape: Union[int, Tuple[int, ...]],
-    xp,
+    xp: Namespace,
     *,
-    dtype: Optional[Dtype] = None,
+    dtype: Optional[DType] = None,
     device: Optional[Device] = None,
     **kwargs,
-) -> ndarray:
+) -> Array:
     _check_device(xp, device)
     return xp.ones(shape, dtype=dtype, **kwargs)
 
 def ones_like(
-    x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    dtype: Optional[DType] = None,
+    device: Optional[Device] = None,
     **kwargs,
-) -> ndarray:
+) -> Array:
     _check_device(xp, device)
     return xp.ones_like(x, dtype=dtype, **kwargs)
 
 def zeros(
     shape: Union[int, Tuple[int, ...]],
-    xp,
+    xp: Namespace,
     *,
-    dtype: Optional[Dtype] = None,
+    dtype: Optional[DType] = None,
     device: Optional[Device] = None,
     **kwargs,
-) -> ndarray:
+) -> Array:
     _check_device(xp, device)
     return xp.zeros(shape, dtype=dtype, **kwargs)
 
 def zeros_like(
-    x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    dtype: Optional[DType] = None,
+    device: Optional[Device] = None,
     **kwargs,
-) -> ndarray:
+) -> Array:
     _check_device(xp, device)
     return xp.zeros_like(x, dtype=dtype, **kwargs)
 
@@ -150,23 +161,23 @@ def zeros_like(
 # Note that these named tuples aren't actually part of the standard namespace,
 # but I don't see any issue with exporting the names here regardless.
 class UniqueAllResult(NamedTuple):
-    values: ndarray
-    indices: ndarray
-    inverse_indices: ndarray
-    counts: ndarray
+    values: Array
+    indices: Array
+    inverse_indices: Array
+    counts: Array
 
 
 class UniqueCountsResult(NamedTuple):
-    values: ndarray
-    counts: ndarray
+    values: Array
+    counts: Array
 
 
 class UniqueInverseResult(NamedTuple):
-    values: ndarray
-    inverse_indices: ndarray
+    values: Array
+    inverse_indices: Array
 
 
-def _unique_kwargs(xp):
+def _unique_kwargs(xp: Namespace) -> dict[str, bool]:
     # Older versions of NumPy and CuPy do not have equal_nan. Rather than
     # trying to parse version numbers, just check if equal_nan is in the
     # signature.
@@ -175,7 +186,7 @@ def _unique_kwargs(xp):
         return {'equal_nan': False}
     return {}
 
-def unique_all(x: ndarray, /, xp) -> UniqueAllResult:
+def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult:
     kwargs = _unique_kwargs(xp)
     values, indices, inverse_indices, counts = xp.unique(
         x,
@@ -195,7 +206,7 @@ def unique_all(x: ndarray, /, xp) -> UniqueAllResult:
     )
 
 
-def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult:
+def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult:
     kwargs = _unique_kwargs(xp)
     res = xp.unique(
         x,
@@ -208,7 +219,7 @@ def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult:
     return UniqueCountsResult(*res)
 
 
-def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult:
+def unique_inverse(x: Array, /, xp: Namespace) -> UniqueInverseResult:
     kwargs = _unique_kwargs(xp)
     values, inverse_indices = xp.unique(
         x,
@@ -223,7 +234,7 @@ def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult:
     return UniqueInverseResult(values, inverse_indices)
 
 
-def unique_values(x: ndarray, /, xp) -> ndarray:
+def unique_values(x: Array, /, xp: Namespace) -> Array:
     kwargs = _unique_kwargs(xp)
     return xp.unique(
         x,
@@ -236,42 +247,42 @@ def unique_values(x: ndarray, /, xp) -> ndarray:
 # These functions have different keyword argument names
 
 def std(
-    x: ndarray,
+    x: Array,
     /,
-    xp,
+    xp: Namespace,
     *,
     axis: Optional[Union[int, Tuple[int, ...]]] = None,
-    correction: Union[int, float] = 0.0, # correction instead of ddof
+    correction: Union[int, float] = 0.0,  # correction instead of ddof
     keepdims: bool = False,
     **kwargs,
-) -> ndarray:
+) -> Array:
     return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
 
 def var(
-    x: ndarray,
+    x: Array,
     /,
-    xp,
+    xp: Namespace,
     *,
     axis: Optional[Union[int, Tuple[int, ...]]] = None,
-    correction: Union[int, float] = 0.0, # correction instead of ddof
+    correction: Union[int, float] = 0.0,  # correction instead of ddof
     keepdims: bool = False,
     **kwargs,
-) -> ndarray:
+) -> Array:
     return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
 
 # cumulative_sum is renamed from cumsum, and adds the include_initial keyword
 # argument
 
 def cumulative_sum(
-    x: ndarray,
+    x: Array,
     /,
-    xp,
+    xp: Namespace,
     *,
     axis: Optional[int] = None,
-    dtype: Optional[Dtype] = None,
+    dtype: Optional[DType] = None,
     include_initial: bool = False,
-    **kwargs
-) -> ndarray:
+    **kwargs,
+) -> Array:
     wrapped_xp = array_namespace(x)
 
     # TODO: The standard is not clear about what should happen when x.ndim == 0.
@@ -294,15 +305,15 @@ def cumulative_sum(
 
 
 def cumulative_prod(
-    x: ndarray,
+    x: Array,
     /,
-    xp,
+    xp: Namespace,
     *,
     axis: Optional[int] = None,
-    dtype: Optional[Dtype] = None,
+    dtype: Optional[DType] = None,
     include_initial: bool = False,
-    **kwargs
-) -> ndarray:
+    **kwargs,
+) -> Array:
     wrapped_xp = array_namespace(x)
 
     if axis is None:
@@ -325,17 +336,18 @@ def cumulative_prod(
 # The min and max argument names in clip are different and not optional in numpy, and type
 # promotion behavior is different.
 def clip(
-    x: ndarray,
+    x: Array,
     /,
-    min: Optional[Union[int, float, ndarray]] = None,
-    max: Optional[Union[int, float, ndarray]] = None,
+    min: Optional[Union[int, float, Array]] = None,
+    max: Optional[Union[int, float, Array]] = None,
     *,
-    xp,
+    xp: Namespace,
     # TODO: np.clip has other ufunc kwargs
-    out: Optional[ndarray] = None,
-) -> ndarray:
+    out: Optional[Array] = None,
+) -> Array:
     def _isscalar(a):
         return isinstance(a, (int, float, type(None)))
+
     min_shape = () if _isscalar(min) else min.shape
     max_shape = () if _isscalar(max) else max.shape
 
@@ -390,15 +402,19 @@ def _isscalar(a):
     return out[()]
 
 # Unlike transpose(), the axes argument to permute_dims() is required.
-def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
+def permute_dims(x: Array, /, axes: Tuple[int, ...], xp: Namespace) -> Array:
     return xp.transpose(x, axes)
 
 # np.reshape calls the keyword argument 'newshape' instead of 'shape'
-def reshape(x: ndarray,
-            /,
-            shape: Tuple[int, ...],
-            xp, copy: Optional[bool] = None,
-            **kwargs) -> ndarray:
+def reshape(
+    x: Array,
+    /,
+    shape: Tuple[int, ...],
+    xp: Namespace,
+    *,
+    copy: Optional[bool] = None,
+    **kwargs,
+) -> Array:
     if copy is True:
         x = x.copy()
     elif copy is False:
@@ -410,9 +426,15 @@ def reshape(x: ndarray,
 # The descending keyword is new in sort and argsort, and 'kind' replaced with
 # 'stable'
 def argsort(
-    x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    axis: int = -1,
+    descending: bool = False,
+    stable: bool = True,
     **kwargs,
-) -> ndarray:
+) -> Array:
     # Note: this keyword argument is different, and the default is different.
     # We set it in kwargs like this because numpy.sort uses kind='quicksort'
     # as the default whereas cupy.sort uses kind=None.
@@ -435,9 +457,15 @@ def argsort(
     return res
 
 def sort(
-    x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    axis: int = -1,
+    descending: bool = False,
+    stable: bool = True,
     **kwargs,
-) -> ndarray:
+) -> Array:
     # Note: this keyword argument is different, and the default is different.
     # We set it in kwargs like this because numpy.sort uses kind='quicksort'
     # as the default whereas cupy.sort uses kind=None.
@@ -449,50 +477,51 @@ def sort(
     return res
 
 # nonzero should error for zero-dimensional arrays
-def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
+def nonzero(x: Array, /, xp: Namespace, **kwargs) -> Tuple[Array, ...]:
     if x.ndim == 0:
         raise ValueError("nonzero() does not support zero-dimensional arrays")
     return xp.nonzero(x, **kwargs)
 
 # ceil, floor, and trunc return integers for integer inputs
 
-def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
+def ceil(x: Array, /, xp: Namespace, **kwargs) -> Array:
     if xp.issubdtype(x.dtype, xp.integer):
         return x
     return xp.ceil(x, **kwargs)
 
-def floor(x: ndarray, /, xp, **kwargs) -> ndarray:
+def floor(x: Array, /, xp: Namespace, **kwargs) -> Array:
     if xp.issubdtype(x.dtype, xp.integer):
         return x
     return xp.floor(x, **kwargs)
 
-def trunc(x: ndarray, /, xp, **kwargs) -> ndarray:
+def trunc(x: Array, /, xp: Namespace, **kwargs) -> Array:
     if xp.issubdtype(x.dtype, xp.integer):
         return x
     return xp.trunc(x, **kwargs)
 
 # linear algebra functions
 
-def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray:
+def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array:
     return xp.matmul(x1, x2, **kwargs)
 
 # Unlike transpose, matrix_transpose only transposes the last two axes.
-def matrix_transpose(x: ndarray, /, xp) -> ndarray:
+def matrix_transpose(x: Array, /, xp: Namespace) -> Array:
     if x.ndim < 2:
         raise ValueError("x must be at least 2-dimensional for matrix_transpose")
     return xp.swapaxes(x, -1, -2)
 
-def tensordot(x1: ndarray,
-              x2: ndarray,
-              /,
-              xp,
-              *,
-              axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
-              **kwargs,
-) -> ndarray:
+def tensordot(
+    x1: Array,
+    x2: Array,
+    /,
+    xp: Namespace,
+    *,
+    axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
+    **kwargs,
+) -> Array:
     return xp.tensordot(x1, x2, axes=axes, **kwargs)
 
-def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
+def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array:
     if x1.shape[axis] != x2.shape[axis]:
         raise ValueError("x1 and x2 must have the same size along the given axis")
 
@@ -511,8 +540,11 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
 # isdtype is a new function in the 2022.12 array API specification.
 
 def isdtype(
-    dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp,
-    *, _tuple=True, # Disallow nested tuples
+    dtype: DType,
+    kind: Union[DType, str, Tuple[Union[DType, str], ...]],
+    xp: Namespace,
+    *,
+    _tuple: bool = True, # Disallow nested tuples
 ) -> bool:
     """
     Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
@@ -551,14 +583,14 @@ def isdtype(
         return dtype == kind
 
 # unstack is a new function in the 2023.12 array API standard
-def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
+def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> Tuple[Array, ...]:
     if x.ndim == 0:
         raise ValueError("Input array must be at least 1-d.")
     return tuple(xp.moveaxis(x, axis, 0))
 
 # numpy 1.26 does not use the standard definition for sign on complex numbers
 
-def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
+def sign(x: Array, /, xp: Namespace, **kwargs) -> Array:
     if isdtype(x.dtype, 'complex floating', xp=xp):
         out = (x/xp.abs(x, **kwargs))[...]
         # sign(0) = 0 but the above formula would give nan
diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py
index e5caebef..bd2a4e1a 100644
--- a/array_api_compat/common/_fft.py
+++ b/array_api_compat/common/_fft.py
@@ -1,149 +1,148 @@
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Union, Optional, Literal
+from collections.abc import Sequence
+from typing import Union, Optional, Literal
 
-if TYPE_CHECKING:
-    from ._typing import Device, ndarray, DType
-    from collections.abc import Sequence
+from ._typing import Device, Array, DType, Namespace
 
 # Note: NumPy fft functions improperly upcast float32 and complex64 to
 # complex128, which is why we require wrapping them all here.
 
 def fft(
-    x: ndarray,
+    x: Array,
     /,
-    xp,
+    xp: Namespace,
     *,
     n: Optional[int] = None,
     axis: int = -1,
     norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+) -> Array:
     res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
         return res.astype(xp.complex64)
     return res
 
 def ifft(
-    x: ndarray,
+    x: Array,
     /,
-    xp,
+    xp: Namespace,
     *,
     n: Optional[int] = None,
     axis: int = -1,
     norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+) -> Array:
     res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
         return res.astype(xp.complex64)
     return res
 
 def fftn(
-    x: ndarray,
+    x: Array,
     /,
-    xp,
+    xp: Namespace,
     *,
     s: Sequence[int] = None,
     axes: Sequence[int] = None,
     norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+) -> Array:
     res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
         return res.astype(xp.complex64)
     return res
 
 def ifftn(
-    x: ndarray,
+    x: Array,
     /,
-    xp,
+    xp: Namespace,
     *,
     s: Sequence[int] = None,
     axes: Sequence[int] = None,
     norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+) -> Array:
     res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
         return res.astype(xp.complex64)
     return res
 
 def rfft(
-    x: ndarray,
+    x: Array,
     /,
-    xp,
+    xp: Namespace,
     *,
     n: Optional[int] = None,
     axis: int = -1,
     norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+) -> Array:
     res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
     if x.dtype == xp.float32:
         return res.astype(xp.complex64)
     return res
 
 def irfft(
-    x: ndarray,
+    x: Array,
     /,
-    xp,
+    xp: Namespace,
     *,
     n: Optional[int] = None,
     axis: int = -1,
     norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+) -> Array:
     res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
     if x.dtype == xp.complex64:
         return res.astype(xp.float32)
     return res
 
 def rfftn(
-    x: ndarray,
+    x: Array,
     /,
-    xp,
+    xp: Namespace,
     *,
     s: Sequence[int] = None,
     axes: Sequence[int] = None,
     norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+) -> Array:
     res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
     if x.dtype == xp.float32:
         return res.astype(xp.complex64)
     return res
 
 def irfftn(
-    x: ndarray,
+    x: Array,
     /,
-    xp,
+    xp: Namespace,
     *,
     s: Sequence[int] = None,
     axes: Sequence[int] = None,
     norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+) -> Array:
     res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
     if x.dtype == xp.complex64:
         return res.astype(xp.float32)
     return res
 
 def hfft(
-    x: ndarray,
+    x: Array,
     /,
-    xp,
+    xp: Namespace,
     *,
     n: Optional[int] = None,
     axis: int = -1,
     norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+) -> Array:
     res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
         return res.astype(xp.float32)
     return res
 
 def ihfft(
-    x: ndarray,
+    x: Array,
     /,
-    xp,
+    xp: Namespace,
     *,
     n: Optional[int] = None,
     axis: int = -1,
     norm: Literal["backward", "ortho", "forward"] = "backward",
-) -> ndarray:
+) -> Array:
     res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
     if x.dtype in [xp.float32, xp.complex64]:
         return res.astype(xp.complex64)
@@ -152,12 +151,12 @@ def ihfft(
 def fftfreq(
     n: int,
     /,
-    xp,
+    xp: Namespace,
     *,
     d: float = 1.0,
     dtype: Optional[DType] = None,
-    device: Optional[Device] = None
-) -> ndarray:
+    device: Optional[Device] = None,
+) -> Array:
     if device not in ["cpu", None]:
         raise ValueError(f"Unsupported device {device!r}")
     res = xp.fft.fftfreq(n, d=d)
@@ -168,12 +167,12 @@ def fftfreq(
 def rfftfreq(
     n: int,
     /,
-    xp,
+    xp: Namespace,
     *,
     d: float = 1.0,
     dtype: Optional[DType] = None,
-    device: Optional[Device] = None
-) -> ndarray:
+    device: Optional[Device] = None,
+) -> Array:
     if device not in ["cpu", None]:
         raise ValueError(f"Unsupported device {device!r}")
     res = xp.fft.rfftfreq(n, d=d)
@@ -181,10 +180,14 @@ def rfftfreq(
         return res.astype(dtype)
     return res
 
-def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
+def fftshift(
+    x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
+) -> Array:
     return xp.fft.fftshift(x, axes=axes)
 
-def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
+def ifftshift(
+    x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
+) -> Array:
     return xp.fft.ifftshift(x, axes=axes)
 
 __all__ = [
diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py
index 791edb81..6d95069d 100644
--- a/array_api_compat/common/_helpers.py
+++ b/array_api_compat/common/_helpers.py
@@ -7,16 +7,14 @@
 """
 from __future__ import annotations
 
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
-    from typing import Optional, Union, Any
-    from ._typing import Array, Device, Namespace
-
 import sys
 import math
 import inspect
 import warnings
+from typing import Optional, Union, Any
+
+from ._typing import Array, Device, Namespace
+
 
 def _is_jax_zero_gradient_array(x: object) -> bool:
     """Return True if `x` is a zero-gradient array.
@@ -268,7 +266,7 @@ def _compat_module_name() -> str:
     return __name__.removesuffix('.common._helpers')
 
 
-def is_numpy_namespace(xp) -> bool:
+def is_numpy_namespace(xp: Namespace) -> bool:
     """
     Returns True if `xp` is a NumPy namespace.
 
@@ -289,7 +287,7 @@ def is_numpy_namespace(xp) -> bool:
     return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
 
 
-def is_cupy_namespace(xp) -> bool:
+def is_cupy_namespace(xp: Namespace) -> bool:
     """
     Returns True if `xp` is a CuPy namespace.
 
@@ -310,7 +308,7 @@ def is_cupy_namespace(xp) -> bool:
     return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
 
 
-def is_torch_namespace(xp) -> bool:
+def is_torch_namespace(xp: Namespace) -> bool:
     """
     Returns True if `xp` is a PyTorch namespace.
 
@@ -331,7 +329,7 @@ def is_torch_namespace(xp) -> bool:
     return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
 
 
-def is_ndonnx_namespace(xp) -> bool:
+def is_ndonnx_namespace(xp: Namespace) -> bool:
     """
     Returns True if `xp` is an NDONNX namespace.
 
@@ -350,7 +348,7 @@ def is_ndonnx_namespace(xp) -> bool:
     return xp.__name__ == 'ndonnx'
 
 
-def is_dask_namespace(xp) -> bool:
+def is_dask_namespace(xp: Namespace) -> bool:
     """
     Returns True if `xp` is a Dask namespace.
 
@@ -371,7 +369,7 @@ def is_dask_namespace(xp) -> bool:
     return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
 
 
-def is_jax_namespace(xp) -> bool:
+def is_jax_namespace(xp: Namespace) -> bool:
     """
     Returns True if `xp` is a JAX namespace.
 
@@ -393,7 +391,7 @@ def is_jax_namespace(xp) -> bool:
     return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
 
 
-def is_pydata_sparse_namespace(xp) -> bool:
+def is_pydata_sparse_namespace(xp: Namespace) -> bool:
     """
     Returns True if `xp` is a pydata/sparse namespace.
 
@@ -412,7 +410,7 @@ def is_pydata_sparse_namespace(xp) -> bool:
     return xp.__name__ == 'sparse'
 
 
-def is_array_api_strict_namespace(xp) -> bool:
+def is_array_api_strict_namespace(xp: Namespace) -> bool:
     """
     Returns True if `xp` is an array-api-strict namespace.
 
@@ -439,7 +437,11 @@ def _check_api_version(api_version: str) -> None:
         raise ValueError("Only the 2024.12 version of the array API specification is currently supported")
 
 
-def array_namespace(*xs, api_version=None, use_compat=None) -> Namespace:
+def array_namespace(
+    *xs: Union[Array, bool, int, float, complex, None],
+    api_version: Optional[str] = None,
+    use_compat: Optional[bool] = None,
+) -> Namespace:
     """
     Get the array API compatible namespace for the arrays `xs`.
 
diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py
index bfa1f1b9..c77ee3b8 100644
--- a/array_api_compat/common/_linalg.py
+++ b/array_api_compat/common/_linalg.py
@@ -1,11 +1,7 @@
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, NamedTuple
-if TYPE_CHECKING:
-    from typing import Literal, Optional, Tuple, Union
-    from ._typing import ndarray
-
 import math
+from typing import Literal, NamedTuple, Optional, Tuple, Union
 
 import numpy as np
 if np.__version__[0] == "2":
@@ -15,50 +11,53 @@
 
 from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
 from .._internal import get_xp
+from ._typing import Array, Namespace
 
 # These are in the main NumPy namespace but not in numpy.linalg
-def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray:
+def cross(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1, **kwargs) -> Array:
     return xp.cross(x1, x2, axis=axis, **kwargs)
 
-def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray:
+def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array:
     return xp.outer(x1, x2, **kwargs)
 
 class EighResult(NamedTuple):
-    eigenvalues: ndarray
-    eigenvectors: ndarray
+    eigenvalues: Array
+    eigenvectors: Array
 
 class QRResult(NamedTuple):
-    Q: ndarray
-    R: ndarray
+    Q: Array
+    R: Array
 
 class SlogdetResult(NamedTuple):
-    sign: ndarray
-    logabsdet: ndarray
+    sign: Array
+    logabsdet: Array
 
 class SVDResult(NamedTuple):
-    U: ndarray
-    S: ndarray
-    Vh: ndarray
+    U: Array
+    S: Array
+    Vh: Array
 
 # These functions are the same as their NumPy counterparts except they return
 # a namedtuple.
-def eigh(x: ndarray, /, xp, **kwargs) -> EighResult:
+def eigh(x: Array, /, xp: Namespace, **kwargs) -> EighResult:
     return EighResult(*xp.linalg.eigh(x, **kwargs))
 
-def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced',
+def qr(x: Array, /, xp: Namespace, *, mode: Literal['reduced', 'complete'] = 'reduced',
        **kwargs) -> QRResult:
     return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs))
 
-def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult:
+def slogdet(x: Array, /, xp: Namespace, **kwargs) -> SlogdetResult:
     return SlogdetResult(*xp.linalg.slogdet(x, **kwargs))
 
-def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult:
+def svd(
+    x: Array, /, xp: Namespace, *, full_matrices: bool = True, **kwargs
+) -> SVDResult:
     return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs))
 
 # These functions have additional keyword arguments
 
 # The upper keyword argument is new from NumPy
-def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray:
+def cholesky(x: Array, /, xp: Namespace, *, upper: bool = False, **kwargs) -> Array:
     L = xp.linalg.cholesky(x, **kwargs)
     if upper:
         U = get_xp(xp)(matrix_transpose)(L)
@@ -69,12 +68,12 @@ def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray:
 
 # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
 # Note that it has a different semantic meaning from tol and rcond.
-def matrix_rank(x: ndarray,
+def matrix_rank(x: Array,
                 /,
-                xp,
+                xp: Namespace,
                 *,
-                rtol: Optional[Union[float, ndarray]] = None,
-                **kwargs) -> ndarray:
+                rtol: Optional[Union[float, Array]] = None,
+                **kwargs) -> Array:
     # this is different from xp.linalg.matrix_rank, which supports 1
     # dimensional arrays.
     if x.ndim < 2:
@@ -88,7 +87,9 @@ def matrix_rank(x: ndarray,
         tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis]
     return xp.count_nonzero(S > tol, axis=-1)
 
-def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray:
+def pinv(
+    x: Array, /, xp: Namespace, *, rtol: Optional[Union[float, Array]] = None, **kwargs
+) -> Array:
     # this is different from xp.linalg.pinv, which does not multiply the
     # default tolerance by max(M, N).
     if rtol is None:
@@ -97,15 +98,30 @@ def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **k
 
 # These functions are new in the array API spec
 
-def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray:
+def matrix_norm(
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    keepdims: bool = False,
+    ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro',
+) -> Array:
     return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
 
 # svdvals is not in NumPy (but it is in SciPy). It is equivalent to
 # xp.linalg.svd(compute_uv=False).
-def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]:
+def svdvals(x: Array, /, xp: Namespace) -> Union[Array, Tuple[Array, ...]]:
     return xp.linalg.svd(x, compute_uv=False)
 
-def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray:
+def vector_norm(
+    x: Array,
+    /,
+    xp: Namespace,
+    *,
+    axis: Optional[Union[int, Tuple[int, ...]]] = None,
+    keepdims: bool = False,
+    ord: Optional[Union[int, float]] = 2,
+) -> Array:
     # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
     # when axis=None and the input is 2-D, so to force a vector norm, we make
     # it so the input is 1-D (for axis=None), or reshape so that norm is done
@@ -143,11 +159,15 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]
 # xp.diagonal and xp.trace operate on the first two axes whereas these
 # operates on the last two
 
-def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
+def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs) -> Array:
     return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
 
-def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray:
-    return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
+def trace(
+    x: Array, /, xp: Namespace, *, offset: int = 0, dtype=None, **kwargs
+) -> Array:
+    return xp.asarray(
+        xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)
+    )
 
 __all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
            'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py
index d8acdef7..4c3b356b 100644
--- a/array_api_compat/common/_typing.py
+++ b/array_api_compat/common/_typing.py
@@ -1,26 +1,24 @@
 from __future__ import annotations
+from types import ModuleType as Namespace
+from typing import Any, TypeVar, Protocol
 
 __all__ = [
+    "Array",
+    "DType",
+    "Device",
+    "Namespace",
     "NestedSequence",
     "SupportsBufferProtocol",
 ]
 
-from types import ModuleType
-from typing import (
-    Any,
-    TypeVar,
-    Protocol,
-)
-
 _T_co = TypeVar("_T_co", covariant=True)
 
 class NestedSequence(Protocol[_T_co]):
     def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
     def __len__(self, /) -> int: ...
 
-SupportsBufferProtocol = Any
 
+SupportsBufferProtocol = Any
 Array = Any
 Device = Any
 DType = Any
-Namespace = ModuleType
diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py
index 30d9fe48..ebc7ccd9 100644
--- a/array_api_compat/cupy/_aliases.py
+++ b/array_api_compat/cupy/_aliases.py
@@ -1,16 +1,14 @@
 from __future__ import annotations
 
+from typing import Optional
+
 import cupy as cp
 
 from ..common import _aliases, _helpers
+from ..common._typing import NestedSequence, SupportsBufferProtocol
 from .._internal import get_xp
-
 from ._info import __array_namespace_info__
-
-from typing import TYPE_CHECKING
-if TYPE_CHECKING:
-    from typing import Optional, Union
-    from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
+from ._typing import Array, Device, DType
 
 bool = cp.bool_
 
@@ -66,23 +64,19 @@
 
 _copy_default = object()
 
+
 # asarray also adds the copy keyword, which is not present in numpy 1.0.
 def asarray(
-    obj: Union[
-        ndarray,
-        bool,
-        int,
-        float,
-        NestedSequence[bool | int | float],
-        SupportsBufferProtocol,
-    ],
+    obj: (
+        Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol
+    ),
     /,
     *,
-    dtype: Optional[Dtype] = None,
+    dtype: Optional[DType] = None,
     device: Optional[Device] = None,
     copy: Optional[bool] = _copy_default,
     **kwargs,
-) -> ndarray:
+) -> Array:
     """
     Array API compatibility wrapper for asarray().
 
@@ -112,13 +106,13 @@ def asarray(
 
 
 def astype(
-    x: ndarray,
-    dtype: Dtype,
+    x: Array,
+    dtype: DType,
     /,
     *,
     copy: bool = True,
     device: Optional[Device] = None,
-) -> ndarray:
+) -> Array:
     if device is None:
         return x.astype(dtype=dtype, copy=copy)
     out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device)
@@ -127,10 +121,10 @@ def astype(
 
 # cupy.count_nonzero does not have keepdims
 def count_nonzero(
-    x: ndarray,
+    x: Array,
     axis=None,
     keepdims=False
-) -> ndarray:
+) -> Array:
    result = cp.count_nonzero(x, axis)
    if keepdims:
        if axis is None:
diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py
index f3d9aab6..66af5d19 100644
--- a/array_api_compat/cupy/_typing.py
+++ b/array_api_compat/cupy/_typing.py
@@ -1,46 +1,31 @@
 from __future__ import annotations
 
-__all__ = [
-    "ndarray",
-    "Device",
-    "Dtype",
-]
+__all__ = ["Array", "DType", "Device"]
+_all_ignore = ["cp"]
 
-import sys
-from typing import (
-    Union,
-    TYPE_CHECKING,
-)
-
-from cupy import (
-    ndarray,
-    dtype,
-    int8,
-    int16,
-    int32,
-    int64,
-    uint8,
-    uint16,
-    uint32,
-    uint64,
-    float32,
-    float64,
-)
+from typing import TYPE_CHECKING
 
+import cupy as cp
+from cupy import ndarray as Array
 from cupy.cuda.device import Device
 
-if TYPE_CHECKING or sys.version_info >= (3, 9):
-    Dtype = dtype[Union[
-        int8,
-        int16,
-        int32,
-        int64,
-        uint8,
-        uint16,
-        uint32,
-        uint64,
-        float32,
-        float64,
-    ]]
+if TYPE_CHECKING:
+    # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[]
+    DType = cp.dtype[
+        cp.intp
+        | cp.int8
+        | cp.int16
+        | cp.int32
+        | cp.int64
+        | cp.uint8
+        | cp.uint16
+        | cp.uint32
+        | cp.uint64
+        | cp.float32
+        | cp.float64
+        | cp.complex64
+        | cp.complex128
+        | cp.bool_
+    ]
 else:
-    Dtype = dtype
+    DType = cp.dtype
diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py
index 80d66281..e737cebd 100644
--- a/array_api_compat/dask/array/_aliases.py
+++ b/array_api_compat/dask/array/_aliases.py
@@ -1,16 +1,10 @@
 from __future__ import annotations
 
-from typing import Callable
-
-from ...common import _aliases, array_namespace
-
-from ..._internal import get_xp
-
-from ._info import __array_namespace_info__
+from typing import Callable, Optional, Union
 
 import numpy as np
 from numpy import (
-    # Dtypes
+    # dtypes
     iinfo,
     finfo,
     bool_ as bool,
@@ -29,22 +23,19 @@
     can_cast,
     result_type,
 )
-
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
-    from typing import Optional, Union
-
-    from ...common._typing import (
-        Device,
-        Dtype,
-        Array,
-        NestedSequence,
-        SupportsBufferProtocol,
-    )
-
 import dask.array as da
 
+from ...common import _aliases, array_namespace
+from ...common._typing import (
+    Array,
+    Device,
+    DType,
+    NestedSequence,
+    SupportsBufferProtocol,
+)
+from ..._internal import get_xp
+from ._info import __array_namespace_info__
+
 isdtype = get_xp(np)(_aliases.isdtype)
 unstack = get_xp(da)(_aliases.unstack)
 
@@ -52,7 +43,7 @@
 # da.astype doesn't respect copy=True
 def astype(
     x: Array,
-    dtype: Dtype,
+    dtype: DType,
     /,
     *,
     copy: bool = True,
@@ -84,7 +75,7 @@ def arange(
     stop: Optional[Union[int, float]] = None,
     step: Union[int, float] = 1,
     *,
-    dtype: Optional[Dtype] = None,
+    dtype: Optional[DType] = None,
     device: Optional[Device] = None,
     **kwargs,
 ) -> Array:
@@ -144,17 +135,12 @@ def arange(
 
 # asarray also adds the copy keyword, which is not present in numpy 1.0.
 def asarray(
-    obj: Union[
-        Array,
-        bool,
-        int,
-        float,
-        NestedSequence[bool | int | float],
-        SupportsBufferProtocol,
-    ],
+    obj: (
+        Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol
+    ),
     /,
     *,
-    dtype: Optional[Dtype] = None,
+    dtype: Optional[DType] = None,
     device: Optional[Device] = None,
     copy: Optional[Union[bool, np._CopyMode]] = None,
     **kwargs,
@@ -360,4 +346,4 @@ def count_nonzero(
                     'complex64', 'complex128', 'iinfo', 'finfo',
                     'can_cast', 'count_nonzero', 'result_type']
 
-_all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"]
+_all_ignore = ["array_namespace", "get_xp", "da", "np"]
diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py
index aebd86f7..3f40dffe 100644
--- a/array_api_compat/dask/array/fft.py
+++ b/array_api_compat/dask/array/fft.py
@@ -4,9 +4,10 @@
 # from dask.array.fft import __all__ as linalg_all
 _n = {}
 exec('from dask.array.fft import *', _n)
-del _n['__builtins__']
+for k in ("__builtins__", "Sequence", "annotations", "warnings"):
+    _n.pop(k, None)
 fft_all = list(_n)
-del _n
+del _n, k
 
 from ...common import _fft
 from ..._internal import get_xp
@@ -16,9 +17,5 @@
 fftfreq = get_xp(da)(_fft.fftfreq)
 rfftfreq = get_xp(da)(_fft.rfftfreq)
 
-__all__ = [elem for elem in fft_all if elem != "annotations"] + ["fftfreq", "rfftfreq"]
-
-del get_xp
-del da
-del fft_all
-del _fft
+__all__ = fft_all + ["fftfreq", "rfftfreq"]
+_all_ignore = ["da", "fft_all", "get_xp", "warnings"]
diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py
index 49c26d8b..bd53f0df 100644
--- a/array_api_compat/dask/array/linalg.py
+++ b/array_api_compat/dask/array/linalg.py
@@ -1,33 +1,28 @@
 from __future__ import annotations
 
-from ...common import _linalg
-from ..._internal import get_xp
+from typing import Literal
 
+import dask.array as da
 # Exports
 from dask.array.linalg import * # noqa: F403
 from dask.array import outer
-
 # These functions are in both the main and linalg namespaces
 from dask.array import matmul, tensordot
-from ._aliases import matrix_transpose, vecdot
 
-import dask.array as da
-
-from typing import TYPE_CHECKING
-if TYPE_CHECKING:
-    from ...common._typing import Array
-    from typing import Literal
+from ..._internal import get_xp
+from ...common import _linalg
+from ...common._typing import Array
+from ._aliases import matrix_transpose, vecdot
 
 # dask.array.linalg doesn't have __all__. If it is added, replace this with
 #
 # from dask.array.linalg import __all__ as linalg_all
 _n = {}
 exec('from dask.array.linalg import *', _n)
-del _n['__builtins__']
-if 'annotations' in _n:
-    del _n['annotations']
+for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'):
+    _n.pop(k, None)
 linalg_all = list(_n)
-del _n
+del _n, k
 
 EighResult = _linalg.EighResult
 QRResult = _linalg.QRResult
@@ -70,4 +65,4 @@ def svdvals(x: Array) -> Array:
                         "cholesky", "matrix_rank", "matrix_norm", "svdvals",
                         "vector_norm", "diagonal"]
 
-_all_ignore = ['get_xp', 'da', 'linalg_all']
+_all_ignore = ['get_xp', 'da', 'linalg_all', 'warnings']
diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py
index a47f7121..6536d9a8 100644
--- a/array_api_compat/numpy/_aliases.py
+++ b/array_api_compat/numpy/_aliases.py
@@ -1,17 +1,15 @@
 from __future__ import annotations
 
-from ..common import _aliases
+from typing import Optional, Union
 
 from .._internal import get_xp
-
+from ..common import _aliases
+from ..common._typing import NestedSequence, SupportsBufferProtocol
 from ._info import __array_namespace_info__
-
-from typing import TYPE_CHECKING
-if TYPE_CHECKING:
-    from typing import Optional, Union
-    from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
+from ._typing import Array, Device, DType
 
 import numpy as np
+
 bool = np.bool_
 
 # Basic renames
@@ -64,6 +62,7 @@
 tensordot = get_xp(np)(_aliases.tensordot)
 sign = get_xp(np)(_aliases.sign)
 
+
 def _supports_buffer_protocol(obj):
     try:
         memoryview(obj)
@@ -71,26 +70,22 @@ def _supports_buffer_protocol(obj):
         return False
     return True
 
+
 # asarray also adds the copy keyword, which is not present in numpy 1.0.
 # asarray() is different enough between numpy, cupy, and dask, the logic
 # complicated enough that it's easier to define it separately for each module
 # rather than trying to combine everything into one function in common/
 def asarray(
-    obj: Union[
-        ndarray,
-        bool,
-        int,
-        float,
-        NestedSequence[bool | int | float],
-        SupportsBufferProtocol,
-    ],
+    obj: (
+        Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol
+    ),
     /,
     *,
-    dtype: Optional[Dtype] = None,
+    dtype: Optional[DType] = None,
     device: Optional[Device] = None,
     copy: "Optional[Union[bool, np._CopyMode]]" = None,
     **kwargs,
-) -> ndarray:
+) -> Array:
     """
     Array API compatibility wrapper for asarray().
 
@@ -117,23 +112,19 @@ def asarray(
 
 
 def astype(
-    x: ndarray,
-    dtype: Dtype,
+    x: Array,
+    dtype: DType,
     /,
     *,
     copy: bool = True,
     device: Optional[Device] = None,
-) -> ndarray:
+) -> Array:
     return x.astype(dtype=dtype, copy=copy)
 
 
 # count_nonzero returns a python int for axis=None and keepdims=False
 # https://github.com/numpy/numpy/issues/17562
-def count_nonzero(
-    x : ndarray,
-    axis=None,
-    keepdims=False
-) -> ndarray:
+def count_nonzero(x: Array, axis=None, keepdims=False) -> Array:
     result = np.count_nonzero(x, axis=axis, keepdims=keepdims)
     if axis is None and not keepdims:
         return np.asarray(result)
diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py
index c5ebb5ab..6a18a3b2 100644
--- a/array_api_compat/numpy/_typing.py
+++ b/array_api_compat/numpy/_typing.py
@@ -1,46 +1,31 @@
 from __future__ import annotations
 
-__all__ = [
-    "ndarray",
-    "Device",
-    "Dtype",
-]
+__all__ = ["Array", "DType", "Device"]
+_all_ignore = ["np"]
 
-import sys
-from typing import (
-    Literal,
-    Union,
-    TYPE_CHECKING,
-)
+from typing import Literal, TYPE_CHECKING
 
-from numpy import (
-    ndarray,
-    dtype,
-    int8,
-    int16,
-    int32,
-    int64,
-    uint8,
-    uint16,
-    uint32,
-    uint64,
-    float32,
-    float64,
-)
+import numpy as np
+from numpy import ndarray as Array
 
 Device = Literal["cpu"]
-if TYPE_CHECKING or sys.version_info >= (3, 9):
-    Dtype = dtype[Union[
-        int8,
-        int16,
-        int32,
-        int64,
-        uint8,
-        uint16,
-        uint32,
-        uint64,
-        float32,
-        float64,
-    ]]
+if TYPE_CHECKING:
+    # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[]
+    DType = np.dtype[
+        np.intp
+        | np.int8
+        | np.int16
+        | np.int32
+        | np.int64
+        | np.uint8
+        | np.uint16
+        | np.uint32
+        | np.uint64
+        | np.float32
+        | np.float64
+        | np.complex64
+        | np.complex128
+        | np.bool
+    ]
 else:
-    Dtype = dtype
+    DType = np.dtype
diff --git a/array_api_compat/py.typed b/array_api_compat/py.typed
new file mode 100644
index 00000000..e69de29b
diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py
index 4b727f1c..87d32d85 100644
--- a/array_api_compat/torch/_aliases.py
+++ b/array_api_compat/torch/_aliases.py
@@ -2,21 +2,14 @@
 
 from functools import reduce as _reduce, wraps as _wraps
 from builtins import all as _builtin_all, any as _builtin_any
-
-from ..common import _aliases
-from .._internal import get_xp
-
-from ._info import __array_namespace_info__
+from typing import List, Optional, Sequence, Tuple, Union
 
 import torch
 
-from typing import TYPE_CHECKING
-if TYPE_CHECKING:
-    from typing import List, Optional, Sequence, Tuple, Union
-    from ..common._typing import Device
-    from torch import dtype as Dtype
-
-    array = torch.Tensor
+from .._internal import get_xp
+from ..common import _aliases
+from ._info import __array_namespace_info__
+from ._typing import Array, Device, DType
 
 _int_dtypes = {
     torch.uint8,
@@ -123,7 +116,7 @@ def _fix_promotion(x1, x2, only_scalar=True):
 _py_scalars = (bool, int, float, complex)
 
 
-def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype:
+def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType:
     num = len(arrays_and_dtypes)
 
     if num == 0:
@@ -170,7 +163,7 @@ def _result_type(x, y):
     return torch.result_type(x, y)
 
 
-def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
+def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
     if not isinstance(from_, torch.dtype):
         from_ = from_.dtype
     return torch.can_cast(from_, to)
@@ -216,13 +209,13 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
 # of 'axis'.
 
 # torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745
-def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
+def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array:
     # https://github.com/pytorch/pytorch/issues/29137
     if axis == ():
         return torch.clone(x)
     return torch.amax(x, axis, keepdims=keepdims)
 
-def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
+def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array:
     # https://github.com/pytorch/pytorch/issues/29137
     if axis == ():
         return torch.clone(x)
@@ -235,7 +228,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
 
 # torch.sort also returns a tuple
 # https://github.com/pytorch/pytorch/issues/70921
-def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array:
+def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> Array:
     return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values
 
 def _normalize_axes(axis, ndim):
@@ -280,13 +273,13 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
             out = torch.unsqueeze(out, a)
     return out
 
-def prod(x: array,
+def prod(x: Array,
          /,
          *,
          axis: Optional[Union[int, Tuple[int, ...]]] = None,
-         dtype: Optional[Dtype] = None,
+         dtype: Optional[DType] = None,
          keepdims: bool = False,
-         **kwargs) -> array:
+         **kwargs) -> Array:
     x = torch.asarray(x)
     ndim = x.ndim
 
@@ -316,13 +309,13 @@ def prod(x: array,
     return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
 
 
-def sum(x: array,
+def sum(x: Array,
          /,
          *,
          axis: Optional[Union[int, Tuple[int, ...]]] = None,
-         dtype: Optional[Dtype] = None,
+         dtype: Optional[DType] = None,
          keepdims: bool = False,
-         **kwargs) -> array:
+         **kwargs) -> Array:
     x = torch.asarray(x)
     ndim = x.ndim
 
@@ -347,12 +340,12 @@ def sum(x: array,
 
     return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
 
-def any(x: array,
+def any(x: Array,
         /,
         *,
         axis: Optional[Union[int, Tuple[int, ...]]] = None,
         keepdims: bool = False,
-        **kwargs) -> array:
+        **kwargs) -> Array:
     x = torch.asarray(x)
     ndim = x.ndim
     if axis == ():
@@ -372,12 +365,12 @@ def any(x: array,
     # torch.any doesn't return bool for uint8
     return torch.any(x, axis, keepdims=keepdims).to(torch.bool)
 
-def all(x: array,
+def all(x: Array,
         /,
         *,
         axis: Optional[Union[int, Tuple[int, ...]]] = None,
         keepdims: bool = False,
-        **kwargs) -> array:
+        **kwargs) -> Array:
     x = torch.asarray(x)
     ndim = x.ndim
     if axis == ():
@@ -397,12 +390,12 @@ def all(x: array,
     # torch.all doesn't return bool for uint8
     return torch.all(x, axis, keepdims=keepdims).to(torch.bool)
 
-def mean(x: array,
+def mean(x: Array,
          /,
          *,
          axis: Optional[Union[int, Tuple[int, ...]]] = None,
          keepdims: bool = False,
-         **kwargs) -> array:
+         **kwargs) -> Array:
     # https://github.com/pytorch/pytorch/issues/29137
     if axis == ():
         return torch.clone(x)
@@ -414,13 +407,13 @@ def mean(x: array,
         return res
     return torch.mean(x, axis, keepdims=keepdims, **kwargs)
 
-def std(x: array,
+def std(x: Array,
         /,
         *,
         axis: Optional[Union[int, Tuple[int, ...]]] = None,
         correction: Union[int, float] = 0.0,
         keepdims: bool = False,
-        **kwargs) -> array:
+        **kwargs) -> Array:
     # Note, float correction is not supported
     # https://github.com/pytorch/pytorch/issues/61492. We don't try to
     # implement it here for now.
@@ -445,13 +438,13 @@ def std(x: array,
         return res
     return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs)
 
-def var(x: array,
+def var(x: Array,
         /,
         *,
         axis: Optional[Union[int, Tuple[int, ...]]] = None,
         correction: Union[int, float] = 0.0,
         keepdims: bool = False,
-        **kwargs) -> array:
+        **kwargs) -> Array:
     # Note, float correction is not supported
     # https://github.com/pytorch/pytorch/issues/61492. We don't try to
     # implement it here for now.
@@ -474,11 +467,11 @@ def var(x: array,
 
 # torch.concat doesn't support dim=None
 # https://github.com/pytorch/pytorch/issues/70925
-def concat(arrays: Union[Tuple[array, ...], List[array]],
+def concat(arrays: Union[Tuple[Array, ...], List[Array]],
            /,
            *,
            axis: Optional[int] = 0,
-           **kwargs) -> array:
+           **kwargs) -> Array:
     if axis is None:
         arrays = tuple(ar.flatten() for ar in arrays)
         axis = 0
@@ -487,7 +480,7 @@ def concat(arrays: Union[Tuple[array, ...], List[array]],
 # torch.squeeze only accepts int dim and doesn't require it
 # https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was
 # added at https://github.com/pytorch/pytorch/pull/89017.
-def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
+def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
     if isinstance(axis, int):
         axis = (axis,)
     for a in axis:
@@ -501,27 +494,27 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
     return x
 
 # torch.broadcast_to uses size instead of shape
-def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array:
+def broadcast_to(x: Array, /, shape: Tuple[int, ...], **kwargs) -> Array:
     return torch.broadcast_to(x, shape, **kwargs)
 
 # torch.permute uses dims instead of axes
-def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
+def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
     return torch.permute(x, axes)
 
 # The axis parameter doesn't work for flip() and roll()
 # https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't
 # accept axis=None
-def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
+def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array:
     if axis is None:
         axis = tuple(range(x.ndim))
     # torch.flip doesn't accept dim as an int but the method does
     # https://github.com/pytorch/pytorch/issues/18095
     return x.flip(axis, **kwargs)
 
-def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
+def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array:
     return torch.roll(x, shift, axis, **kwargs)
 
-def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
+def nonzero(x: Array, /, **kwargs) -> Tuple[Array, ...]:
     if x.ndim == 0:
         raise ValueError("nonzero() does not support zero-dimensional arrays")
     return torch.nonzero(x, as_tuple=True, **kwargs)
@@ -529,25 +522,25 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
 
 # torch uses `dim` instead of `axis`
 def diff(
-    x: array,
+    x: Array,
     /,
     *,
     axis: int = -1,
     n: int = 1,
-    prepend: Optional[array] = None,
-    append: Optional[array] = None,
-) -> array:
+    prepend: Optional[Array] = None,
+    append: Optional[Array] = None,
+) -> Array:
     return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append)
 
 
 # torch uses `dim` instead of `axis`, does not have keepdims
 def count_nonzero(
-    x: array,
+    x: Array,
     /,
     *,
     axis: Optional[Union[int, Tuple[int, ...]]] = None,
     keepdims: bool = False,
-) -> array:
+) -> Array:
     result = torch.count_nonzero(x, dim=axis)
     if keepdims:
         if axis is not None:
@@ -557,17 +550,17 @@ def count_nonzero(
         return result
 
 
-
-def where(condition: array, x1: array, x2: array, /) -> array:
+def where(condition: Array, x1: Array, x2: Array, /) -> Array:
     x1, x2 = _fix_promotion(x1, x2)
     return torch.where(condition, x1, x2)
 
 # torch.reshape doesn't have the copy keyword
-def reshape(x: array,
+def reshape(x: Array,
             /,
             shape: Tuple[int, ...],
+            *,
             copy: Optional[bool] = None,
-            **kwargs) -> array:
+            **kwargs) -> Array:
     if copy is not None:
         raise NotImplementedError("torch.reshape doesn't yet support the copy keyword")
     return torch.reshape(x, shape, **kwargs)
@@ -581,9 +574,9 @@ def arange(start: Union[int, float],
            stop: Optional[Union[int, float]] = None,
            step: Union[int, float] = 1,
            *,
-           dtype: Optional[Dtype] = None,
+           dtype: Optional[DType] = None,
            device: Optional[Device] = None,
-           **kwargs) -> array:
+           **kwargs) -> Array:
     if stop is None:
         start, stop = 0, start
     if step > 0 and stop <= start or step < 0 and stop >= start:
@@ -602,9 +595,9 @@ def eye(n_rows: int,
         /,
         *,
         k: int = 0,
-        dtype: Optional[Dtype] = None,
+        dtype: Optional[DType] = None,
         device: Optional[Device] = None,
-        **kwargs) -> array:
+        **kwargs) -> Array:
     if n_cols is None:
         n_cols = n_rows
     z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs)
@@ -618,10 +611,10 @@ def linspace(start: Union[int, float],
              /,
              num: int,
              *,
-             dtype: Optional[Dtype] = None,
+             dtype: Optional[DType] = None,
              device: Optional[Device] = None,
              endpoint: bool = True,
-             **kwargs) -> array:
+             **kwargs) -> Array:
     if not endpoint:
         return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1]
     return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs)
@@ -629,11 +622,11 @@ def linspace(start: Union[int, float],
 # torch.full does not accept an int size
 # https://github.com/pytorch/pytorch/issues/70906
 def full(shape: Union[int, Tuple[int, ...]],
-         fill_value: Union[bool, int, float, complex],
+         fill_value: complex,
          *,
-         dtype: Optional[Dtype] = None,
+         dtype: Optional[DType] = None,
          device: Optional[Device] = None,
-         **kwargs) -> array:
+         **kwargs) -> Array:
     if isinstance(shape, int):
         shape = (shape,)
 
@@ -642,52 +635,52 @@ def full(shape: Union[int, Tuple[int, ...]],
 # ones, zeros, and empty do not accept shape as a keyword argument
 def ones(shape: Union[int, Tuple[int, ...]],
          *,
-         dtype: Optional[Dtype] = None,
+         dtype: Optional[DType] = None,
          device: Optional[Device] = None,
-         **kwargs) -> array:
+         **kwargs) -> Array:
     return torch.ones(shape, dtype=dtype, device=device, **kwargs)
 
 def zeros(shape: Union[int, Tuple[int, ...]],
          *,
-         dtype: Optional[Dtype] = None,
+         dtype: Optional[DType] = None,
          device: Optional[Device] = None,
-         **kwargs) -> array:
+         **kwargs) -> Array:
     return torch.zeros(shape, dtype=dtype, device=device, **kwargs)
 
 def empty(shape: Union[int, Tuple[int, ...]],
          *,
-         dtype: Optional[Dtype] = None,
+         dtype: Optional[DType] = None,
          device: Optional[Device] = None,
-         **kwargs) -> array:
+         **kwargs) -> Array:
     return torch.empty(shape, dtype=dtype, device=device, **kwargs)
 
 # tril and triu do not call the keyword argument k
 
-def tril(x: array, /, *, k: int = 0) -> array:
+def tril(x: Array, /, *, k: int = 0) -> Array:
     return torch.tril(x, k)
 
-def triu(x: array, /, *, k: int = 0) -> array:
+def triu(x: Array, /, *, k: int = 0) -> Array:
     return torch.triu(x, k)
 
 # Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
-def expand_dims(x: array, /, *, axis: int = 0) -> array:
+def expand_dims(x: Array, /, *, axis: int = 0) -> Array:
     return torch.unsqueeze(x, axis)
 
 
 def astype(
-    x: array,
-    dtype: Dtype,
+    x: Array,
+    dtype: DType,
     /,
     *,
     copy: bool = True,
     device: Optional[Device] = None,
-) -> array:
+) -> Array:
     if device is not None:
         return x.to(device, dtype=dtype, copy=copy)
     return x.to(dtype=dtype, copy=copy)
 
 
-def broadcast_arrays(*arrays: array) -> List[array]:
+def broadcast_arrays(*arrays: Array) -> List[Array]:
     shape = torch.broadcast_shapes(*[a.shape for a in arrays])
     return [torch.broadcast_to(a, shape) for a in arrays]
 
@@ -697,7 +690,7 @@ def broadcast_arrays(*arrays: array) -> List[array]:
                                UniqueInverseResult)
 
 # https://github.com/pytorch/pytorch/issues/70920
-def unique_all(x: array) -> UniqueAllResult:
+def unique_all(x: Array) -> UniqueAllResult:
     # torch.unique doesn't support returning indices.
     # https://github.com/pytorch/pytorch/issues/36748. The workaround
     # suggested in that issue doesn't actually function correctly (it relies
@@ -710,7 +703,7 @@ def unique_all(x: array) -> UniqueAllResult:
     # counts[torch.isnan(values)] = 1
     # return UniqueAllResult(values, indices, inverse_indices, counts)
 
-def unique_counts(x: array) -> UniqueCountsResult:
+def unique_counts(x: Array) -> UniqueCountsResult:
     values, counts = torch.unique(x, return_counts=True)
 
     # torch.unique incorrectly gives a 0 count for nan values.
@@ -718,14 +711,14 @@ def unique_counts(x: array) -> UniqueCountsResult:
     counts[torch.isnan(values)] = 1
     return UniqueCountsResult(values, counts)
 
-def unique_inverse(x: array) -> UniqueInverseResult:
+def unique_inverse(x: Array) -> UniqueInverseResult:
     values, inverse = torch.unique(x, return_inverse=True)
     return UniqueInverseResult(values, inverse)
 
-def unique_values(x: array) -> array:
+def unique_values(x: Array) -> Array:
     return torch.unique(x)
 
-def matmul(x1: array, x2: array, /, **kwargs) -> array:
+def matmul(x1: Array, x2: Array, /, **kwargs) -> Array:
     # torch.matmul doesn't type promote (but differently from _fix_promotion)
     x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
     return torch.matmul(x1, x2, **kwargs)
@@ -733,12 +726,19 @@ def matmul(x1: array, x2: array, /, **kwargs) -> array:
 matrix_transpose = get_xp(torch)(_aliases.matrix_transpose)
 _vecdot = get_xp(torch)(_aliases.vecdot)
 
-def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
+def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
     x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
     return _vecdot(x1, x2, axis=axis)
 
 # torch.tensordot uses dims instead of axes
-def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> array:
+def tensordot(
+    x1: Array,
+    x2: Array,
+    /,
+    *, 
+    axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, 
+    **kwargs,
+) -> Array:
     # Note: torch.tensordot fails with integer dtypes when there is only 1
     # element in the axis (https://github.com/pytorch/pytorch/issues/84530).
     x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
@@ -746,7 +746,7 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
 
 
 def isdtype(
-    dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]],
+    dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]],
     *, _tuple=True, # Disallow nested tuples
 ) -> bool:
     """
@@ -781,7 +781,7 @@ def isdtype(
     else:
         return dtype == kind
 
-def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array:
+def take(x: Array, indices: Array, /, *, axis: Optional[int] = None, **kwargs) -> Array:
     if axis is None:
         if x.ndim != 1:
             raise ValueError("axis must be specified when ndim > 1")
@@ -789,11 +789,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
     return torch.index_select(x, axis, indices, **kwargs)
 
 
-def take_along_axis(x: array, indices: array, /, *, axis: int = -1) -> array:
+def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
     return torch.take_along_dim(x, indices, dim=axis)
 
 
-def sign(x: array, /) -> array:
+def sign(x: Array, /) -> Array:
     # torch sign() does not support complex numbers and does not propagate
     # nans. See https://github.com/data-apis/array-api-compat/issues/136
     if x.dtype.is_complex:
diff --git a/array_api_compat/torch/_typing.py b/array_api_compat/torch/_typing.py
new file mode 100644
index 00000000..29ad3fa7
--- /dev/null
+++ b/array_api_compat/torch/_typing.py
@@ -0,0 +1,4 @@
+__all__ = ["Array", "DType", "Device"]
+
+from torch import dtype as DType, Tensor as Array
+from ..common._typing import Device
diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py
index 3c9117ee..50e6a0d0 100644
--- a/array_api_compat/torch/fft.py
+++ b/array_api_compat/torch/fft.py
@@ -1,76 +1,75 @@
 from __future__ import annotations
 
-from typing import TYPE_CHECKING
-if TYPE_CHECKING:
-    import torch
-    array = torch.Tensor
-    from typing import Union, Sequence, Literal
+from typing import Union, Sequence, Literal
 
-from torch.fft import * # noqa: F403
+import torch
 import torch.fft
+from torch.fft import * # noqa: F403
+
+from ._typing import Array
 
 # Several torch fft functions do not map axes to dim
 
 def fftn(
-    x: array,
+    x: Array,
     /,
     *,
     s: Sequence[int] = None,
     axes: Sequence[int] = None,
     norm: Literal["backward", "ortho", "forward"] = "backward",
     **kwargs,
-) -> array:
+) -> Array:
     return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs)
 
 def ifftn(
-    x: array,
+    x: Array,
     /,
     *,
     s: Sequence[int] = None,
     axes: Sequence[int] = None,
     norm: Literal["backward", "ortho", "forward"] = "backward",
     **kwargs,
-) -> array:
+) -> Array:
     return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs)
 
 def rfftn(
-    x: array,
+    x: Array,
     /,
     *,
     s: Sequence[int] = None,
     axes: Sequence[int] = None,
     norm: Literal["backward", "ortho", "forward"] = "backward",
     **kwargs,
-) -> array:
+) -> Array:
     return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs)
 
 def irfftn(
-    x: array,
+    x: Array,
     /,
     *,
     s: Sequence[int] = None,
     axes: Sequence[int] = None,
     norm: Literal["backward", "ortho", "forward"] = "backward",
     **kwargs,
-) -> array:
+) -> Array:
     return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs)
 
 def fftshift(
-    x: array,
+    x: Array,
     /,
     *,
     axes: Union[int, Sequence[int]] = None,
     **kwargs,
-) -> array:
+) -> Array:
     return torch.fft.fftshift(x, dim=axes, **kwargs)
 
 def ifftshift(
-    x: array,
+    x: Array,
     /,
     *,
     axes: Union[int, Sequence[int]] = None,
     **kwargs,
-) -> array:
+) -> Array:
     return torch.fft.ifftshift(x, dim=axes, **kwargs)
 
 
diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py
index e26198b9..7b59a670 100644
--- a/array_api_compat/torch/linalg.py
+++ b/array_api_compat/torch/linalg.py
@@ -1,14 +1,7 @@
 from __future__ import annotations
 
-from typing import TYPE_CHECKING
-if TYPE_CHECKING:
-    import torch
-    array = torch.Tensor
-    from torch import dtype as Dtype
-    from typing import Optional, Union, Tuple, Literal
-    inf = float('inf')
-
-from ._aliases import _fix_promotion, sum
+import torch
+from typing import Optional, Union, Tuple
 
 from torch.linalg import * # noqa: F403
 
@@ -19,15 +12,17 @@
 
 # outer is implemented in torch but aren't in the linalg namespace
 from torch import outer
+from ._aliases import _fix_promotion, sum
 # These functions are in both the main and linalg namespaces
 from ._aliases import matmul, matrix_transpose, tensordot
+from ._typing import Array, DType
 
 # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
 # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
 
 # torch.cross also does not support broadcasting when it would add new
 # dimensions https://github.com/pytorch/pytorch/issues/39656
-def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
+def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
     x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
     if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
         raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
@@ -36,7 +31,7 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
     x1, x2 = torch.broadcast_tensors(x1, x2)
     return torch_linalg.cross(x1, x2, dim=axis)
 
-def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
+def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array:
     from ._aliases import isdtype
 
     x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
@@ -58,7 +53,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
         return res[..., 0, 0]
     return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
 
-def solve(x1: array, x2: array, /, **kwargs) -> array:
+def solve(x1: Array, x2: Array, /, **kwargs) -> Array:
     x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
     # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
     # whenever
@@ -79,19 +74,20 @@ def solve(x1: array, x2: array, /, **kwargs) -> array:
     return torch.linalg.solve(x1, x2, **kwargs)
 
 # torch.trace doesn't support the offset argument and doesn't support stacking
-def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
+def trace(x: Array, /, *, offset: int = 0, dtype: Optional[DType] = None) -> Array:
     # Use our wrapped sum to make sure it does upcasting correctly
     return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
 
 def vector_norm(
-    x: array,
+    x: Array,
     /,
     *,
     axis: Optional[Union[int, Tuple[int, ...]]] = None,
     keepdims: bool = False,
-    ord: Union[int, float, Literal[inf, -inf]] = 2,
+    # float stands for inf | -inf, which are not valid for Literal
+    ord: Union[int, float, float] = 2,
     **kwargs,
-) -> array:
+) -> Array:
     # torch.vector_norm incorrectly treats axis=() the same as axis=None
     if axis == ():
         out = kwargs.get('out')
diff --git a/setup.py b/setup.py
index 3d2b68a2..2368ccc4 100644
--- a/setup.py
+++ b/setup.py
@@ -33,5 +33,8 @@
         "Programming Language :: Python :: 3.13",
         "License :: OSI Approved :: MIT License",
         "Operating System :: OS Independent",
-    ]
+    ],
+    package_data={
+        "array_api_compat": ["py.typed"],
+    },
 )
diff --git a/tests/test_all.py b/tests/test_all.py
index d2e9b768..598fab62 100644
--- a/tests/test_all.py
+++ b/tests/test_all.py
@@ -15,6 +15,16 @@
 from ._helpers import import_, wrapped_libraries
 
 import pytest
+import typing
+
+TYPING_NAMES = frozenset((
+    "Array",
+    "Device",
+    "DType",
+    "Namespace",
+    "NestedSequence",
+    "SupportsBufferProtocol",
+))
 
 @pytest.mark.skip(reason="TODO: starts failing after adding test_torch.py in gh-277")
 @pytest.mark.parametrize("library", ["common"] + wrapped_libraries)
@@ -38,8 +48,11 @@ def test_all(library):
         dir_names = [n for n in dir(module) if not n.startswith('_')]
         if '__array_namespace_info__' in dir(module):
             dir_names.append('__array_namespace_info__')
-        ignore_all_names = getattr(module, '_all_ignore', [])
-        ignore_all_names += ['annotations', 'TYPE_CHECKING']
+        ignore_all_names = set(getattr(module, '_all_ignore', ()))
+        ignore_all_names |= set(dir(typing))
+        ignore_all_names |= {"annotations"}
+        if not module.__name__.endswith("._typing"):
+            ignore_all_names |= TYPING_NAMES
         dir_names = set(dir_names) - set(ignore_all_names)
         all_names = module.__all__