|
1 | 1 | import numpy as np
|
2 | 2 | import pytest
|
3 | 3 |
|
| 4 | +import pytensor |
4 | 5 | import pytensor.tensor as pt
|
5 | 6 | import pytensor.tensor.math as ptm
|
6 | 7 | from pytensor.configdefaults import config
|
7 | 8 | from pytensor.graph.fg import FunctionGraph
|
| 9 | +from pytensor.scalar.basic import ScalarOp, get_scalar_type |
| 10 | +from pytensor.tensor.elemwise import Elemwise |
8 | 11 | from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
|
9 | 12 | from pytensor.tensor.type import matrix, tensor, tensor3, vector
|
10 | 13 | from tests.link.pytorch.test_basic import compare_pytorch_and_py
|
@@ -150,3 +153,33 @@ def test_cast():
|
150 | 153 | fgraph, [np.arange(6, dtype="float32").reshape(2, 3)]
|
151 | 154 | )
|
152 | 155 | assert res.dtype == torch.int32
|
| 156 | + |
| 157 | + |
| 158 | +def test_vmap_elemwise(): |
| 159 | + from pytensor.link.pytorch.dispatch.basic import pytorch_funcify |
| 160 | + |
| 161 | + class TestOp(ScalarOp): |
| 162 | + def __init__(self): |
| 163 | + super().__init__( |
| 164 | + output_types_preference=lambda *_: [get_scalar_type("float32")] |
| 165 | + ) |
| 166 | + self.call_shapes = [] |
| 167 | + self.nin = 1 |
| 168 | + |
| 169 | + def perform(self, *_): |
| 170 | + raise RuntimeError("In perform") |
| 171 | + |
| 172 | + @pytorch_funcify.register(TestOp) |
| 173 | + def relu(op, node, **kwargs): |
| 174 | + def relu(row): |
| 175 | + op.call_shapes.append(row.size()) |
| 176 | + return torch.max(torch.zeros_like(row), row) |
| 177 | + |
| 178 | + return relu |
| 179 | + |
| 180 | + x = matrix("x", shape=(2, 3)) |
| 181 | + op = TestOp() |
| 182 | + f = pytensor.function([x], Elemwise(op)(x), mode="PYTORCH") |
| 183 | + vals = torch.zeros(2, 3).normal_() |
| 184 | + np.testing.assert_allclose(f(vals), torch.relu(vals)) |
| 185 | + assert op.call_shapes == [torch.Size([])], op.call_shapes |
0 commit comments