Skip to content

Commit 709f4fe

Browse files
committed
TST: add a regression test for where
Check that mixing scalars with arrays preserves the dtype
1 parent bde69fd commit 709f4fe

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

array_api_strict/tests/test_searching_functions.py

+7
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,10 @@ def test_where_mixed_dtypes():
3535
c = x > 1.5
3636
res = xp.where(c, False, c)
3737
assert all(res == xp.asarray([False, False]))
38+
39+
40+
def test_where_f32():
41+
# https://github.com/data-apis/array-api-strict/issues/131#issuecomment-2723016294
42+
res = xp.where(xp.asarray([True, False]), 1., xp.asarray([2, 2], dtype=xp.float32))
43+
assert res.dtype == xp.float32
44+

0 commit comments

Comments
 (0)