@@ -208,6 +208,7 @@ def test_isdtype(dtype, kind):
208
208
assert out == expected , f"{ out = } , but should be { expected } [isdtype()]"
209
209
210
210
211
+ @pytest .mark .min_version ("2024.12" )
211
212
class TestResultType :
212
213
@given (dtypes = hh .mutually_promotable_dtypes (None ))
213
214
def test_result_type (self , dtypes ):
@@ -230,3 +231,30 @@ def test_arrays_and_dtypes(self, pair, data):
230
231
out = xp .result_type (* a_and_dt )
231
232
ph .assert_dtype ("result_type" , in_dtype = s1 + s2 , out_dtype = out , repr_name = "out" )
232
233
234
+ @given (dtypes = hh .mutually_promotable_dtypes (2 ), data = st .data ())
235
+ def test_with_scalars (self , dtypes , data ):
236
+ out = xp .result_type (* dtypes )
237
+
238
+ if out == xp .bool :
239
+ scalars = [True ]
240
+ elif out in dh .all_int_dtypes :
241
+ scalars = [1 ]
242
+ elif out in dh .real_dtypes :
243
+ scalars = [1 , 1.0 ]
244
+ elif out in dh .numeric_dtypes :
245
+ scalars = [1 , 1.0 , 1j ] # numeric_types - real_types == complex_types
246
+ else :
247
+ raise ValueError (f"unknown dtype { out = } ." )
248
+
249
+ scalar = data .draw (st .sampled_from (scalars ))
250
+ inputs = data .draw (st .permutations (dtypes + (scalar ,)))
251
+
252
+ out_scalar = xp .result_type (* inputs )
253
+ assert out_scalar == out
254
+
255
+ # retry with arrays
256
+ arrays = tuple (xp .empty (1 , dtype = dt ) for dt in dtypes )
257
+ inputs = data .draw (st .permutations (arrays + (scalar ,)))
258
+ out_scalar = xp .result_type (* inputs )
259
+ assert out_scalar == out
260
+
0 commit comments