Skip to content

Commit f01f612

Browse files
authored
MAINT Use class_of_interest in DecisionBoundaryDisplay (#772)
1 parent 57d821e commit f01f612

File tree

4 files changed

+90
-116
lines changed

4 files changed

+90
-116
lines changed

notebooks/trees_ex_01.ipynb

+3-3
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@
8383
"<div class=\"admonition warning alert alert-danger\">\n",
8484
"<p class=\"first admonition-title\" style=\"font-weight: bold;\">Warning</p>\n",
8585
"<p class=\"last\">At this time, it is not possible to use <tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict_proba\"</span></tt> for\n",
86-
"multiclass problems. This is a planned feature for a future version of\n",
87-
"scikit-learn. In the mean time, you can use <tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict\"</span></tt>\n",
88-
"instead.</p>\n",
86+
"multiclass problems on a single plot. This is a planned feature for a future\n",
87+
"version of scikit-learn. In the mean time, you can use\n",
88+
"<tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict\"</span></tt> instead.</p>\n",
8989
"</div>"
9090
]
9191
},

notebooks/trees_sol_01.ipynb

+42-55
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@
8787
"<div class=\"admonition warning alert alert-danger\">\n",
8888
"<p class=\"first admonition-title\" style=\"font-weight: bold;\">Warning</p>\n",
8989
"<p class=\"last\">At this time, it is not possible to use <tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict_proba\"</span></tt> for\n",
90-
"multiclass problems. This is a planned feature for a future version of\n",
91-
"scikit-learn. In the mean time, you can use <tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict\"</span></tt>\n",
92-
"instead.</p>\n",
90+
"multiclass problems on a single plot. This is a planned feature for a future\n",
91+
"version of scikit-learn. In the mean time, you can use\n",
92+
"<tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict\"</span></tt> instead.</p>\n",
9393
"</div>"
9494
]
9595
},
@@ -212,12 +212,14 @@
212212
"except that for a K-class problem you have K probability outputs for each\n",
213213
"data point. Visualizing all these on a single plot can quickly become tricky\n",
214214
"to interpret. It is then common to instead produce K separate plots, one for\n",
215-
"each class, in a one-vs-rest (or one-vs-all) fashion.\n",
215+
"each class, in a one-vs-rest (or one-vs-all) fashion. This can be achieved by\n",
216+
"calling `DecisionBoundaryDisplay` several times, once for each class, and\n",
217+
"passing the `class_of_interest` parameter to the function.\n",
216218
"\n",
217-
"For example, in the plot below, the first plot on the left shows in yellow the\n",
218-
"certainty on classifying a data point as belonging to the \"Adelie\" class. In\n",
219-
"the same plot, the spectre from green to purple represents the certainty of\n",
220-
"**not** belonging to the \"Adelie\" class. The same logic applies to the other\n",
219+
"For example, in the plot below, the first plot on the left shows the\n",
220+
"certainty of classifying a data point as belonging to the \"Adelie\" class. The\n",
221+
"darker the color, the more certain the model is that a given point in the\n",
222+
"feature space belongs to a given class. The same logic applies to the other\n",
221223
"plots in the figure."
222224
]
223225
},
@@ -231,48 +233,38 @@
231233
},
232234
"outputs": [],
233235
"source": [
234-
"import numpy as np\n",
235-
"\n",
236-
"xx = np.linspace(30, 60, 100)\n",
237-
"yy = np.linspace(10, 23, 100)\n",
238-
"xx, yy = np.meshgrid(xx, yy)\n",
239-
"Xfull = pd.DataFrame(\n",
240-
" {\"Culmen Length (mm)\": xx.ravel(), \"Culmen Depth (mm)\": yy.ravel()}\n",
241-
")\n",
242-
"\n",
243-
"probas = tree.predict_proba(Xfull)\n",
244-
"n_classes = len(np.unique(tree.classes_))\n",
236+
"from matplotlib import cm\n",
245237
"\n",
246238
"_, axs = plt.subplots(ncols=3, nrows=1, sharey=True, figsize=(12, 5))\n",
247-
"plt.suptitle(\"Predicted probabilities for decision tree model\", y=0.8)\n",
239+
"plt.suptitle(\"Predicted probabilities for decision tree model\", y=1.05)\n",
240+
"plt.subplots_adjust(bottom=0.45)\n",
248241
"\n",
249-
"for class_of_interest in range(n_classes):\n",
250-
" axs[class_of_interest].set_title(\n",
251-
" f\"Class {tree.classes_[class_of_interest]}\"\n",
252-
" )\n",
253-
" imshow_handle = axs[class_of_interest].imshow(\n",
254-
" probas[:, class_of_interest].reshape((100, 100)),\n",
255-
" extent=(30, 60, 10, 23),\n",
256-
" vmin=0.0,\n",
257-
" vmax=1.0,\n",
258-
" origin=\"lower\",\n",
259-
" cmap=\"viridis\",\n",
242+
"for idx, (class_of_interest, ax) in enumerate(zip(tree.classes_, axs)):\n",
243+
" ax.set_title(f\"Class {class_of_interest}\")\n",
244+
" DecisionBoundaryDisplay.from_estimator(\n",
245+
" tree,\n",
246+
" data_test,\n",
247+
" response_method=\"predict_proba\",\n",
248+
" class_of_interest=class_of_interest,\n",
249+
" ax=ax,\n",
250+
" vmin=0,\n",
251+
" vmax=1,\n",
252+
" cmap=\"Blues\",\n",
260253
" )\n",
261-
" axs[class_of_interest].set_xlabel(\"Culmen Length (mm)\")\n",
262-
" if class_of_interest == 0:\n",
263-
" axs[class_of_interest].set_ylabel(\"Culmen Depth (mm)\")\n",
264-
" idx = target_test == tree.classes_[class_of_interest]\n",
265-
" axs[class_of_interest].scatter(\n",
266-
" data_test[\"Culmen Length (mm)\"].loc[idx],\n",
267-
" data_test[\"Culmen Depth (mm)\"].loc[idx],\n",
254+
" ax.scatter(\n",
255+
" data_test[\"Culmen Length (mm)\"].loc[target_test == class_of_interest],\n",
256+
" data_test[\"Culmen Depth (mm)\"].loc[target_test == class_of_interest],\n",
268257
" marker=\"o\",\n",
269258
" c=\"w\",\n",
270259
" edgecolor=\"k\",\n",
271260
" )\n",
261+
" ax.set_xlabel(\"Culmen Length (mm)\")\n",
262+
" if idx == 0:\n",
263+
" ax.set_ylabel(\"Culmen Depth (mm)\")\n",
272264
"\n",
273-
"ax = plt.axes([0.15, 0.04, 0.7, 0.05])\n",
274-
"plt.colorbar(imshow_handle, cax=ax, orientation=\"horizontal\")\n",
275-
"_ = plt.title(\"Probability\")"
265+
"ax = plt.axes([0.15, 0.14, 0.7, 0.05])\n",
266+
"plt.colorbar(cm.ScalarMappable(cmap=\"Blues\"), cax=ax, orientation=\"horizontal\")\n",
267+
"_ = ax.set_title(\"Predicted class membership probability\")"
276268
]
277269
},
278270
{
@@ -283,22 +275,17 @@
283275
]
284276
},
285277
"source": [
278+
"\n",
286279
"<div class=\"admonition note alert alert-info\">\n",
287280
"<p class=\"first admonition-title\" style=\"font-weight: bold;\">Note</p>\n",
288-
"<p class=\"last\">You may have noticed that we are no longer using a diverging colormap. Indeed,\n",
289-
"the chance level for a one-vs-rest binarization of the multi-class\n",
290-
"classification problem is almost never at predicted probability of 0.5. So\n",
291-
"using a colormap with a neutral white at 0.5 might give a false impression on\n",
292-
"the certainty.</p>\n",
293-
"</div>\n",
294-
"\n",
295-
"In future versions of scikit-learn `DecisionBoundaryDisplay` will support a\n",
296-
"`class_of_interest` parameter that will allow in particular for a\n",
297-
"visualization of `predict_proba` in multi-class settings.\n",
298-
"\n",
299-
"We also plan to make it possible to visualize the `predict_proba` values for\n",
300-
"the class with the maximum predicted probability (without having to pass a\n",
301-
"given a fixed `class_of_interest` value)."
281+
"<p class=\"last\">You may notice that we do not use a diverging colormap (2 color gradients with\n",
282+
"white in the middle). Indeed, in a multiclass setting, 0.5 is not a\n",
283+
"meaningful value, hence using white as the center of the colormap is not\n",
284+
"appropriate. Instead, we use a sequential colormap, where the color intensity\n",
285+
"indicates the certainty of the classification. The darker the color, the more\n",
286+
"certain the model is that a given point in the feature space belongs to a\n",
287+
"given class.</p>\n",
288+
"</div>"
302289
]
303290
}
304291
],

python_scripts/trees_ex_01.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@
5959
#
6060
# ```{warning}
6161
# At this time, it is not possible to use `response_method="predict_proba"` for
62-
# multiclass problems. This is a planned feature for a future version of
63-
# scikit-learn. In the mean time, you can use `response_method="predict"`
64-
# instead.
62+
# multiclass problems on a single plot. This is a planned feature for a future
63+
# version of scikit-learn. In the mean time, you can use
64+
# `response_method="predict"` instead.
6565
# ```
6666

6767
# %%

python_scripts/trees_sol_01.py

+42-55
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@
5757
#
5858
# ```{warning}
5959
# At this time, it is not possible to use `response_method="predict_proba"` for
60-
# multiclass problems. This is a planned feature for a future version of
61-
# scikit-learn. In the mean time, you can use `response_method="predict"`
62-
# instead.
60+
# multiclass problems on a single plot. This is a planned feature for a future
61+
# version of scikit-learn. In the mean time, you can use
62+
# `response_method="predict"` instead.
6363
# ```
6464

6565
# %%
@@ -140,71 +140,58 @@
140140
# except that for a K-class problem you have K probability outputs for each
141141
# data point. Visualizing all these on a single plot can quickly become tricky
142142
# to interpret. It is then common to instead produce K separate plots, one for
143-
# each class, in a one-vs-rest (or one-vs-all) fashion.
143+
# each class, in a one-vs-rest (or one-vs-all) fashion. This can be achieved by
144+
# calling `DecisionBoundaryDisplay` several times, once for each class, and
145+
# passing the `class_of_interest` parameter to the function.
144146
#
145-
# For example, in the plot below, the first plot on the left shows in yellow the
146-
# certainty on classifying a data point as belonging to the "Adelie" class. In
147-
# the same plot, the spectre from green to purple represents the certainty of
148-
# **not** belonging to the "Adelie" class. The same logic applies to the other
147+
# For example, in the plot below, the first plot on the left shows the
148+
# certainty of classifying a data point as belonging to the "Adelie" class. The
149+
# darker the color, the more certain the model is that a given point in the
150+
# feature space belongs to a given class. The same logic applies to the other
149151
# plots in the figure.
150152

151153
# %% tags=["solution"]
152-
import numpy as np
153-
154-
xx = np.linspace(30, 60, 100)
155-
yy = np.linspace(10, 23, 100)
156-
xx, yy = np.meshgrid(xx, yy)
157-
Xfull = pd.DataFrame(
158-
{"Culmen Length (mm)": xx.ravel(), "Culmen Depth (mm)": yy.ravel()}
159-
)
160-
161-
probas = tree.predict_proba(Xfull)
162-
n_classes = len(np.unique(tree.classes_))
154+
from matplotlib import cm
163155

164156
_, axs = plt.subplots(ncols=3, nrows=1, sharey=True, figsize=(12, 5))
165-
plt.suptitle("Predicted probabilities for decision tree model", y=0.8)
166-
167-
for class_of_interest in range(n_classes):
168-
axs[class_of_interest].set_title(
169-
f"Class {tree.classes_[class_of_interest]}"
157+
plt.suptitle("Predicted probabilities for decision tree model", y=1.05)
158+
plt.subplots_adjust(bottom=0.45)
159+
160+
for idx, (class_of_interest, ax) in enumerate(zip(tree.classes_, axs)):
161+
ax.set_title(f"Class {class_of_interest}")
162+
DecisionBoundaryDisplay.from_estimator(
163+
tree,
164+
data_test,
165+
response_method="predict_proba",
166+
class_of_interest=class_of_interest,
167+
ax=ax,
168+
vmin=0,
169+
vmax=1,
170+
cmap="Blues",
170171
)
171-
imshow_handle = axs[class_of_interest].imshow(
172-
probas[:, class_of_interest].reshape((100, 100)),
173-
extent=(30, 60, 10, 23),
174-
vmin=0.0,
175-
vmax=1.0,
176-
origin="lower",
177-
cmap="viridis",
178-
)
179-
axs[class_of_interest].set_xlabel("Culmen Length (mm)")
180-
if class_of_interest == 0:
181-
axs[class_of_interest].set_ylabel("Culmen Depth (mm)")
182-
idx = target_test == tree.classes_[class_of_interest]
183-
axs[class_of_interest].scatter(
184-
data_test["Culmen Length (mm)"].loc[idx],
185-
data_test["Culmen Depth (mm)"].loc[idx],
172+
ax.scatter(
173+
data_test["Culmen Length (mm)"].loc[target_test == class_of_interest],
174+
data_test["Culmen Depth (mm)"].loc[target_test == class_of_interest],
186175
marker="o",
187176
c="w",
188177
edgecolor="k",
189178
)
179+
ax.set_xlabel("Culmen Length (mm)")
180+
if idx == 0:
181+
ax.set_ylabel("Culmen Depth (mm)")
190182

191-
ax = plt.axes([0.15, 0.04, 0.7, 0.05])
192-
plt.colorbar(imshow_handle, cax=ax, orientation="horizontal")
193-
_ = plt.title("Probability")
183+
ax = plt.axes([0.15, 0.14, 0.7, 0.05])
184+
plt.colorbar(cm.ScalarMappable(cmap="Blues"), cax=ax, orientation="horizontal")
185+
_ = ax.set_title("Predicted class membership probability")
194186

195187
# %% [markdown] tags=["solution"]
188+
#
196189
# ```{note}
197-
# You may have noticed that we are no longer using a diverging colormap. Indeed,
198-
# the chance level for a one-vs-rest binarization of the multi-class
199-
# classification problem is almost never at predicted probability of 0.5. So
200-
# using a colormap with a neutral white at 0.5 might give a false impression on
201-
# the certainty.
190+
# You may notice that we do not use a diverging colormap (2 color gradients with
191+
# white in the middle). Indeed, in a multiclass setting, 0.5 is not a
192+
# meaningful value, hence using white as the center of the colormap is not
193+
# appropriate. Instead, we use a sequential colormap, where the color intensity
194+
# indicates the certainty of the classification. The darker the color, the more
195+
# certain the model is that a given point in the feature space belongs to a
196+
# given class.
202197
# ```
203-
#
204-
# In future versions of scikit-learn `DecisionBoundaryDisplay` will support a
205-
# `class_of_interest` parameter that will allow in particular for a
206-
# visualization of `predict_proba` in multi-class settings.
207-
#
208-
# We also plan to make it possible to visualize the `predict_proba` values for
209-
# the class with the maximum predicted probability (without having to pass a
210-
# given a fixed `class_of_interest` value).

0 commit comments

Comments
 (0)