Skip to content

Commit 5fc5379

Browse files
committed
fix iwhere
1 parent 1d7bd9f commit 5fc5379

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

array_api_compat/common/_helpers.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,12 @@ def iwhere(condition, x, y, /):
962962
x in place, if it's possible and beneficial for performance.
963963
"""
964964
if is_writeable_array(x):
965-
x[condition] = y
965+
if is_array_api_obj(y) and len(y.shape) > 0:
966+
xp = array_namespace(x)
967+
x, y = xp.broadcast_arrays(x, y)
968+
x[condition] = y[condition]
969+
else:
970+
x[condition] = y
966971
return x
967972
else:
968973
xp = array_namespace(x)

0 commit comments

Comments
 (0)