diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index d7346aa1..f7fa306b 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -198,6 +198,7 @@ def get_scalar_type(dtype: DataType) -> ScalarType: def is_scalar(x): return isinstance(x, (int, float, complex, bool)) + def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping: dtype_value_pairs = [] for name, value in mapping.items(): diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index d967daa4..3c09a1d5 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -9,7 +9,7 @@ from hypothesis import assume, reject from hypothesis.strategies import (SearchStrategy, booleans, composite, floats, - integers, just, lists, none, one_of, + integers, complex_numbers, just, lists, none, one_of, sampled_from, shared, builds, nothing) from . import _array_module as xp, api_version @@ -19,7 +19,7 @@ from . import xps from ._array_module import _UndefinedStub from ._array_module import bool as bool_dtype -from ._array_module import broadcast_to, eye, float32, float64, full +from ._array_module import broadcast_to, eye, float32, float64, full, complex64, complex128 from .stubs import category_to_funcs from .pytest_helpers import nargs from .typing import Array, DataType, Scalar, Shape @@ -462,6 +462,14 @@ def scalars(draw, dtypes, finite=False): if finite: return draw(floats(width=32, allow_nan=False, allow_infinity=False)) return draw(floats(width=32)) + elif dtype == complex64: + if finite: + return draw(complex_numbers(width=32, allow_nan=False, allow_infinity=False)) + return draw(complex_numbers(width=32)) + elif dtype == complex128: + if finite: + return draw(complex_numbers(allow_nan=False, allow_infinity=False)) + return draw(complex_numbers()) else: raise ValueError(f"Unrecognized dtype {dtype}") @@ -571,6 +579,20 @@ def two_mutual_arrays( ) return arrays1, arrays2 + +@composite +def array_and_py_scalar(draw, dtypes): + """Draw a pair: (array, scalar) or (scalar, array).""" + dtype = draw(sampled_from(dtypes)) + scalar_var = draw(scalars(just(dtype), finite=True)) + array_var = draw(arrays(dtype, shape=shapes(min_dims=1))) + + if draw(booleans()): + return scalar_var, array_var + else: + return array_var, scalar_var + + @composite def kwargs(draw, **kw): """ diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index dbd44223..129f5c21 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -690,6 +690,40 @@ def binary_param_assert_against_refimpl( ) +def _convert_scalars_helper(x1, x2): + """Convert python scalar to arrays, record the shapes/dtypes of arrays. + + For inputs being scalars or arrays, return the dtypes and shapes of array arguments, + and all arguments converted to arrays. + + dtypes are separate to help distinguishing between + `py_scalar + f32_array -> f32_array` and `f64_array + f32_array -> f64_array` + """ + if dh.is_scalar(x1): + in_dtypes = [x2.dtype] + in_shapes = [x2.shape] + x1a, x2a = xp.asarray(x1), x2 + elif dh.is_scalar(x2): + in_dtypes = [x1.dtype] + in_shapes = [x1.shape] + x1a, x2a = x1, xp.asarray(x2) + else: + in_dtypes = [x1.dtype, x2.dtype] + in_shapes = [x1.shape, x2.shape] + x1a, x2a = x1, x2 + + return in_dtypes, in_shapes, (x1a, x2a) + + +def _assert_correctness_binary( + name, func, in_dtypes, in_shapes, in_arrs, out, expected_dtype=None, **kwargs +): + x1a, x2a = in_arrs + ph.assert_dtype(name, in_dtype=in_dtypes, out_dtype=out.dtype, expected=expected_dtype) + ph.assert_result_shape(name, in_shapes=in_shapes, out_shape=out.shape) + binary_assert_against_refimpl(name, x1a, x2a, out, func, **kwargs) + + @pytest.mark.parametrize("ctx", make_unary_params("abs", dh.numeric_dtypes)) @given(data=st.data()) def test_abs(ctx, data): @@ -789,10 +823,14 @@ def test_atan(x): @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_atan2(x1, x2): out = xp.atan2(x1, x2) - ph.assert_dtype("atan2", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) - ph.assert_result_shape("atan2", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) - refimpl = cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2 - binary_assert_against_refimpl("atan2", x1, x2, out, refimpl) + _assert_correctness_binary( + "atan", + cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out, + ) @given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) @@ -1258,10 +1296,14 @@ def test_greater_equal(ctx, data): @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_hypot(x1, x2): out = xp.hypot(x1, x2) - ph.assert_dtype("hypot", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) - ph.assert_result_shape("hypot", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) - binary_assert_against_refimpl("hypot", x1, x2, out, math.hypot) - + _assert_correctness_binary( + "hypot", + math.hypot, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out + ) @pytest.mark.min_version("2022.12") @@ -1411,21 +1453,17 @@ def logaddexp_refimpl(l: float, r: float) -> float: raise OverflowError +@pytest.mark.min_version("2023.12") @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_logaddexp(x1, x2): out = xp.logaddexp(x1, x2) - ph.assert_dtype("logaddexp", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) - ph.assert_result_shape("logaddexp", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) - binary_assert_against_refimpl("logaddexp", x1, x2, out, logaddexp_refimpl) - - -@given(*hh.two_mutual_arrays([xp.bool])) -def test_logical_and(x1, x2): - out = xp.logical_and(x1, x2) - ph.assert_dtype("logical_and", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) - ph.assert_result_shape("logical_and", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) - binary_assert_against_refimpl( - "logical_and", x1, x2, out, operator.and_, expr_template="({} and {})={}" + _assert_correctness_binary( + "logaddexp", + logaddexp_refimpl, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out ) @@ -1439,23 +1477,45 @@ def test_logical_not(x): ) +@given(*hh.two_mutual_arrays([xp.bool])) +def test_logical_and(x1, x2): + out = xp.logical_and(x1, x2) + _assert_correctness_binary( + "logical_and", + operator.and_, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out, + expr_template="({} and {})={}" + ) + + @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_or(x1, x2): out = xp.logical_or(x1, x2) - ph.assert_dtype("logical_or", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) - ph.assert_result_shape("logical_or", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) - binary_assert_against_refimpl( - "logical_or", x1, x2, out, operator.or_, expr_template="({} or {})={}" + _assert_correctness_binary( + "logical_or", + operator.or_, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out, + expr_template="({} or {})={}" ) @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_xor(x1, x2): out = xp.logical_xor(x1, x2) - ph.assert_dtype("logical_xor", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) - ph.assert_result_shape("logical_xor", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) - binary_assert_against_refimpl( - "logical_xor", x1, x2, out, operator.xor, expr_template="({} ^ {})={}" + _assert_correctness_binary( + "logical_xor", + operator.xor, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out, + expr_template="({} ^ {})={}" ) @@ -1463,18 +1523,18 @@ def test_logical_xor(x1, x2): @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_maximum(x1, x2): out = xp.maximum(x1, x2) - ph.assert_dtype("maximum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) - ph.assert_result_shape("maximum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) - binary_assert_against_refimpl("maximum", x1, x2, out, max, strict_check=True) + _assert_correctness_binary( + "maximum", max, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True + ) @pytest.mark.min_version("2023.12") @given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_minimum(x1, x2): out = xp.minimum(x1, x2) - ph.assert_dtype("minimum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) - ph.assert_result_shape("minimum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) - binary_assert_against_refimpl("minimum", x1, x2, out, min, strict_check=True) + _assert_correctness_binary( + "minimum", min, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True + ) @pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes)) @@ -1719,3 +1779,88 @@ def test_trunc(x): ph.assert_dtype("trunc", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape("trunc", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("trunc", x, out, math.trunc, strict_check=True) + + +def _check_binary_with_scalars(func_data, x1x2): + x1, x2 = x1x2 + func, name, refimpl, kwds, expected_dtype = func_data + out = func(x1, x2) + in_dtypes, in_shapes, (x1a, x2a) = _convert_scalars_helper(x1, x2) + _assert_correctness_binary( + name, refimpl, in_dtypes, in_shapes, (x1a, x2a), out, expected_dtype, **kwds + ) + + +def _filter_zero(x): + return x != 0 if dh.is_scalar(x) else (not xp.any(x == 0)) + + +@pytest.mark.min_version("2024.12") +@pytest.mark.parametrize('func_data', + # xp_func, name, refimpl, kwargs, expected_dtype + [ + (xp.add, "add", operator.add, {}, None), + (xp.atan2, "atan2", math.atan2, {}, None), + (xp.copysign, "copysign", math.copysign, {}, None), + (xp.divide, "divide", operator.truediv, {"filter_": lambda s: s != 0}, None), + (xp.hypot, "hypot", math.hypot, {}, None), + (xp.logaddexp, "logaddexp", logaddexp_refimpl, {}, None), + (xp.maximum, "maximum", max, {'strict_check': True}, None), + (xp.minimum, "minimum", min, {'strict_check': True}, None), + (xp.multiply, "mul", operator.mul, {}, None), + (xp.subtract, "sub", operator.sub, {}, None), + + (xp.equal, "equal", operator.eq, {}, xp.bool), + (xp.not_equal, "neq", operator.ne, {}, xp.bool), + (xp.less, "less", operator.lt, {}, xp.bool), + (xp.less_equal, "les_equal", operator.le, {}, xp.bool), + (xp.greater, "greater", operator.gt, {}, xp.bool), + (xp.greater_equal, "greater_equal", operator.ge, {}, xp.bool), + (xp.remainder, "remainder", operator.mod, {}, None), + (xp.floor_divide, "floor_divide", operator.floordiv, {}, None), + ], + ids=lambda func_data: func_data[1] # use names for test IDs +) +@given(x1x2=hh.array_and_py_scalar(dh.real_float_dtypes)) +def test_binary_with_scalars_real(func_data, x1x2): + + if func_data[1] == "remainder": + assume(_filter_zero(x1x2[1])) + if func_data[1] == "floor_divide": + assume(_filter_zero(x1x2[0]) and _filter_zero(x1x2[1])) + + _check_binary_with_scalars(func_data, x1x2) + + +@pytest.mark.min_version("2024.12") +@pytest.mark.parametrize('func_data', + # xp_func, name, refimpl, kwargs, expected_dtype + [ + (xp.logical_and, "logical_and", operator.and_, {"expr_template": "({} or {})={}"}, None), + (xp.logical_or, "logical_or", operator.or_, {"expr_template": "({} or {})={}"}, None), + (xp.logical_xor, "logical_xor", operator.xor, {"expr_template": "({} or {})={}"}, None), + ], + ids=lambda func_data: func_data[1] # use names for test IDs +) +@given(x1x2=hh.array_and_py_scalar([xp.bool])) +def test_binary_with_scalars_bool(func_data, x1x2): + _check_binary_with_scalars(func_data, x1x2) + + +@pytest.mark.min_version("2024.12") +@pytest.mark.parametrize('func_data', + # xp_func, name, refimpl, kwargs, expected_dtype + [ + (xp.bitwise_and, "bitwise_and", operator.and_, {}, None), + (xp.bitwise_or, "bitwise_or", operator.or_, {}, None), + (xp.bitwise_xor, "bitwise_xor", operator.xor, {}, None), + ], + ids=lambda func_data: func_data[1] # use names for test IDs +) +@given(x1x2=hh.array_and_py_scalar([xp.int32])) +def test_binary_with_scalars_bitwise(func_data, x1x2): + xp_func, name, refimpl, kwargs, expected = func_data + # repack the refimpl + refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 ) + _check_binary_with_scalars((xp_func, name, refimpl_, kwargs,expected), x1x2) +