Skip to content

Commit 7c9e31f

Browse files
committed
Set default_update of RandomVariables rng in compile_rv_inplace
1 parent 3dd1253 commit 7c9e31f

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

pymc3/aesaraf.py

+16
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,22 @@ def compile_rv_inplace(inputs, outputs, mode=None, **kwargs):
895895
Using this function ensures that compiled functions containing random
896896
variables will produce new samples on each call.
897897
"""
898+
899+
# Avoid circular dependency
900+
from pymc3.distributions import NoDistribution
901+
902+
# Set the default update of a NoDistribution RNG so that it is automatically
903+
# updated after every function call
904+
output_to_list = outputs if isinstance(outputs, list) else [outputs]
905+
for rv in (
906+
node
907+
for node in walk_model(output_to_list, walk_past_rvs=True)
908+
if node.owner and isinstance(node.owner.op, NoDistribution)
909+
):
910+
rng = rv.owner.inputs[0]
911+
if not hasattr(rng, "default_update"):
912+
rng.default_update = rv.owner.outputs[0]
913+
898914
mode = get_mode(mode)
899915
opt_qry = mode.provided_optimizer.including("random_make_inplace")
900916
mode = Mode(linker=mode.linker, optimizer=opt_qry)

pymc3/distributions/simulator.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,9 @@ def logp(cls, value, sim_rv):
244244

245245
# Create a new simulatorRV with identical inputs as the original one
246246
sim_op = sim_rv.owner.op
247-
new_rng, sim_value = sim_op.make_node(rng, *sim_rv.owner.inputs[1:]).outputs
247+
sim_value = sim_op.make_node(rng, *sim_rv.owner.inputs[1:]).default_output()
248248
sim_value.name = "sim_value"
249249

250-
# Automatically update rng when expression is evaluated
251-
sim_value.update = (rng, new_rng)
252-
rng.default_update = new_rng
253-
254250
return sim_op.distance(
255251
sim_op.epsilon,
256252
sim_op.sum_stat(value),

0 commit comments

Comments
 (0)