@@ -597,32 +597,42 @@ def solve_args():
597
597
of shape (..., M, M), and x2 is either shape (M,) or (..., M, K),
598
598
where the ... parts of x1 and x2 are broadcast compatible.
599
599
"""
600
+ mutual_dtypes = shared (mutually_promotable_dtypes (dtypes = dh .all_float_dtypes ))
601
+
600
602
stack_shapes = shared (two_mutually_broadcastable_shapes )
601
603
# Don't worry about dtypes since all floating dtypes are type promotable
602
604
# with each other.
603
- x1 = shared (invertible_matrices (stack_shapes = stack_shapes .map (lambda pair :
604
- pair [0 ])))
605
+ x1 = shared (invertible_matrices (
606
+ stack_shapes = stack_shapes .map (lambda pair : pair [0 ]),
607
+ dtypes = mutual_dtypes .map (lambda pair : pair [0 ])))
605
608
606
609
@composite
607
610
def _x2_shapes (draw ):
608
611
end = draw (integers (0 , SQRT_MAX_ARRAY_SIZE ))
609
612
return draw (stack_shapes )[1 ] + draw (x1 ).shape [- 1 :] + (end ,)
610
613
611
614
x2_shapes = one_of (x1 .map (lambda x : (x .shape [- 1 ],)), _x2_shapes ())
612
- x2 = arrays (dtype = all_floating_dtypes (), shape = x2_shapes )
615
+ x2 = arrays (shape = x2_shapes , dtype = mutual_dtypes . map ( lambda pair : pair [ 1 ]) )
613
616
return x1 , x2
614
617
615
618
@pytest .mark .xp_extension ('linalg' )
616
619
@given (* solve_args ())
617
620
def test_solve (x1 , x2 ):
618
621
res = linalg .solve (x1 , x2 )
619
622
623
+ ph .assert_dtype ("solve" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = res .dtype )
620
624
if x2 .ndim == 1 :
625
+ expected_shape = x1 .shape [:- 2 ] + x2 .shape [- 1 :]
621
626
_test_stacks (linalg .solve , x1 , x2 , res = res , dims = 1 ,
622
627
matrix_axes = [(- 2 , - 1 ), (0 ,)], res_axes = [- 1 ])
623
628
else :
629
+ stack_shape = sh .broadcast_shapes (x1 .shape [:- 2 ], x2 .shape [:- 2 ])
630
+ expected_shape = stack_shape + x2 .shape [- 2 :]
624
631
_test_stacks (linalg .solve , x1 , x2 , res = res , dims = 2 )
625
632
633
+ ph .assert_result_shape ("solve" , in_shapes = [x1 .shape , x2 .shape ],
634
+ out_shape = res .shape , expected = expected_shape )
635
+
626
636
@pytest .mark .xp_extension ('linalg' )
627
637
@given (
628
638
x = finite_matrices (),
0 commit comments