Skip to content

Commit b905bca

Browse files
authored
Merge pull request #290 from asmeurer/test_sum-fix
Fix test_sum to be more numerically correct
2 parents 3ab322a + 621316b commit b905bca

File tree

2 files changed

+82
-10
lines changed

2 files changed

+82
-10
lines changed

Diff for: array_api_tests/pytest_helpers.py

+37-3
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def assert_scalar_equals(
397397
kw: dict = {},
398398
):
399399
"""
400-
Assert a 0d array, convered to a scalar, is as expected, e.g.
400+
Assert a 0d array, converted to a scalar, is as expected, e.g.
401401
402402
>>> x = xp.ones(5, dtype=xp.uint8)
403403
>>> out = xp.sum(x)
@@ -407,6 +407,8 @@ def assert_scalar_equals(
407407
408408
>>> assert int(out) == 5
409409
410+
NOTE: This function does *exact* comparison, even for floats. For
411+
approximate float comparisons use assert_scalar_isclose
410412
"""
411413
__tracebackhide__ = True
412414
repr_name = repr_name if idx == () else f"{repr_name}[{idx}]"
@@ -418,8 +420,40 @@ def assert_scalar_equals(
418420
msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
419421
assert cmath.isnan(out), msg
420422
else:
421-
msg = f"{repr_name}={out}, but should be roughly {expected} [{f_func}]"
422-
assert cmath.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg
423+
msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
424+
assert out == expected, msg
425+
426+
427+
def assert_scalar_isclose(
428+
func_name: str,
429+
*,
430+
rel_tol: float = 0.25,
431+
abs_tol: float = 1,
432+
type_: ScalarType,
433+
idx: Shape,
434+
out: Scalar,
435+
expected: Scalar,
436+
repr_name: str = "out",
437+
kw: dict = {},
438+
):
439+
"""
440+
Assert a 0d array, converted to a scalar, is close to the expected value, e.g.
441+
442+
>>> x = xp.ones(5., dtype=xp.float64)
443+
>>> out = xp.sum(x)
444+
>>> assert_scalar_isclose('sum', type_int, out=(), out=int(out), expected=5.)
445+
446+
is equivalent to
447+
448+
>>> assert math.isclose(float(out) == 5.)
449+
450+
"""
451+
__tracebackhide__ = True
452+
repr_name = repr_name if idx == () else f"{repr_name}[{idx}]"
453+
f_func = f"{func_name}({fmt_kw(kw)})"
454+
msg = f"{repr_name}={out}, but should be roughly {expected} [{f_func}]"
455+
assert type_ in [float, complex] # Sanity check
456+
assert cmath.isclose(out, expected, rel_tol=rel_tol, abs_tol=abs_tol), msg
423457

424458

425459
def assert_fill(

Diff for: array_api_tests/test_statistical_functions.py

+45-7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818

1919
@pytest.mark.min_version("2023.12")
20+
@pytest.mark.unvectorized
2021
@given(
2122
x=hh.arrays(
2223
dtype=hh.numeric_dtypes,
@@ -80,10 +81,15 @@ def test_cumulative_sum(x, data):
8081
if dh.is_int_dtype(out.dtype):
8182
m, M = dh.dtype_ranges[out.dtype]
8283
assume(m <= expected <= M)
83-
ph.assert_scalar_equals("cumulative_sum", type_=scalar_type,
84-
idx=out_idx.raw, out=out_val,
85-
expected=expected)
86-
84+
ph.assert_scalar_equals("cumulative_sum", type_=scalar_type,
85+
idx=out_idx.raw, out=out_val,
86+
expected=expected)
87+
else:
88+
condition_number = _sum_condition_number(elements)
89+
assume(condition_number < 1e6)
90+
ph.assert_scalar_isclose("cumulative_sum", type_=scalar_type,
91+
idx=out_idx.raw, out=out_val,
92+
expected=expected)
8793

8894
def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
8995
dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype]
@@ -176,6 +182,16 @@ def test_min(x, data):
176182
ph.assert_scalar_equals("min", type_=scalar_type, idx=out_idx, out=min_, expected=expected)
177183

178184

185+
def _prod_condition_number(elements):
186+
# Relative condition number using the infinity norm
187+
abs_max = max([abs(i) for i in elements])
188+
abs_min = min([abs(i) for i in elements])
189+
190+
if abs_min == 0:
191+
return float('inf')
192+
193+
return abs_max / abs_min
194+
179195
@pytest.mark.unvectorized
180196
@given(
181197
x=hh.arrays(
@@ -225,7 +241,13 @@ def test_prod(x, data):
225241
if dh.is_int_dtype(out.dtype):
226242
m, M = dh.dtype_ranges[out.dtype]
227243
assume(m <= expected <= M)
228-
ph.assert_scalar_equals("prod", type_=scalar_type, idx=out_idx, out=prod, expected=expected)
244+
ph.assert_scalar_equals("prod", type_=scalar_type, idx=out_idx,
245+
out=prod, expected=expected)
246+
else:
247+
condition_number = _prod_condition_number(elements)
248+
assume(condition_number < 1e15)
249+
ph.assert_scalar_isclose("prod", type_=scalar_type, idx=out_idx,
250+
out=prod, expected=expected)
229251

230252

231253
@pytest.mark.skip(reason="flaky") # TODO: fix!
@@ -264,8 +286,16 @@ def test_std(x, data):
264286
)
265287
# We can't easily test the result(s) as standard deviation methods vary a lot
266288

289+
def _sum_condition_number(elements):
290+
sum_abs = sum([abs(i) for i in elements])
291+
abs_sum = abs(sum(elements))
267292

268-
@pytest.mark.unvectorized
293+
if abs_sum == 0:
294+
return float('inf')
295+
296+
return sum_abs / abs_sum
297+
298+
# @pytest.mark.unvectorized
269299
@given(
270300
x=hh.arrays(
271301
dtype=hh.numeric_dtypes,
@@ -314,7 +344,15 @@ def test_sum(x, data):
314344
if dh.is_int_dtype(out.dtype):
315345
m, M = dh.dtype_ranges[out.dtype]
316346
assume(m <= expected <= M)
317-
ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected)
347+
ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx,
348+
out=sum_, expected=expected)
349+
else:
350+
# Avoid value testing for ill conditioned summations. See
351+
# https://en.wikipedia.org/wiki/Kahan_summation_algorithm#Accuracy and
352+
# https://en.wikipedia.org/wiki/Condition_number.
353+
condition_number = _sum_condition_number(elements)
354+
assume(condition_number < 1e6)
355+
ph.assert_scalar_isclose("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected)
318356

319357

320358
@pytest.mark.unvectorized

0 commit comments

Comments
 (0)