Skip to content

Commit 73813e1

Browse files
Enhance plot_pdp and fix plot_scatter_submodels (#218)
* Add YML env files * Expand scatter_submodels to categorical likelihood * Add softmax option to plot_pdp * Remove comments * Use func for softmax * handle func upstream * move func upstream * ensure p_d is an array --------- Co-authored-by: aloctavodia <[email protected]>
1 parent 16a78df commit 73813e1

File tree

3 files changed

+102
-31
lines changed

3 files changed

+102
-31
lines changed

env-dev.yml

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
name: pymc-bart-dev
2+
channels:
3+
- conda-forge
4+
- defaults
5+
dependencies:
6+
- pymc>=5.16.2,<=5.19.1
7+
- arviz>=0.18.0
8+
- numba
9+
- matplotlib
10+
- numpy
11+
- pytensor
12+
# Development dependencies
13+
- pytest>=4.4.0
14+
- pytest-cov>=2.6.1
15+
- click==8.0.4
16+
- pylint==2.17.4
17+
- pre-commit
18+
- black
19+
- isort
20+
- flake8
21+
- pip
22+
- pip:
23+
- -e .

env.yml

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
name: pymc-bart
2+
channels:
3+
- conda-forge
4+
- defaults
5+
dependencies:
6+
- pymc>=5.16.2,<=5.19.1
7+
- arviz>=0.18.0
8+
- numba
9+
- matplotlib
10+
- numpy
11+
- pytensor
12+
- pip
13+
- pip:
14+
- pymc-bart

pymc_bart/utils.py

+65-31
Original file line numberDiff line numberDiff line change
@@ -254,13 +254,13 @@ def identity(x):
254254
)
255255

256256
new_x = fake_X[:, var]
257-
p_d = np.array(y_pred)
257+
p_d = func(np.array(y_pred))
258258

259259
for s_i in range(shape):
260260
if centered:
261-
p_di = func(p_d[:, :, s_i]) - func(p_d[:, :, s_i][:, 0][:, None])
261+
p_di = p_d[:, :, s_i] - p_d[:, :, s_i][:, 0][:, None]
262262
else:
263-
p_di = func(p_d[:, :, s_i])
263+
p_di = p_d[:, :, s_i]
264264
if var in var_discrete:
265265
axes[count].plot(new_x, p_di.mean(0), "o", color=color_mean)
266266
axes[count].plot(new_x, p_di.T, ".", color=color, alpha=alpha)
@@ -393,14 +393,17 @@ def identity(x):
393393
for var in range(len(var_idx)):
394394
excluded = indices[:]
395395
excluded.remove(var)
396-
p_d = _sample_posterior(
397-
all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape
396+
p_d = func(
397+
_sample_posterior(
398+
all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape
399+
)
398400
)
401+
399402
with warnings.catch_warnings():
400403
warnings.filterwarnings("ignore", message="hdi currently interprets 2d data")
401404
new_x = fake_X[:, var]
402405
for s_i in range(shape):
403-
p_di = func(p_d[:, :, s_i])
406+
p_di = p_d[:, :, s_i]
404407
null_pd.append(p_di.mean())
405408
if var in var_discrete:
406409
_, idx_uni = np.unique(new_x, return_index=True)
@@ -1125,8 +1128,11 @@ def plot_scatter_submodels(
11251128
plot_kwargs : dict
11261129
Additional keyword arguments for the plot. Defaults to None.
11271130
Valid keys are:
1128-
- color_ref: matplotlib valid color for the 45 degree line
1131+
- marker_scatter: matplotlib valid marker for the scatter plot
11291132
- color_scatter: matplotlib valid color for the scatter plot
1133+
- alpha_scatter: matplotlib valid alpha for the scatter plot
1134+
- color_ref: matplotlib valid color for the 45 degree line
1135+
- ls_ref: matplotlib valid linestyle for the reference line
11301136
axes : axes
11311137
Matplotlib axes.
11321138
@@ -1140,41 +1146,69 @@ def plot_scatter_submodels(
11401146
submodels = np.sort(submodels)
11411147

11421148
indices = vi_results["indices"][submodels]
1143-
preds = vi_results["preds"][submodels]
1149+
preds_sub = vi_results["preds"][submodels]
11441150
preds_all = vi_results["preds_all"]
11451151

1152+
if labels is None:
1153+
labels = vi_results["labels"][submodels]
1154+
1155+
# handle categorical regression case:
1156+
n_cats = None
1157+
if preds_all.ndim > 2:
1158+
n_cats = preds_all.shape[-1]
1159+
indices = np.tile(indices, n_cats)
1160+
11461161
if ax is None:
11471162
_, ax = _get_axes(grid, len(indices), True, True, figsize)
11481163

11491164
if plot_kwargs is None:
11501165
plot_kwargs = {}
11511166

1152-
if labels is None:
1153-
labels = vi_results["labels"][submodels]
1154-
11551167
if func is not None:
1156-
preds = func(preds)
1168+
preds_sub = func(preds_sub)
11571169
preds_all = func(preds_all)
11581170

1159-
min_ = min(np.min(preds), np.min(preds_all))
1160-
max_ = max(np.max(preds), np.max(preds_all))
1161-
1162-
for pred, x_label, axi in zip(preds, labels, ax.ravel()):
1163-
axi.plot(
1164-
pred,
1165-
preds_all,
1166-
marker=plot_kwargs.get("marker_scatter", "."),
1167-
ls="",
1168-
color=plot_kwargs.get("color_scatter", "C0"),
1169-
alpha=plot_kwargs.get("alpha_scatter", 0.1),
1170-
)
1171-
axi.set_xlabel(x_label)
1172-
axi.axline(
1173-
[min_, min_],
1174-
[max_, max_],
1175-
color=plot_kwargs.get("color_ref", "0.5"),
1176-
ls=plot_kwargs.get("ls_ref", "--"),
1177-
)
1171+
min_ = min(np.min(preds_sub), np.min(preds_all))
1172+
max_ = max(np.max(preds_sub), np.max(preds_all))
1173+
1174+
# handle categorical regression case:
1175+
if n_cats is not None:
1176+
i = 0
1177+
for cat in range(n_cats):
1178+
for pred_sub, x_label in zip(preds_sub, labels):
1179+
ax[i].plot(
1180+
pred_sub[..., cat],
1181+
preds_all[..., cat],
1182+
marker=plot_kwargs.get("marker_scatter", "."),
1183+
ls="",
1184+
color=plot_kwargs.get("color_scatter", f"C{cat}"),
1185+
alpha=plot_kwargs.get("alpha_scatter", 0.1),
1186+
)
1187+
ax[i].set(xlabel=x_label, ylabel="ref model", title=f"Category {cat}")
1188+
ax[i].axline(
1189+
[min_, min_],
1190+
[max_, max_],
1191+
color=plot_kwargs.get("color_ref", "0.5"),
1192+
ls=plot_kwargs.get("ls_ref", "--"),
1193+
)
1194+
i += 1
1195+
else:
1196+
for pred_sub, x_label, axi in zip(preds_sub, labels, ax.ravel()):
1197+
axi.plot(
1198+
pred_sub,
1199+
preds_all,
1200+
marker=plot_kwargs.get("marker_scatter", "."),
1201+
ls="",
1202+
color=plot_kwargs.get("color_scatter", "C0"),
1203+
alpha=plot_kwargs.get("alpha_scatter", 0.1),
1204+
)
1205+
axi.set(xlabel=x_label, ylabel="ref model")
1206+
axi.axline(
1207+
[min_, min_],
1208+
[max_, max_],
1209+
color=plot_kwargs.get("color_ref", "0.5"),
1210+
ls=plot_kwargs.get("ls_ref", "--"),
1211+
)
11781212
return ax
11791213

11801214

0 commit comments

Comments
 (0)