Skip to content

Commit 2f7b4d9

Browse files
MAINT: Array API 2024.12 typing nits (#156)
* MAINT: Array API 2024.12 typing nits * update docstrings * Update src/array_api_extra/_lib/_utils/_helpers.py Co-authored-by: Lucas Colley <[email protected]> * Add xref to pyright bug --------- Co-authored-by: Lucas Colley <[email protected]>
1 parent 95c0ead commit 2f7b4d9

File tree

5 files changed

+60
-31
lines changed

5 files changed

+60
-31
lines changed

docs/index.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ increase performance.
185185
In particular, the following kinds of function are also in-scope:
186186

187187
- Functions which implement
188-
[array API standard extension](https://data-apis.org/array-api/2023.12/extensions/index.html)
188+
[array API standard extension](https://data-apis.org/array-api/latest/extensions/index.html)
189189
functions in terms of functions from the base standard.
190190
- Functions which add functionality (e.g. extra parameters) to functions from
191191
the standard.

src/array_api_extra/_delegation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def _delegate(xp: ModuleType, *backends: Backend) -> bool:
3131

3232

3333
def isclose(
34-
a: Array,
35-
b: Array,
34+
a: Array | complex,
35+
b: Array | complex,
3636
*,
3737
rtol: float = 1e-05,
3838
atol: float = 1e-08,

src/array_api_extra/_lib/_funcs.py

+29-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import warnings
88
from collections.abc import Sequence
99
from types import ModuleType
10-
from typing import cast
10+
from typing import TYPE_CHECKING, cast
1111

1212
from ._at import at
1313
from ._utils import _compat, _helpers
@@ -375,8 +375,8 @@ def expand_dims(
375375

376376

377377
def isclose(
378-
a: Array,
379-
b: Array,
378+
a: Array | complex,
379+
b: Array | complex,
380380
*,
381381
rtol: float = 1e-05,
382382
atol: float = 1e-08,
@@ -385,6 +385,10 @@ def isclose(
385385
) -> Array: # numpydoc ignore=PR01,RT01
386386
"""See docstring in array_api_extra._delegation."""
387387
a, b = asarrays(a, b, xp=xp)
388+
# FIXME https://github.com/microsoft/pyright/issues/10085
389+
if TYPE_CHECKING: # pragma: nocover
390+
assert _compat.is_array_api_obj(a)
391+
assert _compat.is_array_api_obj(b)
388392

389393
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
390394
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
@@ -419,7 +423,13 @@ def isclose(
419423
return xp.abs(a - b) <= (atol + xp.abs(b) // nrtol)
420424

421425

422-
def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
426+
def kron(
427+
a: Array | complex,
428+
b: Array | complex,
429+
/,
430+
*,
431+
xp: ModuleType | None = None,
432+
) -> Array:
423433
"""
424434
Kronecker product of two arrays.
425435
@@ -495,9 +505,16 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
495505
if xp is None:
496506
xp = array_namespace(a, b)
497507
a, b = asarrays(a, b, xp=xp)
508+
# FIXME https://github.com/microsoft/pyright/issues/10085
509+
if TYPE_CHECKING: # pragma: nocover
510+
assert _compat.is_array_api_obj(a)
511+
assert _compat.is_array_api_obj(b)
498512

499513
singletons = (1,) * (b.ndim - a.ndim)
500514
a = xp.broadcast_to(a, singletons + a.shape)
515+
# FIXME https://github.com/microsoft/pyright/issues/10085
516+
if TYPE_CHECKING: # pragma: nocover
517+
assert _compat.is_array_api_obj(a)
501518

502519
nd_b, nd_a = b.ndim, a.ndim
503520
nd_max = max(nd_b, nd_a)
@@ -614,8 +631,8 @@ def pad(
614631

615632

616633
def setdiff1d(
617-
x1: Array,
618-
x2: Array,
634+
x1: Array | complex,
635+
x2: Array | complex,
619636
/,
620637
*,
621638
assume_unique: bool = False,
@@ -628,7 +645,7 @@ def setdiff1d(
628645
629646
Parameters
630647
----------
631-
x1 : array
648+
x1 : array | int | float | complex | bool
632649
Input array.
633650
x2 : array
634651
Input comparison array.
@@ -665,6 +682,11 @@ def setdiff1d(
665682
else:
666683
x1 = xp.unique_values(x1)
667684
x2 = xp.unique_values(x2)
685+
686+
# FIXME https://github.com/microsoft/pyright/issues/10085
687+
if TYPE_CHECKING: # pragma: nocover
688+
assert _compat.is_array_api_obj(x1)
689+
668690
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
669691

670692

src/array_api_extra/_lib/_utils/_compat.pyi

+14-11
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,30 @@ from __future__ import annotations
55

66
from types import ModuleType
77

8+
# TODO import from typing (requires Python >=3.13)
9+
from typing_extensions import TypeIs
10+
811
from ._typing import Array, Device
912

1013
# pylint: disable=missing-class-docstring,unused-argument
1114

12-
class ArrayModule(ModuleType):
15+
class Namespace(ModuleType):
1316
def device(self, x: Array, /) -> Device: ...
1417

1518
def array_namespace(
16-
*xs: Array,
19+
*xs: Array | complex | None,
1720
api_version: str | None = None,
1821
use_compat: bool | None = None,
19-
) -> ArrayModule: ...
22+
) -> Namespace: ...
2023
def device(x: Array, /) -> Device: ...
21-
def is_array_api_obj(x: object, /) -> bool: ...
22-
def is_array_api_strict_namespace(xp: ModuleType, /) -> bool: ...
23-
def is_cupy_namespace(xp: ModuleType, /) -> bool: ...
24-
def is_dask_namespace(xp: ModuleType, /) -> bool: ...
25-
def is_jax_namespace(xp: ModuleType, /) -> bool: ...
26-
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
27-
def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ...
28-
def is_torch_namespace(xp: ModuleType, /) -> bool: ...
24+
def is_array_api_obj(x: object, /) -> TypeIs[Array]: ...
25+
def is_array_api_strict_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
26+
def is_cupy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
27+
def is_dask_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
28+
def is_jax_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
29+
def is_numpy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
30+
def is_pydata_sparse_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
31+
def is_torch_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
2932
def is_cupy_array(x: object, /) -> bool: ...
3033
def is_dask_array(x: object, /) -> bool: ...
3134
def is_jax_array(x: object, /) -> bool: ...

src/array_api_extra/_lib/_utils/_helpers.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@
55

66
from collections.abc import Generator
77
from types import ModuleType
8-
from typing import cast
8+
from typing import TYPE_CHECKING
99

1010
from . import _compat
1111
from ._compat import array_namespace, is_array_api_obj, is_numpy_array
1212
from ._typing import Array
1313

14+
if TYPE_CHECKING: # pragma: no cover
15+
# TODO import from typing (requires Python >=3.13)
16+
from typing_extensions import TypeIs
17+
18+
1419
__all__ = ["asarrays", "in1d", "is_python_scalar", "mean"]
1520

1621

@@ -96,16 +101,17 @@ def mean(
96101
return xp.mean(x, axis=axis, keepdims=keepdims)
97102

98103

99-
def is_python_scalar(x: object) -> bool: # numpydoc ignore=PR01,RT01
104+
def is_python_scalar(x: object) -> TypeIs[complex]: # numpydoc ignore=PR01,RT01
100105
"""Return True if `x` is a Python scalar, False otherwise."""
101106
# isinstance(x, float) returns True for np.float64
102107
# isinstance(x, complex) returns True for np.complex128
103-
return isinstance(x, int | float | complex | bool) and not is_numpy_array(x)
108+
# bool is a subclass of int
109+
return isinstance(x, int | float | complex) and not is_numpy_array(x)
104110

105111

106112
def asarrays(
107-
a: Array | int | float | complex | bool,
108-
b: Array | int | float | complex | bool,
113+
a: Array | complex,
114+
b: Array | complex,
109115
xp: ModuleType,
110116
) -> tuple[Array, Array]:
111117
"""
@@ -150,9 +156,7 @@ def asarrays(
150156
if is_array_api_obj(a):
151157
# a is an Array API object
152158
# b is a int | float | complex | bool
153-
154-
# pyright doesn't like it if you reuse the same variable name
155-
xa = cast(Array, a)
159+
xa = a
156160

157161
# https://data-apis.org/array-api/draft/API_specification/type_promotion.html#mixing-arrays-with-python-scalars
158162
same_dtype = {
@@ -162,8 +166,8 @@ def asarrays(
162166
complex: "complex floating",
163167
}
164168
kind = same_dtype[type(b)] # type: ignore[index]
165-
if xp.isdtype(xa.dtype, kind):
166-
xb = xp.asarray(b, dtype=xa.dtype)
169+
if xp.isdtype(a.dtype, kind):
170+
xb = xp.asarray(b, dtype=a.dtype)
167171
else:
168172
# Undefined behaviour. Let the function deal with it, if it can.
169173
xb = xp.asarray(b)

0 commit comments

Comments
 (0)