|
17 | 17 |
|
18 | 18 |
|
19 | 19 | @pytest.mark.min_version("2023.12")
|
| 20 | +@pytest.mark.unvectorized |
20 | 21 | @given(
|
21 | 22 | x=hh.arrays(
|
22 | 23 | dtype=hh.numeric_dtypes,
|
@@ -80,10 +81,15 @@ def test_cumulative_sum(x, data):
|
80 | 81 | if dh.is_int_dtype(out.dtype):
|
81 | 82 | m, M = dh.dtype_ranges[out.dtype]
|
82 | 83 | assume(m <= expected <= M)
|
83 |
| - ph.assert_scalar_equals("cumulative_sum", type_=scalar_type, |
84 |
| - idx=out_idx.raw, out=out_val, |
85 |
| - expected=expected) |
86 |
| - |
| 84 | + ph.assert_scalar_equals("cumulative_sum", type_=scalar_type, |
| 85 | + idx=out_idx.raw, out=out_val, |
| 86 | + expected=expected) |
| 87 | + else: |
| 88 | + condition_number = _sum_condition_number(elements) |
| 89 | + assume(condition_number < 1e6) |
| 90 | + ph.assert_scalar_isclose("cumulative_sum", type_=scalar_type, |
| 91 | + idx=out_idx.raw, out=out_val, |
| 92 | + expected=expected) |
87 | 93 |
|
88 | 94 | def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
|
89 | 95 | dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype]
|
@@ -176,6 +182,16 @@ def test_min(x, data):
|
176 | 182 | ph.assert_scalar_equals("min", type_=scalar_type, idx=out_idx, out=min_, expected=expected)
|
177 | 183 |
|
178 | 184 |
|
| 185 | +def _prod_condition_number(elements): |
| 186 | + # Relative condition number using the infinity norm |
| 187 | + abs_max = max([abs(i) for i in elements]) |
| 188 | + abs_min = min([abs(i) for i in elements]) |
| 189 | + |
| 190 | + if abs_min == 0: |
| 191 | + return float('inf') |
| 192 | + |
| 193 | + return abs_max / abs_min |
| 194 | + |
179 | 195 | @pytest.mark.unvectorized
|
180 | 196 | @given(
|
181 | 197 | x=hh.arrays(
|
@@ -225,7 +241,13 @@ def test_prod(x, data):
|
225 | 241 | if dh.is_int_dtype(out.dtype):
|
226 | 242 | m, M = dh.dtype_ranges[out.dtype]
|
227 | 243 | assume(m <= expected <= M)
|
228 |
| - ph.assert_scalar_equals("prod", type_=scalar_type, idx=out_idx, out=prod, expected=expected) |
| 244 | + ph.assert_scalar_equals("prod", type_=scalar_type, idx=out_idx, |
| 245 | + out=prod, expected=expected) |
| 246 | + else: |
| 247 | + condition_number = _prod_condition_number(elements) |
| 248 | + assume(condition_number < 1e15) |
| 249 | + ph.assert_scalar_isclose("prod", type_=scalar_type, idx=out_idx, |
| 250 | + out=prod, expected=expected) |
229 | 251 |
|
230 | 252 |
|
231 | 253 | @pytest.mark.skip(reason="flaky") # TODO: fix!
|
@@ -264,8 +286,16 @@ def test_std(x, data):
|
264 | 286 | )
|
265 | 287 | # We can't easily test the result(s) as standard deviation methods vary a lot
|
266 | 288 |
|
| 289 | +def _sum_condition_number(elements): |
| 290 | + sum_abs = sum([abs(i) for i in elements]) |
| 291 | + abs_sum = abs(sum(elements)) |
267 | 292 |
|
268 |
| -@pytest.mark.unvectorized |
| 293 | + if abs_sum == 0: |
| 294 | + return float('inf') |
| 295 | + |
| 296 | + return sum_abs / abs_sum |
| 297 | + |
| 298 | +# @pytest.mark.unvectorized |
269 | 299 | @given(
|
270 | 300 | x=hh.arrays(
|
271 | 301 | dtype=hh.numeric_dtypes,
|
@@ -314,7 +344,15 @@ def test_sum(x, data):
|
314 | 344 | if dh.is_int_dtype(out.dtype):
|
315 | 345 | m, M = dh.dtype_ranges[out.dtype]
|
316 | 346 | assume(m <= expected <= M)
|
317 |
| - ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected) |
| 347 | + ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx, |
| 348 | + out=sum_, expected=expected) |
| 349 | + else: |
| 350 | + # Avoid value testing for ill conditioned summations. See |
| 351 | + # https://en.wikipedia.org/wiki/Kahan_summation_algorithm#Accuracy and |
| 352 | + # https://en.wikipedia.org/wiki/Condition_number. |
| 353 | + condition_number = _sum_condition_number(elements) |
| 354 | + assume(condition_number < 1e6) |
| 355 | + ph.assert_scalar_isclose("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected) |
318 | 356 |
|
319 | 357 |
|
320 | 358 | @pytest.mark.unvectorized
|
|
0 commit comments