@@ -336,6 +336,7 @@ def test_xp(self, xp: ModuleType):
336
336
class TestIsClose :
337
337
# FIXME use lazywhere to avoid warnings on inf
338
338
@pytest .mark .filterwarnings ("ignore:invalid value encountered" )
339
+ @pytest .mark .parametrize ("swap" , [False , True ])
339
340
@pytest .mark .parametrize (
340
341
("a" , "b" ),
341
342
[
@@ -353,9 +354,9 @@ class TestIsClose:
353
354
(float ("inf" ), float ("inf" )),
354
355
(float ("inf" ), 100.0 ),
355
356
(float ("inf" ), float ("-inf" )),
357
+ (float ("-inf" ), float ("-inf" )),
356
358
(float ("nan" ), float ("nan" )),
357
- (float ("nan" ), 0.0 ),
358
- (0.0 , float ("nan" )),
359
+ (float ("nan" ), 100.0 ),
359
360
(1e6 , 1e6 + 1 ), # True - within rtol
360
361
(1e6 , 1e6 + 100 ), # False - outside rtol
361
362
(1e-6 , 1.1e-6 ), # False - outside atol
@@ -364,19 +365,20 @@ class TestIsClose:
364
365
(1e6 + 0j , 1e6 + 100j ), # False - outside rtol
365
366
],
366
367
)
367
- def test_basic (self , a : float , b : float , xp : ModuleType ):
368
+ def test_basic (self , a : float , b : float , swap : bool , xp : ModuleType ):
369
+ if swap :
370
+ b , a = a , b
368
371
a_xp = xp .asarray (a )
369
372
b_xp = xp .asarray (b )
370
373
371
374
xp_assert_equal (isclose (a_xp , b_xp ), xp .asarray (np .isclose (a , b )))
372
375
373
376
with warnings .catch_warnings ():
374
377
warnings .simplefilter ("ignore" )
375
- r_xp = xp .asarray (np .arange (10 ), dtype = a_xp .dtype )
376
- ar_xp = a_xp * r_xp
377
- br_xp = b_xp * r_xp
378
378
ar_np = a * np .arange (10 )
379
379
br_np = b * np .arange (10 )
380
+ ar_xp = xp .asarray (ar_np )
381
+ br_xp = xp .asarray (br_np )
380
382
381
383
xp_assert_equal (isclose (ar_xp , br_xp ), xp .asarray (np .isclose (ar_np , br_np )))
382
384
@@ -395,14 +397,14 @@ def test_broadcast(self, dtype: str, xp: ModuleType):
395
397
# FIXME use lazywhere to avoid warnings on inf
396
398
@pytest .mark .filterwarnings ("ignore:invalid value encountered" )
397
399
def test_some_inf (self , xp : ModuleType ):
398
- a = xp .asarray ([0.0 , 1.0 , float ( " inf" ), float ( " inf" ), float ( " inf" ) ])
399
- b = xp .asarray ([1e-9 , 1.0 , float ( " inf" ), float ( "- inf" ) , 2.0 ])
400
+ a = xp .asarray ([0.0 , 1.0 , xp . inf , xp . inf , xp . inf ])
401
+ b = xp .asarray ([1e-9 , 1.0 , xp . inf , - xp . inf , 2.0 ])
400
402
actual = isclose (a , b )
401
403
xp_assert_equal (actual , xp .asarray ([True , True , True , False , False ]))
402
404
403
405
def test_equal_nan (self , xp : ModuleType ):
404
- a = xp .asarray ([float ( " nan" ), float ( " nan" ) , 1.0 ])
405
- b = xp .asarray ([float ( " nan" ) , 1.0 , float ( " nan" ) ])
406
+ a = xp .asarray ([xp . nan , xp . nan , 1.0 ])
407
+ b = xp .asarray ([xp . nan , 1.0 , xp . nan ])
406
408
xp_assert_equal (isclose (a , b ), xp .asarray ([False , False , False ]))
407
409
xp_assert_equal (isclose (a , b , equal_nan = True ), xp .asarray ([True , False , False ]))
408
410
0 commit comments