diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 84fb504788..e042593d88 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -25,7 +25,10 @@ Also check out the [milestones](https://github.com/pymc-devs/pymc/milestones) fo All of the above apply to: -Signature and default parameters changed for several distributions: +⚠ Random seeding behavior changed! + - Sampling results will differ from those of V3 when passing the same random_state as before. They will be consitent across subsequent V4 releases unless mentioned otherwise. + - Sampling functions no longer respect user-specified global seeding! Always pass `random_seed` to ensure reproducible behavior. +- Signature and default parameters changed for several distributions: - `pm.StudentT` now requires either `sigma` or `lam` as kwarg (see [#5628](https://github.com/pymc-devs/pymc/pull/5628)) - `pm.StudentT` now requires `nu` to be specified (no longer defaults to 1) (see [#5628](https://github.com/pymc-devs/pymc/pull/5628)) - `pm.AsymmetricLaplace` positional arguments re-ordered (see [#5628](https://github.com/pymc-devs/pymc/pull/5628)) diff --git a/benchmarks/benchmarks/benchmarks.py b/benchmarks/benchmarks/benchmarks.py index ea3c464c6d..9ebe1494cd 100644 --- a/benchmarks/benchmarks/benchmarks.py +++ b/benchmarks/benchmarks/benchmarks.py @@ -174,13 +174,16 @@ def time_glm_hierarchical_init(self, init): """How long does it take to run the initialization.""" with glm_hierarchical_model(): pm.init_nuts( - init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains) + init=init, + chains=self.chains, + progressbar=False, + random_seed=np.arange(self.chains), ) def track_glm_hierarchical_ess(self, init): with glm_hierarchical_model(): start, step = pm.init_nuts( - init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains) + init=init, chains=self.chains, progressbar=False, random_seed=np.arange(self.chains) ) t0 = time.time() idata = pm.sample( @@ -201,7 +204,7 @@ def track_marginal_mixture_model_ess(self, init): model, start = mixture_model() with model: _, step = pm.init_nuts( - init=init, chains=self.chains, progressbar=False, seeds=np.arange(self.chains) + init=init, chains=self.chains, progressbar=False, random_seed=np.arange(self.chains) ) start = [{k: v for k, v in start.items()} for _ in range(self.chains)] t0 = time.time() diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index b77cada8c3..0160635d0b 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -20,6 +20,7 @@ Iterable, List, Optional, + Sequence, Set, Tuple, Union, @@ -893,11 +894,64 @@ def local_check_parameter_to_ninf_switch(fgraph, node): ) +def find_rng_nodes(variables: Iterable[TensorVariable]): + """Return RNG variables in a graph""" + return [ + node + for node in graph_inputs(variables) + if isinstance( + node, + ( + at.random.var.RandomStateSharedVariable, + at.random.var.RandomGeneratorSharedVariable, + ), + ) + ] + + +SeedSequenceSeed = Optional[Union[int, Sequence[int], np.ndarray, np.random.SeedSequence]] + + +def reseed_rngs( + rngs: Sequence[SharedVariable], + seed: SeedSequenceSeed, +) -> None: + """Create a new set of RandomState/Generator for each rng based on a seed""" + bit_generators = [ + np.random.PCG64(sub_seed) for sub_seed in np.random.SeedSequence(seed).spawn(len(rngs)) + ] + for rng, bit_generator in zip(rngs, bit_generators): + if isinstance(rng, at.random.var.RandomStateSharedVariable): + new_rng = np.random.RandomState(bit_generator) + else: + new_rng = np.random.Generator(bit_generator) + rng.set_value(new_rng, borrow=True) + + def compile_pymc( - inputs, outputs, mode=None, **kwargs + inputs, + outputs, + random_seed: SeedSequenceSeed = None, + mode=None, + **kwargs, ) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]: """Use ``aesara.function`` with specialized pymc rewrites always enabled. + This function also ensures shared RandomState/Generator used by RandomVariables + in the graph are updated across calls, to ensure independent draws. + + Parameters + ---------- + inputs: list of TensorVariables, optional + Inputs of the compiled Aesara function + outputs: list of TensorVariables, optional + Outputs of the compiled Aesara function + random_seed: int, array-like of int or SeedSequence, optional + Seed used to override any RandomState/Generator shared variables in the graph. + If not specified, the value of original shared variables will still be overwritten. + mode: optional + Aesara mode used to compile the function + Included rewrites ----------------- random_make_inplace @@ -917,7 +971,6 @@ def compile_pymc( """ # Create an update mapping of RandomVariable's RNG so that it is automatically # updated after every function call - # TODO: This won't work for variables with InnerGraphs (Scan and OpFromGraph) rng_updates = {} output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs] for random_var in ( @@ -931,11 +984,17 @@ def compile_pymc( rng = random_var.owner.inputs[0] if not hasattr(rng, "default_update"): rng_updates[rng] = random_var.owner.outputs[0] + else: + rng_updates[rng] = rng.default_update else: update_fn = getattr(random_var.owner.op, "update", None) if update_fn is not None: rng_updates.update(update_fn(random_var.owner)) + # We always reseed random variables as this provides RNGs with no chances of collision + if rng_updates: + reseed_rngs(rng_updates.keys(), random_seed) + # If called inside a model context, see if check_bounds flag is set to False try: from pymc.model import modelcontext diff --git a/pymc/distributions/censored.py b/pymc/distributions/censored.py index 57dde86218..c38681d79f 100644 --- a/pymc/distributions/censored.py +++ b/pymc/distributions/censored.py @@ -85,16 +85,12 @@ def dist(cls, dist, lower, upper, **kwargs): check_dist_not_registered(dist) return super().dist([dist, lower, upper], **kwargs) - @classmethod - def num_rngs(cls, *args, **kwargs): - return 1 - @classmethod def ndim_supp(cls, *dist_params): return 0 @classmethod - def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None): + def rv_op(cls, dist, lower=None, upper=None, size=None): lower = at.constant(-np.inf) if lower is None else at.as_tensor_variable(lower) upper = at.constant(np.inf) if upper is None else at.as_tensor_variable(upper) @@ -112,21 +108,8 @@ def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None): rv_out.tag.lower = lower rv_out.tag.upper = upper - if rngs is not None: - rv_out = cls._change_rngs(rv_out, rngs) - return rv_out - @classmethod - def _change_rngs(cls, rv, new_rngs): - (new_rng,) = new_rngs - dist_node = rv.tag.dist.owner - lower = rv.tag.lower - upper = rv.tag.upper - olg_rng, size, dtype, *dist_params = dist_node.inputs - new_dist = dist_node.op.make_node(new_rng, size, dtype, *dist_params).default_output() - return cls.rv_op(new_dist, lower, upper) - @classmethod def change_size(cls, rv, new_size, expand=False): dist = rv.tag.dist diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 138b3cd253..98b661ca23 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -19,7 +19,7 @@ from abc import ABCMeta from functools import singledispatch -from typing import Callable, Iterable, Optional, Sequence, Tuple, Union, cast +from typing import Callable, Optional, Sequence, Tuple, Union, cast import aesara import numpy as np @@ -258,13 +258,10 @@ def __new__( if not isinstance(name, string_types): raise TypeError(f"Name needs to be a string but got: {name}") - if rng is None: - rng = model.next_rng() - # Create the RV and process dims and observed to determine # a shape by which the created RV may need to be resized. rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape( - cls=cls, dims=dims, model=model, observed=observed, args=args, rng=rng, **kwargs + cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs ) if resize_shape: @@ -383,9 +380,6 @@ class SymbolicDistribution: to a canonical parametrization. It should call `super().dist()`, passing a list with the default parameters as the first and only non keyword argument, followed by other keyword arguments like size and rngs, and return the result - cls.num_rngs - Returns the number of rngs given the same arguments passed by the user when - calling the distribution cls.ndim_supp Returns the support of the symbolic distribution, given the default set of parameters. This may not always be constant, for instance if the symbolic @@ -402,7 +396,6 @@ def __new__( cls, name: str, *args, - rngs: Optional[Iterable] = None, dims: Optional[Dims] = None, initval=None, observed=None, @@ -419,8 +412,6 @@ def __new__( A distribution class that inherits from SymbolicDistribution. name : str Name for the new model variable. - rngs : optional - Random number generator to use for the RandomVariable(s) in the graph. dims : tuple, optional A tuple of dimension names known to the model. initval : optional @@ -468,17 +459,10 @@ def __new__( if not isinstance(name, string_types): raise TypeError(f"Name needs to be a string but got: {name}") - if rngs is None: - # Instead of passing individual RNG variables we could pass a RandomStream - # and let the classes create as many RNGs as they need - rngs = [model.next_rng() for _ in range(cls.num_rngs(*args, **kwargs))] - elif not isinstance(rngs, (list, tuple)): - rngs = [rngs] - # Create the RV and process dims and observed to determine # a shape by which the created RV may need to be resized. rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape( - cls=cls, dims=dims, model=model, observed=observed, args=args, rngs=rngs, **kwargs + cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs ) if resize_shape: diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index a1c6129ebe..952016d0b7 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -205,27 +205,15 @@ def dist(cls, w, comp_dists, **kwargs): w = at.as_tensor_variable(w) return super().dist([w, *comp_dists], **kwargs) - @classmethod - def num_rngs(cls, w, comp_dists, **kwargs): - if not isinstance(comp_dists, (tuple, list)): - # comp_dists is a single component - comp_dists = [comp_dists] - return len(comp_dists) + 1 - @classmethod def ndim_supp(cls, weights, *components): # We already checked that all components have the same support dimensionality return components[0].owner.op.ndim_supp @classmethod - def rv_op(cls, weights, *components, size=None, rngs=None): - # Update rngs if provided - if rngs is not None: - components = cls._reseed_components(rngs, *components) - *_, mix_indexes_rng = rngs - else: - # Create new rng for the mix_indexes internal RV - mix_indexes_rng = aesara.shared(np.random.default_rng()) + def rv_op(cls, weights, *components, size=None): + # Create new rng for the mix_indexes internal RV + mix_indexes_rng = aesara.shared(np.random.default_rng()) single_component = len(components) == 1 ndim_supp = components[0].owner.op.ndim_supp @@ -317,19 +305,6 @@ def rv_op(cls, weights, *components, size=None, rngs=None): return mix_out - @classmethod - def _reseed_components(cls, rngs, *components): - *components_rngs, mix_indexes_rng = rngs - assert len(components) == len(components_rngs) - new_components = [] - for component, component_rng in zip(components, components_rngs): - component_node = component.owner - old_rng, *inputs = component_node.inputs - new_components.append( - component_node.op.make_node(component_rng, *inputs).default_output() - ) - return new_components - @classmethod def _resize_components(cls, size, *components): if len(components) == 1: @@ -345,7 +320,6 @@ def _resize_components(cls, size, *components): def change_size(cls, rv, new_size, expand=False): weights = rv.tag.weights components = rv.tag.components - rngs = [component.owner.inputs[0] for component in components] + [rv.tag.choices_rng] if expand: component = rv.tag.components[0] @@ -360,7 +334,7 @@ def change_size(cls, rv, new_size, expand=False): components = cls._resize_components(new_size, *components) - return cls.rv_op(weights, *components, rngs=rngs, size=None) + return cls.rv_op(weights, *components, size=None) @_get_measurable_outputs.register(MarginalMixtureRV) diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 5e1bb1e1a2..9c0b547a67 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -494,28 +494,12 @@ def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant: return ar_order - @classmethod - def num_rngs(cls, *args, **kwargs): - return 2 - @classmethod def ndim_supp(cls, *args): return 1 @classmethod - def rv_op(cls, rhos, sigma, init_dist, steps, ar_order, constant_term, size=None, rngs=None): - - if rngs is None: - rngs = [ - aesara.shared(np.random.default_rng(seed)) - for seed in np.random.SeedSequence().spawn(2) - ] - (init_dist_rng, noise_rng) = rngs - # Re-seed init_dist - if init_dist.owner.inputs[0] is not init_dist_rng: - _, *inputs = init_dist.owner.inputs - init_dist = init_dist.owner.op.make_node(init_dist_rng, *inputs).default_output() - + def rv_op(cls, rhos, sigma, init_dist, steps, ar_order, constant_term, size=None): # Init dist should have shape (*size, ar_order) if size is not None: batch_size = size @@ -543,6 +527,8 @@ def rv_op(cls, rhos, sigma, init_dist, steps, ar_order, constant_term, size=None rhos_bcast_shape_ = (*rhos_bcast_shape_[:-1], rhos_bcast_shape_[-1] + 1) rhos_bcast_ = at.broadcast_to(rhos_, rhos_bcast_shape_) + noise_rng = aesara.shared(np.random.default_rng()) + def step(*args): *prev_xs, reversed_rhos, sigma, rng = args if constant_term: @@ -581,16 +567,12 @@ def change_size(cls, rv, new_size, expand=False): old_size = rv.shape[:-1] new_size = at.concatenate([new_size, old_size]) - init_dist_rng = rv.owner.inputs[2].owner.inputs[0] - noise_rng = rv.owner.inputs[-1] - op = rv.owner.op return cls.rv_op( *rv.owner.inputs, ar_order=op.ar_order, constant_term=op.constant_term, size=new_size, - rngs=(init_dist_rng, noise_rng), ) diff --git a/pymc/initial_point.py b/pymc/initial_point.py index b1149c73e0..11e80d0fa2 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -20,11 +20,11 @@ import aesara.tensor as at import numpy as np -from aesara.graph.basic import Variable, graph_inputs +from aesara.graph.basic import Variable from aesara.graph.fg import FunctionGraph from aesara.tensor.var import TensorVariable -from pymc.aesaraf import compile_pymc +from pymc.aesaraf import compile_pymc, find_rng_nodes, reseed_rngs from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name StartDict = Dict[Union[Variable, str], Union[np.ndarray, Variable, str]] @@ -150,19 +150,6 @@ def make_initial_point_fn( If `True` the returned variables will correspond to transformed initial values. """ - def find_rng_nodes(variables): - return [ - node - for node in graph_inputs(variables) - if isinstance( - node, - ( - at.random.var.RandomStateSharedVariable, - at.random.var.RandomGeneratorSharedVariable, - ), - ) - ] - sdict_overrides = convert_str_to_rv_dict(model, overrides or {}) initval_strats = { **model.initial_values, @@ -208,16 +195,7 @@ def make_seeded_function(func): @functools.wraps(func) def inner(seed, *args, **kwargs): - seeds = [ - np.random.PCG64(sub_seed) - for sub_seed in np.random.SeedSequence(seed).spawn(len(rngs)) - ] - for rng, seed in zip(rngs, seeds): - if isinstance(rng, at.random.var.RandomStateSharedVariable): - new_rng = np.random.RandomState(seed) - else: - new_rng = np.random.Generator(seed) - rng.set_value(new_rng, True) + reseed_rngs(rngs, seed) values = func(*args, **kwargs) return dict(zip(varnames, values)) diff --git a/pymc/model.py b/pymc/model.py index 59ae618b8d..185e30644c 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -44,7 +44,6 @@ from aesara.graph.basic import Constant, Variable, graph_inputs from aesara.graph.fg import FunctionGraph from aesara.tensor.random.opt import local_subtensor_rv_lift -from aesara.tensor.random.var import RandomStateSharedVariable from aesara.tensor.sharedvar import ScalarSharedVariable from aesara.tensor.var import TensorConstant, TensorVariable @@ -445,13 +444,6 @@ class Model(WithMemoization, metaclass=ContextMeta): parameters can only take on valid values you can set this to False for increased speed. This should not be used if your model contains discrete variables. - rng_seeder: int or numpy.random.RandomState - The ``numpy.random.RandomState`` used to seed the - ``RandomStateSharedVariable`` sequence used by a model - ``RandomVariable``s, or an int used to seed a new - ``numpy.random.RandomState``. If ``None``, a - ``RandomStateSharedVariable`` will be generated and used. Incremental - access to the state sequence is provided by ``Model.next_rng``. Examples -------- @@ -549,20 +541,10 @@ def __init__( name="", coords=None, check_bounds=True, - rng_seeder: Optional[Union[int, np.random.RandomState]] = None, ): self.name = self._validate_name(name) self.check_bounds = check_bounds - if rng_seeder is None: - self.rng_seeder = np.random.RandomState() - elif isinstance(rng_seeder, int): - self.rng_seeder = np.random.RandomState(rng_seeder) - else: - self.rng_seeder = rng_seeder - - # The sequence of model-generated RNGs - self.rng_seq: List[SharedVariable] = [] self._initial_values: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]] = {} if self.parent is not None: @@ -1016,8 +998,6 @@ def initial_point(self, seed=None) -> Dict[str, np.ndarray]: ip : dict Maps names of transformed variables to numeric initial values in the transformed space. """ - if seed is None: - seed = self.rng_seeder.randint(2**30, dtype=np.int64) fn = make_initial_point_fn(model=self, return_transformed=True) return Point(fn(seed), model=self) @@ -1038,20 +1018,6 @@ def set_initval(self, rv_var, initval): self.initial_values[rv_var] = initval - def next_rng(self) -> RandomStateSharedVariable: - """Generate a new ``RandomStateSharedVariable``. - - The new ``RandomStateSharedVariable`` is also added to - ``Model.rng_seq``. - """ - new_seed = self.rng_seeder.randint(2**30, dtype=np.int64) - next_rng = aesara.shared(np.random.RandomState(new_seed), borrow=True) - next_rng.tag.is_rng = True - - self.rng_seq.append(next_rng) - - return next_rng - def shape_from_dims(self, dims): shape = [] if len(set(dims)) != len(dims): @@ -1379,14 +1345,11 @@ def make_obs_var( clone=False, ) (observed_rv_var,) = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner) - # Make a clone of the RV, but change the rng so that observed and missing - # are not treated as equivalent nodes by aesara. This would happen if the - # size of the masked and unmasked array happened to coincide + # Make a clone of the RV, but let it create a new rng so that observed and + # missing are not treated as equivalent nodes by aesara. This would happen + # if the size of the masked and unmasked array happened to coincide _, size, _, *inps = observed_rv_var.owner.inputs - rng = self.model.next_rng() - observed_rv_var = observed_rv_var.owner.op( - *inps, size=size, rng=rng, name=f"{name}_observed" - ) + observed_rv_var = observed_rv_var.owner.op(*inps, size=size, name=f"{name}_observed") observed_rv_var.tag.observations = nonmissing_data self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data) diff --git a/pymc/parallel_sampling.py b/pymc/parallel_sampling.py index 7364a7021b..bb03b89ac0 100644 --- a/pymc/parallel_sampling.py +++ b/pymc/parallel_sampling.py @@ -21,7 +21,7 @@ import traceback from collections import namedtuple -from typing import Dict, List, Sequence +from typing import TYPE_CHECKING, Dict, List, Sequence import cloudpickle import numpy as np @@ -32,6 +32,10 @@ from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError +# Avoid circular import +if TYPE_CHECKING: + from pymc.sampling import RandomSeed + logger = logging.getLogger("pymc") @@ -389,7 +393,7 @@ def __init__( tune: int, chains: int, cores: int, - seeds: list, + seeds: Sequence["RandomSeed"], start_points: Sequence[Dict[str, np.ndarray]], step_method, start_chain_num: int = 0, diff --git a/pymc/sampling.py b/pymc/sampling.py index d7982e923e..9cb058a15a 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -14,7 +14,6 @@ """Functions for MCMC sampling.""" -import collections.abc as abc import logging import pickle import sys @@ -104,6 +103,9 @@ PointList: TypeAlias = List[PointType] Backend: TypeAlias = Union[BaseTrace, MultiTrace, NDArray] +RandomSeed = Optional[Union[int, Sequence[int], np.ndarray]] +RandomState = Union[RandomSeed, np.random.RandomState, np.random.Generator] + _log = logging.getLogger("pymc") @@ -253,6 +255,56 @@ def all_continuous(vars): return True +def _get_seeds_per_chain( + random_state: RandomState, + chains: int, +) -> Union[Sequence[int], np.ndarray]: + """Obtain or validate specified integer seeds per chain. + + This function process different possible sources of seeding and returns one integer + seed per chain: + 1. If the input is an integer and a single chain is requested, the input is + returned inside a tuple. + 2. If the input is a sequence or NumPy array with as many entries as chains, + the input is returned. + 3. If the input is an integer and multiple chains are requested, new unique seeds + are generated from NumPy default Generator seeded with that integer. + 4. If the input is None new unique seeds are generated from an unseeded NumPy default + Generator. + 5. If a RandomState or Generator is provided, new unique seeds are generated from it. + + Raises + ------ + ValueError + If none of the conditions above are met + """ + + def _get_unique_seeds_per_chain(integers_fn): + seeds = [] + while len(set(seeds)) != chains: + seeds = [int(seed) for seed in integers_fn(2**30, dtype=np.int64, size=chains)] + return seeds + + if random_state is None or isinstance(random_state, int): + if chains == 1 and isinstance(random_state, int): + return (random_state,) + return _get_unique_seeds_per_chain(np.random.default_rng(random_state).integers) + if isinstance(random_state, np.random.Generator): + return _get_unique_seeds_per_chain(random_state.integers) + if isinstance(random_state, np.random.RandomState): + return _get_unique_seeds_per_chain(random_state.randint) + + if not isinstance(random_state, (list, tuple, np.ndarray)): + raise ValueError(f"The `seeds` must be array-like. Got {type(random_state)} instead.") + + if len(random_state) != chains: + raise ValueError( + f"Number of seeds ({len(random_state)}) does not match the number of chains ({chains})." + ) + + return random_state + + def sample( draws: int = 1000, step=None, @@ -266,7 +318,7 @@ def sample( tune: int = 1000, progressbar: bool = True, model=None, - random_seed=None, + random_seed: RandomState = None, discard_tuned_samples: bool = True, compute_convergence_checks: bool = True, callback=None, @@ -323,9 +375,10 @@ def sample( time until completion ("expected time of arrival"; ETA). model : Model (optional if in ``with`` context) Model to sample from. The model needs to have free random variables. - random_seed : int or list of ints - Random seed(s) used by the sampling steps. A list is accepted if ``cores`` is greater than - one. + random_seed : int, array-like of int, RandomState or Generator, optional + Random seed(s) used by the sampling steps. If a list, tuple or array of ints + is passed, each entry will be used to seed each chain. A ValueError will be + raised if the length does not match the number of chains. discard_tuned_samples : bool Whether to discard posterior samples of the tune interval. compute_convergence_checks : bool, default=True @@ -434,18 +487,10 @@ def sample( if chains is None: chains = max(2, cores) + if random_seed == -1: random_seed = None - if chains == 1 and isinstance(random_seed, int): - random_seed = [random_seed] - - if random_seed is None or isinstance(random_seed, int): - if random_seed is not None: - np.random.seed(random_seed) - random_seed = [np.random.randint(2**30) for _ in range(chains)] - - if not isinstance(random_seed, abc.Iterable): - raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int") + random_seed_list = _get_seeds_per_chain(random_seed, chains) if not discard_tuned_samples and not return_inferencedata: warnings.warn( @@ -490,7 +535,7 @@ def sample( chains=chains, n_init=n_init, model=model, - seeds=random_seed, + random_seed=random_seed_list, progressbar=progressbar, jitter_max_retries=jitter_max_retries, tune=tune, @@ -506,7 +551,7 @@ def sample( jitter_rvs=filter_rvs_to_jitter(step), chains=chains, ) - initial_points = [ipfn(seed) for ipfn, seed in zip(ipfns, random_seed)] + initial_points = [ipfn(seed) for ipfn, seed in zip(ipfns, random_seed_list)] # One final check that shapes and logps at the starting points are okay. for ip in initial_points: @@ -523,7 +568,6 @@ def sample( "tune": tune, "progressbar": progressbar, "model": model, - "random_seed": random_seed, "cores": cores, "callback": callback, "discard_tuned_samples": discard_tuned_samples, @@ -542,6 +586,19 @@ def sample( ) parallel = cores > 1 and chains > 1 and not has_population_samplers + # At some point it was decided that PyMC should not set a global seed by default, + # unless the user specified a seed. This is a symptom of the fact that PyMC samplers + # are built around global seeding. This branch makes sure we maintain this unspoken + # rule. See https://github.com/pymc-devs/pymc/pull/1395. + if parallel: + # For parallel sampling we can pass the list of random seeds directly, as + # global seeding will only be called inside each process + sample_args["random_seed"] = random_seed_list + else: + # We pass None if the original random seed was None. The single core sampler + # methods will only set a global seed when it is not None. + sample_args["random_seed"] = random_seed if random_seed is None else random_seed_list + t_start = time.time() if parallel: _log.info(f"Multiprocess sampling ({chains} chains in {cores} jobs)") @@ -674,7 +731,7 @@ def _sample_many( chain: int, chains: int, start: Sequence[PointType], - random_seed: list, + random_seed: Optional[Sequence[RandomSeed]], step, callback=None, **kwargs, @@ -691,7 +748,7 @@ def _sample_many( Total number of chains to sample. start: list Starting points for each chain - random_seed: list + random_seed: list of random seeds, optional A list of seeds, one for each chain step: function Step function @@ -708,7 +765,7 @@ def _sample_many( chain=chain + i, start=start[i], step=step, - random_seed=random_seed[i], + random_seed=None if random_seed is None else random_seed[i], callback=callback, **kwargs, ) @@ -731,7 +788,7 @@ def _sample_population( chain: int, chains: int, start: Sequence[PointType], - random_seed, + random_seed: RandomSeed, step, tune: int, model, @@ -751,8 +808,7 @@ def _sample_population( The total number of chains in the population start : list Start points for each chain - random_seed : int or list of ints, optional - A list is accepted if more if ``cores`` is greater than one. + random_seed : single random seed, optional step : function Step function (should be or contain a population step method) tune : int @@ -793,7 +849,7 @@ def _sample( *, chain: int, progressbar: bool, - random_seed, + random_seed: RandomSeed, start: PointType, draws: int, step=None, @@ -815,8 +871,7 @@ def _sample( Whether or not to display a progress bar in the command line. The bar shows the percentage of completion, the sampling speed in samples per second (SPS), and the estimated remaining time until completion ("expected time of arrival"; ETA). - random_seed : int or list of ints - A list is accepted if ``cores`` is greater than one. + random_seed : single random seed start : dict Starting point in parameter space (or partial point) draws : int @@ -871,7 +926,7 @@ def iter_sample( chain: int = 0, tune: int = 0, model: Optional[Model] = None, - random_seed: Optional[Union[int, List[int]]] = None, + random_seed: RandomSeed = None, callback=None, ) -> Iterator[MultiTrace]: """Generate a trace on each iteration using the given step method. @@ -896,8 +951,7 @@ def iter_sample( tune : int, optional Number of iterations to tune (defaults to 0). model : Model (optional if in ``with`` context) - random_seed : int or list of ints, optional - A list is accepted if more if ``cores`` is greater than one. + random_seed : single random seed, optional callback : A function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw and will contain all samples for a single trace. @@ -930,7 +984,7 @@ def _iter_sample( chain: int = 0, tune: int = 0, model=None, - random_seed=None, + random_seed: RandomSeed = None, callback=None, ) -> Iterator[Tuple[BaseTrace, bool]]: """Generator for sampling one chain. (Used in singleprocess sampling.) @@ -953,8 +1007,7 @@ def _iter_sample( tune : int, optional Number of iterations to tune (defaults to 0). model : Model (optional if in ``with`` context) - random_seed : int or list of ints, optional - A list is accepted if more if ``cores`` is greater than one. + random_seed : single random seed, optional Yields ------ @@ -1194,7 +1247,7 @@ def _prepare_iter_population( parallelize: bool, tune: int, model=None, - random_seed=None, + random_seed: RandomSeed = None, progressbar=True, ) -> Iterator[Sequence[BaseTrace]]: """Prepare a PopulationStepper and traces for population sampling. @@ -1214,8 +1267,7 @@ def _prepare_iter_population( tune : int Number of iterations to tune. model : Model (optional if in ``with`` context) - random_seed : int or list of ints, optional - A list is accepted if more if ``cores`` is greater than one. + random_seed : single random seed, optional progressbar : bool ``progressbar`` argument for the ``PopulationStepper``, (defaults to True) @@ -1400,7 +1452,7 @@ def _mp_sample( chains: int, cores: int, chain: int, - random_seed: list, + random_seed: Sequence[RandomSeed], start: Sequence[PointType], progressbar: bool = True, trace: Optional[Union[BaseTrace, List[str]]] = None, @@ -1426,7 +1478,7 @@ def _mp_sample( The number of chains to run in parallel. chain : int Number of the first chain. - random_seed : list of ints + random_seed : list of random seeds Random seeds for each chain. start : list Starting points for each chain. @@ -1695,7 +1747,7 @@ def sample_posterior_predictive( model: Optional[Model] = None, var_names: Optional[List[str]] = None, keep_size: Optional[bool] = None, - random_seed=None, + random_seed: RandomState = None, progressbar: bool = True, return_inferencedata: bool = True, extend_inferencedata: bool = False, @@ -1729,7 +1781,7 @@ def sample_posterior_predictive( keep_size : bool, default True Force posterior predictive sample to have the same shape as posterior and sample stats data: ``(nchains, ndraws, ...)``. Overrides samples parameter. - random_seed : int + random_seed : int, RandomState or Generator, optional Seed for the random number generator. progressbar : bool Whether or not to display a progress bar in the command line. The bar shows the percentage @@ -1844,14 +1896,6 @@ def sample_posterior_predictive( else: vars_ = model.observed_RVs + model.auto_deterministics - if random_seed is not None: - warnings.warn( - "In this version, RNG seeding is managed by the Model objects. " - "See the `rng_seeder` argument in Model's constructor.", - FutureWarning, - stacklevel=2, - ) - indices = np.arange(samples) if progressbar: @@ -1868,6 +1912,9 @@ def sample_posterior_predictive( vars_in_trace = get_vars_in_point_list(_trace, model) + if random_seed is not None: + (random_seed,) = _get_seeds_per_chain(random_seed, 1) + if compile_kwargs is None: compile_kwargs = {} compile_kwargs.setdefault("allow_input_downcast", True) @@ -1879,6 +1926,7 @@ def sample_posterior_predictive( vars_in_trace=vars_in_trace, basic_rvs=model.basic_RVs, givens_dict=None, + random_seed=random_seed, **compile_kwargs, ) ) @@ -1941,7 +1989,7 @@ def sample_posterior_predictive_w( samples: Optional[int] = None, models: Optional[List[Model]] = None, weights: Optional[ArrayLike] = None, - random_seed: Optional[int] = None, + random_seed: RandomState = None, progressbar: bool = True, return_inferencedata: bool = True, idata_kwargs: dict = None, @@ -1965,7 +2013,7 @@ def sample_posterior_predictive_w( only be meaningful if all models share the same distributions for the observed RVs. weights : array-like, optional Individual weights for each trace. Default, same weight for each model. - random_seed : int, optional + random_seed : int, RandomState or Generator, optional Seed for the random number generator. progressbar : bool, optional default True Whether or not to display a progress bar in the command line. The bar shows the percentage @@ -1998,13 +2046,8 @@ def sample_posterior_predictive_w( if models is None: models = [modelcontext(models)] * len(traces) - if random_seed: - warnings.warn( - "In this version, RNG seeding is managed by the Model objects. " - "See the `rng_seeder` argument in Model's constructor.", - FutureWarning, - stacklevel=2, - ) + if random_seed is not None: + (random_seed,) = _get_seeds_per_chain(random_seed, 1) for model in models: if model.potentials: @@ -2114,7 +2157,7 @@ def sample_prior_predictive( samples: int = 500, model: Optional[Model] = None, var_names: Optional[Iterable[str]] = None, - random_seed=None, + random_seed: RandomState = None, return_inferencedata: bool = True, idata_kwargs: dict = None, compile_kwargs: dict = None, @@ -2130,7 +2173,7 @@ def sample_prior_predictive( A list of names of variables for which to compute the prior predictive samples. Defaults to both observed and unobserved RVs. Transformed values are not included unless explicitly defined in var_names. - random_seed : int + random_seed : int, RandomState or Generator, optional Seed for the random number generator. return_inferencedata : bool Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False). @@ -2165,21 +2208,13 @@ def sample_prior_predictive( else: vars_ = set(var_names) - if random_seed is not None: - warnings.warn( - "In this version, RNG seeding is managed by the Model objects. " - "See the `rng_seeder` argument in Model's constructor.", - FutureWarning, - stacklevel=2, - ) - - names = get_default_varnames(vars_, include_transformed=False) + names = sorted(get_default_varnames(vars_, include_transformed=False)) vars_to_sample = [model[name] for name in names] # Any variables from var_names that are missing must be transformed variables. # Misspelled variables would have raised a KeyError above. missing_names = vars_.difference(names) - for name in missing_names: + for name in sorted(missing_names): transformed_value_var = model[name] rv_var = model.values_to_rvs[transformed_value_var] transform = transformed_value_var.tag.transform @@ -2194,6 +2229,9 @@ def sample_prior_predictive( names.append(rv_var.name) vars_to_sample.append(rv_var) + if random_seed is not None: + (random_seed,) = _get_seeds_per_chain(random_seed, 1) + if compile_kwargs is None: compile_kwargs = {} compile_kwargs.setdefault("allow_input_downcast", True) @@ -2204,6 +2242,7 @@ def sample_prior_predictive( vars_in_trace=[], basic_rvs=model.basic_RVs, givens_dict=None, + random_seed=random_seed, **compile_kwargs, ) @@ -2229,6 +2268,7 @@ def sample_prior_predictive( def draw( vars: Union[Variable, Sequence[Variable]], draws: int = 1, + random_seed: RandomState = None, **kwargs, ) -> Union[np.ndarray, List[np.ndarray]]: """Draw samples for one variable or a list of variables @@ -2239,6 +2279,8 @@ def draw( A variable or a list of variables for which to draw samples. draws : int, default 1 Number of samples needed to draw. + random_seed : int, RandomState or Generator, optional + Seed for the random number generator. **kwargs : dict, optional Keyword arguments for :func:`pymc.aesara.compile_pymc`. @@ -2271,8 +2313,10 @@ def draw( assert draws[1].shape == (num_draws, 10) assert draws[2].shape == (num_draws, 5) """ + if random_seed is not None: + (random_seed,) = _get_seeds_per_chain(random_seed, 1) - draw_fn = compile_pymc(inputs=[], outputs=vars, **kwargs) + draw_fn = compile_pymc(inputs=[], outputs=vars, random_seed=random_seed, **kwargs) if draws == 1: return draw_fn() @@ -2291,7 +2335,7 @@ def draw( def _init_jitter( model: Model, initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]], - seeds: Union[List[Any], Tuple[Any, ...], np.ndarray], + seeds: Union[Sequence[int], np.ndarray], jitter: bool, jitter_max_retries: int, ) -> List[PointType]: @@ -2348,7 +2392,7 @@ def init_nuts( chains: int = 1, n_init: int = 500_000, model=None, - seeds: Iterable[Any] = None, + random_seed: RandomSeed = None, progressbar=True, jitter_max_retries: int = 10, tune: Optional[int] = None, @@ -2395,8 +2439,8 @@ def init_nuts( n_init : int Number of iterations of initializer. Only works for 'ADVI' init methods. model : Model (optional if in ``with`` context) - seeds : list - Seed values for each chain. + random_seed : int, array-like of int, RandomState or Generator, optional + Seed for the random number generator. progressbar : bool Whether or not to display a progressbar for advi sampling. jitter_max_retries : int @@ -2429,14 +2473,7 @@ def init_nuts( if init == "auto": init = "jitter+adapt_diag" - if seeds is None: - seeds = model.rng_seeder.randint(2**30, dtype=np.int64, size=chains) - if not isinstance(seeds, (list, tuple, np.ndarray)): - raise ValueError(f"The `seeds` must be array-like. Got {type(seeds)} instead.") - if len(seeds) != chains: - raise ValueError( - f"Number of seeds ({len(seeds)}) does not match the number of chains ({chains})." - ) + random_seed_list = _get_seeds_per_chain(random_seed, chains) _log.info(f"Initializing NUTS using {init}...") @@ -2448,7 +2485,7 @@ def init_nuts( initial_points = _init_jitter( model, initvals, - seeds=seeds, + seeds=random_seed_list, jitter="jitter" in init, jitter_max_retries=jitter_max_retries, ) @@ -2486,7 +2523,7 @@ def init_nuts( ) elif init == "advi+adapt_diag": approx = pm.fit( - random_seed=seeds[0], + random_seed=random_seed_list[0], n=n_init, method="advi", model=model, @@ -2494,7 +2531,9 @@ def init_nuts( progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - approx_sample = approx.sample(draws=chains, return_inferencedata=False) + approx_sample = approx.sample( + draws=chains, random_seed=random_seed_list[0], return_inferencedata=False + ) initial_points = [approx_sample[i] for i in range(chains)] std_apoint = approx.std.eval() cov = std_apoint**2 @@ -2504,7 +2543,7 @@ def init_nuts( potential = quadpotential.QuadPotentialDiagAdapt(n, mean, cov, weight) elif init == "advi": approx = pm.fit( - random_seed=seeds[0], + random_seed=random_seed_list[0], n=n_init, method="advi", model=model, @@ -2512,27 +2551,31 @@ def init_nuts( progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - approx_sample = approx.sample(draws=chains, return_inferencedata=False) + approx_sample = approx.sample( + draws=chains, random_seed=random_seed_list[0], return_inferencedata=False + ) initial_points = [approx_sample[i] for i in range(chains)] cov = approx.std.eval() ** 2 potential = quadpotential.QuadPotentialDiag(cov) elif init == "advi_map": - start = pm.find_MAP(include_transformed=True, seed=seeds[0]) + start = pm.find_MAP(include_transformed=True, seed=random_seed_list[0]) approx = pm.MeanField(model=model, start=start) pm.fit( - random_seed=seeds[0], + random_seed=random_seed_list[0], n=n_init, method=pm.KLqp(approx), callbacks=cb, progressbar=progressbar, obj_optimizer=pm.adagrad_window, ) - approx_sample = approx.sample(draws=chains, return_inferencedata=False) + approx_sample = approx.sample( + draws=chains, random_seed=random_seed_list[0], return_inferencedata=False + ) initial_points = [approx_sample[i] for i in range(chains)] cov = approx.std.eval() ** 2 potential = quadpotential.QuadPotentialDiag(cov) elif init == "map": - start = pm.find_MAP(include_transformed=True, seed=seeds[0]) + start = pm.find_MAP(include_transformed=True, seed=random_seed_list[0]) cov = pm.find_hessian(point=start) initial_points = [start] * chains potential = quadpotential.QuadPotentialFull(cov) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index e3d2e1034d..a087e005ca 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -7,7 +7,7 @@ from typing import Callable, Dict, List, Optional, Sequence, Union from pymc.initial_point import StartDict -from pymc.sampling import _init_jitter +from pymc.sampling import RandomSeed, _get_seeds_per_chain, _init_jitter xla_flags = os.getenv("XLA_FLAGS", "") xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split() @@ -147,7 +147,7 @@ def _get_batched_jittered_initial_points( model: Model, chains: int, initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]], - random_seed: int, + random_seed: RandomSeed, jitter: bool = True, jitter_max_retries: int = 10, ) -> Union[np.ndarray, List[np.ndarray]]: @@ -160,14 +160,10 @@ def _get_batched_jittered_initial_points( Each item has shape `(chains, *var.shape)` """ - random_seed = np.random.default_rng(random_seed).integers(2**30, size=chains) - - assert len(random_seed) == chains - initial_points = _init_jitter( model, initvals, - seeds=random_seed, + seeds=_get_seeds_per_chain(random_seed, chains), jitter=jitter, jitter_max_retries=jitter_max_retries, ) @@ -220,7 +216,7 @@ def sample_blackjax_nuts( tune=1000, chains=4, target_accept=0.8, - random_seed=10, + random_seed: RandomSeed = None, initvals=None, model=None, var_names=None, @@ -245,7 +241,7 @@ def sample_blackjax_nuts( target_accept : float in [0, 1]. The step size is tuned such that we approximate this acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic posteriors. - random_seed : int, default 10 + random_seed : int, RandomState or Generator, optional Random seed used by the sampling steps. model : Model, optional Model to sample from. The model needs to have free random variables. When inside a ``with`` model @@ -292,6 +288,8 @@ def sample_blackjax_nuts( else: dims = {} + (random_seed,) = _get_seeds_per_chain(random_seed, 1) + tic1 = datetime.now() print("Compiling...", file=sys.stdout) @@ -387,7 +385,7 @@ def sample_numpyro_nuts( tune: int = 1000, chains: int = 4, target_accept: float = 0.8, - random_seed: int = None, + random_seed: RandomSeed = None, initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, model: Optional[Model] = None, var_names=None, @@ -414,7 +412,7 @@ def sample_numpyro_nuts( target_accept : float in [0, 1]. The step size is tuned such that we approximate this acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic posteriors. - random_seed : int, default 10 + random_seed : int, RandomState or Generator, optional Random seed used by the sampling steps. model : Model, optional Model to sample from. The model needs to have free random variables. When inside a ``with`` model @@ -470,8 +468,7 @@ def sample_numpyro_nuts( else: dims = {} - if random_seed is None: - random_seed = model.rng_seeder.randint(2**30, dtype=np.int64) + (random_seed,) = _get_seeds_per_chain(random_seed, 1) tic1 = datetime.now() print("Compiling...", file=sys.stdout) diff --git a/pymc/tests/models.py b/pymc/tests/models.py index 0c1f176a75..8b4d6b319e 100644 --- a/pymc/tests/models.py +++ b/pymc/tests/models.py @@ -189,7 +189,7 @@ def simple_normal(bounded_prior=False): sigma = 1.0 a, b = (9, 12) # bounds for uniform RV, need non-symmetric to reproduce issue - with pm.Model(rng_seeder=2482) as model: + with pm.Model() as model: if bounded_prior: mu_i = pm.Uniform("mu_i", a, b) else: diff --git a/pymc/tests/test_aesaraf.py b/pymc/tests/test_aesaraf.py index e5b288a6d3..e529f39c68 100644 --- a/pymc/tests/test_aesaraf.py +++ b/pymc/tests/test_aesaraf.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock import aesara import aesara.tensor as at @@ -26,6 +27,7 @@ from aesara.graph.basic import Constant, Variable, ancestors, equal_computations from aesara.tensor.random.basic import normal, uniform from aesara.tensor.random.op import RandomVariable +from aesara.tensor.random.var import RandomStateSharedVariable from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 from aesara.tensor.var import TensorVariable @@ -36,6 +38,7 @@ compile_pymc, convert_observed_data, extract_obs_data, + reseed_rngs, rvs_to_value_vars, walk_model, ) @@ -487,7 +490,9 @@ def test_compile_pymc_updates_inputs(self): # Each RV adds a shared output for its rng assert len(fn_fgraph.outputs) == 1 + rvs_in_graph - def test_compile_pymc_custom_update_op(self): + # Disable `reseed_rngs` so that we can test with simpler update rule + @mock.patch("pymc.aesaraf.reseed_rngs") + def test_compile_pymc_custom_update_op(self, _): """Test that custom MeasurableVariable Op updates are used by compile_pymc""" class UnmeasurableOp(OpFromGraph): @@ -507,3 +512,62 @@ def update(self, node): fn = compile_pymc(inputs=[], outputs=dummy_x) assert fn() == 2.0 assert fn() == 3.0 + + def test_random_seed(self): + seedx = aesara.shared(np.random.default_rng(1)) + seedy = aesara.shared(np.random.default_rng(1)) + x = at.random.normal(rng=seedx) + y = at.random.normal(rng=seedy) + + # Shared variables are the same, so outputs will be identical + f0 = aesara.function([], [x, y]) + x0_eval, y0_eval = f0() + assert x0_eval == y0_eval + + # The variables will be reseeded with new seeds by default + f1 = compile_pymc([], [x, y]) + x1_eval, y1_eval = f1() + assert x1_eval != y1_eval + + # Check that seeding works + f2 = compile_pymc([], [x, y], random_seed=1) + x2_eval, y2_eval = f2() + assert x2_eval != x1_eval + assert y2_eval != y1_eval + + f3 = compile_pymc([], [x, y], random_seed=1) + x3_eval, y3_eval = f3() + assert x3_eval == x2_eval + assert y3_eval == y2_eval + + +def test_reseed_rngs(): + # Reseed_rngs uses the `PCG64` bit_generator, which is currently the default + # bit_generator used by NumPy. If this default changes in the future, this test will + # catch that. We will then have to decide whether to switch to the new default in + # PyMC or whether to stick with the older one (PCG64). This will pose a trade-off + # between backwards reproducibility and better/faster seeding. If we decide to change, + # the next line should be updated: + default_rng = np.random.PCG64 + assert isinstance(np.random.default_rng().bit_generator, default_rng) + + seed = 543 + + bit_generators = [default_rng(sub_seed) for sub_seed in np.random.SeedSequence(seed).spawn(2)] + + rngs = [ + aesara.shared(rng_type(default_rng())) + for rng_type in (np.random.Generator, np.random.RandomState) + ] + for rng, bit_generator in zip(rngs, bit_generators): + if isinstance(rng, RandomStateSharedVariable): + assert rng.get_value()._bit_generator.state != bit_generator.state + else: + assert rng.get_value().bit_generator.state != bit_generator.state + + reseed_rngs(rngs, seed) + for rng, bit_generator in zip(rngs, bit_generators): + if isinstance(rng, RandomStateSharedVariable): + assert rng.get_value()._bit_generator.state == bit_generator.state + else: + assert rng.get_value().bit_generator.state == bit_generator.state diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index f58b6d16b7..c805e01f12 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -3222,21 +3222,23 @@ def random(rng, size): def test_distinct_rvs(): """Make sure `RandomVariable`s generated using a `Model`'s default RNG state all have distinct states.""" - with pm.Model(rng_seeder=np.random.RandomState(2023532)) as model: + with pm.Model() as model: X_rv = pm.Normal("x") Y_rv = pm.Normal("y") - pp_samples = pm.sample_prior_predictive(samples=2, return_inferencedata=False) + pp_samples = pm.sample_prior_predictive( + samples=2, return_inferencedata=False, random_seed=np.random.RandomState(2023532) + ) assert X_rv.owner.inputs[0] != Y_rv.owner.inputs[0] - assert len(model.rng_seq) == 2 - - with pm.Model(rng_seeder=np.random.RandomState(2023532)): + with pm.Model(): X_rv = pm.Normal("x") Y_rv = pm.Normal("y") - pp_samples_2 = pm.sample_prior_predictive(samples=2, return_inferencedata=False) + pp_samples_2 = pm.sample_prior_predictive( + samples=2, return_inferencedata=False, random_seed=np.random.RandomState(2023532) + ) assert np.array_equal(pp_samples["y"], pp_samples_2["y"]) @@ -3312,7 +3314,8 @@ def test_censored_workflow(self, censored): data[data <= low] = low data[data >= high] = high - with pm.Model(rng_seeder=17092021) as m: + rng = 17092021 + with pm.Model() as m: mu = pm.Normal( "mu", mu=((high - low) / 2) + low, @@ -3328,9 +3331,9 @@ def test_censored_workflow(self, censored): observed=data, ) - prior_pred = pm.sample_prior_predictive() - posterior = pm.sample(tune=500, draws=500) - posterior_pred = pm.sample_posterior_predictive(posterior) + prior_pred = pm.sample_prior_predictive(random_seed=rng) + posterior = pm.sample(tune=500, draws=500, random_seed=rng) + posterior_pred = pm.sample_posterior_predictive(posterior, random_seed=rng) expected = True if censored else False assert (9 < prior_pred.prior_predictive.mean() < 10) == expected diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 9fd7c2ce81..5da279f85b 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -1682,14 +1682,14 @@ def check_draws(self): def ref_rand(mu, rowcov, colcov): return st.matrix_normal.rvs(mean=mu, rowcov=rowcov, colcov=colcov) - with pm.Model(rng_seeder=1): + with pm.Model(): matrixnormal = pm.MatrixNormal( "matnormal", mu=np.random.random((3, 3)), rowcov=np.eye(3), colcov=np.eye(3), ) - check = pm.sample_prior_predictive(n_fails, return_inferencedata=False) + check = pm.sample_prior_predictive(n_fails, return_inferencedata=False, random_seed=1) ref_smp = ref_rand(mu=np.random.random((3, 3)), rowcov=np.eye(3), colcov=np.eye(3)) @@ -2328,10 +2328,10 @@ def test_car_rng_fn(sparse): if sparse: W = aesara.sparse.csr_from_dense(W) - with pm.Model(rng_seeder=1): + with pm.Model(): car = pm.CAR("car", mu, W, alpha, tau, size=size) mn = pm.MvNormal("mn", mu, cov, size=size) - check = pm.sample_prior_predictive(n_fails, return_inferencedata=False) + check = pm.sample_prior_predictive(n_fails, return_inferencedata=False, random_seed=1) p, f = delta, n_fails while p <= delta and f > 0: diff --git a/pymc/tests/test_mixture.py b/pymc/tests/test_mixture.py index 4562911f95..d164418135 100644 --- a/pymc/tests/test_mixture.py +++ b/pymc/tests/test_mixture.py @@ -232,12 +232,8 @@ def test_list_univariate_components_deterministic_weights(self, weights, compone @pytest.mark.parametrize("size", [None, (4,), (5, 4)]) def test_single_multivariate_component_deterministic_weights(self, weights, component, size): # This test needs seeding to avoid repetitions - rngs = [ - aesara.shared(np.random.default_rng(seed)) - for seed in self.get_random_state().randint(2**30, size=2) - ] - mix = Mixture.dist(weights, component, size=size, rngs=rngs) - mix_eval = mix.eval() + mix = Mixture.dist(weights, component, size=size) + mix_eval = draw(mix, random_seed=self.get_random_state()) # Test shape # component shape is either (4, 2, 3), (2, 3) @@ -853,7 +849,7 @@ def test_scalar_components(self): # [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]] mus = at.constant(np.full((nd, npop), np.arange(npop))) - with Model(rng_seeder=self.get_random_state()) as model: + with Model() as model: m = NormalMixture( "m", w=np.ones(npop) / npop, @@ -867,8 +863,8 @@ def test_scalar_components(self): latent_m = Normal("latent_m", mu=mu, sigma=1e-5, shape=nd) size = 100 - m_val = draw(m, draws=size) - latent_m_val = draw(latent_m, draws=size) + m_val = draw(m, draws=size, random_seed=self.get_random_state()) + latent_m_val = draw(latent_m, draws=size, random_seed=self.get_random_state()) assert m_val.shape == latent_m_val.shape # Test that each element in axis = -1 can come from independent @@ -888,7 +884,7 @@ def test_vector_components(self): # [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]] mus = at.constant(np.full((nd, npop), np.arange(npop))) - with Model(rng_seeder=self.get_random_state()) as model: + with Model() as model: m = Mixture( "m", w=np.ones(npop) / npop, @@ -900,8 +896,8 @@ def test_vector_components(self): latent_m = Normal("latent_m", mu=mus[..., z], sigma=1e-5, shape=nd) size = 100 - m_val = draw(m, draws=size) - latent_m_val = draw(latent_m, draws=size) + m_val = draw(m, draws=size, random_seed=998) + latent_m_val = draw(latent_m, draws=size, random_seed=998 * 2) assert m_val.shape == latent_m_val.shape # Test that each element in axis = -1 comes from the same mixture # component diff --git a/pymc/tests/test_model.py b/pymc/tests/test_model.py index 6a01bc28d0..a4c2181906 100644 --- a/pymc/tests/test_model.py +++ b/pymc/tests/test_model.py @@ -657,7 +657,7 @@ def test_set_initval(): # generating initial values rng = np.random.RandomState(392) - with pm.Model(rng_seeder=rng) as model: + with pm.Model() as model: eta = pm.Uniform("eta", 1.0, 2.0, size=(1, 1)) mu = pm.Normal("mu", sigma=eta, initval=[[100]]) alpha = pm.HalfNormal("alpha", initval=100) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index ec7bac8395..cdc2d05192 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re import unittest.mock as mock from contextlib import ExitStack as does_not_raise -from itertools import combinations from typing import Tuple import aesara @@ -39,7 +39,7 @@ from pymc.backends.base import MultiTrace from pymc.backends.ndarray import NDArray from pymc.exceptions import IncorrectArgumentsError, SamplingError -from pymc.sampling import compile_forward_sampling_function +from pymc.sampling import _get_seeds_per_chain, compile_forward_sampling_function from pymc.tests.helpers import SeededTest, fast_unstable_sampling_mode from pymc.tests.models import simple_init @@ -51,10 +51,8 @@ def setup_method(self): def test_checks_seeds_kwarg(self): with self.model: - with pytest.raises(ValueError, match="must be array-like"): - pm.sampling.init_nuts(seeds=1) with pytest.raises(ValueError, match="Number of seeds"): - pm.sampling.init_nuts(chains=2, seeds=[1]) + pm.sampling.init_nuts(chains=2, random_seed=[1]) class TestSample(SeededTest): @@ -76,7 +74,7 @@ def setup_method(self): ], ) def test_random_seed(self, chains, seeds, cores, init): - with pm.Model(rng_seeder=3): + with pm.Model(): x = pm.Normal("x", 0, 10, initval="prior") tr1 = pm.sample( chains=chains, @@ -102,43 +100,47 @@ def test_random_seed(self, chains, seeds, cores, init): allequal = np.all(tr1["x"] == tr2["x"]) if seeds is None: assert not allequal - # TODO: ADVI init methods are not correctly seeded, as they rely on the state of - # the model RandomState/Generators which is updated in place when the function - # is compiled and evaluated. This elif branch must be removed once this is fixed - elif init == "advi": - assert not allequal else: assert allequal - def test_sample_does_not_set_seed(self): - # This tests that when random_seed is None, the global seed is not affected - random_numbers = [] - for _ in range(2): + @mock.patch("numpy.random.seed") + def test_default_sample_does_not_set_global_seed(self, mocked_seed): + # Test that when random_seed is None, `np.random.seed` is not called in the main + # process. Ideally it would never be called, but PyMC step samplers still rely + # on global seeding for reproducible behavior. + kwargs = dict(tune=2, draws=2, random_seed=None) + with self.model: + pm.sample(chains=1, **kwargs) + pm.sample(chains=2, cores=1, **kwargs) + pm.sample(chains=2, cores=2, **kwargs) + mocked_seed.assert_not_called() + + def test_sample_does_not_rely_on_external_global_seeding(self): + # Tests that sampling does not depend on exertenal global seeding + kwargs = dict( + tune=2, + draws=20, + random_seed=None, + return_inferencedata=False, + ) + with self.model: np.random.seed(1) - with self.model: - pm.sample(1, tune=0, chains=1, random_seed=None) - random_numbers.append(np.random.random()) - assert random_numbers[0] == random_numbers[1] - - def test_parallel_sample_does_not_reuse_seed(self): - cores = 4 - random_numbers = [] - draws = [] - for _ in range(2): - np.random.seed(1) # seeds in other processes don't effect main process - with self.model: - idata = pm.sample(100, tune=0, cores=cores) - # numpy thread mentioned race condition. might as well check none are equal - for first, second in combinations(range(cores), 2): - first_chain = idata.posterior["x"].sel(chain=first).values - second_chain = idata.posterior["x"].sel(chain=second).values - assert not np.allclose(first_chain, second_chain) - draws.append(idata.posterior["x"].values) - random_numbers.append(np.random.random()) - - # Make sure future random processes aren't effected by this - assert random_numbers[0] == random_numbers[1] - assert (draws[0] == draws[1]).all() + idata11 = pm.sample(chains=1, **kwargs) + np.random.seed(1) + idata12 = pm.sample(chains=2, cores=1, **kwargs) + np.random.seed(1) + idata13 = pm.sample(chains=2, cores=2, **kwargs) + + np.random.seed(1) + idata21 = pm.sample(chains=1, **kwargs) + np.random.seed(1) + idata22 = pm.sample(chains=2, cores=1, **kwargs) + np.random.seed(1) + idata23 = pm.sample(chains=2, cores=2, **kwargs) + + assert np.all(idata11["x"] != idata21["x"]) + assert np.all(idata12["x"] != idata22["x"]) + assert np.all(idata13["x"] != idata23["x"]) def test_sample(self): test_cores = [1] @@ -223,7 +225,7 @@ def test_reset_tuning(self): with self.model: tune = 50 chains = 2 - start, step = pm.sampling.init_nuts(chains=chains, seeds=[1, 2]) + start, step = pm.sampling.init_nuts(chains=chains, random_seed=[1, 2]) pm.sample(draws=2, tune=tune, chains=chains, step=step, start=start, cores=1) assert step.potential._n_samples == tune assert step.step_adapt._count == tune + 1 @@ -734,12 +736,14 @@ def test_model_shared_variable(self): y = x > 0 x_shared = aesara.shared(x) y_shared = aesara.shared(y) - with pm.Model(rng_seeder=rng) as model: + with pm.Model() as model: coeff = pm.Normal("x", mu=0, sigma=1) logistic = pm.Deterministic("p", pm.math.sigmoid(coeff * x_shared)) obs = pm.Bernoulli("obs", p=logistic, observed=y_shared) - trace = pm.sample(100, return_inferencedata=False, compute_convergence_checks=False) + trace = pm.sample( + 100, return_inferencedata=False, compute_convergence_checks=False, random_seed=rng + ) x_shared.set_value([-1, 0, 1.0]) y_shared.set_value([0, 0, 0]) @@ -760,7 +764,7 @@ def test_deterministic_of_observed(self): meas_in_1 = pm.aesaraf.floatX(2 + 4 * rng.randn(10)) meas_in_2 = pm.aesaraf.floatX(5 + 4 * rng.randn(10)) nchains = 2 - with pm.Model(rng_seeder=rng) as model: + with pm.Model() as model: mu_in_1 = pm.Normal("mu_in_1", 0, 2) sigma_in_1 = pm.HalfNormal("sd_in_1", 1) mu_in_2 = pm.Normal("mu_in_2", 0, 2) @@ -779,6 +783,7 @@ def test_deterministic_of_observed(self): step=pm.Metropolis(), return_inferencedata=False, compute_convergence_checks=False, + random_seed=rng, ) rtol = 1e-5 if aesara.config.floatX == "float64" else 1e-4 @@ -799,7 +804,7 @@ def test_deterministic_of_observed_modified_interface(self): meas_in_1 = pm.aesaraf.floatX(2 + 4 * rng.randn(100)) meas_in_2 = pm.aesaraf.floatX(5 + 4 * rng.randn(100)) - with pm.Model(rng_seeder=rng) as model: + with pm.Model() as model: mu_in_1 = pm.Normal("mu_in_1", 0, 1, initval=0) sigma_in_1 = pm.HalfNormal("sd_in_1", 1, initval=1) mu_in_2 = pm.Normal("mu_in_2", 0, 1, initval=0) @@ -817,6 +822,7 @@ def test_deterministic_of_observed_modified_interface(self): step=pm.Metropolis(), return_inferencedata=False, compute_convergence_checks=False, + random_seed=rng, ) varnames = [v for v in trace.varnames if v != "out"] ppc_trace = [ @@ -998,12 +1004,12 @@ def check_exec_nuts_init(method): pm.Normal("a", mu=0, sigma=1, size=2) pm.HalfNormal("b", sigma=1) with model: - start, _ = pm.init_nuts(init=method, n_init=10, seeds=[1]) + start, _ = pm.init_nuts(init=method, n_init=10, random_seed=[1]) assert isinstance(start, list) assert len(start) == 1 assert isinstance(start[0], dict) assert set(start[0].keys()) == {v.name for v in model.value_vars} - start, _ = pm.init_nuts(init=method, n_init=10, chains=2, seeds=[1, 2]) + start, _ = pm.init_nuts(init=method, n_init=10, chains=2, random_seed=[1, 2]) assert isinstance(start, list) assert len(start) == 2 assert isinstance(start[0], dict) @@ -1134,11 +1140,11 @@ def test_multivariate2(self): assert sim_ppc["obs"].shape == (20,) + mn_data.shape def test_layers(self): - with pm.Model(rng_seeder=232093) as model: + with pm.Model() as model: a = pm.Uniform("a", lower=0, upper=1, size=10) b = pm.Binomial("b", n=1, p=a, size=10) - b_sampler = compile_pymc([], b, mode="FAST_RUN") + b_sampler = compile_pymc([], b, mode="FAST_RUN", random_seed=232093) avg = np.stack([b_sampler() for i in range(10000)]).mean(0) npt.assert_array_almost_equal(avg, 0.5 * np.ones((10,)), decimal=2) @@ -1239,13 +1245,14 @@ def ub_interval_forward(x, ub): # Interval transform assuming lower bound is zero return np.log(x - 0) - np.log(ub - x) - with pm.Model(rng_seeder=123) as model: + with pm.Model() as model: ub = pm.HalfNormal("ub", 10) x = pm.Uniform("x", 0, ub) prior = pm.sample_prior_predictive( var_names=["ub", "ub_log__", "x", "x_interval__"], samples=10, + random_seed=123, ) # Check values are correct @@ -1256,13 +1263,14 @@ def ub_interval_forward(x, ub): ) # Check that it works when the original RVs are not mentioned in var_names - with pm.Model(rng_seeder=123) as model_transformed_only: + with pm.Model() as model_transformed_only: ub = pm.HalfNormal("ub", 10) x = pm.Uniform("x", 0, ub) prior_transformed_only = pm.sample_prior_predictive( var_names=["ub_log__", "x_interval__"], samples=10, + random_seed=123, ) assert ( "ub" not in prior_transformed_only.prior.data_vars @@ -1279,19 +1287,23 @@ def test_issue_4490(self): # Test that samples do not depend on var_name order or, more fundamentally, # that they do not depend on the set order used inside `sample_prior_predictive` seed = 4490 - with pm.Model(rng_seeder=seed) as m1: + with pm.Model() as m1: a = pm.Normal("a") b = pm.Normal("b") c = pm.Normal("c") d = pm.Normal("d") - prior1 = pm.sample_prior_predictive(samples=1, var_names=["a", "b", "c", "d"]) + prior1 = pm.sample_prior_predictive( + samples=1, var_names=["a", "b", "c", "d"], random_seed=seed + ) - with pm.Model(rng_seeder=seed) as m2: + with pm.Model() as m2: a = pm.Normal("a") b = pm.Normal("b") c = pm.Normal("c") d = pm.Normal("d") - prior2 = pm.sample_prior_predictive(samples=1, var_names=["b", "a", "d", "c"]) + prior2 = pm.sample_prior_predictive( + samples=1, var_names=["b", "a", "d", "c"], random_seed=seed + ) assert prior1.prior["a"] == prior2.prior["a"] assert prior1.prior["b"] == prior2.prior["b"] @@ -1433,19 +1445,21 @@ def test_draw_aesara_function_kwargs(self): def test_step_args(): - with pm.Model(rng_seeder=1410) as model: + with pm.Model() as model: a = pm.Normal("a") - idata0 = pm.sample(target_accept=0.5) - idata1 = pm.sample(nuts={"target_accept": 0.5}) + idata0 = pm.sample(target_accept=0.5, random_seed=1410) + idata1 = pm.sample(nuts={"target_accept": 0.5}, random_seed=1410 * 2) npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) npt.assert_almost_equal(idata1.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) - with pm.Model(rng_seeder=1418) as model: + with pm.Model() as model: a = pm.Normal("a") b = pm.Poisson("b", 1) - idata0 = pm.sample(target_accept=0.5) - idata1 = pm.sample(nuts={"target_accept": 0.5}, metropolis={"scaling": 0}) + idata0 = pm.sample(target_accept=0.5, random_seed=1418) + idata1 = pm.sample( + nuts={"target_accept": 0.5}, metropolis={"scaling": 0}, random_seed=1418 * 2 + ) npt.assert_almost_equal(idata0.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) npt.assert_almost_equal(idata1.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) @@ -1676,3 +1690,39 @@ def test_lkj_cholesky_cov(self): ) assert {i.name for i in self.get_function_inputs(f)} == set() assert {i.name for i in self.get_function_roots(f)} == set() + + +def test_get_seeds_per_chain(): + ret = _get_seeds_per_chain(None, chains=1) + assert len(ret) == 1 and isinstance(ret[0], int) + + ret = _get_seeds_per_chain(None, chains=2) + assert len(ret) == 2 and isinstance(ret[0], int) + + ret = _get_seeds_per_chain(5, chains=1) + assert ret == (5,) + + ret = _get_seeds_per_chain(5, chains=3) + assert len(ret) == 3 and isinstance(ret[0], int) and not any(r == 5 for r in ret) + + rng = np.random.default_rng(123) + expected_ret = rng.integers(2**30, dtype=np.int64, size=1) + rng = np.random.default_rng(123) + ret = _get_seeds_per_chain(rng, chains=1) + assert ret == expected_ret + + rng = np.random.RandomState(456) + expected_ret = rng.randint(2**30, dtype=np.int64, size=2) + rng = np.random.RandomState(456) + ret = _get_seeds_per_chain(rng, chains=2) + assert np.all(ret == expected_ret) + + for expected_ret in ([0, 1, 2], (0, 1, 2, 3), np.arange(5)): + ret = _get_seeds_per_chain(expected_ret, chains=len(expected_ret)) + assert ret is expected_ret + + with pytest.raises(ValueError, match="does not match the number of chains"): + _get_seeds_per_chain(expected_ret, chains=len(expected_ret) + 1) + + with pytest.raises(ValueError, match=re.escape("The `seeds` must be array-like")): + _get_seeds_per_chain({1: 1, 2: 2}, 2) diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/test_sampling_jax.py index 488db8c5e4..dc3e4ad4c8 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/test_sampling_jax.py @@ -199,9 +199,16 @@ def test_get_batched_jittered_initial_points(): assert np.all(ips[0][0] != ips[0][1]) +@pytest.mark.parametrize( + "sampler", + [ + sample_blackjax_nuts, + sample_numpyro_nuts, + ], +) @pytest.mark.parametrize("random_seed", (None, 123)) @pytest.mark.parametrize("chains", (1, 2)) -def test_seeding(chains, random_seed): +def test_seeding(chains, random_seed, sampler): sample_kwargs = dict( tune=100, draws=5, @@ -209,20 +216,17 @@ def test_seeding(chains, random_seed): random_seed=random_seed, ) - with pm.Model(rng_seeder=456) as m: - pm.Normal("x", mu=0, sigma=1) - result1 = sample_numpyro_nuts(**sample_kwargs) - - with pm.Model(rng_seeder=456) as m: + with pm.Model() as m: pm.Normal("x", mu=0, sigma=1) - result2 = sample_numpyro_nuts(**sample_kwargs) - result3 = sample_numpyro_nuts(**sample_kwargs) + result1 = sampler(**sample_kwargs) + result2 = sampler(**sample_kwargs) - assert np.all(result1.posterior["x"] == result2.posterior["x"]) - expected_equal_result3 = random_seed is not None - assert np.all(result2.posterior["x"] == result3.posterior["x"]) == expected_equal_result3 + all_equal = np.all(result1.posterior["x"] == result2.posterior["x"]) + if random_seed is None: + assert not all_equal + else: + assert all_equal if chains > 1: assert np.all(result1.posterior["x"].sel(chain=0) != result1.posterior["x"].sel(chain=1)) assert np.all(result2.posterior["x"].sel(chain=0) != result2.posterior["x"].sel(chain=1)) - assert np.all(result3.posterior["x"].sel(chain=0) != result3.posterior["x"].sel(chain=1)) diff --git a/pymc/tests/test_step.py b/pymc/tests/test_step.py index 0a9b923017..7c89c91005 100644 --- a/pymc/tests/test_step.py +++ b/pymc/tests/test_step.py @@ -523,7 +523,7 @@ def test_linalg(self, caplog): b = at.slinalg.solve(floatX(np.eye(2)), a, check_finite=False) Normal("c", mu=b, size=2, initval=floatX(np.r_[0.0, 0.0])) caplog.clear() - trace = sample(20, tune=5, chains=2, return_inferencedata=False) + trace = sample(20, tune=5, chains=2, return_inferencedata=False, random_seed=526) warns = [msg.msg for msg in caplog.records] assert np.any(trace["diverging"]) assert ( @@ -1183,7 +1183,7 @@ def perform(self, node, inputs, outputs): rng = np.random.RandomState(seed) - with Model(rng_seeder=rng) as coarse_model_0: + with Model() as coarse_model_0: if aesara.config.floatX == "float32": Q = Data("Q", np.float32(0.0)) else: @@ -1202,7 +1202,7 @@ def perform(self, node, inputs, outputs): rng = np.random.RandomState(seed) - with Model(rng_seeder=rng) as coarse_model_1: + with Model() as coarse_model_1: if aesara.config.floatX == "float32": Q = Data("Q", np.float32(0.0)) else: @@ -1221,7 +1221,7 @@ def perform(self, node, inputs, outputs): rng = np.random.RandomState(seed) - with Model(rng_seeder=rng) as model: + with Model() as model: if aesara.config.floatX == "float32": Q = Data("Q", np.float32(0.0)) else: diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index f2f94951dd..ccca3aa2f7 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -107,8 +107,6 @@ def find_MAP( return_transformed=True, overrides=start, ) - if seed is None: - seed = model.rng_seeder.randint(2**30, dtype=np.int64) start = ipfn(seed) model.check_start_vals(start) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 0b6d6dddc6..eae2c6a849 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -57,11 +57,20 @@ import pymc as pm -from pymc.aesaraf import at_rng, compile_pymc, identity, rvs_to_value_vars +from pymc.aesaraf import ( + SeedSequenceSeed, + at_rng, + compile_pymc, + find_rng_nodes, + identity, + reseed_rngs, + rvs_to_value_vars, +) from pymc.backends import NDArray from pymc.blocking import DictToArrayBijection from pymc.initial_point import make_initial_point_fn from pymc.model import modelcontext +from pymc.sampling import RandomState, _get_seeds_per_chain from pymc.util import WithMemoization, locally_cachedmethod from pymc.variational.updates import adagrad_window from pymc.vartypes import discrete_types @@ -846,6 +855,7 @@ def __init__( self._vfam = vfam self._local = local self._batched = rowwise + self.rng = np.random.default_rng(random_seed) self._rng = at_rng(random_seed) model = modelcontext(model) self.model = model @@ -867,7 +877,7 @@ def _prepare_start(self, start=None): jitter_rvs={}, return_transformed=True, ) - start = ipfn(self.model.rng_seeder.randint(2**30, dtype=np.int64)) + start = ipfn(self.rng.integers(2**30, dtype=np.int64)) group_vars = {self.model.rvs_to_values[v].name for v in self.group} start = {k: v for k, v in start.items() if k in group_vars} if self.batched: @@ -1640,22 +1650,30 @@ def sample_dict_fn(self): sampled = [self.rslice(name) for name in names] sampled = self.set_size_and_deterministic(sampled, s, 0) sample_fn = compile_pymc([s], sampled) + rng_nodes = find_rng_nodes(sampled) - def inner(draws=100): + def inner(draws=100, *, random_seed: SeedSequenceSeed = None): + if random_seed is not None: + reseed_rngs(rng_nodes, random_seed) _samples = sample_fn(draws) + return {v_: s_ for v_, s_ in zip(names, _samples)} return inner - def sample(self, draws=500, return_inferencedata=True, **kwargs): + def sample( + self, draws=500, *, random_seed: RandomState = None, return_inferencedata=True, **kwargs + ): """Draw samples from variational posterior. Parameters ---------- - draws: `int` + draws : int Number of random samples. - return_inferencedata: `bool` - Return trace in Arviz format + random_seed : int, RandomState or Generator, optional + Seed for the random number generator. + return_inferencedata : bool + Return trace in Arviz format. Returns ------- @@ -1665,7 +1683,9 @@ def sample(self, draws=500, return_inferencedata=True, **kwargs): # TODO: add tests for include_transformed case kwargs["log_likelihood"] = False - samples = self.sample_dict_fn(draws) # type: dict + if random_seed is not None: + (random_seed,) = _get_seeds_per_chain(random_seed, 1) + samples = self.sample_dict_fn(draws, random_seed=random_seed) # type: dict points = ({name: records[i] for name, records in samples.items()} for i in range(draws)) trace = NDArray(