Skip to content

Include n_tune, n_draws and t_sampling in SamplerReport #3827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
27 changes: 26 additions & 1 deletion pymc3/backends/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from collections import namedtuple
import logging
import enum
import typing
from ..util import is_transformed_name, get_untransformed_name


Expand Down Expand Up @@ -51,11 +52,15 @@ class WarningType(enum.Enum):


class SamplerReport:
"""This object bundles warnings, convergence statistics and metadata of a sampling run."""
def __init__(self):
self._chain_warnings = {}
self._global_warnings = []
self._ess = None
self._rhat = None
self._n_tune = None
self._n_draws = None
self._t_sampling = None

@property
def _warnings(self):
Expand All @@ -68,6 +73,25 @@ def ok(self):
return all(_LEVELS[warn.level] < _LEVELS['warn']
for warn in self._warnings)

@property
def n_tune(self) -> typing.Optional[int]:
"""Number of tune iterations."""
return self._n_tune

@property
def n_draws(self) -> typing.Optional[int]:
"""Number of draw iterations."""
return self._n_draws

@property
def t_sampling(self) -> typing.Optional[float]:
"""
Number of seconds that the sampling procedure took.

(Includes parallelization overhead.)
"""
return self._t_sampling

def raise_ok(self, level='error'):
errors = [warn for warn in self._warnings
if _LEVELS[warn.level] >= _LEVELS[level]]
Expand Down Expand Up @@ -151,7 +175,8 @@ def _add_warnings(self, warnings, chain=None):
warn_list.extend(warnings)

def _log_summary(self):

if self._n_tune is not None and self._n_draws is not None and self._t_sampling is not None:
logger.info(f'Sampling {self.n_tune} tune and {self.n_draws} draw iterations took {self.t_sampling:.0f} seconds.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this include the number of chains and total number of draws? That might clarify how we count: For 1000 draws in 4 chains, the progressbar will go to 4000, so a message like "Sampling 500 tune and 1000 draws in 4 chains (2000 plus 4000 draws total) took 18 seconds." might be helpful.

What do you think?

Also, I might use {self.n_tune:,d} and {self.n_draws:,d} to get commas in the number, but that's a very weak desire.

Copy link
Member Author

@michaelosthege michaelosthege Mar 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add implement your suggestions tomorrow - thanks!

As a German, I read 10,000 as float(10), but maybe we can go with the pythonic neutral ground of 10_000? (According to SI, one should use a thin space U+202F, but in monospaced fonts that's somewhat pointless.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is with the update I just pushed:

5 chains sequentially, interrupted during tuning of the first:
image

Interrupted during the second chain:
image

Not interrupted:
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha! I thought there was a locale-safe version of this. this looks great to me -- feel free to merge.

def log_warning(warn):
level = _LEVELS[warn.level]
logger.log(level, warn.message)
Expand Down
5 changes: 5 additions & 0 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from copy import copy
import pickle
import logging
import time
import warnings

import numpy as np
Expand Down Expand Up @@ -486,6 +487,7 @@ def sample(
)

parallel = cores > 1 and chains > 1 and not has_population_samplers
t_start = time.time()
if parallel:
_log.info("Multiprocess sampling ({} chains in {} jobs)".format(chains, cores))
_print_step_hierarchy(step)
Expand Down Expand Up @@ -533,6 +535,9 @@ def sample(

discard = tune if discard_tuned_samples else 0
trace = trace[discard:]
trace.report._n_tune = tune
trace.report._n_draws = draws
trace.report._t_sampling = time.time() - t_start

if compute_convergence_checks:
if draws - tune < 100:
Expand Down