Skip to content

Commit 5ed004d

Browse files
committed
Improve BART categorical with new plots
1 parent c73b785 commit 5ed004d

File tree

2 files changed

+284
-195
lines changed

2 files changed

+284
-195
lines changed

examples/bart/bart_categorical_hawks.ipynb

Lines changed: 247 additions & 185 deletions
Large diffs are not rendered by default.

examples/bart/bart_categorical_hawks.myst.md

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ jupytext:
55
format_name: myst
66
format_version: 0.13
77
kernelspec:
8-
display_name: Python 3 (ipykernel)
8+
display_name: pymc-examples
99
language: python
10-
name: python3
10+
name: pymc-examples
1111
myst:
1212
substitutions:
1313
conda_dependencies: pymc-bart
@@ -52,6 +52,11 @@ RANDOM_SEED = 8457
5252
az.style.use("arviz-darkgrid")
5353
```
5454

55+
```{code-cell} ipython3
56+
%load_ext autoreload
57+
%autoreload 2
58+
```
59+
5560
## Hawks dataset
5661

5762
Here we will use a dataset that contains information about 3 species of hawks (*CH*=Cooper's, *RT*=Red-tailed, *SS*=Sharp-Shinned). This dataset has information for 908 individuals in total, each one containing 16 variables, in addition to the species. To simplify the example, we will use the following 5 covariables:
@@ -143,21 +148,39 @@ vi_results = pmb.compute_variable_importance(idata, μ, x_0, method="VI", random
143148
pmb.plot_variable_importance(vi_results);
144149
```
145150

146-
It can be observed that with the covariables `Hallux`, `Culmen`, and `Wing` we achieve the same R$^2$ value that we obtained with all the covariables, this is that the last two covariables contribute less than the other three to the classification. One thing we have to take into account in this is that the HDI is quite wide, which gives us less precision on the results, later we are going to see a way to reduce this.
151+
It can be observed that with the covariables `Hallux`, `Culmen`, and `Wing` we achieve the same $R^2$ value that we obtained with all the covariables, this is that the last two covariables contribute less than the other three to the classification. One thing we have to take into account in this is that the HDI is quite wide, which gives us less precision on the results; later we are going to see a way to reduce this.
147152

148-
+++
153+
We can also plot the scatter plot of the submodels' predictions to the full model's predictions to get an idea of how each new covariate improves the submodel's predictions.
154+
155+
```{code-cell} ipython3
156+
axes = pmb.plot_scatter_submodels(
157+
vi_results, grid=(5, 3), figsize=(12, 14), plot_kwargs={"alpha_scatter": 0.05}
158+
)
159+
plt.suptitle("Comparison of submodels' predictions to full model's\n", fontsize=18)
160+
for ax, cat in zip(axes, np.repeat(species, len(vi_results["labels"]))):
161+
ax.set(title=f"Species {cat}")
162+
```
149163

150164
### Partial Dependence Plot
151165

152-
Let's check the behavior of each covariable for each species with `pmb.plot_pdp()`, which shows the marginal effect a covariate has on the predicted variable, while we average over all the other covariates.
166+
Let's check the behavior of each covariable for each species with `pmb.plot_pdp()`, which shows the marginal effect a covariate has on the predicted variable, while we average over all the other covariates. Since our response variable is categorical, we use the `softmax_link=True` parameter to get the partial dependence plot in the probability space.
153167

154168
```{code-cell} ipython3
155-
pmb.plot_pdp(μ, X=x_0, Y=y_0, grid=(5, 3), figsize=(12, 7));
169+
axes = pmb.plot_pdp(μ, X=x_0, Y=y_0, grid=(5, 3), figsize=(12, 12), softmax_link=True)
170+
plt.suptitle("Partial Dependence Plots\n", fontsize=18)
171+
for (i, ax), cat in zip(enumerate(axes), np.tile(species, len(vi_results["labels"]))):
172+
ax.set(title=f"Species {cat}")
156173
```
157174

158-
The pdp plot, together with the Variable Importance plot, confirms that `Tail` is the covariable with the smaller effect over the predicted variable. In the Variable Importance plot `Tail` is the last covariable to be added and does not improve the result, in the pdp plot `Tail` has the flattest response. For the rest of the covariables in this plot, it's hard to see which of them have more effect over the predicted variable, because they have great variability, showed in the HDI wide, same as before later we are going to see a way to reduce this variability. Finally, some variability depends on the amount of data for each species, which we can see in the `counts` from one of the covariables using Pandas `.describe()` and grouping the data from "Species" with `.groupby("Species")`.
175+
The Partial Dependence Plot, together with the Variable Importance plot, confirms that `Tail` is the covariable with the smaller effect over the predicted variable: in the Variable Importance plot, `Tail` is the last covariate to be added and does not improve the result; in the PDP plot `Tail` has the flattest response.
159176

160-
+++
177+
For the rest of the covariate in this plot, it's hard to see which of them have more effect over the predicted variable, because they have great variability, showed in the HDI width.
178+
179+
Finally, some variability depends on the amount of data for each species, which we can see in the `counts` of each covariable for each species:
180+
181+
```{code-cell} ipython3
182+
Hawks.groupby("Species").count()
183+
```
161184

162185
### Predicted vs Observed
163186

@@ -222,10 +245,13 @@ pmb.plot_variable_importance(vi_results);
222245
```
223246

224247
```{code-cell} ipython3
225-
pmb.plot_pdp(μ_t, X=x_0, Y=y_0, grid=(5, 3), figsize=(12, 7));
248+
axes = pmb.plot_pdp(μ_t, X=x_0, Y=y_0, grid=(5, 3), figsize=(12, 12), softmax_link=True)
249+
plt.suptitle("Partial Dependence Plots\n", fontsize=18)
250+
for (i, ax), cat in zip(enumerate(axes), np.tile(species, len(vi_results["labels"]))):
251+
ax.set(title=f"Species {cat}")
226252
```
227253

228-
Comparing these two plots with the previous ones shows a marked reduction in the variance for each one. In the case of `pmb.plot_variable_importance()` there are smallers error bands with an R$^{2}$ value more close to 1. And for `pm.plot_pdp()` we can see thinner bands and a reduction in the limits on the y-axis, this is a representation of the reduction of the uncertainty due to adjusting the trees separately. Another benefit of this is that is more visible the behavior of each covariable for each one of the species.
254+
Comparing these two plots with the previous ones shows a marked reduction in the variance for each one. In the case of `pmb.plot_variable_importance()` there are smallers error bands with an $R^{2}$ value closer to 1. And for `pmb.plot_pdp()` we can see thinner HDI bands. This is a representation of the reduction of the uncertainty due to adjusting the trees separately. Another benefit of this is that the behavior of each covariable for each one of the species is more visible.
229255

230256
With all these together, we can select `Hallux`, `Culmen`, and, `Wing` as covariables to make the classification.
231257

@@ -259,6 +285,7 @@ all
259285
## Authors
260286
- Authored by [Pablo Garay](https://github.com/PabloGGaray) and [Osvaldo Martin](https://aloctavodia.github.io/) in May, 2024
261287
- Updated by Osvaldo Martin in Dec, 2024
288+
- Expanded by [Alex Andorra](https://github.com/AlexAndorra) in Feb, 2025
262289

263290
+++
264291

0 commit comments

Comments
 (0)