Skip to content

TST: run tests on CPU+GPU #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ disable_error_code = ["redundant-expr", "unreachable", "no-any-return"]

[[tool.mypy.overrides]]
# slow/unavailable on Windows; do not add to the lint env
module = ["dask.*", "jax.*"]
module = ["dask.*", "jax.*", "torch.*"]
ignore_missing_imports = true

# pyright
Expand Down
10 changes: 9 additions & 1 deletion src/array_api_extra/_lib/_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from types import ModuleType
from typing import TYPE_CHECKING, ClassVar, cast

from ._utils import _compat
from ._utils._compat import (
array_namespace,
is_dask_array,
is_jax_array,
is_torch_array,
is_writeable_array,
)
from ._utils._helpers import meta_namespace
Expand Down Expand Up @@ -298,7 +300,7 @@ def _op(
and idx.dtype == xp.bool
and idx.shape == x.shape
):
y_xp = xp.asarray(y, dtype=x.dtype)
y_xp = xp.asarray(y, dtype=x.dtype, device=_compat.device(x))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Untested fix, which only benefits eager JAX.
On jax.jit, device propagation fails due to jax-ml/jax#26000

In a follow-up PR I'll rework the test_device tests to align them to the pattern recently established in scipy.

if y_xp.ndim == 0:
if out_of_place_op: # add(), subtract(), ...
# suppress inf warnings on Dask
Expand Down Expand Up @@ -344,6 +346,12 @@ def _op(
msg = f"Can't update read-only array {x}"
raise ValueError(msg)

# Work around bug in PyTorch where __setitem__ doesn't
# always support mismatched dtypes
# https://github.com/pytorch/pytorch/issues/150017
if is_torch_array(y):
y = xp.astype(y, x.dtype, copy=False)

# Backends without boolean indexing (other than JAX) crash here
if in_place_op: # add(), subtract(), ...
x[idx] = in_place_op(x[idx], y)
Expand Down
6 changes: 5 additions & 1 deletion src/array_api_extra/_lib/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an
NUMPY_READONLY = "numpy:readonly", _compat.is_numpy_namespace
CUPY = "cupy", _compat.is_cupy_namespace
TORCH = "torch", _compat.is_torch_namespace
TORCH_GPU = "torch:gpu", _compat.is_torch_namespace
DASK = "dask.array", _compat.is_dask_namespace
SPARSE = "sparse", _compat.is_pydata_sparse_namespace
JAX = "jax.numpy", _compat.is_jax_namespace
JAX_GPU = "jax.numpy:gpu", _compat.is_jax_namespace
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an aside, I profoundly dislike that this enum is used in production by the dispatch mechanism.
@lucascolley would it be OK if I make it test-only again and just use is_*_namespace in dispatch?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I profoundly dislike that this enum is used in production by the dispatch mechanism.

what motivates your dislike?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enum contains a lot of "backends" that are just variant duplicates, which is an artifact of using this enum to parametrize the xp fixture. Which is a decent hack for pytest, but which doesn't make any sense in the context of the dispatch system.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, I see the problem. Happy to go with whatever works best for you. FWIW I would like to keep using an enum in the dispatch mechanism, but fair enough if it is too messy to maintain the like mapping.


def __new__(
cls, value: str, _is_namespace: Callable[[ModuleType], bool]
Expand All @@ -54,7 +56,9 @@ def __init__(

def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
"""Pretty-print parameterized test names."""
return self.name.lower()
return (
self.name.lower().replace("_gpu", ":gpu").replace("_readonly", ":readonly")
)

@property
def modname(self) -> str: # numpydoc ignore=RT01
Expand Down
63 changes: 28 additions & 35 deletions src/array_api_extra/_lib/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@ def _check_ns_shape_dtype(
return desired_xp


def _prepare_for_test(array: Array, xp: ModuleType) -> Array:
"""
Ensure that the array can be compared with xp.testing or np.testing.

This involves transferring it from GPU to CPU memory, densifying it, etc.
"""
if is_torch_namespace(xp):
return array.cpu() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
if is_pydata_sparse_namespace(xp):
return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
if is_array_api_strict_namespace(xp):
# Note: we deliberately did not add a `.to_device` method in _typing.pyi
# even if it is required by the standard as many backends don't support it
return array.to_device(xp.Device("CPU_DEVICE")) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
# Note: nothing to do for CuPy, because it uses a bespoke test function
return array


def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
"""
Array-API compatible version of `np.testing.assert_array_equal`.
Expand All @@ -84,6 +102,8 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
"""
xp = _check_ns_shape_dtype(actual, desired)
actual = _prepare_for_test(actual, xp)
desired = _prepare_for_test(desired, xp)

if is_cupy_namespace(xp):
xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
Expand All @@ -102,22 +122,7 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
else:
import numpy as np # pylint: disable=import-outside-toplevel

if is_pydata_sparse_namespace(xp):
actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]

actual_np = None
desired_np = None
if is_array_api_strict_namespace(xp):
# __array__ doesn't work on array-api-strict device arrays
# We need to convert to the CPU device first
actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
desired_np = np.asarray(xp.asarray(desired, device=xp.Device("CPU_DEVICE")))

# JAX/Dask arrays work with `np.testing`
actual_np = actual if actual_np is None else actual_np
desired_np = desired if desired_np is None else desired_np
np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg) # pyright: ignore[reportUnknownArgumentType]
np.testing.assert_array_equal(actual, desired, err_msg=err_msg)


def xp_assert_close(
Expand Down Expand Up @@ -165,6 +170,9 @@ def xp_assert_close(
elif rtol is None:
rtol = 1e-7

actual = _prepare_for_test(actual, xp)
desired = _prepare_for_test(desired, xp)

if is_cupy_namespace(xp):
xp.testing.assert_allclose(
actual, desired, rtol=rtol, atol=atol, err_msg=err_msg
Expand All @@ -176,26 +184,11 @@ def xp_assert_close(
else:
import numpy as np # pylint: disable=import-outside-toplevel

if is_pydata_sparse_namespace(xp):
actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]

actual_np = None
desired_np = None
if is_array_api_strict_namespace(xp):
# __array__ doesn't work on array-api-strict device arrays
# We need to convert to the CPU device first
actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE")))
desired_np = np.asarray(xp.asarray(desired, device=xp.Device("CPU_DEVICE")))

# JAX/Dask arrays work with `np.testing`
actual_np = actual if actual_np is None else actual_np
desired_np = desired if desired_np is None else desired_np

# JAX/Dask arrays work directly with `np.testing`
assert isinstance(rtol, float)
np.testing.assert_allclose( # pyright: ignore[reportCallIssue]
actual_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
desired_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
np.testing.assert_allclose( # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
actual, # pyright: ignore[reportArgumentType]
desired, # pyright: ignore[reportArgumentType]
rtol=rtol,
atol=atol,
err_msg=err_msg,
Expand Down
21 changes: 20 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,31 @@ def xp(
# in the global scope of the module containing the test function.
patch_lazy_xp_functions(request, monkeypatch, xp=xp)

if library == Backend.JAX:
if library.like(Backend.JAX):
import jax

# suppress unused-ignore to run mypy in -e lint as well as -e dev
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]

if library == Backend.JAX_GPU:
try:
device = jax.devices("cuda")[0]
except RuntimeError:
pytest.skip("no CUDA device available")
else:
device = jax.devices("cpu")[0]
jax.config.update("jax_default_device", device)

elif library == Backend.TORCH_GPU:
import torch.cuda

if not torch.cuda.is_available():
pytest.skip("no CUDA device available")
xp.set_default_device("cuda")

elif library == Backend.TORCH: # CPU
xp.set_default_device("cpu")

yield xp


Expand Down
10 changes: 8 additions & 2 deletions tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def assert_copy(
pytest.mark.skip_xp_backend( # test passes when copy=False
Backend.JAX, reason="bool mask update with shaped rhs"
),
pytest.mark.skip_xp_backend( # test passes when copy=False
Backend.JAX_GPU, reason="bool mask update with shaped rhs"
),
pytest.mark.xfail_xp_backend(
Backend.DASK, reason="bool mask update with shaped rhs"
),
Expand Down Expand Up @@ -247,14 +250,14 @@ def test_incompatible_dtype(
idx = xp.asarray([True, False]) if bool_mask else slice(None)
z = None

if library is Backend.JAX:
if library.like(Backend.JAX):
if bool_mask:
z = at_op(x, idx, op, 1.1, copy=copy)
else:
with pytest.warns(FutureWarning, match="cannot safely cast"):
z = at_op(x, idx, op, 1.1, copy=copy)

elif library is Backend.DASK:
elif library.like(Backend.DASK):
z = at_op(x, idx, op, 1.1, copy=copy)

elif library.like(Backend.ARRAY_API_STRICT) and op is not _AtOp.SET:
Expand Down Expand Up @@ -302,6 +305,9 @@ def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
Backend.NUMPY_READONLY, reason="read-only backend"
),
pytest.mark.skip_xp_backend(Backend.JAX, reason="read-only backend"),
pytest.mark.skip_xp_backend(
Backend.JAX_GPU, reason="read-only backend"
),
pytest.mark.skip_xp_backend(Backend.SPARSE, reason="read-only backend"),
],
),
Expand Down
5 changes: 4 additions & 1 deletion tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def test_hypothesis( # type: ignore[explicit-any,decorated-any]

# cupy/cupy#8382
# https://github.com/jax-ml/jax/issues/26658
elements = {"allow_subnormal": library not in (Backend.CUPY, Backend.JAX)}
elements = {"allow_subnormal": not library.like(Backend.CUPY, Backend.JAX)}

fill_value = xp.asarray(
data.draw(npst.arrays(dtype=dtype, shape=(), elements=elements))
Expand Down Expand Up @@ -930,6 +930,9 @@ class TestSetDiff1D:
@pytest.mark.xfail_xp_backend(
Backend.TORCH, reason="index_select not implemented for uint32"
)
@pytest.mark.xfail_xp_backend(
Backend.TORCH_GPU, reason="index_select not implemented for uint32"
)
def test_setdiff1d(self, xp: ModuleType):
x1 = xp.asarray([6, 5, 4, 7, 1, 2, 7, 4])
x2 = xp.asarray([2, 4, 3, 3, 2, 1, 5])
Expand Down
14 changes: 12 additions & 2 deletions tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
True,
marks=[
pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"),
pytest.mark.skip_xp_backend(
Backend.TORCH_GPU, reason="device->host copy"
),
pytest.mark.skip_xp_backend(Backend.SPARSE, reason="densification"),
],
),
Expand Down Expand Up @@ -100,6 +103,9 @@ def f(x: Array) -> tuple[Array, Array]:
True,
marks=[
pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"),
pytest.mark.skip_xp_backend(
Backend.TORCH_GPU, reason="device->host copy"
),
pytest.mark.skip_xp_backend(Backend.SPARSE, reason="densification"),
],
),
Expand Down Expand Up @@ -216,7 +222,7 @@ def test_lazy_apply_none_shape_in_args(xp: ModuleType, library: Backend):
int_type = xp.asarray(0).dtype

ctx: contextlib.AbstractContextManager[object]
if library is Backend.JAX:
if library.like(Backend.JAX):
ctx = pytest.raises(ValueError, match="Output shape must be fully known")
elif library is Backend.ARRAY_API_STRICTEST:
ctx = pytest.raises(RuntimeError, match="data-dependent shapes")
Expand Down Expand Up @@ -254,6 +260,7 @@ def f(x: Array) -> Array:

@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array")
@pytest.mark.skip_xp_backend(Backend.JAX, reason="boolean indexing")
@pytest.mark.skip_xp_backend(Backend.JAX_GPU, reason="boolean indexing")
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing")
def test_lazy_apply_none_shape_broadcast(xp: ModuleType):
"""Broadcast from input array with unknown shape"""
Expand All @@ -273,6 +280,9 @@ def test_lazy_apply_none_shape_broadcast(xp: ModuleType):
Backend.ARRAY_API_STRICT, reason="device->host copy"
),
pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"),
pytest.mark.skip_xp_backend(
Backend.TORCH_GPU, reason="device->host copy"
),
pytest.mark.skip_xp_backend(Backend.SPARSE, reason="densification"),
],
),
Expand Down Expand Up @@ -321,7 +331,7 @@ def f(x: Array, y: None, z: int | Array) -> Array:
assert isinstance(x, mtyp)
assert y is None
# jax.pure_callback wraps scalar args
assert isinstance(z, mtyp if library is Backend.JAX else int)
assert isinstance(z, mtyp if library.like(Backend.JAX) else int)
return x + z

x = xp.asarray([1, 2])
Expand Down
5 changes: 3 additions & 2 deletions tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def test_lazy_xp_function_static_params(xp: ModuleType, func: Callable[..., Arra
erf = None


@pytest.mark.skip_xp_backend(Backend.TORCH_GPU, reason="device->host copy")
@pytest.mark.filterwarnings("ignore:__array_wrap__:DeprecationWarning") # PyTorch
def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend):
pytest.importorskip("scipy")
Expand Down Expand Up @@ -293,12 +294,12 @@ def test_lazy_xp_modules(xp: ModuleType, library: Backend):
y = naked.f(x)
xp_assert_equal(y, x)

if library is Backend.JAX:
if library.like(Backend.JAX):
with pytest.raises(
TypeError, match="Attempted boolean conversion of traced array"
):
wrapped.f(x)
elif library is Backend.DASK:
elif library.like(Backend.DASK):
with pytest.raises(AssertionError, match=r"dask\.compute"):
wrapped.f(x)
else:
Expand Down