Skip to content

Commit de582f7

Browse files
authored
add error bars to variable importance (#90)
1 parent 7cf8595 commit de582f7

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

pymc_bart/utils.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from scipy.interpolate import griddata
1313
from scipy.signal import savgol_filter
1414
from scipy.stats import norm, pearsonr
15+
from xarray import concat
1516

1617
from .tree import Tree
1718

@@ -742,6 +743,12 @@ def plot_variable_importance(
742743
labels = X.columns
743744
X = X.values
744745

746+
n_draws = idata["posterior"].dims["draw"]
747+
half = n_draws // 2
748+
f_half = idata["sample_stats"]["variable_inclusion"].sel(draw=slice(0, half - 1))
749+
s_half = idata["sample_stats"]["variable_inclusion"].sel(draw=slice(half, n_draws))
750+
751+
var_imp_chains = concat([f_half, s_half], dim="chain", join="override").mean(("draw")).values
745752
var_imp = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
746753
if labels is None:
747754
labels_ary = np.arange(len(var_imp))
@@ -759,7 +766,16 @@ def plot_variable_importance(
759766
indices = idxs[::-1]
760767
else:
761768
indices = np.arange(len(var_imp))
762-
axes[0].plot((var_imp / var_imp.sum())[indices], "o-")
769+
770+
chains_mean = (var_imp / var_imp.sum())[indices]
771+
chains_hdi = az.hdi((var_imp_chains.T / var_imp_chains.sum(axis=1)).T)[indices]
772+
773+
axes[0].errorbar(
774+
ticks,
775+
chains_mean,
776+
np.array((chains_mean - chains_hdi[:, 0], chains_hdi[:, 1] - chains_mean)),
777+
color="C0",
778+
)
763779
axes[0].set_xticks(ticks)
764780
axes[0].set_xticklabels(labels_ary[indices])
765781
axes[0].set_xlabel("covariables")
@@ -790,8 +806,9 @@ def plot_variable_importance(
790806
ev_mean[idx] = np.mean(pearson)
791807
ev_hdi[idx] = az.hdi(pearson)
792808

793-
axes[1].errorbar(ticks, ev_mean, np.array((ev_mean - ev_hdi[:, 0], ev_hdi[:, 1] - ev_mean)))
794-
809+
axes[1].errorbar(
810+
ticks, ev_mean, np.array((ev_mean - ev_hdi[:, 0], ev_hdi[:, 1] - ev_mean)), color="C0"
811+
)
795812
axes[1].axhline(ev_mean[-1], ls="--", color="0.5")
796813
axes[1].set_xticks(ticks)
797814
axes[1].set_xticklabels(ticks + 1)

0 commit comments

Comments
 (0)