@@ -494,6 +494,7 @@ def check_result(result: float) -> bool:
494
494
class Case (Protocol ):
495
495
cond_expr : str
496
496
result_expr : str
497
+ raw_case : Optional [str ]
497
498
498
499
def cond (self , * args ) -> bool :
499
500
...
@@ -532,6 +533,7 @@ class UnaryCase(Case):
532
533
cond_from_dtype : FromDtypeFunc
533
534
cond : UnaryCheck
534
535
check_result : UnaryResultCheck
536
+ raw_case : Optional [str ] = field (default = None )
535
537
536
538
537
539
r_unary_case = re .compile ("If ``x_i`` is (.+), the result is (.+)" )
@@ -674,6 +676,7 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
674
676
cond_from_dtype = cond_from_dtype ,
675
677
result_expr = result_expr ,
676
678
check_result = check_result ,
679
+ raw_case = case_str ,
677
680
)
678
681
cases .append (case )
679
682
else :
@@ -700,6 +703,7 @@ class BinaryCase(Case):
700
703
x2_cond_from_dtype : FromDtypeFunc
701
704
cond : BinaryCond
702
705
check_result : BinaryResultCheck
706
+ raw_case : Optional [str ] = field (default = None )
703
707
704
708
705
709
r_binary_case = re .compile ("If (.+), the result (.+)" )
@@ -1058,6 +1062,7 @@ def cond(i1: float, i2: float) -> bool:
1058
1062
x2_cond_from_dtype = x2_cond_from_dtype ,
1059
1063
result_expr = result_expr ,
1060
1064
check_result = check_result ,
1065
+ raw_case = case_str ,
1061
1066
)
1062
1067
1063
1068
0 commit comments