Skip to content

Commit 3b773db

Browse files
committed
Combine different clip test cases together and omit out-of-bounds cases
1 parent 9465959 commit 3b773db

File tree

1 file changed

+54
-69
lines changed

1 file changed

+54
-69
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

+54-69
Original file line numberDiff line numberDiff line change
@@ -970,80 +970,65 @@ def test_clip(x, data):
970970
expected_shape = sh.broadcast_shapes(*shapes)
971971
ph.assert_shape("clip", out_shape=out.shape, expected=expected_shape)
972972

973-
if min is max is None:
974-
ph.assert_array_elements("clip", out=out, expected=x)
975-
elif max is None:
976-
# If one operand is nan, the result is nan. See
977-
# https://github.com/data-apis/array-api/pull/813.
978-
def refimpl(_x, _min):
979-
if math.isnan(_x) or math.isnan(_min):
980-
return math.nan
973+
# This is based on right_scalar_assert_against_refimpl and
974+
# binary_assert_against_refimpl. clip() is currently the only ternary
975+
# elementwise function and the only function that supports arrays and
976+
# scalars. However, where() (in test_searching_functions) is similar
977+
# and if scalar support is added to it, we may want to factor out and
978+
# reuse this logic.
979+
980+
def refimpl(_x, _min, _max):
981+
# Skip cases where _min and _max are integers whose values do not
982+
# fit in the dtype of _x, since this behavior is unspecified.
983+
if dh.is_int_dtype(x.dtype):
984+
if _min is not None and _min not in dh.dtype_ranges[x.dtype]:
985+
return None
986+
if _max is not None and _max not in dh.dtype_ranges[x.dtype]:
987+
return None
988+
989+
if (math.isnan(_x)
990+
or (_min is not None and math.isnan(_min))
991+
or (_max is not None and math.isnan(_max))):
992+
return math.nan
993+
if _min is _max is None:
994+
return _x
995+
if _max is None:
981996
return builtins.max(_x, _min)
982-
if dh.is_scalar(min):
983-
right_scalar_assert_against_refimpl(
984-
"clip", x, min, out, refimpl,
985-
left_sym="x",
986-
expr_template="clip({}, min={})",
987-
)
988-
else:
989-
binary_assert_against_refimpl(
990-
"clip", x, min, out, refimpl,
991-
left_sym="x", right_sym="min",
992-
expr_template="clip({}, min={})",
993-
)
994-
elif min is None:
995-
def refimpl(_x, _max):
996-
if math.isnan(_x) or math.isnan(_max):
997-
return math.nan
997+
if _min is None:
998998
return builtins.min(_x, _max)
999-
if dh.is_scalar(max):
1000-
right_scalar_assert_against_refimpl(
1001-
"clip", x, max, out, refimpl,
1002-
left_sym="x",
1003-
expr_template="clip({}, max={})",
999+
return builtins.min(builtins.max(_x, _min), _max)
1000+
1001+
stype = dh.get_scalar_type(x.dtype)
1002+
min_shape = () if min is None or dh.is_scalar(min) else min.shape
1003+
max_shape = () if max is None or dh.is_scalar(max) else max.shape
1004+
1005+
for x_idx, min_idx, max_idx, o_idx in sh.iter_indices(
1006+
x.shape, min_shape, max_shape, out.shape):
1007+
x_val = stype(x[x_idx])
1008+
if min is None or dh.is_scalar(min):
1009+
min_val = min
1010+
else:
1011+
min_val = stype(min[min_idx])
1012+
if max is None or dh.is_scalar(max):
1013+
max_val = max
1014+
else:
1015+
max_val = stype(max[max_idx])
1016+
expected = refimpl(x_val, min_val, max_val)
1017+
if expected is None:
1018+
continue
1019+
out_val = stype(out[o_idx])
1020+
if math.isnan(expected):
1021+
assert math.isnan(out_val), (
1022+
f"out[{o_idx}]={out[o_idx]} but should be nan [clip()]\n"
1023+
f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}"
10041024
)
10051025
else:
1006-
binary_assert_against_refimpl(
1007-
"clip", x, max, out, refimpl,
1008-
left_sym="x", right_sym="max",
1009-
expr_template="clip({}, max={})",
1026+
assert out_val == expected, (
1027+
f"out[{o_idx}]={out[o_idx]} but should be {expected} [clip()]\n"
1028+
f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}"
10101029
)
1011-
else:
1012-
def refimpl(_x, _min, _max):
1013-
if math.isnan(_x) or math.isnan(_min) or math.isnan(_max):
1014-
return math.nan
1015-
return builtins.min(builtins.max(_x, _min), _max)
1016-
1017-
# This is based on right_scalar_assert_against_refimpl and
1018-
# binary_assert_against_refimpl. clip() is currently the only ternary
1019-
# elementwise function and the only function that supports arrays and
1020-
# scalars. However, where() (in test_searching_functions) is similar
1021-
# and if scalar support is added to it, we may want to factor out and
1022-
# reuse this logic.
1023-
1024-
stype = dh.get_scalar_type(x.dtype)
1025-
min_shape = () if dh.is_scalar(min) else min.shape
1026-
max_shape = () if dh.is_scalar(max) else max.shape
1027-
1028-
for x_idx, min_idx, max_idx, o_idx in sh.iter_indices(
1029-
x.shape, min_shape, max_shape, out.shape):
1030-
x_val = stype(x[x_idx])
1031-
min_val = min if dh.is_scalar(min) else min[min_idx]
1032-
min_val = stype(min_val)
1033-
max_val = max if dh.is_scalar(max) else max[max_idx]
1034-
max_val = stype(max_val)
1035-
expected = refimpl(x_val, min_val, max_val)
1036-
out_val = stype(out[o_idx])
1037-
if math.isnan(expected):
1038-
assert math.isnan(out_val), (
1039-
f"out[{o_idx}]={out[o_idx]} but should be nan [clip()]\n"
1040-
f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}"
1041-
)
1042-
else:
1043-
assert out_val == expected, (
1044-
f"out[{o_idx}]={out[o_idx]} but should be {expected} [clip()]\n"
1045-
f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}"
1046-
)
1030+
1031+
10471032
if api_version >= "2022.12":
10481033

10491034
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))

0 commit comments

Comments
 (0)