Skip to content

Commit 0ecc4f0

Browse files
committed
ENH: add a test that result_type does not depend on the order of arguments
1 parent 835a9ca commit 0ecc4f0

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

array_api_tests/hypothesis_helpers.py

+11
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
from hypothesis import assume, reject
1111
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
1212
integers, complex_numbers, just, lists, none, one_of,
13+
<<<<<<< HEAD
1314
sampled_from, shared, builds, nothing)
15+
=======
16+
sampled_from, shared, builds, nothing, permutations)
17+
>>>>>>> ENH: add a test that result_type does not depend on the order of arguments
1418

1519
from . import _array_module as xp, api_version
1620
from . import array_helpers as ah
@@ -148,6 +152,13 @@ def mutually_promotable_dtypes(
148152
return one_of(strats).map(tuple)
149153

150154

155+
@composite
156+
def pair_of_mutually_promotable_dtypes(draw, max_size=2, *, dtypes=dh.all_dtypes):
157+
sample = draw(mutually_promotable_dtypes( max_size, dtypes=dtypes))
158+
permuted = draw(permutations(sample))
159+
return sample, permuted
160+
161+
151162
class OnewayPromotableDtypes(NamedTuple):
152163
input_dtype: DataType
153164
result_dtype: DataType

array_api_tests/test_data_type_functions.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,17 @@ def test_isdtype(dtype, kind):
208208
assert out == expected, f"{out=}, but should be {expected} [isdtype()]"
209209

210210

211-
@given(hh.mutually_promotable_dtypes(None))
212-
def test_result_type(dtypes):
213-
out = xp.result_type(*dtypes)
214-
ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out")
211+
class TestResultType:
212+
@given(dtypes=hh.mutually_promotable_dtypes(None))
213+
def test_result_type(self, dtypes):
214+
out = xp.result_type(*dtypes)
215+
ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out")
216+
217+
@given(pair=hh.pair_of_mutually_promotable_dtypes(None))
218+
def test_shuffled(self, pair):
219+
"""Test that result_type is insensitive to the order of arguments."""
220+
s1, s2 = pair
221+
out1 = xp.result_type(*s1)
222+
out2 = xp.result_type(*s2)
223+
assert out1 == out2
224+

0 commit comments

Comments
 (0)