18
18
import sys
19
19
sys .setrecursionlimit (10000 )
20
20
21
- __all__ = ['sample' , 'iter_sample' , 'sample_ppc' , 'init_nuts' ]
21
+ __all__ = ['sample' , 'iter_sample' , 'sample_ppc' , 'sample_ppc_w' , ' init_nuts' ]
22
22
23
23
STEP_METHODS = (NUTS , HamiltonianMC , Metropolis , BinaryMetropolis ,
24
24
BinaryGibbsMetropolis , Slice , CategoricalGibbsMetropolis )
@@ -489,14 +489,15 @@ def _update_start_vals(a, b, model):
489
489
490
490
a .update ({k : v for k , v in b .items () if k not in a })
491
491
492
+
492
493
def sample_ppc (trace , samples = None , model = None , vars = None , size = None ,
493
494
random_seed = None , progressbar = True ):
494
495
"""Generate posterior predictive samples from a model given a trace.
495
496
496
497
Parameters
497
498
----------
498
499
trace : backend, list, or MultiTrace
499
- Trace generated from MCMC sampling
500
+ Trace generated from MCMC sampling.
500
501
samples : int
501
502
Number of posterior predictive samples to generate. Defaults to the
502
503
length of `trace`
@@ -508,12 +509,19 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
508
509
size : int
509
510
The number of random draws from the distribution specified by the
510
511
parameters in each sample of the trace.
512
+ random_seed : int
513
+ Seed for the random number generator.
514
+ progressbar : bool
515
+ Whether or not to display a progress bar in the command line. The
516
+ bar shows the percentage of completion, the sampling speed in
517
+ samples per second (SPS), and the estimated remaining time until
518
+ completion ("expected time of arrival"; ETA).
511
519
512
520
Returns
513
521
-------
514
522
samples : dict
515
- Dictionary with the variables as keys. The values corresponding
516
- to the posterior predictive samples.
523
+ Dictionary with the variables as keys. The values corresponding to the
524
+ posterior predictive samples.
517
525
"""
518
526
if samples is None :
519
527
samples = len (trace )
@@ -526,18 +534,124 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
526
534
527
535
seed (random_seed )
528
536
537
+ indices = randint (0 , len (trace ), samples )
529
538
if progressbar :
530
- indices = tqdm (randint (0 , len (trace ), samples ), total = samples )
531
- else :
532
- indices = randint (0 , len (trace ), samples )
539
+ indices = tqdm (indices , total = samples )
533
540
534
541
try :
535
542
ppc = defaultdict (list )
536
543
for idx in indices :
537
544
param = trace [idx ]
538
545
for var in vars :
539
- vals = var .distribution .random (point = param , size = size )
540
- ppc [var .name ].append (vals )
546
+ ppc [var .name ].append (var .distribution .random (point = param ,
547
+ size = size ))
548
+
549
+ except KeyboardInterrupt :
550
+ pass
551
+
552
+ finally :
553
+ if progressbar :
554
+ indices .close ()
555
+
556
+ return {k : np .asarray (v ) for k , v in ppc .items ()}
557
+
558
+
559
+ def sample_ppc_w (traces , samples = None , models = None , size = None , weights = None ,
560
+ random_seed = None , progressbar = True ):
561
+ """Generate weighted posterior predictive samples from a list of models and
562
+ a list of traces according to a set of weights.
563
+
564
+ Parameters
565
+ ----------
566
+ traces : list
567
+ List of traces generated from MCMC sampling. The number of traces should
568
+ be equal to the number of weights.
569
+ samples : int
570
+ Number of posterior predictive samples to generate. Defaults to the
571
+ length of the shorter trace in traces.
572
+ models : list
573
+ List of models used to generate the list of traces. The number of models
574
+ should be equal to the number of weights and the number of observed RVs
575
+ should be the same for all models.
576
+ By default a single model will be inferred from `with` context, in this
577
+ case results will only be meaningful if all models share the same
578
+ distributions for the observed RVs.
579
+ size : int
580
+ The number of random draws from the distributions specified by the
581
+ parameters in each sample of the trace.
582
+ weights: array-like
583
+ Individual weights for each trace. Default, same weight for each model.
584
+ random_seed : int
585
+ Seed for the random number generator.
586
+ progressbar : bool
587
+ Whether or not to display a progress bar in the command line. The
588
+ bar shows the percentage of completion, the sampling speed in
589
+ samples per second (SPS), and the estimated remaining time until
590
+ completion ("expected time of arrival"; ETA).
591
+
592
+ Returns
593
+ -------
594
+ samples : dict
595
+ Dictionary with the variables as keys. The values corresponding to the
596
+ posterior predictive samples from the weighted models.
597
+ """
598
+ seed (random_seed )
599
+
600
+ if models is None :
601
+ models = [modelcontext (models )] * len (traces )
602
+
603
+ if weights is None :
604
+ weights = [1 ] * len (traces )
605
+
606
+ if len (traces ) != len (weights ):
607
+ raise ValueError ('The number of traces and weights should be the same' )
608
+
609
+ if len (models ) != len (weights ):
610
+ raise ValueError ('The number of models and weights should be the same' )
611
+
612
+ lenght_morv = len (models [0 ].observed_RVs )
613
+ if not all (len (i .observed_RVs ) == lenght_morv for i in models ):
614
+ raise ValueError (
615
+ 'The number of observed RVs should be the same for all models' )
616
+
617
+ weights = np .asarray (weights )
618
+ p = weights / np .sum (weights )
619
+
620
+ min_tr = min ([len (i ) for i in traces ])
621
+
622
+ n = (min_tr * p ).astype ('int' )
623
+ # ensure n sum up to min_tr
624
+ idx = np .argmax (n )
625
+ n [idx ] = n [idx ] + min_tr - np .sum (n )
626
+
627
+ trace = np .concatenate ([np .random .choice (traces [i ], j )
628
+ for i , j in enumerate (n )])
629
+
630
+ variables = []
631
+ for i , m in enumerate (models ):
632
+ variables .extend (m .observed_RVs * n [i ])
633
+
634
+ len_trace = len (trace )
635
+
636
+ if samples is None :
637
+ samples = len_trace
638
+
639
+ indices = randint (0 , len_trace , samples )
640
+
641
+ if progressbar :
642
+ indices = tqdm (indices , total = samples )
643
+
644
+ try :
645
+ ppc = defaultdict (list )
646
+ for idx in indices :
647
+ param = trace [idx ]
648
+ var = variables [idx ]
649
+ ppc [var .name ].append (var .distribution .random (point = param ,
650
+ size = size ))
651
+
652
+ except KeyboardInterrupt :
653
+ pass
654
+
541
655
finally :
542
656
if progressbar :
543
657
indices .close ()
0 commit comments