Skip to content

Commit fa5b9a7

Browse files
committed
Store original special case text in processed case objects
1 parent cfd9601 commit fa5b9a7

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

array_api_tests/test_special_cases.py

+5
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ def check_result(result: float) -> bool:
494494
class Case(Protocol):
495495
cond_expr: str
496496
result_expr: str
497+
raw_case: Optional[str]
497498

498499
def cond(self, *args) -> bool:
499500
...
@@ -532,6 +533,7 @@ class UnaryCase(Case):
532533
cond_from_dtype: FromDtypeFunc
533534
cond: UnaryCheck
534535
check_result: UnaryResultCheck
536+
raw_case: Optional[str] = field(default=None)
535537

536538

537539
r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)")
@@ -674,6 +676,7 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
674676
cond_from_dtype=cond_from_dtype,
675677
result_expr=result_expr,
676678
check_result=check_result,
679+
raw_case=case_str,
677680
)
678681
cases.append(case)
679682
else:
@@ -700,6 +703,7 @@ class BinaryCase(Case):
700703
x2_cond_from_dtype: FromDtypeFunc
701704
cond: BinaryCond
702705
check_result: BinaryResultCheck
706+
raw_case: Optional[str] = field(default=None)
703707

704708

705709
r_binary_case = re.compile("If (.+), the result (.+)")
@@ -1058,6 +1062,7 @@ def cond(i1: float, i2: float) -> bool:
10581062
x2_cond_from_dtype=x2_cond_from_dtype,
10591063
result_expr=result_expr,
10601064
check_result=check_result,
1065+
raw_case=case_str,
10611066
)
10621067

10631068

0 commit comments

Comments
 (0)