Skip to content

Commit dad7d7f

Browse files
committed
add example, fix minor issues
1 parent 2852edc commit dad7d7f

File tree

5 files changed

+543
-11
lines changed

5 files changed

+543
-11
lines changed

docs/source/examples.rst

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ Howto
1212
notebooks/sampler-stats.ipynb
1313
notebooks/Diagnosing_biased_Inference_with_Divergences.ipynb
1414
notebooks/posterior_predictive.ipynb
15+
notebooks/model_comparison.ipynb
16+
notebooks/model_averaging.ipynb
1517
notebooks/howto_debugging.ipynb
1618
notebooks/PyMC3_tips_and_heuristic.ipynb
1719
notebooks/LKJ.ipynb

docs/source/notebooks/model_averaging.ipynb

+495
Large diffs are not rendered by default.

pymc3/examples/data/milk.csv

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
kcal.per.g,neocortex,log_mass
2+
0.490,0.552,0.668
3+
0.470,0.645,1.658
4+
0.560,0.645,1.681
5+
0.890,0.676,0.920
6+
0.920,0.688,-0.386
7+
0.800,0.589,-2.120
8+
0.460,0.617,-0.755
9+
0.710,0.603,-1.139
10+
0.680,0.700,0.438
11+
0.970,0.704,1.176
12+
0.840,0.734,2.510
13+
0.620,0.675,1.681
14+
0.540,0.713,3.569
15+
0.490,0.726,4.375
16+
0.480,0.702,3.707
17+
0.550,0.763,3.500
18+
0.710,0.755,4.006

pymc3/sampling.py

+28-11
Original file line numberDiff line numberDiff line change
@@ -534,12 +534,20 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
534534
if progressbar:
535535
indices = tqdm(indices, total=samples)
536536

537-
ppc = defaultdict(list)
538-
for idx in indices:
539-
param = trace[idx]
540-
for var in vars:
541-
ppc[var.name].append(var.distribution.random(point=param,
542-
size=size))
537+
try:
538+
ppc = defaultdict(list)
539+
for idx in indices:
540+
param = trace[idx]
541+
for var in vars:
542+
ppc[var.name].append(var.distribution.random(point=param,
543+
size=size))
544+
545+
except KeyboardInterrupt:
546+
pass
547+
548+
finally:
549+
if progressbar:
550+
indices.close()
543551

544552
return {k: np.asarray(v) for k, v in ppc.items()}
545553

@@ -629,11 +637,20 @@ def sample_ppc_w(traces, samples=None, models=None, size=None, weights=None,
629637
if progressbar:
630638
indices = tqdm(indices, total=samples)
631639

632-
ppc = defaultdict(list)
633-
for idx in indices:
634-
param = trace[idx]
635-
var = variables[idx]
636-
ppc[var.name].append(var.distribution.random(point=param, size=size))
640+
try:
641+
ppc = defaultdict(list)
642+
for idx in indices:
643+
param = trace[idx]
644+
var = variables[idx]
645+
ppc[var.name].append(var.distribution.random(point=param,
646+
size=size))
647+
648+
except KeyboardInterrupt:
649+
pass
650+
651+
finally:
652+
if progressbar:
653+
indices.close()
637654

638655
return {k: np.asarray(v) for k, v in ppc.items()}
639656

0 commit comments

Comments
 (0)