Skip to content

Commit 53e8339

Browse files
committed
torch/result_type: add a regression test from array-api-compat#273
1 parent 1a154fb commit 53e8339

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

Diff for: array_api_tests/test_data_type_functions.py

+10
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,13 @@ def test_with_scalars(self, dtypes, data):
258258
out_scalar = xp.result_type(*inputs)
259259
assert out_scalar == out
260260

261+
@pytest.mark.parametrize("dtype_a", (xp.int32, xp.int64) + dh.real_float_dtypes + dh.complex_dtypes)
262+
@pytest.mark.parametrize("dtype_b", (xp.int32, xp.int64) + dh.real_float_dtypes + dh.complex_dtypes)
263+
def test_gh_273(self, dtype_a, dtype_b):
264+
# Regression test for https://github.com/data-apis/array-api-compat/issues/273
265+
# Note it is manually parametrized instead of using hypothesis
266+
a = xp.asarray([2, 1], dtype=dtype_a)
267+
b = xp.asarray([1, -1], dtype=dtype_b)
268+
dtype_1 = xp.result_type(a, b, 1.0)
269+
dtype_2 = xp.result_type(b, a, 1.0)
270+
assert dtype_1 == dtype_2

0 commit comments

Comments
 (0)