From a1bf1ff6b73e256f980f76fe27c384fce3abbfb9 Mon Sep 17 00:00:00 2001 From: Colin Carroll Date: Tue, 16 Jul 2019 13:21:14 -0400 Subject: [PATCH 1/2] Add live divergence statistics --- pymc3/parallel_sampling.py | 8 +++++++- pymc3/sampling.py | 26 +++++++++++++++++--------- pymc3/tests/test_sampling.py | 2 +- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py index ed8bafd4aa..6ed241c20a 100644 --- a/pymc3/parallel_sampling.py +++ b/pymc3/parallel_sampling.py @@ -368,11 +368,14 @@ def __init__( self._start_chain_num = start_chain_num self._progress = None + self._divergences = 0 + self._desc = "Sampling {0._chains:d} chains, {0._divergences:,d} divergences" + self._chains = chains if progressbar: self._progress = tqdm( total=chains * (draws + tune), unit="draws", - desc="Sampling %s chains" % chains, + desc=self._desc.format(self) ) def _make_active(self): @@ -391,6 +394,9 @@ def __iter__(self): draw = ProcessAdapter.recv_draw(self._active) proc, is_last, draw, tuning, stats, warns = draw if self._progress is not None: + if not tuning and stats and stats[0].get('diverging'): + self._divergences += 1 + self._progress.set_description(self._desc.format(self)) self._progress.update() if is_last: diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 6d01b6fa67..8fb17df4f5 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -24,7 +24,7 @@ from .exceptions import IncorrectArgumentsError from pymc3.step_methods.hmc import quadpotential import pymc3 as pm -from tqdm import tqdm +from tqdm import tqdm import sys @@ -539,13 +539,19 @@ def _sample(chain, progressbar, random_seed, start, draws=None, step=None, sampling = _iter_sample(draws, step, start, trace, chain, tune, model, random_seed) + _pbar_data = None if progressbar: - sampling = tqdm(sampling, total=draws) + _pbar_data = {"chain": chain, "divergences": 0} + _desc = "Sampling chain {chain:d}, {divergences:,d} divergences" + sampling = tqdm(sampling, total=draws, desc=_desc.format(**_pbar_data)) try: strace = None - for it, strace in enumerate(sampling): + for it, (strace, diverging) in enumerate(sampling): if it >= skip_first: trace = MultiTrace([strace]) + if diverging and _pbar_data is not None: + _pbar_data["divergences"] += 1 + sampling.set_description(_desc.format(**_pbar_data)) except KeyboardInterrupt: pass finally: @@ -591,7 +597,7 @@ def iter_sample(draws, step, start=None, trace=None, chain=0, tune=None, """ sampling = _iter_sample(draws, step, start, trace, chain, tune, model, random_seed) - for i, strace in enumerate(sampling): + for i, (strace, _) in enumerate(sampling): yield MultiTrace([strace[:i + 1]]) @@ -632,15 +638,17 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None, if i == tune: step = stop_tuning(step) if step.generates_stats: - point, states = step.step(point) + point, stats = step.step(point) if strace.supports_sampler_stats: - strace.record(point, states) + strace.record(point, stats) + diverging = i > tune and stats and stats[0].get('diverging') else: strace.record(point) else: point = step.step(point) strace.record(point) - yield strace + diverging = False + yield strace, diverging except KeyboardInterrupt: strace.close() if hasattr(step, 'warnings'): @@ -892,9 +900,9 @@ def _iter_population(draws, tune, popstep, steppers, traces, points): # apply the update to the points and record to the traces for c, strace in enumerate(traces): if steppers[c].generates_stats: - points[c], states = updates[c] + points[c], stats = updates[c] if strace.supports_sampler_stats: - strace.record(points[c], states) + strace.record(points[c], stats) else: strace.record(points[c]) else: diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 9fd36af11d..1ce5cba00b 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -102,7 +102,7 @@ def test_iter_sample(self): tune=0, random_seed=self.random_seed, ) - for i, trace in enumerate(samps): + for i, (trace, _) in enumerate(samps): assert i == len(trace) - 1, "Trace does not have correct length." def test_parallel_start(self): From c4ab3094dce288615f203e389f7b24b1dcd9ab43 Mon Sep 17 00:00:00 2001 From: Colin Carroll Date: Tue, 16 Jul 2019 14:51:05 -0400 Subject: [PATCH 2/2] Overly enthusiastic on the refactor --- pymc3/tests/test_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 1ce5cba00b..9fd36af11d 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -102,7 +102,7 @@ def test_iter_sample(self): tune=0, random_seed=self.random_seed, ) - for i, (trace, _) in enumerate(samps): + for i, trace in enumerate(samps): assert i == len(trace) - 1, "Trace does not have correct length." def test_parallel_start(self):