From 80522f04160d6c0bdc21d594ad05da85cb20a7ac Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 29 Nov 2024 12:29:14 +0200 Subject: [PATCH 1/2] ENH: add testing of reciprocal --- .../test_operators_and_elementwise_functions.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index fecf9d91..dbd44223 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1570,6 +1570,22 @@ def test_real(x): unary_assert_against_refimpl("real", x, out, operator.attrgetter("real")) +@pytest.mark.min_version("2024.12") +@given(hh.arrays(dtype=hh.floating_dtypes, shape=hh.shapes(), elements=finite_kw)) +def test_reciprocal(x): + out = xp.reciprocal(x) + ph.assert_dtype("reciprocal", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("reciprocal", out_shape=out.shape, expected=x.shape) + refimpl = lambda x: 1.0 / x + unary_assert_against_refimpl( + "reciprocal", + x, + out, + refimpl, + strict_check=True, + ) + + @pytest.mark.skip(reason="flaky") @pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.real_dtypes)) @given(data=st.data()) From 6bba09e6898389bee0fd8393723d025796169127 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 28 Dec 2024 16:43:03 +0200 Subject: [PATCH 2/2] ENH: add testing of cumulative_prod --- array_api_tests/test_statistical_functions.py | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 29f8adbe..0e3aa9d4 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -91,6 +91,64 @@ def test_cumulative_sum(x, data): idx=out_idx.raw, out=out_val, expected=expected) + + +@pytest.mark.min_version("2024.12") +@pytest.mark.unvectorized +@given( + x=hh.arrays( + dtype=hh.numeric_dtypes, + shape=hh.shapes(min_dims=1)), + data=st.data(), +) +def test_cumulative_prod(x, data): + axes = st.integers(-x.ndim, x.ndim - 1) + if x.ndim == 1: + axes = axes | st.none() + axis = data.draw(axes, label='axis') + _axis, = sh.normalize_axis(axis, x.ndim) + dtype = data.draw(kwarg_dtypes(x.dtype)) + include_initial = data.draw(st.booleans(), label="include_initial") + + kw = data.draw( + hh.specified_kwargs( + ("axis", axis, None), + ("dtype", dtype, None), + ("include_initial", include_initial, False), + ), + label="kw", + ) + + out = xp.cumulative_prod(x, **kw) + + expected_shape = list(x.shape) + if include_initial: + expected_shape[_axis] += 1 + expected_shape = tuple(expected_shape) + ph.assert_shape("cumulative_prod", out_shape=out.shape, expected=expected_shape) + + expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) + if expected_dtype is None: + # If a default uint cannot exist (i.e. in PyTorch which doesn't support + # uint32 or uint64), we skip testing the output dtype. + # See https://github.com/data-apis/array-api-tests/issues/106 + if x.dtype in dh.uint_dtypes: + assert dh.is_int_dtype(out.dtype) # sanity check + else: + ph.assert_dtype("cumulative_prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) + + scalar_type = dh.get_scalar_type(out.dtype) + + for x_idx, out_idx, in iter_indices(x.shape, expected_shape, skip_axes=_axis): + #x_arr = x[x_idx.raw] + out_arr = out[out_idx.raw] + + if include_initial: + ph.assert_scalar_equals("cumulative_prod", type_=scalar_type, idx=out_idx.raw, out=out_arr[0], expected=1) + + #TODO: add value testing of cumulative_prod + + def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype] dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]