Skip to content

Commit da4df53

Browse files
committed
change delta to target_accept
1 parent 586b009 commit da4df53

File tree

2 files changed

+464
-443
lines changed

2 files changed

+464
-443
lines changed

examples/diagnostics_and_criticism/Diagnosing_biased_Inference_with_Divergences.ipynb

Lines changed: 448 additions & 433 deletions
Large diffs are not rendered by default.

examples/diagnostics_and_criticism/Diagnosing_biased_Inference_with_Divergences.myst.md

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -327,11 +327,11 @@ df["Divergent"] = pd.Series(
327327
acceptance_runs[0.99].sample_stats["diverging"].sum().item(),
328328
]
329329
)
330-
df["delta_target"] = pd.Series([".80", ".85", ".90", ".95", ".99"])
330+
df["target_accept"] = pd.Series([".80", ".85", ".90", ".95", ".99"])
331331
df
332332
```
333333

334-
Here, the number of divergent transitions dropped dramatically when delta was increased to 0.99.
334+
Here, the number of divergent transitions dropped dramatically when the target acceptance rate was increased to 0.99.
335335

336336
This behavior also has a nice geometric intuition. The more we decrease the step size the more the Hamiltonian Markov chain can explore the neck of the funnel. Consequently, the marginal posterior distribution for $log (\tau)$ stretches further and further towards negative values with the decreasing step size.
337337

@@ -345,7 +345,7 @@ pairplot_divergence(acceptance_runs[0.99], ax=ax, color="C3", divergence=False)
345345
346346
pairplot_divergence(longer_trace, ax=ax, color="C1", divergence=False)
347347
348-
ax.legend(["Centered, delta=0.99", "Centered, delta=0.85"]);
348+
ax.legend(["Centered, target_accept=0.99", "Centered, target_accept=0.85"]);
349349
```
350350

351351
```{code-cell} ipython3
@@ -357,11 +357,11 @@ plt.figure(figsize=(15, 4))
357357
plt.axhline(0.7657852, lw=2.5, color="gray")
358358
359359
mlogtau0 = [logtau0[:, :i].mean() for i in longer_trace.posterior.coords["draw"].values]
360-
plt.plot(mlogtau0, label="Centered, delta=0.85", lw=2.5)
360+
plt.plot(mlogtau0, label="Centered, target_accept=0.85", lw=2.5)
361361
mlogtau2 = [logtau2[:, :i].mean() for i in acceptance_runs[0.90].posterior.coords["draw"].values]
362-
plt.plot(mlogtau2, label="Centered, delta=0.90", lw=2.5)
362+
plt.plot(mlogtau2, label="Centered, target_accept=0.90", lw=2.5)
363363
mlogtau1 = [logtau1[:, :i].mean() for i in acceptance_runs[0.99].posterior.coords["draw"].values]
364-
plt.plot(mlogtau1, label="Centered, delta=0.99", lw=2.5)
364+
plt.plot(mlogtau1, label="Centered, target_accept=0.99", lw=2.5)
365365
plt.ylim(0, 2)
366366
plt.xlabel("Iteration")
367367
plt.ylabel("MCMC mean of log(tau)")
@@ -459,7 +459,13 @@ pairplot_divergence(acceptance_runs[0.99], ax=ax, color="C3", divergence=False)
459459
acceptance_runs[0.90].posterior["log_tau"] = np.log(acceptance_runs[0.90].posterior["tau"])
460460
pairplot_divergence(acceptance_runs[0.90], ax=ax, color="C1", divergence=False)
461461
462-
ax.legend(["Non-Centered, delta=0.80", "Centered, delta=0.99", "Centered, delta=0.90"]);
462+
ax.legend(
463+
[
464+
"Non-Centered, target_accept=0.80",
465+
"Centered, target_accept=0.99",
466+
"Centered, target_accept=0.90",
467+
]
468+
);
463469
```
464470

465471
```{code-cell} ipython3
@@ -468,11 +474,11 @@ plt.axhline(0.7657852, lw=2.5, color="gray")
468474
mlogtaun = [
469475
fit_ncp80.posterior["log_tau"][:, :i].mean() for i in fit_ncp80.posterior.coords["draw"].values
470476
]
471-
plt.plot(mlogtaun, color="C0", lw=2.5, label="Non-Centered, delta=0.80")
477+
plt.plot(mlogtaun, color="C0", lw=2.5, label="Non-Centered, target_accept=0.80")
472478
mlogtau2 = [logtau2[:, :i].mean() for i in acceptance_runs[0.90].posterior.coords["draw"].values]
473-
plt.plot(mlogtau2, color="C2", label="Centered, delta=0.90", lw=2.5)
479+
plt.plot(mlogtau2, color="C2", label="Centered, target_accept=0.90", lw=2.5)
474480
mlogtau1 = [logtau1[:, :i].mean() for i in acceptance_runs[0.99].posterior.coords["draw"].values]
475-
plt.plot(mlogtau1, color="C1", label="Centered, delta=0.99", lw=2.5)
481+
plt.plot(mlogtau1, color="C1", label="Centered, target_accept=0.99", lw=2.5)
476482
plt.ylim(0, 2)
477483
plt.xlabel("Iteration")
478484
plt.ylabel("MCMC mean of log(tau)")

0 commit comments

Comments
 (0)