Skip to content

Commit 6de3151

Browse files
Ch0ronomatoricardoV94
authored andcommitted
Improve torch elemwise operator
1 parent 0ba554b commit 6de3151

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

pytensor/link/pytorch/dispatch/elemwise.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,21 @@ def pytorch_funcify_Elemwise(op, node, **kwargs):
1111
scalar_op = op.scalar_op
1212
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
1313

14-
def elemwise_fn(*inputs):
15-
Elemwise._check_runtime_broadcast(node, inputs)
16-
return base_fn(*inputs)
14+
if hasattr(scalar_op, "nfunc_spec") and hasattr(torch, scalar_op.nfunc_spec[0]):
15+
# torch can handle this scalar
16+
# broadcast, we'll let it.
17+
def elemwise_fn(*inputs):
18+
Elemwise._check_runtime_broadcast(node, inputs)
19+
return base_fn(*inputs)
20+
else:
21+
22+
def elemwise_fn(*inputs):
23+
Elemwise._check_runtime_broadcast(node, inputs)
24+
broadcast_inputs = torch.broadcast_tensors(*inputs)
25+
ufunc = base_fn
26+
for _ in range(broadcast_inputs[0].dim()):
27+
ufunc = torch.vmap(ufunc)
28+
return ufunc(*broadcast_inputs)
1729

1830
return elemwise_fn
1931

tests/link/pytorch/test_elemwise.py

+33
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import numpy as np
22
import pytest
33

4+
import pytensor
45
import pytensor.tensor as pt
56
import pytensor.tensor.math as ptm
67
from pytensor.configdefaults import config
78
from pytensor.graph.fg import FunctionGraph
9+
from pytensor.scalar.basic import ScalarOp, get_scalar_type
10+
from pytensor.tensor.elemwise import Elemwise
811
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
912
from pytensor.tensor.type import matrix, tensor, tensor3, vector
1013
from tests.link.pytorch.test_basic import compare_pytorch_and_py
@@ -150,3 +153,33 @@ def test_cast():
150153
fgraph, [np.arange(6, dtype="float32").reshape(2, 3)]
151154
)
152155
assert res.dtype == torch.int32
156+
157+
158+
def test_vmap_elemwise():
159+
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
160+
161+
class TestOp(ScalarOp):
162+
def __init__(self):
163+
super().__init__(
164+
output_types_preference=lambda *_: [get_scalar_type("float32")]
165+
)
166+
self.call_shapes = []
167+
self.nin = 1
168+
169+
def perform(self, *_):
170+
raise RuntimeError("In perform")
171+
172+
@pytorch_funcify.register(TestOp)
173+
def relu(op, node, **kwargs):
174+
def relu(row):
175+
op.call_shapes.append(row.size())
176+
return torch.max(torch.zeros_like(row), row)
177+
178+
return relu
179+
180+
x = matrix("x", shape=(2, 3))
181+
op = TestOp()
182+
f = pytensor.function([x], Elemwise(op)(x), mode="PYTORCH")
183+
vals = torch.zeros(2, 3).normal_()
184+
np.testing.assert_allclose(f(vals), torch.relu(vals))
185+
assert op.call_shapes == [torch.Size([])], op.call_shapes

0 commit comments

Comments
 (0)