Skip to content

Commit b061dc2

Browse files
authored
Merge pull request #132 from ev-br/where_scalars
BUG: fix where(cond, float_array, int)
2 parents ca99593 + 709f4fe commit b061dc2

File tree

3 files changed

+24
-14
lines changed

3 files changed

+24
-14
lines changed

Diff for: array_api_strict/_helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name):
3131
x2 = x1._promote_scalar(x2)
3232
else:
3333
if x1.dtype not in _allowed_dtypes or x2.dtype not in _allowed_dtypes:
34-
raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. "
34+
raise TypeError(f"Only {dtype_category} dtypes are allowed in {func_name}(...). "
3535
f"Got {x1.dtype} and {x2.dtype}.")
3636
return x1, x2
3737

Diff for: array_api_strict/_searching_functions.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ._array_object import Array
44
from ._dtypes import _result_type, _real_numeric_dtypes, bool as _bool
55
from ._flags import requires_data_dependent_shapes, requires_api_version, get_array_api_strict_flags
6+
from ._helpers import _maybe_normalize_py_scalars
67

78
from typing import TYPE_CHECKING
89
if TYPE_CHECKING:
@@ -101,18 +102,7 @@ def where(
101102
See its docstring for more information.
102103
"""
103104
if get_array_api_strict_flags()['api_version'] > '2023.12':
104-
num_scalars = 0
105-
106-
if isinstance(x1, (bool, float, complex, int)):
107-
x1 = Array._new(np.asarray(x1), device=condition.device)
108-
num_scalars += 1
109-
110-
if isinstance(x2, (bool, float, complex, int)):
111-
x2 = Array._new(np.asarray(x2), device=condition.device)
112-
num_scalars += 1
113-
114-
if num_scalars == 2:
115-
raise ValueError("One of x1, x2 arguments must be an array.")
105+
x1, x2 = _maybe_normalize_py_scalars(x1, x2, "all", "where")
116106

117107
# Call result type here just to raise on disallowed type combinations
118108
_result_type(x1.dtype, x2.dtype)

Diff for: array_api_strict/tests/test_searching_functions.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,25 @@ def test_where_with_scalars():
2020
assert xp.all(x_where == expected)
2121

2222
# The spec does not allow both x1 and x2 to be scalars
23-
with pytest.raises(ValueError, match="One of"):
23+
with pytest.raises(TypeError, match="Two scalars"):
2424
xp.where(x == 1, 42, 44)
25+
26+
27+
def test_where_mixed_dtypes():
28+
# https://github.com/data-apis/array-api-strict/issues/131
29+
x = xp.asarray([1., 2.])
30+
res = xp.where(x > 1.5, x, 0)
31+
assert res.dtype == x.dtype
32+
assert all(res == xp.asarray([0., 2.]))
33+
34+
# retry with boolean x1, x2
35+
c = x > 1.5
36+
res = xp.where(c, False, c)
37+
assert all(res == xp.asarray([False, False]))
38+
39+
40+
def test_where_f32():
41+
# https://github.com/data-apis/array-api-strict/issues/131#issuecomment-2723016294
42+
res = xp.where(xp.asarray([True, False]), 1., xp.asarray([2, 2], dtype=xp.float32))
43+
assert res.dtype == xp.float32
44+

0 commit comments

Comments
 (0)