@@ -970,80 +970,65 @@ def test_clip(x, data):
970
970
expected_shape = sh .broadcast_shapes (* shapes )
971
971
ph .assert_shape ("clip" , out_shape = out .shape , expected = expected_shape )
972
972
973
- if min is max is None :
974
- ph .assert_array_elements ("clip" , out = out , expected = x )
975
- elif max is None :
976
- # If one operand is nan, the result is nan. See
977
- # https://github.com/data-apis/array-api/pull/813.
978
- def refimpl (_x , _min ):
979
- if math .isnan (_x ) or math .isnan (_min ):
980
- return math .nan
973
+ # This is based on right_scalar_assert_against_refimpl and
974
+ # binary_assert_against_refimpl. clip() is currently the only ternary
975
+ # elementwise function and the only function that supports arrays and
976
+ # scalars. However, where() (in test_searching_functions) is similar
977
+ # and if scalar support is added to it, we may want to factor out and
978
+ # reuse this logic.
979
+
980
+ def refimpl (_x , _min , _max ):
981
+ # Skip cases where _min and _max are integers whose values do not
982
+ # fit in the dtype of _x, since this behavior is unspecified.
983
+ if dh .is_int_dtype (x .dtype ):
984
+ if _min is not None and _min not in dh .dtype_ranges [x .dtype ]:
985
+ return None
986
+ if _max is not None and _max not in dh .dtype_ranges [x .dtype ]:
987
+ return None
988
+
989
+ if (math .isnan (_x )
990
+ or (_min is not None and math .isnan (_min ))
991
+ or (_max is not None and math .isnan (_max ))):
992
+ return math .nan
993
+ if _min is _max is None :
994
+ return _x
995
+ if _max is None :
981
996
return builtins .max (_x , _min )
982
- if dh .is_scalar (min ):
983
- right_scalar_assert_against_refimpl (
984
- "clip" , x , min , out , refimpl ,
985
- left_sym = "x" ,
986
- expr_template = "clip({}, min={})" ,
987
- )
988
- else :
989
- binary_assert_against_refimpl (
990
- "clip" , x , min , out , refimpl ,
991
- left_sym = "x" , right_sym = "min" ,
992
- expr_template = "clip({}, min={})" ,
993
- )
994
- elif min is None :
995
- def refimpl (_x , _max ):
996
- if math .isnan (_x ) or math .isnan (_max ):
997
- return math .nan
997
+ if _min is None :
998
998
return builtins .min (_x , _max )
999
- if dh .is_scalar (max ):
1000
- right_scalar_assert_against_refimpl (
1001
- "clip" , x , max , out , refimpl ,
1002
- left_sym = "x" ,
1003
- expr_template = "clip({}, max={})" ,
999
+ return builtins .min (builtins .max (_x , _min ), _max )
1000
+
1001
+ stype = dh .get_scalar_type (x .dtype )
1002
+ min_shape = () if min is None or dh .is_scalar (min ) else min .shape
1003
+ max_shape = () if max is None or dh .is_scalar (max ) else max .shape
1004
+
1005
+ for x_idx , min_idx , max_idx , o_idx in sh .iter_indices (
1006
+ x .shape , min_shape , max_shape , out .shape ):
1007
+ x_val = stype (x [x_idx ])
1008
+ if min is None or dh .is_scalar (min ):
1009
+ min_val = min
1010
+ else :
1011
+ min_val = stype (min [min_idx ])
1012
+ if max is None or dh .is_scalar (max ):
1013
+ max_val = max
1014
+ else :
1015
+ max_val = stype (max [max_idx ])
1016
+ expected = refimpl (x_val , min_val , max_val )
1017
+ if expected is None :
1018
+ continue
1019
+ out_val = stype (out [o_idx ])
1020
+ if math .isnan (expected ):
1021
+ assert math .isnan (out_val ), (
1022
+ f"out[{ o_idx } ]={ out [o_idx ]} but should be nan [clip()]\n "
1023
+ f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
1004
1024
)
1005
1025
else :
1006
- binary_assert_against_refimpl (
1007
- "clip" , x , max , out , refimpl ,
1008
- left_sym = "x" , right_sym = "max" ,
1009
- expr_template = "clip({}, max={})" ,
1026
+ assert out_val == expected , (
1027
+ f"out[{ o_idx } ]={ out [o_idx ]} but should be { expected } [clip()]\n "
1028
+ f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
1010
1029
)
1011
- else :
1012
- def refimpl (_x , _min , _max ):
1013
- if math .isnan (_x ) or math .isnan (_min ) or math .isnan (_max ):
1014
- return math .nan
1015
- return builtins .min (builtins .max (_x , _min ), _max )
1016
-
1017
- # This is based on right_scalar_assert_against_refimpl and
1018
- # binary_assert_against_refimpl. clip() is currently the only ternary
1019
- # elementwise function and the only function that supports arrays and
1020
- # scalars. However, where() (in test_searching_functions) is similar
1021
- # and if scalar support is added to it, we may want to factor out and
1022
- # reuse this logic.
1023
-
1024
- stype = dh .get_scalar_type (x .dtype )
1025
- min_shape = () if dh .is_scalar (min ) else min .shape
1026
- max_shape = () if dh .is_scalar (max ) else max .shape
1027
-
1028
- for x_idx , min_idx , max_idx , o_idx in sh .iter_indices (
1029
- x .shape , min_shape , max_shape , out .shape ):
1030
- x_val = stype (x [x_idx ])
1031
- min_val = min if dh .is_scalar (min ) else min [min_idx ]
1032
- min_val = stype (min_val )
1033
- max_val = max if dh .is_scalar (max ) else max [max_idx ]
1034
- max_val = stype (max_val )
1035
- expected = refimpl (x_val , min_val , max_val )
1036
- out_val = stype (out [o_idx ])
1037
- if math .isnan (expected ):
1038
- assert math .isnan (out_val ), (
1039
- f"out[{ o_idx } ]={ out [o_idx ]} but should be nan [clip()]\n "
1040
- f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
1041
- )
1042
- else :
1043
- assert out_val == expected , (
1044
- f"out[{ o_idx } ]={ out [o_idx ]} but should be { expected } [clip()]\n "
1045
- f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
1046
- )
1030
+
1031
+
1047
1032
if api_version >= "2022.12" :
1048
1033
1049
1034
@given (hh .arrays (dtype = hh .complex_dtypes , shape = hh .shapes ()))
0 commit comments