4
4
import cmath
5
5
import math
6
6
import operator
7
+ import builtins
7
8
from copy import copy
8
9
from enum import Enum , auto
9
10
from typing import Callable , List , NamedTuple , Optional , Sequence , TypeVar , Union
@@ -369,6 +370,8 @@ def right_scalar_assert_against_refimpl(
369
370
370
371
See unary_assert_against_refimpl for more information.
371
372
"""
373
+ if expr_template is None :
374
+ expr_template = func_name + "({}, {})={}"
372
375
if left .dtype in dh .complex_dtypes :
373
376
component_filter = copy (filter_ )
374
377
filter_ = lambda s : component_filter (s .real ) and component_filter (s .imag )
@@ -422,7 +425,7 @@ def right_scalar_assert_against_refimpl(
422
425
)
423
426
424
427
425
- # When appropiate , this module tests operators alongside their respective
428
+ # When appropriate , this module tests operators alongside their respective
426
429
# elementwise methods. We do this by parametrizing a generalised test method
427
430
# with every relevant method and operator.
428
431
#
@@ -432,8 +435,8 @@ def right_scalar_assert_against_refimpl(
432
435
# - The argument strategies, which can be used to draw arguments for the test
433
436
# case. They may require additional filtering for certain test cases.
434
437
# - right_is_scalar (binary parameters only), which denotes if the right
435
- # argument is a scalar in a test case. This can be used to appropiately adjust
436
- # draw filtering and test logic.
438
+ # argument is a scalar in a test case. This can be used to appropriately
439
+ # adjust draw filtering and test logic.
437
440
438
441
439
442
func_to_op = {v : k for k , v in dh .op_to_func .items ()}
@@ -475,7 +478,7 @@ def make_unary_params(
475
478
)
476
479
if api_version < min_version :
477
480
marks = pytest .mark .skip (
478
- reason = f"requires ARRAY_API_TESTS_VERSION=> { min_version } "
481
+ reason = f"requires ARRAY_API_TESTS_VERSION >= { min_version } "
479
482
)
480
483
else :
481
484
marks = ()
@@ -924,15 +927,125 @@ def test_ceil(x):
924
927
925
928
926
929
@pytest .mark .min_version ("2023.12" )
927
- @given (hh .arrays (dtype = hh .real_floating_dtypes , shape = hh .shapes ()))
928
- def test_clip (x ):
930
+ @given (x = hh .arrays (dtype = hh .real_dtypes , shape = hh .shapes ()), data = st . data ( ))
931
+ def test_clip (x , data ):
929
932
# TODO: test min/max kwargs, adjust values testing accordingly
930
- out = xp .clip (x )
931
- ph .assert_dtype ("clip" , in_dtype = x .dtype , out_dtype = out .dtype )
932
- ph .assert_shape ("clip" , out_shape = out .shape , expected = x .shape )
933
- ph .assert_array_elements ("clip" , out = out , expected = x )
934
933
934
+ # Ensure that if both min and max are arrays that all three of x, min, max
935
+ # are broadcast compatible.
936
+ shape1 , shape2 = data .draw (hh .mutually_broadcastable_shapes (2 ,
937
+ base_shape = x .shape ),
938
+ label = "min.shape, max.shape" )
939
+
940
+ dtypes = hh .real_floating_dtypes if dh .is_float_dtype (x .dtype ) else hh .int_dtypes
941
+
942
+ min = data .draw (st .one_of (
943
+ st .none (),
944
+ hh .scalars (dtypes = st .just (x .dtype )),
945
+ hh .arrays (dtype = dtypes , shape = shape1 ),
946
+ ), label = "min" )
947
+ max = data .draw (st .one_of (
948
+ st .none (),
949
+ hh .scalars (dtypes = st .just (x .dtype )),
950
+ hh .arrays (dtype = dtypes , shape = shape2 ),
951
+ ), label = "max" )
952
+
953
+ # min > max is undefined (but allow nans)
954
+ assume (min is None or max is None or not xp .any (xp .asarray (min ) > xp .asarray (max )))
955
+
956
+ kw = data .draw (
957
+ hh .specified_kwargs (
958
+ ("min" , min , None ),
959
+ ("max" , max , None )),
960
+ label = "kwargs" )
961
+
962
+ out = xp .clip (x , ** kw )
963
+
964
+ # min and max do not participate in type promotion
965
+ ph .assert_dtype ("clip" , in_dtype = x .dtype , out_dtype = out .dtype )
935
966
967
+ shapes = [x .shape ]
968
+ if min is not None and not dh .is_scalar (min ):
969
+ shapes .append (min .shape )
970
+ if max is not None and not dh .is_scalar (max ):
971
+ shapes .append (max .shape )
972
+ expected_shape = sh .broadcast_shapes (* shapes )
973
+ ph .assert_shape ("clip" , out_shape = out .shape , expected = expected_shape )
974
+
975
+ if min is max is None :
976
+ ph .assert_array_elements ("clip" , out = out , expected = x )
977
+ elif max is None :
978
+ # If one operand is nan, the result is nan. See
979
+ # https://github.com/data-apis/array-api/pull/813.
980
+ def refimpl (_x , _min ):
981
+ if math .isnan (_x ) or math .isnan (_min ):
982
+ return math .nan
983
+ return builtins .max (_x , _min )
984
+ if dh .is_scalar (min ):
985
+ right_scalar_assert_against_refimpl (
986
+ "clip" , x , min , out , refimpl ,
987
+ left_sym = "x" ,
988
+ expr_template = "clip({}, min={})" ,
989
+ )
990
+ else :
991
+ binary_assert_against_refimpl (
992
+ "clip" , x , min , out , refimpl ,
993
+ left_sym = "x" , right_sym = "min" ,
994
+ expr_template = "clip({}, min={})" ,
995
+ )
996
+ elif min is None :
997
+ def refimpl (_x , _max ):
998
+ if math .isnan (_x ) or math .isnan (_max ):
999
+ return math .nan
1000
+ return builtins .min (_x , _max )
1001
+ if dh .is_scalar (max ):
1002
+ right_scalar_assert_against_refimpl (
1003
+ "clip" , x , max , out , refimpl ,
1004
+ left_sym = "x" ,
1005
+ expr_template = "clip({}, max={})" ,
1006
+ )
1007
+ else :
1008
+ binary_assert_against_refimpl (
1009
+ "clip" , x , max , out , refimpl ,
1010
+ left_sym = "x" , right_sym = "max" ,
1011
+ expr_template = "clip({}, max={})" ,
1012
+ )
1013
+ else :
1014
+ def refimpl (_x , _min , _max ):
1015
+ if math .isnan (_x ) or math .isnan (_min ) or math .isnan (_max ):
1016
+ return math .nan
1017
+ return builtins .min (builtins .max (_x , _min ), _max )
1018
+
1019
+ # This is based on right_scalar_assert_against_refimpl and
1020
+ # binary_assert_against_refimpl. clip() is currently the only ternary
1021
+ # elementwise function and the only function that supports arrays and
1022
+ # scalars. However, where() (in test_searching_functions) is similar
1023
+ # and if scalar support is added to it, we may want to factor out and
1024
+ # reuse this logic.
1025
+
1026
+ stype = dh .get_scalar_type (x .dtype )
1027
+ min_shape = () if dh .is_scalar (min ) else min .shape
1028
+ max_shape = () if dh .is_scalar (max ) else max .shape
1029
+
1030
+ for x_idx , min_idx , max_idx , o_idx in sh .iter_indices (
1031
+ x .shape , min_shape , max_shape , out .shape ):
1032
+ x_val = stype (x [x_idx ])
1033
+ min_val = min if dh .is_scalar (min ) else min [min_idx ]
1034
+ min_val = stype (min_val )
1035
+ max_val = max if dh .is_scalar (max ) else max [max_idx ]
1036
+ max_val = stype (max_val )
1037
+ expected = refimpl (x_val , min_val , max_val )
1038
+ out_val = stype (out [o_idx ])
1039
+ if math .isnan (expected ):
1040
+ assert math .isnan (out_val ), (
1041
+ f"out[{ o_idx } ]={ out [o_idx ]} but should be nan [clip()]\n "
1042
+ f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
1043
+ )
1044
+ else :
1045
+ assert out_val == expected , (
1046
+ f"out[{ o_idx } ]={ out [o_idx ]} but should be { expected } [clip()]\n "
1047
+ f"x[{ x_idx } ]={ x_val } , min[{ min_idx } ]={ min_val } , max[{ max_idx } ]={ max_val } "
1048
+ )
936
1049
if api_version >= "2022.12" :
937
1050
938
1051
@given (hh .arrays (dtype = hh .complex_dtypes , shape = hh .shapes ()))
0 commit comments