diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 34c40024..c69e4143 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -107,31 +107,17 @@ def test_broadcast_to(x, data): # TODO: test values -@given(_from=non_complex_dtypes(), to=non_complex_dtypes(), data=st.data()) -def test_can_cast(_from, to, data): - from_ = data.draw( - st.just(_from) | hh.arrays(dtype=_from, shape=hh.shapes()), label="from_" - ) +@given(_from=hh.all_dtypes, to=hh.all_dtypes) +def test_can_cast(_from, to): + out = xp.can_cast(_from, to) - out = xp.can_cast(from_, to) + expected = False + for other in dh.all_dtypes: + if dh.promotion_table.get((_from, other)) == to: + expected = True + break f_func = f"[can_cast({dh.dtype_to_name[_from]}, {dh.dtype_to_name[to]})]" - assert isinstance(out, bool), f"{type(out)=}, but should be bool {f_func}" - if _from == xp.bool: - expected = to == xp.bool - else: - same_family = None - for dtypes in [dh.all_int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]: - if _from in dtypes: - same_family = to in dtypes - break - assert same_family is not None # sanity check - if same_family: - from_min, from_max = dh.dtype_ranges[_from] - to_min, to_max = dh.dtype_ranges[to] - expected = from_min >= to_min and from_max <= to_max - else: - expected = False if expected: # cross-kind casting is not explicitly disallowed. We can only test # the cases where it should return True. TODO: if expected=False,