@@ -100,101 +100,128 @@ def test_pytorch_AdvSubtensor():
100
100
out_fg = FunctionGraph ([x_pt , a_pt ], [out_pt ])
101
101
compare_pytorch_and_py (out_fg , [x_np , a_np ])
102
102
103
+ with pytest .raises (
104
+ NotImplementedError , match = "Negative step sizes are not supported in Pytorch"
105
+ ):
106
+ out_pt = x_pt [[1 , 2 ], ::- 1 ]
107
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
108
+ assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedSubtensor )
109
+ compare_pytorch_and_py (out_fg , [x_np ])
103
110
104
- def test_pytorch_IncSubtensor ():
105
- rng = np .random .default_rng (42 )
106
111
107
- x_np = rng .uniform (- 1 , 1 , size = (3 , 4 , 5 )).astype (config .floatX )
108
- x_pt = pt .constant (np .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 )).astype (config .floatX ))
112
+ def test_pytorch_SetSubtensor ():
113
+ x_pt = pt .tensor3 ("x" )
114
+ x_test = np .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 )).astype (config .floatX )
109
115
110
116
# "Set" basic indices
111
117
st_pt = pt .as_tensor_variable (np .array (- 10.0 , dtype = config .floatX ))
112
118
out_pt = pt_subtensor .set_subtensor (x_pt [1 , 2 , 3 ], st_pt )
113
119
assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
114
- out_fg = FunctionGraph ([], [out_pt ])
115
- compare_pytorch_and_py (out_fg , [])
120
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
121
+ compare_pytorch_and_py (out_fg , [x_test ])
116
122
117
123
st_pt = pt .as_tensor_variable (np .r_ [- 1.0 , 0.0 ].astype (config .floatX ))
118
124
out_pt = pt_subtensor .set_subtensor (x_pt [:2 , 0 , 0 ], st_pt )
119
125
assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
120
- out_fg = FunctionGraph ([], [out_pt ])
121
- compare_pytorch_and_py (out_fg , [])
126
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
127
+ compare_pytorch_and_py (out_fg , [x_test ])
128
+
129
+ out_pt = pt_subtensor .set_subtensor (x_pt [0 , 1 :3 , 0 ], st_pt )
130
+ assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
131
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
132
+ compare_pytorch_and_py (out_fg , [x_test ])
122
133
123
134
out_pt = pt_subtensor .set_subtensor (x_pt [0 , 1 :3 , 0 ], st_pt )
124
135
assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
125
- out_fg = FunctionGraph ([], [out_pt ])
126
- compare_pytorch_and_py (out_fg , [])
136
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
137
+ compare_pytorch_and_py (out_fg , [x_test ])
138
+
139
+
140
+ def test_pytorch_AdvSetSubtensor ():
141
+ rng = np .random .default_rng (42 )
142
+
143
+ x_np = rng .uniform (- 1 , 1 , size = (3 , 4 , 5 )).astype (config .floatX )
144
+ x_pt = pt .tensor3 ("x" )
145
+ x_test = np .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 )).astype (config .floatX )
127
146
128
147
# "Set" advanced indices
129
148
st_pt = pt .as_tensor_variable (
130
149
rng .uniform (- 1 , 1 , size = (2 , 4 , 5 )).astype (config .floatX )
131
150
)
132
151
out_pt = pt_subtensor .set_subtensor (x_pt [np .r_ [0 , 2 ]], st_pt )
133
152
assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
134
- out_fg = FunctionGraph ([], [out_pt ])
135
- compare_pytorch_and_py (out_fg , [])
153
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
154
+ compare_pytorch_and_py (out_fg , [x_test ])
136
155
137
156
st_pt = pt .as_tensor_variable (np .r_ [- 1.0 , 0.0 ].astype (config .floatX ))
138
157
out_pt = pt_subtensor .set_subtensor (x_pt [[0 , 2 ], 0 , 0 ], st_pt )
139
158
assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
140
- out_fg = FunctionGraph ([], [out_pt ])
141
- compare_pytorch_and_py (out_fg , [])
159
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
160
+ compare_pytorch_and_py (out_fg , [x_test ])
142
161
143
162
# "Set" boolean indices
144
163
mask_pt = pt .constant (x_np > 0 )
145
164
out_pt = pt_subtensor .set_subtensor (x_pt [mask_pt ], 0.0 )
146
165
assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
147
- out_fg = FunctionGraph ([], [out_pt ])
148
- compare_pytorch_and_py (out_fg , [])
166
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
167
+ compare_pytorch_and_py (out_fg , [x_test ])
168
+
169
+
170
+ def test_pytorch_IncSubtensor ():
171
+ x_pt = pt .tensor3 ("x" )
172
+ x_test = np .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 )).astype (config .floatX )
149
173
150
174
# "Increment" basic indices
151
175
st_pt = pt .as_tensor_variable (np .array (- 10.0 , dtype = config .floatX ))
152
176
out_pt = pt_subtensor .inc_subtensor (x_pt [1 , 2 , 3 ], st_pt )
153
177
assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
154
- out_fg = FunctionGraph ([], [out_pt ])
155
- compare_pytorch_and_py (out_fg , [])
178
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
179
+ compare_pytorch_and_py (out_fg , [x_test ])
156
180
157
181
st_pt = pt .as_tensor_variable (np .r_ [- 1.0 , 0.0 ].astype (config .floatX ))
158
182
out_pt = pt_subtensor .inc_subtensor (x_pt [:2 , 0 , 0 ], st_pt )
159
183
assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
160
- out_fg = FunctionGraph ([], [out_pt ])
161
- compare_pytorch_and_py (out_fg , [])
184
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
185
+ compare_pytorch_and_py (out_fg , [x_test ])
162
186
163
- out_pt = pt_subtensor .set_subtensor (x_pt [0 , 1 :3 , 0 ], st_pt )
164
- assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
165
- out_fg = FunctionGraph ([], [out_pt ])
166
- compare_pytorch_and_py (out_fg , [])
187
+
188
+ def test_pytorch_AvdancedIncSubtensor ():
189
+ rng = np .random .default_rng (42 )
190
+
191
+ x_np = rng .uniform (- 1 , 1 , size = (3 , 4 , 5 )).astype (config .floatX )
192
+ x_pt = pt .tensor3 ("x" )
193
+ x_test = np .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 )).astype (config .floatX )
167
194
168
195
# "Increment" advanced indices
169
196
st_pt = pt .as_tensor_variable (
170
197
rng .uniform (- 1 , 1 , size = (2 , 4 , 5 )).astype (config .floatX )
171
198
)
172
199
out_pt = pt_subtensor .inc_subtensor (x_pt [np .r_ [0 , 2 ]], st_pt )
173
200
assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
174
- out_fg = FunctionGraph ([], [out_pt ])
175
- compare_pytorch_and_py (out_fg , [])
201
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
202
+ compare_pytorch_and_py (out_fg , [x_test ])
176
203
177
204
st_pt = pt .as_tensor_variable (np .r_ [- 1.0 , 0.0 ].astype (config .floatX ))
178
205
out_pt = pt_subtensor .inc_subtensor (x_pt [[0 , 2 ], 0 , 0 ], st_pt )
179
206
assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
180
- out_fg = FunctionGraph ([], [out_pt ])
181
- compare_pytorch_and_py (out_fg , [])
207
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
208
+ compare_pytorch_and_py (out_fg , [x_test ])
182
209
183
210
# "Increment" boolean indices
184
211
mask_pt = pt .constant (x_np > 0 )
185
212
out_pt = pt_subtensor .set_subtensor (x_pt [mask_pt ], 1.0 )
186
213
assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
187
- out_fg = FunctionGraph ([], [out_pt ])
188
- compare_pytorch_and_py (out_fg , [])
214
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
215
+ compare_pytorch_and_py (out_fg , [x_test ])
189
216
190
217
st_pt = pt .as_tensor_variable (x_np [[0 , 2 ], 0 , :3 ])
191
218
out_pt = pt_subtensor .set_subtensor (x_pt [[0 , 2 ], 0 , :3 ], st_pt )
192
219
assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
193
- out_fg = FunctionGraph ([], [out_pt ])
194
- compare_pytorch_and_py (out_fg , [])
220
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
221
+ compare_pytorch_and_py (out_fg , [x_test ])
195
222
196
223
st_pt = pt .as_tensor_variable (x_np [[0 , 2 ], 0 , :3 ])
197
224
out_pt = pt_subtensor .inc_subtensor (x_pt [[0 , 2 ], 0 , :3 ], st_pt )
198
225
assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
199
- out_fg = FunctionGraph ([], [out_pt ])
200
- compare_pytorch_and_py (out_fg , [])
226
+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
227
+ compare_pytorch_and_py (out_fg , [x_test ])
0 commit comments