Skip to content

Commit fd1e5ba

Browse files
authored
Merge pull request #2273 from aloctavodia/sample_ppc_ma
add weighted ppc
2 parents 53fa46b + f231a13 commit fd1e5ba

File tree

5 files changed

+638
-9
lines changed

5 files changed

+638
-9
lines changed

docs/source/examples.rst

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ Howto
1212
notebooks/sampler-stats.ipynb
1313
notebooks/Diagnosing_biased_Inference_with_Divergences.ipynb
1414
notebooks/posterior_predictive.ipynb
15+
notebooks/model_comparison.ipynb
16+
notebooks/model_averaging.ipynb
1517
notebooks/howto_debugging.ipynb
1618
notebooks/PyMC3_tips_and_heuristic.ipynb
1719
notebooks/LKJ.ipynb

docs/source/notebooks/model_averaging.ipynb

+495
Large diffs are not rendered by default.

pymc3/examples/data/milk.csv

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
kcal.per.g,neocortex,log_mass
2+
0.490,0.552,0.668
3+
0.470,0.645,1.658
4+
0.560,0.645,1.681
5+
0.890,0.676,0.920
6+
0.920,0.688,-0.386
7+
0.800,0.589,-2.120
8+
0.460,0.617,-0.755
9+
0.710,0.603,-1.139
10+
0.680,0.700,0.438
11+
0.970,0.704,1.176
12+
0.840,0.734,2.510
13+
0.620,0.675,1.681
14+
0.540,0.713,3.569
15+
0.490,0.726,4.375
16+
0.480,0.702,3.707
17+
0.550,0.763,3.500
18+
0.710,0.755,4.006

pymc3/sampling.py

+123-9
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import sys
1919
sys.setrecursionlimit(10000)
2020

21-
__all__ = ['sample', 'iter_sample', 'sample_ppc', 'init_nuts']
21+
__all__ = ['sample', 'iter_sample', 'sample_ppc', 'sample_ppc_w', 'init_nuts']
2222

2323
STEP_METHODS = (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis,
2424
BinaryGibbsMetropolis, Slice, CategoricalGibbsMetropolis)
@@ -489,14 +489,15 @@ def _update_start_vals(a, b, model):
489489

490490
a.update({k: v for k, v in b.items() if k not in a})
491491

492+
492493
def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
493494
random_seed=None, progressbar=True):
494495
"""Generate posterior predictive samples from a model given a trace.
495496
496497
Parameters
497498
----------
498499
trace : backend, list, or MultiTrace
499-
Trace generated from MCMC sampling
500+
Trace generated from MCMC sampling.
500501
samples : int
501502
Number of posterior predictive samples to generate. Defaults to the
502503
length of `trace`
@@ -508,12 +509,19 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
508509
size : int
509510
The number of random draws from the distribution specified by the
510511
parameters in each sample of the trace.
512+
random_seed : int
513+
Seed for the random number generator.
514+
progressbar : bool
515+
Whether or not to display a progress bar in the command line. The
516+
bar shows the percentage of completion, the sampling speed in
517+
samples per second (SPS), and the estimated remaining time until
518+
completion ("expected time of arrival"; ETA).
511519
512520
Returns
513521
-------
514522
samples : dict
515-
Dictionary with the variables as keys. The values corresponding
516-
to the posterior predictive samples.
523+
Dictionary with the variables as keys. The values corresponding to the
524+
posterior predictive samples.
517525
"""
518526
if samples is None:
519527
samples = len(trace)
@@ -526,18 +534,124 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
526534

527535
seed(random_seed)
528536

537+
indices = randint(0, len(trace), samples)
529538
if progressbar:
530-
indices = tqdm(randint(0, len(trace), samples), total=samples)
531-
else:
532-
indices = randint(0, len(trace), samples)
539+
indices = tqdm(indices, total=samples)
533540

534541
try:
535542
ppc = defaultdict(list)
536543
for idx in indices:
537544
param = trace[idx]
538545
for var in vars:
539-
vals = var.distribution.random(point=param, size=size)
540-
ppc[var.name].append(vals)
546+
ppc[var.name].append(var.distribution.random(point=param,
547+
size=size))
548+
549+
except KeyboardInterrupt:
550+
pass
551+
552+
finally:
553+
if progressbar:
554+
indices.close()
555+
556+
return {k: np.asarray(v) for k, v in ppc.items()}
557+
558+
559+
def sample_ppc_w(traces, samples=None, models=None, size=None, weights=None,
560+
random_seed=None, progressbar=True):
561+
"""Generate weighted posterior predictive samples from a list of models and
562+
a list of traces according to a set of weights.
563+
564+
Parameters
565+
----------
566+
traces : list
567+
List of traces generated from MCMC sampling. The number of traces should
568+
be equal to the number of weights.
569+
samples : int
570+
Number of posterior predictive samples to generate. Defaults to the
571+
length of the shorter trace in traces.
572+
models : list
573+
List of models used to generate the list of traces. The number of models
574+
should be equal to the number of weights and the number of observed RVs
575+
should be the same for all models.
576+
By default a single model will be inferred from `with` context, in this
577+
case results will only be meaningful if all models share the same
578+
distributions for the observed RVs.
579+
size : int
580+
The number of random draws from the distributions specified by the
581+
parameters in each sample of the trace.
582+
weights: array-like
583+
Individual weights for each trace. Default, same weight for each model.
584+
random_seed : int
585+
Seed for the random number generator.
586+
progressbar : bool
587+
Whether or not to display a progress bar in the command line. The
588+
bar shows the percentage of completion, the sampling speed in
589+
samples per second (SPS), and the estimated remaining time until
590+
completion ("expected time of arrival"; ETA).
591+
592+
Returns
593+
-------
594+
samples : dict
595+
Dictionary with the variables as keys. The values corresponding to the
596+
posterior predictive samples from the weighted models.
597+
"""
598+
seed(random_seed)
599+
600+
if models is None:
601+
models = [modelcontext(models)] * len(traces)
602+
603+
if weights is None:
604+
weights = [1] * len(traces)
605+
606+
if len(traces) != len(weights):
607+
raise ValueError('The number of traces and weights should be the same')
608+
609+
if len(models) != len(weights):
610+
raise ValueError('The number of models and weights should be the same')
611+
612+
lenght_morv = len(models[0].observed_RVs)
613+
if not all(len(i.observed_RVs) == lenght_morv for i in models):
614+
raise ValueError(
615+
'The number of observed RVs should be the same for all models')
616+
617+
weights = np.asarray(weights)
618+
p = weights / np.sum(weights)
619+
620+
min_tr = min([len(i) for i in traces])
621+
622+
n = (min_tr * p).astype('int')
623+
# ensure n sum up to min_tr
624+
idx = np.argmax(n)
625+
n[idx] = n[idx] + min_tr - np.sum(n)
626+
627+
trace = np.concatenate([np.random.choice(traces[i], j)
628+
for i, j in enumerate(n)])
629+
630+
variables = []
631+
for i, m in enumerate(models):
632+
variables.extend(m.observed_RVs * n[i])
633+
634+
len_trace = len(trace)
635+
636+
if samples is None:
637+
samples = len_trace
638+
639+
indices = randint(0, len_trace, samples)
640+
641+
if progressbar:
642+
indices = tqdm(indices, total=samples)
643+
644+
try:
645+
ppc = defaultdict(list)
646+
for idx in indices:
647+
param = trace[idx]
648+
var = variables[idx]
649+
ppc[var.name].append(var.distribution.random(point=param,
650+
size=size))
651+
652+
except KeyboardInterrupt:
653+
pass
654+
541655
finally:
542656
if progressbar:
543657
indices.close()

0 commit comments

Comments
 (0)