@@ -100,101 +100,128 @@ def test_pytorch_AdvSubtensor():
100100 out_fg = FunctionGraph ([x_pt , a_pt ], [out_pt ])
101101 compare_pytorch_and_py (out_fg , [x_np , a_np ])
102102
103+ with pytest .raises (
104+ NotImplementedError , match = "Negative step sizes are not supported in Pytorch"
105+ ):
106+ out_pt = x_pt [[1 , 2 ], ::- 1 ]
107+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
108+ assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedSubtensor )
109+ compare_pytorch_and_py (out_fg , [x_np ])
103110
104- def test_pytorch_IncSubtensor ():
105- rng = np .random .default_rng (42 )
106111
107- x_np = rng .uniform (- 1 , 1 , size = (3 , 4 , 5 )).astype (config .floatX )
108- x_pt = pt .constant (np .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 )).astype (config .floatX ))
112+ def test_pytorch_SetSubtensor ():
113+ x_pt = pt .tensor3 ("x" )
114+ x_test = np .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 )).astype (config .floatX )
109115
110116 # "Set" basic indices
111117 st_pt = pt .as_tensor_variable (np .array (- 10.0 , dtype = config .floatX ))
112118 out_pt = pt_subtensor .set_subtensor (x_pt [1 , 2 , 3 ], st_pt )
113119 assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
114- out_fg = FunctionGraph ([], [out_pt ])
115- compare_pytorch_and_py (out_fg , [])
120+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
121+ compare_pytorch_and_py (out_fg , [x_test ])
116122
117123 st_pt = pt .as_tensor_variable (np .r_ [- 1.0 , 0.0 ].astype (config .floatX ))
118124 out_pt = pt_subtensor .set_subtensor (x_pt [:2 , 0 , 0 ], st_pt )
119125 assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
120- out_fg = FunctionGraph ([], [out_pt ])
121- compare_pytorch_and_py (out_fg , [])
126+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
127+ compare_pytorch_and_py (out_fg , [x_test ])
128+
129+ out_pt = pt_subtensor .set_subtensor (x_pt [0 , 1 :3 , 0 ], st_pt )
130+ assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
131+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
132+ compare_pytorch_and_py (out_fg , [x_test ])
122133
123134 out_pt = pt_subtensor .set_subtensor (x_pt [0 , 1 :3 , 0 ], st_pt )
124135 assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
125- out_fg = FunctionGraph ([], [out_pt ])
126- compare_pytorch_and_py (out_fg , [])
136+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
137+ compare_pytorch_and_py (out_fg , [x_test ])
138+
139+
140+ def test_pytorch_AdvSetSubtensor ():
141+ rng = np .random .default_rng (42 )
142+
143+ x_np = rng .uniform (- 1 , 1 , size = (3 , 4 , 5 )).astype (config .floatX )
144+ x_pt = pt .tensor3 ("x" )
145+ x_test = np .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 )).astype (config .floatX )
127146
128147 # "Set" advanced indices
129148 st_pt = pt .as_tensor_variable (
130149 rng .uniform (- 1 , 1 , size = (2 , 4 , 5 )).astype (config .floatX )
131150 )
132151 out_pt = pt_subtensor .set_subtensor (x_pt [np .r_ [0 , 2 ]], st_pt )
133152 assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
134- out_fg = FunctionGraph ([], [out_pt ])
135- compare_pytorch_and_py (out_fg , [])
153+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
154+ compare_pytorch_and_py (out_fg , [x_test ])
136155
137156 st_pt = pt .as_tensor_variable (np .r_ [- 1.0 , 0.0 ].astype (config .floatX ))
138157 out_pt = pt_subtensor .set_subtensor (x_pt [[0 , 2 ], 0 , 0 ], st_pt )
139158 assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
140- out_fg = FunctionGraph ([], [out_pt ])
141- compare_pytorch_and_py (out_fg , [])
159+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
160+ compare_pytorch_and_py (out_fg , [x_test ])
142161
143162 # "Set" boolean indices
144163 mask_pt = pt .constant (x_np > 0 )
145164 out_pt = pt_subtensor .set_subtensor (x_pt [mask_pt ], 0.0 )
146165 assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
147- out_fg = FunctionGraph ([], [out_pt ])
148- compare_pytorch_and_py (out_fg , [])
166+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
167+ compare_pytorch_and_py (out_fg , [x_test ])
168+
169+
170+ def test_pytorch_IncSubtensor ():
171+ x_pt = pt .tensor3 ("x" )
172+ x_test = np .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 )).astype (config .floatX )
149173
150174 # "Increment" basic indices
151175 st_pt = pt .as_tensor_variable (np .array (- 10.0 , dtype = config .floatX ))
152176 out_pt = pt_subtensor .inc_subtensor (x_pt [1 , 2 , 3 ], st_pt )
153177 assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
154- out_fg = FunctionGraph ([], [out_pt ])
155- compare_pytorch_and_py (out_fg , [])
178+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
179+ compare_pytorch_and_py (out_fg , [x_test ])
156180
157181 st_pt = pt .as_tensor_variable (np .r_ [- 1.0 , 0.0 ].astype (config .floatX ))
158182 out_pt = pt_subtensor .inc_subtensor (x_pt [:2 , 0 , 0 ], st_pt )
159183 assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
160- out_fg = FunctionGraph ([], [out_pt ])
161- compare_pytorch_and_py (out_fg , [])
184+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
185+ compare_pytorch_and_py (out_fg , [x_test ])
162186
163- out_pt = pt_subtensor .set_subtensor (x_pt [0 , 1 :3 , 0 ], st_pt )
164- assert isinstance (out_pt .owner .op , pt_subtensor .IncSubtensor )
165- out_fg = FunctionGraph ([], [out_pt ])
166- compare_pytorch_and_py (out_fg , [])
187+
188+ def test_pytorch_AvdancedIncSubtensor ():
189+ rng = np .random .default_rng (42 )
190+
191+ x_np = rng .uniform (- 1 , 1 , size = (3 , 4 , 5 )).astype (config .floatX )
192+ x_pt = pt .tensor3 ("x" )
193+ x_test = np .arange (3 * 4 * 5 ).reshape ((3 , 4 , 5 )).astype (config .floatX )
167194
168195 # "Increment" advanced indices
169196 st_pt = pt .as_tensor_variable (
170197 rng .uniform (- 1 , 1 , size = (2 , 4 , 5 )).astype (config .floatX )
171198 )
172199 out_pt = pt_subtensor .inc_subtensor (x_pt [np .r_ [0 , 2 ]], st_pt )
173200 assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
174- out_fg = FunctionGraph ([], [out_pt ])
175- compare_pytorch_and_py (out_fg , [])
201+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
202+ compare_pytorch_and_py (out_fg , [x_test ])
176203
177204 st_pt = pt .as_tensor_variable (np .r_ [- 1.0 , 0.0 ].astype (config .floatX ))
178205 out_pt = pt_subtensor .inc_subtensor (x_pt [[0 , 2 ], 0 , 0 ], st_pt )
179206 assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
180- out_fg = FunctionGraph ([], [out_pt ])
181- compare_pytorch_and_py (out_fg , [])
207+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
208+ compare_pytorch_and_py (out_fg , [x_test ])
182209
183210 # "Increment" boolean indices
184211 mask_pt = pt .constant (x_np > 0 )
185212 out_pt = pt_subtensor .set_subtensor (x_pt [mask_pt ], 1.0 )
186213 assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
187- out_fg = FunctionGraph ([], [out_pt ])
188- compare_pytorch_and_py (out_fg , [])
214+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
215+ compare_pytorch_and_py (out_fg , [x_test ])
189216
190217 st_pt = pt .as_tensor_variable (x_np [[0 , 2 ], 0 , :3 ])
191218 out_pt = pt_subtensor .set_subtensor (x_pt [[0 , 2 ], 0 , :3 ], st_pt )
192219 assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
193- out_fg = FunctionGraph ([], [out_pt ])
194- compare_pytorch_and_py (out_fg , [])
220+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
221+ compare_pytorch_and_py (out_fg , [x_test ])
195222
196223 st_pt = pt .as_tensor_variable (x_np [[0 , 2 ], 0 , :3 ])
197224 out_pt = pt_subtensor .inc_subtensor (x_pt [[0 , 2 ], 0 , :3 ], st_pt )
198225 assert isinstance (out_pt .owner .op , pt_subtensor .AdvancedIncSubtensor )
199- out_fg = FunctionGraph ([], [out_pt ])
200- compare_pytorch_and_py (out_fg , [])
226+ out_fg = FunctionGraph ([x_pt ], [out_pt ])
227+ compare_pytorch_and_py (out_fg , [x_test ])
0 commit comments