Skip to content

Commit 414b322

Browse files
committed
Add allclose() and assert_allclose() helper functions
1 parent 45b36d6 commit 414b322

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

array_api_tests/array_helpers.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from ._array_module import logical_not, subtract, floor, ceil, where
1010
from . import dtype_helpers as dh
1111

12+
from ndindex import iter_indices
13+
14+
import math
1215

1316
__all__ = ['all', 'any', 'logical_and', 'logical_or', 'logical_not', 'less',
1417
'less_equal', 'greater', 'subtract', 'negative', 'floor', 'ceil',
@@ -146,6 +149,43 @@ def exactly_equal(x, y):
146149

147150
return equal(x, y)
148151

152+
def allclose(x, y, rel_tol=0.25, abs_tol=1, return_indices=False):
153+
"""
154+
Return True all elements of x and y are within tolerance
155+
156+
If return_indices=True, returns (False, (i, j)) when the arrays are not
157+
close, where i and j are the indices into x and y of corresponding
158+
non-close elements.
159+
"""
160+
for i, j in iter_indices(x.shape, y.shape):
161+
i, j = i.raw, j.raw
162+
a = x[i]
163+
b = y[j]
164+
if not (math.isfinite(a) and math.isfinite(b)):
165+
# TODO: If a and b are both infinite, require the same type of infinity
166+
continue
167+
close = math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol)
168+
if not close:
169+
if return_indices:
170+
return (False, (i, j))
171+
return False
172+
return True
173+
174+
def assert_allclose(x, y, rel_tol=0.25, abs_tol=1):
175+
"""
176+
Test that x and y are approximately equal to each other.
177+
178+
Also asserts that x and y have the same shape and dtype.
179+
"""
180+
assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape})"
181+
182+
assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype})"
183+
184+
c = allclose(x, y, rel_tol=rel_tol, abs_tol=abs_tol, return_indices=True)
185+
if c is not True:
186+
_, (i, j) = c
187+
raise AssertionError(f"The input arrays are not close with {rel_tol = } and {abs_tol = } at indices {i = } and {j = }")
188+
149189
def notequal(x, y):
150190
"""
151191
Same as not_equal(x, y) except it gives False when both values are nan.
@@ -305,4 +345,3 @@ def same_sign(x, y):
305345

306346
def assert_same_sign(x, y):
307347
assert all(same_sign(x, y)), "The input arrays do not have the same sign"
308-

0 commit comments

Comments
 (0)