Skip to content

Commit 9858b33

Browse files
Ch0ronomatoIan SchweerricardoV94
authored
Implement ScalarLoop in torch backend (#958)
* Add for loop based scalar loop * Pass all loop tests * Fetch constants from op * Add while loop test * Fix while loop and nasty stack over dtypes * Disable compile here based on CI result * Fix mypy signature * Remove unnecessary torch stack * Only call .cpu when necessary * Recursive false for torch compiler * Add elemwise test * Late import torch * Do iteration instead of vmap for elemwise * Clean up and add description * Add unit test to verify iteration * Refactor to ravel method * Fix unpacking Co-authored-by: Ricardo Vieira <[email protected]> * Fix comment * Remove extra return * Update test * Add single carry test * Remove compiler disable * Better name * Lint * Better docstring * Pr comments --------- Co-authored-by: Ian Schweer <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 07bd48d commit 9858b33

File tree

4 files changed

+165
-9
lines changed

4 files changed

+165
-9
lines changed

pytensor/link/pytorch/dispatch/elemwise.py

+39
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44

55
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
6+
from pytensor.scalar import ScalarLoop
67
from pytensor.tensor.elemwise import DimShuffle, Elemwise
78
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum
89
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@@ -11,6 +12,7 @@
1112
@pytorch_funcify.register(Elemwise)
1213
def pytorch_funcify_Elemwise(op, node, **kwargs):
1314
scalar_op = op.scalar_op
15+
1416
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
1517

1618
def check_special_scipy(func_name):
@@ -33,6 +35,9 @@ def elemwise_fn(*inputs):
3335
Elemwise._check_runtime_broadcast(node, inputs)
3436
return base_fn(*inputs)
3537

38+
elif isinstance(scalar_op, ScalarLoop):
39+
return elemwise_ravel_fn(base_fn, op, node, **kwargs)
40+
3641
else:
3742

3843
def elemwise_fn(*inputs):
@@ -176,3 +181,37 @@ def softmax_grad(dy, sm):
176181
return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm
177182

178183
return softmax_grad
184+
185+
186+
def elemwise_ravel_fn(base_fn, op, node, **kwargs):
187+
"""
188+
Dispatch methods using `.item()` (ScalarLoop + Elemwise) is common, but vmap
189+
in torch has a limitation: https://github.com/pymc-devs/pytensor/issues/1031,
190+
Instead, we can ravel all the inputs, broadcasted according to torch
191+
"""
192+
193+
n_outputs = len(node.outputs)
194+
195+
def elemwise_fn(*inputs):
196+
bcasted_inputs = torch.broadcast_tensors(*inputs)
197+
raveled_inputs = [inp.ravel() for inp in bcasted_inputs]
198+
199+
out_shape = bcasted_inputs[0].size()
200+
out_size = out_shape.numel()
201+
raveled_outputs = [torch.empty(out_size) for out in node.outputs]
202+
203+
for i in range(out_size):
204+
core_outs = base_fn(*(inp[i] for inp in raveled_inputs))
205+
if n_outputs == 1:
206+
raveled_outputs[0][i] = core_outs
207+
else:
208+
for o in range(n_outputs):
209+
raveled_outputs[o][i] = core_outs[o]
210+
211+
outputs = tuple(out.view(out_shape) for out in raveled_outputs)
212+
if n_outputs == 1:
213+
return outputs[0]
214+
else:
215+
return outputs
216+
217+
return elemwise_fn

pytensor/link/pytorch/dispatch/scalar.py

+35
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Cast,
88
ScalarOp,
99
)
10+
from pytensor.scalar.loop import ScalarLoop
1011
from pytensor.scalar.math import Softplus
1112

1213

@@ -62,3 +63,37 @@ def cast(x):
6263
@pytorch_funcify.register(Softplus)
6364
def pytorch_funcify_Softplus(op, node, **kwargs):
6465
return torch.nn.Softplus()
66+
67+
68+
@pytorch_funcify.register(ScalarLoop)
69+
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
70+
update = pytorch_funcify(op.fgraph, **kwargs)
71+
state_length = op.nout
72+
if op.is_while:
73+
74+
def scalar_loop(steps, *start_and_constants):
75+
carry, constants = (
76+
start_and_constants[:state_length],
77+
start_and_constants[state_length:],
78+
)
79+
done = True
80+
for _ in range(steps):
81+
*carry, done = update(*carry, *constants)
82+
if torch.any(done):
83+
break
84+
return *carry, done
85+
else:
86+
87+
def scalar_loop(steps, *start_and_constants):
88+
carry, constants = (
89+
start_and_constants[:state_length],
90+
start_and_constants[state_length:],
91+
)
92+
for _ in range(steps):
93+
carry = update(*carry, *constants)
94+
if len(node.outputs) == 1:
95+
return carry[0]
96+
else:
97+
return carry
98+
99+
return scalar_loop

pytensor/link/pytorch/linker.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -54,34 +54,30 @@ def __init__(self, fn, gen_functors):
5454
self.fn = torch.compile(fn)
5555
self.gen_functors = gen_functors.copy()
5656

57-
def __call__(self, *args, **kwargs):
57+
def __call__(self, *inputs, **kwargs):
5858
import pytensor.link.utils
5959

6060
# set attrs
6161
for n, fn in self.gen_functors:
6262
setattr(pytensor.link.utils, n[1:], fn)
6363

64-
res = self.fn(*args, **kwargs)
64+
# Torch does not accept numpy inputs and may return GPU objects
65+
outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs)
6566

6667
# unset attrs
6768
for n, _ in self.gen_functors:
6869
if getattr(pytensor.link.utils, n[1:], False):
6970
delattr(pytensor.link.utils, n[1:])
7071

71-
return res
72+
return tuple(out.cpu().numpy() for out in outs)
7273

7374
def __del__(self):
7475
del self.gen_functors
7576

7677
inner_fn = wrapper(fn, self.gen_functors)
7778
self.gen_functors = []
7879

79-
# Torch does not accept numpy inputs and may return GPU objects
80-
def fn(*inputs, inner_fn=inner_fn):
81-
outs = inner_fn(*(pytorch_typify(inp) for inp in inputs))
82-
return tuple(out.cpu().numpy() for out in outs)
83-
84-
return fn
80+
return inner_fn
8581

8682
def create_thunk_inputs(self, storage_map):
8783
thunk_inputs = []

tests/link/pytorch/test_basic.py

+86
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import pytest
66

7+
import pytensor.tensor as pt
78
import pytensor.tensor.basic as ptb
89
from pytensor.compile.builders import OpFromGraph
910
from pytensor.compile.function import function
@@ -17,7 +18,10 @@
1718
from pytensor.ifelse import ifelse
1819
from pytensor.link.pytorch.linker import PytorchLinker
1920
from pytensor.raise_op import CheckAndRaise
21+
from pytensor.scalar import float64, int64
22+
from pytensor.scalar.loop import ScalarLoop
2023
from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus
24+
from pytensor.tensor.elemwise import Elemwise
2125
from pytensor.tensor.type import matrices, matrix, scalar, vector
2226

2327

@@ -385,3 +389,85 @@ def test_pytorch_softplus():
385389
out = softplus(x)
386390
f = FunctionGraph([x], [out])
387391
compare_pytorch_and_py(f, [np.random.rand(3)])
392+
393+
394+
def test_ScalarLoop():
395+
n_steps = int64("n_steps")
396+
x0 = float64("x0")
397+
const = float64("const")
398+
x = x0 + const
399+
400+
op = ScalarLoop(init=[x0], constant=[const], update=[x])
401+
x = op(n_steps, x0, const)
402+
403+
fn = function([n_steps, x0, const], x, mode=pytorch_mode)
404+
np.testing.assert_allclose(fn(5, 0, 1), 5)
405+
np.testing.assert_allclose(fn(5, 0, 2), 10)
406+
np.testing.assert_allclose(fn(4, 3, -1), -1)
407+
408+
409+
def test_ScalarLoop_while():
410+
n_steps = int64("n_steps")
411+
x0 = float64("x0")
412+
x = x0 + 1
413+
until = x >= 10
414+
415+
op = ScalarLoop(init=[x0], update=[x], until=until)
416+
fn = function([n_steps, x0], op(n_steps, x0), mode=pytorch_mode)
417+
for res, expected in zip(
418+
[fn(n_steps=20, x0=0), fn(n_steps=20, x0=1), fn(n_steps=5, x0=1)],
419+
[[10, True], [10, True], [6, False]],
420+
strict=True,
421+
):
422+
np.testing.assert_allclose(res[0], np.array(expected[0]))
423+
np.testing.assert_allclose(res[1], np.array(expected[1]))
424+
425+
426+
def test_ScalarLoop_Elemwise_single_carries():
427+
n_steps = int64("n_steps")
428+
x0 = float64("x0")
429+
x = x0 * 2
430+
until = x >= 10
431+
432+
scalarop = ScalarLoop(init=[x0], update=[x], until=until)
433+
op = Elemwise(scalarop)
434+
435+
n_steps = pt.scalar("n_steps", dtype="int32")
436+
x0 = pt.vector("x0", dtype="float32")
437+
state, done = op(n_steps, x0)
438+
439+
f = FunctionGraph([n_steps, x0], [state, done])
440+
args = [
441+
np.array(10).astype("int32"),
442+
np.arange(0, 5).astype("float32"),
443+
]
444+
compare_pytorch_and_py(
445+
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
446+
)
447+
448+
449+
def test_ScalarLoop_Elemwise_multi_carries():
450+
n_steps = int64("n_steps")
451+
x0 = float64("x0")
452+
x1 = float64("x1")
453+
x = x0 * 2
454+
x1_n = x1 * 3
455+
until = x >= 10
456+
457+
scalarop = ScalarLoop(init=[x0, x1], update=[x, x1_n], until=until)
458+
op = Elemwise(scalarop)
459+
460+
n_steps = pt.scalar("n_steps", dtype="int32")
461+
x0 = pt.vector("x0", dtype="float32")
462+
x1 = pt.tensor("c0", dtype="float32", shape=(7, 3, 1))
463+
*states, done = op(n_steps, x0, x1)
464+
465+
f = FunctionGraph([n_steps, x0, x1], [*states, done])
466+
args = [
467+
np.array(10).astype("int32"),
468+
np.arange(0, 5).astype("float32"),
469+
np.random.rand(7, 3, 1).astype("float32"),
470+
]
471+
compare_pytorch_and_py(
472+
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
473+
)

0 commit comments

Comments
 (0)