Skip to content

Commit f82fee5

Browse files
committed
Generate and test the even rounding halves case correctly
1 parent 3773a4d commit f82fee5

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

array_api_tests/test_special_cases.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -540,20 +540,31 @@ class UnaryCase(Case):
540540

541541

542542
r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)")
543-
r_even_int_round_case = re.compile(
543+
r_even_round_halves_case = re.compile(
544544
"If two integers are equally close to ``x_i``, "
545545
"the result is the even integer closest to ``x_i``"
546546
)
547547

548548

549549
def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]:
550-
m, M = dh.dtype_ranges[dtype]
551-
return st.integers(math.ceil(m) // 2, math.floor(M) // 2).map(lambda n: n * 0.5)
550+
"""
551+
Returns a strategy that generates floats that end with .5 and are within the
552+
bounds of dtype.
553+
"""
554+
# We bound our base integers strategy to a range of values which should be
555+
# able to represent a decimal 5 when .5 is added or subtracted.
556+
if dtype == xp.float32:
557+
abs_max = 10**4
558+
else:
559+
abs_max = 10**16
560+
return st.sampled_from([0.5, -0.5]).flatmap(
561+
lambda half: st.integers(-abs_max, abs_max).map(lambda n: n + half)
562+
)
552563

553564

554-
even_int_round_case = UnaryCase(
555-
cond_expr="i % 0.5 == 0",
556-
cond=lambda i: i % 0.5 == 0,
565+
even_round_halves_case = UnaryCase(
566+
cond_expr="modf(i)[0] == 0.5",
567+
cond=lambda i: math.modf(i)[0] == 0.5,
557568
cond_from_dtype=trailing_halves_from_dtype,
558569
result_expr="Decimal(i).to_integral_exact(ROUND_HALF_EVEN)",
559570
check_result=lambda i, result: (
@@ -645,8 +656,8 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
645656
check_result=check_result,
646657
)
647658
cases.append(case)
648-
elif m := r_even_int_round_case.search(case):
649-
cases.append(even_int_round_case)
659+
elif m := r_even_round_halves_case.search(case):
660+
cases.append(even_round_halves_case)
650661
else:
651662
if not r_remaining_case.search(case):
652663
warn(f"case not machine-readable: '{case}'")

0 commit comments

Comments
 (0)