@@ -43,18 +43,29 @@ def test_pytorch_CumOp(axis, dtype):
43
43
compare_pytorch_and_py (fgraph , [test_value ])
44
44
45
45
46
- @pytest .mark .parametrize ("axis" , [0 , 1 ])
47
- def test_pytorch_Repeat (axis ):
46
+ @pytest .mark .parametrize (
47
+ "axis, repeats" ,
48
+ [
49
+ (0 , (1 , 2 , 3 )),
50
+ (1 , (3 , 3 )),
51
+ pytest .param (
52
+ None ,
53
+ 3 ,
54
+ marks = pytest .mark .xfail (reason = "Reshape not implemented" ),
55
+ ),
56
+ ],
57
+ )
58
+ def test_pytorch_Repeat (axis , repeats ):
48
59
a = pt .matrix ("a" , dtype = "float64" )
49
60
50
61
test_value = np .arange (6 , dtype = "float64" ).reshape ((3 , 2 ))
51
62
52
- out = pt .repeat (a , ( 1 , 2 , 3 ) if axis == 0 else ( 3 , 3 ) , axis = axis )
63
+ out = pt .repeat (a , repeats , axis = axis )
53
64
fgraph = FunctionGraph ([a ], [out ])
54
65
compare_pytorch_and_py (fgraph , [test_value ])
55
66
56
67
57
- @pytest .mark .parametrize ("axis" , [0 , 1 ])
68
+ @pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
58
69
def test_pytorch_Unique_axis (axis ):
59
70
a = pt .matrix ("a" , dtype = "float64" )
60
71
0 commit comments