Skip to content

Commit 590a2de

Browse files
authored
Add support for scalar arguments to xp.where (#78)
Reviewed at #78
1 parent 1a288de commit 590a2de

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

array_api_strict/_searching_functions.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from ._array_object import Array
44
from ._dtypes import _result_type, _real_numeric_dtypes, bool as _bool
5-
from ._flags import requires_data_dependent_shapes, requires_api_version
5+
from ._flags import requires_data_dependent_shapes, requires_api_version, get_array_api_strict_flags
66

77
from typing import TYPE_CHECKING
88
if TYPE_CHECKING:
@@ -90,20 +90,38 @@ def searchsorted(
9090
# x1 must be 1-D, but NumPy already requires this.
9191
return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device)
9292

93-
def where(condition: Array, x1: Array, x2: Array, /) -> Array:
93+
def where(
94+
condition: Array,
95+
x1: bool | int | float | complex | Array,
96+
x2: bool | int | float | complex | Array, /
97+
) -> Array:
9498
"""
9599
Array API compatible wrapper for :py:func:`np.where <numpy.where>`.
96100
97101
See its docstring for more information.
98102
"""
103+
if get_array_api_strict_flags()['api_version'] > '2023.12':
104+
num_scalars = 0
105+
106+
if isinstance(x1, (bool, float, complex, int)):
107+
x1 = Array._new(np.asarray(x1), device=condition.device)
108+
num_scalars += 1
109+
110+
if isinstance(x2, (bool, float, complex, int)):
111+
x2 = Array._new(np.asarray(x2), device=condition.device)
112+
num_scalars += 1
113+
114+
if num_scalars == 2:
115+
raise ValueError("One of x1, x2 arguments must be an array.")
116+
99117
# Call result type here just to raise on disallowed type combinations
100118
_result_type(x1.dtype, x2.dtype)
101119

102120
if condition.dtype != _bool:
103121
raise TypeError("`condition` must be have a boolean data type")
104122

105123
if len({a.device for a in (condition, x1, x2)}) > 1:
106-
raise ValueError("where inputs must all be on the same device")
124+
raise ValueError("Inputs to `where` must all use the same device")
107125

108126
x1, x2 = Array._normalize_two_args(x1, x2)
109127
return Array._new(np.where(condition._array, x1._array, x2._array), device=x1.device)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
3+
import array_api_strict as xp
4+
5+
from array_api_strict import ArrayAPIStrictFlags
6+
from array_api_strict._flags import draft_version
7+
8+
9+
def test_where_with_scalars():
10+
x = xp.asarray([1, 2, 3, 1])
11+
12+
# Versions up to and including 2023.12 don't support scalar arguments
13+
with pytest.raises(AttributeError, match="object has no attribute 'dtype'"):
14+
xp.where(x == 1, 42, 44)
15+
16+
# Versions after 2023.12 support scalar arguments
17+
with (pytest.warns(
18+
UserWarning,
19+
match="The 2024.12 version of the array API specification is in draft status"
20+
),
21+
ArrayAPIStrictFlags(api_version=draft_version),
22+
):
23+
x_where = xp.where(x == 1, xp.asarray(42), 44)
24+
25+
expected = xp.asarray([42, 44, 44, 42])
26+
assert xp.all(x_where == expected)
27+
28+
# The spec does not allow both x1 and x2 to be scalars
29+
with pytest.raises(ValueError, match="One of"):
30+
xp.where(x == 1, 42, 44)

0 commit comments

Comments
 (0)