@@ -5,9 +5,9 @@ jupytext:
5
5
format_name : myst
6
6
format_version : 0.13
7
7
kernelspec :
8
- display_name : Python 3 (ipykernel)
8
+ display_name : pymc-examples
9
9
language : python
10
- name : python3
10
+ name : pymc-examples
11
11
---
12
12
13
13
(spline)=
@@ -43,14 +43,15 @@ import numpy as np
43
43
import pandas as pd
44
44
import pymc as pm
45
45
46
- from patsy import dmatrix
46
+ from patsy import build_design_matrices, dmatrix
47
47
```
48
48
49
49
``` {code-cell} ipython3
50
50
%matplotlib inline
51
51
%config InlineBackend.figure_format = "retina"
52
52
53
- RANDOM_SEED = 8927
53
+ seed = sum(map(ord, "splines"))
54
+ rng = np.random.default_rng(seed)
54
55
az.style.use("arviz-darkgrid")
55
56
```
56
57
@@ -84,7 +85,12 @@ If we visualize the data, it is clear that there a lot of annual variation, but
84
85
85
86
``` {code-cell} ipython3
86
87
blossom_data.plot.scatter(
87
- "year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Days in bloom"
88
+ "year",
89
+ "doy",
90
+ color="cornflowerblue",
91
+ s=10,
92
+ title="Cherry Blossom Data",
93
+ ylabel="Days in bloom",
88
94
);
89
95
```
90
96
@@ -106,18 +112,23 @@ The spline will have 15 *knots*, splitting the year into 16 sections (including
106
112
107
113
``` {code-cell} ipython3
108
114
num_knots = 15
109
- knot_list = np.quantile (blossom_data.year, np.linspace(0, 1 , num_knots))
115
+ knot_list = np.percentile (blossom_data.year, np.linspace(0, 100 , num_knots + 2))[1:-1]
110
116
knot_list
111
117
```
112
118
113
119
Below is a plot of the locations of the knots over the data.
114
120
115
121
``` {code-cell} ipython3
116
122
blossom_data.plot.scatter(
117
- "year", "doy", color="cornflowerblue", s=10, title="Cherry Blossom Data", ylabel="Day of Year"
123
+ "year",
124
+ "doy",
125
+ color="cornflowerblue",
126
+ s=10,
127
+ title="Cherry Blossom Data",
128
+ ylabel="Day of Year",
118
129
)
119
130
for knot in knot_list:
120
- plt.gca().axvline(knot, color="grey", alpha=0.4);
131
+ plt.gca().axvline(knot, color="grey", alpha=0.4)
121
132
```
122
133
123
134
We can use ` patsy ` to create the matrix $B$ that will be the b-spline basis for the regression.
@@ -128,7 +139,7 @@ The degree is set to 3 to create a cubic b-spline.
128
139
129
140
B = dmatrix(
130
141
"bs(year, knots=knots, degree=3, include_intercept=True) - 1",
131
- {"year": blossom_data.year.values, "knots": knot_list[1:-1] },
142
+ {"year": blossom_data.year.values, "knots": knot_list},
132
143
)
133
144
B
134
145
```
@@ -160,9 +171,14 @@ COORDS = {"splines": np.arange(B.shape[1])}
160
171
with pm.Model(coords=COORDS) as spline_model:
161
172
a = pm.Normal("a", 100, 5)
162
173
w = pm.Normal("w", mu=0, sigma=3, size=B.shape[1], dims="splines")
163
- mu = pm.Deterministic("mu", a + pm.math.dot(np.asarray(B, order="F"), w.T))
174
+
175
+ mu = pm.Deterministic(
176
+ "mu",
177
+ a + pm.math.dot(np.asarray(B, order="F"), w.T),
178
+ )
164
179
sigma = pm.Exponential("sigma", 1)
165
- D = pm.Normal("D", mu=mu, sigma=sigma, observed=blossom_data.doy, dims="obs")
180
+
181
+ D = pm.Normal("D", mu=mu, sigma=sigma, observed=blossom_data.doy)
166
182
```
167
183
168
184
``` {code-cell} ipython3
@@ -172,7 +188,15 @@ pm.model_to_graphviz(spline_model)
172
188
``` {code-cell} ipython3
173
189
with spline_model:
174
190
idata = pm.sample_prior_predictive()
175
- idata.extend(pm.sample(draws=1000, tune=1000, random_seed=RANDOM_SEED, chains=4))
191
+ idata.extend(
192
+ pm.sample(
193
+ nuts_sampler="nutpie",
194
+ draws=1000,
195
+ tune=1000,
196
+ random_seed=rng,
197
+ chains=4,
198
+ )
199
+ )
176
200
pm.sample_posterior_predictive(idata, extend_inferencedata=True)
177
201
```
178
202
@@ -230,7 +254,7 @@ spline_df_merged.plot("year", "value", c="black", lw=2, ax=plt.gca())
230
254
plt.legend(title="Spline Index", loc="lower center", fontsize=8, ncol=6)
231
255
232
256
for knot in knot_list:
233
- plt.gca().axvline(knot, color="grey", alpha=0.4);
257
+ plt.gca().axvline(knot, color="grey", alpha=0.4)
234
258
```
235
259
236
260
### Model predictions
@@ -267,6 +291,150 @@ plt.fill_between(
267
291
);
268
292
```
269
293
294
+ ## Predicting on new data
295
+
296
+ Now imagine we got a new data set, with the same range of years as the original data set, and we want to get predictions for this new data set with our fitted model. We can do this with the classic PyMC workflow of ` Data ` containers and ` set_data ` method.
297
+
298
+ Before we get there though, let's note that we didn't say the new data set contains * new* years, i.e out-of-sample years. And that's on purpose, because splines can't extrapolate beyond the range of the data set used to fit the model -- hence their limitation for time series analysis. On data ranges previously seen though, that's no problem.
299
+
300
+ That precision out of the way, let's redefine our model, this time adding ` Data ` containers.
301
+
302
+ ``` {code-cell} ipython3
303
+ COORDS = {"obs": blossom_data.index}
304
+ ```
305
+
306
+ ``` {code-cell} ipython3
307
+ with pm.Model(coords=COORDS) as spline_model:
308
+ year_data = pm.Data("year", blossom_data.year)
309
+ doy = pm.Data("doy", blossom_data.doy)
310
+
311
+ # intercept
312
+ a = pm.Normal("a", 100, 5)
313
+
314
+ # Create spline bases & coefficients
315
+ ## Store knots & design matrix for prediction
316
+ spline_model.knots = np.percentile(year_data.eval(), np.linspace(0, 100, num_knots + 2))[1:-1]
317
+ spline_model.dm = dmatrix(
318
+ "bs(x, knots=spline_model.knots, degree=3, include_intercept=False) - 1",
319
+ {"x": year_data.eval()},
320
+ )
321
+ spline_model.add_coords({"spline": np.arange(spline_model.dm.shape[1])})
322
+ splines_basis = pm.Data("splines_basis", np.asarray(spline_model.dm), dims=("obs", "spline"))
323
+ w = pm.Normal("w", mu=0, sigma=3, dims="spline")
324
+
325
+ mu = pm.Deterministic(
326
+ "mu",
327
+ a + pm.math.dot(splines_basis, w),
328
+ )
329
+ sigma = pm.Exponential("sigma", 1)
330
+
331
+ D = pm.Normal("D", mu=mu, sigma=sigma, observed=doy)
332
+ ```
333
+
334
+ ``` {code-cell} ipython3
335
+ pm.model_to_graphviz(spline_model)
336
+ ```
337
+
338
+ ``` {code-cell} ipython3
339
+ with spline_model:
340
+ idata = pm.sample(
341
+ nuts_sampler="nutpie",
342
+ random_seed=rng,
343
+ )
344
+ idata.extend(pm.sample_posterior_predictive(idata, random_seed=rng))
345
+ ```
346
+
347
+ Now we can swap out the data and update the design matrix with the new data:
348
+
349
+ ``` {code-cell} ipython3
350
+ new_blossom_data = (
351
+ blossom_data.sample(50, random_state=rng).sort_values("year").reset_index(drop=True)
352
+ )
353
+
354
+ # update design matrix with new data
355
+ year_data_new = new_blossom_data.year.to_numpy()
356
+ dm_new = build_design_matrices([spline_model.dm.design_info], {"x": year_data_new})[0]
357
+ ```
358
+
359
+ Use ` set_data ` to update the data in the model:
360
+
361
+ ``` {code-cell} ipython3
362
+ with spline_model:
363
+ pm.set_data(
364
+ new_data={
365
+ "year": year_data_new,
366
+ "doy": new_blossom_data.doy.to_numpy(),
367
+ "splines_basis": np.asarray(dm_new),
368
+ },
369
+ coords={
370
+ "obs": new_blossom_data.index,
371
+ },
372
+ )
373
+ ```
374
+
375
+ And all that's left is to sample from the posterior predictive distribution:
376
+
377
+ ``` {code-cell} ipython3
378
+ with spline_model:
379
+ preds = pm.sample_posterior_predictive(idata, var_names=["mu"])
380
+ ```
381
+
382
+ Plot the predictions, to check if everything went well:
383
+
384
+ ``` {code-cell} ipython3
385
+ _, axes = plt.subplots(1, 2, figsize=(16, 5), sharex=True, sharey=True)
386
+
387
+ blossom_data.plot.scatter(
388
+ "year",
389
+ "doy",
390
+ color="cornflowerblue",
391
+ s=10,
392
+ title="Posterior predictions",
393
+ ylabel="Days in bloom",
394
+ ax=axes[0],
395
+ )
396
+ axes[0].vlines(
397
+ spline_model.knots,
398
+ blossom_data.doy.min(),
399
+ blossom_data.doy.max(),
400
+ color="grey",
401
+ alpha=0.4,
402
+ )
403
+ axes[0].plot(
404
+ blossom_data.year,
405
+ idata.posterior["mu"].mean(("chain", "draw")),
406
+ color="firebrick",
407
+ )
408
+ az.plot_hdi(blossom_data.year, idata.posterior["mu"], color="firebrick", ax=axes[0])
409
+
410
+ new_blossom_data.plot.scatter(
411
+ "year",
412
+ "doy",
413
+ color="cornflowerblue",
414
+ s=10,
415
+ title="Predictions on new data",
416
+ ylabel="Days in bloom",
417
+ ax=axes[1],
418
+ )
419
+ axes[1].vlines(
420
+ spline_model.knots,
421
+ blossom_data.doy.min(),
422
+ blossom_data.doy.max(),
423
+ color="grey",
424
+ alpha=0.4,
425
+ )
426
+ axes[1].plot(
427
+ new_blossom_data.year,
428
+ preds.posterior_predictive.mu.mean(("chain", "draw")),
429
+ color="firebrick",
430
+ )
431
+ az.plot_hdi(new_blossom_data.year, preds.posterior_predictive.mu, color="firebrick", ax=axes[1]);
432
+ ```
433
+
434
+ And... voilà! Granted, this example is not the most realistic one, but we trust you to adapt it to your wildest dreams ;)
435
+
436
+ +++
437
+
270
438
## References
271
439
272
440
:::{bibliography}
@@ -280,6 +448,7 @@ plt.fill_between(
280
448
- Created by Joshua Cook
281
449
- Updated by Tyler James Burch
282
450
- Updated by Chris Fonnesbeck
451
+ - Predictions on new data added by Alex Andorra
283
452
284
453
+++
285
454
0 commit comments