diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 5af46d2..1643043 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -197,7 +197,7 @@ def isdtype( else: raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}") -def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype: +def result_type(*arrays_and_dtypes: Union[Array, Dtype, int, float, complex, bool]) -> Dtype: """ Array API compatible wrapper for :py:func:`np.result_type `. @@ -208,19 +208,40 @@ def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype: # too many extra type promotions like int64 + uint64 -> float64, and does # value-based casting on scalar arrays. A = [] + scalars = [] for a in arrays_and_dtypes: if isinstance(a, Array): a = a.dtype + elif isinstance(a, (bool, int, float, complex)): + scalars.append(a) elif isinstance(a, np.ndarray) or a not in _all_dtypes: raise TypeError("result_type() inputs must be array_api arrays or dtypes") A.append(a) + # remove python scalars + A = [a for a in A if not isinstance(a, (bool, int, float, complex))] + if len(A) == 0: raise ValueError("at least one array or dtype is required") elif len(A) == 1: - return A[0] + result = A[0] else: t = A[0] for t2 in A[1:]: t = _result_type(t, t2) - return t + result = t + + if len(scalars) == 0: + return result + + if get_array_api_strict_flags()['api_version'] <= '2023.12': + raise TypeError("result_type() inputs must be array_api arrays or dtypes") + + # promote python scalars given the result_type for all arrays/dtypes + from ._creation_functions import empty + arr = empty(1, dtype=result) + for s in scalars: + x = arr._promote_scalar(s) + result = _result_type(x.dtype, result) + + return result diff --git a/array_api_strict/tests/test_data_type_functions.py b/array_api_strict/tests/test_data_type_functions.py index 488eab7..863d3d4 100644 --- a/array_api_strict/tests/test_data_type_functions.py +++ b/array_api_strict/tests/test_data_type_functions.py @@ -6,9 +6,9 @@ import numpy as np from .._creation_functions import asarray -from .._data_type_functions import astype, can_cast, isdtype +from .._data_type_functions import astype, can_cast, isdtype, result_type from .._dtypes import ( - bool, int8, int16, uint8, float64, + bool, int8, int16, uint8, float64, int64 ) from .._flags import set_array_api_strict_flags @@ -70,3 +70,22 @@ def astype_device(api_version): else: pytest.raises(TypeError, lambda: astype(a, int8, device=None)) pytest.raises(TypeError, lambda: astype(a, int8, device=a.device)) + + +@pytest.mark.parametrize("api_version", ['2023.12', '2024.12']) +def test_result_type_py_scalars(api_version): + if api_version <= '2023.12': + set_array_api_strict_flags(api_version=api_version) + + with pytest.raises(TypeError): + result_type(int16, 3) + else: + with pytest.warns(UserWarning): + set_array_api_strict_flags(api_version=api_version) + + assert result_type(int8, 3) == int8 + assert result_type(uint8, 3) == uint8 + assert result_type(float64, 3) == float64 + + with pytest.raises(TypeError): + result_type(int64, True) diff --git a/array_api_strict/tests/test_manipulation_functions.py b/array_api_strict/tests/test_manipulation_functions.py index 9969651..bd247ee 100644 --- a/array_api_strict/tests/test_manipulation_functions.py +++ b/array_api_strict/tests/test_manipulation_functions.py @@ -11,7 +11,7 @@ def test_concat_errors(): - assert_raises(TypeError, lambda: concat((1, 1), axis=None)) + assert_raises((TypeError, ValueError), lambda: concat((1, 1), axis=None)) assert_raises(TypeError, lambda: concat([asarray([1], dtype=int8), asarray([1], dtype=float64)])) diff --git a/array_api_strict/tests/test_validation.py b/array_api_strict/tests/test_validation.py index 035b6f4..bd76ec6 100644 --- a/array_api_strict/tests/test_validation.py +++ b/array_api_strict/tests/test_validation.py @@ -18,7 +18,7 @@ def p(func: Callable, *args, **kwargs): [ p(xp.can_cast, 42, xp.int8), p(xp.can_cast, xp.int8, 42), - p(xp.result_type, 42), + p(xp.result_type, "42"), ], ) def test_raises_on_invalid_types(func, args, kwargs):