Skip to content

Commit 96105b3

Browse files
committed
BUG: at should not force overwrite in Dask when copy=None
1 parent 03f0b3e commit 96105b3

File tree

2 files changed

+101
-32
lines changed

2 files changed

+101
-32
lines changed

src/array_api_extra/_lib/_at.py

+13-17
Original file line numberDiff line numberDiff line change
@@ -275,16 +275,11 @@ def _op(
275275
msg = f"copy must be True, False, or None; got {copy!r}"
276276
raise ValueError(msg)
277277

278-
if copy is None:
279-
writeable = is_writeable_array(x)
280-
copy = not writeable
281-
elif copy:
282-
writeable = None
283-
else:
284-
writeable = is_writeable_array(x)
278+
writeable = None if copy else is_writeable_array(x)
285279

286-
# JAX inside jax.jit and Dask don't support in-place updates with boolean
287-
# mask. However we can handle the common special case of 0-dimensional y
280+
# JAX inside jax.jit doesn't support in-place updates with boolean
281+
# masks; Dask exclusively supports __setitem__ but not iops.
282+
# We can handle the common special case of 0-dimensional y
288283
# with where(idx, y, x) instead.
289284
if (
290285
(is_dask_array(idx) or is_jax_array(idx))
@@ -293,21 +288,22 @@ def _op(
293288
):
294289
y_xp = xp.asarray(y, dtype=x.dtype)
295290
if y_xp.ndim == 0:
296-
if out_of_place_op:
291+
if out_of_place_op: # add(), subtract(), ...
297292
# FIXME: suppress inf warnings on dask with lazywhere
298293
out = xp.where(idx, out_of_place_op(x, y_xp), x)
299294
# Undo int->float promotion on JAX after _AtOp.DIVIDE
300295
out = xp.astype(out, x.dtype, copy=False)
301-
else:
296+
else: # set()
302297
out = xp.where(idx, y_xp, x)
303298

304-
if copy:
305-
return out
306-
x[()] = out
307-
return x
299+
if copy is False:
300+
x[()] = out
301+
return x
302+
return out
303+
308304
# else: this will work on eager JAX and crash on jax.jit and Dask
309305

310-
if copy:
306+
if copy or (copy is None and not writeable):
311307
if is_jax_array(x):
312308
# Use JAX's at[]
313309
func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value))
@@ -331,7 +327,7 @@ def _op(
331327
msg = f"Can't update read-only array {x}"
332328
raise ValueError(msg)
333329

334-
if in_place_op:
330+
if in_place_op: # add(), subtract(), ...
335331
x[self._idx] = in_place_op(x[self._idx], y)
336332
else: # set()
337333
x[self._idx] = y

tests/test_at.py

+88-15
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,15 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
7272
array_orig = xp.asarray(array, copy=True)
7373
yield
7474

75-
if copy is None:
76-
copy = not is_writeable_array(array)
77-
xp_assert_equal(xp.all(array == array_orig), xp.asarray(copy))
75+
if copy is True:
76+
# Original has not been modified
77+
xp_assert_equal(array, array_orig)
78+
elif copy is False:
79+
# Original has been modified
80+
with pytest.raises(AssertionError):
81+
xp_assert_equal(array, array_orig)
82+
# Test nothing for copy=None. Dask changes behaviour depending on
83+
# whether it's a special case of a bool mask with scalar RHS or not.
7884

7985

8086
@pytest.mark.parametrize(
@@ -89,7 +95,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
8995
],
9096
)
9197
@pytest.mark.parametrize(
92-
("op", "y", "expect"),
98+
("op", "y", "expect_list"),
9399
[
94100
(_AtOp.SET, 40.0, [10.0, 40.0, 40.0]),
95101
(_AtOp.ADD, 40.0, [10.0, 60.0, 70.0]),
@@ -102,14 +108,13 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
102108
],
103109
)
104110
@pytest.mark.parametrize(
105-
("bool_mask", "shaped_y"),
111+
("bool_mask", "x_ndim", "y_ndim"),
106112
[
107-
(False, False),
108-
(False, True),
109-
(True, False), # Uses xp.where(idx, y, x) on JAX and Dask
113+
(False, 1, 0),
114+
(False, 1, 1),
115+
(True, 1, 0), # Uses xp.where(idx, y, x) on JAX and Dask
110116
pytest.param(
111-
True,
112-
True,
117+
*(True, 1, 1),
113118
marks=(
114119
pytest.mark.skip_xp_backend( # test passes when copy=False
115120
Backend.JAX, reason="bool mask update with shaped rhs"
@@ -119,6 +124,8 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
119124
),
120125
),
121126
),
127+
(False, 0, 0),
128+
(True, 0, 0),
122129
],
123130
)
124131
def test_update_ops(
@@ -127,13 +134,26 @@ def test_update_ops(
127134
expect_copy: bool | None,
128135
op: _AtOp,
129136
y: float,
130-
expect: list[float],
137+
expect_list: list[float],
131138
bool_mask: bool,
132-
shaped_y: bool,
139+
x_ndim: int,
140+
y_ndim: int,
133141
):
134-
x = xp.asarray([10.0, 20.0, 30.0])
135-
idx = xp.asarray([False, True, True]) if bool_mask else slice(1, None)
136-
if shaped_y:
142+
if x_ndim == 1:
143+
x = xp.asarray([10.0, 20.0, 30.0])
144+
idx = xp.asarray([False, True, True]) if bool_mask else slice(1, None)
145+
expect: list[float] | float = expect_list
146+
else:
147+
idx = xp.asarray(True) if bool_mask else ()
148+
# Pick an element that does change with the operation
149+
if op is _AtOp.MIN:
150+
x = xp.asarray(30.0)
151+
expect = expect_list[2]
152+
else:
153+
x = xp.asarray(20.0)
154+
expect = expect_list[1]
155+
156+
if y_ndim == 1:
137157
y = xp.asarray([y, y])
138158

139159
with assert_copy(x, expect_copy):
@@ -259,3 +279,56 @@ def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
259279
# inf - inf -> nan with a warning
260280
z = at_op(x, idx, _AtOp.SUBTRACT, math.inf)
261281
xp_assert_equal(z, xp.asarray([math.inf, -math.inf, -math.inf]))
282+
283+
284+
@pytest.mark.parametrize(
285+
"copy",
286+
[
287+
None,
288+
pytest.param(
289+
False,
290+
marks=[
291+
pytest.mark.skip_xp_backend(
292+
Backend.NUMPY, reason="np.generic is read-only"
293+
),
294+
pytest.mark.skip_xp_backend(
295+
Backend.NUMPY_READONLY, reason="read-only backend"
296+
),
297+
pytest.mark.skip_xp_backend(Backend.JAX, reason="read-only backend"),
298+
pytest.mark.skip_xp_backend(Backend.SPARSE, reason="read-only backend"),
299+
pytest.mark.xfail_xp_backend(Backend.DASK, reason="dask/dask#11722"),
300+
],
301+
),
302+
],
303+
)
304+
@pytest.mark.parametrize(
305+
"bool_mask",
306+
[
307+
pytest.param(
308+
False,
309+
marks=pytest.mark.xfail_xp_backend(Backend.DASK, reason="dask/dask#11722"),
310+
),
311+
True,
312+
],
313+
)
314+
def test_gh134(xp: ModuleType, bool_mask: bool, copy: bool | None):
315+
"""
316+
Test that xpx.at doesn't encroach in a bug of dask.array.Array.__setitem__, which
317+
blindly assumes that chunk contents are writeable np.ndarray objects:
318+
319+
https://github.com/dask/dask/issues/11722
320+
321+
In other words: when special-casing bool masks for Dask, unless the user explicitly
322+
asks for copy=False, do not needlessly write back to the input.
323+
"""
324+
x = xp.zeros(1)
325+
326+
# In numpy, we have a writeable np.ndarray in input and a read-only np.generic in
327+
# output. As both are Arrays, this behaviour is Array API compliant.
328+
# In Dask, we have a writeable da.Array on both sides, and if you call __setitem__
329+
# on it all seems fine, but when you compute() your graph is corrupted.
330+
y = x[0]
331+
332+
idx = xp.asarray(True) if bool_mask else ()
333+
z = at_op(y, idx, _AtOp.SET, 1, copy=copy)
334+
xp_assert_equal(z, xp.asarray(1, dtype=x.dtype))

0 commit comments

Comments
 (0)