15
15
16
16
import pytest
17
17
from hypothesis import assume , given
18
- from hypothesis .strategies import (booleans , composite , none , lists , tuples ,
19
- floats , integers , shared , sampled_from ,
20
- one_of , data , just )
18
+ from hypothesis .strategies import (booleans , composite , none , tuples , floats ,
19
+ integers , shared , sampled_from , one_of ,
20
+ data , just )
21
21
from ndindex import iter_indices
22
22
23
+ import itertools
24
+
23
25
from .array_helpers import assert_exactly_equal , asarray
24
26
from .hypothesis_helpers import (xps , dtypes , shapes , kwargs , matrix_shapes ,
25
27
square_matrix_shapes , symmetric_matrices ,
@@ -619,15 +621,52 @@ def tensordot_shapes(draw):
619
621
shape1 , shape2 = map (tuple , [_shape1 , _shape2 ])
620
622
return (shape1 , shape2 )
621
623
624
+ def _test_tensordot_stacks (x1 , x2 , kw , res ):
625
+ """
626
+ Variant of _test_stacks for tensordot
627
+
628
+ tensordot doesn't stack directly along the non-contracted dimensions like
629
+ the other linalg functions. Rather, it is stacked along the product of
630
+ each non-contracted dimension. These dimensions are independent of one
631
+ another and do not broadcast.
632
+ """
633
+ shape1 , shape2 = x1 .shape , x2 .shape
634
+
635
+ axes = kw .get ('axes' , 2 )
636
+
637
+ if isinstance (axes , int ):
638
+ res_axes = axes
639
+ axes = [list (range (- axes , 0 )), list (range (0 , axes ))]
640
+ else :
641
+ # Convert something like (0, 4, 2) into (0, 2, 1)
642
+ res_axes = []
643
+ for a , s in zip (axes , [shape1 , shape2 ]):
644
+ indices = [range (len (s ))[i ] for i in a ]
645
+ repl = dict (zip (sorted (indices ), range (len (indices ))))
646
+ res_axes .append (tuple (repl [i ] for i in indices ))
647
+
648
+ for ((i ,), (j ,)), (res_idx ,) in zip (
649
+ itertools .product (
650
+ iter_indices (shape1 , skip_axes = axes [0 ]),
651
+ iter_indices (shape2 , skip_axes = axes [1 ])),
652
+ iter_indices (res .shape )):
653
+ i , j , res_idx = i .raw , j .raw , res_idx .raw
654
+
655
+ res_stack = res [res_idx ]
656
+ x1_stack = x1 [i ]
657
+ x2_stack = x2 [j ]
658
+ decomp_res_stack = xp .tensordot (x1_stack , x2_stack , axes = res_axes )
659
+ assert_exactly_equal (res_stack , decomp_res_stack )
660
+
622
661
@given (
623
662
* two_mutual_arrays (dh .numeric_dtypes , two_shapes = tensordot_shapes ()),
624
663
tensordot_kw ,
625
664
)
626
665
def test_tensordot (x1 , x2 , kw ):
627
666
# TODO: vary shapes, vary contracted axes, test different axes arguments
628
- out = xp .tensordot (x1 , x2 , ** kw )
667
+ res = xp .tensordot (x1 , x2 , ** kw )
629
668
630
- ph .assert_dtype ("tensordot" , [x1 .dtype , x2 .dtype ], out .dtype )
669
+ ph .assert_dtype ("tensordot" , [x1 .dtype , x2 .dtype ], res .dtype )
631
670
632
671
axes = _axes = kw .get ('axes' , 2 )
633
672
@@ -641,10 +680,10 @@ def test_tensordot(x1, x2, kw):
641
680
_shape1 = tuple ([i for i in _shape1 if i is not None ])
642
681
_shape2 = tuple ([i for i in _shape2 if i is not None ])
643
682
result_shape = _shape1 + _shape2
644
- ph .assert_result_shape ('tensordot' , [x1 .shape , x2 .shape ], out .shape ,
683
+ ph .assert_result_shape ('tensordot' , [x1 .shape , x2 .shape ], res .shape ,
645
684
expected = result_shape )
646
685
# TODO: assert stacking and elements
647
-
686
+ _test_tensordot_stacks ( x1 , x2 , kw , res )
648
687
649
688
@pytest .mark .xp_extension ('linalg' )
650
689
@given (
0 commit comments