|
54 | 54 | "import pymc3 as pm\n",
|
55 | 55 | "import theano\n",
|
56 | 56 | "import pymc3.sampling_jax\n",
|
57 |
| - "print('Running on PyMC3 v{}'.format(pm.__version__))" |
| 57 | + "\n", |
| 58 | + "print(\"Running on PyMC3 v{}\".format(pm.__version__))" |
58 | 59 | ]
|
59 | 60 | },
|
60 | 61 | {
|
|
64 | 65 | "outputs": [],
|
65 | 66 | "source": [
|
66 | 67 | "%config InlineBackend.figure_format = 'retina'\n",
|
67 |
| - "az.style.use('arviz-darkgrid')" |
| 68 | + "az.style.use(\"arviz-darkgrid\")" |
68 | 69 | ]
|
69 | 70 | },
|
70 | 71 | {
|
|
73 | 74 | "metadata": {},
|
74 | 75 | "outputs": [],
|
75 | 76 | "source": [
|
76 |
| - "data = pd.read_csv(pm.get_data('radon.csv'))\n", |
77 |
| - "data['log_radon'] = data['log_radon'].astype(theano.config.floatX)\n", |
| 77 | + "data = pd.read_csv(pm.get_data(\"radon.csv\"))\n", |
| 78 | + "data[\"log_radon\"] = data[\"log_radon\"].astype(theano.config.floatX)\n", |
78 | 79 | "county_names = data.county.unique()\n",
|
79 |
| - "county_idx = data.county_code.values.astype('int32')\n", |
| 80 | + "county_idx = data.county_code.values.astype(\"int32\")\n", |
80 | 81 | "\n",
|
81 | 82 | "n_counties = len(data.county.unique())"
|
82 | 83 | ]
|
|
96 | 97 | "source": [
|
97 | 98 | "with pm.Model() as hierarchical_model:\n",
|
98 | 99 | " # Hyperpriors for group nodes\n",
|
99 |
| - " mu_a = pm.Normal('mu_a', mu=0., sigma=100.)\n", |
100 |
| - " sigma_a = pm.HalfNormal('sigma_a', 5.)\n", |
101 |
| - " mu_b = pm.Normal('mu_b', mu=0., sigma=100.)\n", |
102 |
| - " sigma_b = pm.HalfNormal('sigma_b', 5.)\n", |
| 100 | + " mu_a = pm.Normal(\"mu_a\", mu=0.0, sigma=100.0)\n", |
| 101 | + " sigma_a = pm.HalfNormal(\"sigma_a\", 5.0)\n", |
| 102 | + " mu_b = pm.Normal(\"mu_b\", mu=0.0, sigma=100.0)\n", |
| 103 | + " sigma_b = pm.HalfNormal(\"sigma_b\", 5.0)\n", |
103 | 104 | "\n",
|
104 | 105 | " # Intercept for each county, distributed around group mean mu_a\n",
|
105 | 106 | " # Above we just set mu and sd to a fixed value while here we\n",
|
106 | 107 | " # plug in a common group distribution for all a and b (which are\n",
|
107 | 108 | " # vectors of length n_counties).\n",
|
108 |
| - " a = pm.Normal('a', mu=mu_a, sigma=sigma_a, shape=n_counties)\n", |
| 109 | + " a = pm.Normal(\"a\", mu=mu_a, sigma=sigma_a, shape=n_counties)\n", |
109 | 110 | " # Intercept for each county, distributed around group mean mu_a\n",
|
110 |
| - " b = pm.Normal('b', mu=mu_b, sigma=sigma_b, shape=n_counties)\n", |
| 111 | + " b = pm.Normal(\"b\", mu=mu_b, sigma=sigma_b, shape=n_counties)\n", |
111 | 112 | "\n",
|
112 | 113 | " # Model error\n",
|
113 |
| - " eps = pm.HalfCauchy('eps', 5.)\n", |
| 114 | + " eps = pm.HalfCauchy(\"eps\", 5.0)\n", |
114 | 115 | "\n",
|
115 |
| - " radon_est = a[county_idx] + b[county_idx]*data.floor.values\n", |
| 116 | + " radon_est = a[county_idx] + b[county_idx] * data.floor.values\n", |
116 | 117 | "\n",
|
117 | 118 | " # Data likelihood\n",
|
118 |
| - " radon_like = pm.Normal('radon_like', mu=radon_est,\n", |
119 |
| - " sigma=eps, observed=data.log_radon)" |
| 119 | + " radon_like = pm.Normal(\"radon_like\", mu=radon_est, sigma=eps, observed=data.log_radon)" |
120 | 120 | ]
|
121 | 121 | },
|
122 | 122 | {
|
|
193 | 193 | "source": [
|
194 | 194 | "%%time\n",
|
195 | 195 | "with hierarchical_model:\n",
|
196 |
| - " hierarchical_trace = pm.sample(2000, tune=2000, target_accept=.9, \n", |
197 |
| - " compute_convergence_checks=False)" |
| 196 | + " hierarchical_trace = pm.sample(\n", |
| 197 | + " 2000, tune=2000, target_accept=0.9, compute_convergence_checks=False\n", |
| 198 | + " )" |
198 | 199 | ]
|
199 | 200 | },
|
200 | 201 | {
|
|
238 | 239 | "%%time\n",
|
239 | 240 | "# Inference button (TM)!\n",
|
240 | 241 | "with hierarchical_model:\n",
|
241 |
| - " hierarchical_trace_jax = pm.sampling_jax.sample_numpyro_nuts(\n", |
242 |
| - " 2000, tune=2000, target_accept=.9)" |
| 242 | + " hierarchical_trace_jax = pm.sampling_jax.sample_numpyro_nuts(2000, tune=2000, target_accept=0.9)" |
243 | 243 | ]
|
244 | 244 | },
|
245 | 245 | {
|
|
281 | 281 | }
|
282 | 282 | ],
|
283 | 283 | "source": [
|
284 |
| - "pm.traceplot(hierarchical_trace_jax, \n", |
285 |
| - " var_names=['mu_a', 'mu_b',\n", |
286 |
| - " 'sigma_a_log__', 'sigma_b_log__',\n", |
287 |
| - " 'eps_log__']);" |
| 284 | + "pm.traceplot(\n", |
| 285 | + " hierarchical_trace_jax,\n", |
| 286 | + " var_names=[\"mu_a\", \"mu_b\", \"sigma_a_log__\", \"sigma_b_log__\", \"eps_log__\"],\n", |
| 287 | + ");" |
288 | 288 | ]
|
289 | 289 | },
|
290 | 290 | {
|
|
309 | 309 | }
|
310 | 310 | ],
|
311 | 311 | "source": [
|
312 |
| - "pm.traceplot(hierarchical_trace_jax, \n", |
313 |
| - " var_names=['a'], coords={'a_dim_0': range(5)});" |
| 312 | + "pm.traceplot(hierarchical_trace_jax, var_names=[\"a\"], coords={\"a_dim_0\": range(5)});" |
314 | 313 | ]
|
315 | 314 | },
|
316 | 315 | {
|
|
0 commit comments