12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import logging
15
+ import platform
15
16
import warnings
16
17
17
18
import numpy as np
28
29
from pymc .smc .kernels import IMH , systematic_resampling
29
30
from tests .helpers import assert_random_state_equal
30
31
32
+ _IS_WINDOWS = platform .system () == "Windows"
33
+
31
34
32
35
class TestSMC :
33
36
"""Tests for the default SMC kernel"""
@@ -75,7 +78,9 @@ def two_gaussians(x):
75
78
def test_sample (self ):
76
79
initial_rng_state = np .random .get_state ()
77
80
with self .SMC_test :
78
- mtrace = pm .sample_smc (draws = self .samples , return_inferencedata = False )
81
+ mtrace = pm .sample_smc (
82
+ draws = self .samples , return_inferencedata = False , progressbar = not _IS_WINDOWS
83
+ )
79
84
80
85
# Verify sampling was done with a non-global random generator
81
86
assert_random_state_equal (initial_rng_state , np .random .get_state ())
@@ -142,7 +147,9 @@ def test_marginal_likelihood(self):
142
147
with pm .Model () as model :
143
148
a = pm .Beta ("a" , alpha , beta )
144
149
y = pm .Bernoulli ("y" , a , observed = data )
145
- trace = pm .sample_smc (2000 , chains = 2 , return_inferencedata = False )
150
+ trace = pm .sample_smc (
151
+ 2000 , chains = 2 , return_inferencedata = False , progressbar = not _IS_WINDOWS
152
+ )
146
153
# log_marginal_likelihood is found in the last value of each chain
147
154
lml = np .mean ([chain [- 1 ] for chain in trace .report .log_marginal_likelihood ])
148
155
marginals .append (lml )
@@ -203,8 +210,15 @@ def test_return_datatype(self, chains):
203
210
with warnings .catch_warnings ():
204
211
warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
205
212
warnings .filterwarnings ("ignore" , "More chains .* than draws .*" , UserWarning )
206
- idata = pm .sample_smc (chains = chains , draws = draws )
207
- mt = pm .sample_smc (chains = chains , draws = draws , return_inferencedata = False )
213
+ idata = pm .sample_smc (
214
+ chains = chains , draws = draws , progressbar = not (chains > 1 and _IS_WINDOWS )
215
+ )
216
+ mt = pm .sample_smc (
217
+ chains = chains ,
218
+ draws = draws ,
219
+ return_inferencedata = False ,
220
+ progressbar = not (chains > 1 and _IS_WINDOWS ),
221
+ )
208
222
209
223
assert isinstance (idata , InferenceData )
210
224
assert "sample_stats" in idata
@@ -218,7 +232,7 @@ def test_return_datatype(self, chains):
218
232
def test_convergence_checks (self , caplog ):
219
233
with caplog .at_level (logging .INFO ):
220
234
with self .fast_model :
221
- pm .sample_smc (draws = 99 )
235
+ pm .sample_smc (draws = 99 , progressbar = not _IS_WINDOWS )
222
236
assert "The number of samples is too small" in caplog .text
223
237
224
238
def test_deprecated_parallel_arg (self ):
@@ -265,7 +279,7 @@ def test_normal_model(self):
265
279
mu = pm .Normal ("mu" , 0 , 3 )
266
280
sigma = pm .HalfNormal ("sigma" , 1 )
267
281
y = pm .Normal ("y" , mu , sigma , observed = data )
268
- idata = pm .sample_smc (draws = 2000 , kernel = pm .smc .MH )
282
+ idata = pm .sample_smc (draws = 2000 , kernel = pm .smc .MH , progressbar = not _IS_WINDOWS )
269
283
assert_random_state_equal (initial_rng_state , np .random .get_state ())
270
284
271
285
post = idata .posterior .stack (sample = ("chain" , "draw" ))
0 commit comments