Skip to content

Commit 5aace72

Browse files
committed
Add missing bound from dtypes for sign special cases
1 parent fa5b9a7 commit 5aace72

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

array_api_tests/test_special_cases.py

+6
Original file line numberDiff line numberDiff line change
@@ -941,12 +941,18 @@ def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
941941
def partial_cond(i1: float, i2: float) -> bool:
942942
return math.copysign(1, i1) == math.copysign(1, i2)
943943

944+
x1_cond_from_dtypes.append(BoundFromDtype(kwargs={"min_value": 1}))
945+
x2_cond_from_dtypes.append(BoundFromDtype(kwargs={"min_value": 1}))
946+
944947
elif value_str == "different mathematical signs":
945948
partial_expr = "copysign(1, x1_i) != copysign(1, x2_i)"
946949

947950
def partial_cond(i1: float, i2: float) -> bool:
948951
return math.copysign(1, i1) != math.copysign(1, i2)
949952

953+
x1_cond_from_dtypes.append(BoundFromDtype(kwargs={"min_value": 1}))
954+
x2_cond_from_dtypes.append(BoundFromDtype(kwargs={"max_value": -1}))
955+
950956
else:
951957
unary_cond, expr_template, cond_from_dtype = parse_cond(value_str)
952958
# Do not define partial_cond via the def keyword or lambda

0 commit comments

Comments
 (0)