|
15 | 15 | import unittest.mock as mock
|
16 | 16 |
|
17 | 17 | from contextlib import ExitStack as does_not_raise
|
18 |
| -from itertools import combinations |
19 | 18 | from typing import Tuple
|
20 | 19 |
|
21 | 20 | import aesara
|
@@ -110,35 +109,45 @@ def test_random_seed(self, chains, seeds, cores, init):
|
110 | 109 | else:
|
111 | 110 | assert allequal
|
112 | 111 |
|
113 |
| - def test_sample_does_not_set_seed(self): |
114 |
| - # This tests that when random_seed is None, the global seed is not affected |
115 |
| - random_numbers = [] |
116 |
| - for _ in range(2): |
| 112 | + @mock.patch("numpy.random.seed") |
| 113 | + def test_default_sample_does_not_set_global_seed(self, mocked_seed): |
| 114 | + # Test that when random_seed is None, `np.random.seed` is not called in the main |
| 115 | + # process. Ideally it would never be called, but PyMC step samplers still rely |
| 116 | + # on global seeding for reproducible behavior. |
| 117 | + kwargs = dict(tune=2, draws=2, random_seed=None) |
| 118 | + with self.model: |
| 119 | + pm.sample(chains=1, **kwargs) |
| 120 | + pm.sample(chains=2, cores=1, **kwargs) |
| 121 | + pm.sample(chains=2, cores=2, **kwargs) |
| 122 | + mocked_seed.assert_not_called() |
| 123 | + |
| 124 | + @pytest.mark.xfail(reason="Sampling relies on external global seeding") |
| 125 | + def test_sample_does_not_rely_on_external_global_seeding(self): |
| 126 | + # Tests that sampling does not depend on exertenal global seeding |
| 127 | + kwargs = dict( |
| 128 | + tune=2, |
| 129 | + draws=20, |
| 130 | + random_seed=None, |
| 131 | + return_inferencedata=False, |
| 132 | + ) |
| 133 | + with self.model: |
| 134 | + np.random.seed(1) |
| 135 | + idata11 = pm.sample(chains=1, **kwargs) |
| 136 | + np.random.seed(1) |
| 137 | + idata12 = pm.sample(chains=2, cores=1, **kwargs) |
117 | 138 | np.random.seed(1)
|
118 |
| - with self.model: |
119 |
| - pm.sample(1, tune=0, chains=1, random_seed=None) |
120 |
| - random_numbers.append(np.random.random()) |
121 |
| - assert random_numbers[0] == random_numbers[1] |
122 |
| - |
123 |
| - def test_parallel_sample_does_not_reuse_seed(self): |
124 |
| - cores = 4 |
125 |
| - random_numbers = [] |
126 |
| - draws = [] |
127 |
| - for _ in range(2): |
128 |
| - np.random.seed(1) # seeds in other processes don't effect main process |
129 |
| - with self.model: |
130 |
| - idata = pm.sample(100, tune=0, cores=cores) |
131 |
| - # numpy thread mentioned race condition. might as well check none are equal |
132 |
| - for first, second in combinations(range(cores), 2): |
133 |
| - first_chain = idata.posterior["x"].sel(chain=first).values |
134 |
| - second_chain = idata.posterior["x"].sel(chain=second).values |
135 |
| - assert not np.allclose(first_chain, second_chain) |
136 |
| - draws.append(idata.posterior["x"].values) |
137 |
| - random_numbers.append(np.random.random()) |
138 |
| - |
139 |
| - # Make sure future random processes aren't effected by this |
140 |
| - assert random_numbers[0] == random_numbers[1] |
141 |
| - assert (draws[0] == draws[1]).all() |
| 139 | + idata13 = pm.sample(chains=2, cores=2, **kwargs) |
| 140 | + |
| 141 | + np.random.seed(1) |
| 142 | + idata21 = pm.sample(chains=1, **kwargs) |
| 143 | + np.random.seed(1) |
| 144 | + idata22 = pm.sample(chains=2, cores=1, **kwargs) |
| 145 | + np.random.seed(1) |
| 146 | + idata23 = pm.sample(chains=2, cores=2, **kwargs) |
| 147 | + |
| 148 | + assert np.all(idata11["x"] != idata21["x"]) |
| 149 | + assert np.all(idata12["x"] != idata22["x"]) |
| 150 | + assert np.all(idata13["x"] != idata23["x"]) |
142 | 151 |
|
143 | 152 | def test_sample(self):
|
144 | 153 | test_cores = [1]
|
|
0 commit comments