18
18
import sys
19
19
sys .setrecursionlimit (10000 )
20
20
21
- __all__ = ['sample' , 'iter_sample' , 'sample_ppc' , 'init_nuts' ]
21
+ __all__ = ['sample' , 'iter_sample' , 'sample_ppc' , 'sample_ppc_w' , ' init_nuts' ]
22
22
23
23
STEP_METHODS = (NUTS , HamiltonianMC , Metropolis , BinaryMetropolis ,
24
24
BinaryGibbsMetropolis , Slice , CategoricalGibbsMetropolis )
@@ -484,14 +484,15 @@ def _update_start_vals(a, b, model):
484
484
485
485
a .update ({k : v for k , v in b .items () if k not in a })
486
486
487
+
487
488
def sample_ppc (trace , samples = None , model = None , vars = None , size = None ,
488
489
random_seed = None , progressbar = True ):
489
490
"""Generate posterior predictive samples from a model given a trace.
490
491
491
492
Parameters
492
493
----------
493
494
trace : backend, list, or MultiTrace
494
- Trace generated from MCMC sampling
495
+ Trace generated from MCMC sampling.
495
496
samples : int
496
497
Number of posterior predictive samples to generate. Defaults to the
497
498
length of `trace`
@@ -503,12 +504,19 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
503
504
size : int
504
505
The number of random draws from the distribution specified by the
505
506
parameters in each sample of the trace.
507
+ random_seed : int
508
+ Seed for the random number generator.
509
+ progressbar : bool
510
+ Whether or not to display a progress bar in the command line. The
511
+ bar shows the percentage of completion, the sampling speed in
512
+ samples per second (SPS), and the estimated remaining time until
513
+ completion ("expected time of arrival"; ETA).
506
514
507
515
Returns
508
516
-------
509
517
samples : dict
510
- Dictionary with the variables as keys. The values corresponding
511
- to the posterior predictive samples.
518
+ Dictionary with the variables as keys. The values corresponding to the
519
+ posterior predictive samples.
512
520
"""
513
521
if samples is None :
514
522
samples = len (trace )
@@ -521,18 +529,128 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
521
529
522
530
seed (random_seed )
523
531
532
+ indices = randint (0 , len (trace ), samples )
524
533
if progressbar :
525
- indices = tqdm (randint (0 , len (trace ), samples ), total = samples )
526
- else :
527
- indices = randint (0 , len (trace ), samples )
534
+ indices = tqdm (indices , total = samples )
535
+
536
+ try :
537
+ ppc = defaultdict (list )
538
+ for idx in indices :
539
+ param = trace [idx ]
540
+ for var in vars :
541
+ ppc [var .name ].append (var .distribution .random (point = param ,
542
+ size = size ))
543
+
544
+ except KeyboardInterrupt :
545
+ pass
546
+
547
+ finally :
548
+ if progressbar :
549
+ indices .close ()
550
+
551
+ return {k : np .asarray (v ) for k , v in ppc .items ()}
552
+
553
+
554
+ def sample_ppc_w (traces , samples = None , models = None , size = None , weights = None ,
555
+ random_seed = None , progressbar = True ):
556
+ """Generate weighted posterior predictive samples from a list of models and
557
+ a list of traces according to a set of weights.
558
+
559
+ Parameters
560
+ ----------
561
+ traces : list
562
+ List of traces generated from MCMC sampling. The number of traces should
563
+ be equal to the number of weights.
564
+ samples : int
565
+ Number of posterior predictive samples to generate. Defaults to the
566
+ length of the shorter trace in traces.
567
+ models : list
568
+ List of models used to generate the list of traces. The number of models
569
+ should be equal to the number of weights and the number of observed RVs
570
+ should be the same for all models.
571
+ By default a single model will be inferred from `with` context, in this
572
+ case results will only be meaningful if all models share the same
573
+ distributions for the observed RVs.
574
+ size : int
575
+ The number of random draws from the distributions specified by the
576
+ parameters in each sample of the trace.
577
+ weights: array-like
578
+ Individual weights for each trace. Default, same weight for each model.
579
+ random_seed : int
580
+ Seed for the random number generator.
581
+ progressbar : bool
582
+ Whether or not to display a progress bar in the command line. The
583
+ bar shows the percentage of completion, the sampling speed in
584
+ samples per second (SPS), and the estimated remaining time until
585
+ completion ("expected time of arrival"; ETA).
586
+
587
+ Returns
588
+ -------
589
+ samples : dict
590
+ Dictionary with the variables as keys. The values corresponding to the
591
+ posterior predictive samples from the weighted models.
592
+ """
593
+ seed (random_seed )
594
+
595
+ if models is None :
596
+ models = [modelcontext (models )] * len (traces )
597
+
598
+ if weights is None :
599
+ weights = [1 ] * len (traces )
600
+
601
+ if len (traces ) != len (weights ):
602
+ raise ValueError ('The number of traces and weights should be the same' )
603
+
604
+ if len (models ) != len (weights ):
605
+ raise ValueError ('The number of models and weights should be the same' )
606
+
607
+ lenght_morv = len (models [0 ].observed_RVs )
608
+ if not all (len (i .observed_RVs ) == lenght_morv for i in models ):
609
+ raise ValueError (
610
+ 'The number of observed RVs should be the same for all models' )
611
+
612
+ weights = np .asarray (weights )
613
+ p = weights / np .sum (weights )
614
+
615
+ min_tr = min ([len (i ) for i in traces ])
616
+
617
+ n = (min_tr * p ).astype ('int' )
618
+ # ensure n sum up to min_tr
619
+ idx = np .argmax (n )
620
+ n [idx ] = n [idx ] + min_tr - np .sum (n )
528
621
529
- ppc = defaultdict (list )
530
- for idx in indices :
531
- param = trace [idx ]
532
- for var in vars :
622
+ trace = np .concatenate ([np .random .choice (traces [i ], j )
623
+ for i , j in enumerate (n )])
624
+
625
+ variables = []
626
+ for i , m in enumerate (models ):
627
+ variables .extend (m .observed_RVs * n [i ])
628
+
629
+ len_trace = len (trace )
630
+
631
+ if samples is None :
632
+ samples = len_trace
633
+
634
+ indices = randint (0 , len_trace , samples )
635
+
636
+ if progressbar :
637
+ indices = tqdm (indices , total = samples )
638
+
639
+ try :
640
+ ppc = defaultdict (list )
641
+ for idx in indices :
642
+ param = trace [idx ]
643
+ var = variables [idx ]
533
644
ppc [var .name ].append (var .distribution .random (point = param ,
534
645
size = size ))
535
646
647
+ except KeyboardInterrupt :
648
+ pass
649
+
650
+ finally :
651
+ if progressbar :
652
+ indices .close ()
653
+
536
654
return {k : np .asarray (v ) for k , v in ppc .items ()}
537
655
538
656
0 commit comments