Skip to content

Commit 45b36d6

Browse files
committed
Test stacking for tensordot
1 parent a776cd4 commit 45b36d6

File tree

1 file changed

+46
-7
lines changed

1 file changed

+46
-7
lines changed

array_api_tests/test_linalg.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515

1616
import pytest
1717
from hypothesis import assume, given
18-
from hypothesis.strategies import (booleans, composite, none, lists, tuples,
19-
floats, integers, shared, sampled_from,
20-
one_of, data, just)
18+
from hypothesis.strategies import (booleans, composite, none, tuples, floats,
19+
integers, shared, sampled_from, one_of,
20+
data, just)
2121
from ndindex import iter_indices
2222

23+
import itertools
24+
2325
from .array_helpers import assert_exactly_equal, asarray
2426
from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes,
2527
square_matrix_shapes, symmetric_matrices,
@@ -619,15 +621,52 @@ def tensordot_shapes(draw):
619621
shape1, shape2 = map(tuple, [_shape1, _shape2])
620622
return (shape1, shape2)
621623

624+
def _test_tensordot_stacks(x1, x2, kw, res):
625+
"""
626+
Variant of _test_stacks for tensordot
627+
628+
tensordot doesn't stack directly along the non-contracted dimensions like
629+
the other linalg functions. Rather, it is stacked along the product of
630+
each non-contracted dimension. These dimensions are independent of one
631+
another and do not broadcast.
632+
"""
633+
shape1, shape2 = x1.shape, x2.shape
634+
635+
axes = kw.get('axes', 2)
636+
637+
if isinstance(axes, int):
638+
res_axes = axes
639+
axes = [list(range(-axes, 0)), list(range(0, axes))]
640+
else:
641+
# Convert something like (0, 4, 2) into (0, 2, 1)
642+
res_axes = []
643+
for a, s in zip(axes, [shape1, shape2]):
644+
indices = [range(len(s))[i] for i in a]
645+
repl = dict(zip(sorted(indices), range(len(indices))))
646+
res_axes.append(tuple(repl[i] for i in indices))
647+
648+
for ((i,), (j,)), (res_idx,) in zip(
649+
itertools.product(
650+
iter_indices(shape1, skip_axes=axes[0]),
651+
iter_indices(shape2, skip_axes=axes[1])),
652+
iter_indices(res.shape)):
653+
i, j, res_idx = i.raw, j.raw, res_idx.raw
654+
655+
res_stack = res[res_idx]
656+
x1_stack = x1[i]
657+
x2_stack = x2[j]
658+
decomp_res_stack = xp.tensordot(x1_stack, x2_stack, axes=res_axes)
659+
assert_exactly_equal(res_stack, decomp_res_stack)
660+
622661
@given(
623662
*two_mutual_arrays(dh.numeric_dtypes, two_shapes=tensordot_shapes()),
624663
tensordot_kw,
625664
)
626665
def test_tensordot(x1, x2, kw):
627666
# TODO: vary shapes, vary contracted axes, test different axes arguments
628-
out = xp.tensordot(x1, x2, **kw)
667+
res = xp.tensordot(x1, x2, **kw)
629668

630-
ph.assert_dtype("tensordot", [x1.dtype, x2.dtype], out.dtype)
669+
ph.assert_dtype("tensordot", [x1.dtype, x2.dtype], res.dtype)
631670

632671
axes = _axes = kw.get('axes', 2)
633672

@@ -641,10 +680,10 @@ def test_tensordot(x1, x2, kw):
641680
_shape1 = tuple([i for i in _shape1 if i is not None])
642681
_shape2 = tuple([i for i in _shape2 if i is not None])
643682
result_shape = _shape1 + _shape2
644-
ph.assert_result_shape('tensordot', [x1.shape, x2.shape], out.shape,
683+
ph.assert_result_shape('tensordot', [x1.shape, x2.shape], res.shape,
645684
expected=result_shape)
646685
# TODO: assert stacking and elements
647-
686+
_test_tensordot_stacks(x1, x2, kw, res)
648687

649688
@pytest.mark.xp_extension('linalg')
650689
@given(

0 commit comments

Comments
 (0)