Skip to content

Commit 3fb693c

Browse files
authored
Merge pull request #354 from ev-br/lighten_scalar_tests
MAINT: lighten the notation in _with_scalar tests
2 parents eb3d690 + 18c3a12 commit 3fb693c

File tree

1 file changed

+38
-46
lines changed

1 file changed

+38
-46
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

+38-46
Original file line numberDiff line numberDiff line change
@@ -1783,48 +1783,42 @@ def test_trunc(x):
17831783

17841784
def _check_binary_with_scalars(func_data, x1x2):
17851785
x1, x2 = x1x2
1786-
func, name, refimpl, kwds, expected_dtype = func_data
1786+
func_name, refimpl, kwds, expected_dtype = func_data
1787+
func = getattr(xp, func_name)
17871788
out = func(x1, x2)
17881789
in_dtypes, in_shapes, (x1a, x2a) = _convert_scalars_helper(x1, x2)
17891790
_assert_correctness_binary(
1790-
name, refimpl, in_dtypes, in_shapes, (x1a, x2a), out, expected_dtype, **kwds
1791+
func_name, refimpl, in_dtypes, in_shapes, (x1a, x2a), out, expected_dtype, **kwds
17911792
)
17921793

17931794

17941795
def _filter_zero(x):
17951796
return x != 0 if dh.is_scalar(x) else (not xp.any(x == 0))
17961797

1797-
# workarounds for xp.copysign etc only available in 2023.12
1798-
# Without it, test suite fails to import with ARRAY_API_VERSION=2022.12
1799-
_xp_copysign = getattr(xp, "copysign", None)
1800-
_xp_hypot = getattr(xp, "hypot", None)
1801-
_xp_maximum = getattr(xp, "maximum", None)
1802-
_xp_minimum = getattr(xp, "minimum", None)
1803-
18041798

18051799
@pytest.mark.min_version("2024.12")
18061800
@pytest.mark.parametrize('func_data',
1807-
# xp_func, name, refimpl, kwargs, expected_dtype
1801+
# func_name, refimpl, kwargs, expected_dtype
18081802
[
1809-
(xp.add, "add", operator.add, {}, None),
1810-
(xp.atan2, "atan2", math.atan2, {}, None),
1811-
(_xp_copysign, "copysign", math.copysign, {}, None),
1812-
(xp.divide, "divide", operator.truediv, {"filter_": lambda s: s != 0}, None),
1813-
(_xp_hypot, "hypot", math.hypot, {}, None),
1814-
(xp.logaddexp, "logaddexp", logaddexp_refimpl, {}, None),
1815-
(_xp_maximum, "maximum", max, {'strict_check': True}, None),
1816-
(_xp_minimum, "minimum", min, {'strict_check': True}, None),
1817-
(xp.multiply, "mul", operator.mul, {}, None),
1818-
(xp.subtract, "sub", operator.sub, {}, None),
1819-
1820-
(xp.equal, "equal", operator.eq, {}, xp.bool),
1821-
(xp.not_equal, "neq", operator.ne, {}, xp.bool),
1822-
(xp.less, "less", operator.lt, {}, xp.bool),
1823-
(xp.less_equal, "les_equal", operator.le, {}, xp.bool),
1824-
(xp.greater, "greater", operator.gt, {}, xp.bool),
1825-
(xp.greater_equal, "greater_equal", operator.ge, {}, xp.bool),
1803+
("add", operator.add, {}, None),
1804+
("atan2", math.atan2, {}, None),
1805+
("copysign", math.copysign, {}, None),
1806+
("divide", operator.truediv, {"filter_": lambda s: s != 0}, None),
1807+
("hypot", math.hypot, {}, None),
1808+
("logaddexp", logaddexp_refimpl, {}, None),
1809+
("maximum", max, {'strict_check': True}, None),
1810+
("minimum", min, {'strict_check': True}, None),
1811+
("multiply", operator.mul, {}, None),
1812+
("subtract", operator.sub, {}, None),
1813+
1814+
("equal", operator.eq, {}, xp.bool),
1815+
("not_equal", operator.ne, {}, xp.bool),
1816+
("less", operator.lt, {}, xp.bool),
1817+
("less_equal", operator.le, {}, xp.bool),
1818+
("greater", operator.gt, {}, xp.bool),
1819+
("greater_equal", operator.ge, {}, xp.bool),
18261820
],
1827-
ids=lambda func_data: func_data[1] # use names for test IDs
1821+
ids=lambda func_data: func_data[0] # use names for test IDs
18281822
)
18291823
@given(x1x2=hh.array_and_py_scalar(dh.real_float_dtypes))
18301824
def test_binary_with_scalars_real(func_data, x1x2):
@@ -1833,13 +1827,13 @@ def test_binary_with_scalars_real(func_data, x1x2):
18331827

18341828
@pytest.mark.min_version("2024.12")
18351829
@pytest.mark.parametrize('func_data',
1836-
# xp_func, name, refimpl, kwargs, expected_dtype
1830+
# func_name, refimpl, kwargs, expected_dtype
18371831
[
1838-
(xp.logical_and, "logical_and", operator.and_, {"expr_template": "({} or {})={}"}, None),
1839-
(xp.logical_or, "logical_or", operator.or_, {"expr_template": "({} or {})={}"}, None),
1840-
(xp.logical_xor, "logical_xor", operator.xor, {"expr_template": "({} or {})={}"}, None),
1832+
("logical_and", operator.and_, {"expr_template": "({} or {})={}"}, None),
1833+
("logical_or", operator.or_, {"expr_template": "({} or {})={}"}, None),
1834+
("logical_xor", operator.xor, {"expr_template": "({} or {})={}"}, None),
18411835
],
1842-
ids=lambda func_data: func_data[1] # use names for test IDs
1836+
ids=lambda func_data: func_data[0] # use names for test IDs
18431837
)
18441838
@given(x1x2=hh.array_and_py_scalar([xp.bool]))
18451839
def test_binary_with_scalars_bool(func_data, x1x2):
@@ -1848,36 +1842,34 @@ def test_binary_with_scalars_bool(func_data, x1x2):
18481842

18491843
@pytest.mark.min_version("2024.12")
18501844
@pytest.mark.parametrize('func_data',
1851-
# xp_func, name, refimpl, kwargs, expected_dtype
1845+
# func_name, refimpl, kwargs, expected_dtype
18521846
[
1853-
1854-
(xp.floor_divide, "floor_divide", operator.floordiv, {}, None),
1855-
(xp.remainder, "remainder", operator.mod, {}, None),
1847+
("floor_divide", operator.floordiv, {}, None),
1848+
("remainder", operator.mod, {}, None),
18561849
],
1857-
ids=lambda func_data: func_data[1] # use names for test IDs
1850+
ids=lambda func_data: func_data[0] # use names for test IDs
18581851
)
18591852
@given(x1x2=hh.array_and_py_scalar([xp.int64]))
18601853
def test_binary_with_scalars_int(func_data, x1x2):
1861-
18621854
assume(_filter_zero(x1x2[1]))
18631855
assume(_filter_zero(x1x2[0]) and _filter_zero(x1x2[1]))
18641856
_check_binary_with_scalars(func_data, x1x2)
18651857

18661858

18671859
@pytest.mark.min_version("2024.12")
18681860
@pytest.mark.parametrize('func_data',
1869-
# xp_func, name, refimpl, kwargs, expected_dtype
1861+
# func_name, refimpl, kwargs, expected_dtype
18701862
[
1871-
(xp.bitwise_and, "bitwise_and", operator.and_, {}, None),
1872-
(xp.bitwise_or, "bitwise_or", operator.or_, {}, None),
1873-
(xp.bitwise_xor, "bitwise_xor", operator.xor, {}, None),
1863+
("bitwise_and", operator.and_, {}, None),
1864+
("bitwise_or", operator.or_, {}, None),
1865+
("bitwise_xor", operator.xor, {}, None),
18741866
],
1875-
ids=lambda func_data: func_data[1] # use names for test IDs
1867+
ids=lambda func_data: func_data[0] # use names for test IDs
18761868
)
18771869
@given(x1x2=hh.array_and_py_scalar([xp.int32]))
18781870
def test_binary_with_scalars_bitwise(func_data, x1x2):
1879-
xp_func, name, refimpl, kwargs, expected = func_data
1871+
func_name, refimpl, kwargs, expected = func_data
18801872
# repack the refimpl
18811873
refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 )
1882-
_check_binary_with_scalars((xp_func, name, refimpl_, kwargs,expected), x1x2)
1874+
_check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2)
18831875

0 commit comments

Comments
 (0)