|
19 | 19 | from dataclasses import dataclass, field
|
20 | 20 | from decimal import ROUND_HALF_EVEN, Decimal
|
21 | 21 | from enum import Enum, auto
|
22 |
| -from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple |
| 22 | +from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Literal |
23 | 23 | from warnings import warn
|
24 | 24 |
|
25 | 25 | import pytest
|
@@ -544,6 +544,10 @@ class UnaryCase(Case):
|
544 | 544 | "If two integers are equally close to ``x_i``, "
|
545 | 545 | "the result is the even integer closest to ``x_i``"
|
546 | 546 | )
|
| 547 | +r_nan_signbit = re.compile( |
| 548 | + "If ``x_i`` is ``NaN`` and the sign bit of ``x_i`` is ``(.+)``, " |
| 549 | + "the result is ``(.+)``" |
| 550 | +) |
547 | 551 |
|
548 | 552 |
|
549 | 553 | def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
|
@@ -599,6 +603,25 @@ def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]:
|
599 | 603 | )
|
600 | 604 |
|
601 | 605 |
|
| 606 | +def make_nan_signbit_case(signbit: Literal[0, 1], expected: bool) -> UnaryCase: |
| 607 | + if signbit: |
| 608 | + sign = -1 |
| 609 | + nan_expr = "-NaN" |
| 610 | + float_arg = "-nan" |
| 611 | + else: |
| 612 | + sign = 1 |
| 613 | + nan_expr = "+NaN" |
| 614 | + float_arg = "nan" |
| 615 | + |
| 616 | + return UnaryCase( |
| 617 | + cond_expr=f"x_i is {nan_expr}", |
| 618 | + cond=lambda i: math.isnan(i) and math.copysign(1, i) == sign, |
| 619 | + cond_from_dtype=lambda _: st.just(float(float_arg)), |
| 620 | + result_expr=str(expected), |
| 621 | + check_result=lambda _, result: result == float(expected), |
| 622 | + ) |
| 623 | + |
| 624 | + |
602 | 625 | def make_unary_check_result(check_just_result: UnaryCheck) -> UnaryResultCheck:
|
603 | 626 | def check_result(i: float, result: float) -> bool:
|
604 | 627 | return check_just_result(result)
|
@@ -655,10 +678,14 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
|
655 | 678 | cases = []
|
656 | 679 | for case_m in r_case.finditer(case_block):
|
657 | 680 | case_str = case_m.group(1)
|
658 |
| - if m := r_already_int_case.search(case_str): |
| 681 | + if r_already_int_case.search(case_str): |
659 | 682 | cases.append(already_int_case)
|
660 |
| - elif m := r_even_round_halves_case.search(case_str): |
| 683 | + elif r_even_round_halves_case.search(case_str): |
661 | 684 | cases.append(even_round_halves_case)
|
| 685 | + elif m := r_nan_signbit.search(case_str): |
| 686 | + signbit = parse_value(m.group(1)) |
| 687 | + expected = bool(parse_value(m.group(2))) |
| 688 | + cases.append(make_nan_signbit_case(signbit, expected)) |
662 | 689 | elif m := r_unary_case.search(case_str):
|
663 | 690 | try:
|
664 | 691 | cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1))
|
|
0 commit comments