@@ -254,13 +254,13 @@ def identity(x):
254
254
)
255
255
256
256
new_x = fake_X [:, var ]
257
- p_d = np .array (y_pred )
257
+ p_d = func ( np .array (y_pred ) )
258
258
259
259
for s_i in range (shape ):
260
260
if centered :
261
- p_di = func ( p_d [:, :, s_i ]) - func ( p_d [:, :, s_i ][:, 0 ][:, None ])
261
+ p_di = p_d [:, :, s_i ] - p_d [:, :, s_i ][:, 0 ][:, None ]
262
262
else :
263
- p_di = func ( p_d [:, :, s_i ])
263
+ p_di = p_d [:, :, s_i ]
264
264
if var in var_discrete :
265
265
axes [count ].plot (new_x , p_di .mean (0 ), "o" , color = color_mean )
266
266
axes [count ].plot (new_x , p_di .T , "." , color = color , alpha = alpha )
@@ -393,14 +393,17 @@ def identity(x):
393
393
for var in range (len (var_idx )):
394
394
excluded = indices [:]
395
395
excluded .remove (var )
396
- p_d = _sample_posterior (
397
- all_trees , X = fake_X , rng = rng , size = samples , excluded = excluded , shape = shape
396
+ p_d = func (
397
+ _sample_posterior (
398
+ all_trees , X = fake_X , rng = rng , size = samples , excluded = excluded , shape = shape
399
+ )
398
400
)
401
+
399
402
with warnings .catch_warnings ():
400
403
warnings .filterwarnings ("ignore" , message = "hdi currently interprets 2d data" )
401
404
new_x = fake_X [:, var ]
402
405
for s_i in range (shape ):
403
- p_di = func ( p_d [:, :, s_i ])
406
+ p_di = p_d [:, :, s_i ]
404
407
null_pd .append (p_di .mean ())
405
408
if var in var_discrete :
406
409
_ , idx_uni = np .unique (new_x , return_index = True )
@@ -1125,8 +1128,11 @@ def plot_scatter_submodels(
1125
1128
plot_kwargs : dict
1126
1129
Additional keyword arguments for the plot. Defaults to None.
1127
1130
Valid keys are:
1128
- - color_ref : matplotlib valid color for the 45 degree line
1131
+ - marker_scatter : matplotlib valid marker for the scatter plot
1129
1132
- color_scatter: matplotlib valid color for the scatter plot
1133
+ - alpha_scatter: matplotlib valid alpha for the scatter plot
1134
+ - color_ref: matplotlib valid color for the 45 degree line
1135
+ - ls_ref: matplotlib valid linestyle for the reference line
1130
1136
axes : axes
1131
1137
Matplotlib axes.
1132
1138
@@ -1140,41 +1146,69 @@ def plot_scatter_submodels(
1140
1146
submodels = np .sort (submodels )
1141
1147
1142
1148
indices = vi_results ["indices" ][submodels ]
1143
- preds = vi_results ["preds" ][submodels ]
1149
+ preds_sub = vi_results ["preds" ][submodels ]
1144
1150
preds_all = vi_results ["preds_all" ]
1145
1151
1152
+ if labels is None :
1153
+ labels = vi_results ["labels" ][submodels ]
1154
+
1155
+ # handle categorical regression case:
1156
+ n_cats = None
1157
+ if preds_all .ndim > 2 :
1158
+ n_cats = preds_all .shape [- 1 ]
1159
+ indices = np .tile (indices , n_cats )
1160
+
1146
1161
if ax is None :
1147
1162
_ , ax = _get_axes (grid , len (indices ), True , True , figsize )
1148
1163
1149
1164
if plot_kwargs is None :
1150
1165
plot_kwargs = {}
1151
1166
1152
- if labels is None :
1153
- labels = vi_results ["labels" ][submodels ]
1154
-
1155
1167
if func is not None :
1156
- preds = func (preds )
1168
+ preds_sub = func (preds_sub )
1157
1169
preds_all = func (preds_all )
1158
1170
1159
- min_ = min (np .min (preds ), np .min (preds_all ))
1160
- max_ = max (np .max (preds ), np .max (preds_all ))
1161
-
1162
- for pred , x_label , axi in zip (preds , labels , ax .ravel ()):
1163
- axi .plot (
1164
- pred ,
1165
- preds_all ,
1166
- marker = plot_kwargs .get ("marker_scatter" , "." ),
1167
- ls = "" ,
1168
- color = plot_kwargs .get ("color_scatter" , "C0" ),
1169
- alpha = plot_kwargs .get ("alpha_scatter" , 0.1 ),
1170
- )
1171
- axi .set_xlabel (x_label )
1172
- axi .axline (
1173
- [min_ , min_ ],
1174
- [max_ , max_ ],
1175
- color = plot_kwargs .get ("color_ref" , "0.5" ),
1176
- ls = plot_kwargs .get ("ls_ref" , "--" ),
1177
- )
1171
+ min_ = min (np .min (preds_sub ), np .min (preds_all ))
1172
+ max_ = max (np .max (preds_sub ), np .max (preds_all ))
1173
+
1174
+ # handle categorical regression case:
1175
+ if n_cats is not None :
1176
+ i = 0
1177
+ for cat in range (n_cats ):
1178
+ for pred_sub , x_label in zip (preds_sub , labels ):
1179
+ ax [i ].plot (
1180
+ pred_sub [..., cat ],
1181
+ preds_all [..., cat ],
1182
+ marker = plot_kwargs .get ("marker_scatter" , "." ),
1183
+ ls = "" ,
1184
+ color = plot_kwargs .get ("color_scatter" , f"C{ cat } " ),
1185
+ alpha = plot_kwargs .get ("alpha_scatter" , 0.1 ),
1186
+ )
1187
+ ax [i ].set (xlabel = x_label , ylabel = "ref model" , title = f"Category { cat } " )
1188
+ ax [i ].axline (
1189
+ [min_ , min_ ],
1190
+ [max_ , max_ ],
1191
+ color = plot_kwargs .get ("color_ref" , "0.5" ),
1192
+ ls = plot_kwargs .get ("ls_ref" , "--" ),
1193
+ )
1194
+ i += 1
1195
+ else :
1196
+ for pred_sub , x_label , axi in zip (preds_sub , labels , ax .ravel ()):
1197
+ axi .plot (
1198
+ pred_sub ,
1199
+ preds_all ,
1200
+ marker = plot_kwargs .get ("marker_scatter" , "." ),
1201
+ ls = "" ,
1202
+ color = plot_kwargs .get ("color_scatter" , "C0" ),
1203
+ alpha = plot_kwargs .get ("alpha_scatter" , 0.1 ),
1204
+ )
1205
+ axi .set (xlabel = x_label , ylabel = "ref model" )
1206
+ axi .axline (
1207
+ [min_ , min_ ],
1208
+ [max_ , max_ ],
1209
+ color = plot_kwargs .get ("color_ref" , "0.5" ),
1210
+ ls = plot_kwargs .get ("ls_ref" , "--" ),
1211
+ )
1178
1212
return ax
1179
1213
1180
1214
0 commit comments