Skip to content

Commit 0e74cef

Browse files
authored
Update notebook Dirichlet mixture of multinomials to PyMC 5 (#581)
* Update notebook to pymc v5 Also remove unused import * Make randomness deterministic * Update notebook to match new simulated data Since the simulated data has changed, the numerical issues we faced are slightly different. The notebook has been modified to reflect this. - No need to do metropolis sampling for the multinomial model - Worse numerical issues for the explicit DM model. Addressed by increasing target_accept and modify the discussion about poor rhat values - Minor changes to match new data - Also refer to species by name instead of index. * Minor changes to notebook text * Change target_accept from 0.95 to 0.9 This gets rid of most divergences, and gives much better r_hat statistics * Add myself to author list * Minor change * Fix minor text mismatches after updates * Update unrelated files
1 parent 5177871 commit 0e74cef

File tree

4 files changed

+684
-525
lines changed

4 files changed

+684
-525
lines changed

examples/case_studies/nyc_bym.ipynb

-1
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,6 @@
634634
"outputs": [],
635635
"source": [
636636
"with pm.Model(coords=coords) as BYM_model:\n",
637-
"\n",
638637
" # intercept\n",
639638
" beta0 = pm.Normal(\"beta0\", 0, 1)\n",
640639
"\n",

examples/case_studies/nyc_bym.myst.md

-1
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,6 @@ Lastly, we'll use a Poisson outcome distribution. The number of traffic accident
319319

320320
```{code-cell} ipython3
321321
with pm.Model(coords=coords) as BYM_model:
322-
323322
# intercept
324323
beta0 = pm.Normal("beta0", 0, 1)
325324

examples/mixture_models/dirichlet_mixture_of_multinomials.ipynb

+648-452
Large diffs are not rendered by default.

examples/mixture_models/dirichlet_mixture_of_multinomials.myst.md

+36-71
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,10 @@ This notebook will demonstrate the performance benefits that come from taking th
4747
import arviz as az
4848
import matplotlib.pyplot as plt
4949
import numpy as np
50-
import pymc3 as pm
50+
import pymc as pm
5151
import scipy as sp
52-
import scipy.stats
53-
import seaborn as sns
5452
55-
print(f"Running on PyMC3 v{pm.__version__}")
53+
print(f"Running on PyMC v{pm.__version__}")
5654
```
5755

5856
```{code-cell} ipython3
@@ -121,8 +119,10 @@ k = len(trees)
121119
n = len(forests)
122120
total_count = 50
123121
124-
true_p = sp.stats.dirichlet(true_conc * true_frac).rvs(size=n)
125-
observed_counts = np.vstack([sp.stats.multinomial(n=total_count, p=p_i).rvs() for p_i in true_p])
122+
true_p = sp.stats.dirichlet(true_conc * true_frac).rvs(size=n, random_state=rng)
123+
observed_counts = np.vstack(
124+
[sp.stats.multinomial(n=total_count, p=p_i).rvs(random_state=rng) for p_i in true_p]
125+
)
126126
127127
observed_counts
128128
```
@@ -149,46 +149,16 @@ with pm.Model(coords=coords) as model_multinomial:
149149
pm.model_to_graphviz(model_multinomial)
150150
```
151151

152-
Interestingly, NUTS frequently runs into numerical problems on this model, perhaps an example of the
153-
["Folk Theorem of Statistical Computing"](https://statmodeling.stat.columbia.edu/2008/05/13/the_folk_theore/).
154-
155-
Because of a couple of identities of the multinomial distribution,
156-
we could reparameterize this model in a number of ways—we
157-
would obtain equivalent models by exploding our $n$ observations
158-
of $\mathrm{total\_count}$ items into $(n \times \mathrm{total\_count})$
159-
independent categorical trials, or collapsing them down into
160-
one Multinomial draw with $(n \times \mathrm{total\_count})$ items.
161-
(Importantly, this is _not_ true for the DM distribution.)
162-
163-
Rather than _actually_ fixing our problem through reparameterization,
164-
here we'll instead switch to the Metropolis step method,
165-
which ignores some of the geometric pathologies of our naïve model.
166-
167-
**Important**: switching to Metropolis does not not _fix_ our model's issues, rather it _sweeps them under the rug_.
168-
In fact, if you try running this model with NUTS (PyMC3's default step method), it will break loudly during sampling.
169-
When that happens, this should be a **red alert** that there is something wrong in our model.
170-
171-
You'll also notice below that we have to increase considerably the number of draws we take from the posterior;
172-
this is because Metropolis is much less efficient at
173-
exploring the posterior than NUTS.
174-
175152
```{code-cell} ipython3
176153
with model_multinomial:
177-
trace_multinomial = pm.sample(
178-
draws=5000, chains=4, step=pm.Metropolis(), return_inferencedata=True
179-
)
154+
trace_multinomial = pm.sample(chains=4)
180155
```
181156

182-
Let's ignore the warning about inefficient sampling for now.
183-
184157
```{code-cell} ipython3
185158
az.plot_trace(data=trace_multinomial, var_names=["frac"]);
186159
```
187160

188-
The trace plots look fairly good;
189-
visually, each parameter appears to be moving around the posterior well,
190-
although some sharp parts of the KDE plot suggests that
191-
sampling sometimes gets stuck in one place for a few steps.
161+
The trace plots look fairly good; visually, each parameter appears to be moving around the posterior well.
192162

193163
```{code-cell} ipython3
194164
summary_multinomial = az.summary(trace_multinomial, var_names=["frac"])
@@ -201,12 +171,9 @@ summary_multinomial
201171
```
202172

203173
Likewise, diagnostics in the parameter summary table all look fine.
204-
Here I've added a column estimating the effective sample size per
174+
Here we've added a column estimating the effective sample size per
205175
second of sampling.
206176

207-
Nonetheless, the fact that we were unable to use NUTS is still a red flag, and we should be
208-
very cautious in using these results.
209-
210177
```{code-cell} ipython3
211178
az.plot_forest(trace_multinomial, var_names=["frac"])
212179
for j, (y_tick, frac_j) in enumerate(zip(plt.gca().get_yticks(), reversed(true_frac))):
@@ -219,7 +186,7 @@ frequencies are for each species (dashed lines), we can comment on the accuracy
219186
of our inferences.
220187
And now the issues with our model become apparent;
221188
notice that the 94% HDIs _don't include the true values_ for
222-
tree species 0, 2, 3.
189+
tree species 0, 1, 3.
223190
We might have seen _one_ HDI miss, but _three_???
224191

225192
...what's going on?
@@ -228,9 +195,7 @@ Let's troubleshoot this model using a posterior-predictive check, comparing our
228195

229196
```{code-cell} ipython3
230197
with model_multinomial:
231-
pp_samples = az.from_pymc3(
232-
posterior_predictive=pm.fast_sample_posterior_predictive(trace=trace_multinomial)
233-
)
198+
pp_samples = pm.sample_posterior_predictive(trace=trace_multinomial)
234199
235200
# Concatenate with InferenceData object
236201
trace_multinomial.extend(pp_samples)
@@ -283,12 +248,12 @@ axs[-1].set_ylim(0, 0.6);
283248
Here we're plotting histograms of the predicted counts
284249
against the observed counts for each species.
285250

286-
_(Notice that the y-axis isn't full height and clips the distributions for species-4 in purple.)_
251+
_(Notice that the y-axis isn't full height and clips the distributions for species ``mahogany`` in purple.)_
287252

288253
And now we can start to see why our posterior HDI deviates from the _true_ parameters for three of five species (vertical lines).
289254
See that for all of the species the observed counts are frequently quite far from the predictions
290255
conditioned on the posterior distribution.
291-
This is particularly obvious for (e.g.) species-2 where we have one observation of more than 20
256+
This is particularly obvious for (e.g.) ``oak`` where we have one observation of more than 30
292257
trees of this species, despite the posterior predicitive mass being concentrated far below that.
293258

294259
This is overdispersion at work, and a clear sign that we need to adjust our model to accommodate it.
@@ -330,13 +295,14 @@ accounting for overdispersion of counts relative to the simple multinomial model
330295

331296
```{code-cell} ipython3
332297
with model_dm_explicit:
333-
trace_dm_explicit = pm.sample(chains=4, return_inferencedata=True)
298+
trace_dm_explicit = pm.sample(chains=4, target_accept=0.9)
334299
```
335300

336-
We got a warning, although we'll ignore it for now.
337-
More interesting is how much longer it took to sample this model than the
338-
first.
339-
This may be because our model has an additional ~$(n \times k)$ parameters,
301+
Here we had to increase ``target_accept`` from 0.8 to 0.9 to not get drowned in divergences.
302+
303+
We also got a warning about the ``rhat`` statistic, although we'll ignore it for now.
304+
More interesting is how much longer it took to sample this model than the first.
305+
This is partly because our model has an additional ~$(n \times k)$ parameters,
340306
but it seems like there are other geometric challenges for NUTS as well.
341307

342308
We'll see if we can fix these in the next model, but for now let's take a look at the traces.
@@ -345,7 +311,7 @@ We'll see if we can fix these in the next model, but for now let's take a look a
345311
az.plot_trace(data=trace_dm_explicit, var_names=["frac", "conc"]);
346312
```
347313

348-
Obviously some sampling issues, but it's hard to see where divergences are occurring.
314+
The divergences seem to occur when the estimated fraction of the rare species (``mahogany``) is very close to zero.
349315

350316
```{code-cell} ipython3
351317
az.plot_forest(trace_dm_explicit, var_names=["frac"])
@@ -372,8 +338,8 @@ summary_dm_explicit
372338
```
373339

374340
This is great, but _we can do better_.
375-
The larger $\hat{R}$ value for $\mathrm{frac}_4$ is mildly concerning, and it's surprising
376-
that our $\mathrm{ESS} \; \mathrm{sec}^{-1}$ is relatively small.
341+
The slightly too large $\hat{R}$ value for ``frac[mahogany]`` is a bit concerning, and it's surprising
342+
that our $\mathrm{ESS} \; \mathrm{sec}^{-1}$ is quite small.
377343

378344
+++
379345

@@ -383,7 +349,7 @@ that our $\mathrm{ESS} \; \mathrm{sec}^{-1}$ is relatively small.
383349

384350
Happily, the Dirichlet distribution is conjugate to the multinomial
385351
and therefore there's a convenient, closed-form for the marginalized
386-
distribution, i.e. the Dirichlet-multinomial distribution, which was added to PyMC3 in [3.11.0](https://github.com/pymc-devs/pymc3/releases/tag/v3.11.0).
352+
distribution, i.e. the Dirichlet-multinomial distribution, which was added to PyMC in [3.11.0](https://github.com/pymc-devs/pymc3/releases/tag/v3.11.0).
387353

388354
Let's take advantage of this, marginalizing out the explicit latent parameter, $p_i$,
389355
replacing the combination of this node and the multinomial
@@ -405,7 +371,7 @@ nodes together into a single DM node.
405371

406372
```{code-cell} ipython3
407373
with model_dm_marginalized:
408-
trace_dm_marginalized = pm.sample(chains=4, return_inferencedata=True)
374+
trace_dm_marginalized = pm.sample(chains=4)
409375
```
410376

411377
It samples much more quickly and without any of the warnings from before!
@@ -437,9 +403,7 @@ Posterior predictive checks to the rescue (again)!
437403

438404
```{code-cell} ipython3
439405
with model_dm_marginalized:
440-
pp_samples = az.from_pymc3(
441-
posterior_predictive=pm.fast_sample_posterior_predictive(trace_dm_marginalized)
442-
)
406+
pp_samples = pm.sample_posterior_predictive(trace_dm_marginalized)
443407
444408
# Concatenate with InferenceData object
445409
trace_dm_marginalized.extend(pp_samples)
@@ -494,7 +458,7 @@ axs[-1, 0].set_ylim(0, 0.6)
494458
ax.set_ylim(0, 0.6);
495459
```
496460

497-
_(Notice, again, that the y-axis isn't full height, and clips the distributions for species-4 in purple.)_
461+
_(Notice, again, that the y-axis isn't full height, and clips the distributions for ``mahogany`` in purple.)_
498462

499463
Compared to the multinomial (plots on the right), PPCs for the DM (left) show that the observed data is
500464
an entirely reasonable realization of our model.
@@ -512,15 +476,19 @@ We'll use leave-one-out cross validation to compare the
512476
out-of-sample predictive ability of the two.
513477

514478
```{code-cell} ipython3
479+
with model_multinomial:
480+
pm.compute_log_likelihood(trace_multinomial)
481+
482+
with model_dm_marginalized:
483+
pm.compute_log_likelihood(trace_dm_marginalized)
484+
515485
az.compare(
516486
{"multinomial": trace_multinomial, "dirichlet_multinomial": trace_dm_marginalized}, ic="loo"
517487
)
518488
```
519489

520-
Unsurprisingly, the DM outclasses the multinomial by a mile, assigning a weight of nearly
521-
100% to the over-dispersed model.
522-
We can conclude that between the two, the DM should be greatly favored for prediction,
523-
parameter inference, etc.
490+
Unsurprisingly, the DM outclasses the multinomial by a mile, assigning a weight of 100% to the over-dispersed model.
491+
While the ``warning=True`` flag for the multinomial distribution indicates that the numerical value cannot be fully trusted, the large difference in ``elpd_loo`` is further confirmation that between the two, the DM should be greatly favored for prediction, parameter inference, etc.
524492

525493
+++
526494

@@ -549,19 +517,16 @@ In that case, swapping the vanilla Dirichlet distribution for something fancier
549517
## Authors
550518
* Authored by [Byron J. Smith](https://github.com/bsmith89) on Jan, 2021 ([pymc-examples#18](https://github.com/pymc-devs/pymc-examples/pull/18))
551519
* Updated by Abhipsha Das and Oriol Abril-Pla on August, 2021 ([pymc-examples#212](https://github.com/pymc-devs/pymc-examples/pull/212))
520+
* Updated to PyMC v5 by Erik Werner on Oct, 2023 ([pymc-examples#581](https://github.com/pymc-devs/pymc-examples/pull/581))
552521

553522
+++
554523

555524
## Watermark
556525

557526
```{code-cell} ipython3
558527
%load_ext watermark
559-
%watermark -n -u -v -iv -w -p theano,xarray
528+
%watermark -n -u -v -iv -w -p pytensor,xarray
560529
```
561530

562531
:::{include} page_footer.md
563532
:::
564-
565-
```{code-cell} ipython3
566-
567-
```

0 commit comments

Comments
 (0)