Skip to content

Commit 70fb777

Browse files
committed
Fix special case testing signbit on NaNs
1 parent 3cf8ef6 commit 70fb777

File tree

1 file changed

+30
-3
lines changed

1 file changed

+30
-3
lines changed

array_api_tests/test_special_cases.py

+30-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from dataclasses import dataclass, field
2020
from decimal import ROUND_HALF_EVEN, Decimal
2121
from enum import Enum, auto
22-
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
22+
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Literal
2323
from warnings import warn
2424

2525
import pytest
@@ -544,6 +544,10 @@ class UnaryCase(Case):
544544
"If two integers are equally close to ``x_i``, "
545545
"the result is the even integer closest to ``x_i``"
546546
)
547+
r_nan_signbit = re.compile(
548+
"If ``x_i`` is ``NaN`` and the sign bit of ``x_i`` is ``(.+)``, "
549+
"the result is ``(.+)``"
550+
)
547551

548552

549553
def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
@@ -599,6 +603,25 @@ def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]:
599603
)
600604

601605

606+
def make_nan_signbit_case(signbit: Literal[0, 1], expected: bool) -> UnaryCase:
607+
if signbit:
608+
sign = -1
609+
nan_expr = "-NaN"
610+
float_arg = "-nan"
611+
else:
612+
sign = 1
613+
nan_expr = "+NaN"
614+
float_arg = "nan"
615+
616+
return UnaryCase(
617+
cond_expr=f"x_i is {nan_expr}",
618+
cond=lambda i: math.isnan(i) and math.copysign(1, i) == sign,
619+
cond_from_dtype=lambda _: st.just(float(float_arg)),
620+
result_expr=str(expected),
621+
check_result=lambda _, result: result == float(expected),
622+
)
623+
624+
602625
def make_unary_check_result(check_just_result: UnaryCheck) -> UnaryResultCheck:
603626
def check_result(i: float, result: float) -> bool:
604627
return check_just_result(result)
@@ -655,10 +678,14 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
655678
cases = []
656679
for case_m in r_case.finditer(case_block):
657680
case_str = case_m.group(1)
658-
if m := r_already_int_case.search(case_str):
681+
if r_already_int_case.search(case_str):
659682
cases.append(already_int_case)
660-
elif m := r_even_round_halves_case.search(case_str):
683+
elif r_even_round_halves_case.search(case_str):
661684
cases.append(even_round_halves_case)
685+
elif m := r_nan_signbit.search(case_str):
686+
signbit = parse_value(m.group(1))
687+
expected = bool(parse_value(m.group(2)))
688+
cases.append(make_nan_signbit_case(signbit, expected))
662689
elif m := r_unary_case.search(case_str):
663690
try:
664691
cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1))

0 commit comments

Comments
 (0)