Skip to content

Commit c6b786b

Browse files
hvasbathColCarroll
authored andcommitted
fixes in smc backend refactoring, added test (#2286)
* fixes in smc backend refactoring, added test * added multicore testing again
1 parent 88e2941 commit c6b786b

File tree

3 files changed

+47
-26
lines changed

3 files changed

+47
-26
lines changed

pymc3/backends/smc_text.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,13 @@ def highest_sampled_stage(self):
177177
-------
178178
stage number : int
179179
"""
180-
return max(self.stage_number(s) for s in glob(self.path('*')))
180+
return max(self.stage_number(s) for s in glob(self.stage_path('*')))
181181

182182
def atmip_path(self, stage_number):
183183
"""Consistent naming for atmip params."""
184184
return os.path.join(self.stage_path(stage_number), 'atmip.params.pkl')
185185

186-
def load_atmip_params(self, stage_number):
186+
def load_atmip_params(self, stage_number, model):
187187
"""Load saved parameters from last sampled ATMIP stage.
188188
189189
Parameters
@@ -196,8 +196,14 @@ def load_atmip_params(self, stage_number):
196196
else:
197197
prev = stage_number - 1
198198
pm._log.info('Loading parameters from completed stage {}'.format(prev))
199-
with open(self.atmip_path(prev), 'rb') as buff:
200-
return pickle.load(buff)
199+
200+
with model:
201+
with open(self.atmip_path(prev), 'rb') as buff:
202+
step = pickle.load(buff)
203+
204+
# update step stage to current stage
205+
step.stage = stage_number
206+
return step
201207

202208
def dump_atmip_params(self, step):
203209
"""Save atmip params to file."""
@@ -278,7 +284,7 @@ def recover_existing_results(self, stage, draws, step, n_jobs, model=None):
278284
# load incomplete stage results
279285
pm._log.info('Reloading existing results ...')
280286
mtrace = self.load_multitrace(stage, model=model)
281-
if len(mtrace) > 0:
287+
if len(mtrace.chains) > 0:
282288
# continue sampling if traces exist
283289
pm._log.info('Checking for corrupted files ...')
284290
return self.check_multitrace(mtrace, draws=draws, n_chains=step.n_chains)

pymc3/step_methods/smc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def ATMIP_sample(n_steps, step=None, start=None, homepath=None, chain=0, stage=0
520520
step.stage = stage
521521
draws = 1
522522
else:
523-
step = stage_handler.load_atmip_params(stage)
523+
step = stage_handler.load_atmip_params(stage, model=model)
524524
draws = step.n_steps
525525

526526
stage_handler.clean_directory(stage, None, rm_flag)

pymc3/tests/test_smc.py

+35-20
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pymc3 as pm
22
import numpy as np
33
from pymc3.step_methods import smc
4+
from pymc3.backends.smc_text import TextStage
45
import pytest
56
from tempfile import mkdtemp
67
import shutil
@@ -13,15 +14,13 @@
1314
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
1415
class TestSMC(SeededTest):
1516

16-
def setup_method(self):
17-
super(TestSMC, self).setup_method()
17+
def setup_class(self):
18+
super(TestSMC, self).setup_class()
1819
self.test_folder = mkdtemp(prefix='ATMIP_TEST')
1920

20-
@pytest.mark.parametrize('n_jobs', [1, 2])
21-
def test_sample_n_core(self, n_jobs):
22-
n_chains = 300
23-
n_steps = 100
24-
tune_interval = 25
21+
self.n_chains = 300
22+
self.n_steps = 100
23+
self.tune_interval = 25
2524

2625
n = 4
2726

@@ -36,9 +35,6 @@ def test_sample_n_core(self, n_jobs):
3635
w1 = stdev
3736
w2 = (1 - stdev)
3837

39-
def last_sample(x):
40-
return x[(n_steps - 1)::n_steps]
41-
4238
def two_gaussians(x):
4339
log_like1 = - 0.5 * n * tt.log(2 * np.pi) \
4440
- 0.5 * tt.log(dsigma) \
@@ -48,7 +44,7 @@ def two_gaussians(x):
4844
- 0.5 * (x - mu2).T.dot(isigma).dot(x - mu2)
4945
return tt.log(w1 * tt.exp(log_like1) + w2 * tt.exp(log_like2))
5046

51-
with pm.Model() as ATMIP_test:
47+
with pm.Model() as self.ATMIP_test:
5248
X = pm.Uniform('X',
5349
shape=n,
5450
lower=-2. * np.ones_like(mu1),
@@ -58,25 +54,44 @@ def two_gaussians(x):
5854
like = pm.Deterministic('like', two_gaussians(X))
5955
llk = pm.Potential('like_potential', like)
6056

61-
with ATMIP_test:
62-
step = smc.SMC(
63-
n_chains=n_chains,
64-
tune_interval=tune_interval,
65-
likelihood_name=ATMIP_test.deterministics[0].name)
57+
self.muref = mu1
58+
59+
@pytest.mark.parametrize('n_jobs', [1, 2])
60+
def test_sample_n_core(self, n_jobs):
61+
62+
def last_sample(x):
63+
return x[(self.n_steps - 1)::self.n_steps]
64+
65+
step = smc.SMC(
66+
n_chains=self.n_chains,
67+
tune_interval=self.tune_interval,
68+
model=self.ATMIP_test,
69+
likelihood_name=self.ATMIP_test.deterministics[0].name)
6670

6771
mtrace = smc.ATMIP_sample(
68-
n_steps=n_steps,
72+
n_steps=self.n_steps,
6973
step=step,
7074
n_jobs=n_jobs,
7175
progressbar=True,
7276
homepath=self.test_folder,
73-
model=ATMIP_test,
77+
model=self.ATMIP_test,
7478
rm_flag=True)
7579

7680
d = mtrace.get_values('X', combine=True, squeeze=True)
7781
x = last_sample(d)
7882
mu1d = np.abs(x).mean(axis=0)
79-
np.testing.assert_allclose(mu1, mu1d, rtol=0., atol=0.03)
83+
np.testing.assert_allclose(self.muref, mu1d, rtol=0., atol=0.03)
84+
85+
def test_stage_handler(self):
86+
stage_number = -1
87+
stage_handler = TextStage(self.test_folder)
88+
89+
step = stage_handler.load_atmip_params(stage_number, model=self.ATMIP_test)
90+
assert step.stage == stage_number
91+
92+
corrupted_chains = stage_handler.recover_existing_results(
93+
stage_number, self.n_steps, step, n_jobs=1, model=self.ATMIP_test)
94+
assert len(corrupted_chains) == 0
8095

81-
def teardown_method(self):
96+
def teardown_class(self):
8297
shutil.rmtree(self.test_folder)

0 commit comments

Comments
 (0)