Skip to content

Commit 704f456

Browse files
committed
Ignore unspecified rounding in downcasting in clip()
1 parent 3b773db commit 704f456

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

+21
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,27 @@ def refimpl(_x, _min, _max):
986986
if _max is not None and _max not in dh.dtype_ranges[x.dtype]:
987987
return None
988988

989+
# If min or max are float64 and x is float32, they will need to be
990+
# downcast to float32. This could result in a round in the wrong
991+
# direction meaning the resulting clipped value might not actually be
992+
# between min and max. This behavior is unspecified, so skip any cases
993+
# where x is within the rounding error of downcasting min or max.
994+
if x.dtype == xp.float32:
995+
if min is not None and not dh.is_scalar(min) and min.dtype == xp.float64 and math.isfinite(_min):
996+
_min_float32 = float(xp.asarray(_min, dtype=xp.float32))
997+
if math.isinf(_min_float32):
998+
return None
999+
tol = abs(_min - _min_float32)
1000+
if math.isclose(_min, _min_float32, abs_tol=tol):
1001+
return None
1002+
if max is not None and not dh.is_scalar(max) and max.dtype == xp.float64 and math.isfinite(_max):
1003+
_max_float32 = float(xp.asarray(_max, dtype=xp.float32))
1004+
if math.isinf(_max_float32):
1005+
return None
1006+
tol = abs(_max - _max_float32)
1007+
if math.isclose(_max, _max_float32, abs_tol=tol):
1008+
return None
1009+
9891010
if (math.isnan(_x)
9901011
or (_min is not None and math.isnan(_min))
9911012
or (_max is not None and math.isnan(_max))):

0 commit comments

Comments
 (0)