1212from pytensor .graph .fg import FunctionGraph
1313from pytensor .graph .op import Op
1414from pytensor .raise_op import CheckAndRaise
15+ from pytensor .tensor import alloc , arange , as_tensor , empty
1516from pytensor .tensor .type import scalar , vector
1617
1718
@@ -191,7 +192,7 @@ def test_shared_updates(device):
191192 assert isinstance (a .get_value (), np .ndarray )
192193
193194
194- def test_pytorch_checkandraise ():
195+ def test_checkandraise ():
195196 check_and_raise = CheckAndRaise (AssertionError , "testing" )
196197
197198 x = scalar ("x" )
@@ -203,3 +204,34 @@ def test_pytorch_checkandraise():
203204 with pytest .raises (AssertionError , match = "testing" ):
204205 y_fn (0.0 )
205206 assert y_fn (4 ).item () == 4
207+
208+
209+ def test_alloc_and_empty ():
210+ dim0 = as_tensor (5 , dtype = "int64" )
211+ dim1 = scalar ("dim1" , dtype = "int64" )
212+
213+ out = empty ((dim0 , dim1 , 3 ), dtype = "float32" )
214+ fn = function ([dim1 ], out , mode = pytorch_mode )
215+ res = fn (7 )
216+ assert res .shape == (5 , 7 , 3 )
217+ assert res .dtype == torch .float32
218+
219+ v = vector ("v" , shape = (3 ,), dtype = "float64" )
220+ out = alloc (v , (dim0 , dim1 , 3 ))
221+ compare_pytorch_and_py (
222+ FunctionGraph ([v , dim1 ], [out ]),
223+ [np .array ([1 , 2 , 3 ]), np .array (7 )],
224+ )
225+
226+
227+ def test_arange ():
228+ start = scalar ("start" , dtype = "int64" )
229+ stop = scalar ("stop" , dtype = "int64" )
230+ step = scalar ("step" , dtype = "int64" )
231+
232+ out = arange (start , stop , step , dtype = "int16" )
233+
234+ compare_pytorch_and_py (
235+ FunctionGraph ([start , stop , step ], [out ]),
236+ [np .array (1 ), np .array (10 ), np .array (2 )],
237+ )
0 commit comments