Skip to content

Commit aefe195

Browse files
committed
Split out assert_scalar_equals and assert_scalar_isclose
1 parent 827edd8 commit aefe195

File tree

1 file changed

+37
-3
lines changed

1 file changed

+37
-3
lines changed

array_api_tests/pytest_helpers.py

+37-3
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def assert_scalar_equals(
397397
kw: dict = {},
398398
):
399399
"""
400-
Assert a 0d array, convered to a scalar, is as expected, e.g.
400+
Assert a 0d array, converted to a scalar, is as expected, e.g.
401401
402402
>>> x = xp.ones(5, dtype=xp.uint8)
403403
>>> out = xp.sum(x)
@@ -407,6 +407,8 @@ def assert_scalar_equals(
407407
408408
>>> assert int(out) == 5
409409
410+
NOTE: This function does *exact* comparison, even for floats. For
411+
approximate float comparisons use assert_scalar_isclose
410412
"""
411413
__tracebackhide__ = True
412414
repr_name = repr_name if idx == () else f"{repr_name}[{idx}]"
@@ -418,8 +420,40 @@ def assert_scalar_equals(
418420
msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
419421
assert cmath.isnan(out), msg
420422
else:
421-
msg = f"{repr_name}={out}, but should be roughly {expected} [{f_func}]"
422-
assert cmath.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg
423+
msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
424+
assert out == expected, msg
425+
426+
427+
def assert_scalar_isclose(
428+
func_name: str,
429+
*,
430+
rel_tol: float = 0.25,
431+
abs_tol: float = 1,
432+
type_: ScalarType,
433+
idx: Shape,
434+
out: Scalar,
435+
expected: Scalar,
436+
repr_name: str = "out",
437+
kw: dict = {},
438+
):
439+
"""
440+
Assert a 0d array, converted to a scalar, is close to the expected value, e.g.
441+
442+
>>> x = xp.ones(5., dtype=xp.float64)
443+
>>> out = xp.sum(x)
444+
>>> assert_scalar_isclose('sum', type_int, out=(), out=int(out), expected=5.)
445+
446+
is equivalent to
447+
448+
>>> assert math.isclose(float(out) == 5.)
449+
450+
"""
451+
__tracebackhide__ = True
452+
repr_name = repr_name if idx == () else f"{repr_name}[{idx}]"
453+
f_func = f"{func_name}({fmt_kw(kw)})"
454+
msg = f"{repr_name}={out}, but should be roughly {expected} [{f_func}]"
455+
assert type_ in [float, complex] # Sanity check
456+
assert cmath.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg
423457

424458

425459
def assert_fill(

0 commit comments

Comments
 (0)