12
12
from pytensor .graph .fg import FunctionGraph
13
13
from pytensor .graph .op import Op
14
14
from pytensor .raise_op import CheckAndRaise
15
+ from pytensor .tensor import alloc , arange , as_tensor , empty
15
16
from pytensor .tensor .type import scalar , vector
16
17
17
18
@@ -191,7 +192,7 @@ def test_shared_updates(device):
191
192
assert isinstance (a .get_value (), np .ndarray )
192
193
193
194
194
- def test_pytorch_checkandraise ():
195
+ def test_checkandraise ():
195
196
check_and_raise = CheckAndRaise (AssertionError , "testing" )
196
197
197
198
x = scalar ("x" )
@@ -203,3 +204,34 @@ def test_pytorch_checkandraise():
203
204
with pytest .raises (AssertionError , match = "testing" ):
204
205
y_fn (0.0 )
205
206
assert y_fn (4 ).item () == 4
207
+
208
+
209
+ def test_alloc_and_empty ():
210
+ dim0 = as_tensor (5 , dtype = "int64" )
211
+ dim1 = scalar ("dim1" , dtype = "int64" )
212
+
213
+ out = empty ((dim0 , dim1 , 3 ), dtype = "float32" )
214
+ fn = function ([dim1 ], out , mode = pytorch_mode )
215
+ res = fn (7 )
216
+ assert res .shape == (5 , 7 , 3 )
217
+ assert res .dtype == torch .float32
218
+
219
+ v = vector ("v" , shape = (3 ,), dtype = "float64" )
220
+ out = alloc (v , (dim0 , dim1 , 3 ))
221
+ compare_pytorch_and_py (
222
+ FunctionGraph ([v , dim1 ], [out ]),
223
+ [np .array ([1 , 2 , 3 ]), np .array (7 )],
224
+ )
225
+
226
+
227
+ def test_arange ():
228
+ start = scalar ("start" , dtype = "int64" )
229
+ stop = scalar ("stop" , dtype = "int64" )
230
+ step = scalar ("step" , dtype = "int64" )
231
+
232
+ out = arange (start , stop , step , dtype = "int16" )
233
+
234
+ compare_pytorch_and_py (
235
+ FunctionGraph ([start , stop , step ], [out ]),
236
+ [np .array (1 ), np .array (10 ), np .array (2 )],
237
+ )
0 commit comments