@@ -43,47 +43,52 @@ def test_pytorch_CumOp(axis, dtype):
43
43
compare_pytorch_and_py (fgraph , [test_value ])
44
44
45
45
46
- def test_pytorch_Repeat ():
46
+ @pytest .mark .parametrize ("axis" , [0 , 1 ])
47
+ def test_pytorch_Repeat (axis ):
47
48
a = pt .matrix ("a" , dtype = "float64" )
48
49
49
50
test_value = np .arange (6 , dtype = "float64" ).reshape ((3 , 2 ))
50
51
51
- # Test along axis 0
52
- out = pt .repeat (a , (1 , 2 , 3 ), axis = 0 )
52
+ out = pt .repeat (a , (1 , 2 , 3 ) if axis == 0 else (3 , 3 ), axis = axis )
53
53
fgraph = FunctionGraph ([a ], [out ])
54
54
compare_pytorch_and_py (fgraph , [test_value ])
55
55
56
- # Test along axis 1
57
- out = pt .repeat (a , (3 , 3 ), axis = 1 )
58
- fgraph = FunctionGraph ([a ], [out ])
59
- compare_pytorch_and_py (fgraph , [test_value ])
60
56
61
-
62
- def test_pytorch_Unique ( ):
57
+ @ pytest . mark . parametrize ( "axis" , [ 0 , 1 ])
58
+ def test_pytorch_Unique_axis ( axis ):
63
59
a = pt .matrix ("a" , dtype = "float64" )
64
60
65
61
test_value = np .array (
66
62
[[1.0 , 1.0 , 2.0 ], [1.0 , 1.0 , 2.0 ], [3.0 , 3.0 , 0.0 ]], dtype = "float64"
67
63
)
68
64
69
- # Test along axis 0
70
- out = pt .unique (a , axis = 0 )
71
- fgraph = FunctionGraph ([a ], [out ])
72
- compare_pytorch_and_py (fgraph , [test_value ])
73
-
74
- # Test along axis 1
75
- out = pt .unique (a , axis = 1 )
65
+ out = pt .unique (a , axis = axis )
76
66
fgraph = FunctionGraph ([a ], [out ])
77
67
compare_pytorch_and_py (fgraph , [test_value ])
78
68
79
- # Test with params
80
- out = pt .unique (a , return_inverse = True , return_counts = True , axis = 0 )
81
- fgraph = FunctionGraph ([a ], [out [0 ]])
82
- compare_pytorch_and_py (fgraph , [test_value ])
83
69
84
- # Test with return_index=True
85
- out = pt .unique (a , return_index = True , axis = 0 )
86
- fgraph = FunctionGraph ([a ], [out [0 ]])
70
+ @pytest .mark .parametrize (
71
+ "return_index, return_inverse, return_counts" ,
72
+ [
73
+ (False , True , False ),
74
+ (False , True , True ),
75
+ pytest .param (
76
+ True , False , False , marks = pytest .mark .xfail (raises = NotImplementedError )
77
+ ),
78
+ ],
79
+ )
80
+ def test_pytorch_Unique_params (return_index , return_inverse , return_counts ):
81
+ a = pt .matrix ("a" , dtype = "float64" )
82
+ test_value = np .array (
83
+ [[1.0 , 1.0 , 2.0 ], [1.0 , 1.0 , 2.0 ], [3.0 , 3.0 , 0.0 ]], dtype = "float64"
84
+ )
87
85
88
- with pytest .raises (NotImplementedError ):
89
- compare_pytorch_and_py (fgraph , [test_value ])
86
+ out = pt .unique (
87
+ a ,
88
+ return_index = return_index ,
89
+ return_inverse = return_inverse ,
90
+ return_counts = return_counts ,
91
+ axis = 0 ,
92
+ )
93
+ fgraph = FunctionGraph ([a ], [out [0 ] if isinstance (out , list ) else out ])
94
+ compare_pytorch_and_py (fgraph , [test_value ])
0 commit comments