Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update BART Hawks NBs with new plots #775

Merged
merged 4 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
419 changes: 244 additions & 175 deletions examples/bart/bart_categorical_hawks.ipynb

Large diffs are not rendered by default.

60 changes: 50 additions & 10 deletions examples/bart/bart_categorical_hawks.myst.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ jupytext:
format_name: myst
format_version: 0.13
kernelspec:
display_name: Python 3 (ipykernel)
display_name: pymc-examples
language: python
name: python3
name: pymc-examples
myst:
substitutions:
conda_dependencies: pymc-bart
Expand Down Expand Up @@ -43,6 +43,8 @@ import pymc as pm
import pymc_bart as pmb
import seaborn as sns

from scipy.special import softmax

warnings.simplefilter(action="ignore", category=FutureWarning)
```

Expand Down Expand Up @@ -143,21 +145,48 @@ vi_results = pmb.compute_variable_importance(idata, μ, x_0, method="VI", random
pmb.plot_variable_importance(vi_results);
```

It can be observed that with the covariables `Hallux`, `Culmen`, and `Wing` we achieve the same R$^2$ value that we obtained with all the covariables, this is that the last two covariables contribute less than the other three to the classification. One thing we have to take into account in this is that the HDI is quite wide, which gives us less precision on the results, later we are going to see a way to reduce this.
It can be observed that with the covariables `Hallux`, `Culmen`, and `Wing` we achieve the same $R^2$ value that we obtained with all the covariables, this is that the last two covariables contribute less than the other three to the classification. One thing we have to take into account in this is that the HDI is quite wide, which gives us less precision on the results; later we are going to see a way to reduce this.

+++
We can also plot the scatter plot of the submodels' predictions to the full model's predictions to get an idea of how each new covariate improves the submodel's predictions.

```{code-cell} ipython3
axes = pmb.plot_scatter_submodels(
vi_results, grid=(5, 3), figsize=(12, 14), plot_kwargs={"alpha_scatter": 0.05}
)
plt.suptitle("Comparison of submodels' predictions to full model's\n", fontsize=18)
for ax, cat in zip(axes, np.repeat(species, len(vi_results["labels"]))):
ax.set(title=f"Species {cat}")
```

### Partial Dependence Plot

Let's check the behavior of each covariable for each species with `pmb.plot_pdp()`, which shows the marginal effect a covariate has on the predicted variable, while we average over all the other covariates.
Let's check the behavior of each covariable for each species with `pmb.plot_pdp()`, which shows the marginal effect a covariate has on the predicted variable, while we average over all the other covariates. Since our response variable is categorical, we'll pass `softmax` as the inverse link function to `plot_pdp`.

You can see we have to be careful with the `softmax` function, because it's not vectorized: it considers relationships between elements, so the specific axis along which we apply it matters. By default, scipy applies to all axes, but we want to apply it to the last axis, since that's where the categories are. To make sure of that, we use `np.apply_along_axis` and pass it in a lambda function.

```{code-cell} ipython3
pmb.plot_pdp(μ, X=x_0, Y=y_0, grid=(5, 3), figsize=(12, 7));
axes = pmb.plot_pdp(
μ,
X=x_0,
Y=y_0,
grid=(5, 3),
figsize=(12, 12),
func=lambda x: np.apply_along_axis(softmax, axis=-1, arr=x),
)
plt.suptitle("Partial Dependence Plots\n", fontsize=18)
for (i, ax), cat in zip(enumerate(axes), np.tile(species, len(vi_results["labels"]))):
ax.set(title=f"Species {cat}")
```

The pdp plot, together with the Variable Importance plot, confirms that `Tail` is the covariable with the smaller effect over the predicted variable. In the Variable Importance plot `Tail` is the last covariable to be added and does not improve the result, in the pdp plot `Tail` has the flattest response. For the rest of the covariables in this plot, it's hard to see which of them have more effect over the predicted variable, because they have great variability, showed in the HDI wide, same as before later we are going to see a way to reduce this variability. Finally, some variability depends on the amount of data for each species, which we can see in the `counts` from one of the covariables using Pandas `.describe()` and grouping the data from "Species" with `.groupby("Species")`.
The Partial Dependence Plot, together with the Variable Importance plot, confirms that `Tail` is the covariable with the smaller effect over the predicted variable: in the Variable Importance plot, `Tail` is the last covariate to be added and does not improve the result; in the PDP plot `Tail` has the flattest response.

+++
For the rest of the covariate in this plot, it's hard to see which of them have more effect over the predicted variable, because they have great variability, showed in the HDI width.

Finally, some variability depends on the amount of data for each species, which we can see in the `counts` of each covariable for each species:

```{code-cell} ipython3
Hawks.groupby("Species").count()
```

### Predicted vs Observed

Expand Down Expand Up @@ -222,10 +251,20 @@ pmb.plot_variable_importance(vi_results);
```

```{code-cell} ipython3
pmb.plot_pdp(μ_t, X=x_0, Y=y_0, grid=(5, 3), figsize=(12, 7));
axes = pmb.plot_pdp(
μ_t,
X=x_0,
Y=y_0,
grid=(5, 3),
figsize=(12, 12),
func=lambda x: np.apply_along_axis(softmax, axis=-1, arr=x),
)
plt.suptitle("Partial Dependence Plots\n", fontsize=18)
for (i, ax), cat in zip(enumerate(axes), np.tile(species, len(vi_results["labels"]))):
ax.set(title=f"Species {cat}")
```

Comparing these two plots with the previous ones shows a marked reduction in the variance for each one. In the case of `pmb.plot_variable_importance()` there are smallers error bands with an R$^{2}$ value more close to 1. And for `pm.plot_pdp()` we can see thinner bands and a reduction in the limits on the y-axis, this is a representation of the reduction of the uncertainty due to adjusting the trees separately. Another benefit of this is that is more visible the behavior of each covariable for each one of the species.
Comparing these two plots with the previous ones shows a marked reduction in the variance for each one. In the case of `pmb.plot_variable_importance()` there are smallers error bands with an $R^{2}$ value closer to 1. And for `pmb.plot_pdp()` we can see thinner HDI bands. This is a representation of the reduction of the uncertainty due to adjusting the trees separately. Another benefit of this is that the behavior of each covariable for each one of the species is more visible.

With all these together, we can select `Hallux`, `Culmen`, and, `Wing` as covariables to make the classification.

Expand Down Expand Up @@ -259,6 +298,7 @@ all
## Authors
- Authored by [Pablo Garay](https://github.com/PabloGGaray) and [Osvaldo Martin](https://aloctavodia.github.io/) in May, 2024
- Updated by Osvaldo Martin in Dec, 2024
- Expanded by [Alex Andorra](https://github.com/AlexAndorra) in Feb, 2025

+++

Expand Down