Skip to content

Commit c243c9b

Browse files
committed
Fix non-deterministic NUTS initialization when using ADVI
1 parent afbd84b commit c243c9b

File tree

3 files changed

+35
-15
lines changed

3 files changed

+35
-15
lines changed

pymc/sampling.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -2531,7 +2531,9 @@ def init_nuts(
25312531
progressbar=progressbar,
25322532
obj_optimizer=pm.adagrad_window,
25332533
)
2534-
approx_sample = approx.sample(draws=chains, return_inferencedata=False)
2534+
approx_sample = approx.sample(
2535+
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
2536+
)
25352537
initial_points = [approx_sample[i] for i in range(chains)]
25362538
std_apoint = approx.std.eval()
25372539
cov = std_apoint**2
@@ -2549,7 +2551,9 @@ def init_nuts(
25492551
progressbar=progressbar,
25502552
obj_optimizer=pm.adagrad_window,
25512553
)
2552-
approx_sample = approx.sample(draws=chains, return_inferencedata=False)
2554+
approx_sample = approx.sample(
2555+
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
2556+
)
25532557
initial_points = [approx_sample[i] for i in range(chains)]
25542558
cov = approx.std.eval() ** 2
25552559
potential = quadpotential.QuadPotentialDiag(cov)
@@ -2564,7 +2568,9 @@ def init_nuts(
25642568
progressbar=progressbar,
25652569
obj_optimizer=pm.adagrad_window,
25662570
)
2567-
approx_sample = approx.sample(draws=chains, return_inferencedata=False)
2571+
approx_sample = approx.sample(
2572+
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
2573+
)
25682574
initial_points = [approx_sample[i] for i in range(chains)]
25692575
cov = approx.std.eval() ** 2
25702576
potential = quadpotential.QuadPotentialDiag(cov)

pymc/tests/test_sampling.py

-5
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,6 @@ def test_random_seed(self, chains, seeds, cores, init):
100100
allequal = np.all(tr1["x"] == tr2["x"])
101101
if seeds is None:
102102
assert not allequal
103-
# TODO: ADVI init methods are not correctly seeded, as they rely on the state of
104-
# the model RandomState/Generators which is updated in place when the function
105-
# is compiled and evaluated. This elif branch must be removed once this is fixed
106-
elif init == "advi":
107-
assert not allequal
108103
else:
109104
assert allequal
110105

pymc/variational/opvi.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,20 @@
5757

5858
import pymc as pm
5959

60-
from pymc.aesaraf import at_rng, compile_pymc, identity, rvs_to_value_vars
60+
from pymc.aesaraf import (
61+
SeedSequenceSeed,
62+
at_rng,
63+
compile_pymc,
64+
find_rng_nodes,
65+
identity,
66+
reseed_rngs,
67+
rvs_to_value_vars,
68+
)
6169
from pymc.backends import NDArray
6270
from pymc.blocking import DictToArrayBijection
6371
from pymc.initial_point import make_initial_point_fn
6472
from pymc.model import modelcontext
73+
from pymc.sampling import RandomState, _get_seeds_per_chain
6574
from pymc.util import WithMemoization, locally_cachedmethod
6675
from pymc.variational.updates import adagrad_window
6776
from pymc.vartypes import discrete_types
@@ -1641,22 +1650,30 @@ def sample_dict_fn(self):
16411650
sampled = [self.rslice(name) for name in names]
16421651
sampled = self.set_size_and_deterministic(sampled, s, 0)
16431652
sample_fn = compile_pymc([s], sampled)
1653+
rng_nodes = find_rng_nodes(sampled)
16441654

1645-
def inner(draws=100):
1655+
def inner(draws=100, *, random_seed: SeedSequenceSeed = None):
1656+
if random_seed is not None:
1657+
reseed_rngs(rng_nodes, random_seed)
16461658
_samples = sample_fn(draws)
1659+
16471660
return {v_: s_ for v_, s_ in zip(names, _samples)}
16481661

16491662
return inner
16501663

1651-
def sample(self, draws=500, return_inferencedata=True, **kwargs):
1664+
def sample(
1665+
self, draws=500, *, random_seed: RandomState = None, return_inferencedata=True, **kwargs
1666+
):
16521667
"""Draw samples from variational posterior.
16531668
16541669
Parameters
16551670
----------
1656-
draws: `int`
1671+
draws : int
16571672
Number of random samples.
1658-
return_inferencedata: `bool`
1659-
Return trace in Arviz format
1673+
random_seed : int, RandomState or Generator, optional
1674+
Seed for the random number generator.
1675+
return_inferencedata : bool
1676+
Return trace in Arviz format.
16601677
16611678
Returns
16621679
-------
@@ -1666,7 +1683,9 @@ def sample(self, draws=500, return_inferencedata=True, **kwargs):
16661683
# TODO: add tests for include_transformed case
16671684
kwargs["log_likelihood"] = False
16681685

1669-
samples = self.sample_dict_fn(draws) # type: dict
1686+
if random_seed is not None:
1687+
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
1688+
samples = self.sample_dict_fn(draws, random_seed=random_seed) # type: dict
16701689
points = ({name: records[i] for name, records in samples.items()} for i in range(draws))
16711690

16721691
trace = NDArray(

0 commit comments

Comments
 (0)