Skip to content

Commit d53cb39

Browse files
committed
ENH: add a result_type test with python scalars
1 parent af389a8 commit d53cb39

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

Diff for: array_api_tests/hypothesis_helpers.py

-4
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,7 @@
1010
from hypothesis import assume, reject
1111
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
1212
integers, complex_numbers, just, lists, none, one_of,
13-
<<<<<<< HEAD
14-
sampled_from, shared, builds, nothing)
15-
=======
1613
sampled_from, shared, builds, nothing, permutations)
17-
>>>>>>> ENH: add a test that result_type does not depend on the order of arguments
1814

1915
from . import _array_module as xp, api_version
2016
from . import array_helpers as ah

Diff for: array_api_tests/test_data_type_functions.py

+28
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def test_isdtype(dtype, kind):
208208
assert out == expected, f"{out=}, but should be {expected} [isdtype()]"
209209

210210

211+
@pytest.mark.min_version("2024.12")
211212
class TestResultType:
212213
@given(dtypes=hh.mutually_promotable_dtypes(None))
213214
def test_result_type(self, dtypes):
@@ -230,3 +231,30 @@ def test_arrays_and_dtypes(self, pair, data):
230231
out = xp.result_type(*a_and_dt)
231232
ph.assert_dtype("result_type", in_dtype=s1+s2, out_dtype=out, repr_name="out")
232233

234+
@given(dtypes=hh.mutually_promotable_dtypes(2), data=st.data())
235+
def test_with_scalars(self, dtypes, data):
236+
out = xp.result_type(*dtypes)
237+
238+
if out == xp.bool:
239+
scalars = [True]
240+
elif out in dh.all_int_dtypes:
241+
scalars = [1]
242+
elif out in dh.real_dtypes:
243+
scalars = [1, 1.0]
244+
elif out in dh.numeric_dtypes:
245+
scalars = [1, 1.0, 1j] # numeric_types - real_types == complex_types
246+
else:
247+
raise ValueError(f"unknown dtype {out = }.")
248+
249+
scalar = data.draw(st.sampled_from(scalars))
250+
inputs = data.draw(st.permutations(dtypes + (scalar,)))
251+
252+
out_scalar = xp.result_type(*inputs)
253+
assert out_scalar == out
254+
255+
# retry with arrays
256+
arrays = tuple(xp.empty(1, dtype=dt) for dt in dtypes)
257+
inputs = data.draw(st.permutations(arrays + (scalar,)))
258+
out_scalar = xp.result_type(*inputs)
259+
assert out_scalar == out
260+

0 commit comments

Comments
 (0)