From 0ecc4f0fd051c33817495ae9cd3431252b53df02 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 16 Mar 2025 17:33:57 +0100 Subject: [PATCH 1/3] ENH: add a test that result_type does not depend on the order of arguments --- array_api_tests/hypothesis_helpers.py | 11 +++++++++++ array_api_tests/test_data_type_functions.py | 18 ++++++++++++++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 3c09a1d5..4cb9b68c 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -10,7 +10,11 @@ from hypothesis import assume, reject from hypothesis.strategies import (SearchStrategy, booleans, composite, floats, integers, complex_numbers, just, lists, none, one_of, +<<<<<<< HEAD sampled_from, shared, builds, nothing) +======= + sampled_from, shared, builds, nothing, permutations) +>>>>>>> ENH: add a test that result_type does not depend on the order of arguments from . import _array_module as xp, api_version from . import array_helpers as ah @@ -148,6 +152,13 @@ def mutually_promotable_dtypes( return one_of(strats).map(tuple) +@composite +def pair_of_mutually_promotable_dtypes(draw, max_size=2, *, dtypes=dh.all_dtypes): + sample = draw(mutually_promotable_dtypes( max_size, dtypes=dtypes)) + permuted = draw(permutations(sample)) + return sample, permuted + + class OnewayPromotableDtypes(NamedTuple): input_dtype: DataType result_dtype: DataType diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index f9642f31..33261291 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -208,7 +208,17 @@ def test_isdtype(dtype, kind): assert out == expected, f"{out=}, but should be {expected} [isdtype()]" -@given(hh.mutually_promotable_dtypes(None)) -def test_result_type(dtypes): - out = xp.result_type(*dtypes) - ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out") +class TestResultType: + @given(dtypes=hh.mutually_promotable_dtypes(None)) + def test_result_type(self, dtypes): + out = xp.result_type(*dtypes) + ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out") + + @given(pair=hh.pair_of_mutually_promotable_dtypes(None)) + def test_shuffled(self, pair): + """Test that result_type is insensitive to the order of arguments.""" + s1, s2 = pair + out1 = xp.result_type(*s1) + out2 = xp.result_type(*s2) + assert out1 == out2 + From af389a89bb48462cc529ae4c9840e5d7d1fb2738 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 17 Mar 2025 13:10:21 +0100 Subject: [PATCH 2/3] ENH: add a result_type test with a mix of arrays and dtypes --- array_api_tests/hypothesis_helpers.py | 2 +- array_api_tests/test_data_type_functions.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 4cb9b68c..0e0edfae 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -156,7 +156,7 @@ def mutually_promotable_dtypes( def pair_of_mutually_promotable_dtypes(draw, max_size=2, *, dtypes=dh.all_dtypes): sample = draw(mutually_promotable_dtypes( max_size, dtypes=dtypes)) permuted = draw(permutations(sample)) - return sample, permuted + return sample, tuple(permuted) class OnewayPromotableDtypes(NamedTuple): diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 33261291..78306b03 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -222,3 +222,11 @@ def test_shuffled(self, pair): out2 = xp.result_type(*s2) assert out1 == out2 + @given(pair=hh.pair_of_mutually_promotable_dtypes(2), data=st.data()) + def test_arrays_and_dtypes(self, pair, data): + s1, s2 = pair + a2 = tuple(xp.empty(1, dtype=dt) for dt in s2) + a_and_dt = data.draw(st.permutations(s1 + a2)) + out = xp.result_type(*a_and_dt) + ph.assert_dtype("result_type", in_dtype=s1+s2, out_dtype=out, repr_name="out") + From d53cb3940e42d936dc064ed4462abe774931971e Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 17 Mar 2025 15:55:10 +0100 Subject: [PATCH 3/3] ENH: add a result_type test with python scalars --- array_api_tests/hypothesis_helpers.py | 4 --- array_api_tests/test_data_type_functions.py | 28 +++++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 0e0edfae..7e7294f8 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -10,11 +10,7 @@ from hypothesis import assume, reject from hypothesis.strategies import (SearchStrategy, booleans, composite, floats, integers, complex_numbers, just, lists, none, one_of, -<<<<<<< HEAD - sampled_from, shared, builds, nothing) -======= sampled_from, shared, builds, nothing, permutations) ->>>>>>> ENH: add a test that result_type does not depend on the order of arguments from . import _array_module as xp, api_version from . import array_helpers as ah diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 78306b03..e844c432 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -208,6 +208,7 @@ def test_isdtype(dtype, kind): assert out == expected, f"{out=}, but should be {expected} [isdtype()]" +@pytest.mark.min_version("2024.12") class TestResultType: @given(dtypes=hh.mutually_promotable_dtypes(None)) def test_result_type(self, dtypes): @@ -230,3 +231,30 @@ def test_arrays_and_dtypes(self, pair, data): out = xp.result_type(*a_and_dt) ph.assert_dtype("result_type", in_dtype=s1+s2, out_dtype=out, repr_name="out") + @given(dtypes=hh.mutually_promotable_dtypes(2), data=st.data()) + def test_with_scalars(self, dtypes, data): + out = xp.result_type(*dtypes) + + if out == xp.bool: + scalars = [True] + elif out in dh.all_int_dtypes: + scalars = [1] + elif out in dh.real_dtypes: + scalars = [1, 1.0] + elif out in dh.numeric_dtypes: + scalars = [1, 1.0, 1j] # numeric_types - real_types == complex_types + else: + raise ValueError(f"unknown dtype {out = }.") + + scalar = data.draw(st.sampled_from(scalars)) + inputs = data.draw(st.permutations(dtypes + (scalar,))) + + out_scalar = xp.result_type(*inputs) + assert out_scalar == out + + # retry with arrays + arrays = tuple(xp.empty(1, dtype=dt) for dt in dtypes) + inputs = data.draw(st.permutations(arrays + (scalar,))) + out_scalar = xp.result_type(*inputs) + assert out_scalar == out +