Skip to content

Commit 558ffdc

Browse files
committed
Case expression fixes
1 parent f82fee5 commit 558ffdc

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

array_api_tests/test_special_cases.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -364,13 +364,13 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, BoundFromDtype]:
364364
if m := r_code.match(cond_str):
365365
value = parse_value(m.group(1))
366366
cond = make_strict_eq(value)
367-
expr_template = "{} == " + m.group(1)
367+
expr_template = "{} is " + m.group(1)
368368
from_dtype = wrap_strat_as_from_dtype(st.just(value))
369369
elif m := r_either_code.match(cond_str):
370370
v1 = parse_value(m.group(1))
371371
v2 = parse_value(m.group(2))
372372
cond = make_or(make_strict_eq(v1), make_strict_eq(v2))
373-
expr_template = "({} == " + m.group(1) + " or {} == " + m.group(2) + ")"
373+
expr_template = "({} is " + m.group(1) + " or {} == " + m.group(2) + ")"
374374
from_dtype = wrap_strat_as_from_dtype(st.sampled_from([v1, v2]))
375375
elif m := r_equal_to.match(cond_str):
376376
value = parse_value(m.group(1))
@@ -487,7 +487,7 @@ def check_result(result: float) -> bool:
487487
return True
488488
return math.copysign(1, result) == 1
489489

490-
expr = "+"
490+
expr = "positive sign"
491491
elif "negative" in result_str:
492492

493493
def check_result(result: float) -> bool:
@@ -496,7 +496,7 @@ def check_result(result: float) -> bool:
496496
return True
497497
return math.copysign(1, result) == -1
498498

499-
expr = "-"
499+
expr = "negative sign"
500500
else:
501501
raise ParseError(result_str)
502502

@@ -927,7 +927,7 @@ def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
927927
unary_cond, expr_template, cond_from_dtype = parse_cond(m.group(1))
928928
left_expr = expr_template.replace("{}", "x1_i")
929929
right_expr = expr_template.replace("{}", "x2_i")
930-
partial_expr = f"({left_expr}) and ({right_expr})"
930+
partial_expr = f"{left_expr} and {right_expr}"
931931
partial_cond = make_binary_cond( # type: ignore
932932
BinaryCondArg.BOTH, unary_cond
933933
)
@@ -972,7 +972,7 @@ def partial_cond(i1: float, i2: float) -> bool:
972972
elif r_and_input.match(input_str):
973973
left_expr = expr_template.replace("{}", "x1_i")
974974
right_expr = expr_template.replace("{}", "x2_i")
975-
partial_expr = f"({left_expr}) and ({right_expr})"
975+
partial_expr = f"{left_expr} and {right_expr}"
976976
cond_arg = BinaryCondArg.BOTH
977977
elif r_or_input.match(input_str):
978978
left_expr = expr_template.replace("{}", "x1_i")

0 commit comments

Comments
 (0)