diff --git a/pixi.lock b/pixi.lock index c2890eee..71b1afc2 100644 --- a/pixi.lock +++ b/pixi.lock @@ -5256,7 +5256,7 @@ packages: - pypi: . name: array-api-extra version: 0.7.2.dev0 - sha256: 038ce201a10c2f3747ee69453356a561627455e1caa462151433eaeea9bdcea6 + sha256: a46c6db2ae9462de7b2b2078ea507a1bd7e21f2a6170bfed292f306878055331 requires_dist: - array-api-compat>=1.11.2,<2 requires_python: '>=3.10' diff --git a/pyproject.toml b/pyproject.toml index 0471921e..473e933e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -238,7 +238,7 @@ disable_error_code = ["redundant-expr", "unreachable", "no-any-return"] [[tool.mypy.overrides]] # slow/unavailable on Windows; do not add to the lint env -module = ["dask.*", "jax.*"] +module = ["dask.*", "jax.*", "torch.*"] ignore_missing_imports = true # pyright diff --git a/src/array_api_extra/_lib/_at.py b/src/array_api_extra/_lib/_at.py index 736f77bd..870884b8 100644 --- a/src/array_api_extra/_lib/_at.py +++ b/src/array_api_extra/_lib/_at.py @@ -8,10 +8,12 @@ from types import ModuleType from typing import TYPE_CHECKING, ClassVar, cast +from ._utils import _compat from ._utils._compat import ( array_namespace, is_dask_array, is_jax_array, + is_torch_array, is_writeable_array, ) from ._utils._helpers import meta_namespace @@ -298,7 +300,7 @@ def _op( and idx.dtype == xp.bool and idx.shape == x.shape ): - y_xp = xp.asarray(y, dtype=x.dtype) + y_xp = xp.asarray(y, dtype=x.dtype, device=_compat.device(x)) if y_xp.ndim == 0: if out_of_place_op: # add(), subtract(), ... # suppress inf warnings on Dask @@ -344,6 +346,12 @@ def _op( msg = f"Can't update read-only array {x}" raise ValueError(msg) + # Work around bug in PyTorch where __setitem__ doesn't + # always support mismatched dtypes + # https://github.com/pytorch/pytorch/issues/150017 + if is_torch_array(y): + y = xp.astype(y, x.dtype, copy=False) + # Backends without boolean indexing (other than JAX) crash here if in_place_op: # add(), subtract(), ... x[idx] = in_place_op(x[idx], y) diff --git a/src/array_api_extra/_lib/_backends.py b/src/array_api_extra/_lib/_backends.py index 3beb676d..e30afd55 100644 --- a/src/array_api_extra/_lib/_backends.py +++ b/src/array_api_extra/_lib/_backends.py @@ -34,9 +34,11 @@ class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-an NUMPY_READONLY = "numpy:readonly", _compat.is_numpy_namespace CUPY = "cupy", _compat.is_cupy_namespace TORCH = "torch", _compat.is_torch_namespace + TORCH_GPU = "torch:gpu", _compat.is_torch_namespace DASK = "dask.array", _compat.is_dask_namespace SPARSE = "sparse", _compat.is_pydata_sparse_namespace JAX = "jax.numpy", _compat.is_jax_namespace + JAX_GPU = "jax.numpy:gpu", _compat.is_jax_namespace def __new__( cls, value: str, _is_namespace: Callable[[ModuleType], bool] @@ -54,7 +56,9 @@ def __init__( def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01 """Pretty-print parameterized test names.""" - return self.name.lower() + return ( + self.name.lower().replace("_gpu", ":gpu").replace("_readonly", ":readonly") + ) @property def modname(self) -> str: # numpydoc ignore=RT01 diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index e5ec16a6..319297c8 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -65,6 +65,24 @@ def _check_ns_shape_dtype( return desired_xp +def _prepare_for_test(array: Array, xp: ModuleType) -> Array: + """ + Ensure that the array can be compared with xp.testing or np.testing. + + This involves transferring it from GPU to CPU memory, densifying it, etc. + """ + if is_torch_namespace(xp): + return array.cpu() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + if is_pydata_sparse_namespace(xp): + return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + if is_array_api_strict_namespace(xp): + # Note: we deliberately did not add a `.to_device` method in _typing.pyi + # even if it is required by the standard as many backends don't support it + return array.to_device(xp.Device("CPU_DEVICE")) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + # Note: nothing to do for CuPy, because it uses a bespoke test function + return array + + def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None: """ Array-API compatible version of `np.testing.assert_array_equal`. @@ -84,6 +102,8 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None: numpy.testing.assert_array_equal : Similar function for NumPy arrays. """ xp = _check_ns_shape_dtype(actual, desired) + actual = _prepare_for_test(actual, xp) + desired = _prepare_for_test(desired, xp) if is_cupy_namespace(xp): xp.testing.assert_array_equal(actual, desired, err_msg=err_msg) @@ -102,22 +122,7 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None: else: import numpy as np # pylint: disable=import-outside-toplevel - if is_pydata_sparse_namespace(xp): - actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - - actual_np = None - desired_np = None - if is_array_api_strict_namespace(xp): - # __array__ doesn't work on array-api-strict device arrays - # We need to convert to the CPU device first - actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) - desired_np = np.asarray(xp.asarray(desired, device=xp.Device("CPU_DEVICE"))) - - # JAX/Dask arrays work with `np.testing` - actual_np = actual if actual_np is None else actual_np - desired_np = desired if desired_np is None else desired_np - np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg) # pyright: ignore[reportUnknownArgumentType] + np.testing.assert_array_equal(actual, desired, err_msg=err_msg) def xp_assert_close( @@ -165,6 +170,9 @@ def xp_assert_close( elif rtol is None: rtol = 1e-7 + actual = _prepare_for_test(actual, xp) + desired = _prepare_for_test(desired, xp) + if is_cupy_namespace(xp): xp.testing.assert_allclose( actual, desired, rtol=rtol, atol=atol, err_msg=err_msg @@ -176,26 +184,11 @@ def xp_assert_close( else: import numpy as np # pylint: disable=import-outside-toplevel - if is_pydata_sparse_namespace(xp): - actual = actual.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - desired = desired.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - - actual_np = None - desired_np = None - if is_array_api_strict_namespace(xp): - # __array__ doesn't work on array-api-strict device arrays - # We need to convert to the CPU device first - actual_np = np.asarray(xp.asarray(actual, device=xp.Device("CPU_DEVICE"))) - desired_np = np.asarray(xp.asarray(desired, device=xp.Device("CPU_DEVICE"))) - - # JAX/Dask arrays work with `np.testing` - actual_np = actual if actual_np is None else actual_np - desired_np = desired if desired_np is None else desired_np - + # JAX/Dask arrays work directly with `np.testing` assert isinstance(rtol, float) - np.testing.assert_allclose( # pyright: ignore[reportCallIssue] - actual_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] - desired_np, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + np.testing.assert_allclose( # type: ignore[call-overload] # pyright: ignore[reportCallIssue] + actual, # pyright: ignore[reportArgumentType] + desired, # pyright: ignore[reportArgumentType] rtol=rtol, atol=atol, err_msg=err_msg, diff --git a/tests/conftest.py b/tests/conftest.py index 4e368850..70854249 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -139,12 +139,31 @@ def xp( # in the global scope of the module containing the test function. patch_lazy_xp_functions(request, monkeypatch, xp=xp) - if library == Backend.JAX: + if library.like(Backend.JAX): import jax # suppress unused-ignore to run mypy in -e lint as well as -e dev jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore] + if library == Backend.JAX_GPU: + try: + device = jax.devices("cuda")[0] + except RuntimeError: + pytest.skip("no CUDA device available") + else: + device = jax.devices("cpu")[0] + jax.config.update("jax_default_device", device) + + elif library == Backend.TORCH_GPU: + import torch.cuda + + if not torch.cuda.is_available(): + pytest.skip("no CUDA device available") + xp.set_default_device("cuda") + + elif library == Backend.TORCH: # CPU + xp.set_default_device("cpu") + yield xp diff --git a/tests/test_at.py b/tests/test_at.py index 4bde5ce2..926685cb 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -118,6 +118,9 @@ def assert_copy( pytest.mark.skip_xp_backend( # test passes when copy=False Backend.JAX, reason="bool mask update with shaped rhs" ), + pytest.mark.skip_xp_backend( # test passes when copy=False + Backend.JAX_GPU, reason="bool mask update with shaped rhs" + ), pytest.mark.xfail_xp_backend( Backend.DASK, reason="bool mask update with shaped rhs" ), @@ -247,14 +250,14 @@ def test_incompatible_dtype( idx = xp.asarray([True, False]) if bool_mask else slice(None) z = None - if library is Backend.JAX: + if library.like(Backend.JAX): if bool_mask: z = at_op(x, idx, op, 1.1, copy=copy) else: with pytest.warns(FutureWarning, match="cannot safely cast"): z = at_op(x, idx, op, 1.1, copy=copy) - elif library is Backend.DASK: + elif library.like(Backend.DASK): z = at_op(x, idx, op, 1.1, copy=copy) elif library.like(Backend.ARRAY_API_STRICT) and op is not _AtOp.SET: @@ -302,6 +305,9 @@ def test_no_inf_warnings(xp: ModuleType, bool_mask: bool): Backend.NUMPY_READONLY, reason="read-only backend" ), pytest.mark.skip_xp_backend(Backend.JAX, reason="read-only backend"), + pytest.mark.skip_xp_backend( + Backend.JAX_GPU, reason="read-only backend" + ), pytest.mark.skip_xp_backend(Backend.SPARSE, reason="read-only backend"), ], ), diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 48ad7b06..553df5dc 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -234,7 +234,7 @@ def test_hypothesis( # type: ignore[explicit-any,decorated-any] # cupy/cupy#8382 # https://github.com/jax-ml/jax/issues/26658 - elements = {"allow_subnormal": library not in (Backend.CUPY, Backend.JAX)} + elements = {"allow_subnormal": not library.like(Backend.CUPY, Backend.JAX)} fill_value = xp.asarray( data.draw(npst.arrays(dtype=dtype, shape=(), elements=elements)) @@ -930,6 +930,9 @@ class TestSetDiff1D: @pytest.mark.xfail_xp_backend( Backend.TORCH, reason="index_select not implemented for uint32" ) + @pytest.mark.xfail_xp_backend( + Backend.TORCH_GPU, reason="index_select not implemented for uint32" + ) def test_setdiff1d(self, xp: ModuleType): x1 = xp.asarray([6, 5, 4, 7, 1, 2, 7, 4]) x2 = xp.asarray([2, 4, 3, 3, 2, 1, 5]) diff --git a/tests/test_lazy.py b/tests/test_lazy.py index d360e50d..8690c33e 100644 --- a/tests/test_lazy.py +++ b/tests/test_lazy.py @@ -27,6 +27,9 @@ True, marks=[ pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"), + pytest.mark.skip_xp_backend( + Backend.TORCH_GPU, reason="device->host copy" + ), pytest.mark.skip_xp_backend(Backend.SPARSE, reason="densification"), ], ), @@ -100,6 +103,9 @@ def f(x: Array) -> tuple[Array, Array]: True, marks=[ pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"), + pytest.mark.skip_xp_backend( + Backend.TORCH_GPU, reason="device->host copy" + ), pytest.mark.skip_xp_backend(Backend.SPARSE, reason="densification"), ], ), @@ -216,7 +222,7 @@ def test_lazy_apply_none_shape_in_args(xp: ModuleType, library: Backend): int_type = xp.asarray(0).dtype ctx: contextlib.AbstractContextManager[object] - if library is Backend.JAX: + if library.like(Backend.JAX): ctx = pytest.raises(ValueError, match="Output shape must be fully known") elif library is Backend.ARRAY_API_STRICTEST: ctx = pytest.raises(RuntimeError, match="data-dependent shapes") @@ -254,6 +260,7 @@ def f(x: Array) -> Array: @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array") @pytest.mark.skip_xp_backend(Backend.JAX, reason="boolean indexing") +@pytest.mark.skip_xp_backend(Backend.JAX_GPU, reason="boolean indexing") @pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing") def test_lazy_apply_none_shape_broadcast(xp: ModuleType): """Broadcast from input array with unknown shape""" @@ -273,6 +280,9 @@ def test_lazy_apply_none_shape_broadcast(xp: ModuleType): Backend.ARRAY_API_STRICT, reason="device->host copy" ), pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"), + pytest.mark.skip_xp_backend( + Backend.TORCH_GPU, reason="device->host copy" + ), pytest.mark.skip_xp_backend(Backend.SPARSE, reason="densification"), ], ), @@ -321,7 +331,7 @@ def f(x: Array, y: None, z: int | Array) -> Array: assert isinstance(x, mtyp) assert y is None # jax.pure_callback wraps scalar args - assert isinstance(z, mtyp if library is Backend.JAX else int) + assert isinstance(z, mtyp if library.like(Backend.JAX) else int) return x + z x = xp.asarray([1, 2]) diff --git a/tests/test_testing.py b/tests/test_testing.py index 10ce7ab3..9976e6fd 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -218,6 +218,7 @@ def test_lazy_xp_function_static_params(xp: ModuleType, func: Callable[..., Arra erf = None +@pytest.mark.skip_xp_backend(Backend.TORCH_GPU, reason="device->host copy") @pytest.mark.filterwarnings("ignore:__array_wrap__:DeprecationWarning") # PyTorch def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend): pytest.importorskip("scipy") @@ -293,12 +294,12 @@ def test_lazy_xp_modules(xp: ModuleType, library: Backend): y = naked.f(x) xp_assert_equal(y, x) - if library is Backend.JAX: + if library.like(Backend.JAX): with pytest.raises( TypeError, match="Attempted boolean conversion of traced array" ): wrapped.f(x) - elif library is Backend.DASK: + elif library.like(Backend.DASK): with pytest.raises(AssertionError, match=r"dask\.compute"): wrapped.f(x) else: