Skip to content

Commit 66a7ec1

Browse files
committed
nbqa NB
1 parent b650686 commit 66a7ec1

File tree

1 file changed

+24
-25
lines changed

1 file changed

+24
-25
lines changed

docs/source/notebooks/GLM-hierarchical-jax.ipynb

+24-25
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@
5454
"import pymc3 as pm\n",
5555
"import theano\n",
5656
"import pymc3.sampling_jax\n",
57-
"print('Running on PyMC3 v{}'.format(pm.__version__))"
57+
"\n",
58+
"print(\"Running on PyMC3 v{}\".format(pm.__version__))"
5859
]
5960
},
6061
{
@@ -64,7 +65,7 @@
6465
"outputs": [],
6566
"source": [
6667
"%config InlineBackend.figure_format = 'retina'\n",
67-
"az.style.use('arviz-darkgrid')"
68+
"az.style.use(\"arviz-darkgrid\")"
6869
]
6970
},
7071
{
@@ -73,10 +74,10 @@
7374
"metadata": {},
7475
"outputs": [],
7576
"source": [
76-
"data = pd.read_csv(pm.get_data('radon.csv'))\n",
77-
"data['log_radon'] = data['log_radon'].astype(theano.config.floatX)\n",
77+
"data = pd.read_csv(pm.get_data(\"radon.csv\"))\n",
78+
"data[\"log_radon\"] = data[\"log_radon\"].astype(theano.config.floatX)\n",
7879
"county_names = data.county.unique()\n",
79-
"county_idx = data.county_code.values.astype('int32')\n",
80+
"county_idx = data.county_code.values.astype(\"int32\")\n",
8081
"\n",
8182
"n_counties = len(data.county.unique())"
8283
]
@@ -96,27 +97,26 @@
9697
"source": [
9798
"with pm.Model() as hierarchical_model:\n",
9899
" # Hyperpriors for group nodes\n",
99-
" mu_a = pm.Normal('mu_a', mu=0., sigma=100.)\n",
100-
" sigma_a = pm.HalfNormal('sigma_a', 5.)\n",
101-
" mu_b = pm.Normal('mu_b', mu=0., sigma=100.)\n",
102-
" sigma_b = pm.HalfNormal('sigma_b', 5.)\n",
100+
" mu_a = pm.Normal(\"mu_a\", mu=0.0, sigma=100.0)\n",
101+
" sigma_a = pm.HalfNormal(\"sigma_a\", 5.0)\n",
102+
" mu_b = pm.Normal(\"mu_b\", mu=0.0, sigma=100.0)\n",
103+
" sigma_b = pm.HalfNormal(\"sigma_b\", 5.0)\n",
103104
"\n",
104105
" # Intercept for each county, distributed around group mean mu_a\n",
105106
" # Above we just set mu and sd to a fixed value while here we\n",
106107
" # plug in a common group distribution for all a and b (which are\n",
107108
" # vectors of length n_counties).\n",
108-
" a = pm.Normal('a', mu=mu_a, sigma=sigma_a, shape=n_counties)\n",
109+
" a = pm.Normal(\"a\", mu=mu_a, sigma=sigma_a, shape=n_counties)\n",
109110
" # Intercept for each county, distributed around group mean mu_a\n",
110-
" b = pm.Normal('b', mu=mu_b, sigma=sigma_b, shape=n_counties)\n",
111+
" b = pm.Normal(\"b\", mu=mu_b, sigma=sigma_b, shape=n_counties)\n",
111112
"\n",
112113
" # Model error\n",
113-
" eps = pm.HalfCauchy('eps', 5.)\n",
114+
" eps = pm.HalfCauchy(\"eps\", 5.0)\n",
114115
"\n",
115-
" radon_est = a[county_idx] + b[county_idx]*data.floor.values\n",
116+
" radon_est = a[county_idx] + b[county_idx] * data.floor.values\n",
116117
"\n",
117118
" # Data likelihood\n",
118-
" radon_like = pm.Normal('radon_like', mu=radon_est,\n",
119-
" sigma=eps, observed=data.log_radon)"
119+
" radon_like = pm.Normal(\"radon_like\", mu=radon_est, sigma=eps, observed=data.log_radon)"
120120
]
121121
},
122122
{
@@ -193,8 +193,9 @@
193193
"source": [
194194
"%%time\n",
195195
"with hierarchical_model:\n",
196-
" hierarchical_trace = pm.sample(2000, tune=2000, target_accept=.9, \n",
197-
" compute_convergence_checks=False)"
196+
" hierarchical_trace = pm.sample(\n",
197+
" 2000, tune=2000, target_accept=0.9, compute_convergence_checks=False\n",
198+
" )"
198199
]
199200
},
200201
{
@@ -238,8 +239,7 @@
238239
"%%time\n",
239240
"# Inference button (TM)!\n",
240241
"with hierarchical_model:\n",
241-
" hierarchical_trace_jax = pm.sampling_jax.sample_numpyro_nuts(\n",
242-
" 2000, tune=2000, target_accept=.9)"
242+
" hierarchical_trace_jax = pm.sampling_jax.sample_numpyro_nuts(2000, tune=2000, target_accept=0.9)"
243243
]
244244
},
245245
{
@@ -281,10 +281,10 @@
281281
}
282282
],
283283
"source": [
284-
"pm.traceplot(hierarchical_trace_jax, \n",
285-
" var_names=['mu_a', 'mu_b',\n",
286-
" 'sigma_a_log__', 'sigma_b_log__',\n",
287-
" 'eps_log__']);"
284+
"pm.traceplot(\n",
285+
" hierarchical_trace_jax,\n",
286+
" var_names=[\"mu_a\", \"mu_b\", \"sigma_a_log__\", \"sigma_b_log__\", \"eps_log__\"],\n",
287+
");"
288288
]
289289
},
290290
{
@@ -309,8 +309,7 @@
309309
}
310310
],
311311
"source": [
312-
"pm.traceplot(hierarchical_trace_jax, \n",
313-
" var_names=['a'], coords={'a_dim_0': range(5)});"
312+
"pm.traceplot(hierarchical_trace_jax, var_names=[\"a\"], coords={\"a_dim_0\": range(5)});"
314313
]
315314
},
316315
{

0 commit comments

Comments
 (0)