Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: test binary functions with python scalars #348

Merged
merged 4 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def get_scalar_type(dtype: DataType) -> ScalarType:
def is_scalar(x):
return isinstance(x, (int, float, complex, bool))


def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
dtype_value_pairs = []
for name, value in mapping.items():
Expand Down
26 changes: 24 additions & 2 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from hypothesis import assume, reject
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
integers, just, lists, none, one_of,
integers, complex_numbers, just, lists, none, one_of,
sampled_from, shared, builds, nothing)

from . import _array_module as xp, api_version
Expand All @@ -19,7 +19,7 @@
from . import xps
from ._array_module import _UndefinedStub
from ._array_module import bool as bool_dtype
from ._array_module import broadcast_to, eye, float32, float64, full
from ._array_module import broadcast_to, eye, float32, float64, full, complex64, complex128
from .stubs import category_to_funcs
from .pytest_helpers import nargs
from .typing import Array, DataType, Scalar, Shape
Expand Down Expand Up @@ -462,6 +462,14 @@ def scalars(draw, dtypes, finite=False):
if finite:
return draw(floats(width=32, allow_nan=False, allow_infinity=False))
return draw(floats(width=32))
elif dtype == complex64:
if finite:
return draw(complex_numbers(width=32, allow_nan=False, allow_infinity=False))
return draw(complex_numbers(width=32))
elif dtype == complex128:
if finite:
return draw(complex_numbers(allow_nan=False, allow_infinity=False))
return draw(complex_numbers())
else:
raise ValueError(f"Unrecognized dtype {dtype}")

Expand Down Expand Up @@ -571,6 +579,20 @@ def two_mutual_arrays(
)
return arrays1, arrays2


@composite
def array_and_py_scalar(draw, dtypes):
"""Draw a pair: (array, scalar) or (scalar, array)."""
dtype = draw(sampled_from(dtypes))
scalar_var = draw(scalars(just(dtype), finite=True))
array_var = draw(arrays(dtype, shape=shapes(min_dims=1)))

if draw(booleans()):
return scalar_var, array_var
else:
return array_var, scalar_var


@composite
def kwargs(draw, **kw):
"""
Expand Down
213 changes: 179 additions & 34 deletions array_api_tests/test_operators_and_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,40 @@ def binary_param_assert_against_refimpl(
)


def _convert_scalars_helper(x1, x2):
"""Convert python scalar to arrays, record the shapes/dtypes of arrays.

For inputs being scalars or arrays, return the dtypes and shapes of array arguments,
and all arguments converted to arrays.

dtypes are separate to help distinguishing between
`py_scalar + f32_array -> f32_array` and `f64_array + f32_array -> f64_array`
"""
if dh.is_scalar(x1):
in_dtypes = [x2.dtype]
in_shapes = [x2.shape]
x1a, x2a = xp.asarray(x1), x2
elif dh.is_scalar(x2):
in_dtypes = [x1.dtype]
in_shapes = [x1.shape]
x1a, x2a = x1, xp.asarray(x2)
else:
in_dtypes = [x1.dtype, x2.dtype]
in_shapes = [x1.shape, x2.shape]
x1a, x2a = x1, x2

return in_dtypes, in_shapes, (x1a, x2a)


def _assert_correctness_binary(
name, func, in_dtypes, in_shapes, in_arrs, out, expected_dtype=None, **kwargs
):
x1a, x2a = in_arrs
ph.assert_dtype(name, in_dtype=in_dtypes, out_dtype=out.dtype, expected=expected_dtype)
ph.assert_result_shape(name, in_shapes=in_shapes, out_shape=out.shape)
binary_assert_against_refimpl(name, x1a, x2a, out, func, **kwargs)


@pytest.mark.parametrize("ctx", make_unary_params("abs", dh.numeric_dtypes))
@given(data=st.data())
def test_abs(ctx, data):
Expand Down Expand Up @@ -789,10 +823,14 @@ def test_atan(x):
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
def test_atan2(x1, x2):
out = xp.atan2(x1, x2)
ph.assert_dtype("atan2", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
ph.assert_result_shape("atan2", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
refimpl = cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2
binary_assert_against_refimpl("atan2", x1, x2, out, refimpl)
_assert_correctness_binary(
"atan",
cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2,
in_dtypes=[x1.dtype, x2.dtype],
in_shapes=[x1.shape, x2.shape],
in_arrs=[x1, x2],
out=out,
)


@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
Expand Down Expand Up @@ -1258,10 +1296,14 @@ def test_greater_equal(ctx, data):
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
def test_hypot(x1, x2):
out = xp.hypot(x1, x2)
ph.assert_dtype("hypot", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
ph.assert_result_shape("hypot", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
binary_assert_against_refimpl("hypot", x1, x2, out, math.hypot)

_assert_correctness_binary(
"hypot",
math.hypot,
in_dtypes=[x1.dtype, x2.dtype],
in_shapes=[x1.shape, x2.shape],
in_arrs=[x1, x2],
out=out
)


@pytest.mark.min_version("2022.12")
Expand Down Expand Up @@ -1411,21 +1453,17 @@ def logaddexp_refimpl(l: float, r: float) -> float:
raise OverflowError


@pytest.mark.min_version("2023.12")
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
def test_logaddexp(x1, x2):
out = xp.logaddexp(x1, x2)
ph.assert_dtype("logaddexp", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
ph.assert_result_shape("logaddexp", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
binary_assert_against_refimpl("logaddexp", x1, x2, out, logaddexp_refimpl)


@given(*hh.two_mutual_arrays([xp.bool]))
def test_logical_and(x1, x2):
out = xp.logical_and(x1, x2)
ph.assert_dtype("logical_and", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
ph.assert_result_shape("logical_and", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
binary_assert_against_refimpl(
"logical_and", x1, x2, out, operator.and_, expr_template="({} and {})={}"
_assert_correctness_binary(
"logaddexp",
logaddexp_refimpl,
in_dtypes=[x1.dtype, x2.dtype],
in_shapes=[x1.shape, x2.shape],
in_arrs=[x1, x2],
out=out
)


Expand All @@ -1439,42 +1477,64 @@ def test_logical_not(x):
)


@given(*hh.two_mutual_arrays([xp.bool]))
def test_logical_and(x1, x2):
out = xp.logical_and(x1, x2)
_assert_correctness_binary(
"logical_and",
operator.and_,
in_dtypes=[x1.dtype, x2.dtype],
in_shapes=[x1.shape, x2.shape],
in_arrs=[x1, x2],
out=out,
expr_template="({} and {})={}"
)


@given(*hh.two_mutual_arrays([xp.bool]))
def test_logical_or(x1, x2):
out = xp.logical_or(x1, x2)
ph.assert_dtype("logical_or", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
ph.assert_result_shape("logical_or", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
binary_assert_against_refimpl(
"logical_or", x1, x2, out, operator.or_, expr_template="({} or {})={}"
_assert_correctness_binary(
"logical_or",
operator.or_,
in_dtypes=[x1.dtype, x2.dtype],
in_shapes=[x1.shape, x2.shape],
in_arrs=[x1, x2],
out=out,
expr_template="({} or {})={}"
)


@given(*hh.two_mutual_arrays([xp.bool]))
def test_logical_xor(x1, x2):
out = xp.logical_xor(x1, x2)
ph.assert_dtype("logical_xor", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
ph.assert_result_shape("logical_xor", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
binary_assert_against_refimpl(
"logical_xor", x1, x2, out, operator.xor, expr_template="({} ^ {})={}"
_assert_correctness_binary(
"logical_xor",
operator.xor,
in_dtypes=[x1.dtype, x2.dtype],
in_shapes=[x1.shape, x2.shape],
in_arrs=[x1, x2],
out=out,
expr_template="({} ^ {})={}"
)


@pytest.mark.min_version("2023.12")
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
def test_maximum(x1, x2):
out = xp.maximum(x1, x2)
ph.assert_dtype("maximum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
ph.assert_result_shape("maximum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
binary_assert_against_refimpl("maximum", x1, x2, out, max, strict_check=True)
_assert_correctness_binary(
"maximum", max, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True
)


@pytest.mark.min_version("2023.12")
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
def test_minimum(x1, x2):
out = xp.minimum(x1, x2)
ph.assert_dtype("minimum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
ph.assert_result_shape("minimum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
binary_assert_against_refimpl("minimum", x1, x2, out, min, strict_check=True)
_assert_correctness_binary(
"minimum", min, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True
)


@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes))
Expand Down Expand Up @@ -1719,3 +1779,88 @@ def test_trunc(x):
ph.assert_dtype("trunc", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape("trunc", out_shape=out.shape, expected=x.shape)
unary_assert_against_refimpl("trunc", x, out, math.trunc, strict_check=True)


def _check_binary_with_scalars(func_data, x1x2):
x1, x2 = x1x2
func, name, refimpl, kwds, expected_dtype = func_data
out = func(x1, x2)
in_dtypes, in_shapes, (x1a, x2a) = _convert_scalars_helper(x1, x2)
_assert_correctness_binary(
name, refimpl, in_dtypes, in_shapes, (x1a, x2a), out, expected_dtype, **kwds
)


def _filter_zero(x):
return x != 0 if dh.is_scalar(x) else (not xp.any(x == 0))


@pytest.mark.min_version("2024.12")
@pytest.mark.parametrize('func_data',
# xp_func, name, refimpl, kwargs, expected_dtype
[
(xp.add, "add", operator.add, {}, None),
(xp.atan2, "atan2", math.atan2, {}, None),
(xp.copysign, "copysign", math.copysign, {}, None),
(xp.divide, "divide", operator.truediv, {"filter_": lambda s: s != 0}, None),
(xp.hypot, "hypot", math.hypot, {}, None),
(xp.logaddexp, "logaddexp", logaddexp_refimpl, {}, None),
(xp.maximum, "maximum", max, {'strict_check': True}, None),
(xp.minimum, "minimum", min, {'strict_check': True}, None),
(xp.multiply, "mul", operator.mul, {}, None),
(xp.subtract, "sub", operator.sub, {}, None),

(xp.equal, "equal", operator.eq, {}, xp.bool),
(xp.not_equal, "neq", operator.ne, {}, xp.bool),
(xp.less, "less", operator.lt, {}, xp.bool),
(xp.less_equal, "les_equal", operator.le, {}, xp.bool),
(xp.greater, "greater", operator.gt, {}, xp.bool),
(xp.greater_equal, "greater_equal", operator.ge, {}, xp.bool),
(xp.remainder, "remainder", operator.mod, {}, None),
(xp.floor_divide, "floor_divide", operator.floordiv, {}, None),
],
ids=lambda func_data: func_data[1] # use names for test IDs
)
@given(x1x2=hh.array_and_py_scalar(dh.real_float_dtypes))
def test_binary_with_scalars_real(func_data, x1x2):

if func_data[1] == "remainder":
assume(_filter_zero(x1x2[1]))
if func_data[1] == "floor_divide":
assume(_filter_zero(x1x2[0]) and _filter_zero(x1x2[1]))

_check_binary_with_scalars(func_data, x1x2)


@pytest.mark.min_version("2024.12")
@pytest.mark.parametrize('func_data',
# xp_func, name, refimpl, kwargs, expected_dtype
[
(xp.logical_and, "logical_and", operator.and_, {"expr_template": "({} or {})={}"}, None),
(xp.logical_or, "logical_or", operator.or_, {"expr_template": "({} or {})={}"}, None),
(xp.logical_xor, "logical_xor", operator.xor, {"expr_template": "({} or {})={}"}, None),
],
ids=lambda func_data: func_data[1] # use names for test IDs
)
@given(x1x2=hh.array_and_py_scalar([xp.bool]))
def test_binary_with_scalars_bool(func_data, x1x2):
_check_binary_with_scalars(func_data, x1x2)


@pytest.mark.min_version("2024.12")
@pytest.mark.parametrize('func_data',
# xp_func, name, refimpl, kwargs, expected_dtype
[
(xp.bitwise_and, "bitwise_and", operator.and_, {}, None),
(xp.bitwise_or, "bitwise_or", operator.or_, {}, None),
(xp.bitwise_xor, "bitwise_xor", operator.xor, {}, None),
],
ids=lambda func_data: func_data[1] # use names for test IDs
)
@given(x1x2=hh.array_and_py_scalar([xp.int32]))
def test_binary_with_scalars_bitwise(func_data, x1x2):
xp_func, name, refimpl, kwargs, expected = func_data
# repack the refimpl
refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 )
_check_binary_with_scalars((xp_func, name, refimpl_, kwargs,expected), x1x2)