@@ -72,9 +72,15 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
72
72
array_orig = xp .asarray (array , copy = True )
73
73
yield
74
74
75
- if copy is None :
76
- copy = not is_writeable_array (array )
77
- xp_assert_equal (xp .all (array == array_orig ), xp .asarray (copy ))
75
+ if copy is True :
76
+ # Original has not been modified
77
+ xp_assert_equal (array , array_orig )
78
+ elif copy is False :
79
+ # Original has been modified
80
+ with pytest .raises (AssertionError ):
81
+ xp_assert_equal (array , array_orig )
82
+ # Test nothing for copy=None. Dask changes behaviour depending on
83
+ # whether it's a special case of a bool mask with scalar RHS or not.
78
84
79
85
80
86
@pytest .mark .parametrize (
@@ -89,7 +95,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
89
95
],
90
96
)
91
97
@pytest .mark .parametrize (
92
- ("op" , "y" , "expect " ),
98
+ ("op" , "y" , "expect_list " ),
93
99
[
94
100
(_AtOp .SET , 40.0 , [10.0 , 40.0 , 40.0 ]),
95
101
(_AtOp .ADD , 40.0 , [10.0 , 60.0 , 70.0 ]),
@@ -102,14 +108,13 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
102
108
],
103
109
)
104
110
@pytest .mark .parametrize (
105
- ("bool_mask" , "shaped_y " ),
111
+ ("bool_mask" , "x_ndim" , "y_ndim " ),
106
112
[
107
- (False , False ),
108
- (False , True ),
109
- (True , False ), # Uses xp.where(idx, y, x) on JAX and Dask
113
+ (False , 1 , 0 ),
114
+ (False , 1 , 1 ),
115
+ (True , 1 , 0 ), # Uses xp.where(idx, y, x) on JAX and Dask
110
116
pytest .param (
111
- True ,
112
- True ,
117
+ * (True , 1 , 1 ),
113
118
marks = (
114
119
pytest .mark .skip_xp_backend ( # test passes when copy=False
115
120
Backend .JAX , reason = "bool mask update with shaped rhs"
@@ -119,6 +124,8 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
119
124
),
120
125
),
121
126
),
127
+ (False , 0 , 0 ),
128
+ (True , 0 , 0 ),
122
129
],
123
130
)
124
131
def test_update_ops (
@@ -127,13 +134,26 @@ def test_update_ops(
127
134
expect_copy : bool | None ,
128
135
op : _AtOp ,
129
136
y : float ,
130
- expect : list [float ],
137
+ expect_list : list [float ],
131
138
bool_mask : bool ,
132
- shaped_y : bool ,
139
+ x_ndim : int ,
140
+ y_ndim : int ,
133
141
):
134
- x = xp .asarray ([10.0 , 20.0 , 30.0 ])
135
- idx = xp .asarray ([False , True , True ]) if bool_mask else slice (1 , None )
136
- if shaped_y :
142
+ if x_ndim == 1 :
143
+ x = xp .asarray ([10.0 , 20.0 , 30.0 ])
144
+ idx = xp .asarray ([False , True , True ]) if bool_mask else slice (1 , None )
145
+ expect : list [float ] | float = expect_list
146
+ else :
147
+ idx = xp .asarray (True ) if bool_mask else ()
148
+ # Pick an element that does change with the operation
149
+ if op is _AtOp .MIN :
150
+ x = xp .asarray (30.0 )
151
+ expect = expect_list [2 ]
152
+ else :
153
+ x = xp .asarray (20.0 )
154
+ expect = expect_list [1 ]
155
+
156
+ if y_ndim == 1 :
137
157
y = xp .asarray ([y , y ])
138
158
139
159
with assert_copy (x , expect_copy ):
@@ -259,3 +279,56 @@ def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
259
279
# inf - inf -> nan with a warning
260
280
z = at_op (x , idx , _AtOp .SUBTRACT , math .inf )
261
281
xp_assert_equal (z , xp .asarray ([math .inf , - math .inf , - math .inf ]))
282
+
283
+
284
+ @pytest .mark .parametrize (
285
+ "copy" ,
286
+ [
287
+ None ,
288
+ pytest .param (
289
+ False ,
290
+ marks = [
291
+ pytest .mark .skip_xp_backend (
292
+ Backend .NUMPY , reason = "np.generic is read-only"
293
+ ),
294
+ pytest .mark .skip_xp_backend (
295
+ Backend .NUMPY_READONLY , reason = "read-only backend"
296
+ ),
297
+ pytest .mark .skip_xp_backend (Backend .JAX , reason = "read-only backend" ),
298
+ pytest .mark .skip_xp_backend (Backend .SPARSE , reason = "read-only backend" ),
299
+ pytest .mark .xfail_xp_backend (Backend .DASK , reason = "dask/dask#11722" ),
300
+ ],
301
+ ),
302
+ ],
303
+ )
304
+ @pytest .mark .parametrize (
305
+ "bool_mask" ,
306
+ [
307
+ pytest .param (
308
+ False ,
309
+ marks = pytest .mark .xfail_xp_backend (Backend .DASK , reason = "dask/dask#11722" ),
310
+ ),
311
+ True ,
312
+ ],
313
+ )
314
+ def test_gh134 (xp : ModuleType , bool_mask : bool , copy : bool | None ):
315
+ """
316
+ Test that xpx.at doesn't encroach in a bug of dask.array.Array.__setitem__, which
317
+ blindly assumes that chunk contents are writeable np.ndarray objects:
318
+
319
+ https://github.com/dask/dask/issues/11722
320
+
321
+ In other words: when special-casing bool masks for Dask, unless the user explicitly
322
+ asks for copy=False, do not needlessly write back to the input.
323
+ """
324
+ x = xp .zeros (1 )
325
+
326
+ # In numpy, we have a writeable np.ndarray in input and a read-only np.generic in
327
+ # output. As both are Arrays, this behaviour is Array API compliant.
328
+ # In Dask, we have a writeable da.Array on both sides, and if you call __setitem__
329
+ # on it all seems fine, but when you compute() your graph is corrupted.
330
+ y = x [0 ]
331
+
332
+ idx = xp .asarray (True ) if bool_mask else ()
333
+ z = at_op (y , idx , _AtOp .SET , 1 , copy = copy )
334
+ xp_assert_equal (z , xp .asarray (1 , dtype = x .dtype ))
0 commit comments