Skip to content

Commit a71b4c0

Browse files
authored
Merge pull request #314 from ev-br/vecdot_conj
ENH: test vecdot values, incl complex conj
2 parents c2e010e + 6ea8ae2 commit a71b4c0

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

array_api_tests/test_linalg.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from . import _array_module as xp
4646
from ._array_module import linalg
4747

48+
4849
def assert_equal(x, y, msg_extra=None):
4950
extra = '' if not msg_extra else f' ({msg_extra})'
5051
if x.dtype in dh.all_float_dtypes:
@@ -60,6 +61,7 @@ def assert_equal(x, y, msg_extra=None):
6061
else:
6162
assert_exactly_equal(x, y, msg_extra=msg_extra)
6263

64+
6365
def _test_stacks(f, *args, res=None, dims=2, true_val=None,
6466
matrix_axes=(-2, -1),
6567
res_axes=None,
@@ -106,6 +108,7 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None,
106108
if true_val:
107109
assert_equal(decomp_res_stack, true_val(*x_stacks, **kw), msg_extra)
108110

111+
109112
def _test_namedtuple(res, fields, func_name):
110113
"""
111114
Test that res is a namedtuple with the correct fields.
@@ -121,6 +124,7 @@ def _test_namedtuple(res, fields, func_name):
121124
assert hasattr(res, field), f"{func_name}() result namedtuple doesn't have the '{field}' field"
122125
assert res[i] is getattr(res, field), f"{func_name}() result namedtuple '{field}' field is not in position {i}"
123126

127+
124128
@pytest.mark.unvectorized
125129
@pytest.mark.xp_extension('linalg')
126130
@given(
@@ -901,6 +905,15 @@ def true_trace(x_stack, offset=0):
901905

902906
_test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace)
903907

908+
909+
def _conj(x):
910+
# XXX: replace with xp.dtype when all array libraries implement it
911+
if x.dtype in (xp.complex64, xp.complex128):
912+
return xp.conj(x)
913+
else:
914+
return x
915+
916+
904917
def _test_vecdot(namespace, x1, x2, data):
905918
vecdot = namespace.vecdot
906919
broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape)
@@ -925,11 +938,8 @@ def _test_vecdot(namespace, x1, x2, data):
925938
ph.assert_result_shape("vecdot", in_shapes=[x1.shape, x2.shape],
926939
out_shape=res.shape, expected=expected_shape)
927940

928-
if x1.dtype in dh.int_dtypes:
929-
def true_val(x, y, axis=-1):
930-
return xp.sum(xp.multiply(x, y), dtype=res.dtype)
931-
else:
932-
true_val = None
941+
def true_val(x, y, axis=-1):
942+
return xp.sum(xp.multiply(_conj(x), y), dtype=res.dtype)
933943

934944
_test_stacks(vecdot, x1, x2, res=res, dims=0,
935945
matrix_axes=(axis,), true_val=true_val)
@@ -944,6 +954,7 @@ def true_val(x, y, axis=-1):
944954
def test_linalg_vecdot(x1, x2, data):
945955
_test_vecdot(linalg, x1, x2, data)
946956

957+
947958
@pytest.mark.unvectorized
948959
@given(
949960
*two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)),
@@ -952,10 +963,12 @@ def test_linalg_vecdot(x1, x2, data):
952963
def test_vecdot(x1, x2, data):
953964
_test_vecdot(_array_module, x1, x2, data)
954965

966+
955967
# Insanely large orders might not work. There isn't a limit specified in the
956968
# spec, so we just limit to reasonable values here.
957969
max_ord = 100
958970

971+
959972
@pytest.mark.unvectorized
960973
@pytest.mark.xp_extension('linalg')
961974
@given(

0 commit comments

Comments
 (0)