Skip to content

Commit 78cbf30

Browse files
authored
update model_comp examples and ppc_w to work with idata (#4042)
* update model_comp examples and ppc_w to work with idata * add test * update release notes * update release notes * fix typos and style
1 parent 05d6655 commit 78cbf30

File tree

5 files changed

+497
-454
lines changed

5 files changed

+497
-454
lines changed

RELEASE-NOTES.md

+4
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
### Maintenance
66
- Mentioned the way to do any random walk with `theano.tensor.cumsum()` in `GaussianRandomWalk` docstrings (see [#4048](https://github.com/pymc-devs/pymc3/pull/4048)).
77

8+
89
### Documentation
910

1011
### New features
12+
- `sample_posterior_predictive_w` can now feed on `xarray.Dataset` - e.g. from `InferenceData.posterior`. (see [#4042](https://github.com/pymc-devs/pymc3/pull/4042))
13+
1114

1215
## PyMC3 3.9.3 (11 August 2020)
1316

@@ -27,6 +30,7 @@
2730

2831
_NB: The `docs/*` folder is still removed from the tarball due to an upload size limit on PyPi._
2932

33+
3034
## PyMC3 3.9.2 (24 June 2020)
3135
### Maintenance
3236
- Warning added in GP module when `input_dim` is lower than the number of columns in `X` to compute the covariance function (see [#3974](https://github.com/pymc-devs/pymc3/pull/3974)).

docs/source/notebooks/model_averaging.ipynb

+234-88
Large diffs are not rendered by default.

docs/source/notebooks/model_comparison.ipynb

+217-287
Large diffs are not rendered by default.

pymc3/sampling.py

+35-74
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,7 @@ def _print_step_hierarchy(s, level=0):
226226
else:
227227
varnames = ", ".join(
228228
[
229-
get_untransformed_name(v.name)
230-
if is_transformed_name(v.name)
231-
else v.name
229+
get_untransformed_name(v.name) if is_transformed_name(v.name) else v.name
232230
for v in s.vars
233231
]
234232
)
@@ -491,10 +489,7 @@ def sample(
491489
start = start_
492490
except (AttributeError, NotImplementedError, tg.NullTypeGradError):
493491
# gradient computation failed
494-
_log.info(
495-
"Initializing NUTS failed. "
496-
"Falling back to elementwise auto-assignment."
497-
)
492+
_log.info("Initializing NUTS failed. " "Falling back to elementwise auto-assignment.")
498493
_log.debug("Exception in init nuts", exec_info=True)
499494
step = assign_step_methods(model, step, step_kwargs=kwargs)
500495
else:
@@ -559,9 +554,7 @@ def sample(
559554
has_demcmc = np.any(
560555
[
561556
isinstance(m, DEMetropolis)
562-
for m in (
563-
step.methods if isinstance(step, CompoundStep) else [step]
564-
)
557+
for m in (step.methods if isinstance(step, CompoundStep) else [step])
565558
]
566559
)
567560
_log.info("Population sampling ({} chains)".format(chains))
@@ -625,9 +618,7 @@ def sample(
625618

626619
if compute_convergence_checks:
627620
if draws - tune < 100:
628-
warnings.warn(
629-
"The number of samples is too small to check convergence reliably."
630-
)
621+
warnings.warn("The number of samples is too small to check convergence reliably.")
631622
else:
632623
trace.report._run_convergence_checks(idata, model)
633624
trace.report._log_summary()
@@ -664,14 +655,7 @@ def _check_start_shape(model, start):
664655

665656

666657
def _sample_many(
667-
draws,
668-
chain: int,
669-
chains: int,
670-
start: list,
671-
random_seed: list,
672-
step,
673-
callback=None,
674-
**kwargs,
658+
draws, chain: int, chains: int, start: list, random_seed: list, step, callback=None, **kwargs,
675659
):
676660
"""Samples all chains sequentially.
677661
@@ -833,9 +817,7 @@ def _sample(
833817
"""
834818
skip_first = kwargs.get("skip_first", 0)
835819

836-
sampling = _iter_sample(
837-
draws, step, start, trace, chain, tune, model, random_seed, callback
838-
)
820+
sampling = _iter_sample(draws, step, start, trace, chain, tune, model, random_seed, callback)
839821
_pbar_data = {"chain": chain, "divergences": 0}
840822
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
841823
if progressbar:
@@ -909,9 +891,7 @@ def iter_sample(
909891
for trace in iter_sample(500, step):
910892
...
911893
"""
912-
sampling = _iter_sample(
913-
draws, step, start, trace, chain, tune, model, random_seed, callback
914-
)
894+
sampling = _iter_sample(draws, step, start, trace, chain, tune, model, random_seed, callback)
915895
for i, (strace, _) in enumerate(sampling):
916896
yield MultiTrace([strace[: i + 1]])
917897

@@ -1012,8 +992,7 @@ def _iter_sample(
1012992
if callback is not None:
1013993
warns = getattr(step, "warnings", None)
1014994
callback(
1015-
trace=strace,
1016-
draw=Draw(chain, i == draws, i, i < tune, stats, point, warns),
995+
trace=strace, draw=Draw(chain, i == draws, i, i < tune, stats, point, warns),
1017996
)
1018997

1019998
yield strace, diverging
@@ -1067,9 +1046,7 @@ def __init__(self, steppers, parallelize, progressbar=True):
10671046
import multiprocessing
10681047

10691048
for c, stepper in (
1070-
enumerate(progress_bar(steppers))
1071-
if progressbar
1072-
else enumerate(steppers)
1049+
enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers)
10731050
):
10741051
secondary_end, primary_end = multiprocessing.Pipe()
10751052
stepper_dumps = pickle.dumps(stepper, protocol=4)
@@ -1136,9 +1113,7 @@ def _run_secondary(c, stepper_dumps, secondary_end):
11361113
# but rather a CompoundStep. PopulationArrayStepShared.population
11371114
# has to be updated, therefore we identify the substeppers first.
11381115
population_steppers = []
1139-
for sm in (
1140-
stepper.methods if isinstance(stepper, CompoundStep) else [stepper]
1141-
):
1116+
for sm in stepper.methods if isinstance(stepper, CompoundStep) else [stepper]:
11421117
if isinstance(sm, arraystep.PopulationArrayStepShared):
11431118
population_steppers.append(sm)
11441119
while True:
@@ -1682,13 +1657,9 @@ def sample_posterior_predictive(
16821657
nchain = 1
16831658

16841659
if keep_size and samples is not None:
1685-
raise IncorrectArgumentsError(
1686-
"Should not specify both keep_size and samples arguments"
1687-
)
1660+
raise IncorrectArgumentsError("Should not specify both keep_size and samples arguments")
16881661
if keep_size and size is not None:
1689-
raise IncorrectArgumentsError(
1690-
"Should not specify both keep_size and size arguments"
1691-
)
1662+
raise IncorrectArgumentsError("Should not specify both keep_size and size arguments")
16921663

16931664
if samples is None:
16941665
if isinstance(_trace, MultiTrace):
@@ -1714,15 +1685,11 @@ def sample_posterior_predictive(
17141685

17151686
if var_names is not None:
17161687
if vars is not None:
1717-
raise IncorrectArgumentsError(
1718-
"Should not specify both vars and var_names arguments."
1719-
)
1688+
raise IncorrectArgumentsError("Should not specify both vars and var_names arguments.")
17201689
else:
17211690
vars = [model[x] for x in var_names]
17221691
elif vars is not None: # var_names is None, and vars is not.
1723-
warnings.warn(
1724-
"vars argument is deprecated in favor of var_names.", DeprecationWarning
1725-
)
1692+
warnings.warn("vars argument is deprecated in favor of var_names.", DeprecationWarning)
17261693
if vars is None:
17271694
vars = model.observed_RVs
17281695

@@ -1741,11 +1708,7 @@ def sample_posterior_predictive(
17411708
# the trace object will either be a MultiTrace (and have _straces)...
17421709
if hasattr(_trace, "_straces"):
17431710
chain_idx, point_idx = np.divmod(idx, len_trace)
1744-
param = (
1745-
cast(MultiTrace, _trace)
1746-
._straces[chain_idx % nchain]
1747-
.point(point_idx)
1748-
)
1711+
param = cast(MultiTrace, _trace)._straces[chain_idx % nchain].point(point_idx)
17491712
# ... or a PointList
17501713
else:
17511714
param = cast(PointList, _trace)[idx % len_trace]
@@ -1783,9 +1746,9 @@ def sample_posterior_predictive_w(
17831746
Parameters
17841747
----------
17851748
traces : list or list of lists
1786-
List of traces generated from MCMC sampling, or a list of list
1787-
containing dicts from find_MAP() or points. The number of traces should
1788-
be equal to the number of weights.
1749+
List of traces generated from MCMC sampling (xarray.Dataset, arviz.InferenceData, or
1750+
MultiTrace), or a list of list containing dicts from find_MAP() or points. The number of
1751+
traces should be equal to the number of weights.
17891752
samples : int, optional
17901753
Number of posterior predictive samples to generate. Defaults to the
17911754
length of the shorter trace in traces.
@@ -1811,6 +1774,17 @@ def sample_posterior_predictive_w(
18111774
"""
18121775
np.random.seed(random_seed)
18131776

1777+
if isinstance(traces[0], InferenceData):
1778+
n_samples = [
1779+
trace.posterior.sizes["chain"] * trace.posterior.sizes["draw"] for trace in traces
1780+
]
1781+
traces = [dataset_to_point_dict(trace.posterior) for trace in traces]
1782+
elif isinstance(traces[0], xarray.Dataset):
1783+
n_samples = [trace.sizes["chain"] * trace.sizes["draw"] for trace in traces]
1784+
traces = [dataset_to_point_dict(trace) for trace in traces]
1785+
else:
1786+
n_samples = [len(i) * i.nchains for i in traces]
1787+
18141788
if models is None:
18151789
models = [modelcontext(models)] * len(traces)
18161790

@@ -1830,7 +1804,7 @@ def sample_posterior_predictive_w(
18301804
weights = np.asarray(weights)
18311805
p = weights / np.sum(weights)
18321806

1833-
min_tr = min([len(i) * i.nchains for i in traces])
1807+
min_tr = min(n_samples)
18341808

18351809
n = (min_tr * p).astype("int")
18361810
# ensure n sum up to min_tr
@@ -1933,18 +1907,15 @@ def sample_prior_predictive(
19331907
if vars is None and var_names is None:
19341908
prior_pred_vars = model.observed_RVs
19351909
prior_vars = (
1936-
get_default_varnames(model.unobserved_RVs, include_transformed=True)
1937-
+ model.potentials
1910+
get_default_varnames(model.unobserved_RVs, include_transformed=True) + model.potentials
19381911
)
19391912
vars_ = [var.name for var in prior_vars + prior_pred_vars]
19401913
vars = set(vars_)
19411914
elif vars is None:
19421915
vars = var_names
19431916
vars_ = vars
19441917
elif vars is not None:
1945-
warnings.warn(
1946-
"vars argument is deprecated in favor of var_names.", DeprecationWarning
1947-
)
1918+
warnings.warn("vars argument is deprecated in favor of var_names.", DeprecationWarning)
19481919
vars_ = vars
19491920
else:
19501921
raise ValueError("Cannot supply both vars and var_names arguments.")
@@ -1974,13 +1945,7 @@ def sample_prior_predictive(
19741945

19751946

19761947
def init_nuts(
1977-
init="auto",
1978-
chains=1,
1979-
n_init=500000,
1980-
model=None,
1981-
random_seed=None,
1982-
progressbar=True,
1983-
**kwargs,
1948+
init="auto", chains=1, n_init=500000, model=None, random_seed=None, progressbar=True, **kwargs,
19841949
):
19851950
"""Set up the mass matrix initialization for NUTS.
19861951
@@ -2036,9 +2001,7 @@ def init_nuts(
20362001
if set(vars) != set(model.vars):
20372002
raise ValueError("Must use init_nuts on all variables of a model.")
20382003
if not all_continuous(vars):
2039-
raise ValueError(
2040-
"init_nuts can only be used for models with only " "continuous variables."
2041-
)
2004+
raise ValueError("init_nuts can only be used for models with only " "continuous variables.")
20422005

20432006
if not isinstance(init, str):
20442007
raise TypeError("init must be a string.")
@@ -2092,9 +2055,7 @@ def init_nuts(
20922055
mean = approx.bij.rmap(approx.mean.get_value())
20932056
mean = model.dict_to_array(mean)
20942057
weight = 50
2095-
potential = quadpotential.QuadPotentialDiagAdaptGrad(
2096-
model.ndim, mean, cov, weight
2097-
)
2058+
potential = quadpotential.QuadPotentialDiagAdaptGrad(model.ndim, mean, cov, weight)
20982059
elif init == "advi+adapt_diag":
20992060
approx = pm.fit(
21002061
random_seed=random_seed,

pymc3/tests/test_sampling.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -718,22 +718,24 @@ def test_sample_posterior_predictive_w(self):
718718
mu = pm.Normal("mu", mu=0, sigma=1)
719719
y = pm.Normal("y", mu=mu, sigma=1, observed=data0)
720720
trace_0 = pm.sample()
721+
idata_0 = az.from_pymc3(trace_0)
721722

722723
with pm.Model() as model_1:
723724
mu = pm.Normal("mu", mu=0, sigma=1, shape=len(data0))
724725
y = pm.Normal("y", mu=mu, sigma=1, observed=data0)
725726
trace_1 = pm.sample()
726-
727-
traces = [trace_0, trace_0]
728-
models = [model_0, model_0]
729-
ppc = pm.sample_posterior_predictive_w(traces, 100, models)
730-
assert ppc["y"].shape == (100, 500)
727+
idata_1 = az.from_pymc3(trace_1)
731728

732729
traces = [trace_0, trace_1]
730+
idatas = [idata_0, idata_1]
733731
models = [model_0, model_1]
732+
734733
ppc = pm.sample_posterior_predictive_w(traces, 100, models)
735734
assert ppc["y"].shape == (100, 500)
736735

736+
ppc = pm.sample_posterior_predictive_w(idatas, 100, models)
737+
assert ppc["y"].shape == (100, 500)
738+
737739

738740
@pytest.mark.parametrize(
739741
"method",

0 commit comments

Comments
 (0)