|
9 | 9 | from ._array_module import logical_not, subtract, floor, ceil, where
|
10 | 10 | from . import dtype_helpers as dh
|
11 | 11 |
|
| 12 | +from ndindex import iter_indices |
| 13 | + |
| 14 | +import math |
12 | 15 |
|
13 | 16 | __all__ = ['all', 'any', 'logical_and', 'logical_or', 'logical_not', 'less',
|
14 | 17 | 'less_equal', 'greater', 'subtract', 'negative', 'floor', 'ceil',
|
@@ -146,6 +149,43 @@ def exactly_equal(x, y):
|
146 | 149 |
|
147 | 150 | return equal(x, y)
|
148 | 151 |
|
| 152 | +def allclose(x, y, rel_tol=0.25, abs_tol=1, return_indices=False): |
| 153 | + """ |
| 154 | + Return True all elements of x and y are within tolerance |
| 155 | +
|
| 156 | + If return_indices=True, returns (False, (i, j)) when the arrays are not |
| 157 | + close, where i and j are the indices into x and y of corresponding |
| 158 | + non-close elements. |
| 159 | + """ |
| 160 | + for i, j in iter_indices(x.shape, y.shape): |
| 161 | + i, j = i.raw, j.raw |
| 162 | + a = x[i] |
| 163 | + b = y[j] |
| 164 | + if not (math.isfinite(a) and math.isfinite(b)): |
| 165 | + # TODO: If a and b are both infinite, require the same type of infinity |
| 166 | + continue |
| 167 | + close = math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) |
| 168 | + if not close: |
| 169 | + if return_indices: |
| 170 | + return (False, (i, j)) |
| 171 | + return False |
| 172 | + return True |
| 173 | + |
| 174 | +def assert_allclose(x, y, rel_tol=0.25, abs_tol=1): |
| 175 | + """ |
| 176 | + Test that x and y are approximately equal to each other. |
| 177 | +
|
| 178 | + Also asserts that x and y have the same shape and dtype. |
| 179 | + """ |
| 180 | + assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape})" |
| 181 | + |
| 182 | + assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype})" |
| 183 | + |
| 184 | + c = allclose(x, y, rel_tol=rel_tol, abs_tol=abs_tol, return_indices=True) |
| 185 | + if c is not True: |
| 186 | + _, (i, j) = c |
| 187 | + raise AssertionError(f"The input arrays are not close with {rel_tol = } and {abs_tol = } at indices {i = } and {j = }") |
| 188 | + |
149 | 189 | def notequal(x, y):
|
150 | 190 | """
|
151 | 191 | Same as not_equal(x, y) except it gives False when both values are nan.
|
@@ -305,4 +345,3 @@ def same_sign(x, y):
|
305 | 345 |
|
306 | 346 | def assert_same_sign(x, y):
|
307 | 347 | assert all(same_sign(x, y)), "The input arrays do not have the same sign"
|
308 |
| - |
|
0 commit comments