From bde69fd950ff273ad24d9f2c115aafee66e57b65 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 6 Mar 2025 11:27:08 +0100 Subject: [PATCH 1/2] BUG: fix where(cond, float_array, int) This should be allowed by the spec: python int scalars can combine with float arrays. --- array_api_strict/_helpers.py | 2 +- array_api_strict/_searching_functions.py | 14 ++------------ .../tests/test_searching_functions.py | 15 ++++++++++++++- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/array_api_strict/_helpers.py b/array_api_strict/_helpers.py index 2258d29..d3fc9c9 100644 --- a/array_api_strict/_helpers.py +++ b/array_api_strict/_helpers.py @@ -31,7 +31,7 @@ def _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name): x2 = x1._promote_scalar(x2) else: if x1.dtype not in _allowed_dtypes or x2.dtype not in _allowed_dtypes: - raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. " + raise TypeError(f"Only {dtype_category} dtypes are allowed in {func_name}(...). " f"Got {x1.dtype} and {x2.dtype}.") return x1, x2 diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index ad32aaa..9864132 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -3,6 +3,7 @@ from ._array_object import Array from ._dtypes import _result_type, _real_numeric_dtypes, bool as _bool from ._flags import requires_data_dependent_shapes, requires_api_version, get_array_api_strict_flags +from ._helpers import _maybe_normalize_py_scalars from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -101,18 +102,7 @@ def where( See its docstring for more information. """ if get_array_api_strict_flags()['api_version'] > '2023.12': - num_scalars = 0 - - if isinstance(x1, (bool, float, complex, int)): - x1 = Array._new(np.asarray(x1), device=condition.device) - num_scalars += 1 - - if isinstance(x2, (bool, float, complex, int)): - x2 = Array._new(np.asarray(x2), device=condition.device) - num_scalars += 1 - - if num_scalars == 2: - raise ValueError("One of x1, x2 arguments must be an array.") + x1, x2 = _maybe_normalize_py_scalars(x1, x2, "all", "where") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py index 0e54d5f..016862c 100644 --- a/array_api_strict/tests/test_searching_functions.py +++ b/array_api_strict/tests/test_searching_functions.py @@ -20,5 +20,18 @@ def test_where_with_scalars(): assert xp.all(x_where == expected) # The spec does not allow both x1 and x2 to be scalars - with pytest.raises(ValueError, match="One of"): + with pytest.raises(TypeError, match="Two scalars"): xp.where(x == 1, 42, 44) + + +def test_where_mixed_dtypes(): + # https://github.com/data-apis/array-api-strict/issues/131 + x = xp.asarray([1., 2.]) + res = xp.where(x > 1.5, x, 0) + assert res.dtype == x.dtype + assert all(res == xp.asarray([0., 2.])) + + # retry with boolean x1, x2 + c = x > 1.5 + res = xp.where(c, False, c) + assert all(res == xp.asarray([False, False])) From 709f4febd201829777c667ee2fe090912e65981b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 14 Mar 2025 10:11:46 +0100 Subject: [PATCH 2/2] TST: add a regression test for `where` Check that mixing scalars with arrays preserves the dtype --- array_api_strict/tests/test_searching_functions.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py index 016862c..2a3a79e 100644 --- a/array_api_strict/tests/test_searching_functions.py +++ b/array_api_strict/tests/test_searching_functions.py @@ -35,3 +35,10 @@ def test_where_mixed_dtypes(): c = x > 1.5 res = xp.where(c, False, c) assert all(res == xp.asarray([False, False])) + + +def test_where_f32(): + # https://github.com/data-apis/array-api-strict/issues/131#issuecomment-2723016294 + res = xp.where(xp.asarray([True, False]), 1., xp.asarray([2, 2], dtype=xp.float32)) + assert res.dtype == xp.float32 +