diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index bbb44cdc..6f4608da 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -206,11 +206,11 @@ def test_cross(x1_x2_kw): def exact_cross(a, b): assert a.shape == b.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite." - return asarray([ + return asarray(xp.stack([ a[1]*b[2] - a[2]*b[1], a[2]*b[0] - a[0]*b[2], a[0]*b[1] - a[1]*b[0], - ], dtype=res.dtype) + ]), dtype=res.dtype) # We don't want to pass in **kw here because that would pass axis to # cross() on a single stack, but the axis is not meaningful on unstacked @@ -267,7 +267,7 @@ def true_diag(x_stack, offset=0): x_stack_diag = [x_stack[i, i + offset] for i in range(diag_size)] else: x_stack_diag = [x_stack[i - offset, i] for i in range(diag_size)] - return asarray(x_stack_diag, dtype=x.dtype) + return asarray(xp.stack(x_stack_diag) if x_stack_diag else [], dtype=x.dtype) _test_stacks(linalg.diagonal, x, **kw, res=res, dims=1, true_val=true_diag) @@ -901,7 +901,9 @@ def true_trace(x_stack, offset=0): x_stack_diag = [x_stack[i, i + offset] for i in range(diag_size)] else: x_stack_diag = [x_stack[i - offset, i] for i in range(diag_size)] - return _array_module.sum(asarray(x_stack_diag, dtype=x.dtype)) + result = xp.asarray(xp.stack(x_stack_diag) if x_stack_diag else [], dtype=x.dtype) + return _array_module.sum(result) + _test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace)