diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 82ab3351..6b4e5550 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1595,6 +1595,26 @@ def test_not_equal(ctx, data): ) +@pytest.mark.min_version("2024.12") +@given( + shapes=hh.two_mutually_broadcastable_shapes, + dtype=hh.real_floating_dtypes, + data=st.data() +) +def test_nextafter(shapes, dtype, data): + x1 = data.draw(hh.arrays(dtype=dtype, shape=shapes[0]), label="x1") + x2 = data.draw(hh.arrays(dtype=dtype, shape=shapes[0]), label="x2") + + out = xp.nextafter(x1, x2) + _assert_correctness_binary( + "nextafter", + math.nextafter, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out + ) + @pytest.mark.parametrize("ctx", make_unary_params("positive", dh.numeric_dtypes)) @given(data=st.data()) def test_positive(ctx, data): @@ -1813,6 +1833,7 @@ def _filter_zero(x): ("divide", operator.truediv, {"filter_": lambda s: s != 0}, None), ("hypot", math.hypot, {}, None), ("logaddexp", logaddexp_refimpl, {}, None), + ("nextafter", math.nextafter, {}, None), ("maximum", max, {'strict_check': True}, None), ("minimum", min, {'strict_check': True}, None), ("multiply", operator.mul, {}, None),