diff --git a/src/array_api_extra/_lib/_at.py b/src/array_api_extra/_lib/_at.py index 1915da5..a008efa 100644 --- a/src/array_api_extra/_lib/_at.py +++ b/src/array_api_extra/_lib/_at.py @@ -275,16 +275,11 @@ def _op( msg = f"copy must be True, False, or None; got {copy!r}" raise ValueError(msg) - if copy is None: - writeable = is_writeable_array(x) - copy = not writeable - elif copy: - writeable = None - else: - writeable = is_writeable_array(x) + writeable = None if copy else is_writeable_array(x) - # JAX inside jax.jit and Dask don't support in-place updates with boolean - # mask. However we can handle the common special case of 0-dimensional y + # JAX inside jax.jit doesn't support in-place updates with boolean + # masks; Dask exclusively supports __setitem__ but not iops. + # We can handle the common special case of 0-dimensional y # with where(idx, y, x) instead. if ( (is_dask_array(idx) or is_jax_array(idx)) @@ -293,21 +288,22 @@ def _op( ): y_xp = xp.asarray(y, dtype=x.dtype) if y_xp.ndim == 0: - if out_of_place_op: + if out_of_place_op: # add(), subtract(), ... # FIXME: suppress inf warnings on dask with lazywhere out = xp.where(idx, out_of_place_op(x, y_xp), x) # Undo int->float promotion on JAX after _AtOp.DIVIDE out = xp.astype(out, x.dtype, copy=False) - else: + else: # set() out = xp.where(idx, y_xp, x) - if copy: - return out - x[()] = out - return x + if copy is False: + x[()] = out + return x + return out + # else: this will work on eager JAX and crash on jax.jit and Dask - if copy: + if copy or (copy is None and not writeable): if is_jax_array(x): # Use JAX's at[] func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value)) @@ -331,7 +327,7 @@ def _op( msg = f"Can't update read-only array {x}" raise ValueError(msg) - if in_place_op: + if in_place_op: # add(), subtract(), ... x[self._idx] = in_place_op(x[self._idx], y) else: # set() x[self._idx] = y diff --git a/tests/test_at.py b/tests/test_at.py index 447b099..ce27fbf 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -3,7 +3,7 @@ from collections.abc import Callable, Generator from contextlib import contextmanager from types import ModuleType -from typing import Any, cast +from typing import cast import numpy as np import pytest @@ -23,12 +23,13 @@ ] -def at_op( # type: ignore[no-any-explicit] +def at_op( x: Array, idx: Index, op: _AtOp, y: Array | object, - **kwargs: Any, # Test the default copy=None + copy: bool | None = None, + xp: ModuleType | None = None, ) -> Array: """ Wrapper around at(x, idx).op(y, copy=copy, xp=xp). @@ -39,30 +40,33 @@ def at_op( # type: ignore[no-any-explicit] which is not a common use case. """ if isinstance(idx, (slice | tuple)): - return _at_op(x, None, pickle.dumps(idx), op, y, **kwargs) - return _at_op(x, idx, None, op, y, **kwargs) + return _at_op(x, None, pickle.dumps(idx), op, y, copy=copy, xp=xp) + return _at_op(x, idx, None, op, y, copy=copy, xp=xp) -def _at_op( # type: ignore[no-any-explicit] +def _at_op( x: Array, idx: Index | None, idx_pickle: bytes | None, op: _AtOp, y: Array | object, - **kwargs: Any, + copy: bool | None, + xp: ModuleType | None = None, ) -> Array: """jitted helper of at_op""" if idx_pickle: idx = pickle.loads(idx_pickle) meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[no-any-explicit] - return meth(y, **kwargs) + return meth(y, copy=copy, xp=xp) lazy_xp_function(_at_op, static_argnames=("op", "idx_pickle", "copy", "xp")) @contextmanager -def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: +def assert_copy( + array: Array, copy: bool | None, expect_copy: bool | None = None +) -> Generator[None, None, None]: if copy is False and not is_writeable_array(array): with pytest.raises((TypeError, ValueError)): yield @@ -72,24 +76,23 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: array_orig = xp.asarray(array, copy=True) yield - if copy is None: - copy = not is_writeable_array(array) - xp_assert_equal(xp.all(array == array_orig), xp.asarray(copy)) + if expect_copy is None: + expect_copy = copy + if expect_copy: + # Original has not been modified + xp_assert_equal(array, array_orig) + elif expect_copy is False: + # Original has been modified + with pytest.raises(AssertionError): + xp_assert_equal(array, array_orig) + # Test nothing for copy=None. Dask changes behaviour depending on + # whether it's a special case of a bool mask with scalar RHS or not. + +@pytest.mark.parametrize("copy", [False, True, None]) @pytest.mark.parametrize( - ("kwargs", "expect_copy"), - [ - pytest.param({"copy": True}, True, id="copy=True"), - pytest.param({"copy": False}, False, id="copy=False"), - # Behavior is backend-specific - pytest.param({"copy": None}, None, id="copy=None"), - # Test that the copy parameter defaults to None - pytest.param({}, None, id="no copy kwarg"), - ], -) -@pytest.mark.parametrize( - ("op", "y", "expect"), + ("op", "y", "expect_list"), [ (_AtOp.SET, 40.0, [10.0, 40.0, 40.0]), (_AtOp.ADD, 40.0, [10.0, 60.0, 70.0]), @@ -102,14 +105,13 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: ], ) @pytest.mark.parametrize( - ("bool_mask", "shaped_y"), + ("bool_mask", "x_ndim", "y_ndim"), [ - (False, False), - (False, True), - (True, False), # Uses xp.where(idx, y, x) on JAX and Dask + (False, 1, 0), + (False, 1, 1), + (True, 1, 0), # Uses xp.where(idx, y, x) on JAX and Dask pytest.param( - True, - True, + *(True, 1, 1), marks=( pytest.mark.skip_xp_backend( # test passes when copy=False Backend.JAX, reason="bool mask update with shaped rhs" @@ -119,29 +121,65 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: ), ), ), + (False, 0, 0), + (True, 0, 0), ], ) def test_update_ops( xp: ModuleType, - kwargs: dict[str, bool | None], - expect_copy: bool | None, + copy: bool | None, op: _AtOp, y: float, - expect: list[float], + expect_list: list[float], bool_mask: bool, - shaped_y: bool, + x_ndim: int, + y_ndim: int, ): - x = xp.asarray([10.0, 20.0, 30.0]) - idx = xp.asarray([False, True, True]) if bool_mask else slice(1, None) - if shaped_y: + if x_ndim == 1: + x = xp.asarray([10.0, 20.0, 30.0]) + idx = xp.asarray([False, True, True]) if bool_mask else slice(1, None) + expect: list[float] | float = expect_list + else: + idx = xp.asarray(True) if bool_mask else () + # Pick an element that does change with the operation + if op is _AtOp.MIN: + x = xp.asarray(30.0) + expect = expect_list[2] + else: + x = xp.asarray(20.0) + expect = expect_list[1] + + if y_ndim == 1: y = xp.asarray([y, y]) - with assert_copy(x, expect_copy): - z = at_op(x, idx, op, y, **kwargs) + with assert_copy(x, copy): + z = at_op(x, idx, op, y, copy=copy) assert isinstance(z, type(x)) xp_assert_equal(z, xp.asarray(expect)) +@pytest.mark.parametrize("op", list(_AtOp)) +def test_copy_default(xp: ModuleType, library: Backend, op: _AtOp): + """ + Test that the default copy behaviour is False for writeable arrays + and True for read-only ones. + """ + x = xp.asarray([1.0, 10.0, 20.0]) + expect_copy = not is_writeable_array(x) + meth = cast(Callable[..., Array], getattr(at(x)[:2], op.value)) # type: ignore[no-any-explicit] + with assert_copy(x, None, expect_copy): + _ = meth(2.0) + + x = xp.asarray([1.0, 10.0, 20.0]) + # Dask's default copy value is True for bool masks, + # even if the arrays are writeable. + expect_copy = not is_writeable_array(x) or library is Backend.DASK + idx = xp.asarray([True, True, False]) + meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[no-any-explicit] + with assert_copy(x, None, expect_copy): + _ = meth(2.0) + + def test_copy_invalid(): a = np.asarray([1, 2, 3]) with pytest.raises(ValueError, match="copy"): @@ -259,3 +297,46 @@ def test_no_inf_warnings(xp: ModuleType, bool_mask: bool): # inf - inf -> nan with a warning z = at_op(x, idx, _AtOp.SUBTRACT, math.inf) xp_assert_equal(z, xp.asarray([math.inf, -math.inf, -math.inf])) + + +@pytest.mark.parametrize( + "copy", + [ + None, + pytest.param( + False, + marks=[ + pytest.mark.skip_xp_backend( + Backend.NUMPY, reason="np.generic is read-only" + ), + pytest.mark.skip_xp_backend( + Backend.NUMPY_READONLY, reason="read-only backend" + ), + pytest.mark.skip_xp_backend(Backend.JAX, reason="read-only backend"), + pytest.mark.skip_xp_backend(Backend.SPARSE, reason="read-only backend"), + ], + ), + ], +) +@pytest.mark.parametrize("bool_mask", [False, True]) +def test_gh134(xp: ModuleType, bool_mask: bool, copy: bool | None): + """ + Test that xpx.at doesn't encroach in a bug of dask.array.Array.__setitem__, which + blindly assumes that chunk contents are writeable np.ndarray objects: + + https://github.com/dask/dask/issues/11722 + + In other words: when special-casing bool masks for Dask, unless the user explicitly + asks for copy=False, do not needlessly write back to the input. + """ + x = xp.zeros(1) + + # In numpy, we have a writeable np.ndarray in input and a read-only np.generic in + # output. As both are Arrays, this behaviour is Array API compliant. + # In Dask, we have a writeable da.Array on both sides, and if you call __setitem__ + # on it all seems fine, but when you compute() your graph is corrupted. + y = x[0] + + idx = xp.asarray(True) if bool_mask else () + z = at_op(y, idx, _AtOp.SET, 1, copy=copy) + xp_assert_equal(z, xp.asarray(1, dtype=x.dtype))