Skip to content

Commit 8ca055e

Browse files
authored
add new data section (#771)
1 parent 321c57e commit 8ca055e

File tree

2 files changed

+1373
-358
lines changed

2 files changed

+1373
-358
lines changed

examples/howto/spline.ipynb

+1,191-345
Large diffs are not rendered by default.

examples/howto/spline.myst.md

+182-13
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ jupytext:
55
format_name: myst
66
format_version: 0.13
77
kernelspec:
8-
display_name: Python 3 (ipykernel)
8+
display_name: pymc-examples
99
language: python
10-
name: python3
10+
name: pymc-examples
1111
---
1212

1313
(spline)=
@@ -43,14 +43,15 @@ import numpy as np
4343
import pandas as pd
4444
import pymc as pm
4545
46-
from patsy import dmatrix
46+
from patsy import build_design_matrices, dmatrix
4747
```
4848

4949
```{code-cell} ipython3
5050
%matplotlib inline
5151
%config InlineBackend.figure_format = "retina"
5252
53-
RANDOM_SEED = 8927
53+
seed = sum(map(ord, "splines"))
54+
rng = np.random.default_rng(seed)
5455
az.style.use("arviz-darkgrid")
5556
```
5657

@@ -84,7 +85,12 @@ If we visualize the data, it is clear that there a lot of annual variation, but
8485

8586
```{code-cell} ipython3
8687
blossom_data.plot.scatter(
87-
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Days in bloom"
88+
"year",
89+
"doy",
90+
color="cornflowerblue",
91+
s=10,
92+
title="Cherry Blossom Data",
93+
ylabel="Days in bloom",
8894
);
8995
```
9096

@@ -106,18 +112,23 @@ The spline will have 15 *knots*, splitting the year into 16 sections (including
106112

107113
```{code-cell} ipython3
108114
num_knots = 15
109-
knot_list = np.quantile(blossom_data.year, np.linspace(0, 1, num_knots))
115+
knot_list = np.percentile(blossom_data.year, np.linspace(0, 100, num_knots + 2))[1:-1]
110116
knot_list
111117
```
112118

113119
Below is a plot of the locations of the knots over the data.
114120

115121
```{code-cell} ipython3
116122
blossom_data.plot.scatter(
117-
"year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Day of Year"
123+
"year",
124+
"doy",
125+
color="cornflowerblue",
126+
s=10,
127+
title="Cherry Blossom Data",
128+
ylabel="Day of Year",
118129
)
119130
for knot in knot_list:
120-
plt.gca().axvline(knot, color="grey", alpha=0.4);
131+
plt.gca().axvline(knot, color="grey", alpha=0.4)
121132
```
122133

123134
We can use `patsy` to create the matrix $B$ that will be the b-spline basis for the regression.
@@ -128,7 +139,7 @@ The degree is set to 3 to create a cubic b-spline.
128139
129140
B = dmatrix(
130141
"bs(year, knots=knots, degree=3, include_intercept=True) - 1",
131-
{"year": blossom_data.year.values, "knots": knot_list[1:-1]},
142+
{"year": blossom_data.year.values, "knots": knot_list},
132143
)
133144
B
134145
```
@@ -160,9 +171,14 @@ COORDS = {"splines": np.arange(B.shape[1])}
160171
with pm.Model(coords=COORDS) as spline_model:
161172
a = pm.Normal("a", 100, 5)
162173
w = pm.Normal("w", mu=0, sigma=3, size=B.shape[1], dims="splines")
163-
mu = pm.Deterministic("mu", a + pm.math.dot(np.asarray(B, order="F"), w.T))
174+
175+
mu = pm.Deterministic(
176+
"mu",
177+
a + pm.math.dot(np.asarray(B, order="F"), w.T),
178+
)
164179
sigma = pm.Exponential("sigma", 1)
165-
D = pm.Normal("D", mu=mu, sigma=sigma, observed=blossom_data.doy, dims="obs")
180+
181+
D = pm.Normal("D", mu=mu, sigma=sigma, observed=blossom_data.doy)
166182
```
167183

168184
```{code-cell} ipython3
@@ -172,7 +188,15 @@ pm.model_to_graphviz(spline_model)
172188
```{code-cell} ipython3
173189
with spline_model:
174190
idata = pm.sample_prior_predictive()
175-
idata.extend(pm.sample(draws=1000, tune=1000, random_seed=RANDOM_SEED, chains=4))
191+
idata.extend(
192+
pm.sample(
193+
nuts_sampler="nutpie",
194+
draws=1000,
195+
tune=1000,
196+
random_seed=rng,
197+
chains=4,
198+
)
199+
)
176200
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
177201
```
178202

@@ -230,7 +254,7 @@ spline_df_merged.plot("year", "value", c="black", lw=2, ax=plt.gca())
230254
plt.legend(title="Spline Index", loc="lower center", fontsize=8, ncol=6)
231255
232256
for knot in knot_list:
233-
plt.gca().axvline(knot, color="grey", alpha=0.4);
257+
plt.gca().axvline(knot, color="grey", alpha=0.4)
234258
```
235259

236260
### Model predictions
@@ -267,6 +291,150 @@ plt.fill_between(
267291
);
268292
```
269293

294+
## Predicting on new data
295+
296+
Now imagine we got a new data set, with the same range of years as the original data set, and we want to get predictions for this new data set with our fitted model. We can do this with the classic PyMC workflow of `Data` containers and `set_data` method.
297+
298+
Before we get there though, let's note that we didn't say the new data set contains *new* years, i.e out-of-sample years. And that's on purpose, because splines can't extrapolate beyond the range of the data set used to fit the model -- hence their limitation for time series analysis. On data ranges previously seen though, that's no problem.
299+
300+
That precision out of the way, let's redefine our model, this time adding `Data` containers.
301+
302+
```{code-cell} ipython3
303+
COORDS = {"obs": blossom_data.index}
304+
```
305+
306+
```{code-cell} ipython3
307+
with pm.Model(coords=COORDS) as spline_model:
308+
year_data = pm.Data("year", blossom_data.year)
309+
doy = pm.Data("doy", blossom_data.doy)
310+
311+
# intercept
312+
a = pm.Normal("a", 100, 5)
313+
314+
# Create spline bases & coefficients
315+
## Store knots & design matrix for prediction
316+
spline_model.knots = np.percentile(year_data.eval(), np.linspace(0, 100, num_knots + 2))[1:-1]
317+
spline_model.dm = dmatrix(
318+
"bs(x, knots=spline_model.knots, degree=3, include_intercept=False) - 1",
319+
{"x": year_data.eval()},
320+
)
321+
spline_model.add_coords({"spline": np.arange(spline_model.dm.shape[1])})
322+
splines_basis = pm.Data("splines_basis", np.asarray(spline_model.dm), dims=("obs", "spline"))
323+
w = pm.Normal("w", mu=0, sigma=3, dims="spline")
324+
325+
mu = pm.Deterministic(
326+
"mu",
327+
a + pm.math.dot(splines_basis, w),
328+
)
329+
sigma = pm.Exponential("sigma", 1)
330+
331+
D = pm.Normal("D", mu=mu, sigma=sigma, observed=doy)
332+
```
333+
334+
```{code-cell} ipython3
335+
pm.model_to_graphviz(spline_model)
336+
```
337+
338+
```{code-cell} ipython3
339+
with spline_model:
340+
idata = pm.sample(
341+
nuts_sampler="nutpie",
342+
random_seed=rng,
343+
)
344+
idata.extend(pm.sample_posterior_predictive(idata, random_seed=rng))
345+
```
346+
347+
Now we can swap out the data and update the design matrix with the new data:
348+
349+
```{code-cell} ipython3
350+
new_blossom_data = (
351+
blossom_data.sample(50, random_state=rng).sort_values("year").reset_index(drop=True)
352+
)
353+
354+
# update design matrix with new data
355+
year_data_new = new_blossom_data.year.to_numpy()
356+
dm_new = build_design_matrices([spline_model.dm.design_info], {"x": year_data_new})[0]
357+
```
358+
359+
Use `set_data` to update the data in the model:
360+
361+
```{code-cell} ipython3
362+
with spline_model:
363+
pm.set_data(
364+
new_data={
365+
"year": year_data_new,
366+
"doy": new_blossom_data.doy.to_numpy(),
367+
"splines_basis": np.asarray(dm_new),
368+
},
369+
coords={
370+
"obs": new_blossom_data.index,
371+
},
372+
)
373+
```
374+
375+
And all that's left is to sample from the posterior predictive distribution:
376+
377+
```{code-cell} ipython3
378+
with spline_model:
379+
preds = pm.sample_posterior_predictive(idata, var_names=["mu"])
380+
```
381+
382+
Plot the predictions, to check if everything went well:
383+
384+
```{code-cell} ipython3
385+
_, axes = plt.subplots(1, 2, figsize=(16, 5), sharex=True, sharey=True)
386+
387+
blossom_data.plot.scatter(
388+
"year",
389+
"doy",
390+
color="cornflowerblue",
391+
s=10,
392+
title="Posterior predictions",
393+
ylabel="Days in bloom",
394+
ax=axes[0],
395+
)
396+
axes[0].vlines(
397+
spline_model.knots,
398+
blossom_data.doy.min(),
399+
blossom_data.doy.max(),
400+
color="grey",
401+
alpha=0.4,
402+
)
403+
axes[0].plot(
404+
blossom_data.year,
405+
idata.posterior["mu"].mean(("chain", "draw")),
406+
color="firebrick",
407+
)
408+
az.plot_hdi(blossom_data.year, idata.posterior["mu"], color="firebrick", ax=axes[0])
409+
410+
new_blossom_data.plot.scatter(
411+
"year",
412+
"doy",
413+
color="cornflowerblue",
414+
s=10,
415+
title="Predictions on new data",
416+
ylabel="Days in bloom",
417+
ax=axes[1],
418+
)
419+
axes[1].vlines(
420+
spline_model.knots,
421+
blossom_data.doy.min(),
422+
blossom_data.doy.max(),
423+
color="grey",
424+
alpha=0.4,
425+
)
426+
axes[1].plot(
427+
new_blossom_data.year,
428+
preds.posterior_predictive.mu.mean(("chain", "draw")),
429+
color="firebrick",
430+
)
431+
az.plot_hdi(new_blossom_data.year, preds.posterior_predictive.mu, color="firebrick", ax=axes[1]);
432+
```
433+
434+
And... voilà! Granted, this example is not the most realistic one, but we trust you to adapt it to your wildest dreams ;)
435+
436+
+++
437+
270438
## References
271439

272440
:::{bibliography}
@@ -280,6 +448,7 @@ plt.fill_between(
280448
- Created by Joshua Cook
281449
- Updated by Tyler James Burch
282450
- Updated by Chris Fonnesbeck
451+
- Predictions on new data added by Alex Andorra
283452

284453
+++
285454

0 commit comments

Comments
 (0)