Skip to content

Commit d12a5e3

Browse files
betatimkgryte
andauthoredJan 9, 2025
feat: add scalar support to where
PR-URL: #860 Ref: #807 Co-authored-by: Athan Reines <kgryte@gmail.com> Reviewed-by: Athan Reines <kgryte@gmail.com> Reviewed-by: Evgeni Burovski Reviewed-by: Lucas Colley <lucas.colley8@gmail.com>
1 parent fd6f507 commit d12a5e3

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed
 

‎spec/draft/API_specification/type_promotion.rst

+3
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ Notes
120120
.. note::
121121
Mixed integer and floating-point type promotion rules are not specified because behavior varies between implementations.
122122

123+
124+
.. _mixing-scalars-and-arrays:
125+
123126
Mixing arrays with Python scalars
124127
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
125128

‎src/array_api_stubs/_draft/searching_functions.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -168,21 +168,35 @@ def searchsorted(
168168
"""
169169

170170

171-
def where(condition: array, x1: array, x2: array, /) -> array:
171+
def where(
172+
condition: array,
173+
x1: Union[array, int, float, complex, bool],
174+
x2: Union[array, int, float, complex, bool],
175+
/,
176+
) -> array:
172177
"""
173178
Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``.
174179
175180
Parameters
176181
----------
177182
condition: array
178183
when ``True``, yield ``x1_i``; otherwise, yield ``x2_i``. Should have a boolean data type. Must be compatible with ``x1`` and ``x2`` (see :ref:`broadcasting`).
179-
x1: array
184+
x1: Union[array, int, float, complex, bool]
180185
first input array. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`).
181-
x2: array
186+
x2: Union[array, int, float, complex, bool]
182187
second input array. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`).
183188
184189
Returns
185190
-------
186191
out: array
187192
an array with elements from ``x1`` where ``condition`` is ``True``, and elements from ``x2`` elsewhere. The returned array must have a data type determined by :ref:`type-promotion` rules with the arrays ``x1`` and ``x2``.
193+
194+
Notes
195+
-----
196+
197+
- At least one of ``x1`` and ``x2`` must be an array.
198+
- If either ``x1`` or ``x2`` is a scalar value, the returned array must have a data type determined according to :ref:`mixing-scalars-and-arrays`.
199+
200+
.. versionchanged:: 2024.12
201+
Added support for scalar arguments.
188202
"""

0 commit comments

Comments
 (0)