Skip to content

Commit 8a6c340

Browse files
authored
Merge pull request #324 from ev-br/reciprocal
Add testing of `reciprocal`, `cumulative_prod` from 2024.12 spec revision
2 parents 4ee45a0 + 6bba09e commit 8a6c340

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

+16
Original file line numberDiff line numberDiff line change
@@ -1570,6 +1570,22 @@ def test_real(x):
15701570
unary_assert_against_refimpl("real", x, out, operator.attrgetter("real"))
15711571

15721572

1573+
@pytest.mark.min_version("2024.12")
1574+
@given(hh.arrays(dtype=hh.floating_dtypes, shape=hh.shapes(), elements=finite_kw))
1575+
def test_reciprocal(x):
1576+
out = xp.reciprocal(x)
1577+
ph.assert_dtype("reciprocal", in_dtype=x.dtype, out_dtype=out.dtype)
1578+
ph.assert_shape("reciprocal", out_shape=out.shape, expected=x.shape)
1579+
refimpl = lambda x: 1.0 / x
1580+
unary_assert_against_refimpl(
1581+
"reciprocal",
1582+
x,
1583+
out,
1584+
refimpl,
1585+
strict_check=True,
1586+
)
1587+
1588+
15731589
@pytest.mark.skip(reason="flaky")
15741590
@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.real_dtypes))
15751591
@given(data=st.data())

array_api_tests/test_statistical_functions.py

+58
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,64 @@ def test_cumulative_sum(x, data):
9191
idx=out_idx.raw, out=out_val,
9292
expected=expected)
9393

94+
95+
96+
@pytest.mark.min_version("2024.12")
97+
@pytest.mark.unvectorized
98+
@given(
99+
x=hh.arrays(
100+
dtype=hh.numeric_dtypes,
101+
shape=hh.shapes(min_dims=1)),
102+
data=st.data(),
103+
)
104+
def test_cumulative_prod(x, data):
105+
axes = st.integers(-x.ndim, x.ndim - 1)
106+
if x.ndim == 1:
107+
axes = axes | st.none()
108+
axis = data.draw(axes, label='axis')
109+
_axis, = sh.normalize_axis(axis, x.ndim)
110+
dtype = data.draw(kwarg_dtypes(x.dtype))
111+
include_initial = data.draw(st.booleans(), label="include_initial")
112+
113+
kw = data.draw(
114+
hh.specified_kwargs(
115+
("axis", axis, None),
116+
("dtype", dtype, None),
117+
("include_initial", include_initial, False),
118+
),
119+
label="kw",
120+
)
121+
122+
out = xp.cumulative_prod(x, **kw)
123+
124+
expected_shape = list(x.shape)
125+
if include_initial:
126+
expected_shape[_axis] += 1
127+
expected_shape = tuple(expected_shape)
128+
ph.assert_shape("cumulative_prod", out_shape=out.shape, expected=expected_shape)
129+
130+
expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
131+
if expected_dtype is None:
132+
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
133+
# uint32 or uint64), we skip testing the output dtype.
134+
# See https://github.com/data-apis/array-api-tests/issues/106
135+
if x.dtype in dh.uint_dtypes:
136+
assert dh.is_int_dtype(out.dtype) # sanity check
137+
else:
138+
ph.assert_dtype("cumulative_prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype)
139+
140+
scalar_type = dh.get_scalar_type(out.dtype)
141+
142+
for x_idx, out_idx, in iter_indices(x.shape, expected_shape, skip_axes=_axis):
143+
#x_arr = x[x_idx.raw]
144+
out_arr = out[out_idx.raw]
145+
146+
if include_initial:
147+
ph.assert_scalar_equals("cumulative_prod", type_=scalar_type, idx=out_idx.raw, out=out_arr[0], expected=1)
148+
149+
#TODO: add value testing of cumulative_prod
150+
151+
94152
def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
95153
dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype]
96154
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]

0 commit comments

Comments
 (0)