Skip to content

Commit be358ed

Browse files
ricardoV94twiecki
authored andcommitted
Implement Cast in PyTorch backend
1 parent be6a032 commit be358ed

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

Diff for: pytensor/link/pytorch/dispatch/scalar.py

+11
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
44
from pytensor.scalar.basic import (
5+
Cast,
56
ScalarOp,
67
)
78

@@ -38,3 +39,13 @@ def pytorch_func(*args):
3839
)
3940

4041
return pytorch_func
42+
43+
44+
@pytorch_funcify.register(Cast)
45+
def pytorch_funcify_Cast(op: Cast, node, **kwargs):
46+
dtype = getattr(torch, op.o_type.dtype)
47+
48+
def cast(x):
49+
return x.to(dtype=dtype)
50+
51+
return cast

Diff for: tests/link/pytorch/test_elemwise.py

+13
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from tests.link.pytorch.test_basic import compare_pytorch_and_py
1111

1212

13+
torch = pytest.importorskip("torch")
14+
15+
1316
def test_pytorch_Dimshuffle():
1417
a_pt = matrix("a")
1518

@@ -137,3 +140,13 @@ def test_softmax_grad(axis):
137140
out = SoftmaxGrad(axis=axis)(dy, sm)
138141
fgraph = FunctionGraph([dy, sm], [out])
139142
compare_pytorch_and_py(fgraph, [dy_value, sm_value])
143+
144+
145+
def test_cast():
146+
x = matrix("x", dtype="float32")
147+
out = pt.cast(x, "int32")
148+
fgraph = FunctionGraph([x], [out])
149+
_, [res] = compare_pytorch_and_py(
150+
fgraph, [np.arange(6, dtype="float32").reshape(2, 3)]
151+
)
152+
assert res.dtype == torch.int32

0 commit comments

Comments
 (0)