Skip to content

Commit 456fc6c

Browse files
committed
Drop outdated x<1/2>_cond_from_dtype logic
1 parent 558ffdc commit 456fc6c

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

array_api_tests/test_special_cases.py

+12-16
Original file line numberDiff line numberDiff line change
@@ -1010,22 +1010,28 @@ def partial_cond(i1: float, i2: float) -> bool:
10101010
st.sampled_from([(True, False), (False, True), (True, True)])
10111011
)
10121012

1013-
def _x1_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
1013+
def _x1_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
1014+
assert len(kw) == 0 # sanity check
10141015
return use_x1_or_x2_strat.flatmap(
10151016
lambda t: cond_from_dtype(dtype)
10161017
if t[0]
10171018
else xps.from_dtype(dtype)
10181019
)
10191020

1020-
def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
1021+
def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
1022+
assert len(kw) == 0 # sanity check
10211023
return use_x1_or_x2_strat.flatmap(
10221024
lambda t: cond_from_dtype(dtype)
10231025
if t[1]
10241026
else xps.from_dtype(dtype)
10251027
)
10261028

1027-
x1_cond_from_dtypes.append(_x1_cond_from_dtype)
1028-
x2_cond_from_dtypes.append(_x2_cond_from_dtype)
1029+
x1_cond_from_dtypes.append(
1030+
BoundFromDtype(base_func=_x1_cond_from_dtype)
1031+
)
1032+
x2_cond_from_dtypes.append(
1033+
BoundFromDtype(base_func=_x2_cond_from_dtype)
1034+
)
10291035

10301036
partial_conds.append(partial_cond)
10311037
partial_exprs.append(partial_expr)
@@ -1050,18 +1056,8 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
10501056
def cond(i1: float, i2: float) -> bool:
10511057
return all(pc(i1, i2) for pc in partial_conds)
10521058

1053-
if len(x1_cond_from_dtypes) == 0:
1054-
x1_cond_from_dtype = xps.from_dtype
1055-
elif len(x1_cond_from_dtypes) == 1:
1056-
x1_cond_from_dtype = x1_cond_from_dtypes[0]
1057-
else:
1058-
x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype())
1059-
if len(x2_cond_from_dtypes) == 0:
1060-
x2_cond_from_dtype = xps.from_dtype
1061-
elif len(x2_cond_from_dtypes) == 1:
1062-
x2_cond_from_dtype = x2_cond_from_dtypes[0]
1063-
else:
1064-
x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype())
1059+
x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype())
1060+
x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype())
10651061

10661062
return BinaryCase(
10671063
cond_expr=cond_expr,

0 commit comments

Comments
 (0)