@@ -929,8 +929,6 @@ def test_ceil(x):
929
929
@pytest .mark .min_version ("2023.12" )
930
930
@given (x = hh .arrays (dtype = hh .real_dtypes , shape = hh .shapes ()), data = st .data ())
931
931
def test_clip (x , data ):
932
- # TODO: test min/max kwargs, adjust values testing accordingly
933
-
934
932
# Ensure that if both min and max are arrays that all three of x, min, max
935
933
# are broadcast compatible.
936
934
shape1 , shape2 = data .draw (hh .mutually_broadcastable_shapes (2 ,
@@ -951,7 +949,7 @@ def test_clip(x, data):
951
949
), label = "max" )
952
950
953
951
# min > max is undefined (but allow nans)
954
- assume (min is None or max is None or not xp .any (xp .asarray (min ) > xp .asarray (max )))
952
+ assume (min is None or max is None or not xp .any (ah . less ( xp .asarray (max ), xp .asarray (min ) )))
955
953
956
954
kw = data .draw (
957
955
hh .specified_kwargs (
@@ -972,80 +970,86 @@ def test_clip(x, data):
972
970
expected_shape = sh .broadcast_shapes (* shapes )
973
971
ph .assert_shape ("clip" , out_shape = out .shape , expected = expected_shape )
974
972
975
- if min is max is None :
976
- ph .assert_array_elements ("clip" , out = out , expected = x )
977
- elif max is None :
978
- # If one operand is nan, the result is nan. See
979
- # https://github.com/data-apis/array-api/pull/813.
980
- def refimpl (_x , _min ):
981
- if math .isnan (_x ) or math .isnan (_min ):
982
- return math .nan
973
+ # This is based on right_scalar_assert_against_refimpl and
974
+ # binary_assert_against_refimpl. clip() is currently the only ternary
975
+ # elementwise function and the only function that supports arrays and
976
+ # scalars. However, where() (in test_searching_functions) is similar
977
+ # and if scalar support is added to it, we may want to factor out and
978
+ # reuse this logic.
979
+
980
+ def refimpl (_x , _min , _max ):
981
+ # Skip cases where _min and _max are integers whose values do not
982
+ # fit in the dtype of _x, since this behavior is unspecified.
983
+ if dh .is_int_dtype (x .dtype ):
984
+ if _min is not None and _min not in dh .dtype_ranges [x .dtype ]:
985
+ return None
986
+ if _max is not None and _max not in dh .dtype_ranges [x .dtype ]:
987
+ return None
988
+
989
+ # If min or max are float64 and x is float32, they will need to be
990
+ # downcast to float32. This could result in a round in the wrong
991
+ # direction meaning the resulting clipped value might not actually be
992
+ # between min and max. This behavior is unspecified, so skip any cases
993
+ # where x is within the rounding error of downcasting min or max.
994
+ if x .dtype == xp .float32 :
995
+ if min is not None and not dh .is_scalar (min ) and min .dtype == xp .float64 and math .isfinite (_min ):
996
+ _min_float32 = float (xp .asarray (_min , dtype = xp .float32 ))
997
+ if math .isinf (_min_float32 ):
998
+ return None
999
+ tol = abs (_min - _min_float32 )
1000
+ if math .isclose (_min , _min_float32 , abs_tol = tol ):
1001
+ return None
1002
+ if max is not None and not dh .is_scalar (max ) and max .dtype == xp .float64 and math .isfinite (_max ):
1003
+ _max_float32 = float (xp .asarray (_max , dtype = xp .float32 ))
1004
+ if math .isinf (_max_float32 ):
1005
+ return None
1006
+ tol = abs (_max - _max_float32 )
1007
+ if math .isclose (_max , _max_float32 , abs_tol = tol ):
1008
+ return None
1009
+
1010
+ if (math .isnan (_x )
1011
+ or (_min is not None and math .isnan (_min ))
1012
+ or (_max is not None and math .isnan (_max ))):
1013
+ return math .nan
1014
+ if _min is _max is None :
1015
+ return _x
1016
+ if _max is None :
983
1017
return builtins .max (_x , _min )
984
- if dh .is_scalar (min ):
985
- right_scalar_assert_against_refimpl (
986
- "clip" , x , min , out , refimpl ,
987
- left_sym = "x" ,
988
- expr_template = "clip({}, min={})" ,
989
- )
990
- else :
991
- binary_assert_against_refimpl (
992
- "clip" , x , min , out , refimpl ,
993
- left_sym = "x" , right_sym = "min" ,
994
- expr_template = "clip({}, min={})" ,
995
- )
996
- elif min is None :
997
- def refimpl (_x , _max ):
998
- if math .isnan (_x ) or math .isnan (_max ):
999
- return math .nan
1018
+ if _min is None :
1000
1019
return builtins .min (_x , _max )
1001
- if dh .is_scalar (max ):
1002
- right_scalar_assert_against_refimpl (
1003
- "clip" , x , max , out , refimpl ,
1004
- left_sym = "x" ,
1005
- expr_template = "clip({}, max={})" ,
1020
+ return builtins .min (builtins .max (_x , _min ), _max )
1021
+
1022
+ stype = dh .get_scalar_type (x .dtype )
1023
+ min_shape = () if min is None or dh .is_scalar (min ) else min .shape
1024
+ max_shape = () if max is None or dh .is_scalar (max ) else max .shape
1025
+
1026
+ for x_idx , min_idx , max_idx , o_idx in sh .iter_indices (
1027
+ x .shape , min_shape , max_shape , out .shape ):
1028
+ x_val = stype (x [x_idx ])
1029
+ if min is None or dh .is_scalar (min ):
1030
+ min_val = min
1031
+ else :
1032
+ min_val = stype (min [min_idx ])
1033
+ if max is None or dh .is_scalar (max ):
1034
+ max_val = max
1035
+ else :
1036
+ max_val = stype (max [max_idx ])
1037
+ expected = refimpl (x_val , min_val , max_val )
1038
+ if expected is None :
1039
+ continue
1040
+ out_val = stype (out [o_idx ])
1041
+ if math .isnan (expected ):
1042
+ assert math .isnan (out_val ), (
1043
+ f"out[{ o_idx } ]={ out [o_idx ]} but should be nan [clip()]\n "
1044
+ f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
1006
1045
)
1007
1046
else :
1008
- binary_assert_against_refimpl (
1009
- "clip" , x , max , out , refimpl ,
1010
- left_sym = "x" , right_sym = "max" ,
1011
- expr_template = "clip({}, max={})" ,
1047
+ assert out_val == expected , (
1048
+ f"out[{ o_idx } ]={ out [o_idx ]} but should be { expected } [clip()]\n "
1049
+ f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
1012
1050
)
1013
- else :
1014
- def refimpl (_x , _min , _max ):
1015
- if math .isnan (_x ) or math .isnan (_min ) or math .isnan (_max ):
1016
- return math .nan
1017
- return builtins .min (builtins .max (_x , _min ), _max )
1018
-
1019
- # This is based on right_scalar_assert_against_refimpl and
1020
- # binary_assert_against_refimpl. clip() is currently the only ternary
1021
- # elementwise function and the only function that supports arrays and
1022
- # scalars. However, where() (in test_searching_functions) is similar
1023
- # and if scalar support is added to it, we may want to factor out and
1024
- # reuse this logic.
1025
-
1026
- stype = dh .get_scalar_type (x .dtype )
1027
- min_shape = () if dh .is_scalar (min ) else min .shape
1028
- max_shape = () if dh .is_scalar (max ) else max .shape
1029
-
1030
- for x_idx , min_idx , max_idx , o_idx in sh .iter_indices (
1031
- x .shape , min_shape , max_shape , out .shape ):
1032
- x_val = stype (x [x_idx ])
1033
- min_val = min if dh .is_scalar (min ) else min [min_idx ]
1034
- min_val = stype (min_val )
1035
- max_val = max if dh .is_scalar (max ) else max [max_idx ]
1036
- max_val = stype (max_val )
1037
- expected = refimpl (x_val , min_val , max_val )
1038
- out_val = stype (out [o_idx ])
1039
- if math .isnan (expected ):
1040
- assert math .isnan (out_val ), (
1041
- f"out[{ o_idx } ]={ out [o_idx ]} but should be nan [clip()]\n "
1042
- f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
1043
- )
1044
- else :
1045
- assert out_val == expected , (
1046
- f"out[{ o_idx } ]={ out [o_idx ]} but should be { expected } [clip()]\n "
1047
- f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
1048
- )
1051
+
1052
+
1049
1053
if api_version >= "2022.12" :
1050
1054
1051
1055
@given (hh .arrays (dtype = hh .complex_dtypes , shape = hh .shapes ()))
@@ -1062,7 +1066,7 @@ def test_copysign(x1, x2):
1062
1066
out = xp .copysign (x1 , x2 )
1063
1067
ph .assert_dtype ("copysign" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
1064
1068
ph .assert_result_shape ("copysign" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1065
- # TODO: values testing
1069
+ binary_assert_against_refimpl ( "copysign" , x1 , x2 , out , math . copysign )
1066
1070
1067
1071
1068
1072
@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
@@ -1535,7 +1539,8 @@ def test_signbit(x):
1535
1539
out = xp .signbit (x )
1536
1540
ph .assert_dtype ("signbit" , in_dtype = x .dtype , out_dtype = out .dtype , expected = xp .bool )
1537
1541
ph .assert_shape ("signbit" , out_shape = out .shape , expected = x .shape )
1538
- # TODO: values testing
1542
+ refimpl = lambda x : math .copysign (1.0 , x ) < 0
1543
+ unary_assert_against_refimpl ("round" , x , out , refimpl , strict_check = True )
1539
1544
1540
1545
1541
1546
@given (hh .arrays (dtype = hh .numeric_dtypes , shape = hh .shapes (), elements = finite_kw ))
0 commit comments