45
45
from . import _array_module as xp
46
46
from ._array_module import linalg
47
47
48
+
48
49
def assert_equal (x , y , msg_extra = None ):
49
50
extra = '' if not msg_extra else f' ({ msg_extra } )'
50
51
if x .dtype in dh .all_float_dtypes :
@@ -60,6 +61,7 @@ def assert_equal(x, y, msg_extra=None):
60
61
else :
61
62
assert_exactly_equal (x , y , msg_extra = msg_extra )
62
63
64
+
63
65
def _test_stacks (f , * args , res = None , dims = 2 , true_val = None ,
64
66
matrix_axes = (- 2 , - 1 ),
65
67
res_axes = None ,
@@ -106,6 +108,7 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None,
106
108
if true_val :
107
109
assert_equal (decomp_res_stack , true_val (* x_stacks , ** kw ), msg_extra )
108
110
111
+
109
112
def _test_namedtuple (res , fields , func_name ):
110
113
"""
111
114
Test that res is a namedtuple with the correct fields.
@@ -121,6 +124,7 @@ def _test_namedtuple(res, fields, func_name):
121
124
assert hasattr (res , field ), f"{ func_name } () result namedtuple doesn't have the '{ field } ' field"
122
125
assert res [i ] is getattr (res , field ), f"{ func_name } () result namedtuple '{ field } ' field is not in position { i } "
123
126
127
+
124
128
@pytest .mark .unvectorized
125
129
@pytest .mark .xp_extension ('linalg' )
126
130
@given (
@@ -901,6 +905,15 @@ def true_trace(x_stack, offset=0):
901
905
902
906
_test_stacks (linalg .trace , x , ** kw , res = res , dims = 0 , true_val = true_trace )
903
907
908
+
909
+ def _conj (x ):
910
+ # XXX: replace with xp.dtype when all array libraries implement it
911
+ if x .dtype in (xp .complex64 , xp .complex128 ):
912
+ return xp .conj (x )
913
+ else :
914
+ return x
915
+
916
+
904
917
def _test_vecdot (namespace , x1 , x2 , data ):
905
918
vecdot = namespace .vecdot
906
919
broadcasted_shape = sh .broadcast_shapes (x1 .shape , x2 .shape )
@@ -925,11 +938,8 @@ def _test_vecdot(namespace, x1, x2, data):
925
938
ph .assert_result_shape ("vecdot" , in_shapes = [x1 .shape , x2 .shape ],
926
939
out_shape = res .shape , expected = expected_shape )
927
940
928
- if x1 .dtype in dh .int_dtypes :
929
- def true_val (x , y , axis = - 1 ):
930
- return xp .sum (xp .multiply (x , y ), dtype = res .dtype )
931
- else :
932
- true_val = None
941
+ def true_val (x , y , axis = - 1 ):
942
+ return xp .sum (xp .multiply (_conj (x ), y ), dtype = res .dtype )
933
943
934
944
_test_stacks (vecdot , x1 , x2 , res = res , dims = 0 ,
935
945
matrix_axes = (axis ,), true_val = true_val )
@@ -944,6 +954,7 @@ def true_val(x, y, axis=-1):
944
954
def test_linalg_vecdot (x1 , x2 , data ):
945
955
_test_vecdot (linalg , x1 , x2 , data )
946
956
957
+
947
958
@pytest .mark .unvectorized
948
959
@given (
949
960
* two_mutual_arrays (dh .numeric_dtypes , mutually_broadcastable_shapes (2 , min_dims = 1 )),
@@ -952,10 +963,12 @@ def test_linalg_vecdot(x1, x2, data):
952
963
def test_vecdot (x1 , x2 , data ):
953
964
_test_vecdot (_array_module , x1 , x2 , data )
954
965
966
+
955
967
# Insanely large orders might not work. There isn't a limit specified in the
956
968
# spec, so we just limit to reasonable values here.
957
969
max_ord = 100
958
970
971
+
959
972
@pytest .mark .unvectorized
960
973
@pytest .mark .xp_extension ('linalg' )
961
974
@given (
0 commit comments