Skip to content

Commit 3f24acc

Browse files
committed
ENH: more binary functions with arrays and python scalars
- equal, not_equal, greater, greater_equal, less, less_equal - add, subtract, multiply, divide
1 parent 3650a6a commit 3f24acc

File tree

2 files changed

+39
-16
lines changed

2 files changed

+39
-16
lines changed

array_api_tests/hypothesis_helpers.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from hypothesis import assume, reject
1111
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
12-
integers, just, lists, none, one_of,
12+
integers, complex_numbers, just, lists, none, one_of,
1313
sampled_from, shared, builds, nothing)
1414

1515
from . import _array_module as xp, api_version
@@ -19,7 +19,7 @@
1919
from . import xps
2020
from ._array_module import _UndefinedStub
2121
from ._array_module import bool as bool_dtype
22-
from ._array_module import broadcast_to, eye, float32, float64, full
22+
from ._array_module import broadcast_to, eye, float32, float64, full, complex64, complex128
2323
from .stubs import category_to_funcs
2424
from .pytest_helpers import nargs
2525
from .typing import Array, DataType, Scalar, Shape
@@ -462,6 +462,14 @@ def scalars(draw, dtypes, finite=False):
462462
if finite:
463463
return draw(floats(width=32, allow_nan=False, allow_infinity=False))
464464
return draw(floats(width=32))
465+
elif dtype == complex64:
466+
if finite:
467+
return draw(complex_numbers(width=32, allow_nan=False, allow_infinity=False))
468+
return draw(complex_numbers(width=32))
469+
elif dtype == complex128:
470+
if finite:
471+
return draw(complex_numbers(allow_nan=False, allow_infinity=False))
472+
return draw(complex_numbers())
465473
else:
466474
raise ValueError(f"Unrecognized dtype {dtype}")
467475

array_api_tests/test_operators_and_elementwise_functions.py

+29-14
Original file line numberDiff line numberDiff line change
@@ -715,9 +715,11 @@ def _convert_scalars_helper(x1, x2):
715715
return in_dtypes, in_shapes, (x1a, x2a)
716716

717717

718-
def _assert_correctness_binary(name, func, in_dtypes, in_shapes, in_arrs, out, **kwargs):
718+
def _assert_correctness_binary(
719+
name, func, in_dtypes, in_shapes, in_arrs, out, expected_dtype=None, **kwargs
720+
):
719721
x1a, x2a = in_arrs
720-
ph.assert_dtype(name, in_dtype=in_dtypes, out_dtype=out.dtype)
722+
ph.assert_dtype(name, in_dtype=in_dtypes, out_dtype=out.dtype, expected=expected_dtype)
721723
ph.assert_result_shape(name, in_shapes=in_shapes, out_shape=out.shape)
722724
binary_assert_against_refimpl(name, x1a, x2a, out, func, **kwargs)
723725

@@ -1781,23 +1783,35 @@ def test_trunc(x):
17811783

17821784
def _check_binary_with_scalars(func_data, x1x2):
17831785
x1, x2 = x1x2
1784-
func, name, refimpl, kwds = func_data
1786+
func, name, refimpl, kwds, expected_dtype = func_data
17851787
out = func(x1, x2)
17861788
in_dtypes, in_shapes, (x1a, x2a) = _convert_scalars_helper(x1, x2)
17871789
_assert_correctness_binary(
1788-
name, refimpl, in_dtypes, in_shapes, (x1a, x2a), out, **kwds
1790+
name, refimpl, in_dtypes, in_shapes, (x1a, x2a), out, expected_dtype, **kwds
17891791
)
17901792

17911793

17921794
@pytest.mark.min_version("2024.12")
17931795
@pytest.mark.parametrize('func_data',
1794-
# xp_func, name, refimpl, kwargs
1796+
# xp_func, name, refimpl, kwargs, expected_dtype
17951797
[
1796-
(xp.atan2, "atan2", math.atan2, {}),
1797-
(xp.hypot, "hypot", math.hypot, {}),
1798-
(xp.logaddexp, "logaddexp", logaddexp_refimpl, {}),
1799-
(xp.maximum, "maximum", max, {'strict_check': True}),
1800-
(xp.minimum, "minimum", min, {'strict_check': True}),
1798+
(xp.add, "add", operator.add, {}, None),
1799+
(xp.atan2, "atan2", math.atan2, {}, None),
1800+
(xp.copysign, "copysign", math.copysign, {}, None),
1801+
(xp.divide, "divide", operator.truediv, {"filter_": lambda s: s != 0}, None),
1802+
(xp.hypot, "hypot", math.hypot, {}, None),
1803+
(xp.logaddexp, "logaddexp", logaddexp_refimpl, {}, None),
1804+
(xp.maximum, "maximum", max, {'strict_check': True}, None),
1805+
(xp.minimum, "minimum", min, {'strict_check': True}, None),
1806+
(xp.multiply, "mul", operator.mul, {}, None),
1807+
(xp.subtract, "sub", operator.sub, {}, None),
1808+
1809+
(xp.equal, "equal", operator.eq, {}, xp.bool),
1810+
(xp.not_equal, "neq", operator.ne, {}, xp.bool),
1811+
(xp.less, "less", operator.lt, {}, xp.bool),
1812+
(xp.less_equal, "les_equal", operator.le, {}, xp.bool),
1813+
(xp.greater, "greater", operator.gt, {}, xp.bool),
1814+
(xp.greater_equal, "greater_equal", operator.ge, {}, xp.bool),
18011815
],
18021816
ids=lambda func_data: func_data[1] # use names for test IDs
18031817
)
@@ -1808,14 +1822,15 @@ def test_binary_with_scalars_real(func_data, x1x2):
18081822

18091823
@pytest.mark.min_version("2024.12")
18101824
@pytest.mark.parametrize('func_data',
1811-
# xp_func, name, refimpl, kwargs
1825+
# xp_func, name, refimpl, kwargs, expected_dtype
18121826
[
1813-
(xp.logical_and, "logical_and", operator.and_, {"expr_template": "({} or {})={}"}),
1814-
(xp.logical_or, "logical_or", operator.or_, {"expr_template": "({} or {})={}"}),
1815-
(xp.logical_xor, "logical_xor", operator.xor, {"expr_template": "({} or {})={}"}),
1827+
(xp.logical_and, "logical_and", operator.and_, {"expr_template": "({} or {})={}"}, None),
1828+
(xp.logical_or, "logical_or", operator.or_, {"expr_template": "({} or {})={}"}, None),
1829+
(xp.logical_xor, "logical_xor", operator.xor, {"expr_template": "({} or {})={}"}, None),
18161830
],
18171831
ids=lambda func_data: func_data[1] # use names for test IDs
18181832
)
18191833
@given(x1x2=hh.array_and_py_scalar([xp.bool]))
18201834
def test_binary_with_scalars_bool(func_data, x1x2):
18211835
_check_binary_with_scalars(func_data, x1x2)
1836+

0 commit comments

Comments
 (0)