@@ -226,9 +226,7 @@ def _print_step_hierarchy(s, level=0):
226
226
else :
227
227
varnames = ", " .join (
228
228
[
229
- get_untransformed_name (v .name )
230
- if is_transformed_name (v .name )
231
- else v .name
229
+ get_untransformed_name (v .name ) if is_transformed_name (v .name ) else v .name
232
230
for v in s .vars
233
231
]
234
232
)
@@ -491,10 +489,7 @@ def sample(
491
489
start = start_
492
490
except (AttributeError , NotImplementedError , tg .NullTypeGradError ):
493
491
# gradient computation failed
494
- _log .info (
495
- "Initializing NUTS failed. "
496
- "Falling back to elementwise auto-assignment."
497
- )
492
+ _log .info ("Initializing NUTS failed. " "Falling back to elementwise auto-assignment." )
498
493
_log .debug ("Exception in init nuts" , exec_info = True )
499
494
step = assign_step_methods (model , step , step_kwargs = kwargs )
500
495
else :
@@ -559,9 +554,7 @@ def sample(
559
554
has_demcmc = np .any (
560
555
[
561
556
isinstance (m , DEMetropolis )
562
- for m in (
563
- step .methods if isinstance (step , CompoundStep ) else [step ]
564
- )
557
+ for m in (step .methods if isinstance (step , CompoundStep ) else [step ])
565
558
]
566
559
)
567
560
_log .info ("Population sampling ({} chains)" .format (chains ))
@@ -625,9 +618,7 @@ def sample(
625
618
626
619
if compute_convergence_checks :
627
620
if draws - tune < 100 :
628
- warnings .warn (
629
- "The number of samples is too small to check convergence reliably."
630
- )
621
+ warnings .warn ("The number of samples is too small to check convergence reliably." )
631
622
else :
632
623
trace .report ._run_convergence_checks (idata , model )
633
624
trace .report ._log_summary ()
@@ -664,14 +655,7 @@ def _check_start_shape(model, start):
664
655
665
656
666
657
def _sample_many (
667
- draws ,
668
- chain : int ,
669
- chains : int ,
670
- start : list ,
671
- random_seed : list ,
672
- step ,
673
- callback = None ,
674
- ** kwargs ,
658
+ draws , chain : int , chains : int , start : list , random_seed : list , step , callback = None , ** kwargs ,
675
659
):
676
660
"""Samples all chains sequentially.
677
661
@@ -833,9 +817,7 @@ def _sample(
833
817
"""
834
818
skip_first = kwargs .get ("skip_first" , 0 )
835
819
836
- sampling = _iter_sample (
837
- draws , step , start , trace , chain , tune , model , random_seed , callback
838
- )
820
+ sampling = _iter_sample (draws , step , start , trace , chain , tune , model , random_seed , callback )
839
821
_pbar_data = {"chain" : chain , "divergences" : 0 }
840
822
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
841
823
if progressbar :
@@ -909,9 +891,7 @@ def iter_sample(
909
891
for trace in iter_sample(500, step):
910
892
...
911
893
"""
912
- sampling = _iter_sample (
913
- draws , step , start , trace , chain , tune , model , random_seed , callback
914
- )
894
+ sampling = _iter_sample (draws , step , start , trace , chain , tune , model , random_seed , callback )
915
895
for i , (strace , _ ) in enumerate (sampling ):
916
896
yield MultiTrace ([strace [: i + 1 ]])
917
897
@@ -1012,8 +992,7 @@ def _iter_sample(
1012
992
if callback is not None :
1013
993
warns = getattr (step , "warnings" , None )
1014
994
callback (
1015
- trace = strace ,
1016
- draw = Draw (chain , i == draws , i , i < tune , stats , point , warns ),
995
+ trace = strace , draw = Draw (chain , i == draws , i , i < tune , stats , point , warns ),
1017
996
)
1018
997
1019
998
yield strace , diverging
@@ -1067,9 +1046,7 @@ def __init__(self, steppers, parallelize, progressbar=True):
1067
1046
import multiprocessing
1068
1047
1069
1048
for c , stepper in (
1070
- enumerate (progress_bar (steppers ))
1071
- if progressbar
1072
- else enumerate (steppers )
1049
+ enumerate (progress_bar (steppers )) if progressbar else enumerate (steppers )
1073
1050
):
1074
1051
secondary_end , primary_end = multiprocessing .Pipe ()
1075
1052
stepper_dumps = pickle .dumps (stepper , protocol = 4 )
@@ -1136,9 +1113,7 @@ def _run_secondary(c, stepper_dumps, secondary_end):
1136
1113
# but rather a CompoundStep. PopulationArrayStepShared.population
1137
1114
# has to be updated, therefore we identify the substeppers first.
1138
1115
population_steppers = []
1139
- for sm in (
1140
- stepper .methods if isinstance (stepper , CompoundStep ) else [stepper ]
1141
- ):
1116
+ for sm in stepper .methods if isinstance (stepper , CompoundStep ) else [stepper ]:
1142
1117
if isinstance (sm , arraystep .PopulationArrayStepShared ):
1143
1118
population_steppers .append (sm )
1144
1119
while True :
@@ -1682,13 +1657,9 @@ def sample_posterior_predictive(
1682
1657
nchain = 1
1683
1658
1684
1659
if keep_size and samples is not None :
1685
- raise IncorrectArgumentsError (
1686
- "Should not specify both keep_size and samples arguments"
1687
- )
1660
+ raise IncorrectArgumentsError ("Should not specify both keep_size and samples arguments" )
1688
1661
if keep_size and size is not None :
1689
- raise IncorrectArgumentsError (
1690
- "Should not specify both keep_size and size arguments"
1691
- )
1662
+ raise IncorrectArgumentsError ("Should not specify both keep_size and size arguments" )
1692
1663
1693
1664
if samples is None :
1694
1665
if isinstance (_trace , MultiTrace ):
@@ -1714,15 +1685,11 @@ def sample_posterior_predictive(
1714
1685
1715
1686
if var_names is not None :
1716
1687
if vars is not None :
1717
- raise IncorrectArgumentsError (
1718
- "Should not specify both vars and var_names arguments."
1719
- )
1688
+ raise IncorrectArgumentsError ("Should not specify both vars and var_names arguments." )
1720
1689
else :
1721
1690
vars = [model [x ] for x in var_names ]
1722
1691
elif vars is not None : # var_names is None, and vars is not.
1723
- warnings .warn (
1724
- "vars argument is deprecated in favor of var_names." , DeprecationWarning
1725
- )
1692
+ warnings .warn ("vars argument is deprecated in favor of var_names." , DeprecationWarning )
1726
1693
if vars is None :
1727
1694
vars = model .observed_RVs
1728
1695
@@ -1741,11 +1708,7 @@ def sample_posterior_predictive(
1741
1708
# the trace object will either be a MultiTrace (and have _straces)...
1742
1709
if hasattr (_trace , "_straces" ):
1743
1710
chain_idx , point_idx = np .divmod (idx , len_trace )
1744
- param = (
1745
- cast (MultiTrace , _trace )
1746
- ._straces [chain_idx % nchain ]
1747
- .point (point_idx )
1748
- )
1711
+ param = cast (MultiTrace , _trace )._straces [chain_idx % nchain ].point (point_idx )
1749
1712
# ... or a PointList
1750
1713
else :
1751
1714
param = cast (PointList , _trace )[idx % len_trace ]
@@ -1783,9 +1746,9 @@ def sample_posterior_predictive_w(
1783
1746
Parameters
1784
1747
----------
1785
1748
traces : list or list of lists
1786
- List of traces generated from MCMC sampling, or a list of list
1787
- containing dicts from find_MAP() or points. The number of traces should
1788
- be equal to the number of weights.
1749
+ List of traces generated from MCMC sampling (xarray.Dataset, arviz.InferenceData, or
1750
+ MultiTrace), or a list of list containing dicts from find_MAP() or points. The number of
1751
+ traces should be equal to the number of weights.
1789
1752
samples : int, optional
1790
1753
Number of posterior predictive samples to generate. Defaults to the
1791
1754
length of the shorter trace in traces.
@@ -1811,6 +1774,17 @@ def sample_posterior_predictive_w(
1811
1774
"""
1812
1775
np .random .seed (random_seed )
1813
1776
1777
+ if isinstance (traces [0 ], InferenceData ):
1778
+ n_samples = [
1779
+ trace .posterior .sizes ["chain" ] * trace .posterior .sizes ["draw" ] for trace in traces
1780
+ ]
1781
+ traces = [dataset_to_point_dict (trace .posterior ) for trace in traces ]
1782
+ elif isinstance (traces [0 ], xarray .Dataset ):
1783
+ n_samples = [trace .sizes ["chain" ] * trace .sizes ["draw" ] for trace in traces ]
1784
+ traces = [dataset_to_point_dict (trace ) for trace in traces ]
1785
+ else :
1786
+ n_samples = [len (i ) * i .nchains for i in traces ]
1787
+
1814
1788
if models is None :
1815
1789
models = [modelcontext (models )] * len (traces )
1816
1790
@@ -1830,7 +1804,7 @@ def sample_posterior_predictive_w(
1830
1804
weights = np .asarray (weights )
1831
1805
p = weights / np .sum (weights )
1832
1806
1833
- min_tr = min ([ len ( i ) * i . nchains for i in traces ] )
1807
+ min_tr = min (n_samples )
1834
1808
1835
1809
n = (min_tr * p ).astype ("int" )
1836
1810
# ensure n sum up to min_tr
@@ -1933,18 +1907,15 @@ def sample_prior_predictive(
1933
1907
if vars is None and var_names is None :
1934
1908
prior_pred_vars = model .observed_RVs
1935
1909
prior_vars = (
1936
- get_default_varnames (model .unobserved_RVs , include_transformed = True )
1937
- + model .potentials
1910
+ get_default_varnames (model .unobserved_RVs , include_transformed = True ) + model .potentials
1938
1911
)
1939
1912
vars_ = [var .name for var in prior_vars + prior_pred_vars ]
1940
1913
vars = set (vars_ )
1941
1914
elif vars is None :
1942
1915
vars = var_names
1943
1916
vars_ = vars
1944
1917
elif vars is not None :
1945
- warnings .warn (
1946
- "vars argument is deprecated in favor of var_names." , DeprecationWarning
1947
- )
1918
+ warnings .warn ("vars argument is deprecated in favor of var_names." , DeprecationWarning )
1948
1919
vars_ = vars
1949
1920
else :
1950
1921
raise ValueError ("Cannot supply both vars and var_names arguments." )
@@ -1974,13 +1945,7 @@ def sample_prior_predictive(
1974
1945
1975
1946
1976
1947
def init_nuts (
1977
- init = "auto" ,
1978
- chains = 1 ,
1979
- n_init = 500000 ,
1980
- model = None ,
1981
- random_seed = None ,
1982
- progressbar = True ,
1983
- ** kwargs ,
1948
+ init = "auto" , chains = 1 , n_init = 500000 , model = None , random_seed = None , progressbar = True , ** kwargs ,
1984
1949
):
1985
1950
"""Set up the mass matrix initialization for NUTS.
1986
1951
@@ -2036,9 +2001,7 @@ def init_nuts(
2036
2001
if set (vars ) != set (model .vars ):
2037
2002
raise ValueError ("Must use init_nuts on all variables of a model." )
2038
2003
if not all_continuous (vars ):
2039
- raise ValueError (
2040
- "init_nuts can only be used for models with only " "continuous variables."
2041
- )
2004
+ raise ValueError ("init_nuts can only be used for models with only " "continuous variables." )
2042
2005
2043
2006
if not isinstance (init , str ):
2044
2007
raise TypeError ("init must be a string." )
@@ -2092,9 +2055,7 @@ def init_nuts(
2092
2055
mean = approx .bij .rmap (approx .mean .get_value ())
2093
2056
mean = model .dict_to_array (mean )
2094
2057
weight = 50
2095
- potential = quadpotential .QuadPotentialDiagAdaptGrad (
2096
- model .ndim , mean , cov , weight
2097
- )
2058
+ potential = quadpotential .QuadPotentialDiagAdaptGrad (model .ndim , mean , cov , weight )
2098
2059
elif init == "advi+adapt_diag" :
2099
2060
approx = pm .fit (
2100
2061
random_seed = random_seed ,
0 commit comments