Skip to content

Commit 60246ad

Browse files
committed
Implement basic Alloc Ops in PyTorch
1 parent 320bac4 commit 60246ad

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

pytensor/link/pytorch/dispatch/basic.py

+31
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pytensor.graph.fg import FunctionGraph
77
from pytensor.link.utils import fgraph_to_python
88
from pytensor.raise_op import CheckAndRaise
9+
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange
910

1011

1112
@singledispatch
@@ -58,3 +59,33 @@ def deepcopyop(x):
5859
return x.clone()
5960

6061
return deepcopyop
62+
63+
64+
@pytorch_funcify.register(AllocEmpty)
65+
def pytorch_funcify_AllocEmpty(op, **kwargs):
66+
dtype = getattr(torch, op.dtype)
67+
68+
def alloc_empty(*shape):
69+
return torch.empty(shape, dtype=dtype)
70+
71+
return alloc_empty
72+
73+
74+
@pytorch_funcify.register(Alloc)
75+
def pytorch_funcify_alloc(op, **kwargs):
76+
def alloc(value, *shape):
77+
out = torch.empty(shape, dtype=value.dtype)
78+
out[...] = value # broadcast value to shape of out
79+
return out
80+
81+
return alloc
82+
83+
84+
@pytorch_funcify.register(ARange)
85+
def pytorch_funcify_arange(op, **kwargs):
86+
dtype = getattr(torch, op.dtype)
87+
88+
def arange(start, stop, step):
89+
return torch.arange(start, stop, step, dtype=dtype)
90+
91+
return arange

tests/link/pytorch/test_basic.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pytensor.graph.fg import FunctionGraph
1313
from pytensor.graph.op import Op
1414
from pytensor.raise_op import CheckAndRaise
15+
from pytensor.tensor import alloc, arange, as_tensor, empty
1516
from pytensor.tensor.type import scalar, vector
1617

1718

@@ -191,7 +192,7 @@ def test_shared_updates(device):
191192
assert isinstance(a.get_value(), np.ndarray)
192193

193194

194-
def test_pytorch_checkandraise():
195+
def test_checkandraise():
195196
check_and_raise = CheckAndRaise(AssertionError, "testing")
196197

197198
x = scalar("x")
@@ -203,3 +204,34 @@ def test_pytorch_checkandraise():
203204
with pytest.raises(AssertionError, match="testing"):
204205
y_fn(0.0)
205206
assert y_fn(4).item() == 4
207+
208+
209+
def test_alloc_and_empty():
210+
dim0 = as_tensor(5, dtype="int64")
211+
dim1 = scalar("dim1", dtype="int64")
212+
213+
out = empty((dim0, dim1, 3), dtype="float32")
214+
fn = function([dim1], out, mode=pytorch_mode)
215+
res = fn(7)
216+
assert res.shape == (5, 7, 3)
217+
assert res.dtype == torch.float32
218+
219+
v = vector("v", shape=(3,), dtype="float64")
220+
out = alloc(v, (dim0, dim1, 3))
221+
compare_pytorch_and_py(
222+
FunctionGraph([v, dim1], [out]),
223+
[np.array([1, 2, 3]), np.array(7)],
224+
)
225+
226+
227+
def test_arange():
228+
start = scalar("start", dtype="int64")
229+
stop = scalar("stop", dtype="int64")
230+
step = scalar("step", dtype="int64")
231+
232+
out = arange(start, stop, step, dtype="int16")
233+
234+
compare_pytorch_and_py(
235+
FunctionGraph([start, stop, step], [out]),
236+
[np.array(1), np.array(10), np.array(2)],
237+
)

0 commit comments

Comments
 (0)