12
12
from scipy .interpolate import griddata
13
13
from scipy .signal import savgol_filter
14
14
from scipy .stats import norm , pearsonr
15
+ from xarray import concat
15
16
16
17
from .tree import Tree
17
18
@@ -742,6 +743,12 @@ def plot_variable_importance(
742
743
labels = X .columns
743
744
X = X .values
744
745
746
+ n_draws = idata ["posterior" ].dims ["draw" ]
747
+ half = n_draws // 2
748
+ f_half = idata ["sample_stats" ]["variable_inclusion" ].sel (draw = slice (0 , half - 1 ))
749
+ s_half = idata ["sample_stats" ]["variable_inclusion" ].sel (draw = slice (half , n_draws ))
750
+
751
+ var_imp_chains = concat ([f_half , s_half ], dim = "chain" , join = "override" ).mean (("draw" )).values
745
752
var_imp = idata ["sample_stats" ]["variable_inclusion" ].mean (("chain" , "draw" )).values
746
753
if labels is None :
747
754
labels_ary = np .arange (len (var_imp ))
@@ -759,7 +766,16 @@ def plot_variable_importance(
759
766
indices = idxs [::- 1 ]
760
767
else :
761
768
indices = np .arange (len (var_imp ))
762
- axes [0 ].plot ((var_imp / var_imp .sum ())[indices ], "o-" )
769
+
770
+ chains_mean = (var_imp / var_imp .sum ())[indices ]
771
+ chains_hdi = az .hdi ((var_imp_chains .T / var_imp_chains .sum (axis = 1 )).T )[indices ]
772
+
773
+ axes [0 ].errorbar (
774
+ ticks ,
775
+ chains_mean ,
776
+ np .array ((chains_mean - chains_hdi [:, 0 ], chains_hdi [:, 1 ] - chains_mean )),
777
+ color = "C0" ,
778
+ )
763
779
axes [0 ].set_xticks (ticks )
764
780
axes [0 ].set_xticklabels (labels_ary [indices ])
765
781
axes [0 ].set_xlabel ("covariables" )
@@ -790,8 +806,9 @@ def plot_variable_importance(
790
806
ev_mean [idx ] = np .mean (pearson )
791
807
ev_hdi [idx ] = az .hdi (pearson )
792
808
793
- axes [1 ].errorbar (ticks , ev_mean , np .array ((ev_mean - ev_hdi [:, 0 ], ev_hdi [:, 1 ] - ev_mean )))
794
-
809
+ axes [1 ].errorbar (
810
+ ticks , ev_mean , np .array ((ev_mean - ev_hdi [:, 0 ], ev_hdi [:, 1 ] - ev_mean )), color = "C0"
811
+ )
795
812
axes [1 ].axhline (ev_mean [- 1 ], ls = "--" , color = "0.5" )
796
813
axes [1 ].set_xticks (ticks )
797
814
axes [1 ].set_xticklabels (ticks + 1 )
0 commit comments