1212import math
1313import sys
1414import warnings
15- from collections . abc import Collection
15+ from types import NoneType
1616from typing import (
1717 TYPE_CHECKING ,
1818 Any ,
1919 Final ,
2020 Literal ,
21- SupportsIndex ,
2221 TypeAlias ,
2322 TypeGuard ,
24- TypeVar ,
2523 cast ,
2624 overload ,
2725)
2826
2927from ._typing import Array , Device , HasShape , Namespace , SupportsArrayNamespace
3028
3129if TYPE_CHECKING :
32-
30+ import cupy as cp
3331 import dask .array as da
3432 import jax
3533 import ndonnx as ndx
3634 import numpy as np
3735 import numpy .typing as npt
38- import sparse # pyright: ignore[reportMissingTypeStubs]
36+ import sparse
3937 import torch
4038
4139 # TODO: import from typing (requires Python >=3.13)
42- from typing_extensions import TypeIs , TypeVar
43-
44- _SizeT = TypeVar ("_SizeT" , bound = int | None )
40+ from typing_extensions import TypeIs
4541
4642 _ZeroGradientArray : TypeAlias = npt .NDArray [np .void ]
47- _CupyArray : TypeAlias = Any # cupy has no py.typed
4843
4944 _ArrayApiObj : TypeAlias = (
5045 npt .NDArray [Any ]
46+ | cp .ndarray
5147 | da .Array
5248 | jax .Array
5349 | ndx .Array
5450 | sparse .SparseArray
5551 | torch .Tensor
56- | SupportsArrayNamespace [Any ]
57- | _CupyArray
52+ | SupportsArrayNamespace
5853 )
5954
6055_API_VERSIONS_OLD : Final = frozenset ({"2021.12" , "2022.12" , "2023.12" })
6156_API_VERSIONS : Final = _API_VERSIONS_OLD | frozenset ({"2024.12" })
6257
6358
64- def _is_jax_zero_gradient_array (x : object ) -> TypeGuard [_ZeroGradientArray ]:
59+ def _is_jax_zero_gradient_array (x : object ) -> TypeIs [_ZeroGradientArray ]:
6560 """Return True if `x` is a zero-gradient array.
6661
6762 These arrays are a design quirk of Jax that may one day be removed.
@@ -80,7 +75,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
8075 )
8176
8277
83- def is_numpy_array (x : object ) -> TypeGuard [npt .NDArray [Any ]]:
78+ def is_numpy_array (x : object ) -> TypeIs [npt .NDArray [Any ]]:
8479 """
8580 Return True if `x` is a NumPy array.
8681
@@ -137,7 +132,7 @@ def is_cupy_array(x: object) -> bool:
137132 if "cupy" not in sys .modules :
138133 return False
139134
140- import cupy as cp # pyright: ignore[reportMissingTypeStubs]
135+ import cupy as cp
141136
142137 # TODO: Should we reject ndarray subclasses?
143138 return isinstance (x , cp .ndarray ) # pyright: ignore[reportUnknownMemberType]
@@ -280,13 +275,13 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
280275 if "sparse" not in sys .modules :
281276 return False
282277
283- import sparse # pyright: ignore[reportMissingTypeStubs]
278+ import sparse
284279
285280 # TODO: Account for other backends.
286281 return isinstance (x , sparse .SparseArray )
287282
288283
289- def is_array_api_obj (x : object ) -> TypeIs [_ArrayApiObj ]: # pyright: ignore[reportUnknownParameterType]
284+ def is_array_api_obj (x : object ) -> TypeGuard [_ArrayApiObj ]:
290285 """
291286 Return True if `x` is an array API compatible array object.
292287
@@ -587,7 +582,7 @@ def your_function(x, y):
587582
588583 namespaces .add (cupy_namespace )
589584 else :
590- import cupy as cp # pyright: ignore[reportMissingTypeStubs]
585+ import cupy as cp
591586
592587 namespaces .add (cp )
593588 elif is_torch_array (x ):
@@ -624,14 +619,14 @@ def your_function(x, y):
624619 if hasattr (jax .numpy , "__array_api_version__" ):
625620 jnp = jax .numpy
626621 else :
627- import jax .experimental .array_api as jnp # pyright : ignore[reportMissingImports ]
622+ import jax .experimental .array_api as jnp # type : ignore[no-redef ]
628623 namespaces .add (jnp )
629624 elif is_pydata_sparse_array (x ):
630625 if use_compat is True :
631626 _check_api_version (api_version )
632627 raise ValueError ("`sparse` does not have an array-api-compat wrapper" )
633628 else :
634- import sparse # pyright: ignore[reportMissingTypeStubs]
629+ import sparse
635630 # `sparse` is already an array namespace. We do not have a wrapper
636631 # submodule for it.
637632 namespaces .add (sparse )
@@ -640,9 +635,9 @@ def your_function(x, y):
640635 raise ValueError (
641636 "The given array does not have an array-api-compat wrapper"
642637 )
643- x = cast (" SupportsArrayNamespace[Any]" , x )
638+ x = cast (SupportsArrayNamespace , x )
644639 namespaces .add (x .__array_namespace__ (api_version = api_version ))
645- elif isinstance (x , ( bool , int , float , complex , type ( None )) ):
640+ elif isinstance (x , int | float | complex | NoneType ):
646641 continue
647642 else :
648643 # TODO: Support Python scalars?
@@ -738,7 +733,7 @@ def device(x: _ArrayApiObj, /) -> Device:
738733 return "cpu"
739734 elif is_dask_array (x ):
740735 # Peek at the metadata of the Dask array to determine type
741- if is_numpy_array (x ._meta ): # pyright: ignore
736+ if is_numpy_array (x ._meta ):
742737 # Must be on CPU since backed by numpy
743738 return "cpu"
744739 return _DASK_DEVICE
@@ -767,7 +762,7 @@ def device(x: _ArrayApiObj, /) -> Device:
767762 return "cpu"
768763 # Return the device of the constituent array
769764 return device (inner ) # pyright: ignore
770- return x .device # pyright: ignore
765+ return x .device # type: ignore # pyright: ignore
771766
772767
773768# Prevent shadowing, used below
@@ -776,12 +771,12 @@ def device(x: _ArrayApiObj, /) -> Device:
776771
777772# Based on cupy.array_api.Array.to_device
778773def _cupy_to_device (
779- x : _CupyArray ,
774+ x : cp . ndarray ,
780775 device : Device ,
781776 / ,
782777 stream : int | Any | None = None ,
783- ) -> _CupyArray :
784- import cupy as cp # pyright: ignore[reportMissingTypeStubs]
778+ ) -> cp . ndarray :
779+ import cupy as cp
785780 from cupy .cuda import Device as _Device # pyright: ignore
786781 from cupy .cuda import stream as stream_module # pyright: ignore
787782 from cupy_backends .cuda .api import runtime # pyright: ignore
@@ -797,10 +792,10 @@ def _cupy_to_device(
797792 raise ValueError (f"Unsupported device { device !r} " )
798793 else :
799794 # see cupy/cupy#5985 for the reason how we handle device/stream here
800- prev_device : Any = runtime .getDevice () # pyright: ignore[reportUnknownMemberType]
795+ prev_device : Device = runtime .getDevice () # pyright: ignore[reportUnknownMemberType]
801796 prev_stream = None
802797 if stream is not None :
803- prev_stream : Any = stream_module .get_current_stream () # pyright: ignore
798+ prev_stream = stream_module .get_current_stream () # pyright: ignore
804799 # stream can be an int as specified in __dlpack__, or a CuPy stream
805800 if isinstance (stream , int ):
806801 stream = cp .cuda .ExternalStream (stream ) # pyright: ignore
@@ -814,7 +809,7 @@ def _cupy_to_device(
814809 arr = x .copy ()
815810 finally :
816811 runtime .setDevice (prev_device ) # pyright: ignore[reportUnknownMemberType]
817- if stream is not None :
812+ if prev_stream is not None :
818813 prev_stream .use ()
819814 return arr
820815
@@ -823,7 +818,7 @@ def _torch_to_device(
823818 x : torch .Tensor ,
824819 device : torch .device | str | int ,
825820 / ,
826- stream : None = None ,
821+ stream : int | Any | None = None ,
827822) -> torch .Tensor :
828823 if stream is not None :
829824 raise NotImplementedError
@@ -889,7 +884,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
889884 # cupy does not yet have to_device
890885 return _cupy_to_device (x , device , stream = stream )
891886 elif is_torch_array (x ):
892- return _torch_to_device (x , device , stream = stream ) # pyright: ignore[reportArgumentType]
887+ return _torch_to_device (x , device , stream = stream )
893888 elif is_dask_array (x ):
894889 if stream is not None :
895890 raise ValueError ("The stream argument to to_device() is not supported" )
@@ -914,12 +909,12 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
914909
915910
916911@overload
917- def size (x : HasShape [Collection [ SupportsIndex ] ]) -> int : ...
912+ def size (x : HasShape [int ]) -> int : ...
918913@overload
919- def size (x : HasShape [Collection [ None ]] ) -> None : ...
914+ def size (x : HasShape [int | None ]) -> int | None : ...
920915@overload
921- def size (x : HasShape [Collection [ SupportsIndex | None ]] ) -> int | None : ...
922- def size (x : HasShape [Collection [ SupportsIndex | None ] ]) -> int | None :
916+ def size (x : HasShape [float ] ) -> int | None : ... # Dask special case
917+ def size (x : HasShape [float | None ]) -> int | None :
923918 """
924919 Return the total number of elements of x.
925920
@@ -934,12 +929,12 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
934929 # Lazy API compliant arrays, such as ndonnx, can contain None in their shape
935930 if None in x .shape :
936931 return None
937- out = math .prod (cast ("Collection[SupportsIndex]" , x .shape ))
932+ out = math .prod (cast (tuple [ float , ...] , x .shape ))
938933 # dask.array.Array.shape can contain NaN
939- return None if math .isnan (out ) else out
934+ return None if math .isnan (out ) else cast ( int , out )
940935
941936
942- def is_writeable_array (x : object ) -> bool :
937+ def is_writeable_array (x : object ) -> TypeGuard [ _ArrayApiObj ] :
943938 """
944939 Return False if ``x.__setitem__`` is expected to raise; True otherwise.
945940 Return False if `x` is not an array API compatible object.
@@ -956,7 +951,7 @@ def is_writeable_array(x: object) -> bool:
956951 return is_array_api_obj (x )
957952
958953
959- def is_lazy_array (x : object ) -> bool :
954+ def is_lazy_array (x : object ) -> TypeGuard [ _ArrayApiObj ] :
960955 """Return True if x is potentially a future or it may be otherwise impossible or
961956 expensive to eagerly read its contents, regardless of their size, e.g. by
962957 calling ``bool(x)`` or ``float(x)``.
@@ -997,7 +992,7 @@ def is_lazy_array(x: object) -> bool:
997992 # on __bool__ (dask is one such example, which however is special-cased above).
998993
999994 # Select a single point of the array
1000- s = size (cast (" HasShape[Collection[SupportsIndex | None]]" , x ))
995+ s = size (cast (HasShape , x ))
1001996 if s is None :
1002997 return True
1003998 xp = array_namespace (x )
@@ -1044,5 +1039,6 @@ def is_lazy_array(x: object) -> bool:
10441039
10451040_all_ignore = ["sys" , "math" , "inspect" , "warnings" ]
10461041
1042+
10471043def __dir__ () -> list [str ]:
10481044 return __all__
0 commit comments