Skip to content

Commit 621316b

Browse files
committed
Use condition numbers to filter out ill-conditioned value testing for sum/cumulative_sum/prod
This isn't completely rigorous (I haven't tweaked the tolerances used in isclose from the generous ones we were using before), but I haven't gotten hypothesis to find any bad corner cases for this yet. If any crop up we can easily tweak the values. Fixes #168
1 parent 8f240f6 commit 621316b

File tree

1 file changed

+45
-7
lines changed

1 file changed

+45
-7
lines changed

array_api_tests/test_statistical_functions.py

+45-7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818

1919
@pytest.mark.min_version("2023.12")
20+
@pytest.mark.unvectorized
2021
@given(
2122
x=hh.arrays(
2223
dtype=hh.numeric_dtypes,
@@ -80,10 +81,15 @@ def test_cumulative_sum(x, data):
8081
if dh.is_int_dtype(out.dtype):
8182
m, M = dh.dtype_ranges[out.dtype]
8283
assume(m <= expected <= M)
83-
ph.assert_scalar_equals("cumulative_sum", type_=scalar_type,
84-
idx=out_idx.raw, out=out_val,
85-
expected=expected)
86-
84+
ph.assert_scalar_equals("cumulative_sum", type_=scalar_type,
85+
idx=out_idx.raw, out=out_val,
86+
expected=expected)
87+
else:
88+
condition_number = _sum_condition_number(elements)
89+
assume(condition_number < 1e6)
90+
ph.assert_scalar_isclose("cumulative_sum", type_=scalar_type,
91+
idx=out_idx.raw, out=out_val,
92+
expected=expected)
8793

8894
def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
8995
dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype]
@@ -176,6 +182,16 @@ def test_min(x, data):
176182
ph.assert_scalar_equals("min", type_=scalar_type, idx=out_idx, out=min_, expected=expected)
177183

178184

185+
def _prod_condition_number(elements):
186+
# Relative condition number using the infinity norm
187+
abs_max = max([abs(i) for i in elements])
188+
abs_min = min([abs(i) for i in elements])
189+
190+
if abs_min == 0:
191+
return float('inf')
192+
193+
return abs_max / abs_min
194+
179195
@pytest.mark.unvectorized
180196
@given(
181197
x=hh.arrays(
@@ -225,7 +241,13 @@ def test_prod(x, data):
225241
if dh.is_int_dtype(out.dtype):
226242
m, M = dh.dtype_ranges[out.dtype]
227243
assume(m <= expected <= M)
228-
ph.assert_scalar_equals("prod", type_=scalar_type, idx=out_idx, out=prod, expected=expected)
244+
ph.assert_scalar_equals("prod", type_=scalar_type, idx=out_idx,
245+
out=prod, expected=expected)
246+
else:
247+
condition_number = _prod_condition_number(elements)
248+
assume(condition_number < 1e15)
249+
ph.assert_scalar_isclose("prod", type_=scalar_type, idx=out_idx,
250+
out=prod, expected=expected)
229251

230252

231253
@pytest.mark.skip(reason="flaky") # TODO: fix!
@@ -264,8 +286,16 @@ def test_std(x, data):
264286
)
265287
# We can't easily test the result(s) as standard deviation methods vary a lot
266288

289+
def _sum_condition_number(elements):
290+
sum_abs = sum([abs(i) for i in elements])
291+
abs_sum = abs(sum(elements))
267292

268-
@pytest.mark.unvectorized
293+
if abs_sum == 0:
294+
return float('inf')
295+
296+
return sum_abs / abs_sum
297+
298+
# @pytest.mark.unvectorized
269299
@given(
270300
x=hh.arrays(
271301
dtype=hh.numeric_dtypes,
@@ -314,7 +344,15 @@ def test_sum(x, data):
314344
if dh.is_int_dtype(out.dtype):
315345
m, M = dh.dtype_ranges[out.dtype]
316346
assume(m <= expected <= M)
317-
ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected)
347+
ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx,
348+
out=sum_, expected=expected)
349+
else:
350+
# Avoid value testing for ill conditioned summations. See
351+
# https://en.wikipedia.org/wiki/Kahan_summation_algorithm#Accuracy and
352+
# https://en.wikipedia.org/wiki/Condition_number.
353+
condition_number = _sum_condition_number(elements)
354+
assume(condition_number < 1e6)
355+
ph.assert_scalar_isclose("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected)
318356

319357

320358
@pytest.mark.unvectorized

0 commit comments

Comments
 (0)