Skip to content

Commit afbd84b

Browse files
committed
Remove model.rng_seeder and allow compile_pymc to reseed RNG variables in compiled function
Sampling functions now also accept RandomState or Generators as input to random_seed, similarly to how random_state behaves in scipy distributions. For backwards compatibility this argument was not renamed.
1 parent fc9dabe commit afbd84b

21 files changed

+307
-288
lines changed

RELEASE-NOTES.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ Also check out the [milestones](https://github.com/pymc-devs/pymc/milestones) fo
2525

2626
All of the above apply to:
2727

28-
Signature and default parameters changed for several distributions:
28+
⚠ Random seeding behavior changed!
29+
- Sampling results will differ from those of V3 when passing the same random_state as before. They will be consitent across subsequent V4 releases unless mentioned otherwise.
30+
- Sampling functions no longer respect user-specified global seeding! Always pass `random_seed` to ensure reproducible behavior.
31+
- Signature and default parameters changed for several distributions:
2932
- `pm.StudentT` now requires either `sigma` or `lam` as kwarg (see [#5628](https://github.com/pymc-devs/pymc/pull/5628))
3033
- `pm.StudentT` now requires `nu` to be specified (no longer defaults to 1) (see [#5628](https://github.com/pymc-devs/pymc/pull/5628))
3134
- `pm.AsymmetricLaplace` positional arguments re-ordered (see [#5628](https://github.com/pymc-devs/pymc/pull/5628))

benchmarks/benchmarks/benchmarks.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,16 @@ def time_glm_hierarchical_init(self, init):
174174
"""How long does it take to run the initialization."""
175175
with glm_hierarchical_model():
176176
pm.init_nuts(
177-
init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains)
177+
init=init,
178+
chains=self.chains,
179+
progressbar=False,
180+
random_seed=np.arange(self.chains),
178181
)
179182

180183
def track_glm_hierarchical_ess(self, init):
181184
with glm_hierarchical_model():
182185
start, step = pm.init_nuts(
183-
init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains)
186+
init=init, chains=self.chains, progressbar=False, random_seed=np.arange(self.chains)
184187
)
185188
t0 = time.time()
186189
idata = pm.sample(
@@ -201,7 +204,7 @@ def track_marginal_mixture_model_ess(self, init):
201204
model, start = mixture_model()
202205
with model:
203206
_, step = pm.init_nuts(
204-
init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains)
207+
init=init, chains=self.chains, progressbar=False, random_seed=np.arange(self.chains)
205208
)
206209
start = [{k: v for k, v in start.items()} for _ in range(self.chains)]
207210
t0 = time.time()

pymc/aesaraf.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -929,10 +929,29 @@ def reseed_rngs(
929929

930930

931931
def compile_pymc(
932-
inputs, outputs, mode=None, **kwargs
932+
inputs,
933+
outputs,
934+
random_seed: SeedSequenceSeed = None,
935+
mode=None,
936+
**kwargs,
933937
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
934938
"""Use ``aesara.function`` with specialized pymc rewrites always enabled.
935939
940+
This function also ensures shared RandomState/Generator used by RandomVariables
941+
in the graph are updated across calls, to ensure independent draws.
942+
943+
Parameters
944+
----------
945+
inputs: list of TensorVariables, optional
946+
Inputs of the compiled Aesara function
947+
outputs: list of TensorVariables, optional
948+
Outputs of the compiled Aesara function
949+
random_seed: int, array-like of int or SeedSequence, optional
950+
Seed used to override any RandomState/Generator shared variables in the graph.
951+
If not specified, the value of original shared variables will still be overwritten.
952+
mode: optional
953+
Aesara mode used to compile the function
954+
936955
Included rewrites
937956
-----------------
938957
random_make_inplace
@@ -952,7 +971,6 @@ def compile_pymc(
952971
"""
953972
# Create an update mapping of RandomVariable's RNG so that it is automatically
954973
# updated after every function call
955-
# TODO: This won't work for variables with InnerGraphs (Scan and OpFromGraph)
956974
rng_updates = {}
957975
output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs]
958976
for random_var in (
@@ -966,11 +984,17 @@ def compile_pymc(
966984
rng = random_var.owner.inputs[0]
967985
if not hasattr(rng, "default_update"):
968986
rng_updates[rng] = random_var.owner.outputs[0]
987+
else:
988+
rng_updates[rng] = rng.default_update
969989
else:
970990
update_fn = getattr(random_var.owner.op, "update", None)
971991
if update_fn is not None:
972992
rng_updates.update(update_fn(random_var.owner))
973993

994+
# We always reseed random variables as this provides RNGs with no chances of collision
995+
if rng_updates:
996+
reseed_rngs(rng_updates.keys(), random_seed)
997+
974998
# If called inside a model context, see if check_bounds flag is set to False
975999
try:
9761000
from pymc.model import modelcontext

pymc/distributions/censored.py

+1-18
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,12 @@ def dist(cls, dist, lower, upper, **kwargs):
8585
check_dist_not_registered(dist)
8686
return super().dist([dist, lower, upper], **kwargs)
8787

88-
@classmethod
89-
def num_rngs(cls, *args, **kwargs):
90-
return 1
91-
9288
@classmethod
9389
def ndim_supp(cls, *dist_params):
9490
return 0
9591

9692
@classmethod
97-
def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
93+
def rv_op(cls, dist, lower=None, upper=None, size=None):
9894

9995
lower = at.constant(-np.inf) if lower is None else at.as_tensor_variable(lower)
10096
upper = at.constant(np.inf) if upper is None else at.as_tensor_variable(upper)
@@ -112,21 +108,8 @@ def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
112108
rv_out.tag.lower = lower
113109
rv_out.tag.upper = upper
114110

115-
if rngs is not None:
116-
rv_out = cls._change_rngs(rv_out, rngs)
117-
118111
return rv_out
119112

120-
@classmethod
121-
def _change_rngs(cls, rv, new_rngs):
122-
(new_rng,) = new_rngs
123-
dist_node = rv.tag.dist.owner
124-
lower = rv.tag.lower
125-
upper = rv.tag.upper
126-
olg_rng, size, dtype, *dist_params = dist_node.inputs
127-
new_dist = dist_node.op.make_node(new_rng, size, dtype, *dist_params).default_output()
128-
return cls.rv_op(new_dist, lower, upper)
129-
130113
@classmethod
131114
def change_size(cls, rv, new_size, expand=False):
132115
dist = rv.tag.dist

pymc/distributions/distribution.py

+3-19
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from abc import ABCMeta
2121
from functools import singledispatch
22-
from typing import Callable, Iterable, Optional, Sequence, Tuple, Union, cast
22+
from typing import Callable, Optional, Sequence, Tuple, Union, cast
2323

2424
import aesara
2525
import numpy as np
@@ -258,13 +258,10 @@ def __new__(
258258
if not isinstance(name, string_types):
259259
raise TypeError(f"Name needs to be a string but got: {name}")
260260

261-
if rng is None:
262-
rng = model.next_rng()
263-
264261
# Create the RV and process dims and observed to determine
265262
# a shape by which the created RV may need to be resized.
266263
rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
267-
cls=cls, dims=dims, model=model, observed=observed, args=args, rng=rng, **kwargs
264+
cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs
268265
)
269266

270267
if resize_shape:
@@ -383,9 +380,6 @@ class SymbolicDistribution:
383380
to a canonical parametrization. It should call `super().dist()`, passing a
384381
list with the default parameters as the first and only non keyword argument,
385382
followed by other keyword arguments like size and rngs, and return the result
386-
cls.num_rngs
387-
Returns the number of rngs given the same arguments passed by the user when
388-
calling the distribution
389383
cls.ndim_supp
390384
Returns the support of the symbolic distribution, given the default set of
391385
parameters. This may not always be constant, for instance if the symbolic
@@ -402,7 +396,6 @@ def __new__(
402396
cls,
403397
name: str,
404398
*args,
405-
rngs: Optional[Iterable] = None,
406399
dims: Optional[Dims] = None,
407400
initval=None,
408401
observed=None,
@@ -419,8 +412,6 @@ def __new__(
419412
A distribution class that inherits from SymbolicDistribution.
420413
name : str
421414
Name for the new model variable.
422-
rngs : optional
423-
Random number generator to use for the RandomVariable(s) in the graph.
424415
dims : tuple, optional
425416
A tuple of dimension names known to the model.
426417
initval : optional
@@ -468,17 +459,10 @@ def __new__(
468459
if not isinstance(name, string_types):
469460
raise TypeError(f"Name needs to be a string but got: {name}")
470461

471-
if rngs is None:
472-
# Instead of passing individual RNG variables we could pass a RandomStream
473-
# and let the classes create as many RNGs as they need
474-
rngs = [model.next_rng() for _ in range(cls.num_rngs(*args, **kwargs))]
475-
elif not isinstance(rngs, (list, tuple)):
476-
rngs = [rngs]
477-
478462
# Create the RV and process dims and observed to determine
479463
# a shape by which the created RV may need to be resized.
480464
rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
481-
cls=cls, dims=dims, model=model, observed=observed, args=args, rngs=rngs, **kwargs
465+
cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs
482466
)
483467

484468
if resize_shape:

pymc/distributions/mixture.py

+4-30
Original file line numberDiff line numberDiff line change
@@ -205,27 +205,15 @@ def dist(cls, w, comp_dists, **kwargs):
205205
w = at.as_tensor_variable(w)
206206
return super().dist([w, *comp_dists], **kwargs)
207207

208-
@classmethod
209-
def num_rngs(cls, w, comp_dists, **kwargs):
210-
if not isinstance(comp_dists, (tuple, list)):
211-
# comp_dists is a single component
212-
comp_dists = [comp_dists]
213-
return len(comp_dists) + 1
214-
215208
@classmethod
216209
def ndim_supp(cls, weights, *components):
217210
# We already checked that all components have the same support dimensionality
218211
return components[0].owner.op.ndim_supp
219212

220213
@classmethod
221-
def rv_op(cls, weights, *components, size=None, rngs=None):
222-
# Update rngs if provided
223-
if rngs is not None:
224-
components = cls._reseed_components(rngs, *components)
225-
*_, mix_indexes_rng = rngs
226-
else:
227-
# Create new rng for the mix_indexes internal RV
228-
mix_indexes_rng = aesara.shared(np.random.default_rng())
214+
def rv_op(cls, weights, *components, size=None):
215+
# Create new rng for the mix_indexes internal RV
216+
mix_indexes_rng = aesara.shared(np.random.default_rng())
229217

230218
single_component = len(components) == 1
231219
ndim_supp = components[0].owner.op.ndim_supp
@@ -317,19 +305,6 @@ def rv_op(cls, weights, *components, size=None, rngs=None):
317305

318306
return mix_out
319307

320-
@classmethod
321-
def _reseed_components(cls, rngs, *components):
322-
*components_rngs, mix_indexes_rng = rngs
323-
assert len(components) == len(components_rngs)
324-
new_components = []
325-
for component, component_rng in zip(components, components_rngs):
326-
component_node = component.owner
327-
old_rng, *inputs = component_node.inputs
328-
new_components.append(
329-
component_node.op.make_node(component_rng, *inputs).default_output()
330-
)
331-
return new_components
332-
333308
@classmethod
334309
def _resize_components(cls, size, *components):
335310
if len(components) == 1:
@@ -345,7 +320,6 @@ def _resize_components(cls, size, *components):
345320
def change_size(cls, rv, new_size, expand=False):
346321
weights = rv.tag.weights
347322
components = rv.tag.components
348-
rngs = [component.owner.inputs[0] for component in components] + [rv.tag.choices_rng]
349323

350324
if expand:
351325
component = rv.tag.components[0]
@@ -360,7 +334,7 @@ def change_size(cls, rv, new_size, expand=False):
360334

361335
components = cls._resize_components(new_size, *components)
362336

363-
return cls.rv_op(weights, *components, rngs=rngs, size=None)
337+
return cls.rv_op(weights, *components, size=None)
364338

365339

366340
@_get_measurable_outputs.register(MarginalMixtureRV)

pymc/distributions/timeseries.py

+3-21
Original file line numberDiff line numberDiff line change
@@ -494,28 +494,12 @@ def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant:
494494

495495
return ar_order
496496

497-
@classmethod
498-
def num_rngs(cls, *args, **kwargs):
499-
return 2
500-
501497
@classmethod
502498
def ndim_supp(cls, *args):
503499
return 1
504500

505501
@classmethod
506-
def rv_op(cls, rhos, sigma, init_dist, steps, ar_order, constant_term, size=None, rngs=None):
507-
508-
if rngs is None:
509-
rngs = [
510-
aesara.shared(np.random.default_rng(seed))
511-
for seed in np.random.SeedSequence().spawn(2)
512-
]
513-
(init_dist_rng, noise_rng) = rngs
514-
# Re-seed init_dist
515-
if init_dist.owner.inputs[0] is not init_dist_rng:
516-
_, *inputs = init_dist.owner.inputs
517-
init_dist = init_dist.owner.op.make_node(init_dist_rng, *inputs).default_output()
518-
502+
def rv_op(cls, rhos, sigma, init_dist, steps, ar_order, constant_term, size=None):
519503
# Init dist should have shape (*size, ar_order)
520504
if size is not None:
521505
batch_size = size
@@ -543,6 +527,8 @@ def rv_op(cls, rhos, sigma, init_dist, steps, ar_order, constant_term, size=None
543527
rhos_bcast_shape_ = (*rhos_bcast_shape_[:-1], rhos_bcast_shape_[-1] + 1)
544528
rhos_bcast_ = at.broadcast_to(rhos_, rhos_bcast_shape_)
545529

530+
noise_rng = aesara.shared(np.random.default_rng())
531+
546532
def step(*args):
547533
*prev_xs, reversed_rhos, sigma, rng = args
548534
if constant_term:
@@ -581,16 +567,12 @@ def change_size(cls, rv, new_size, expand=False):
581567
old_size = rv.shape[:-1]
582568
new_size = at.concatenate([new_size, old_size])
583569

584-
init_dist_rng = rv.owner.inputs[2].owner.inputs[0]
585-
noise_rng = rv.owner.inputs[-1]
586-
587570
op = rv.owner.op
588571
return cls.rv_op(
589572
*rv.owner.inputs,
590573
ar_order=op.ar_order,
591574
constant_term=op.constant_term,
592575
size=new_size,
593-
rngs=(init_dist_rng, noise_rng),
594576
)
595577

596578

0 commit comments

Comments
 (0)