@@ -90,7 +90,8 @@ def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02
90
90
Argument(s) to `f1` (and `f2`). Must be broadcastable with `cond`.
91
91
fill_value : Array or scalar, optional
92
92
If provided, value with which to fill output array where `cond` is False.
93
- It does not need to be scalar.
93
+ It does not need to be scalar; it needs however to be broadcastable with
94
+ `cond` and `args`.
94
95
Mutually exclusive with `f2`. You must provide one or the other.
95
96
xp : array_namespace, optional
96
97
The standard-compatible namespace for `cond` and `args`. Default: infer.
@@ -147,7 +148,8 @@ def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02
147
148
cond , * args = xp .broadcast_arrays (cond , * args )
148
149
149
150
if is_dask_namespace (xp ):
150
- meta_xp = meta_namespace (cond , fill_value , * args , xp = xp )
151
+ meta_xp = meta_namespace (cond , * args , fill_value , xp = xp )
152
+ # map_blocks doesn't descend into tuples of Arrays
151
153
return xp .map_blocks (_apply_where , cond , f1 , f2 , fill_value , * args , xp = meta_xp )
152
154
return _apply_where (cond , f1 , f2 , fill_value , * args , xp = xp )
153
155
@@ -166,21 +168,20 @@ def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
166
168
# jax.jit does not support assignment by boolean mask
167
169
return xp .where (cond , f1 (* args ), f2 (* args ) if f2 is not None else fill_value )
168
170
169
- device = _compat .device (cond )
170
171
temp1 = f1 (* (arr [cond ] for arr in args ))
171
172
172
173
if f2 is None :
173
174
# TODO remove asarrays once all backends support Array API 2024.12
174
175
dtype = xp .result_type (* asarrays (temp1 , fill_value , xp = xp ))
175
176
if getattr (fill_value , "ndim" , 0 ):
176
- fill_value = xp .astype (fill_value , dtype )
177
- return at ( fill_value , cond ). set ( temp1 , copy = True )
178
- out = xp .full (cond . shape , fill_value = fill_value , dtype = dtype , device = device )
177
+ out = xp .astype (fill_value , dtype , copy = True )
178
+ else :
179
+ out = xp .full_like (cond , dtype = dtype , fill_value = fill_value )
179
180
else :
180
181
ncond = ~ cond
181
182
temp2 = f2 (* (arr [ncond ] for arr in args ))
182
183
dtype = xp .result_type (temp1 , temp2 )
183
- out = xp .empty (cond . shape , dtype = dtype , device = device )
184
+ out = xp .empty_like (cond , dtype = dtype )
184
185
out = at (out , ncond ).set (temp2 )
185
186
186
187
return at (out , cond ).set (temp1 )
0 commit comments