Skip to content

Commit fec6f35

Browse files
authored
Merge pull request #239 from asmeurer/test_solve-fixes
Improvements to test_solve
2 parents 6a1f943 + 5e7b208 commit fec6f35

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

array_api_tests/test_linalg.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -597,32 +597,42 @@ def solve_args():
597597
of shape (..., M, M), and x2 is either shape (M,) or (..., M, K),
598598
where the ... parts of x1 and x2 are broadcast compatible.
599599
"""
600+
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dh.all_float_dtypes))
601+
600602
stack_shapes = shared(two_mutually_broadcastable_shapes)
601603
# Don't worry about dtypes since all floating dtypes are type promotable
602604
# with each other.
603-
x1 = shared(invertible_matrices(stack_shapes=stack_shapes.map(lambda pair:
604-
pair[0])))
605+
x1 = shared(invertible_matrices(
606+
stack_shapes=stack_shapes.map(lambda pair: pair[0]),
607+
dtypes=mutual_dtypes.map(lambda pair: pair[0])))
605608

606609
@composite
607610
def _x2_shapes(draw):
608611
end = draw(integers(0, SQRT_MAX_ARRAY_SIZE))
609612
return draw(stack_shapes)[1] + draw(x1).shape[-1:] + (end,)
610613

611614
x2_shapes = one_of(x1.map(lambda x: (x.shape[-1],)), _x2_shapes())
612-
x2 = arrays(dtype=all_floating_dtypes(), shape=x2_shapes)
615+
x2 = arrays(shape=x2_shapes, dtype=mutual_dtypes.map(lambda pair: pair[1]))
613616
return x1, x2
614617

615618
@pytest.mark.xp_extension('linalg')
616619
@given(*solve_args())
617620
def test_solve(x1, x2):
618621
res = linalg.solve(x1, x2)
619622

623+
ph.assert_dtype("solve", in_dtype=[x1.dtype, x2.dtype], out_dtype=res.dtype)
620624
if x2.ndim == 1:
625+
expected_shape = x1.shape[:-2] + x2.shape[-1:]
621626
_test_stacks(linalg.solve, x1, x2, res=res, dims=1,
622627
matrix_axes=[(-2, -1), (0,)], res_axes=[-1])
623628
else:
629+
stack_shape = sh.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
630+
expected_shape = stack_shape + x2.shape[-2:]
624631
_test_stacks(linalg.solve, x1, x2, res=res, dims=2)
625632

633+
ph.assert_result_shape("solve", in_shapes=[x1.shape, x2.shape],
634+
out_shape=res.shape, expected=expected_shape)
635+
626636
@pytest.mark.xp_extension('linalg')
627637
@given(
628638
x=finite_matrices(),

0 commit comments

Comments
 (0)