Skip to content

Commit 202c46b

Browse files
committed
bug: where: check condition is boolean
1 parent 444830f commit 202c46b

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

Diff for: array_api_strict/_searching_functions.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from ._array_object import Array
4-
from ._dtypes import _result_type, _real_numeric_dtypes
4+
from ._dtypes import _result_type, _real_numeric_dtypes, bool as _bool
55
from ._flags import requires_data_dependent_shapes, requires_api_version
66

77
from typing import TYPE_CHECKING
@@ -80,6 +80,9 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array:
8080
"""
8181
# Call result type here just to raise on disallowed type combinations
8282
_result_type(x1.dtype, x2.dtype)
83+
84+
if condition.dtype != _bool:
85+
raise TypeError("`condition` must be have a boolean data type")
8386

8487
if len({a.device for a in (condition, x1, x2)}) > 1:
8588
raise ValueError("where inputs must all be on the same device")

0 commit comments

Comments
 (0)