Skip to content

Commit 6c5254f

Browse files
Include n_tune, n_draws and t_sampling in SamplerReport (#3827)
* include n_tune, n_draws and t_sampling in SamplerReport * count tune/draw samples instead of trusting parameters (because of KeyboardInterrupt) * fall back to tune and len(trace) if tune stat is unavailable * add test for SamplerReport n_tune and n_draws * clarify that n_tune are not necessarily in the trace * use actual number of chains to compute totals * mention new SamplerReport properties in release notes Co-authored-by: Michael Osthege <[email protected]>
1 parent b5891be commit 6c5254f

File tree

4 files changed

+73
-3
lines changed

4 files changed

+73
-3
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- `DEMetropolisZ`, an improved variant of `DEMetropolis` brings better parallelization and higher efficiency with fewer chains with a slower initial convergence. This implementation is experimental. See [#3784](https://github.com/pymc-devs/pymc3/pull/3784) for more info.
99
- Notebooks that give insight into `DEMetropolis`, `DEMetropolisZ` and the `DifferentialEquation` interface are now located in the [Tutorials/Deep Dive](https://docs.pymc.io/nb_tutorials/index.html) section.
1010
- Add `fast_sample_posterior_predictive`, a vectorized alternative to `sample_posterior_predictive`. This alternative is substantially faster for large models.
11+
- `SamplerReport` (`MultiTrace.report`) now has properties `n_tune`, `n_draws`, `t_sampling` for increased convenience (see [#3827](https://github.com/pymc-devs/pymc3/pull/3827))
1112

1213
### Maintenance
1314
- Remove `sample_ppc` and `sample_ppc_w` that were deprecated in 3.6.

pymc3/backends/report.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from collections import namedtuple
1616
import logging
1717
import enum
18+
import typing
1819
from ..util import is_transformed_name, get_untransformed_name
1920

2021

@@ -51,11 +52,15 @@ class WarningType(enum.Enum):
5152

5253

5354
class SamplerReport:
55+
"""This object bundles warnings, convergence statistics and metadata of a sampling run."""
5456
def __init__(self):
5557
self._chain_warnings = {}
5658
self._global_warnings = []
5759
self._ess = None
5860
self._rhat = None
61+
self._n_tune = None
62+
self._n_draws = None
63+
self._t_sampling = None
5964

6065
@property
6166
def _warnings(self):
@@ -68,6 +73,25 @@ def ok(self):
6873
return all(_LEVELS[warn.level] < _LEVELS['warn']
6974
for warn in self._warnings)
7075

76+
@property
77+
def n_tune(self) -> typing.Optional[int]:
78+
"""Number of tune iterations - not necessarily kept in trace!"""
79+
return self._n_tune
80+
81+
@property
82+
def n_draws(self) -> typing.Optional[int]:
83+
"""Number of draw iterations."""
84+
return self._n_draws
85+
86+
@property
87+
def t_sampling(self) -> typing.Optional[float]:
88+
"""
89+
Number of seconds that the sampling procedure took.
90+
91+
(Includes parallelization overhead.)
92+
"""
93+
return self._t_sampling
94+
7195
def raise_ok(self, level='error'):
7296
errors = [warn for warn in self._warnings
7397
if _LEVELS[warn.level] >= _LEVELS[level]]
@@ -151,7 +175,6 @@ def _add_warnings(self, warnings, chain=None):
151175
warn_list.extend(warnings)
152176

153177
def _log_summary(self):
154-
155178
def log_warning(warn):
156179
level = _LEVELS[warn.level]
157180
logger.log(level, warn.message)

pymc3/sampling.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from copy import copy
2525
import pickle
2626
import logging
27+
import time
2728
import warnings
2829

2930
import numpy as np
@@ -488,6 +489,7 @@ def sample(
488489
)
489490

490491
parallel = cores > 1 and chains > 1 and not has_population_samplers
492+
t_start = time.time()
491493
if parallel:
492494
_log.info("Multiprocess sampling ({} chains in {} jobs)".format(chains, cores))
493495
_print_step_hierarchy(step)
@@ -533,8 +535,36 @@ def sample(
533535
_print_step_hierarchy(step)
534536
trace = _sample_many(**sample_args)
535537

536-
discard = tune if discard_tuned_samples else 0
537-
trace = trace[discard:]
538+
t_sampling = time.time() - t_start
539+
# count the number of tune/draw iterations that happened
540+
# ideally via the "tune" statistic, but not all samplers record it!
541+
if 'tune' in trace.stat_names:
542+
stat = trace.get_sampler_stats('tune', chains=0)
543+
# when CompoundStep is used, the stat is 2 dimensional!
544+
if len(stat.shape) == 2:
545+
stat = stat[:,0]
546+
stat = tuple(stat)
547+
n_tune = stat.count(True)
548+
n_draws = stat.count(False)
549+
else:
550+
# these may be wrong when KeyboardInterrupt happened, but they're better than nothing
551+
n_tune = min(tune, len(trace))
552+
n_draws = max(0, len(trace) - n_tune)
553+
554+
if discard_tuned_samples:
555+
trace = trace[n_tune:]
556+
557+
# save metadata in SamplerReport
558+
trace.report._n_tune = n_tune
559+
trace.report._n_draws = n_draws
560+
trace.report._t_sampling = t_sampling
561+
562+
n_chains = len(trace.chains)
563+
_log.info(
564+
f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations '
565+
f'({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) '
566+
f'took {trace.report.t_sampling:.0f} seconds.'
567+
)
538568

539569
if compute_convergence_checks:
540570
if draws - tune < 100:

pymc3/tests/test_sampling.py

+16
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,22 @@ def test_sample_tune_len(self):
142142
trace = pm.sample(draws=100, tune=50, cores=4)
143143
assert len(trace) == 100
144144

145+
@pytest.mark.parametrize("step_cls", [pm.NUTS, pm.Metropolis, pm.Slice])
146+
@pytest.mark.parametrize("discard", [True, False])
147+
def test_trace_report(self, step_cls, discard):
148+
with self.model:
149+
# add more variables, because stats are 2D with CompoundStep!
150+
pm.Uniform('uni')
151+
trace = pm.sample(
152+
draws=100, tune=50, cores=1,
153+
discard_tuned_samples=discard,
154+
step=step_cls()
155+
)
156+
assert trace.report.n_tune == 50
157+
assert trace.report.n_draws == 100
158+
assert isinstance(trace.report.t_sampling, float)
159+
pass
160+
145161
@pytest.mark.parametrize('cores', [1, 2])
146162
def test_sampler_stat_tune(self, cores):
147163
with self.model:

0 commit comments

Comments
 (0)