Skip to content

Commit ae8d83b

Browse files
committed
Merge branch 'main' into xp_assert_enhancements
2 parents d7f7549 + bb6129b commit ae8d83b

File tree

11 files changed

+124
-83
lines changed

11 files changed

+124
-83
lines changed

.github/workflows/docs-deploy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name: Docs Deploy
22

33
permissions:
4-
contents: read
4+
contents: write # needed for the deploy step
55

66
on:
77
workflow_run:

codecov.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@ github_checks:
44
ignore:
55
- "src/array_api_extra/_lib/_compat"
66
- "src/array_api_extra/_lib/_typing"
7+
coverage:
8+
status:
9+
project: off

pixi.lock

Lines changed: 42 additions & 43 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,17 @@ array-api-compat = ">=1.11.2,<2"
5454
array-api-extra = { path = ".", editable = true }
5555

5656
[tool.pixi.feature.lint.dependencies]
57-
typing-extensions = ">=4.13.1"
57+
typing-extensions = ">=4.13.2"
5858
pre-commit = ">=4.2.0"
5959
pylint = ">=3.3.6"
6060
basedmypy = ">=2.10.0"
61-
basedpyright = ">=1.28.3"
61+
basedpyright = ">=1.28.5"
6262
numpydoc = ">=1.8.0,<2"
6363
# import dependencies for mypy:
6464
array-api-strict = ">=2.3.1"
6565
numpy = ">=2.1.3"
6666
pytest = ">=8.3.5"
67-
hypothesis = ">=6.130.11"
67+
hypothesis = ">=6.131.8"
6868
dask-core = ">=2025.3.0" # No distributed, tornado, etc.
6969
# NOTE: don't add jax, pytorch, sparse, cupy here
7070
# as they slow down mypy and are not portable across target OSs
@@ -80,7 +80,7 @@ lint = { depends-on = ["pre-commit", "pylint", "mypy", "pyright"] , description
8080
[tool.pixi.feature.tests.dependencies]
8181
pytest = ">=8.3.5"
8282
pytest-cov = ">=6.1.1"
83-
hypothesis = ">=6.130.11"
83+
hypothesis = ">=6.131.8"
8484
array-api-strict = ">=2.3.1"
8585
numpy = ">=1.22.0"
8686

@@ -107,7 +107,7 @@ sphinx-autodoc-typehints = ">=1.25.3"
107107
# Needed to import parsed modules with autodoc
108108
dask-core = ">=2025.3.0"
109109
pytest = ">=8.3.5"
110-
typing-extensions = ">=4.13.1"
110+
typing-extensions = ">=4.13.2"
111111
numpy = ">=2.1.3"
112112

113113
[tool.pixi.feature.docs.tasks]
@@ -136,7 +136,7 @@ numpy = "=1.22.0"
136136
[tool.pixi.feature.backends.dependencies]
137137
pytorch = ">=2.6.0"
138138
dask = ">=2025.3.0"
139-
numba = ">=0.61.0" # sparse dependency
139+
numba = ">=0.61.2" # sparse dependency
140140
llvmlite = ">=0.44.0" # sparse dependency
141141

142142
[tool.pixi.feature.backends.pypi-dependencies]

renovate.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,22 @@
5050
"matchManagers": ["github-actions"],
5151
"matchPackageNames": ["python"],
5252
"enabled": false
53+
},
54+
{
55+
"description": "Group Dask packages.",
56+
"matchPackageNames": ["dask", "dask-core"],
57+
"groupName": "dask"
58+
},
59+
{
60+
"description": "Group JAX packages.",
61+
"matchPackageNames": ["jax", "jaxlib"],
62+
"groupName": "jax"
63+
},
64+
{
65+
"description": "Schedule hypothesis monthly as releases are frequent.",
66+
"matchManagers": ["pixi"],
67+
"matchPackageNames": ["hypothesis"],
68+
"schedule": ["* * 10 * *"]
5369
}
5470
]
5571
}

src/array_api_extra/_lib/_utils/_compat.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
is_torch_namespace,
2424
is_writeable_array,
2525
size,
26+
to_device,
2627
)
2728
except ImportError:
2829
from array_api_compat import (
@@ -45,6 +46,7 @@
4546
is_torch_namespace,
4647
is_writeable_array,
4748
size,
49+
to_device,
4850
)
4951

5052
__all__ = [
@@ -67,4 +69,5 @@
6769
"is_torch_namespace",
6870
"is_writeable_array",
6971
"size",
72+
"to_device",
7073
]

src/array_api_extra/_lib/_utils/_compat.pyi

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55

66
from types import ModuleType
7+
from typing import Any, TypeGuard
78

89
# TODO import from typing (requires Python >=3.13)
910
from typing_extensions import TypeIs
@@ -12,29 +13,33 @@ from ._typing import Array, Device
1213

1314
# pylint: disable=missing-class-docstring,unused-argument
1415

15-
class Namespace(ModuleType):
16-
def device(self, x: Array, /) -> Device: ...
17-
1816
def array_namespace(
1917
*xs: Array | complex | None,
2018
api_version: str | None = None,
2119
use_compat: bool | None = None,
22-
) -> Namespace: ...
20+
) -> ModuleType: ...
2321
def device(x: Array, /) -> Device: ...
2422
def is_array_api_obj(x: object, /) -> TypeIs[Array]: ...
25-
def is_array_api_strict_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
26-
def is_cupy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
27-
def is_dask_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
28-
def is_jax_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
29-
def is_numpy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
30-
def is_pydata_sparse_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
31-
def is_torch_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
32-
def is_cupy_array(x: object, /) -> TypeIs[Array]: ...
33-
def is_dask_array(x: object, /) -> TypeIs[Array]: ...
34-
def is_jax_array(x: object, /) -> TypeIs[Array]: ...
35-
def is_numpy_array(x: object, /) -> TypeIs[Array]: ...
36-
def is_pydata_sparse_array(x: object, /) -> TypeIs[Array]: ...
37-
def is_torch_array(x: object, /) -> TypeIs[Array]: ...
38-
def is_lazy_array(x: object, /) -> TypeIs[Array]: ...
39-
def is_writeable_array(x: object, /) -> TypeIs[Array]: ...
23+
def is_array_api_strict_namespace(xp: ModuleType, /) -> bool: ...
24+
def is_cupy_namespace(xp: ModuleType, /) -> bool: ...
25+
def is_dask_namespace(xp: ModuleType, /) -> bool: ...
26+
def is_jax_namespace(xp: ModuleType, /) -> bool: ...
27+
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
28+
def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ...
29+
def is_torch_namespace(xp: ModuleType, /) -> bool: ...
30+
def is_cupy_array(x: object, /) -> TypeGuard[Array]: ...
31+
def is_dask_array(x: object, /) -> TypeGuard[Array]: ...
32+
def is_jax_array(x: object, /) -> TypeGuard[Array]: ...
33+
def is_numpy_array(x: object, /) -> TypeGuard[Array]: ...
34+
def is_pydata_sparse_array(x: object, /) -> TypeGuard[Array]: ...
35+
def is_torch_array(x: object, /) -> TypeGuard[Array]: ...
36+
def is_lazy_array(x: object, /) -> TypeGuard[Array]: ...
37+
def is_writeable_array(x: object, /) -> TypeGuard[Array]: ...
4038
def size(x: Array, /) -> int | None: ...
39+
def to_device( # type: ignore[explicit-any]
40+
x: Array,
41+
device: Device, # pylint: disable=redefined-outer-name
42+
/,
43+
*,
44+
stream: int | Any | None = None,
45+
) -> Array: ...

src/array_api_extra/testing.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def override(func: object) -> object:
3939
def lazy_xp_function( # type: ignore[explicit-any]
4040
func: Callable[..., Any],
4141
*,
42-
allow_dask_compute: int = 0,
42+
allow_dask_compute: bool | int = False,
4343
jax_jit: bool = True,
4444
static_argnums: int | Sequence[int] | None = None,
4545
static_argnames: str | Iterable[str] | None = None,
@@ -59,9 +59,10 @@ def lazy_xp_function( # type: ignore[explicit-any]
5959
----------
6060
func : callable
6161
Function to be tested.
62-
allow_dask_compute : int, optional
63-
Number of times `func` is allowed to internally materialize the Dask graph. This
64-
is typically triggered by ``bool()``, ``float()``, or ``np.asarray()``.
62+
allow_dask_compute : bool | int, optional
63+
Whether `func` is allowed to internally materialize the Dask graph, or maximum
64+
number of times it is allowed to do so. This is typically triggered by
65+
``bool()``, ``float()``, or ``np.asarray()``.
6566
6667
Set to 1 if you are aware that `func` converts the input parameters to NumPy and
6768
want to let it do so at least for the time being, knowing that it is going to be
@@ -75,7 +76,10 @@ def lazy_xp_function( # type: ignore[explicit-any]
7576
a test function that invokes `func` multiple times should still work with this
7677
parameter set to 1.
7778
78-
Default: 0, meaning that `func` must be fully lazy and never materialize the
79+
Set to True to allow `func` to materialize the graph an unlimited number
80+
of times.
81+
82+
Default: False, meaning that `func` must be fully lazy and never materialize the
7983
graph.
8084
jax_jit : bool, optional
8185
Set to True to replace `func` with ``jax.jit(func)`` after calling the
@@ -235,6 +239,10 @@ def iter_tagged() -> ( # type: ignore[explicit-any]
235239
if is_dask_namespace(xp):
236240
for mod, name, func, tags in iter_tagged():
237241
n = tags["allow_dask_compute"]
242+
if n is True:
243+
n = 1_000_000
244+
elif n is False:
245+
n = 0
238246
wrapped = _dask_wrap(func, n)
239247
monkeypatch.setattr(mod, name, wrapped)
240248

tests/test_funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def test_xp(self, xp: ModuleType):
521521
class TestExpandDims:
522522
def test_single_axis(self, xp: ModuleType):
523523
"""Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims"""
524-
a = xp.empty((2, 3, 4, 5))
524+
a = xp.asarray(np.reshape(np.arange(2 * 3 * 4 * 5), (2, 3, 4, 5)))
525525
for axis in range(-5, 4):
526526
b = expand_dims(a, axis=axis)
527527
xp_assert_equal(b, xp.expand_dims(a, axis=axis))

tests/test_testing.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,18 @@ def non_materializable4(x: Array) -> Array:
200200
return non_materializable(x)
201201

202202

203+
def non_materializable5(x: Array) -> Array:
204+
return non_materializable(x)
205+
206+
203207
lazy_xp_function(good_lazy)
204208
# Works on JAX and Dask
205209
lazy_xp_function(non_materializable2, jax_jit=False, allow_dask_compute=2)
210+
lazy_xp_function(non_materializable3, jax_jit=False, allow_dask_compute=True)
206211
# Works on JAX, but not Dask
207-
lazy_xp_function(non_materializable3, jax_jit=False, allow_dask_compute=1)
212+
lazy_xp_function(non_materializable4, jax_jit=False, allow_dask_compute=1)
208213
# Works neither on Dask nor JAX
209-
lazy_xp_function(non_materializable4)
214+
lazy_xp_function(non_materializable5)
210215

211216

212217
def test_lazy_xp_function(xp: ModuleType):
@@ -217,29 +222,30 @@ def test_lazy_xp_function(xp: ModuleType):
217222
xp_assert_equal(non_materializable(x), xp.asarray([1.0, 2.0]))
218223
# Wrapping explicitly disabled
219224
xp_assert_equal(non_materializable2(x), xp.asarray([1.0, 2.0]))
225+
xp_assert_equal(non_materializable3(x), xp.asarray([1.0, 2.0]))
220226

221227
if is_jax_namespace(xp):
222-
xp_assert_equal(non_materializable3(x), xp.asarray([1.0, 2.0]))
228+
xp_assert_equal(non_materializable4(x), xp.asarray([1.0, 2.0]))
223229
with pytest.raises(
224230
TypeError, match="Attempted boolean conversion of traced array"
225231
):
226-
_ = non_materializable4(x) # Wrapped
232+
_ = non_materializable5(x) # Wrapped
227233

228234
elif is_dask_namespace(xp):
229235
with pytest.raises(
230236
AssertionError,
231237
match=r"dask\.compute.* 2 times, but only up to 1 calls are allowed",
232238
):
233-
_ = non_materializable3(x)
239+
_ = non_materializable4(x)
234240
with pytest.raises(
235241
AssertionError,
236242
match=r"dask\.compute.* 1 times, but no calls are allowed",
237243
):
238-
_ = non_materializable4(x)
244+
_ = non_materializable5(x)
239245

240246
else:
241-
xp_assert_equal(non_materializable3(x), xp.asarray([1.0, 2.0]))
242247
xp_assert_equal(non_materializable4(x), xp.asarray([1.0, 2.0]))
248+
xp_assert_equal(non_materializable5(x), xp.asarray([1.0, 2.0]))
243249

244250

245251
def static_params(x: Array, n: int, flag: bool = False) -> Array:

vendor_tests/test_vendor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ def test_vendor_compat():
2323
is_torch_namespace,
2424
is_writeable_array,
2525
size,
26+
to_device,
2627
)
2728

2829
x = xp.asarray([1, 2, 3])
2930
assert array_namespace(x) is xp
30-
device(x)
31+
to_device(x, device(x))
3132
assert is_array_api_obj(x)
3233
assert is_array_api_strict_namespace(xp)
3334
assert not is_cupy_array(x)

0 commit comments

Comments
 (0)