Skip to content

Commit 1fd435c

Browse files
committed
make binary_with_scalar tests less flaky
in particular, remainder(float64, float32) is a very bad idea (returns fp garbage on both numpy and pytorch).
1 parent 88b92a0 commit 1fd435c

File tree

2 files changed

+33
-15
lines changed

2 files changed

+33
-15
lines changed

Diff for: array_api_tests/hypothesis_helpers.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def two_broadcastable_shapes(draw):
449449
)
450450

451451
@composite
452-
def scalars(draw, dtypes, finite=False):
452+
def scalars(draw, dtypes, finite=False, **kwds):
453453
"""
454454
Strategy to generate a scalar that matches a dtype strategy
455455
@@ -463,12 +463,12 @@ def scalars(draw, dtypes, finite=False):
463463
return draw(booleans())
464464
elif dtype == float64:
465465
if finite:
466-
return draw(floats(allow_nan=False, allow_infinity=False))
467-
return draw(floats())
466+
return draw(floats(allow_nan=False, allow_infinity=False, **kwds))
467+
return draw(floats(), **kwds)
468468
elif dtype == float32:
469469
if finite:
470-
return draw(floats(width=32, allow_nan=False, allow_infinity=False))
471-
return draw(floats(width=32))
470+
return draw(floats(width=32, allow_nan=False, allow_infinity=False, **kwds))
471+
return draw(floats(width=32, **kwds))
472472
elif dtype == complex64:
473473
if finite:
474474
return draw(complex_numbers(width=32, allow_nan=False, allow_infinity=False))
@@ -591,8 +591,16 @@ def two_mutual_arrays(
591591
def array_and_py_scalar(draw, dtypes):
592592
"""Draw a pair: (array, scalar) or (scalar, array)."""
593593
dtype = draw(sampled_from(dtypes))
594-
scalar_var = draw(scalars(just(dtype), finite=True))
595-
array_var = draw(arrays(dtype, shape=shapes(min_dims=1)))
594+
595+
scalar_var = draw(scalars(just(dtype), finite=True,
596+
**{'min_value': 1/ (2<<5), 'max_value': 2<<5}
597+
))
598+
599+
elements={}
600+
if dtype in dh.real_float_dtypes:
601+
elements = {'allow_nan': False, 'allow_infinity': False,
602+
'min_value': 1.0 / (2<<5), 'max_value': 2<<5}
603+
array_var = draw(arrays(dtype, shape=shapes(min_dims=1), elements=elements))
596604

597605
if draw(booleans()):
598606
return scalar_var, array_var

Diff for: array_api_tests/test_operators_and_elementwise_functions.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -1816,19 +1816,11 @@ def _filter_zero(x):
18161816
(xp.less_equal, "les_equal", operator.le, {}, xp.bool),
18171817
(xp.greater, "greater", operator.gt, {}, xp.bool),
18181818
(xp.greater_equal, "greater_equal", operator.ge, {}, xp.bool),
1819-
(xp.remainder, "remainder", operator.mod, {}, None),
1820-
(xp.floor_divide, "floor_divide", operator.floordiv, {}, None),
18211819
],
18221820
ids=lambda func_data: func_data[1] # use names for test IDs
18231821
)
18241822
@given(x1x2=hh.array_and_py_scalar(dh.real_float_dtypes))
18251823
def test_binary_with_scalars_real(func_data, x1x2):
1826-
1827-
if func_data[1] == "remainder":
1828-
assume(_filter_zero(x1x2[1]))
1829-
if func_data[1] == "floor_divide":
1830-
assume(_filter_zero(x1x2[0]) and _filter_zero(x1x2[1]))
1831-
18321824
_check_binary_with_scalars(func_data, x1x2)
18331825

18341826

@@ -1847,6 +1839,24 @@ def test_binary_with_scalars_bool(func_data, x1x2):
18471839
_check_binary_with_scalars(func_data, x1x2)
18481840

18491841

1842+
@pytest.mark.min_version("2024.12")
1843+
@pytest.mark.parametrize('func_data',
1844+
# xp_func, name, refimpl, kwargs, expected_dtype
1845+
[
1846+
1847+
(xp.floor_divide, "floor_divide", operator.floordiv, {}, None),
1848+
(xp.remainder, "remainder", operator.mod, {}, None),
1849+
],
1850+
ids=lambda func_data: func_data[1] # use names for test IDs
1851+
)
1852+
@given(x1x2=hh.array_and_py_scalar([xp.int64]))
1853+
def test_binary_with_scalars_int(func_data, x1x2):
1854+
1855+
assume(_filter_zero(x1x2[1]))
1856+
assume(_filter_zero(x1x2[0]) and _filter_zero(x1x2[1]))
1857+
_check_binary_with_scalars(func_data, x1x2)
1858+
1859+
18501860
@pytest.mark.min_version("2024.12")
18511861
@pytest.mark.parametrize('func_data',
18521862
# xp_func, name, refimpl, kwargs, expected_dtype

0 commit comments

Comments
 (0)