|
2 | 2 |
|
3 | 3 | from ._array_object import Array
|
4 | 4 | from ._dtypes import _result_type, _real_numeric_dtypes, bool as _bool
|
5 |
| -from ._flags import requires_data_dependent_shapes, requires_api_version |
| 5 | +from ._flags import requires_data_dependent_shapes, requires_api_version, get_array_api_strict_flags |
6 | 6 |
|
7 | 7 | from typing import TYPE_CHECKING
|
8 | 8 | if TYPE_CHECKING:
|
@@ -90,20 +90,38 @@ def searchsorted(
|
90 | 90 | # x1 must be 1-D, but NumPy already requires this.
|
91 | 91 | return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device)
|
92 | 92 |
|
93 |
| -def where(condition: Array, x1: Array, x2: Array, /) -> Array: |
| 93 | +def where( |
| 94 | + condition: Array, |
| 95 | + x1: bool | int | float | complex | Array, |
| 96 | + x2: bool | int | float | complex | Array, / |
| 97 | +) -> Array: |
94 | 98 | """
|
95 | 99 | Array API compatible wrapper for :py:func:`np.where <numpy.where>`.
|
96 | 100 |
|
97 | 101 | See its docstring for more information.
|
98 | 102 | """
|
| 103 | + if get_array_api_strict_flags()['api_version'] > '2023.12': |
| 104 | + num_scalars = 0 |
| 105 | + |
| 106 | + if isinstance(x1, (bool, float, complex, int)): |
| 107 | + x1 = Array._new(np.asarray(x1), device=condition.device) |
| 108 | + num_scalars += 1 |
| 109 | + |
| 110 | + if isinstance(x2, (bool, float, complex, int)): |
| 111 | + x2 = Array._new(np.asarray(x2), device=condition.device) |
| 112 | + num_scalars += 1 |
| 113 | + |
| 114 | + if num_scalars == 2: |
| 115 | + raise ValueError("One of x1, x2 arguments must be an array.") |
| 116 | + |
99 | 117 | # Call result type here just to raise on disallowed type combinations
|
100 | 118 | _result_type(x1.dtype, x2.dtype)
|
101 | 119 |
|
102 | 120 | if condition.dtype != _bool:
|
103 | 121 | raise TypeError("`condition` must be have a boolean data type")
|
104 | 122 |
|
105 | 123 | if len({a.device for a in (condition, x1, x2)}) > 1:
|
106 |
| - raise ValueError("where inputs must all be on the same device") |
| 124 | + raise ValueError("Inputs to `where` must all use the same device") |
107 | 125 |
|
108 | 126 | x1, x2 = Array._normalize_two_args(x1, x2)
|
109 | 127 | return Array._new(np.where(condition._array, x1._array, x2._array), device=x1.device)
|
0 commit comments