Skip to content

Commit b91283e

Browse files
ricardoV94twiecki
authored andcommitted
Pass seed to find_map and set global seeds in _iter_sample and _prepare_iter_sample
This reverts some changes in 47b61de which wrongly disabled global seeding in some sampling contexts that still depended on it.
1 parent 56ad6a9 commit b91283e

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

pymc/sampling.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,9 @@ def _iter_sample(
969969
if draws < 1:
970970
raise ValueError("Argument `draws` must be greater than 0.")
971971

972+
if random_seed is not None:
973+
np.random.seed(random_seed)
974+
972975
try:
973976
step = CompoundStep(step)
974977
except TypeError:
@@ -1229,6 +1232,9 @@ def _prepare_iter_population(
12291232
if draws < 1:
12301233
raise ValueError("Argument `draws` should be above 0.")
12311234

1235+
if random_seed is not None:
1236+
np.random.seed(random_seed)
1237+
12321238
# The initialization of traces, samplers and points must happen in the right order:
12331239
# 1. population of points is created
12341240
# 2. steppers are initialized and linked to the points object
@@ -2511,7 +2517,7 @@ def init_nuts(
25112517
cov = approx.std.eval() ** 2
25122518
potential = quadpotential.QuadPotentialDiag(cov)
25132519
elif init == "advi_map":
2514-
start = pm.find_MAP(include_transformed=True)
2520+
start = pm.find_MAP(include_transformed=True, seed=seeds[0])
25152521
approx = pm.MeanField(model=model, start=start)
25162522
pm.fit(
25172523
random_seed=seeds[0],
@@ -2526,7 +2532,7 @@ def init_nuts(
25262532
cov = approx.std.eval() ** 2
25272533
potential = quadpotential.QuadPotentialDiag(cov)
25282534
elif init == "map":
2529-
start = pm.find_MAP(include_transformed=True)
2535+
start = pm.find_MAP(include_transformed=True, seed=seeds[0])
25302536
cov = pm.find_hessian(point=start)
25312537
initial_points = [start] * chains
25322538
potential = quadpotential.QuadPotentialFull(cov)

pymc/tests/test_sampling.py

+48
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,54 @@ def setup_method(self):
6262
super().setup_method()
6363
self.model, self.start, self.step, _ = simple_init()
6464

65+
@pytest.mark.parametrize("init", ("jitter+adapt_diag", "advi", "map"))
66+
@pytest.mark.parametrize("cores", (1, 2))
67+
@pytest.mark.parametrize(
68+
"chains, seeds",
69+
[
70+
(1, None),
71+
(1, 1),
72+
(1, [1]),
73+
(2, None),
74+
(2, 1),
75+
(2, [1, 2]),
76+
],
77+
)
78+
def test_random_seed(self, chains, seeds, cores, init):
79+
with pm.Model(rng_seeder=3):
80+
x = pm.Normal("x", 0, 10, initval="prior")
81+
tr1 = pm.sample(
82+
chains=chains,
83+
random_seed=seeds,
84+
cores=cores,
85+
init=init,
86+
tune=0,
87+
draws=10,
88+
return_inferencedata=False,
89+
compute_convergence_checks=False,
90+
)
91+
tr2 = pm.sample(
92+
chains=chains,
93+
random_seed=seeds,
94+
cores=cores,
95+
init=init,
96+
tune=0,
97+
draws=10,
98+
return_inferencedata=False,
99+
compute_convergence_checks=False,
100+
)
101+
102+
allequal = np.all(tr1["x"] == tr2["x"])
103+
if seeds is None:
104+
assert not allequal
105+
# TODO: ADVI init methods are not correctly seeded, as they rely on the state of
106+
# the model RandomState/Generators which is updated in place when the function
107+
# is compiled and evaluated. This elif branch must be removed once this is fixed
108+
elif init == "advi":
109+
assert not allequal
110+
else:
111+
assert allequal
112+
65113
def test_sample_does_not_set_seed(self):
66114
# This tests that when random_seed is None, the global seed is not affected
67115
random_numbers = []

0 commit comments

Comments
 (0)