Skip to content

Commit 9cbb346

Browse files
authored
BART: Fully non-parametric curve fit example (#519)
* init * initial complete version * plot improvements * improve last sentence
1 parent 9028ba3 commit 9cbb346

File tree

4 files changed

+879
-0
lines changed

4 files changed

+879
-0
lines changed

examples/case_studies/bart_heteroscedasticity.ipynb

+505
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
---
2+
jupytext:
3+
text_representation:
4+
extension: .md
5+
format_name: myst
6+
format_version: 0.13
7+
kernelspec:
8+
display_name: pymc-examples-env
9+
language: python
10+
name: python3
11+
---
12+
13+
(bart_heteroscedasticity)=
14+
# Modeling Heteroscedasticity with BART
15+
16+
:::{post} January, 2023
17+
:tags: bart regression
18+
:category: beginner, reference
19+
:author: [Juan Orduz](https://juanitorduz.github.io/)
20+
:::
21+
22+
+++
23+
24+
In this notebook we show how to use BART to model heteroscedasticity as described in Section 4.1 of [`pymc-bart`](https://github.com/pymc-devs/pymc-bart)'s paper {cite:p}`quiroga2022bart`. We use the `marketing` data set provided by the R package `datarium` {cite:p}`kassambara2019datarium`. The idea is to model a marketing channel contribution to sales as a function of budget.
25+
26+
```{code-cell} ipython3
27+
:tags: []
28+
29+
import os
30+
31+
import arviz as az
32+
import matplotlib.pyplot as plt
33+
import numpy as np
34+
import pandas as pd
35+
import pymc as pm
36+
import pymc_bart as pmb
37+
```
38+
39+
```{code-cell} ipython3
40+
:tags: []
41+
42+
%config InlineBackend.figure_format = "retina"
43+
az.style.use("arviz-darkgrid")
44+
plt.rcParams["figure.figsize"] = [10, 6]
45+
rng = np.random.default_rng(42)
46+
```
47+
48+
## Read Data
49+
50+
```{code-cell} ipython3
51+
try:
52+
df = pd.read_csv(os.path.join("..", "data", "marketing.csv"), sep=";", decimal=",")
53+
except FileNotFoundError:
54+
df = pd.read_csv(pm.get_data("marketing.csv"), sep=";", decimal=",")
55+
56+
n_obs = df.shape[0]
57+
58+
df.head()
59+
```
60+
61+
## EDA
62+
63+
We start by looking into the data. We are going to focus on *Youtube*.
64+
65+
```{code-cell} ipython3
66+
fig, ax = plt.subplots()
67+
ax.plot(df["youtube"], df["sales"], "o", c="C0")
68+
ax.set(title="Sales as a function of Youtube budget", xlabel="budget", ylabel="sales");
69+
```
70+
71+
We clearly see that both the mean and variance are increasing as a function of budget. One possibility is to manually select an explicit parametrization of these functions, e.g. square root or logarithm. However, in this example we want to learn these functions from the data using a BART model.
72+
73+
+++
74+
75+
## Model Specification
76+
77+
We proceed to prepare the data for modeling. We are going to use the `budget` as the predictor and `sales` as the response.
78+
79+
```{code-cell} ipython3
80+
X = df["youtube"].to_numpy().reshape(-1, 1)
81+
Y = df["sales"].to_numpy()
82+
```
83+
84+
Next, we specify the model. Note that we just need one BART distribution which can be vectorized to model both the mean and variance. We use a Gamma distribution as likelihood as we expect the sales to be positive.
85+
86+
```{code-cell} ipython3
87+
with pm.Model() as model_marketing_full:
88+
w = pmb.BART(name="w", X=X, Y=Y, m=200, shape=(2, n_obs))
89+
y = pm.Gamma(name="y", mu=w[0], sigma=pm.math.abs(w[1]), observed=Y)
90+
91+
pm.model_to_graphviz(model=model_marketing_full)
92+
```
93+
94+
We now fit the model.
95+
96+
```{code-cell} ipython3
97+
with model_marketing_full:
98+
idata_marketing_full = pm.sample(random_seed=rng)
99+
posterior_predictive_marketing_full = pm.sample_posterior_predictive(
100+
trace=idata_marketing_full, random_seed=rng
101+
)
102+
```
103+
104+
## Results
105+
106+
We can now visualize the posterior predictive distribution of the mean and the likelihood.
107+
108+
```{code-cell} ipython3
109+
posterior_mean = idata_marketing_full.posterior["w"].mean(dim=("chain", "draw"))[0]
110+
111+
w_hdi = az.hdi(ary=idata_marketing_full, group="posterior", var_names=["w"])
112+
113+
pps = az.extract(
114+
posterior_predictive_marketing_full, group="posterior_predictive", var_names=["y"]
115+
).T
116+
```
117+
118+
```{code-cell} ipython3
119+
idx = np.argsort(X[:, 0])
120+
121+
122+
fig, ax = plt.subplots()
123+
az.plot_hdi(x=X[:, 0], y=pps, ax=ax, fill_kwargs={"alpha": 0.3, "label": r"Likelihood $94\%$ HDI"})
124+
az.plot_hdi(
125+
x=X[:, 0],
126+
hdi_data=w_hdi["w"].sel(w_dim_0=0),
127+
ax=ax,
128+
fill_kwargs={"alpha": 0.6, "label": r"Mean $94\%$ HDI"},
129+
)
130+
ax.plot(X[:, 0][idx], posterior_mean[idx], c="black", lw=3, label="Posterior Mean")
131+
ax.plot(df["youtube"], df["sales"], "o", c="C0", label="Raw Data")
132+
ax.legend(loc="upper left")
133+
ax.set(
134+
title="Sales as a function of Youtube budget - Posterior Predictive",
135+
xlabel="budget",
136+
ylabel="sales",
137+
);
138+
```
139+
140+
The fit looks good! In fact, we see that the mean and variance increase as a function of the budget.
141+
142+
+++
143+
144+
## Authors
145+
- Authored by [Juan Orduz](https://juanitorduz.github.io/) in February 2023
146+
147+
+++
148+
149+
## References
150+
:::{bibliography}
151+
:filter: docname in docnames
152+
:::
153+
154+
+++
155+
156+
## Watermark
157+
158+
```{code-cell} ipython3
159+
:tags: []
160+
161+
%load_ext watermark
162+
%watermark -n -u -v -iv -w -p pytensor
163+
```
164+
165+
:::{include} ../page_footer.md
166+
:::

examples/data/marketing.csv

+201
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
youtube;facebook;newspaper;sales
2+
276,12;45,36;83,04;26,52
3+
53,40;47,16;54,12;12,48
4+
20,64;55,08;83,16;11,16
5+
181,80;49,56;70,20;22,20
6+
216,96;12,96;70,08;15,48
7+
10,44;58,68;90,00;8,64
8+
69,00;39,36;28,20;14,16
9+
144,24;23,52;13,92;15,84
10+
10,32;2,52;1,20;5,76
11+
239,76;3,12;25,44;12,72
12+
79,32;6,96;29,04;10,32
13+
257,64;28,80;4,80;20,88
14+
28,56;42,12;79,08;11,04
15+
117,00;9,12;8,64;11,64
16+
244,92;39,48;55,20;22,80
17+
234,48;57,24;63,48;26,88
18+
81,36;43,92;136,80;15,00
19+
337,68;47,52;66,96;29,28
20+
83,04;24,60;21,96;13,56
21+
176,76;28,68;22,92;17,52
22+
262,08;33,24;64,08;21,60
23+
284,88;6,12;28,20;15,00
24+
15,84;19,08;59,52;6,72
25+
273,96;20,28;31,44;18,60
26+
74,76;15,12;21,96;11,64
27+
315,48;4,20;23,40;14,40
28+
171,48;35,16;15,12;18,00
29+
288,12;20,04;27,48;19,08
30+
298,56;32,52;27,48;22,68
31+
84,72;19,20;48,96;12,60
32+
351,48;33,96;51,84;25,68
33+
135,48;20,88;46,32;14,28
34+
116,64;1,80;36,00;11,52
35+
318,72;24,00;0,36;20,88
36+
114,84;1,68;8,88;11,40
37+
348,84;4,92;10,20;15,36
38+
320,28;52,56;6,00;30,48
39+
89,64;59,28;54,84;17,64
40+
51,72;32,04;42,12;12,12
41+
273,60;45,24;38,40;25,80
42+
243,00;26,76;37,92;19,92
43+
212,40;40,08;46,44;20,52
44+
352,32;33,24;2,16;24,84
45+
248,28;10,08;31,68;15,48
46+
30,12;30,84;51,96;10,20
47+
210,12;27,00;37,80;17,88
48+
107,64;11,88;42,84;12,72
49+
287,88;49,80;22,20;27,84
50+
272,64;18,96;59,88;17,76
51+
80,28;14,04;44,16;11,64
52+
239,76;3,72;41,52;13,68
53+
120,48;11,52;4,32;12,84
54+
259,68;50,04;47,52;27,12
55+
219,12;55,44;70,44;25,44
56+
315,24;34,56;19,08;24,24
57+
238,68;59,28;72,00;28,44
58+
8,76;33,72;49,68;6,60
59+
163,44;23,04;19,92;15,84
60+
252,96;59,52;45,24;28,56
61+
252,84;35,40;11,16;22,08
62+
64,20;2,40;25,68;9,72
63+
313,56;51,24;65,64;29,04
64+
287,16;18,60;32,76;18,84
65+
123,24;35,52;10,08;16,80
66+
157,32;51,36;34,68;21,60
67+
82,80;11,16;1,08;11,16
68+
37,80;29,52;2,64;11,40
69+
167,16;17,40;12,24;16,08
70+
284,88;33,00;13,20;22,68
71+
260,16;52,68;32,64;26,76
72+
238,92;36,72;46,44;21,96
73+
131,76;17,16;38,04;14,88
74+
32,16;39,60;23,16;10,56
75+
155,28;6,84;37,56;13,20
76+
256,08;29,52;15,72;20,40
77+
20,28;52,44;107,28;10,44
78+
33,00;1,92;24,84;8,28
79+
144,60;34,20;17,04;17,04
80+
6,48;35,88;11,28;6,36
81+
139,20;9,24;27,72;13,20
82+
91,68;32,04;26,76;14,16
83+
287,76;4,92;44,28;14,76
84+
90,36;24,36;39,00;13,56
85+
82,08;53,40;42,72;16,32
86+
256,20;51,60;40,56;26,04
87+
231,84;22,08;78,84;18,24
88+
91,56;33,00;19,20;14,40
89+
132,84;48,72;75,84;19,20
90+
105,96;30,60;88,08;15,48
91+
131,76;57,36;61,68;20,04
92+
161,16;5,88;11,16;13,44
93+
34,32;1,80;39,60;8,76
94+
261,24;40,20;70,80;23,28
95+
301,08;43,80;86,76;26,64
96+
128,88;16,80;13,08;13,80
97+
195,96;37,92;63,48;20,28
98+
237,12;4,20;7,08;14,04
99+
221,88;25,20;26,40;18,60
100+
347,64;50,76;61,44;30,48
101+
162,24;50,04;55,08;20,64
102+
266,88;5,16;59,76;14,04
103+
355,68;43,56;121,08;28,56
104+
336,24;12,12;25,68;17,76
105+
225,48;20,64;21,48;17,64
106+
285,84;41,16;6,36;24,84
107+
165,48;55,68;70,80;23,04
108+
30,00;13,20;35,64;8,64
109+
108,48;0,36;27,84;10,44
110+
15,72;0,48;30,72;6,36
111+
306,48;32,28;6,60;23,76
112+
270,96;9,84;67,80;16,08
113+
290,04;45,60;27,84;26,16
114+
210,84;18,48;2,88;16,92
115+
251,52;24,72;12,84;19,08
116+
93,84;56,16;41,40;17,52
117+
90,12;42,00;63,24;15,12
118+
167,04;17,16;30,72;14,64
119+
91,68;0,96;17,76;11,28
120+
150,84;44,28;95,04;19,08
121+
23,28;19,20;26,76;7,92
122+
169,56;32,16;55,44;18,60
123+
22,56;26,04;60,48;8,40
124+
268,80;2,88;18,72;13,92
125+
147,72;41,52;14,88;18,24
126+
275,40;38,76;89,04;23,64
127+
104,64;14,16;31,08;12,72
128+
9,36;46,68;60,72;7,92
129+
96,24;0,00;11,04;10,56
130+
264,36;58,80;3,84;29,64
131+
71,52;14,40;51,72;11,64
132+
0,84;47,52;10,44;1,92
133+
318,24;3,48;51,60;15,24
134+
10,08;32,64;2,52;6,84
135+
263,76;40,20;54,12;23,52
136+
44,28;46,32;78,72;12,96
137+
57,96;56,40;10,20;13,92
138+
30,72;46,80;11,16;11,40
139+
328,44;34,68;71,64;24,96
140+
51,60;31,08;24,60;11,52
141+
221,88;52,68;2,04;24,84
142+
88,08;20,40;15,48;13,08
143+
232,44;42,48;90,72;23,04
144+
264,60;39,84;45,48;24,12
145+
125,52;6,84;41,28;12,48
146+
115,44;17,76;46,68;13,68
147+
168,36;2,28;10,80;12,36
148+
288,12;8,76;10,44;15,84
149+
291,84;58,80;53,16;30,48
150+
45,60;48,36;14,28;13,08
151+
53,64;30,96;24,72;12,12
152+
336,84;16,68;44,40;19,32
153+
145,20;10,08;58,44;13,92
154+
237,12;27,96;17,04;19,92
155+
205,56;47,64;45,24;22,80
156+
225,36;25,32;11,40;18,72
157+
4,92;13,92;6,84;3,84
158+
112,68;52,20;60,60;18,36
159+
179,76;1,56;29,16;12,12
160+
14,04;44,28;54,24;8,76
161+
158,04;22,08;41,52;15,48
162+
207,00;21,72;36,84;17,28
163+
102,84;42,96;59,16;15,96
164+
226,08;21,72;30,72;17,88
165+
196,20;44,16;8,88;21,60
166+
140,64;17,64;6,48;14,28
167+
281,40;4,08;101,76;14,28
168+
21,48;45,12;25,92;9,60
169+
248,16;6,24;23,28;14,64
170+
258,48;28,32;69,12;20,52
171+
341,16;12,72;7,68;18,00
172+
60,00;13,92;22,08;10,08
173+
197,40;25,08;56,88;17,40
174+
23,52;24,12;20,40;9,12
175+
202,08;8,52;15,36;14,04
176+
266,88;4,08;15,72;13,80
177+
332,28;58,68;50,16;32,40
178+
298,08;36,24;24,36;24,24
179+
204,24;9,36;42,24;14,04
180+
332,04;2,76;28,44;14,16
181+
198,72;12,00;21,12;15,12
182+
187,92;3,12;9,96;12,60
183+
262,20;6,48;32,88;14,64
184+
67,44;6,84;35,64;10,44
185+
345,12;51,60;86,16;31,44
186+
304,56;25,56;36,00;21,12
187+
246,00;54,12;23,52;27,12
188+
167,40;2,52;31,92;12,36
189+
229,32;34,44;21,84;20,76
190+
343,20;16,68;4,44;19,08
191+
22,44;14,52;28,08;8,04
192+
47,40;49,32;6,96;12,96
193+
90,60;12,96;7,20;11,88
194+
20,64;4,92;37,92;7,08
195+
200,16;50,40;4,32;23,52
196+
179,64;42,72;7,20;20,76
197+
45,84;4,44;16,56;9,12
198+
113,04;5,88;9,72;11,64
199+
212,40;11,16;7,68;15,36
200+
340,32;50,40;79,44;30,60
201+
278,52;10,32;10,44;16,08

examples/references.bib

+7
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,13 @@ @article{johnson1999
343343
title = {The Insignificance of Statistical Significance Testing},
344344
journal = {The Journal of Wildlife Management}
345345
}
346+
@manual{kassambara2019datarium,
347+
title = {datarium: Data Bank for Statistical Analysis and Visualization},
348+
author = {Alboukadel Kassambara},
349+
year = {2019},
350+
note = {R package version 0.1.0},
351+
url = {https://CRAN.R-project.org/package=datarium}
352+
}
346353
@misc{kingma2014autoencoding,
347354
title = {Auto-Encoding Variational Bayes},
348355
author = {Diederik P Kingma and Max Welling},

0 commit comments

Comments
 (0)