Skip to content

Commit 43dc1f3

Browse files
committed
Tweaks
1 parent 718b67e commit 43dc1f3

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

Diff for: src/array_api_extra/_lib/_funcs.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02
9090
Argument(s) to `f1` (and `f2`). Must be broadcastable with `cond`.
9191
fill_value : Array or scalar, optional
9292
If provided, value with which to fill output array where `cond` is False.
93-
It does not need to be scalar.
93+
It does not need to be scalar; it needs however to be broadcastable with
94+
`cond` and `args`.
9495
Mutually exclusive with `f2`. You must provide one or the other.
9596
xp : array_namespace, optional
9697
The standard-compatible namespace for `cond` and `args`. Default: infer.
@@ -147,7 +148,8 @@ def apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,PR02
147148
cond, *args = xp.broadcast_arrays(cond, *args)
148149

149150
if is_dask_namespace(xp):
150-
meta_xp = meta_namespace(cond, fill_value, *args, xp=xp)
151+
meta_xp = meta_namespace(cond, *args, fill_value, xp=xp)
152+
# map_blocks doesn't descend into tuples of Arrays
151153
return xp.map_blocks(_apply_where, cond, f1, f2, fill_value, *args, xp=meta_xp)
152154
return _apply_where(cond, f1, f2, fill_value, *args, xp=xp)
153155

@@ -166,21 +168,20 @@ def _apply_where( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
166168
# jax.jit does not support assignment by boolean mask
167169
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)
168170

169-
device = _compat.device(cond)
170171
temp1 = f1(*(arr[cond] for arr in args))
171172

172173
if f2 is None:
173174
# TODO remove asarrays once all backends support Array API 2024.12
174175
dtype = xp.result_type(*asarrays(temp1, fill_value, xp=xp))
175176
if getattr(fill_value, "ndim", 0):
176-
fill_value = xp.astype(fill_value, dtype)
177-
return at(fill_value, cond).set(temp1, copy=True)
178-
out = xp.full(cond.shape, fill_value=fill_value, dtype=dtype, device=device)
177+
out = xp.astype(fill_value, dtype, copy=True)
178+
else:
179+
out = xp.full_like(cond, dtype=dtype, fill_value=fill_value)
179180
else:
180181
ncond = ~cond
181182
temp2 = f2(*(arr[ncond] for arr in args))
182183
dtype = xp.result_type(temp1, temp2)
183-
out = xp.empty(cond.shape, dtype=dtype, device=device)
184+
out = xp.empty_like(cond, dtype=dtype)
184185
out = at(out, ncond).set(temp2)
185186

186187
return at(out, cond).set(temp1)

0 commit comments

Comments
 (0)