1
1
import pymc3 as pm
2
2
import numpy as np
3
3
from pymc3 .step_methods import smc
4
+ from pymc3 .backends .smc_text import TextStage
4
5
import pytest
5
6
from tempfile import mkdtemp
6
7
import shutil
13
14
@pytest .mark .xfail (condition = (theano .config .floatX == "float32" ), reason = "Fails on float32" )
14
15
class TestSMC (SeededTest ):
15
16
16
- def setup_method (self ):
17
- super (TestSMC , self ).setup_method ()
17
+ def setup_class (self ):
18
+ super (TestSMC , self ).setup_class ()
18
19
self .test_folder = mkdtemp (prefix = 'ATMIP_TEST' )
19
20
20
- @pytest .mark .parametrize ('n_jobs' , [1 , 2 ])
21
- def test_sample_n_core (self , n_jobs ):
22
- n_chains = 300
23
- n_steps = 100
24
- tune_interval = 25
21
+ self .n_chains = 300
22
+ self .n_steps = 100
23
+ self .tune_interval = 25
25
24
26
25
n = 4
27
26
@@ -36,9 +35,6 @@ def test_sample_n_core(self, n_jobs):
36
35
w1 = stdev
37
36
w2 = (1 - stdev )
38
37
39
- def last_sample (x ):
40
- return x [(n_steps - 1 )::n_steps ]
41
-
42
38
def two_gaussians (x ):
43
39
log_like1 = - 0.5 * n * tt .log (2 * np .pi ) \
44
40
- 0.5 * tt .log (dsigma ) \
@@ -48,7 +44,7 @@ def two_gaussians(x):
48
44
- 0.5 * (x - mu2 ).T .dot (isigma ).dot (x - mu2 )
49
45
return tt .log (w1 * tt .exp (log_like1 ) + w2 * tt .exp (log_like2 ))
50
46
51
- with pm .Model () as ATMIP_test :
47
+ with pm .Model () as self . ATMIP_test :
52
48
X = pm .Uniform ('X' ,
53
49
shape = n ,
54
50
lower = - 2. * np .ones_like (mu1 ),
@@ -58,25 +54,44 @@ def two_gaussians(x):
58
54
like = pm .Deterministic ('like' , two_gaussians (X ))
59
55
llk = pm .Potential ('like_potential' , like )
60
56
61
- with ATMIP_test :
62
- step = smc .SMC (
63
- n_chains = n_chains ,
64
- tune_interval = tune_interval ,
65
- likelihood_name = ATMIP_test .deterministics [0 ].name )
57
+ self .muref = mu1
58
+
59
+ @pytest .mark .parametrize ('n_jobs' , [1 , 2 ])
60
+ def test_sample_n_core (self , n_jobs ):
61
+
62
+ def last_sample (x ):
63
+ return x [(self .n_steps - 1 )::self .n_steps ]
64
+
65
+ step = smc .SMC (
66
+ n_chains = self .n_chains ,
67
+ tune_interval = self .tune_interval ,
68
+ model = self .ATMIP_test ,
69
+ likelihood_name = self .ATMIP_test .deterministics [0 ].name )
66
70
67
71
mtrace = smc .ATMIP_sample (
68
- n_steps = n_steps ,
72
+ n_steps = self . n_steps ,
69
73
step = step ,
70
74
n_jobs = n_jobs ,
71
75
progressbar = True ,
72
76
homepath = self .test_folder ,
73
- model = ATMIP_test ,
77
+ model = self . ATMIP_test ,
74
78
rm_flag = True )
75
79
76
80
d = mtrace .get_values ('X' , combine = True , squeeze = True )
77
81
x = last_sample (d )
78
82
mu1d = np .abs (x ).mean (axis = 0 )
79
- np .testing .assert_allclose (mu1 , mu1d , rtol = 0. , atol = 0.03 )
83
+ np .testing .assert_allclose (self .muref , mu1d , rtol = 0. , atol = 0.03 )
84
+
85
+ def test_stage_handler (self ):
86
+ stage_number = - 1
87
+ stage_handler = TextStage (self .test_folder )
88
+
89
+ step = stage_handler .load_atmip_params (stage_number , model = self .ATMIP_test )
90
+ assert step .stage == stage_number
91
+
92
+ corrupted_chains = stage_handler .recover_existing_results (
93
+ stage_number , self .n_steps , step , n_jobs = 1 , model = self .ATMIP_test )
94
+ assert len (corrupted_chains ) == 0
80
95
81
- def teardown_method (self ):
96
+ def teardown_class (self ):
82
97
shutil .rmtree (self .test_folder )
0 commit comments