Skip to content

Commit 573ed3c

Browse files
authored
Merge pull request #143 from crusaderky/isclose_inf
BUG: `isclose` finite vs. infinite
2 parents 13b9f2f + 116558a commit 573ed3c

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

src/array_api_extra/_lib/_funcs.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,12 @@ def isclose(
386386
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
387387
if a_inexact or b_inexact:
388388
# FIXME: use scipy's lazywhere to suppress warnings on inf
389-
out = xp.abs(a - b) <= (atol + rtol * xp.abs(b))
390-
out = xp.where(xp.isinf(a) & xp.isinf(b), xp.sign(a) == xp.sign(b), out)
389+
out = xp.where(
390+
xp.isinf(a) | xp.isinf(b),
391+
xp.isinf(a) & xp.isinf(b) & (xp.sign(a) == xp.sign(b)),
392+
# Note: inf <= inf is True!
393+
xp.abs(a - b) <= (atol + rtol * xp.abs(b)),
394+
)
391395
if equal_nan:
392396
out = xp.where(xp.isnan(a) & xp.isnan(b), xp.asarray(True), out)
393397
return out

tests/test_funcs.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def test_xp(self, xp: ModuleType):
336336
class TestIsClose:
337337
# FIXME use lazywhere to avoid warnings on inf
338338
@pytest.mark.filterwarnings("ignore:invalid value encountered")
339+
@pytest.mark.parametrize("swap", [False, True])
339340
@pytest.mark.parametrize(
340341
("a", "b"),
341342
[
@@ -353,9 +354,9 @@ class TestIsClose:
353354
(float("inf"), float("inf")),
354355
(float("inf"), 100.0),
355356
(float("inf"), float("-inf")),
357+
(float("-inf"), float("-inf")),
356358
(float("nan"), float("nan")),
357-
(float("nan"), 0.0),
358-
(0.0, float("nan")),
359+
(float("nan"), 100.0),
359360
(1e6, 1e6 + 1), # True - within rtol
360361
(1e6, 1e6 + 100), # False - outside rtol
361362
(1e-6, 1.1e-6), # False - outside atol
@@ -364,19 +365,20 @@ class TestIsClose:
364365
(1e6 + 0j, 1e6 + 100j), # False - outside rtol
365366
],
366367
)
367-
def test_basic(self, a: float, b: float, xp: ModuleType):
368+
def test_basic(self, a: float, b: float, swap: bool, xp: ModuleType):
369+
if swap:
370+
b, a = a, b
368371
a_xp = xp.asarray(a)
369372
b_xp = xp.asarray(b)
370373

371374
xp_assert_equal(isclose(a_xp, b_xp), xp.asarray(np.isclose(a, b)))
372375

373376
with warnings.catch_warnings():
374377
warnings.simplefilter("ignore")
375-
r_xp = xp.asarray(np.arange(10), dtype=a_xp.dtype)
376-
ar_xp = a_xp * r_xp
377-
br_xp = b_xp * r_xp
378378
ar_np = a * np.arange(10)
379379
br_np = b * np.arange(10)
380+
ar_xp = xp.asarray(ar_np)
381+
br_xp = xp.asarray(br_np)
380382

381383
xp_assert_equal(isclose(ar_xp, br_xp), xp.asarray(np.isclose(ar_np, br_np)))
382384

@@ -395,14 +397,14 @@ def test_broadcast(self, dtype: str, xp: ModuleType):
395397
# FIXME use lazywhere to avoid warnings on inf
396398
@pytest.mark.filterwarnings("ignore:invalid value encountered")
397399
def test_some_inf(self, xp: ModuleType):
398-
a = xp.asarray([0.0, 1.0, float("inf"), float("inf"), float("inf")])
399-
b = xp.asarray([1e-9, 1.0, float("inf"), float("-inf"), 2.0])
400+
a = xp.asarray([0.0, 1.0, xp.inf, xp.inf, xp.inf])
401+
b = xp.asarray([1e-9, 1.0, xp.inf, -xp.inf, 2.0])
400402
actual = isclose(a, b)
401403
xp_assert_equal(actual, xp.asarray([True, True, True, False, False]))
402404

403405
def test_equal_nan(self, xp: ModuleType):
404-
a = xp.asarray([float("nan"), float("nan"), 1.0])
405-
b = xp.asarray([float("nan"), 1.0, float("nan")])
406+
a = xp.asarray([xp.nan, xp.nan, 1.0])
407+
b = xp.asarray([xp.nan, 1.0, xp.nan])
406408
xp_assert_equal(isclose(a, b), xp.asarray([False, False, False]))
407409
xp_assert_equal(isclose(a, b, equal_nan=True), xp.asarray([True, False, False]))
408410

0 commit comments

Comments
 (0)