Skip to content

Commit d95c2ab

Browse files
committed
Test special cases for operators
1 parent 456fc6c commit d95c2ab

File tree

1 file changed

+86
-8
lines changed

1 file changed

+86
-8
lines changed

array_api_tests/test_special_cases.py

+86-8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import inspect
66
import math
7+
import operator
78
import re
89
from dataclasses import dataclass, field
910
from decimal import ROUND_HALF_EVEN, Decimal
@@ -24,6 +25,10 @@
2425
from . import xps
2526
from ._array_module import mod as xp
2627
from .stubs import category_to_funcs
28+
from .test_operators_and_elementwise_functions import (
29+
oneway_broadcastable_shapes,
30+
oneway_promotable_dtypes,
31+
)
2732

2833
pytestmark = pytest.mark.ci
2934

@@ -1138,6 +1143,8 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
11381143

11391144
unary_params = []
11401145
binary_params = []
1146+
iop_params = []
1147+
func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()}
11411148
for stub in category_to_funcs["elementwise"]:
11421149
if stub.__doc__ is None:
11431150
warn(f"{stub.__name__}() stub has no docstring")
@@ -1157,20 +1164,39 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
11571164
continue
11581165
if param_names[0] == "x":
11591166
if cases := parse_unary_docstring(stub.__doc__):
1160-
for case in cases:
1161-
id_ = f"{stub.__name__}({case.cond_expr}) -> {case.result_expr}"
1162-
p = pytest.param(stub.__name__, func, case, id=id_)
1163-
unary_params.append(p)
1167+
func_name_to_func = {stub.__name__: func}
1168+
if stub.__name__ in func_to_op.keys():
1169+
op_name = func_to_op[stub.__name__]
1170+
op = getattr(operator, op_name)
1171+
func_name_to_func[op_name] = op
1172+
for func_name, func in func_name_to_func.items():
1173+
for case in cases:
1174+
id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}"
1175+
p = pytest.param(func_name, func, case, id=id_)
1176+
unary_params.append(p)
11641177
continue
11651178
if len(sig.parameters) == 1:
11661179
warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'")
11671180
continue
11681181
if param_names[0] == "x1" and param_names[1] == "x2":
11691182
if cases := parse_binary_docstring(stub.__doc__):
1170-
for case in cases:
1171-
id_ = f"{stub.__name__}({case.cond_expr}) -> {case.result_expr}"
1172-
p = pytest.param(stub.__name__, func, case, id=id_)
1173-
binary_params.append(p)
1183+
func_name_to_func = {stub.__name__: func}
1184+
if stub.__name__ in func_to_op.keys():
1185+
op_name = func_to_op[stub.__name__]
1186+
op = getattr(operator, op_name)
1187+
func_name_to_func[op_name] = op
1188+
# We collect inplaceoperator test cases seperately
1189+
iop_name = "__i" + op_name[2:]
1190+
iop = getattr(operator, iop_name)
1191+
for case in cases:
1192+
id_ = f"{iop_name}({case.cond_expr}) -> {case.result_expr}"
1193+
p = pytest.param(iop_name, iop, case, id=id_)
1194+
iop_params.append(p)
1195+
for func_name, func in func_name_to_func.items():
1196+
for case in cases:
1197+
id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}"
1198+
p = pytest.param(func_name, func, case, id=id_)
1199+
binary_params.append(p)
11741200
continue
11751201
else:
11761202
warn(
@@ -1264,3 +1290,55 @@ def test_binary(func_name, func, case, x1, x2, data):
12641290
)
12651291
break
12661292
assume(good_example)
1293+
1294+
1295+
@pytest.mark.parametrize("iop_name, iop, case", iop_params)
1296+
@given(
1297+
oneway_dtypes=oneway_promotable_dtypes(dh.float_dtypes),
1298+
oneway_shapes=oneway_broadcastable_shapes(),
1299+
data=st.data(),
1300+
)
1301+
def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data):
1302+
x1 = data.draw(
1303+
xps.arrays(dtype=oneway_dtypes.result_dtype, shape=oneway_shapes.result_shape),
1304+
label="x1",
1305+
)
1306+
x2 = data.draw(
1307+
xps.arrays(dtype=oneway_dtypes.input_dtype, shape=oneway_shapes.input_shape),
1308+
label="x2",
1309+
)
1310+
1311+
all_indices = list(sh.iter_indices(x1.shape, x2.shape, x1.shape))
1312+
1313+
indices_strat = st.shared(st.sampled_from(all_indices))
1314+
set_x1_idx = data.draw(indices_strat.map(lambda t: t[0]), label="set x1 idx")
1315+
set_x1_value = data.draw(case.x1_cond_from_dtype(x1.dtype), label="set x1 value")
1316+
x1[set_x1_idx] = set_x1_value
1317+
note(f"{x1=}")
1318+
set_x2_idx = data.draw(indices_strat.map(lambda t: t[1]), label="set x2 idx")
1319+
set_x2_value = data.draw(case.x2_cond_from_dtype(x2.dtype), label="set x2 value")
1320+
x2[set_x2_idx] = set_x2_value
1321+
note(f"{x2=}")
1322+
1323+
res = xp.asarray(x1, copy=True)
1324+
iop(res, x2)
1325+
# sanity check
1326+
ph.assert_result_shape(iop_name, [x1.shape, x2.shape], res.shape)
1327+
1328+
good_example = False
1329+
for l_idx, r_idx, o_idx in all_indices:
1330+
l = float(x1[l_idx])
1331+
r = float(x2[r_idx])
1332+
if case.cond(l, r):
1333+
good_example = True
1334+
o = float(res[o_idx])
1335+
f_left = f"{sh.fmt_idx('x1', l_idx)}={l}"
1336+
f_right = f"{sh.fmt_idx('x2', r_idx)}={r}"
1337+
f_out = f"{sh.fmt_idx('out', o_idx)}={o}"
1338+
assert case.check_result(l, r, o), (
1339+
f"{f_out}, but should be {case.result_expr} [{iop_name}()]\n"
1340+
f"condition: {case}\n"
1341+
f"{f_left}, {f_right}"
1342+
)
1343+
break
1344+
assume(good_example)

0 commit comments

Comments
 (0)