Skip to content

Commit 6b051f9

Browse files
committed
Disable progressbar in SMC tests on Windows
1 parent 5360939 commit 6b051f9

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

tests/smc/test_smc.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
import platform
1516
import warnings
1617

1718
import numpy as np
@@ -28,6 +29,8 @@
2829
from pymc.smc.kernels import IMH, systematic_resampling
2930
from tests.helpers import assert_random_state_equal
3031

32+
_IS_WINDOWS = platform.system() == "Windows"
33+
3134

3235
class TestSMC:
3336
"""Tests for the default SMC kernel"""
@@ -75,7 +78,9 @@ def two_gaussians(x):
7578
def test_sample(self):
7679
initial_rng_state = np.random.get_state()
7780
with self.SMC_test:
78-
mtrace = pm.sample_smc(draws=self.samples, return_inferencedata=False)
81+
mtrace = pm.sample_smc(
82+
draws=self.samples, return_inferencedata=False, progressbar=not _IS_WINDOWS
83+
)
7984

8085
# Verify sampling was done with a non-global random generator
8186
assert_random_state_equal(initial_rng_state, np.random.get_state())
@@ -142,7 +147,9 @@ def test_marginal_likelihood(self):
142147
with pm.Model() as model:
143148
a = pm.Beta("a", alpha, beta)
144149
y = pm.Bernoulli("y", a, observed=data)
145-
trace = pm.sample_smc(2000, chains=2, return_inferencedata=False)
150+
trace = pm.sample_smc(
151+
2000, chains=2, return_inferencedata=False, progressbar=not _IS_WINDOWS
152+
)
146153
# log_marginal_likelihood is found in the last value of each chain
147154
lml = np.mean([chain[-1] for chain in trace.report.log_marginal_likelihood])
148155
marginals.append(lml)
@@ -203,8 +210,15 @@ def test_return_datatype(self, chains):
203210
with warnings.catch_warnings():
204211
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
205212
warnings.filterwarnings("ignore", "More chains .* than draws .*", UserWarning)
206-
idata = pm.sample_smc(chains=chains, draws=draws)
207-
mt = pm.sample_smc(chains=chains, draws=draws, return_inferencedata=False)
213+
idata = pm.sample_smc(
214+
chains=chains, draws=draws, progressbar=not (chains > 1 and _IS_WINDOWS)
215+
)
216+
mt = pm.sample_smc(
217+
chains=chains,
218+
draws=draws,
219+
return_inferencedata=False,
220+
progressbar=not (chains > 1 and _IS_WINDOWS),
221+
)
208222

209223
assert isinstance(idata, InferenceData)
210224
assert "sample_stats" in idata
@@ -218,7 +232,7 @@ def test_return_datatype(self, chains):
218232
def test_convergence_checks(self, caplog):
219233
with caplog.at_level(logging.INFO):
220234
with self.fast_model:
221-
pm.sample_smc(draws=99)
235+
pm.sample_smc(draws=99, progressbar=not _IS_WINDOWS)
222236
assert "The number of samples is too small" in caplog.text
223237

224238
def test_deprecated_parallel_arg(self):
@@ -265,7 +279,7 @@ def test_normal_model(self):
265279
mu = pm.Normal("mu", 0, 3)
266280
sigma = pm.HalfNormal("sigma", 1)
267281
y = pm.Normal("y", mu, sigma, observed=data)
268-
idata = pm.sample_smc(draws=2000, kernel=pm.smc.MH)
282+
idata = pm.sample_smc(draws=2000, kernel=pm.smc.MH, progressbar=not _IS_WINDOWS)
269283
assert_random_state_equal(initial_rng_state, np.random.get_state())
270284

271285
post = idata.posterior.stack(sample=("chain", "draw"))

0 commit comments

Comments
 (0)