Skip to content

Commit 2d9c6c1

Browse files
author
juanitorduz
committed
cleanup
1 parent b2433dd commit 2d9c6c1

File tree

2 files changed

+93
-85
lines changed

2 files changed

+93
-85
lines changed

examples/time_series/Time_Series_Generative_Graph.ipynb

Lines changed: 76 additions & 68 deletions
Large diffs are not rendered by default.

examples/time_series/Time_Series_Generative_Graph.myst.md

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ Now that we have implemented the AR(2) step, we can assign priors to the paramet
128128
coords = {
129129
"lags": range(-lags, 0),
130130
"steps": range(timeseries_length - lags),
131-
"trials": range(timeseries_length),
131+
"timeseries_length": range(timeseries_length),
132132
}
133133
with pm.Model(coords=coords, check_bounds=False) as model:
134134
rho = pm.Normal(name="rho", mu=0, sigma=0.2, dims=("lags",))
@@ -146,7 +146,7 @@ with pm.Model(coords=coords, check_bounds=False) as model:
146146
)
147147
148148
ar = pm.Deterministic(
149-
name="ar", var=pt.concatenate([ar_init, ar_innov], axis=-1), dims=("trials",)
149+
name="ar", var=pt.concatenate([ar_init, ar_innov], axis=-1), dims=("timeseries_length",)
150150
)
151151
152152
@@ -191,6 +191,7 @@ fig, ax = plt.subplots(
191191
nrows=5, ncols=1, figsize=(12, 12), sharex=True, sharey=True, layout="constrained"
192192
)
193193
chosen_draw = 2
194+
194195
for i, axi in enumerate(ax, start=chosen_draw):
195196
axi.plot(
196197
prior.prior["ar"].isel(draw=i, chain=0),
@@ -207,10 +208,10 @@ Next, we want to condition the AR(2) model on some observed data so that we can
207208
```{code-cell} ipython3
208209
# select a random draw from the prior
209210
prior_draw = prior.prior.isel(chain=0, draw=chosen_draw)
210-
test_data = prior_draw["ar_dist"].values
211+
test_data = prior_draw["ar_dist"].to_numpy()
211212
212213
with pm.observe(model, {"ar_dist": test_data}) as observed_model:
213-
trace = pm.sample()
214+
trace = pm.sample(chains=4, random_seed=rng)
214215
```
215216

216217
Let's plot the trace and the posterior distribution of the parameters.
@@ -229,7 +230,7 @@ axes = az.plot_trace(
229230
("rho", {}, rho_true),
230231
("sigma", {}, sigma_true),
231232
],
232-
backend_kwargs={"figsize": (12, 5), "layout": "constrained"},
233+
backend_kwargs={"figsize": (12, 7), "layout": "constrained"},
233234
)
234235
plt.gcf().suptitle("AR(2) Model Trace", fontsize=18, fontweight="bold");
235236
```
@@ -263,15 +264,15 @@ for i, hdi_prob in enumerate((0.94, 0.64), 1):
263264
lower = hdi.sel(hdi="lower")
264265
upper = hdi.sel(hdi="higher")
265266
ax.fill_between(
266-
x=post_pred_ar["trials"],
267+
x=post_pred_ar["timeseries_length"],
267268
y1=lower,
268269
y2=upper,
269270
alpha=(i - 0.2) * 0.2,
270271
color="C0",
271272
label=f"{hdi_prob:.0%} HDI",
272273
)
273274
ax.plot(
274-
post_pred_ar["trials"],
275+
post_pred_ar["timeseries_length"],
275276
post_pred_ar.mean(("chain", "draw")),
276277
color="C0",
277278
label="Mean",
@@ -343,7 +344,7 @@ We need to shift the coordinate `steps` forward by one! The reasons is that the
343344
coords = {
344345
"lags": range(-lags, 0),
345346
"steps": range(-1, timeseries_length - lags - 1), # <- Coordinate shift!
346-
"trials": range(1, timeseries_length + 1), # <- Coordinate shift!
347+
"timeseries_length": range(1, timeseries_length + 1), # <- Coordinate shift!
347348
}
348349
with pm.Model(coords=coords, check_bounds=False) as conditional_model:
349350
y_data = pm.Data("y_data", ar_obs)
@@ -360,7 +361,7 @@ with pm.Model(coords=coords, check_bounds=False) as conditional_model:
360361
)
361362
362363
ar = pm.Deterministic(
363-
name="ar", var=pt.concatenate([ar_init, ar_innov], axis=-1), dims=("trials",)
364+
name="ar", var=pt.concatenate([ar_init, ar_innov], axis=-1), dims=("timeseries_length",)
364365
)
365366
366367
post_pred_conditional = pm.sample_posterior_predictive(trace, var_names=["ar"], random_seed=rng)
@@ -384,27 +385,27 @@ for i, hdi_prob in enumerate((0.94, 0.64), 1):
384385
lower = hdi.sel(hdi="lower")
385386
upper = hdi.sel(hdi="higher")
386387
ax.fill_between(
387-
x=conditional_post_pred_ar["trials"],
388+
x=conditional_post_pred_ar["timeseries_length"],
388389
y1=lower,
389390
y2=upper,
390391
alpha=(i - 0.2) * 0.2,
391392
color="C1",
392393
label=f"{hdi_prob:.0%} HDI",
393394
)
394395
ax.plot(
395-
conditional_post_pred_ar["trials"],
396+
conditional_post_pred_ar["timeseries_length"],
396397
conditional_post_pred_ar.mean(("chain", "draw")),
397398
color="C1",
398399
label="Mean",
399400
)
400401
ax.plot(ar_obs, color="black", label="Observed")
401402
ax.plot(
402-
conditional_post_pred_ar["trials"],
403+
conditional_post_pred_ar["timeseries_length"],
403404
res.fittedvalues,
404405
color="C2",
405406
label="statsmodels",
406407
)
407-
ax.legend(loc="upper right")
408+
ax.legend(loc="lower center", bbox_to_anchor=(0.5, -0.2), ncol=5)
408409
ax.set_xlabel("time")
409410
ax.set_title("AR(2) Conditional Posterior Predictive Samples", fontsize=18, fontweight="bold");
410411
```
@@ -438,7 +439,6 @@ The idea is to use the posterior samples and the latest available two data point
438439
```{code-cell} ipython3
439440
coords = {
440441
"lags": range(-lags, 0),
441-
"trials": range(timeseries_length),
442442
"steps": range(timeseries_length, timeseries_length + forecast_steps),
443443
}
444444
with pm.Model(coords=coords, check_bounds=False) as forecasting_model:
@@ -481,7 +481,7 @@ for i, hdi_prob in enumerate((0.94, 0.64), 1):
481481
lower = hdi.sel(hdi="lower")
482482
upper = hdi.sel(hdi="higher")
483483
ax.fill_between(
484-
x=conditional_post_pred_ar["trials"],
484+
x=conditional_post_pred_ar["timeseries_length"],
485485
y1=lower,
486486
y2=upper,
487487
alpha=(i - 0.2) * 0.2,
@@ -490,7 +490,7 @@ for i, hdi_prob in enumerate((0.94, 0.64), 1):
490490
)
491491
492492
ax.plot(
493-
conditional_post_pred_ar["trials"],
493+
conditional_post_pred_ar["timeseries_length"],
494494
conditional_post_pred_ar.mean(("chain", "draw")),
495495
color="C1",
496496
label="Mean",
@@ -519,7 +519,7 @@ ax.plot(
519519
520520
ax.plot(ar_obs, color="black", label="Observed")
521521
ax.plot(
522-
conditional_post_pred_ar["trials"],
522+
conditional_post_pred_ar["timeseries_length"],
523523
res.fittedvalues,
524524
color="C2",
525525
label="statsmodels",

0 commit comments

Comments
 (0)