Skip to content

Commit 87c9f10

Browse files
committed
pathfinder: Update notebook to reflect Pathfinder upgrades in pymc-extras>=0.2.2
- Add NUTS sampling results for comparison - Set for better comparison of Pathfinder and NUTS
1 parent 9a36c1a commit 87c9f10

File tree

2 files changed

+33
-17
lines changed

2 files changed

+33
-17
lines changed

examples/variational_inference/pathfinder.ipynb

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
"name": "stdout",
6161
"output_type": "stream",
6262
"text": [
63-
"Running on PyMC v5.20.0\n"
63+
"Running on PyMC v5.20.1+3.gce5f2a271\n"
6464
]
6565
}
6666
],
@@ -143,7 +143,7 @@
143143
{
144144
"data": {
145145
"application/vnd.jupyter.widget-view+json": {
146-
"model_id": "2347bdf3049a4a419f05da238b6c0ec2",
146+
"model_id": "b0e3c66fb3e542979c86b37bde07a125",
147147
"version_major": 2,
148148
"version_minor": 0
149149
},
@@ -174,7 +174,7 @@
174174
{
175175
"data": {
176176
"application/vnd.jupyter.widget-view+json": {
177-
"model_id": "966bbf9be33641d598bd80e3b5c8e3d5",
177+
"model_id": "bc7cc1b0d9284662890da3da2d7603a0",
178178
"version_major": 2,
179179
"version_minor": 0
180180
},
@@ -226,9 +226,9 @@
226226
" Pareto k 0.75 \n",
227227
" \n",
228228
" Timing (seconds): \n",
229-
" Compile 4.91 \n",
230-
" Compute 0.24 \n",
231-
" Total 5.15 \n",
229+
" Compile 4.94 \n",
230+
" Compute 0.25 \n",
231+
" Total 5.19 \n",
232232
"</pre>\n"
233233
],
234234
"text/plain": [
@@ -260,9 +260,9 @@
260260
" Pareto k 0.75 \n",
261261
" \n",
262262
" Timing (seconds): \n",
263-
" Compile 4.91 \n",
264-
" Compute 0.24 \n",
265-
" Total 5.15 \n"
263+
" Compile 4.94 \n",
264+
" Compute 0.25 \n",
265+
" Total 5.19 \n"
266266
]
267267
},
268268
"metadata": {},
@@ -386,19 +386,19 @@
386386
"name": "stdout",
387387
"output_type": "stream",
388388
"text": [
389-
"Last updated: Fri Jan 31 2025\n",
389+
"Last updated: Thu Feb 13 2025\n",
390390
"\n",
391391
"Python implementation: CPython\n",
392392
"Python version : 3.10.16\n",
393393
"IPython version : 8.31.0\n",
394394
"\n",
395395
"xarray: 2025.1.1\n",
396396
"\n",
397-
"matplotlib : 3.10.0\n",
397+
"pymc_extras: 0.2.3\n",
398398
"arviz : 0.20.0\n",
399-
"pymc_extras: 0.2.1\n",
400-
"pymc : 5.20.0\n",
399+
"matplotlib : 3.10.0\n",
401400
"numpy : 1.26.4\n",
401+
"pymc : 5.20.0+15.g5f3f5ec5c\n",
402402
"\n",
403403
"Watermark: 2.5.0\n",
404404
"\n"

examples/variational_inference/pathfinder.myst.md

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ jupytext:
55
format_name: myst
66
format_version: 0.13
77
kernelspec:
8-
display_name: Python 3 (ipykernel)
8+
display_name: python-3.10
99
language: python
1010
name: python3
1111
---
@@ -59,21 +59,37 @@ with pm.Model() as model:
5959
tau = pm.HalfCauchy("tau", 5.0)
6060
6161
z = pm.Normal("z", mu=0, sigma=1, shape=J)
62-
theta = mu + tau * z
62+
theta = pm.Deterministic("theta", mu + tau * z)
6363
obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y)
6464
```
6565

6666
Next, we call `pmx.fit()` and pass in the algorithm we want it to use.
6767

6868
```{code-cell} ipython3
69+
rng = np.random.default_rng(123)
6970
with model:
70-
idata = pmx.fit(method="pathfinder", num_samples=1000)
71+
idata_ref = pm.sample(target_accept=0.9, random_seed=rng)
72+
idata_path = pmx.fit(
73+
method="pathfinder",
74+
jitter=12,
75+
num_draws=1000,
76+
random_seed=123,
77+
)
7178
```
7279

7380
Just like `pymc.sample()`, this returns an idata with samples from the posterior. Note that because these samples do not come from an MCMC chain, convergence can not be assessed in the regular way.
7481

7582
```{code-cell} ipython3
76-
az.plot_trace(idata)
83+
az.plot_forest(
84+
[idata_ref, idata_path],
85+
var_names=["~z"],
86+
model_names=["ref", "path"],
87+
combined=True,
88+
);
89+
```
90+
91+
```{code-cell} ipython3
92+
az.plot_trace(idata_path)
7793
plt.tight_layout();
7894
```
7995

0 commit comments

Comments
 (0)