Skip to content

Commit

Permalink
Add support for scalar arguments to xp.where (data-apis#78)
Browse files Browse the repository at this point in the history
Reviewed at data-apis#78
  • Loading branch information
betatim authored Feb 3, 2025
1 parent 1a288de commit 590a2de
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
24 changes: 21 additions & 3 deletions array_api_strict/_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ._array_object import Array
from ._dtypes import _result_type, _real_numeric_dtypes, bool as _bool
from ._flags import requires_data_dependent_shapes, requires_api_version
from ._flags import requires_data_dependent_shapes, requires_api_version, get_array_api_strict_flags

from typing import TYPE_CHECKING
if TYPE_CHECKING:
Expand Down Expand Up @@ -90,20 +90,38 @@ def searchsorted(
# x1 must be 1-D, but NumPy already requires this.
return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device)

def where(condition: Array, x1: Array, x2: Array, /) -> Array:
def where(
condition: Array,
x1: bool | int | float | complex | Array,
x2: bool | int | float | complex | Array, /
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.where <numpy.where>`.
See its docstring for more information.
"""
if get_array_api_strict_flags()['api_version'] > '2023.12':
num_scalars = 0

if isinstance(x1, (bool, float, complex, int)):
x1 = Array._new(np.asarray(x1), device=condition.device)
num_scalars += 1

if isinstance(x2, (bool, float, complex, int)):
x2 = Array._new(np.asarray(x2), device=condition.device)
num_scalars += 1

if num_scalars == 2:
raise ValueError("One of x1, x2 arguments must be an array.")

# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)

if condition.dtype != _bool:
raise TypeError("`condition` must be have a boolean data type")

if len({a.device for a in (condition, x1, x2)}) > 1:
raise ValueError("where inputs must all be on the same device")
raise ValueError("Inputs to `where` must all use the same device")

x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.where(condition._array, x1._array, x2._array), device=x1.device)
30 changes: 30 additions & 0 deletions array_api_strict/tests/test_searching_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

import array_api_strict as xp

from array_api_strict import ArrayAPIStrictFlags
from array_api_strict._flags import draft_version


def test_where_with_scalars():
x = xp.asarray([1, 2, 3, 1])

# Versions up to and including 2023.12 don't support scalar arguments
with pytest.raises(AttributeError, match="object has no attribute 'dtype'"):
xp.where(x == 1, 42, 44)

# Versions after 2023.12 support scalar arguments
with (pytest.warns(
UserWarning,
match="The 2024.12 version of the array API specification is in draft status"
),
ArrayAPIStrictFlags(api_version=draft_version),
):
x_where = xp.where(x == 1, xp.asarray(42), 44)

expected = xp.asarray([42, 44, 44, 42])
assert xp.all(x_where == expected)

# The spec does not allow both x1 and x2 to be scalars
with pytest.raises(ValueError, match="One of"):
xp.where(x == 1, 42, 44)

0 comments on commit 590a2de

Please sign in to comment.