Skip to content

Commit 5e7b208

Browse files
committed
Improvements to test_solve
- Fix it to only generate arguments with mutually promotable dtypes - Test dtype promotion - Test output shape
1 parent 1cf4a07 commit 5e7b208

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

array_api_tests/test_linalg.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -597,32 +597,42 @@ def solve_args():
597597
of shape (..., M, M), and x2 is either shape (M,) or (..., M, K),
598598
where the ... parts of x1 and x2 are broadcast compatible.
599599
"""
600+
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dh.all_float_dtypes))
601+
600602
stack_shapes = shared(two_mutually_broadcastable_shapes)
601603
# Don't worry about dtypes since all floating dtypes are type promotable
602604
# with each other.
603-
x1 = shared(invertible_matrices(stack_shapes=stack_shapes.map(lambda pair:
604-
pair[0])))
605+
x1 = shared(invertible_matrices(
606+
stack_shapes=stack_shapes.map(lambda pair: pair[0]),
607+
dtypes=mutual_dtypes.map(lambda pair: pair[0])))
605608

606609
@composite
607610
def _x2_shapes(draw):
608611
end = draw(integers(0, SQRT_MAX_ARRAY_SIZE))
609612
return draw(stack_shapes)[1] + draw(x1).shape[-1:] + (end,)
610613

611614
x2_shapes = one_of(x1.map(lambda x: (x.shape[-1],)), _x2_shapes())
612-
x2 = arrays(dtype=all_floating_dtypes(), shape=x2_shapes)
615+
x2 = arrays(shape=x2_shapes, dtype=mutual_dtypes.map(lambda pair: pair[1]))
613616
return x1, x2
614617

615618
@pytest.mark.xp_extension('linalg')
616619
@given(*solve_args())
617620
def test_solve(x1, x2):
618621
res = linalg.solve(x1, x2)
619622

623+
ph.assert_dtype("solve", in_dtype=[x1.dtype, x2.dtype], out_dtype=res.dtype)
620624
if x2.ndim == 1:
625+
expected_shape = x1.shape[:-2] + x2.shape[-1:]
621626
_test_stacks(linalg.solve, x1, x2, res=res, dims=1,
622627
matrix_axes=[(-2, -1), (0,)], res_axes=[-1])
623628
else:
629+
stack_shape = sh.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
630+
expected_shape = stack_shape + x2.shape[-2:]
624631
_test_stacks(linalg.solve, x1, x2, res=res, dims=2)
625632

633+
ph.assert_result_shape("solve", in_shapes=[x1.shape, x2.shape],
634+
out_shape=res.shape, expected=expected_shape)
635+
626636
@pytest.mark.xp_extension('linalg')
627637
@given(
628638
x=finite_matrices(),

0 commit comments

Comments
 (0)