Skip to content

Commit 3773a4d

Browse files
committed
Document parse_binary_case()
1 parent f5b0975 commit 3773a4d

File tree

1 file changed

+33
-12
lines changed

1 file changed

+33
-12
lines changed

array_api_tests/test_special_cases.py

+33-12
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,14 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, BoundFromDtype]:
327327
condition, otherwise False.
328328
2. A string template for expressing the condition.
329329
3. A xps.from_dtype()-like function which returns a strategy that generates
330-
elements which meet the condition.
330+
elements that meet the condition.
331331
332332
e.g.
333333
334334
>>> cond, expr_template, from_dtype = parse_cond('greater than ``0``')
335335
>>> cond(42)
336336
True
337-
>>> cond(-128)
337+
>>> cond(-123)
338338
False
339339
>>> expr_template.replace('{}', 'x_i')
340340
'x_i > 0'
@@ -582,8 +582,8 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
582582
...
583583
... For floating-point operands,
584584
...
585-
... - If ``x_i`` is ``NaN``, the result is ``NaN``.
586585
... - If ``x_i`` is less than ``0``, the result is ``NaN``.
586+
... - If ``x_i`` is ``NaN``, the result is ``NaN``.
587587
... - If ``x_i`` is ``+0``, the result is ``+0``.
588588
... - If ``x_i`` is ``-0``, the result is ``-0``.
589589
... - If ``x_i`` is ``+infinity``, the result is ``+infinity``.
@@ -602,11 +602,16 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
602602
>>> unary_cases = parse_unary_docstring(sqrt.__doc__)
603603
>>> for case in unary_cases:
604604
... print(repr(case))
605-
UnaryCase(x_i == NaN -> NaN)
606-
UnaryCase(x_i < 0 -> NaN)
607-
UnaryCase(x_i == +0 -> +0)
608-
UnaryCase(x_i == -0 -> -0)
609-
UnaryCase(x_i == +infinity -> +infinity)
605+
UnaryCase(<x_i < 0 -> NaN>)
606+
UnaryCase(<x_i == NaN -> NaN>)
607+
UnaryCase(<x_i == +0 -> +0>)
608+
UnaryCase(<x_i == -0 -> -0>)
609+
UnaryCase(<x_i == +infinity -> +infinity>)
610+
>>> lt_0_case = unary_cases[0]
611+
>>> lt_0_case.cond(-123)
612+
True
613+
>>> lt_0_case.check_result(-123, float('nan'))
614+
True
610615
611616
"""
612617

@@ -841,6 +846,22 @@ def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
841846

842847

843848
def parse_binary_case(case_str: str) -> BinaryCase:
849+
"""
850+
Parses a Sphinx-formatted binary case string to return codified binary cases, e.g.
851+
852+
>>> case_str = (
853+
... "If ``x1_i`` is greater than ``0``, ``x1_i`` is a finite number, "
854+
... "and ``x2_i`` is ``+infinity``, the result is ``NaN``."
855+
... )
856+
>>> case = parse_binary_case(case_str)
857+
>>> case
858+
BinaryCase(<x1_i > 0 and isfinite(x1_i) and x2_i == +infinity -> NaN>)
859+
>>> case.cond(42, float('inf'))
860+
True
861+
>>> case.check_result(42, float('inf'), float('nan'))
862+
True
863+
864+
"""
844865
case_m = r_binary_case.match(case_str)
845866
if case_m is None:
846867
raise ParseError(case_str)
@@ -857,12 +878,12 @@ def parse_binary_case(case_str: str) -> BinaryCase:
857878
raise ParseError(cond_str)
858879
partial_expr = f"{in_sign}x{in_no}_i == {other_sign}x{other_no}_i"
859880

860-
input_wrapper = lambda i: -i if other_sign == "-" else noop
861881
# For these scenarios, we want to make sure both array elements
862882
# generate respective to one another by using a shared strategy.
863883
shared_from_dtype = lambda d, **kw: st.shared(
864884
xps.from_dtype(d, **kw), key=cond_str
865885
)
886+
input_wrapper = lambda i: -i if other_sign == "-" else noop
866887
if other_no == "1":
867888

868889
def partial_cond(i1: float, i2: float) -> bool:
@@ -1077,9 +1098,9 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
10771098
>>> binary_cases = parse_binary_docstring(logaddexp.__doc__)
10781099
>>> for case in binary_cases:
10791100
... print(repr(case))
1080-
BinaryCase(x1_i == NaN or x2_i == NaN -> NaN)
1081-
BinaryCase(x1_i == +infinity and not x2_i == NaN -> +infinity)
1082-
BinaryCase(not x1_i == NaN and x2_i == +infinity -> +infinity)
1101+
BinaryCase(<x1_i == NaN or x2_i == NaN -> NaN>)
1102+
BinaryCase(<x1_i == +infinity and not x2_i == NaN -> +infinity>)
1103+
BinaryCase(<not x1_i == NaN and x2_i == +infinity -> +infinity>)
10831104
10841105
"""
10851106

0 commit comments

Comments
 (0)