Skip to content

Commit d3f6f67

Browse files
authored
TST: run tests on CPU+GPU (#221)
* TST: run tests on CPU+GPU * lock * Simplify conftest * capitalisation
1 parent 06c4308 commit d3f6f67

File tree

10 files changed

+91
-47
lines changed

10 files changed

+91
-47
lines changed

pixi.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ disable_error_code = ["redundant-expr", "unreachable", "no-any-return"]
238238

239239
[[tool.mypy.overrides]]
240240
# slow/unavailable on Windows; do not add to the lint env
241-
module = ["dask.*", "jax.*"]
241+
module = ["dask.*", "jax.*", "torch.*"]
242242
ignore_missing_imports = true
243243

244244
# pyright

src/array_api_extra/_lib/_at.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from types import ModuleType
99
from typing import TYPE_CHECKING, ClassVar, cast
1010

11+
from ._utils import _compat
1112
from ._utils._compat import (
1213
array_namespace,
1314
is_dask_array,
1415
is_jax_array,
16+
is_torch_array,
1517
is_writeable_array,
1618
)
1719
from ._utils._helpers import meta_namespace
@@ -298,7 +300,7 @@ def _op(
298300
and idx.dtype == xp.bool
299301
and idx.shape == x.shape
300302
):
301-
y_xp = xp.asarray(y, dtype=x.dtype)
303+
y_xp = xp.asarray(y, dtype=x.dtype, device=_compat.device(x))
302304
if y_xp.ndim == 0:
303305
if out_of_place_op: # add(), subtract(), ...
304306
# suppress inf warnings on Dask
@@ -344,6 +346,12 @@ def _op(
344346
msg = f"Can't update read-only array {x}"
345347
raise ValueError(msg)
346348

349+
# Work around bug in PyTorch where __setitem__ doesn't
350+
# always support mismatched dtypes
351+
# https://github.com/pytorch/pytorch/issues/150017
352+
if is_torch_array(y):
353+
y = xp.astype(y, x.dtype, copy=False)
354+
347355
# Backends without boolean indexing (other than JAX) crash here
348356
if in_place_op: # add(), subtract(), ...
349357
x[idx] = in_place_op(x[idx], y)

src/array_api_extra/_lib/_backends.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an
3434
NUMPY_READONLY = "numpy:readonly", _compat.is_numpy_namespace
3535
CUPY = "cupy", _compat.is_cupy_namespace
3636
TORCH = "torch", _compat.is_torch_namespace
37+
TORCH_GPU = "torch:gpu", _compat.is_torch_namespace
3738
DASK = "dask.array", _compat.is_dask_namespace
3839
SPARSE = "sparse", _compat.is_pydata_sparse_namespace
3940
JAX = "jax.numpy", _compat.is_jax_namespace
41+
JAX_GPU = "jax.numpy:gpu", _compat.is_jax_namespace
4042

4143
def __new__(
4244
cls, value: str, _is_namespace: Callable[[ModuleType], bool]
@@ -54,7 +56,9 @@ def __init__(
5456

5557
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
5658
"""Pretty-print parameterized test names."""
57-
return self.name.lower()
59+
return (
60+
self.name.lower().replace("_gpu", ":gpu").replace("_readonly", ":readonly")
61+
)
5862

5963
@property
6064
def modname(self) -> str: # numpydoc ignore=RT01

src/array_api_extra/_lib/_testing.py

+28-35
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,24 @@ def _check_ns_shape_dtype(
6565
return desired_xp
6666

6767

68+
def _prepare_for_test(array: Array, xp: ModuleType) -> Array:
69+
"""
70+
Ensure that the array can be compared with xp.testing or np.testing.
71+
72+
This involves transferring it from GPU to CPU memory, densifying it, etc.
73+
"""
74+
if is_torch_namespace(xp):
75+
return array.cpu() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
76+
if is_pydata_sparse_namespace(xp):
77+
return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
78+
if is_array_api_strict_namespace(xp):
79+
# Note: we deliberately did not add a `.to_device` method in _typing.pyi
80+
# even if it is required by the standard as many backends don't support it
81+
return array.to_device(xp.Device("CPU_DEVICE")) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
82+
# Note: nothing to do for CuPy, because it uses a bespoke test function
83+
return array
84+
85+
6886
def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
6987
"""
7088
Array-API compatible version of `np.testing.assert_array_equal`.
@@ -84,6 +102,8 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
84102
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
85103
"""
86104
xp = _check_ns_shape_dtype(actual, desired)
105+
actual = _prepare_for_test(actual, xp)
106+
desired = _prepare_for_test(desired, xp)
87107

88108
if is_cupy_namespace(xp):
89109
xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
@@ -102,22 +122,7 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
102122
else:
103123
import numpy as np # pylint: disable=import-outside-toplevel
104124

105-
if is_pydata_sparse_namespace(xp):
106-
actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
107-
desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
108-
109-
actual_np = None
110-
desired_np = None
111-
if is_array_api_strict_namespace(xp):
112-
# __array__ doesn't work on array-api-strict device arrays
113-
# We need to convert to the CPU device first
114-
actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
115-
desired_np = np.asarray(xp.asarray(desired, device=xp.Device("CPU_DEVICE")))
116-
117-
# JAX/Dask arrays work with `np.testing`
118-
actual_np = actual if actual_np is None else actual_np
119-
desired_np = desired if desired_np is None else desired_np
120-
np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg) # pyright: ignore[reportUnknownArgumentType]
125+
np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
121126

122127

123128
def xp_assert_close(
@@ -165,6 +170,9 @@ def xp_assert_close(
165170
elif rtol is None:
166171
rtol = 1e-7
167172

173+
actual = _prepare_for_test(actual, xp)
174+
desired = _prepare_for_test(desired, xp)
175+
168176
if is_cupy_namespace(xp):
169177
xp.testing.assert_allclose(
170178
actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
@@ -176,26 +184,11 @@ def xp_assert_close(
176184
else:
177185
import numpy as np # pylint: disable=import-outside-toplevel
178186

179-
if is_pydata_sparse_namespace(xp):
180-
actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
181-
desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
182-
183-
actual_np = None
184-
desired_np = None
185-
if is_array_api_strict_namespace(xp):
186-
# __array__ doesn't work on array-api-strict device arrays
187-
# We need to convert to the CPU device first
188-
actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
189-
desired_np = np.asarray(xp.asarray(desired, device=xp.Device("CPU_DEVICE")))
190-
191-
# JAX/Dask arrays work with `np.testing`
192-
actual_np = actual if actual_np is None else actual_np
193-
desired_np = desired if desired_np is None else desired_np
194-
187+
# JAX/Dask arrays work directly with `np.testing`
195188
assert isinstance(rtol, float)
196-
np.testing.assert_allclose( # pyright: ignore[reportCallIssue]
197-
actual_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
198-
desired_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
189+
np.testing.assert_allclose( # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
190+
actual, # pyright: ignore[reportArgumentType]
191+
desired, # pyright: ignore[reportArgumentType]
199192
rtol=rtol,
200193
atol=atol,
201194
err_msg=err_msg,

tests/conftest.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,31 @@ def xp(
139139
# in the global scope of the module containing the test function.
140140
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
141141

142-
if library == Backend.JAX:
142+
if library.like(Backend.JAX):
143143
import jax
144144

145145
# suppress unused-ignore to run mypy in -e lint as well as -e dev
146146
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]
147147

148+
if library == Backend.JAX_GPU:
149+
try:
150+
device = jax.devices("cuda")[0]
151+
except RuntimeError:
152+
pytest.skip("no CUDA device available")
153+
else:
154+
device = jax.devices("cpu")[0]
155+
jax.config.update("jax_default_device", device)
156+
157+
elif library == Backend.TORCH_GPU:
158+
import torch.cuda
159+
160+
if not torch.cuda.is_available():
161+
pytest.skip("no CUDA device available")
162+
xp.set_default_device("cuda")
163+
164+
elif library == Backend.TORCH: # CPU
165+
xp.set_default_device("cpu")
166+
148167
yield xp
149168

150169

tests/test_at.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ def assert_copy(
118118
pytest.mark.skip_xp_backend( # test passes when copy=False
119119
Backend.JAX, reason="bool mask update with shaped rhs"
120120
),
121+
pytest.mark.skip_xp_backend( # test passes when copy=False
122+
Backend.JAX_GPU, reason="bool mask update with shaped rhs"
123+
),
121124
pytest.mark.xfail_xp_backend(
122125
Backend.DASK, reason="bool mask update with shaped rhs"
123126
),
@@ -247,14 +250,14 @@ def test_incompatible_dtype(
247250
idx = xp.asarray([True, False]) if bool_mask else slice(None)
248251
z = None
249252

250-
if library is Backend.JAX:
253+
if library.like(Backend.JAX):
251254
if bool_mask:
252255
z = at_op(x, idx, op, 1.1, copy=copy)
253256
else:
254257
with pytest.warns(FutureWarning, match="cannot safely cast"):
255258
z = at_op(x, idx, op, 1.1, copy=copy)
256259

257-
elif library is Backend.DASK:
260+
elif library.like(Backend.DASK):
258261
z = at_op(x, idx, op, 1.1, copy=copy)
259262

260263
elif library.like(Backend.ARRAY_API_STRICT) and op is not _AtOp.SET:
@@ -302,6 +305,9 @@ def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
302305
Backend.NUMPY_READONLY, reason="read-only backend"
303306
),
304307
pytest.mark.skip_xp_backend(Backend.JAX, reason="read-only backend"),
308+
pytest.mark.skip_xp_backend(
309+
Backend.JAX_GPU, reason="read-only backend"
310+
),
305311
pytest.mark.skip_xp_backend(Backend.SPARSE, reason="read-only backend"),
306312
],
307313
),

tests/test_funcs.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def test_hypothesis( # type: ignore[explicit-any,decorated-any]
234234

235235
# cupy/cupy#8382
236236
# https://github.com/jax-ml/jax/issues/26658
237-
elements = {"allow_subnormal": library not in (Backend.CUPY, Backend.JAX)}
237+
elements = {"allow_subnormal": not library.like(Backend.CUPY, Backend.JAX)}
238238

239239
fill_value = xp.asarray(
240240
data.draw(npst.arrays(dtype=dtype, shape=(), elements=elements))
@@ -930,6 +930,9 @@ class TestSetDiff1D:
930930
@pytest.mark.xfail_xp_backend(
931931
Backend.TORCH, reason="index_select not implemented for uint32"
932932
)
933+
@pytest.mark.xfail_xp_backend(
934+
Backend.TORCH_GPU, reason="index_select not implemented for uint32"
935+
)
933936
def test_setdiff1d(self, xp: ModuleType):
934937
x1 = xp.asarray([6, 5, 4, 7, 1, 2, 7, 4])
935938
x2 = xp.asarray([2, 4, 3, 3, 2, 1, 5])

tests/test_lazy.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
True,
2828
marks=[
2929
pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"),
30+
pytest.mark.skip_xp_backend(
31+
Backend.TORCH_GPU, reason="device->host copy"
32+
),
3033
pytest.mark.skip_xp_backend(Backend.SPARSE, reason="densification"),
3134
],
3235
),
@@ -100,6 +103,9 @@ def f(x: Array) -> tuple[Array, Array]:
100103
True,
101104
marks=[
102105
pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"),
106+
pytest.mark.skip_xp_backend(
107+
Backend.TORCH_GPU, reason="device->host copy"
108+
),
103109
pytest.mark.skip_xp_backend(Backend.SPARSE, reason="densification"),
104110
],
105111
),
@@ -216,7 +222,7 @@ def test_lazy_apply_none_shape_in_args(xp: ModuleType, library: Backend):
216222
int_type = xp.asarray(0).dtype
217223

218224
ctx: contextlib.AbstractContextManager[object]
219-
if library is Backend.JAX:
225+
if library.like(Backend.JAX):
220226
ctx = pytest.raises(ValueError, match="Output shape must be fully known")
221227
elif library is Backend.ARRAY_API_STRICTEST:
222228
ctx = pytest.raises(RuntimeError, match="data-dependent shapes")
@@ -254,6 +260,7 @@ def f(x: Array) -> Array:
254260

255261
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array")
256262
@pytest.mark.skip_xp_backend(Backend.JAX, reason="boolean indexing")
263+
@pytest.mark.skip_xp_backend(Backend.JAX_GPU, reason="boolean indexing")
257264
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing")
258265
def test_lazy_apply_none_shape_broadcast(xp: ModuleType):
259266
"""Broadcast from input array with unknown shape"""
@@ -273,6 +280,9 @@ def test_lazy_apply_none_shape_broadcast(xp: ModuleType):
273280
Backend.ARRAY_API_STRICT, reason="device->host copy"
274281
),
275282
pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"),
283+
pytest.mark.skip_xp_backend(
284+
Backend.TORCH_GPU, reason="device->host copy"
285+
),
276286
pytest.mark.skip_xp_backend(Backend.SPARSE, reason="densification"),
277287
],
278288
),
@@ -321,7 +331,7 @@ def f(x: Array, y: None, z: int | Array) -> Array:
321331
assert isinstance(x, mtyp)
322332
assert y is None
323333
# jax.pure_callback wraps scalar args
324-
assert isinstance(z, mtyp if library is Backend.JAX else int)
334+
assert isinstance(z, mtyp if library.like(Backend.JAX) else int)
325335
return x + z
326336

327337
x = xp.asarray([1, 2])

tests/test_testing.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def test_lazy_xp_function_static_params(xp: ModuleType, func: Callable[..., Arra
218218
erf = None
219219

220220

221+
@pytest.mark.skip_xp_backend(Backend.TORCH_GPU, reason="device->host copy")
221222
@pytest.mark.filterwarnings("ignore:__array_wrap__:DeprecationWarning") # PyTorch
222223
def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend):
223224
pytest.importorskip("scipy")
@@ -293,12 +294,12 @@ def test_lazy_xp_modules(xp: ModuleType, library: Backend):
293294
y = naked.f(x)
294295
xp_assert_equal(y, x)
295296

296-
if library is Backend.JAX:
297+
if library.like(Backend.JAX):
297298
with pytest.raises(
298299
TypeError, match="Attempted boolean conversion of traced array"
299300
):
300301
wrapped.f(x)
301-
elif library is Backend.DASK:
302+
elif library.like(Backend.DASK):
302303
with pytest.raises(AssertionError, match=r"dask\.compute"):
303304
wrapped.f(x)
304305
else:

0 commit comments

Comments
 (0)