Skip to content

Commit 6ee70c0

Browse files
authored
ENH: new function isclose (#113)
1 parent 3754e7c commit 6ee70c0

File tree

8 files changed

+447
-145
lines changed

8 files changed

+447
-145
lines changed

docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
cov
1212
create_diagonal
1313
expand_dims
14+
isclose
1415
kron
1516
nunique
1617
pad

pixi.lock

+125-128
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/array_api_extra/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import pad
3+
from ._delegation import isclose, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
66
atleast_nd,
@@ -23,6 +23,7 @@
2323
"cov",
2424
"create_diagonal",
2525
"expand_dims",
26+
"isclose",
2627
"kron",
2728
"nunique",
2829
"pad",

src/array_api_extra/_delegation.py

+91-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ._lib._utils._compat import array_namespace
99
from ._lib._utils._typing import Array
1010

11-
__all__ = ["pad"]
11+
__all__ = ["isclose", "pad"]
1212

1313

1414
def _delegate(xp: ModuleType, *backends: Backend) -> bool:
@@ -30,6 +30,96 @@ def _delegate(xp: ModuleType, *backends: Backend) -> bool:
3030
return any(backend.is_namespace(xp) for backend in backends)
3131

3232

33+
def isclose(
34+
a: Array,
35+
b: Array,
36+
*,
37+
rtol: float = 1e-05,
38+
atol: float = 1e-08,
39+
equal_nan: bool = False,
40+
xp: ModuleType | None = None,
41+
) -> Array:
42+
"""
43+
Return a boolean array where two arrays are element-wise equal within a tolerance.
44+
45+
The tolerance values are positive, typically very small numbers. The relative
46+
difference ``(rtol * abs(b))`` and the absolute difference `atol` are added together
47+
to compare against the absolute difference between `a` and `b`.
48+
49+
NaNs are treated as equal if they are in the same place and if ``equal_nan=True``.
50+
Infs are treated as equal if they are in the same place and of the same sign in both
51+
arrays.
52+
53+
Parameters
54+
----------
55+
a, b : Array
56+
Input arrays to compare.
57+
rtol : array_like, optional
58+
The relative tolerance parameter (see Notes).
59+
atol : array_like, optional
60+
The absolute tolerance parameter (see Notes).
61+
equal_nan : bool, optional
62+
Whether to compare NaN's as equal. If True, NaN's in `a` will be considered
63+
equal to NaN's in `b` in the output array.
64+
xp : array_namespace, optional
65+
The standard-compatible namespace for `a` and `b`. Default: infer.
66+
67+
Returns
68+
-------
69+
Array
70+
A boolean array of shape broadcasted from `a` and `b`, containing ``True`` where
71+
`a` is close to `b`, and ``False`` otherwise.
72+
73+
Warnings
74+
--------
75+
The default `atol` is not appropriate for comparing numbers with magnitudes much
76+
smaller than one (see notes).
77+
78+
See Also
79+
--------
80+
math.isclose : Similar function in stdlib for Python scalars.
81+
82+
Notes
83+
-----
84+
For finite values, `isclose` uses the following equation to test whether two
85+
floating point values are equivalent::
86+
87+
absolute(a - b) <= (atol + rtol * absolute(b))
88+
89+
Unlike the built-in `math.isclose`,
90+
the above equation is not symmetric in `a` and `b`,
91+
so that ``isclose(a, b)`` might be different from ``isclose(b, a)`` in some rare
92+
cases.
93+
94+
The default value of `atol` is not appropriate when the reference value `b` has
95+
magnitude smaller than one. For example, it is unlikely that ``a = 1e-9`` and
96+
``b = 2e-9`` should be considered "close", yet ``isclose(1e-9, 2e-9)`` is ``True``
97+
with default settings. Be sure to select `atol` for the use case at hand, especially
98+
for defining the threshold below which a non-zero value in `a` will be considered
99+
"close" to a very small or zero value in `b`.
100+
101+
The comparison of `a` and `b` uses standard broadcasting, which means that `a` and
102+
`b` need not have the same shape in order for ``isclose(a, b)`` to evaluate to
103+
``True``.
104+
105+
`isclose` is not defined for non-numeric data types.
106+
``bool`` is considered a numeric data-type for this purpose.
107+
"""
108+
xp = array_namespace(a, b) if xp is None else xp
109+
110+
if _delegate(
111+
xp,
112+
Backend.NUMPY,
113+
Backend.CUPY,
114+
Backend.DASK,
115+
Backend.JAX,
116+
Backend.TORCH,
117+
):
118+
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
119+
120+
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
121+
122+
33123
def pad(
34124
x: Array,
35125
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],

src/array_api_extra/_lib/_funcs.py

+34
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,40 @@ def expand_dims(
305305
return a
306306

307307

308+
def isclose(
309+
a: Array,
310+
b: Array,
311+
*,
312+
rtol: float = 1e-05,
313+
atol: float = 1e-08,
314+
equal_nan: bool = False,
315+
xp: ModuleType,
316+
) -> Array: # numpydoc ignore=PR01,RT01
317+
"""See docstring in array_api_extra._delegation."""
318+
319+
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
320+
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
321+
if a_inexact or b_inexact:
322+
# FIXME: use scipy's lazywhere to suppress warnings on inf
323+
out = xp.abs(a - b) <= (atol + rtol * xp.abs(b))
324+
out = xp.where(xp.isinf(a) & xp.isinf(b), xp.sign(a) == xp.sign(b), out)
325+
if equal_nan:
326+
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out)
327+
return out
328+
329+
if xp.isdtype(a.dtype, "bool") or xp.isdtype(b.dtype, "bool"):
330+
if atol >= 1 or rtol >= 1:
331+
return xp.ones_like(a == b)
332+
return a == b
333+
334+
# integer types
335+
atol = int(atol)
336+
if rtol == 0:
337+
return xp.abs(a - b) <= atol
338+
nrtol = int(1.0 / rtol)
339+
return xp.abs(a - b) <= (atol + xp.abs(b) // nrtol)
340+
341+
308342
def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
309343
"""
310344
Kronecker product of two arrays.

src/array_api_extra/_lib/_testing.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
Note that this is private API; don't expect it to be stable.
55
"""
66

7+
import math
78
from types import ModuleType
89

910
from ._utils._compat import (
1011
array_namespace,
1112
is_cupy_namespace,
13+
is_dask_namespace,
1214
is_pydata_sparse_namespace,
1315
is_torch_namespace,
1416
)
@@ -40,8 +42,16 @@ def _check_ns_shape_dtype(
4042
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
4143
assert actual_xp == desired_xp, msg
4244

43-
msg = f"shapes do not match: {actual.shape} != f{desired.shape}"
44-
assert actual.shape == desired.shape, msg
45+
actual_shape = actual.shape
46+
desired_shape = desired.shape
47+
if is_dask_namespace(desired_xp):
48+
if any(math.isnan(i) for i in actual_shape):
49+
actual_shape = actual.compute().shape
50+
if any(math.isnan(i) for i in desired_shape):
51+
desired_shape = desired.compute().shape
52+
53+
msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
54+
assert actual_shape == desired_shape, msg
4555

4656
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
4757
assert actual.dtype == desired.dtype, msg
@@ -61,6 +71,11 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
6171
The expected array (typically hardcoded).
6272
err_msg : str, optional
6373
Error message to display on failure.
74+
75+
See Also
76+
--------
77+
xp_assert_close : Similar function for inexact equality checks.
78+
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
6479
"""
6580
xp = _check_ns_shape_dtype(actual, desired)
6681

@@ -112,6 +127,16 @@ def xp_assert_close(
112127
Absolute tolerance. Default: 0.
113128
err_msg : str, optional
114129
Error message to display on failure.
130+
131+
See Also
132+
--------
133+
xp_assert_equal : Similar function for exact equality checks.
134+
isclose : Public function for checking closeness.
135+
numpy.testing.assert_allclose : Similar function for NumPy arrays.
136+
137+
Notes
138+
-----
139+
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
115140
"""
116141
xp = _check_ns_shape_dtype(actual, desired)
117142

tests/test_funcs.py

+135-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
cov,
1212
create_diagonal,
1313
expand_dims,
14+
isclose,
1415
kron,
1516
nunique,
1617
pad,
@@ -23,7 +24,7 @@
2324
from array_api_extra._lib._utils._typing import Array, Device
2425

2526
# some xp backends are untyped
26-
# mypy: disable-error-code=no-untyped-usage
27+
# mypy: disable-error-code=no-untyped-def
2728

2829

2930
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims")
@@ -252,6 +253,139 @@ def test_xp(self, xp: ModuleType):
252253
assert y.shape == (1, 1, 1, 3)
253254

254255

256+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
257+
class TestIsClose:
258+
# FIXME use lazywhere to avoid warnings on inf
259+
@pytest.mark.filterwarnings("ignore:invalid value encountered")
260+
@pytest.mark.parametrize(
261+
("a", "b"),
262+
[
263+
(0.0, 0.0),
264+
(1.0, 1.0),
265+
(1.0, 2.0),
266+
(1.0, -1.0),
267+
(100.0, 101.0),
268+
(0, 0),
269+
(1, 1),
270+
(1, 2),
271+
(1, -1),
272+
(1.0 + 1j, 1.0 + 1j),
273+
(1.0 + 1j, 1.0 - 1j),
274+
(float("inf"), float("inf")),
275+
(float("inf"), 100.0),
276+
(float("inf"), float("-inf")),
277+
(float("nan"), float("nan")),
278+
(float("nan"), 0.0),
279+
(0.0, float("nan")),
280+
(1e6, 1e6 + 1), # True - within rtol
281+
(1e6, 1e6 + 100), # False - outside rtol
282+
(1e-6, 1.1e-6), # False - outside atol
283+
(1e-7, 1.1e-7), # True - outside atol
284+
(1e6 + 0j, 1e6 + 1j), # True - within rtol
285+
(1e6 + 0j, 1e6 + 100j), # False - outside rtol
286+
],
287+
)
288+
def test_basic(self, a: float, b: float, xp: ModuleType):
289+
a_xp = xp.asarray(a)
290+
b_xp = xp.asarray(b)
291+
292+
xp_assert_equal(isclose(a_xp, b_xp), xp.asarray(np.isclose(a, b)))
293+
294+
with warnings.catch_warnings():
295+
warnings.simplefilter("ignore")
296+
r_xp = xp.asarray(np.arange(10), dtype=a_xp.dtype)
297+
ar_xp = a_xp * r_xp
298+
br_xp = b_xp * r_xp
299+
ar_np = a * np.arange(10)
300+
br_np = b * np.arange(10)
301+
302+
xp_assert_equal(isclose(ar_xp, br_xp), xp.asarray(np.isclose(ar_np, br_np)))
303+
304+
@pytest.mark.parametrize("dtype", ["float32", "int32"])
305+
def test_broadcast(self, dtype: str, xp: ModuleType):
306+
dtype = getattr(xp, dtype)
307+
a = xp.asarray([1, 2, 3], dtype=dtype)
308+
b = xp.asarray([[1], [5]], dtype=dtype)
309+
actual = isclose(a, b)
310+
expect = xp.asarray(
311+
[[True, False, False], [False, False, False]], dtype=xp.bool
312+
)
313+
314+
xp_assert_equal(actual, expect)
315+
316+
# FIXME use lazywhere to avoid warnings on inf
317+
@pytest.mark.filterwarnings("ignore:invalid value encountered")
318+
def test_some_inf(self, xp: ModuleType):
319+
a = xp.asarray([0.0, 1.0, float("inf"), float("inf"), float("inf")])
320+
b = xp.asarray([1e-9, 1.0, float("inf"), float("-inf"), 2.0])
321+
actual = isclose(a, b)
322+
xp_assert_equal(actual, xp.asarray([True, True, True, False, False]))
323+
324+
def test_equal_nan(self, xp: ModuleType):
325+
a = xp.asarray([float("nan"), float("nan"), 1.0])
326+
b = xp.asarray([float("nan"), 1.0, float("nan")])
327+
xp_assert_equal(isclose(a, b), xp.asarray([False, False, False]))
328+
xp_assert_equal(isclose(a, b, equal_nan=True), xp.asarray([True, False, False]))
329+
330+
@pytest.mark.parametrize("dtype", ["float32", "complex64", "int32"])
331+
def test_tolerance(self, dtype: str, xp: ModuleType):
332+
dtype = getattr(xp, dtype)
333+
a = xp.asarray([100, 100], dtype=dtype)
334+
b = xp.asarray([101, 102], dtype=dtype)
335+
xp_assert_equal(isclose(a, b), xp.asarray([False, False]))
336+
xp_assert_equal(isclose(a, b, atol=1), xp.asarray([True, False]))
337+
xp_assert_equal(isclose(a, b, rtol=0.01), xp.asarray([True, False]))
338+
339+
# Attempt to trigger division by 0 in rtol on int dtype
340+
xp_assert_equal(isclose(a, b, rtol=0), xp.asarray([False, False]))
341+
xp_assert_equal(isclose(a, b, atol=1, rtol=0), xp.asarray([True, False]))
342+
343+
def test_very_small_numbers(self, xp: ModuleType):
344+
a = xp.asarray([1e-9, 1e-9])
345+
b = xp.asarray([1.0001e-9, 1.00001e-9])
346+
# Difference is below default atol
347+
xp_assert_equal(isclose(a, b), xp.asarray([True, True]))
348+
# Use only rtol
349+
xp_assert_equal(isclose(a, b, atol=0), xp.asarray([False, True]))
350+
xp_assert_equal(isclose(a, b, atol=0, rtol=0), xp.asarray([False, False]))
351+
352+
def test_bool_dtype(self, xp: ModuleType):
353+
a = xp.asarray([False, True, False])
354+
b = xp.asarray([True, True, False])
355+
xp_assert_equal(isclose(a, b), xp.asarray([False, True, True]))
356+
xp_assert_equal(isclose(a, b, atol=1), xp.asarray([True, True, True]))
357+
xp_assert_equal(isclose(a, b, atol=2), xp.asarray([True, True, True]))
358+
xp_assert_equal(isclose(a, b, rtol=1), xp.asarray([True, True, True]))
359+
xp_assert_equal(isclose(a, b, rtol=2), xp.asarray([True, True, True]))
360+
361+
# Test broadcasting
362+
xp_assert_equal(
363+
isclose(a, xp.asarray(True), atol=1), xp.asarray([True, True, True])
364+
)
365+
xp_assert_equal(
366+
isclose(xp.asarray(True), b, atol=1), xp.asarray([True, True, True])
367+
)
368+
369+
def test_none_shape(self, xp: ModuleType):
370+
a = xp.asarray([1, 5, 0])
371+
b = xp.asarray([1, 4, 2])
372+
b = b[a < 5]
373+
a = a[a < 5]
374+
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))
375+
376+
def test_none_shape_bool(self, xp: ModuleType):
377+
a = xp.asarray([True, True, False])
378+
b = xp.asarray([True, False, True])
379+
b = b[a]
380+
a = a[a]
381+
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))
382+
383+
def test_xp(self, xp: ModuleType):
384+
a = xp.asarray([0.0, 0.0])
385+
b = xp.asarray([1e-9, 1e-4])
386+
xp_assert_equal(isclose(a, b, xp=xp), xp.asarray([True, False]))
387+
388+
255389
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims")
256390
class TestKron:
257391
def test_basic(self, xp: ModuleType):

0 commit comments

Comments
 (0)