@@ -986,6 +986,27 @@ def refimpl(_x, _min, _max):
986
986
if _max is not None and _max not in dh .dtype_ranges [x .dtype ]:
987
987
return None
988
988
989
+ # If min or max are float64 and x is float32, they will need to be
990
+ # downcast to float32. This could result in a round in the wrong
991
+ # direction meaning the resulting clipped value might not actually be
992
+ # between min and max. This behavior is unspecified, so skip any cases
993
+ # where x is within the rounding error of downcasting min or max.
994
+ if x .dtype == xp .float32 :
995
+ if min is not None and not dh .is_scalar (min ) and min .dtype == xp .float64 and math .isfinite (_min ):
996
+ _min_float32 = float (xp .asarray (_min , dtype = xp .float32 ))
997
+ if math .isinf (_min_float32 ):
998
+ return None
999
+ tol = abs (_min - _min_float32 )
1000
+ if math .isclose (_min , _min_float32 , abs_tol = tol ):
1001
+ return None
1002
+ if max is not None and not dh .is_scalar (max ) and max .dtype == xp .float64 and math .isfinite (_max ):
1003
+ _max_float32 = float (xp .asarray (_max , dtype = xp .float32 ))
1004
+ if math .isinf (_max_float32 ):
1005
+ return None
1006
+ tol = abs (_max - _max_float32 )
1007
+ if math .isclose (_max , _max_float32 , abs_tol = tol ):
1008
+ return None
1009
+
989
1010
if (math .isnan (_x )
990
1011
or (_min is not None and math .isnan (_min ))
991
1012
or (_max is not None and math .isnan (_max ))):
0 commit comments