Skip to content

Commit fc9dabe

Browse files
committed
Update expected behavior of sample-related tests in relation to global seeding.
Together, `test_sample_does_not_set_seed` and `test_parallel_sample_does_not_reuse_seed` covered two unspoken behaviors of `sample`: 1. When no seed is specified, PyMC shall not set global seed state of numpy in the main process. 2. When no seed is specified, sampling will depend on numpy global seeding state for reproducible behavior. Point 1 is due to PyMC legacy dependency on global seeding for step samplers. It tries to minimize "damage" by only setting global seeds when it absolutely needs to, in order to ensure deterministic sampling. Ideally calls to `numpy.seed` would never be made. Point 2 goes against NumPy current best practices of using None when defining new Generators / SeedSequences (https://numpy.org/doc/stable/reference/random/bit_generators/generated/numpy.random.SeedSequence.html#numpy.random.SeedSequence) The refactored tests cover point 1 more directly, and assert the opposite of point 2.
1 parent ca692e9 commit fc9dabe

File tree

1 file changed

+38
-29
lines changed

1 file changed

+38
-29
lines changed

pymc/tests/test_sampling.py

+38-29
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import unittest.mock as mock
1616

1717
from contextlib import ExitStack as does_not_raise
18-
from itertools import combinations
1918
from typing import Tuple
2019

2120
import aesara
@@ -110,35 +109,45 @@ def test_random_seed(self, chains, seeds, cores, init):
110109
else:
111110
assert allequal
112111

113-
def test_sample_does_not_set_seed(self):
114-
# This tests that when random_seed is None, the global seed is not affected
115-
random_numbers = []
116-
for _ in range(2):
112+
@mock.patch("numpy.random.seed")
113+
def test_default_sample_does_not_set_global_seed(self, mocked_seed):
114+
# Test that when random_seed is None, `np.random.seed` is not called in the main
115+
# process. Ideally it would never be called, but PyMC step samplers still rely
116+
# on global seeding for reproducible behavior.
117+
kwargs = dict(tune=2, draws=2, random_seed=None)
118+
with self.model:
119+
pm.sample(chains=1, **kwargs)
120+
pm.sample(chains=2, cores=1, **kwargs)
121+
pm.sample(chains=2, cores=2, **kwargs)
122+
mocked_seed.assert_not_called()
123+
124+
@pytest.mark.xfail(reason="Sampling relies on external global seeding")
125+
def test_sample_does_not_rely_on_external_global_seeding(self):
126+
# Tests that sampling does not depend on exertenal global seeding
127+
kwargs = dict(
128+
tune=2,
129+
draws=20,
130+
random_seed=None,
131+
return_inferencedata=False,
132+
)
133+
with self.model:
134+
np.random.seed(1)
135+
idata11 = pm.sample(chains=1, **kwargs)
136+
np.random.seed(1)
137+
idata12 = pm.sample(chains=2, cores=1, **kwargs)
117138
np.random.seed(1)
118-
with self.model:
119-
pm.sample(1, tune=0, chains=1, random_seed=None)
120-
random_numbers.append(np.random.random())
121-
assert random_numbers[0] == random_numbers[1]
122-
123-
def test_parallel_sample_does_not_reuse_seed(self):
124-
cores = 4
125-
random_numbers = []
126-
draws = []
127-
for _ in range(2):
128-
np.random.seed(1) # seeds in other processes don't effect main process
129-
with self.model:
130-
idata = pm.sample(100, tune=0, cores=cores)
131-
# numpy thread mentioned race condition. might as well check none are equal
132-
for first, second in combinations(range(cores), 2):
133-
first_chain = idata.posterior["x"].sel(chain=first).values
134-
second_chain = idata.posterior["x"].sel(chain=second).values
135-
assert not np.allclose(first_chain, second_chain)
136-
draws.append(idata.posterior["x"].values)
137-
random_numbers.append(np.random.random())
138-
139-
# Make sure future random processes aren't effected by this
140-
assert random_numbers[0] == random_numbers[1]
141-
assert (draws[0] == draws[1]).all()
139+
idata13 = pm.sample(chains=2, cores=2, **kwargs)
140+
141+
np.random.seed(1)
142+
idata21 = pm.sample(chains=1, **kwargs)
143+
np.random.seed(1)
144+
idata22 = pm.sample(chains=2, cores=1, **kwargs)
145+
np.random.seed(1)
146+
idata23 = pm.sample(chains=2, cores=2, **kwargs)
147+
148+
assert np.all(idata11["x"] != idata21["x"])
149+
assert np.all(idata12["x"] != idata22["x"])
150+
assert np.all(idata13["x"] != idata23["x"])
142151

143152
def test_sample(self):
144153
test_cores = [1]

0 commit comments

Comments
 (0)