4
4
5
5
import inspect
6
6
import math
7
+ import operator
7
8
import re
8
9
from dataclasses import dataclass , field
9
10
from decimal import ROUND_HALF_EVEN , Decimal
24
25
from . import xps
25
26
from ._array_module import mod as xp
26
27
from .stubs import category_to_funcs
28
+ from .test_operators_and_elementwise_functions import (
29
+ oneway_broadcastable_shapes ,
30
+ oneway_promotable_dtypes ,
31
+ )
27
32
28
33
pytestmark = pytest .mark .ci
29
34
@@ -1138,6 +1143,8 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
1138
1143
1139
1144
unary_params = []
1140
1145
binary_params = []
1146
+ iop_params = []
1147
+ func_to_op : Dict [str , str ] = {v : k for k , v in dh .op_to_func .items ()}
1141
1148
for stub in category_to_funcs ["elementwise" ]:
1142
1149
if stub .__doc__ is None :
1143
1150
warn (f"{ stub .__name__ } () stub has no docstring" )
@@ -1157,20 +1164,39 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
1157
1164
continue
1158
1165
if param_names [0 ] == "x" :
1159
1166
if cases := parse_unary_docstring (stub .__doc__ ):
1160
- for case in cases :
1161
- id_ = f"{ stub .__name__ } ({ case .cond_expr } ) -> { case .result_expr } "
1162
- p = pytest .param (stub .__name__ , func , case , id = id_ )
1163
- unary_params .append (p )
1167
+ func_name_to_func = {stub .__name__ : func }
1168
+ if stub .__name__ in func_to_op .keys ():
1169
+ op_name = func_to_op [stub .__name__ ]
1170
+ op = getattr (operator , op_name )
1171
+ func_name_to_func [op_name ] = op
1172
+ for func_name , func in func_name_to_func .items ():
1173
+ for case in cases :
1174
+ id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1175
+ p = pytest .param (func_name , func , case , id = id_ )
1176
+ unary_params .append (p )
1164
1177
continue
1165
1178
if len (sig .parameters ) == 1 :
1166
1179
warn (f"{ func = } has one parameter '{ param_names [0 ]} ' which is not named 'x'" )
1167
1180
continue
1168
1181
if param_names [0 ] == "x1" and param_names [1 ] == "x2" :
1169
1182
if cases := parse_binary_docstring (stub .__doc__ ):
1170
- for case in cases :
1171
- id_ = f"{ stub .__name__ } ({ case .cond_expr } ) -> { case .result_expr } "
1172
- p = pytest .param (stub .__name__ , func , case , id = id_ )
1173
- binary_params .append (p )
1183
+ func_name_to_func = {stub .__name__ : func }
1184
+ if stub .__name__ in func_to_op .keys ():
1185
+ op_name = func_to_op [stub .__name__ ]
1186
+ op = getattr (operator , op_name )
1187
+ func_name_to_func [op_name ] = op
1188
+ # We collect inplaceoperator test cases seperately
1189
+ iop_name = "__i" + op_name [2 :]
1190
+ iop = getattr (operator , iop_name )
1191
+ for case in cases :
1192
+ id_ = f"{ iop_name } ({ case .cond_expr } ) -> { case .result_expr } "
1193
+ p = pytest .param (iop_name , iop , case , id = id_ )
1194
+ iop_params .append (p )
1195
+ for func_name , func in func_name_to_func .items ():
1196
+ for case in cases :
1197
+ id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1198
+ p = pytest .param (func_name , func , case , id = id_ )
1199
+ binary_params .append (p )
1174
1200
continue
1175
1201
else :
1176
1202
warn (
@@ -1264,3 +1290,55 @@ def test_binary(func_name, func, case, x1, x2, data):
1264
1290
)
1265
1291
break
1266
1292
assume (good_example )
1293
+
1294
+
1295
+ @pytest .mark .parametrize ("iop_name, iop, case" , iop_params )
1296
+ @given (
1297
+ oneway_dtypes = oneway_promotable_dtypes (dh .float_dtypes ),
1298
+ oneway_shapes = oneway_broadcastable_shapes (),
1299
+ data = st .data (),
1300
+ )
1301
+ def test_iop (iop_name , iop , case , oneway_dtypes , oneway_shapes , data ):
1302
+ x1 = data .draw (
1303
+ xps .arrays (dtype = oneway_dtypes .result_dtype , shape = oneway_shapes .result_shape ),
1304
+ label = "x1" ,
1305
+ )
1306
+ x2 = data .draw (
1307
+ xps .arrays (dtype = oneway_dtypes .input_dtype , shape = oneway_shapes .input_shape ),
1308
+ label = "x2" ,
1309
+ )
1310
+
1311
+ all_indices = list (sh .iter_indices (x1 .shape , x2 .shape , x1 .shape ))
1312
+
1313
+ indices_strat = st .shared (st .sampled_from (all_indices ))
1314
+ set_x1_idx = data .draw (indices_strat .map (lambda t : t [0 ]), label = "set x1 idx" )
1315
+ set_x1_value = data .draw (case .x1_cond_from_dtype (x1 .dtype ), label = "set x1 value" )
1316
+ x1 [set_x1_idx ] = set_x1_value
1317
+ note (f"{ x1 = } " )
1318
+ set_x2_idx = data .draw (indices_strat .map (lambda t : t [1 ]), label = "set x2 idx" )
1319
+ set_x2_value = data .draw (case .x2_cond_from_dtype (x2 .dtype ), label = "set x2 value" )
1320
+ x2 [set_x2_idx ] = set_x2_value
1321
+ note (f"{ x2 = } " )
1322
+
1323
+ res = xp .asarray (x1 , copy = True )
1324
+ iop (res , x2 )
1325
+ # sanity check
1326
+ ph .assert_result_shape (iop_name , [x1 .shape , x2 .shape ], res .shape )
1327
+
1328
+ good_example = False
1329
+ for l_idx , r_idx , o_idx in all_indices :
1330
+ l = float (x1 [l_idx ])
1331
+ r = float (x2 [r_idx ])
1332
+ if case .cond (l , r ):
1333
+ good_example = True
1334
+ o = float (res [o_idx ])
1335
+ f_left = f"{ sh .fmt_idx ('x1' , l_idx )} ={ l } "
1336
+ f_right = f"{ sh .fmt_idx ('x2' , r_idx )} ={ r } "
1337
+ f_out = f"{ sh .fmt_idx ('out' , o_idx )} ={ o } "
1338
+ assert case .check_result (l , r , o ), (
1339
+ f"{ f_out } , but should be { case .result_expr } [{ iop_name } ()]\n "
1340
+ f"condition: { case } \n "
1341
+ f"{ f_left } , { f_right } "
1342
+ )
1343
+ break
1344
+ assume (good_example )
0 commit comments