@@ -43,47 +43,52 @@ def test_pytorch_CumOp(axis, dtype):
4343 compare_pytorch_and_py (fgraph , [test_value ])
4444
4545
46- def test_pytorch_Repeat ():
46+ @pytest .mark .parametrize ("axis" , [0 , 1 ])
47+ def test_pytorch_Repeat (axis ):
4748 a = pt .matrix ("a" , dtype = "float64" )
4849
4950 test_value = np .arange (6 , dtype = "float64" ).reshape ((3 , 2 ))
5051
51- # Test along axis 0
52- out = pt .repeat (a , (1 , 2 , 3 ), axis = 0 )
52+ out = pt .repeat (a , (1 , 2 , 3 ) if axis == 0 else (3 , 3 ), axis = axis )
5353 fgraph = FunctionGraph ([a ], [out ])
5454 compare_pytorch_and_py (fgraph , [test_value ])
5555
56- # Test along axis 1
57- out = pt .repeat (a , (3 , 3 ), axis = 1 )
58- fgraph = FunctionGraph ([a ], [out ])
59- compare_pytorch_and_py (fgraph , [test_value ])
6056
61-
62- def test_pytorch_Unique ( ):
57+ @ pytest . mark . parametrize ( "axis" , [ 0 , 1 ])
58+ def test_pytorch_Unique_axis ( axis ):
6359 a = pt .matrix ("a" , dtype = "float64" )
6460
6561 test_value = np .array (
6662 [[1.0 , 1.0 , 2.0 ], [1.0 , 1.0 , 2.0 ], [3.0 , 3.0 , 0.0 ]], dtype = "float64"
6763 )
6864
69- # Test along axis 0
70- out = pt .unique (a , axis = 0 )
71- fgraph = FunctionGraph ([a ], [out ])
72- compare_pytorch_and_py (fgraph , [test_value ])
73-
74- # Test along axis 1
75- out = pt .unique (a , axis = 1 )
65+ out = pt .unique (a , axis = axis )
7666 fgraph = FunctionGraph ([a ], [out ])
7767 compare_pytorch_and_py (fgraph , [test_value ])
7868
79- # Test with params
80- out = pt .unique (a , return_inverse = True , return_counts = True , axis = 0 )
81- fgraph = FunctionGraph ([a ], [out [0 ]])
82- compare_pytorch_and_py (fgraph , [test_value ])
8369
84- # Test with return_index=True
85- out = pt .unique (a , return_index = True , axis = 0 )
86- fgraph = FunctionGraph ([a ], [out [0 ]])
70+ @pytest .mark .parametrize (
71+ "return_index, return_inverse, return_counts" ,
72+ [
73+ (False , True , False ),
74+ (False , True , True ),
75+ pytest .param (
76+ True , False , False , marks = pytest .mark .xfail (raises = NotImplementedError )
77+ ),
78+ ],
79+ )
80+ def test_pytorch_Unique_params (return_index , return_inverse , return_counts ):
81+ a = pt .matrix ("a" , dtype = "float64" )
82+ test_value = np .array (
83+ [[1.0 , 1.0 , 2.0 ], [1.0 , 1.0 , 2.0 ], [3.0 , 3.0 , 0.0 ]], dtype = "float64"
84+ )
8785
88- with pytest .raises (NotImplementedError ):
89- compare_pytorch_and_py (fgraph , [test_value ])
86+ out = pt .unique (
87+ a ,
88+ return_index = return_index ,
89+ return_inverse = return_inverse ,
90+ return_counts = return_counts ,
91+ axis = 0 ,
92+ )
93+ fgraph = FunctionGraph ([a ], [out [0 ] if isinstance (out , list ) else out ])
94+ compare_pytorch_and_py (fgraph , [test_value ])
0 commit comments