@@ -1010,22 +1010,28 @@ def partial_cond(i1: float, i2: float) -> bool:
1010
1010
st .sampled_from ([(True , False ), (False , True ), (True , True )])
1011
1011
)
1012
1012
1013
- def _x1_cond_from_dtype (dtype ) -> st .SearchStrategy [float ]:
1013
+ def _x1_cond_from_dtype (dtype , ** kw ) -> st .SearchStrategy [float ]:
1014
+ assert len (kw ) == 0 # sanity check
1014
1015
return use_x1_or_x2_strat .flatmap (
1015
1016
lambda t : cond_from_dtype (dtype )
1016
1017
if t [0 ]
1017
1018
else xps .from_dtype (dtype )
1018
1019
)
1019
1020
1020
- def _x2_cond_from_dtype (dtype ) -> st .SearchStrategy [float ]:
1021
+ def _x2_cond_from_dtype (dtype , ** kw ) -> st .SearchStrategy [float ]:
1022
+ assert len (kw ) == 0 # sanity check
1021
1023
return use_x1_or_x2_strat .flatmap (
1022
1024
lambda t : cond_from_dtype (dtype )
1023
1025
if t [1 ]
1024
1026
else xps .from_dtype (dtype )
1025
1027
)
1026
1028
1027
- x1_cond_from_dtypes .append (_x1_cond_from_dtype )
1028
- x2_cond_from_dtypes .append (_x2_cond_from_dtype )
1029
+ x1_cond_from_dtypes .append (
1030
+ BoundFromDtype (base_func = _x1_cond_from_dtype )
1031
+ )
1032
+ x2_cond_from_dtypes .append (
1033
+ BoundFromDtype (base_func = _x2_cond_from_dtype )
1034
+ )
1029
1035
1030
1036
partial_conds .append (partial_cond )
1031
1037
partial_exprs .append (partial_expr )
@@ -1050,18 +1056,8 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
1050
1056
def cond (i1 : float , i2 : float ) -> bool :
1051
1057
return all (pc (i1 , i2 ) for pc in partial_conds )
1052
1058
1053
- if len (x1_cond_from_dtypes ) == 0 :
1054
- x1_cond_from_dtype = xps .from_dtype
1055
- elif len (x1_cond_from_dtypes ) == 1 :
1056
- x1_cond_from_dtype = x1_cond_from_dtypes [0 ]
1057
- else :
1058
- x1_cond_from_dtype = sum (x1_cond_from_dtypes , start = BoundFromDtype ())
1059
- if len (x2_cond_from_dtypes ) == 0 :
1060
- x2_cond_from_dtype = xps .from_dtype
1061
- elif len (x2_cond_from_dtypes ) == 1 :
1062
- x2_cond_from_dtype = x2_cond_from_dtypes [0 ]
1063
- else :
1064
- x2_cond_from_dtype = sum (x2_cond_from_dtypes , start = BoundFromDtype ())
1059
+ x1_cond_from_dtype = sum (x1_cond_from_dtypes , start = BoundFromDtype ())
1060
+ x2_cond_from_dtype = sum (x2_cond_from_dtypes , start = BoundFromDtype ())
1065
1061
1066
1062
return BinaryCase (
1067
1063
cond_expr = cond_expr ,
0 commit comments