Skip to content

Commit ce0f674

Browse files
authored
update 9 notebooks to use az.extract + PyMC v5 upgrade (#522)
* update moderation analysis notebook * update GLM binomial regression * update simpsons paradox * update rolling regression * update glm robust with outlier detection * update truncated/censored notebook * update stochastic volatility * update dp mixture * update putting * update author blocks where I'd missed them before * add link to PR in author block * finish putting notebook * polish off rough edges * remove wrong/duplicate v5 update claim in author block
1 parent 081f2f9 commit ce0f674

18 files changed

+583
-3088
lines changed

examples/case_studies/moderation_analysis.ipynb

+26-180
Large diffs are not rendered by default.

examples/case_studies/moderation_analysis.myst.md

+6-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ jupytext:
77
kernelspec:
88
display_name: pymc_env
99
language: python
10-
name: pymc_env
10+
name: python3
1111
---
1212

1313
(moderation_analysis)=
@@ -71,7 +71,7 @@ def posterior_prediction_plot(result, x, moderator, m_quantiles, ax=None):
7171
if ax is None:
7272
fig, ax = plt.subplots(1, 1)
7373
74-
post = result.posterior.stack(sample=("chain", "draw"))
74+
post = az.extract(result)
7575
xi = xr.DataArray(np.linspace(np.min(x), np.max(x), 20), dims=["x_plot"])
7676
m_levels = result.constant_data["m"].quantile(m_quantiles).rename({"quantile": "m_level"})
7777
@@ -99,13 +99,13 @@ def posterior_prediction_plot(result, x, moderator, m_quantiles, ax=None):
9999
return ax
100100
101101
102-
def plot_moderation_effect(m, m_quantiles, ax=None):
102+
def plot_moderation_effect(result, m, m_quantiles, ax=None):
103103
"""Spotlight graph"""
104104
105105
if ax is None:
106106
fig, ax = plt.subplots(1, 1)
107107
108-
post = result.posterior.stack(sample=("chain", "draw"))
108+
post = az.extract(result)
109109
110110
# calculate 95% CI region and median
111111
xi = xr.DataArray(np.linspace(np.min(m), np.max(m), 20), dims=["x_plot"])
@@ -316,7 +316,7 @@ We can also visualise the moderation effect by plotting $\beta_1 + \beta_2 \cdot
316316

317317
```{code-cell} ipython3
318318
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
319-
plot_moderation_effect(m, m_quantiles, ax[0])
319+
plot_moderation_effect(result, m, m_quantiles, ax[0])
320320
az.plot_posterior(result, var_names="β2", ax=ax[1]);
321321
```
322322

@@ -362,6 +362,7 @@ But readers are strongly encouraged to read {cite:t}`mcclelland2017multicollinea
362362
- Authored by Benjamin T. Vincent in June 2021
363363
- Updated by Benjamin T. Vincent in March 2022
364364
- Updated by Benjamin T. Vincent in February 2023 to run on PyMC v5
365+
- Updated to use `az.extract` by [Benjamin T. Vincent](https://github.com/drbenvincent) in February 2023 ([pymc-examples#522](https://github.com/pymc-devs/pymc-examples/pull/522))
365366

366367
+++
367368

examples/case_studies/putting_workflow.ipynb

+111-873
Large diffs are not rendered by default.

examples/case_studies/putting_workflow.myst.md

+37-48
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ jupytext:
55
format_name: myst
66
format_version: 0.13
77
kernelspec:
8-
display_name: Python 3.10.6 ('pymc_env')
8+
display_name: pymc_env
99
language: python
1010
name: python3
1111
substitutions:
@@ -43,8 +43,6 @@ import scipy.stats as st
4343
import xarray as xr
4444
4545
from xarray_einstats.stats import XrContinuousRV
46-
47-
print(f"Running on PyMC3 v{pm.__version__}")
4846
```
4947

5048
```{code-cell} ipython3
@@ -172,30 +170,30 @@ We plot 50 posterior draws of $p(\text{success})$ along with the expected value.
172170
```{code-cell} ipython3
173171
# Draw posterior predictive samples
174172
with logit_model:
175-
# hard to plot more than 400 sensibly
176-
# we generate a posterior predictive sample for only 1 in every 10 draws
177-
logit_trace.extend(pm.sample_posterior_predictive(logit_trace.sel(draw=slice(None, None, 10))))
178-
logit_post = logit_trace.posterior
179-
logit_ppc = logit_trace.posterior_predictive
180-
const_data = logit_trace.constant_data
181-
logit_ppc_success = (logit_ppc["success"] / const_data["tries"]).stack(sample=("chain", "draw"))
173+
logit_trace.extend(pm.sample_posterior_predictive(logit_trace))
174+
175+
# hard to plot more than 400 sensibly
176+
logit_post = az.extract(logit_trace, num_samples=400)
177+
logit_ppc = az.extract(logit_trace, group="posterior_predictive", num_samples=400)
178+
const_data = logit_trace["constant_data"]
179+
180+
logit_ppc_success = logit_ppc["success"] / const_data["tries"]
182181
183182
# Plotting
184183
ax = plot_golf_data(golf_data)
185184
t_ary = np.linspace(CUP_RADIUS - BALL_RADIUS, golf_data.distance.max(), 200)
186185
t = xr.DataArray(t_ary, coords=[("distance", t_ary)])
187186
logit_post["expit"] = scipy.special.expit(logit_post["a"] * t + logit_post["b"])
188-
logit_post_subset = az.extract_dataset(logit_post, num_samples=50, rng=RANDOM_SEED)
189187
190188
ax.plot(
191189
t,
192-
logit_post_subset["expit"],
190+
logit_post["expit"].T,
193191
lw=1,
194192
color="C1",
195193
alpha=0.5,
196194
)
197195
198-
ax.plot(t, logit_post["expit"].mean(("chain", "draw")), color="C2")
196+
ax.plot(t, logit_post["expit"].mean(dim="sample"), color="C2")
199197
200198
ax.plot(golf_data.distance, logit_ppc_success, "k.", alpha=0.01)
201199
ax.set_title("Logit mean and posterior predictive");
@@ -204,9 +202,7 @@ ax.set_title("Logit mean and posterior predictive");
204202
The fit is ok, but not great! It is a good start for a baseline, and lets us answer curve-fitting type questions. We may not trust much extrapolation beyond the end of the data, especially given how the curve does not fit the last four values very well. For example, putts from 50 feet are expected to be made with probability:
205203

206204
```{code-cell} ipython3
207-
prob_at_50 = (
208-
scipy.special.expit(logit_post["a"] * 50 + logit_post["b"]).mean(("chain", "draw")).item()
209-
)
205+
prob_at_50 = scipy.special.expit(logit_post["a"] * 50 + logit_post["b"]).mean(dim="sample").item()
210206
print(f"{100 * prob_at_50:.5f}%")
211207
```
212208

@@ -220,13 +216,13 @@ this appeared here in using
220216

221217
```python
222218
# Right!
223-
scipy.special.expit(logit_trace.posterior["a"] * 50 + logit_trace.posterior["b"]).mean(('chain', 'draw'))
219+
scipy.special.expit(logit_trace.posterior["a"] * 50 + logit_trace.posterior["b"]).mean(dim="sample")
224220
```
225221
rather than
226222

227223
```python
228224
# Wrong!
229-
scipy.special.expit(logit_trace.posterior["a"].mean(('chain', 'draw')) * 50 + logit_trace.posterior["b"].mean(('chain', 'draw')))
225+
scipy.special.expit(logit_trace.posterior["a"].mean(dim="sample") * 50 + logit_trace.posterior["b"].mean(dim="sample"))
230226
```
231227

232228
to calculate our expectation at 50 feet.
@@ -333,7 +329,7 @@ This is a little funny! Most obviously, it should probably be not this common to
333329
with angle_model:
334330
angle_trace.extend(pm.sample(1000, tune=1000, target_accept=0.85))
335331
336-
angle_post = angle_trace.posterior
332+
angle_post = az.extract(angle_trace)
337333
```
338334

339335
```{code-cell} ipython3
@@ -343,21 +339,21 @@ angle_post["expit"] = forward_angle_model(angle_post["variance_of_shot"], t)
343339
344340
ax.plot(
345341
t,
346-
az.extract_dataset(angle_post, num_samples=50)["expit"],
342+
angle_post["expit"][:, ::100],
347343
lw=1,
348344
color="C1",
349345
alpha=0.1,
350346
)
351347
352348
ax.plot(
353349
t,
354-
angle_post["expit"].mean(("chain", "draw")),
350+
angle_post["expit"].mean(dim="sample"),
355351
label="Geometry-based model",
356352
)
357353
358354
ax.plot(
359355
t,
360-
logit_post["expit"].mean(("chain", "draw")),
356+
logit_post["expit"].mean(dim="sample"),
361357
label="Logit-binomial model",
362358
)
363359
ax.set_title("Comparing the fit of geometry-based and logit-binomial model")
@@ -374,21 +370,18 @@ print(f"{100 * angle_prob_at_50.mean().item():.2f}% vs {100 * prob_at_50:.5f}%")
374370
We can also recreate our prior predictive plot, giving us some confidence that the prior was not leading to unreasonable situations in the posterior distribution: the variance in angle is quite small!
375371

376372
```{code-cell} ipython3
377-
angle_of_shot = XrContinuousRV(
378-
st.norm, 0, az.extract_dataset(angle_post, num_samples=500)["variance_of_shot"]
379-
).rvs(
373+
angle_of_shot = XrContinuousRV(st.norm, 0, angle_post["variance_of_shot"]).rvs(
380374
random_state=RANDOM_SEED
381375
) # radians
382376
distance = 20 # feet
383377
384378
end_positions = xr.Dataset(
385-
{"endx": distance * np.cos(angle_of_shot), "endy": distance * np.sin(angle_of_shot)}
379+
{"endx": distance * np.cos(angle_of_shot.data), "endy": distance * np.sin(angle_of_shot.data)}
386380
)
387381
388382
fig, ax = plt.subplots()
389-
for sample in range(end_positions.dims["sample"]):
390-
end = end_positions.isel(sample=sample)
391-
ax.plot([0, end["endx"]], [0, end["endy"]], "k-o", lw=1, mfc="w", alpha=0.5)
383+
for x, y in zip(end_positions.endx, end_positions.endy):
384+
ax.plot([0, x], [0, y], "k-o", lw=1, mfc="w", alpha=0.5)
392385
ax.plot(0, 0, "go", label="Start", mfc="g", ms=20)
393386
ax.plot(distance, 0, "ro", label="Goal", mfc="r", ms=20)
394387
@@ -482,16 +475,15 @@ and convergence warnings have no other solution than using a different model tha
482475
ax = plot_golf_data(new_golf_data)
483476
plot_golf_data(golf_data, ax=ax, color="C1")
484477
485-
new_angle_post = new_angle_trace.posterior
486-
478+
new_angle_post = az.extract(new_angle_trace)
487479
ax.plot(
488480
t,
489-
forward_angle_model(angle_post["variance_of_shot"], t).mean(("chain", "draw")),
481+
forward_angle_model(angle_post["variance_of_shot"], t).mean(dim="sample"),
490482
label="Trained on original data",
491483
)
492484
ax.plot(
493485
t,
494-
forward_angle_model(new_angle_post["variance_of_shot"], t).mean(("chain", "draw")),
486+
forward_angle_model(new_angle_post["variance_of_shot"], t).mean(dim="sample"),
495487
label="Trained on new data",
496488
)
497489
ax.set_title("Retraining the model on new data")
@@ -576,11 +568,11 @@ def forward_distance_angle_model(variance_of_shot, variance_of_distance, t):
576568
577569
ax = plot_golf_data(new_golf_data)
578570
579-
distance_angle_post = distance_angle_trace.posterior
571+
distance_angle_post = az.extract(distance_angle_trace)
580572
581573
ax.plot(
582574
t,
583-
forward_angle_model(new_angle_post["variance_of_shot"], t).mean(("chain", "draw")),
575+
forward_angle_model(new_angle_post["variance_of_shot"], t).mean(dim="sample"),
584576
label="Just angle",
585577
)
586578
ax.plot(
@@ -589,7 +581,7 @@ ax.plot(
589581
distance_angle_post["variance_of_shot"],
590582
distance_angle_post["variance_of_distance"],
591583
t,
592-
).mean(("chain", "draw")),
584+
).mean(dim="sample"),
593585
label="Distance and angle",
594586
)
595587
@@ -675,15 +667,15 @@ with disp_distance_angle_model:
675667
```{code-cell} ipython3
676668
ax = plot_golf_data(new_golf_data, ax=None)
677669
678-
disp_distance_angle_post = disp_distance_angle_trace.posterior
670+
disp_distance_angle_post = az.extract(disp_distance_angle_trace)
679671
680672
ax.plot(
681673
t,
682674
forward_distance_angle_model(
683675
distance_angle_post["variance_of_shot"],
684676
distance_angle_post["variance_of_distance"],
685677
t,
686-
).mean(("chain", "draw")),
678+
).mean(dim="sample"),
687679
label="Distance and angle",
688680
)
689681
ax.plot(
@@ -692,7 +684,7 @@ ax.plot(
692684
disp_distance_angle_post["variance_of_shot"],
693685
disp_distance_angle_post["variance_of_distance"],
694686
t,
695-
).mean(("chain", "draw")),
687+
).mean(dim="sample"),
696688
label="Dispersed model",
697689
)
698690
ax.set_title("Comparing dispersed model with binomial distance/angle model")
@@ -703,14 +695,14 @@ This new model does better between 10 and 30 feet, as we can also see using the
703695

704696
```{code-cell} ipython3
705697
const_data = distance_angle_trace.constant_data
706-
old_pp = distance_angle_trace.posterior_predictive
698+
old_pp = az.extract(distance_angle_trace, group="posterior_predictive")
707699
old_residuals = 100 * ((const_data["successes"] - old_pp["success"]) / const_data["tries"]).mean(
708-
("chain", "draw")
700+
dim="sample"
709701
)
710702
711-
pp = disp_distance_angle_trace.posterior_predictive
703+
pp = az.extract(disp_distance_angle_trace, group="posterior_predictive")
712704
residuals = 100 * (const_data["successes"] / const_data["tries"] - pp["p_success"]).mean(
713-
("chain", "draw")
705+
dim="sample"
714706
)
715707
716708
fig, ax = plt.subplots()
@@ -789,7 +781,7 @@ Note that this is again something we might check experimentally. In particular,
789781
def expected_num_putts(trace, distance_to_hole, trials=100_000):
790782
distance_to_hole = distance_to_hole * np.ones(trials)
791783
792-
combined_trace = trace.posterior.stack(sample=("chain", "draw"))
784+
combined_trace = az.extract(trace)
793785
794786
n_samples = combined_trace.dims["sample"]
795787
@@ -844,6 +836,7 @@ fig.suptitle("Simulated number of putts from\na few distances");
844836
* Adapted by Colin Carroll from the [Model building and expansion for golf putting] case study in the Stan documentation ([pymc#3666](https://github.com/pymc-devs/pymc/pull/3666))
845837
* Updated by Marco Gorelli ([pymc-examples#39](https://github.com/pymc-devs/pymc-examples/pull/39))
846838
* Updated by Oriol Abril-Pla to use PyMC v4 and xarray-einstats
839+
* Updated by [Benjamin T. Vincent](https://github.com/drbenvincent) to use `az.extract` in February 2023 ([pymc-examples#522](https://github.com/pymc-devs/pymc-examples/pull/522))
847840

848841
+++
849842

@@ -864,7 +857,3 @@ fig.suptitle("Simulated number of putts from\na few distances");
864857

865858
:::{include} ../page_footer.md
866859
:::
867-
868-
```{code-cell} ipython3
869-
870-
```

examples/case_studies/stochastic_volatility.ipynb

+32-191
Large diffs are not rendered by default.

examples/case_studies/stochastic_volatility.myst.md

+10-9
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ jupytext:
55
format_name: myst
66
format_version: 0.13
77
kernelspec:
8-
display_name: Python 3 (ipykernel)
8+
display_name: pymc_env
99
language: python
1010
name: python3
1111
---
@@ -106,7 +106,7 @@ pm.model_to_graphviz(stochastic_vol_model)
106106
with stochastic_vol_model:
107107
idata = pm.sample_prior_predictive(500, random_seed=rng)
108108
109-
prior_predictive = idata.prior_predictive.stack(pooled_chain=("chain", "draw"))
109+
prior_predictive = az.extract(idata, group="prior_predictive")
110110
```
111111

112112
We plot and inspect the prior predictive. This is *many* orders of magnitude larger than the actual returns we observed. In fact, I cherry-picked a few draws to keep the plot from looking silly. This may suggest changing our priors: a return that our model considers plausible would violate all sorts of constraints by a huge margin: the total value of all goods and services the world produces is ~$\$10^9$, so we might reasonably *not* expect any returns above that magnitude.
@@ -117,7 +117,7 @@ That said, we get somewhat reasonable results fitting this model anyways, and it
117117
fig, ax = plt.subplots(figsize=(14, 4))
118118
returns["change"].plot(ax=ax, lw=1, color="black")
119119
ax.plot(
120-
prior_predictive["returns"].isel(pooled_chain=slice(4, 6, None)),
120+
prior_predictive["returns"][:, 0::10],
121121
"g",
122122
alpha=0.5,
123123
lw=1,
@@ -138,17 +138,17 @@ Once we are happy with our model, we can sample from the posterior. This is a so
138138

139139
```{code-cell} ipython3
140140
with stochastic_vol_model:
141-
idata.extend(pm.sample(2000, tune=2000, random_seed=rng))
141+
idata.extend(pm.sample(random_seed=rng))
142142
143-
posterior = idata.posterior.stack(pooled_chain=("chain", "draw"))
143+
posterior = az.extract(idata)
144144
posterior["exp_volatility"] = np.exp(posterior["volatility"])
145145
```
146146

147147
```{code-cell} ipython3
148148
with stochastic_vol_model:
149149
idata.extend(pm.sample_posterior_predictive(idata, random_seed=rng))
150150
151-
posterior_predictive = idata.posterior_predictive.stack(pooled_chain=("chain", "draw"))
151+
posterior_predictive = az.extract(idata, group="posterior_predictive")
152152
```
153153

154154
Note that the `step_size` parameter does not look perfect: the different chains look somewhat different. This again indicates some weakness in our model: it may make sense to allow the step_size to change over time, especially over this 11 year time span.
@@ -162,7 +162,7 @@ Now we can look at our posterior estimates of the volatility in S&P 500 returns
162162
```{code-cell} ipython3
163163
fig, ax = plt.subplots(figsize=(14, 4))
164164
165-
y_vals = posterior["exp_volatility"].isel(pooled_chain=slice(None, None, 5))
165+
y_vals = posterior["exp_volatility"]
166166
x_vals = y_vals.time.astype(np.datetime64)
167167
168168
plt.plot(x_vals, y_vals, "k", alpha=0.002)
@@ -177,9 +177,9 @@ Finally, we can use the posterior predictive distribution to see the how the lea
177177
fig, axes = plt.subplots(nrows=2, figsize=(14, 7), sharex=True)
178178
returns["change"].plot(ax=axes[0], color="black")
179179
180-
axes[1].plot(posterior["exp_volatility"].isel(pooled_chain=slice(None, None, 100)), "r", alpha=0.5)
180+
axes[1].plot(posterior["exp_volatility"], "r", alpha=0.5)
181181
axes[0].plot(
182-
posterior_predictive["returns"].isel(pooled_chain=slice(None, None, 100)),
182+
posterior_predictive["returns"],
183183
"g",
184184
alpha=0.5,
185185
zorder=-10,
@@ -215,6 +215,7 @@ axes[1].set_title("Posterior volatility");
215215
* Updated by Michael Osthege on June 1, 2022 ([pymc-examples#343](https://github.com/pymc-devs/pymc-examples/pull/343))
216216
* Updated by Christopher Krapu on June 17, 2022 ([pymc-examples#378](https://github.com/pymc-devs/pymc-examples/pull/378))
217217
* Updated for compatibility with PyMC v5 by Beryl Kanali and Sangam Swadik on Jan 22, 2023 ([pymc-examples#517](https://github.com/pymc-devs/pymc-examples/pull/517))
218+
* Updated by [Benjamin T. Vincent](https://github.com/drbenvincent) to use `az.extract` ([pymc-examples#522](https://github.com/pymc-devs/pymc-examples/pull/522))
218219

219220
+++
220221

examples/generalized_linear_models/GLM-binomial-regression.ipynb

+19-131
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)