Skip to content

Commit eb3d690

Browse files
authored
Merge pull request #353 from ev-br/block_remainder_nan
make binary_with_scalar tests less flaky
2 parents abdb4a0 + 1fd435c commit eb3d690

File tree

2 files changed

+33
-15
lines changed

2 files changed

+33
-15
lines changed

array_api_tests/hypothesis_helpers.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def two_broadcastable_shapes(draw):
449449
)
450450

451451
@composite
452-
def scalars(draw, dtypes, finite=False):
452+
def scalars(draw, dtypes, finite=False, **kwds):
453453
"""
454454
Strategy to generate a scalar that matches a dtype strategy
455455
@@ -463,12 +463,12 @@ def scalars(draw, dtypes, finite=False):
463463
return draw(booleans())
464464
elif dtype == float64:
465465
if finite:
466-
return draw(floats(allow_nan=False, allow_infinity=False))
467-
return draw(floats())
466+
return draw(floats(allow_nan=False, allow_infinity=False, **kwds))
467+
return draw(floats(), **kwds)
468468
elif dtype == float32:
469469
if finite:
470-
return draw(floats(width=32, allow_nan=False, allow_infinity=False))
471-
return draw(floats(width=32))
470+
return draw(floats(width=32, allow_nan=False, allow_infinity=False, **kwds))
471+
return draw(floats(width=32, **kwds))
472472
elif dtype == complex64:
473473
if finite:
474474
return draw(complex_numbers(width=32, allow_nan=False, allow_infinity=False))
@@ -591,8 +591,16 @@ def two_mutual_arrays(
591591
def array_and_py_scalar(draw, dtypes):
592592
"""Draw a pair: (array, scalar) or (scalar, array)."""
593593
dtype = draw(sampled_from(dtypes))
594-
scalar_var = draw(scalars(just(dtype), finite=True))
595-
array_var = draw(arrays(dtype, shape=shapes(min_dims=1)))
594+
595+
scalar_var = draw(scalars(just(dtype), finite=True,
596+
**{'min_value': 1/ (2<<5), 'max_value': 2<<5}
597+
))
598+
599+
elements={}
600+
if dtype in dh.real_float_dtypes:
601+
elements = {'allow_nan': False, 'allow_infinity': False,
602+
'min_value': 1.0 / (2<<5), 'max_value': 2<<5}
603+
array_var = draw(arrays(dtype, shape=shapes(min_dims=1), elements=elements))
596604

597605
if draw(booleans()):
598606
return scalar_var, array_var

array_api_tests/test_operators_and_elementwise_functions.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -1823,19 +1823,11 @@ def _filter_zero(x):
18231823
(xp.less_equal, "les_equal", operator.le, {}, xp.bool),
18241824
(xp.greater, "greater", operator.gt, {}, xp.bool),
18251825
(xp.greater_equal, "greater_equal", operator.ge, {}, xp.bool),
1826-
(xp.remainder, "remainder", operator.mod, {}, None),
1827-
(xp.floor_divide, "floor_divide", operator.floordiv, {}, None),
18281826
],
18291827
ids=lambda func_data: func_data[1] # use names for test IDs
18301828
)
18311829
@given(x1x2=hh.array_and_py_scalar(dh.real_float_dtypes))
18321830
def test_binary_with_scalars_real(func_data, x1x2):
1833-
1834-
if func_data[1] == "remainder":
1835-
assume(_filter_zero(x1x2[1]))
1836-
if func_data[1] == "floor_divide":
1837-
assume(_filter_zero(x1x2[0]) and _filter_zero(x1x2[1]))
1838-
18391831
_check_binary_with_scalars(func_data, x1x2)
18401832

18411833

@@ -1854,6 +1846,24 @@ def test_binary_with_scalars_bool(func_data, x1x2):
18541846
_check_binary_with_scalars(func_data, x1x2)
18551847

18561848

1849+
@pytest.mark.min_version("2024.12")
1850+
@pytest.mark.parametrize('func_data',
1851+
# xp_func, name, refimpl, kwargs, expected_dtype
1852+
[
1853+
1854+
(xp.floor_divide, "floor_divide", operator.floordiv, {}, None),
1855+
(xp.remainder, "remainder", operator.mod, {}, None),
1856+
],
1857+
ids=lambda func_data: func_data[1] # use names for test IDs
1858+
)
1859+
@given(x1x2=hh.array_and_py_scalar([xp.int64]))
1860+
def test_binary_with_scalars_int(func_data, x1x2):
1861+
1862+
assume(_filter_zero(x1x2[1]))
1863+
assume(_filter_zero(x1x2[0]) and _filter_zero(x1x2[1]))
1864+
_check_binary_with_scalars(func_data, x1x2)
1865+
1866+
18571867
@pytest.mark.min_version("2024.12")
18581868
@pytest.mark.parametrize('func_data',
18591869
# xp_func, name, refimpl, kwargs, expected_dtype

0 commit comments

Comments
 (0)