@@ -109,83 +109,36 @@ def test_pytorch_AdvSubtensor():
109
109
compare_pytorch_and_py (out_fg , [x_np ])
110
110
111
111
112
- def test_pytorch_SetSubtensor ():
112
+ @pytest .mark .parametrize (
113
+ "subtensor_op" , [pt_subtensor .set_subtensor , pt_subtensor .inc_subtensor ]
114
+ )
115
+ def test_pytorch_SetSubtensor (subtensor_op ):
113
116
x_pt = pt .tensor3 ("x" )
114
117
x_test = np .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 )).astype (config .floatX )
115
118
116
119
# "Set" basic indices
117
120
st_pt = pt .as_tensor_variable (np .array (- 10.0 , dtype = config .floatX ))
118
- out_pt = pt_subtensor . set_subtensor (x_pt [1 , 2 , 3 ], st_pt )
121
+ out_pt = subtensor_op (x_pt [1 , 2 , 3 ], st_pt )
119
122
assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
120
123
out_fg = FunctionGraph ([x_pt ], [out_pt ])
121
124
compare_pytorch_and_py (out_fg , [x_test ])
122
125
123
126
st_pt = pt .as_tensor_variable (np .r_ [- 1.0 , 0.0 ].astype (config .floatX ))
124
- out_pt = pt_subtensor . set_subtensor (x_pt [:2 , 0 , 0 ], st_pt )
127
+ out_pt = subtensor_op (x_pt [:2 , 0 , 0 ], st_pt )
125
128
assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
126
129
out_fg = FunctionGraph ([x_pt ], [out_pt ])
127
130
compare_pytorch_and_py (out_fg , [x_test ])
128
131
129
- out_pt = pt_subtensor . set_subtensor (x_pt [0 , 1 :3 , 0 ], st_pt )
132
+ out_pt = subtensor_op (x_pt [0 , 1 :3 , 0 ], st_pt )
130
133
assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
131
134
out_fg = FunctionGraph ([x_pt ], [out_pt ])
132
135
compare_pytorch_and_py (out_fg , [x_test ])
133
136
134
- out_pt = pt_subtensor .set_subtensor (x_pt [0 , 1 :3 , 0 ], st_pt )
135
- assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
136
- out_fg = FunctionGraph ([x_pt ], [out_pt ])
137
- compare_pytorch_and_py (out_fg , [x_test ])
138
137
139
-
140
- def test_pytorch_AdvSetSubtensor ():
141
- rng = np .random .default_rng (42 )
142
-
143
- x_np = rng .uniform (- 1 , 1 , size = (3 , 4 , 5 )).astype (config .floatX )
144
- x_pt = pt .tensor3 ("x" )
145
- x_test = np .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 )).astype (config .floatX )
146
-
147
- # "Set" advanced indices
148
- st_pt = pt .as_tensor_variable (
149
- rng .uniform (- 1 , 1 , size = (2 , 4 , 5 )).astype (config .floatX )
150
- )
151
- out_pt = pt_subtensor .set_subtensor (x_pt [np .r_ [0 , 2 ]], st_pt )
152
- assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
153
- out_fg = FunctionGraph ([x_pt ], [out_pt ])
154
- compare_pytorch_and_py (out_fg , [x_test ])
155
-
156
- st_pt = pt .as_tensor_variable (np .r_ [- 1.0 , 0.0 ].astype (config .floatX ))
157
- out_pt = pt_subtensor .set_subtensor (x_pt [[0 , 2 ], 0 , 0 ], st_pt )
158
- assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
159
- out_fg = FunctionGraph ([x_pt ], [out_pt ])
160
- compare_pytorch_and_py (out_fg , [x_test ])
161
-
162
- # "Set" boolean indices
163
- mask_pt = pt .constant (x_np > 0 )
164
- out_pt = pt_subtensor .set_subtensor (x_pt [mask_pt ], 0.0 )
165
- assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
166
- out_fg = FunctionGraph ([x_pt ], [out_pt ])
167
- compare_pytorch_and_py (out_fg , [x_test ])
168
-
169
-
170
- def test_pytorch_IncSubtensor ():
171
- x_pt = pt .tensor3 ("x" )
172
- x_test = np .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 )).astype (config .floatX )
173
-
174
- # "Increment" basic indices
175
- st_pt = pt .as_tensor_variable (np .array (- 10.0 , dtype = config .floatX ))
176
- out_pt = pt_subtensor .inc_subtensor (x_pt [1 , 2 , 3 ], st_pt )
177
- assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
178
- out_fg = FunctionGraph ([x_pt ], [out_pt ])
179
- compare_pytorch_and_py (out_fg , [x_test ])
180
-
181
- st_pt = pt .as_tensor_variable (np .r_ [- 1.0 , 0.0 ].astype (config .floatX ))
182
- out_pt = pt_subtensor .inc_subtensor (x_pt [:2 , 0 , 0 ], st_pt )
183
- assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
184
- out_fg = FunctionGraph ([x_pt ], [out_pt ])
185
- compare_pytorch_and_py (out_fg , [x_test ])
186
-
187
-
188
- def test_pytorch_AvdancedIncSubtensor ():
138
+ @pytest .mark .parametrize (
139
+ "advsubtensor_op" , [pt_subtensor .set_subtensor , pt_subtensor .inc_subtensor ]
140
+ )
141
+ def test_pytorch_AvdancedIncSubtensor (advsubtensor_op ):
189
142
rng = np .random .default_rng (42 )
190
143
191
144
x_np = rng .uniform (- 1 , 1 , size = (3 , 4 , 5 )).astype (config .floatX )
@@ -196,32 +149,26 @@ def test_pytorch_AvdancedIncSubtensor():
196
149
st_pt = pt .as_tensor_variable (
197
150
rng .uniform (- 1 , 1 , size = (2 , 4 , 5 )).astype (config .floatX )
198
151
)
199
- out_pt = pt_subtensor . inc_subtensor (x_pt [np .r_ [0 , 2 ]], st_pt )
152
+ out_pt = advsubtensor_op (x_pt [np .r_ [0 , 2 ]], st_pt )
200
153
assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
201
154
out_fg = FunctionGraph ([x_pt ], [out_pt ])
202
155
compare_pytorch_and_py (out_fg , [x_test ])
203
156
204
157
st_pt = pt .as_tensor_variable (np .r_ [- 1.0 , 0.0 ].astype (config .floatX ))
205
- out_pt = pt_subtensor . inc_subtensor (x_pt [[0 , 2 ], 0 , 0 ], st_pt )
158
+ out_pt = advsubtensor_op (x_pt [[0 , 2 ], 0 , 0 ], st_pt )
206
159
assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
207
160
out_fg = FunctionGraph ([x_pt ], [out_pt ])
208
161
compare_pytorch_and_py (out_fg , [x_test ])
209
162
210
163
# "Increment" boolean indices
211
164
mask_pt = pt .constant (x_np > 0 )
212
- out_pt = pt_subtensor .set_subtensor (x_pt [mask_pt ], 1.0 )
213
- assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
214
- out_fg = FunctionGraph ([x_pt ], [out_pt ])
215
- compare_pytorch_and_py (out_fg , [x_test ])
216
-
217
- st_pt = pt .as_tensor_variable (x_np [[0 , 2 ], 0 , :3 ])
218
- out_pt = pt_subtensor .set_subtensor (x_pt [[0 , 2 ], 0 , :3 ], st_pt )
165
+ out_pt = advsubtensor_op (x_pt [mask_pt ], 1.0 )
219
166
assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
220
167
out_fg = FunctionGraph ([x_pt ], [out_pt ])
221
168
compare_pytorch_and_py (out_fg , [x_test ])
222
169
223
170
st_pt = pt .as_tensor_variable (x_np [[0 , 2 ], 0 , :3 ])
224
- out_pt = pt_subtensor . inc_subtensor (x_pt [[0 , 2 ], 0 , :3 ], st_pt )
171
+ out_pt = advsubtensor_op (x_pt [[0 , 2 ], 0 , :3 ], st_pt )
225
172
assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
226
173
out_fg = FunctionGraph ([x_pt ], [out_pt ])
227
174
compare_pytorch_and_py (out_fg , [x_test ])
0 commit comments