1818import sys
1919sys .setrecursionlimit (10000 )
2020
21- __all__ = ['sample' , 'iter_sample' , 'sample_ppc' , 'init_nuts' ]
21+ __all__ = ['sample' , 'iter_sample' , 'sample_ppc' , 'sample_ppc_w' , ' init_nuts' ]
2222
2323STEP_METHODS = (NUTS , HamiltonianMC , Metropolis , BinaryMetropolis ,
2424 BinaryGibbsMetropolis , Slice , CategoricalGibbsMetropolis )
@@ -484,14 +484,15 @@ def _update_start_vals(a, b, model):
484484
485485 a .update ({k : v for k , v in b .items () if k not in a })
486486
487+
487488def sample_ppc (trace , samples = None , model = None , vars = None , size = None ,
488489 random_seed = None , progressbar = True ):
489490 """Generate posterior predictive samples from a model given a trace.
490491
491492 Parameters
492493 ----------
493494 trace : backend, list, or MultiTrace
494- Trace generated from MCMC sampling
495+ Trace generated from MCMC sampling.
495496 samples : int
496497 Number of posterior predictive samples to generate. Defaults to the
497498 length of `trace`
@@ -503,12 +504,19 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
503504 size : int
504505 The number of random draws from the distribution specified by the
505506 parameters in each sample of the trace.
507+ random_seed : int
508+ Seed for the random number generator.
509+ progressbar : bool
510+ Whether or not to display a progress bar in the command line. The
511+ bar shows the percentage of completion, the sampling speed in
512+ samples per second (SPS), and the estimated remaining time until
513+ completion ("expected time of arrival"; ETA).
506514
507515 Returns
508516 -------
509517 samples : dict
510- Dictionary with the variables as keys. The values corresponding
511- to the posterior predictive samples.
518+ Dictionary with the variables as keys. The values corresponding to the
519+ posterior predictive samples.
512520 """
513521 if samples is None :
514522 samples = len (trace )
@@ -521,18 +529,128 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
521529
522530 seed (random_seed )
523531
532+ indices = randint (0 , len (trace ), samples )
524533 if progressbar :
525- indices = tqdm (randint (0 , len (trace ), samples ), total = samples )
526- else :
527- indices = randint (0 , len (trace ), samples )
534+ indices = tqdm (indices , total = samples )
535+
536+ try :
537+ ppc = defaultdict (list )
538+ for idx in indices :
539+ param = trace [idx ]
540+ for var in vars :
541+ ppc [var .name ].append (var .distribution .random (point = param ,
542+ size = size ))
543+
544+ except KeyboardInterrupt :
545+ pass
546+
547+ finally :
548+ if progressbar :
549+ indices .close ()
550+
551+ return {k : np .asarray (v ) for k , v in ppc .items ()}
552+
553+
554+ def sample_ppc_w (traces , samples = None , models = None , size = None , weights = None ,
555+ random_seed = None , progressbar = True ):
556+ """Generate weighted posterior predictive samples from a list of models and
557+ a list of traces according to a set of weights.
558+
559+ Parameters
560+ ----------
561+ traces : list
562+ List of traces generated from MCMC sampling. The number of traces should
563+ be equal to the number of weights.
564+ samples : int
565+ Number of posterior predictive samples to generate. Defaults to the
566+ length of the shorter trace in traces.
567+ models : list
568+ List of models used to generate the list of traces. The number of models
569+ should be equal to the number of weights and the number of observed RVs
570+ should be the same for all models.
571+ By default a single model will be inferred from `with` context, in this
572+ case results will only be meaningful if all models share the same
573+ distributions for the observed RVs.
574+ size : int
575+ The number of random draws from the distributions specified by the
576+ parameters in each sample of the trace.
577+ weights: array-like
578+ Individual weights for each trace. Default, same weight for each model.
579+ random_seed : int
580+ Seed for the random number generator.
581+ progressbar : bool
582+ Whether or not to display a progress bar in the command line. The
583+ bar shows the percentage of completion, the sampling speed in
584+ samples per second (SPS), and the estimated remaining time until
585+ completion ("expected time of arrival"; ETA).
586+
587+ Returns
588+ -------
589+ samples : dict
590+ Dictionary with the variables as keys. The values corresponding to the
591+ posterior predictive samples from the weighted models.
592+ """
593+ seed (random_seed )
594+
595+ if models is None :
596+ models = [modelcontext (models )] * len (traces )
597+
598+ if weights is None :
599+ weights = [1 ] * len (traces )
600+
601+ if len (traces ) != len (weights ):
602+ raise ValueError ('The number of traces and weights should be the same' )
603+
604+ if len (models ) != len (weights ):
605+ raise ValueError ('The number of models and weights should be the same' )
606+
607+ lenght_morv = len (models [0 ].observed_RVs )
608+ if not all (len (i .observed_RVs ) == lenght_morv for i in models ):
609+ raise ValueError (
610+ 'The number of observed RVs should be the same for all models' )
611+
612+ weights = np .asarray (weights )
613+ p = weights / np .sum (weights )
614+
615+ min_tr = min ([len (i ) for i in traces ])
616+
617+ n = (min_tr * p ).astype ('int' )
618+ # ensure n sum up to min_tr
619+ idx = np .argmax (n )
620+ n [idx ] = n [idx ] + min_tr - np .sum (n )
528621
529- ppc = defaultdict (list )
530- for idx in indices :
531- param = trace [idx ]
532- for var in vars :
622+ trace = np .concatenate ([np .random .choice (traces [i ], j )
623+ for i , j in enumerate (n )])
624+
625+ variables = []
626+ for i , m in enumerate (models ):
627+ variables .extend (m .observed_RVs * n [i ])
628+
629+ len_trace = len (trace )
630+
631+ if samples is None :
632+ samples = len_trace
633+
634+ indices = randint (0 , len_trace , samples )
635+
636+ if progressbar :
637+ indices = tqdm (indices , total = samples )
638+
639+ try :
640+ ppc = defaultdict (list )
641+ for idx in indices :
642+ param = trace [idx ]
643+ var = variables [idx ]
533644 ppc [var .name ].append (var .distribution .random (point = param ,
534645 size = size ))
535646
647+ except KeyboardInterrupt :
648+ pass
649+
650+ finally :
651+ if progressbar :
652+ indices .close ()
653+
536654 return {k : np .asarray (v ) for k , v in ppc .items ()}
537655
538656
0 commit comments