Skip to content

Commit 3650a6a

Browse files
committed
ENH: binary functions with arrays and python scalars
- logical_{and,or,xor} - atan2, hypot, logaddexp, minimum, maximum
1 parent 0b89c52 commit 3650a6a

File tree

3 files changed

+149
-34
lines changed

3 files changed

+149
-34
lines changed

array_api_tests/dtype_helpers.py

+1
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def get_scalar_type(dtype: DataType) -> ScalarType:
198198
def is_scalar(x):
199199
return isinstance(x, (int, float, complex, bool))
200200

201+
201202
def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
202203
dtype_value_pairs = []
203204
for name, value in mapping.items():

array_api_tests/hypothesis_helpers.py

+14
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,20 @@ def two_mutual_arrays(
571571
)
572572
return arrays1, arrays2
573573

574+
575+
@composite
576+
def array_and_py_scalar(draw, dtypes):
577+
"""Draw a pair: (array, scalar) or (scalar, array)."""
578+
dtype = draw(sampled_from(dtypes))
579+
scalar_var = draw(scalars(just(dtype), finite=True))
580+
array_var = draw(arrays(dtype, shape=shapes(min_dims=1)))
581+
582+
if draw(booleans()):
583+
return scalar_var, array_var
584+
else:
585+
return array_var, scalar_var
586+
587+
574588
@composite
575589
def kwargs(draw, **kw):
576590
"""

array_api_tests/test_operators_and_elementwise_functions.py

+134-34
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,38 @@ def binary_param_assert_against_refimpl(
690690
)
691691

692692

693+
def _convert_scalars_helper(x1, x2):
694+
"""Convert python scalar to arrays, record the shapes/dtypes of arrays.
695+
696+
For inputs being scalars or arrays, return the dtypes and shapes of array arguments,
697+
and all arguments converted to arrays.
698+
699+
dtypes are separate to help distinguishing between
700+
`py_scalar + f32_array -> f32_array` and `f64_array + f32_array -> f64_array`
701+
"""
702+
if dh.is_scalar(x1):
703+
in_dtypes = [x2.dtype]
704+
in_shapes = [x2.shape]
705+
x1a, x2a = xp.asarray(x1), x2
706+
elif dh.is_scalar(x2):
707+
in_dtypes = [x1.dtype]
708+
in_shapes = [x1.shape]
709+
x1a, x2a = x1, xp.asarray(x2)
710+
else:
711+
in_dtypes = [x1.dtype, x2.dtype]
712+
in_shapes = [x1.shape, x2.shape]
713+
x1a, x2a = x1, x2
714+
715+
return in_dtypes, in_shapes, (x1a, x2a)
716+
717+
718+
def _assert_correctness_binary(name, func, in_dtypes, in_shapes, in_arrs, out, **kwargs):
719+
x1a, x2a = in_arrs
720+
ph.assert_dtype(name, in_dtype=in_dtypes, out_dtype=out.dtype)
721+
ph.assert_result_shape(name, in_shapes=in_shapes, out_shape=out.shape)
722+
binary_assert_against_refimpl(name, x1a, x2a, out, func, **kwargs)
723+
724+
693725
@pytest.mark.parametrize("ctx", make_unary_params("abs", dh.numeric_dtypes))
694726
@given(data=st.data())
695727
def test_abs(ctx, data):
@@ -789,10 +821,14 @@ def test_atan(x):
789821
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
790822
def test_atan2(x1, x2):
791823
out = xp.atan2(x1, x2)
792-
ph.assert_dtype("atan2", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
793-
ph.assert_result_shape("atan2", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
794-
refimpl = cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2
795-
binary_assert_against_refimpl("atan2", x1, x2, out, refimpl)
824+
_assert_correctness_binary(
825+
"atan",
826+
cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2,
827+
in_dtypes=[x1.dtype, x2.dtype],
828+
in_shapes=[x1.shape, x2.shape],
829+
in_arrs=[x1, x2],
830+
out=out,
831+
)
796832

797833

798834
@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
@@ -1258,10 +1294,14 @@ def test_greater_equal(ctx, data):
12581294
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
12591295
def test_hypot(x1, x2):
12601296
out = xp.hypot(x1, x2)
1261-
ph.assert_dtype("hypot", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1262-
ph.assert_result_shape("hypot", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1263-
binary_assert_against_refimpl("hypot", x1, x2, out, math.hypot)
1264-
1297+
_assert_correctness_binary(
1298+
"hypot",
1299+
math.hypot,
1300+
in_dtypes=[x1.dtype, x2.dtype],
1301+
in_shapes=[x1.shape, x2.shape],
1302+
in_arrs=[x1, x2],
1303+
out=out
1304+
)
12651305

12661306

12671307
@pytest.mark.min_version("2022.12")
@@ -1411,21 +1451,17 @@ def logaddexp_refimpl(l: float, r: float) -> float:
14111451
raise OverflowError
14121452

14131453

1454+
@pytest.mark.min_version("2023.12")
14141455
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
14151456
def test_logaddexp(x1, x2):
14161457
out = xp.logaddexp(x1, x2)
1417-
ph.assert_dtype("logaddexp", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1418-
ph.assert_result_shape("logaddexp", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1419-
binary_assert_against_refimpl("logaddexp", x1, x2, out, logaddexp_refimpl)
1420-
1421-
1422-
@given(*hh.two_mutual_arrays([xp.bool]))
1423-
def test_logical_and(x1, x2):
1424-
out = xp.logical_and(x1, x2)
1425-
ph.assert_dtype("logical_and", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1426-
ph.assert_result_shape("logical_and", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1427-
binary_assert_against_refimpl(
1428-
"logical_and", x1, x2, out, operator.and_, expr_template="({} and {})={}"
1458+
_assert_correctness_binary(
1459+
"logaddexp",
1460+
logaddexp_refimpl,
1461+
in_dtypes=[x1.dtype, x2.dtype],
1462+
in_shapes=[x1.shape, x2.shape],
1463+
in_arrs=[x1, x2],
1464+
out=out
14291465
)
14301466

14311467

@@ -1439,42 +1475,64 @@ def test_logical_not(x):
14391475
)
14401476

14411477

1478+
@given(*hh.two_mutual_arrays([xp.bool]))
1479+
def test_logical_and(x1, x2):
1480+
out = xp.logical_and(x1, x2)
1481+
_assert_correctness_binary(
1482+
"logical_and",
1483+
operator.and_,
1484+
in_dtypes=[x1.dtype, x2.dtype],
1485+
in_shapes=[x1.shape, x2.shape],
1486+
in_arrs=[x1, x2],
1487+
out=out,
1488+
expr_template="({} and {})={}"
1489+
)
1490+
1491+
14421492
@given(*hh.two_mutual_arrays([xp.bool]))
14431493
def test_logical_or(x1, x2):
14441494
out = xp.logical_or(x1, x2)
1445-
ph.assert_dtype("logical_or", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1446-
ph.assert_result_shape("logical_or", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1447-
binary_assert_against_refimpl(
1448-
"logical_or", x1, x2, out, operator.or_, expr_template="({} or {})={}"
1495+
_assert_correctness_binary(
1496+
"logical_or",
1497+
operator.or_,
1498+
in_dtypes=[x1.dtype, x2.dtype],
1499+
in_shapes=[x1.shape, x2.shape],
1500+
in_arrs=[x1, x2],
1501+
out=out,
1502+
expr_template="({} or {})={}"
14491503
)
14501504

14511505

14521506
@given(*hh.two_mutual_arrays([xp.bool]))
14531507
def test_logical_xor(x1, x2):
14541508
out = xp.logical_xor(x1, x2)
1455-
ph.assert_dtype("logical_xor", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1456-
ph.assert_result_shape("logical_xor", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1457-
binary_assert_against_refimpl(
1458-
"logical_xor", x1, x2, out, operator.xor, expr_template="({} ^ {})={}"
1509+
_assert_correctness_binary(
1510+
"logical_xor",
1511+
operator.xor,
1512+
in_dtypes=[x1.dtype, x2.dtype],
1513+
in_shapes=[x1.shape, x2.shape],
1514+
in_arrs=[x1, x2],
1515+
out=out,
1516+
expr_template="({} ^ {})={}"
14591517
)
14601518

14611519

14621520
@pytest.mark.min_version("2023.12")
14631521
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
14641522
def test_maximum(x1, x2):
14651523
out = xp.maximum(x1, x2)
1466-
ph.assert_dtype("maximum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1467-
ph.assert_result_shape("maximum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1468-
binary_assert_against_refimpl("maximum", x1, x2, out, max, strict_check=True)
1524+
_assert_correctness_binary(
1525+
"maximum", max, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True
1526+
)
14691527

14701528

14711529
@pytest.mark.min_version("2023.12")
14721530
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
14731531
def test_minimum(x1, x2):
14741532
out = xp.minimum(x1, x2)
1475-
ph.assert_dtype("minimum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1476-
ph.assert_result_shape("minimum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1477-
binary_assert_against_refimpl("minimum", x1, x2, out, min, strict_check=True)
1533+
_assert_correctness_binary(
1534+
"minimum", min, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True
1535+
)
14781536

14791537

14801538
@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes))
@@ -1719,3 +1777,45 @@ def test_trunc(x):
17191777
ph.assert_dtype("trunc", in_dtype=x.dtype, out_dtype=out.dtype)
17201778
ph.assert_shape("trunc", out_shape=out.shape, expected=x.shape)
17211779
unary_assert_against_refimpl("trunc", x, out, math.trunc, strict_check=True)
1780+
1781+
1782+
def _check_binary_with_scalars(func_data, x1x2):
1783+
x1, x2 = x1x2
1784+
func, name, refimpl, kwds = func_data
1785+
out = func(x1, x2)
1786+
in_dtypes, in_shapes, (x1a, x2a) = _convert_scalars_helper(x1, x2)
1787+
_assert_correctness_binary(
1788+
name, refimpl, in_dtypes, in_shapes, (x1a, x2a), out, **kwds
1789+
)
1790+
1791+
1792+
@pytest.mark.min_version("2024.12")
1793+
@pytest.mark.parametrize('func_data',
1794+
# xp_func, name, refimpl, kwargs
1795+
[
1796+
(xp.atan2, "atan2", math.atan2, {}),
1797+
(xp.hypot, "hypot", math.hypot, {}),
1798+
(xp.logaddexp, "logaddexp", logaddexp_refimpl, {}),
1799+
(xp.maximum, "maximum", max, {'strict_check': True}),
1800+
(xp.minimum, "minimum", min, {'strict_check': True}),
1801+
],
1802+
ids=lambda func_data: func_data[1] # use names for test IDs
1803+
)
1804+
@given(x1x2=hh.array_and_py_scalar(dh.real_float_dtypes))
1805+
def test_binary_with_scalars_real(func_data, x1x2):
1806+
_check_binary_with_scalars(func_data, x1x2)
1807+
1808+
1809+
@pytest.mark.min_version("2024.12")
1810+
@pytest.mark.parametrize('func_data',
1811+
# xp_func, name, refimpl, kwargs
1812+
[
1813+
(xp.logical_and, "logical_and", operator.and_, {"expr_template": "({} or {})={}"}),
1814+
(xp.logical_or, "logical_or", operator.or_, {"expr_template": "({} or {})={}"}),
1815+
(xp.logical_xor, "logical_xor", operator.xor, {"expr_template": "({} or {})={}"}),
1816+
],
1817+
ids=lambda func_data: func_data[1] # use names for test IDs
1818+
)
1819+
@given(x1x2=hh.array_and_py_scalar([xp.bool]))
1820+
def test_binary_with_scalars_bool(func_data, x1x2):
1821+
_check_binary_with_scalars(func_data, x1x2)

0 commit comments

Comments
 (0)