@@ -397,7 +397,7 @@ def assert_scalar_equals(
397
397
kw : dict = {},
398
398
):
399
399
"""
400
- Assert a 0d array, convered to a scalar, is as expected, e.g.
400
+ Assert a 0d array, converted to a scalar, is as expected, e.g.
401
401
402
402
>>> x = xp.ones(5, dtype=xp.uint8)
403
403
>>> out = xp.sum(x)
@@ -407,6 +407,8 @@ def assert_scalar_equals(
407
407
408
408
>>> assert int(out) == 5
409
409
410
+ NOTE: This function does *exact* comparison, even for floats. For
411
+ approximate float comparisons use assert_scalar_isclose
410
412
"""
411
413
__tracebackhide__ = True
412
414
repr_name = repr_name if idx == () else f"{ repr_name } [{ idx } ]"
@@ -418,8 +420,40 @@ def assert_scalar_equals(
418
420
msg = f"{ repr_name } ={ out } , but should be { expected } [{ f_func } ]"
419
421
assert cmath .isnan (out ), msg
420
422
else :
421
- msg = f"{ repr_name } ={ out } , but should be roughly { expected } [{ f_func } ]"
422
- assert cmath .isclose (out , expected , rel_tol = 0.25 , abs_tol = 1 ), msg
423
+ msg = f"{ repr_name } ={ out } , but should be { expected } [{ f_func } ]"
424
+ assert out == expected , msg
425
+
426
+
427
+ def assert_scalar_isclose (
428
+ func_name : str ,
429
+ * ,
430
+ rel_tol : float = 0.25 ,
431
+ abs_tol : float = 1 ,
432
+ type_ : ScalarType ,
433
+ idx : Shape ,
434
+ out : Scalar ,
435
+ expected : Scalar ,
436
+ repr_name : str = "out" ,
437
+ kw : dict = {},
438
+ ):
439
+ """
440
+ Assert a 0d array, converted to a scalar, is close to the expected value, e.g.
441
+
442
+ >>> x = xp.ones(5., dtype=xp.float64)
443
+ >>> out = xp.sum(x)
444
+ >>> assert_scalar_isclose('sum', type_int, out=(), out=int(out), expected=5.)
445
+
446
+ is equivalent to
447
+
448
+ >>> assert math.isclose(float(out) == 5.)
449
+
450
+ """
451
+ __tracebackhide__ = True
452
+ repr_name = repr_name if idx == () else f"{ repr_name } [{ idx } ]"
453
+ f_func = f"{ func_name } ({ fmt_kw (kw )} )"
454
+ msg = f"{ repr_name } ={ out } , but should be roughly { expected } [{ f_func } ]"
455
+ assert type_ in [float , complex ] # Sanity check
456
+ assert cmath .isclose (out , expected , rel_tol = 0.25 , abs_tol = 1 ), msg
423
457
424
458
425
459
def assert_fill (
0 commit comments