Skip to content

Commit 448e558

Browse files
committed
Replace RNG update output in RV lift rewrites
Otherwise we end up with multiple RVs if the RNGs are an output / used elsewhere in the function
1 parent 2d81cca commit 448e558

File tree

2 files changed

+47
-12
lines changed

2 files changed

+47
-12
lines changed

pytensor/tensor/random/rewriting/basic.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
130130
if ds_op.drop:
131131
return False
132132

133+
[ds_rv] = node.outputs
133134
rv_node = node.inputs[0].owner
134135

135136
if not (rv_node and isinstance(rv_node.op, RandomVariable)):
@@ -182,10 +183,17 @@ def local_dimshuffle_rv_lift(fgraph, node):
182183
if config.compute_test_value != "off":
183184
compute_test_value(new_node)
184185

185-
new_rv = new_node.default_output()
186+
new_next_rng, new_rv = new_node.outputs
187+
186188
if rv.name:
187189
new_rv.name = f"{rv.name}_lifted"
188-
return [new_rv]
190+
191+
# We replace uses of the dimshuffled RV by the new RV
192+
# And uses of the old RNG update by the new RNG update
193+
return {
194+
ds_rv: new_rv,
195+
next_rng: new_next_rng,
196+
}
189197

190198

191199
@node_rewriter([Subtensor, AdvancedSubtensor1, AdvancedSubtensor])
@@ -217,7 +225,7 @@ def is_nd_advanced_idx(idx, dtype) -> bool:
217225

218226
rv_op = rv_node.op
219227
rng, size, *dist_params = rv_node.inputs
220-
rv = rv_node.default_output()
228+
next_rng, rv = rv_node.outputs
221229

222230
# If no one else is using the underlying `RandomVariable`, then we can
223231
# do this; otherwise, the graph would be internally inconsistent.
@@ -331,8 +339,13 @@ def is_nd_advanced_idx(idx, dtype) -> bool:
331339

332340
# Create new RV
333341
new_node = rv_op.make_node(rng, new_size, *new_dist_params)
334-
new_rv = new_node.default_output()
342+
new_next_rng, new_rv = new_node.outputs
335343

336344
copy_stack_trace(rv, new_rv)
337345

338-
return [new_rv]
346+
# We replace uses of the indexed RV by the new RV
347+
# And uses of the old RNG update by the new RNG update
348+
return {
349+
indexed_rv: new_rv,
350+
next_rng: new_next_rng,
351+
}

tests/tensor/random/rewriting/test_basic.py

+29-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
from collections.abc import Sequence
2+
13
import numpy as np
24
import pytest
35

46
import pytensor.tensor as pt
57
from pytensor import config, shared
68
from pytensor.compile.function import function
79
from pytensor.compile.mode import Mode
8-
from pytensor.graph.basic import Constant
10+
from pytensor.graph.basic import Constant, Variable, ancestors
911
from pytensor.graph.fg import FunctionGraph
1012
from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter
1113
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
@@ -36,6 +38,16 @@
3638
no_mode = Mode("py", RewriteDatabaseQuery(include=[], exclude=[]))
3739

3840

41+
def count_rv_nodes_in_graph(outputs: Sequence[Variable]) -> int:
42+
return len(
43+
{
44+
var.owner
45+
for var in ancestors(outputs)
46+
if var.owner and isinstance(var.owner.op, RandomVariable)
47+
}
48+
)
49+
50+
3951
def apply_local_rewrite_to_rv(
4052
rewrite, op_fn, dist_op, dist_params, size, rng, name=None
4153
):
@@ -58,7 +70,14 @@ def apply_local_rewrite_to_rv(
5870
s_pt.tag.test_value = s
5971
size_pt.append(s_pt)
6072

61-
dist_st = op_fn(dist_op(*dist_params_pt, size=size_pt, rng=rng, name=name))
73+
next_rng, rv = dist_op(
74+
*dist_params_pt, size=size_pt, rng=rng, name=name
75+
).owner.outputs
76+
dist_st = op_fn(rv)
77+
78+
assert (
79+
count_rv_nodes_in_graph([dist_st, next_rng]) == 1
80+
), "Function expects a single RV in the graph"
6281

6382
f_inputs = [
6483
p
@@ -72,13 +91,16 @@ def apply_local_rewrite_to_rv(
7291

7392
f_rewritten = function(
7493
f_inputs,
75-
dist_st,
94+
[dist_st, next_rng],
7695
mode=mode,
7796
)
7897

79-
(new_out,) = f_rewritten.maker.fgraph.outputs
98+
new_rv, new_next_rng = f_rewritten.maker.fgraph.outputs
99+
assert (
100+
count_rv_nodes_in_graph([new_rv, new_next_rng]) == 1
101+
), "Rewritten should have a single RV in the graph"
80102

81-
return new_out, f_inputs, dist_st, f_rewritten
103+
return new_rv, f_inputs, dist_st, f_rewritten
82104

83105

84106
class TestRVExpraProps(RandomVariable):
@@ -422,7 +444,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
422444

423445
arg_values = [p.get_test_value() for p in f_inputs]
424446
res_base = f_base(*arg_values)
425-
res_rewritten = f_rewritten(*arg_values)
447+
res_rewritten, _ = f_rewritten(*arg_values)
426448

427449
np.testing.assert_allclose(res_base, res_rewritten, rtol=rtol)
428450

@@ -825,7 +847,7 @@ def is_subtensor_or_dimshuffle_subtensor(inp) -> bool:
825847

826848
arg_values = [p.get_test_value() for p in f_inputs]
827849
res_base = f_base(*arg_values)
828-
res_rewritten = f_rewritten(*arg_values)
850+
res_rewritten, _ = f_rewritten(*arg_values)
829851

830852
np.testing.assert_allclose(res_base, res_rewritten, rtol=1e-3, atol=1e-2)
831853

0 commit comments

Comments
 (0)