Skip to content

Commit f413287

Browse files
author
ArturoAmorQ
committed
Synchronize exercise and notebooks
1 parent 7093c62 commit f413287

File tree

3 files changed

+50
-62
lines changed

3 files changed

+50
-62
lines changed

notebooks/trees_ex_01.ipynb

+3-3
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@
8383
"<div class=\"admonition warning alert alert-danger\">\n",
8484
"<p class=\"first admonition-title\" style=\"font-weight: bold;\">Warning</p>\n",
8585
"<p class=\"last\">At this time, it is not possible to use <tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict_proba\"</span></tt> for\n",
86-
"multiclass problems. This is a planned feature for a future version of\n",
87-
"scikit-learn. In the mean time, you can use <tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict\"</span></tt>\n",
88-
"instead.</p>\n",
86+
"multiclass problems on a single plot. This is a planned feature for a future\n",
87+
"version of scikit-learn. In the mean time, you can use\n",
88+
"<tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict\"</span></tt> instead.</p>\n",
8989
"</div>"
9090
]
9191
},

notebooks/trees_sol_01.ipynb

+44-56
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@
8787
"<div class=\"admonition warning alert alert-danger\">\n",
8888
"<p class=\"first admonition-title\" style=\"font-weight: bold;\">Warning</p>\n",
8989
"<p class=\"last\">At this time, it is not possible to use <tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict_proba\"</span></tt> for\n",
90-
"multiclass problems. This is a planned feature for a future version of\n",
91-
"scikit-learn. In the mean time, you can use <tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict\"</span></tt>\n",
92-
"instead.</p>\n",
90+
"multiclass problems on a single plot. This is a planned feature for a future\n",
91+
"version of scikit-learn. In the mean time, you can use\n",
92+
"<tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict\"</span></tt> instead.</p>\n",
9393
"</div>"
9494
]
9595
},
@@ -212,13 +212,15 @@
212212
"except that for a K-class problem you have K probability outputs for each\n",
213213
"data point. Visualizing all these on a single plot can quickly become tricky\n",
214214
"to interpret. It is then common to instead produce K separate plots, one for\n",
215-
"each class, in a one-vs-rest (or one-vs-all) fashion.\n",
215+
"each class, in a one-vs-rest (or one-vs-all) fashion. This can be achieved by\n",
216+
"calling `DecisionBoundaryDisplay` several times, once for each class, and\n",
217+
"passing the `class_of_interest` parameter to the function.\n",
216218
"\n",
217-
"For example, in the plot below, the first plot on the left shows in yellow the\n",
218-
"certainty on classifying a data point as belonging to the \"Adelie\" class. In\n",
219-
"the same plot, the spectre from green to purple represents the certainty of\n",
220-
"**not** belonging to the \"Adelie\" class. The same logic applies to the other\n",
221-
"plots in the figure."
219+
"For example, in the plot below, the first plot on the left shows the\n",
220+
"certainty of classifying a data point as belonging to the \"Adelie\" class. The\n",
221+
"darker the color, the more certain the model is that a given point in the\n",
222+
"feature space belongs to a given class the predictions. The same logic\n",
223+
"applies to the other plots in the figure."
222224
]
223225
},
224226
{
@@ -231,48 +233,39 @@
231233
},
232234
"outputs": [],
233235
"source": [
234-
"import numpy as np\n",
235-
"\n",
236-
"xx = np.linspace(30, 60, 100)\n",
237-
"yy = np.linspace(10, 23, 100)\n",
238-
"xx, yy = np.meshgrid(xx, yy)\n",
239-
"Xfull = pd.DataFrame(\n",
240-
" {\"Culmen Length (mm)\": xx.ravel(), \"Culmen Depth (mm)\": yy.ravel()}\n",
241-
")\n",
242-
"\n",
243-
"probas = tree.predict_proba(Xfull)\n",
244-
"n_classes = len(np.unique(tree.classes_))\n",
236+
"# import numpy as np\n",
237+
"from matplotlib import cm\n",
245238
"\n",
246239
"_, axs = plt.subplots(ncols=3, nrows=1, sharey=True, figsize=(12, 5))\n",
247-
"plt.suptitle(\"Predicted probabilities for decision tree model\", y=0.8)\n",
240+
"plt.suptitle(\"Predicted probabilities for decision tree model\", y=1.05)\n",
241+
"plt.subplots_adjust(bottom=0.45)\n",
248242
"\n",
249-
"for class_of_interest in range(n_classes):\n",
250-
" axs[class_of_interest].set_title(\n",
251-
" f\"Class {tree.classes_[class_of_interest]}\"\n",
252-
" )\n",
253-
" imshow_handle = axs[class_of_interest].imshow(\n",
254-
" probas[:, class_of_interest].reshape((100, 100)),\n",
255-
" extent=(30, 60, 10, 23),\n",
256-
" vmin=0.0,\n",
257-
" vmax=1.0,\n",
258-
" origin=\"lower\",\n",
259-
" cmap=\"viridis\",\n",
243+
"for idx, (class_of_interest, ax) in enumerate(zip(tree.classes_, axs)):\n",
244+
" ax.set_title(f\"Class {class_of_interest}\")\n",
245+
" DecisionBoundaryDisplay.from_estimator(\n",
246+
" tree,\n",
247+
" data_test,\n",
248+
" response_method=\"predict_proba\",\n",
249+
" class_of_interest=class_of_interest,\n",
250+
" ax=ax,\n",
251+
" vmin=0,\n",
252+
" vmax=1,\n",
253+
" cmap=\"Blues\",\n",
260254
" )\n",
261-
" axs[class_of_interest].set_xlabel(\"Culmen Length (mm)\")\n",
262-
" if class_of_interest == 0:\n",
263-
" axs[class_of_interest].set_ylabel(\"Culmen Depth (mm)\")\n",
264-
" idx = target_test == tree.classes_[class_of_interest]\n",
265-
" axs[class_of_interest].scatter(\n",
266-
" data_test[\"Culmen Length (mm)\"].loc[idx],\n",
267-
" data_test[\"Culmen Depth (mm)\"].loc[idx],\n",
255+
" ax.scatter(\n",
256+
" data_test[\"Culmen Length (mm)\"].loc[target_test == class_of_interest],\n",
257+
" data_test[\"Culmen Depth (mm)\"].loc[target_test == class_of_interest],\n",
268258
" marker=\"o\",\n",
269259
" c=\"w\",\n",
270260
" edgecolor=\"k\",\n",
271261
" )\n",
262+
" ax.set_xlabel(\"Culmen Length (mm)\")\n",
263+
" if idx == 0:\n",
264+
" ax.set_ylabel(\"Culmen Depth (mm)\")\n",
272265
"\n",
273-
"ax = plt.axes([0.15, 0.04, 0.7, 0.05])\n",
274-
"plt.colorbar(imshow_handle, cax=ax, orientation=\"horizontal\")\n",
275-
"_ = plt.title(\"Probability\")"
266+
"ax = plt.axes([0.15, 0.14, 0.7, 0.05])\n",
267+
"plt.colorbar(cm.ScalarMappable(cmap=\"Blues\"), cax=ax, orientation=\"horizontal\")\n",
268+
"_ = ax.set_title(\"Predicted class membership probability\")"
276269
]
277270
},
278271
{
@@ -283,22 +276,17 @@
283276
]
284277
},
285278
"source": [
279+
"\n",
286280
"<div class=\"admonition note alert alert-info\">\n",
287281
"<p class=\"first admonition-title\" style=\"font-weight: bold;\">Note</p>\n",
288-
"<p class=\"last\">You may have noticed that we are no longer using a diverging colormap. Indeed,\n",
289-
"the chance level for a one-vs-rest binarization of the multi-class\n",
290-
"classification problem is almost never at predicted probability of 0.5. So\n",
291-
"using a colormap with a neutral white at 0.5 might give a false impression on\n",
292-
"the certainty.</p>\n",
293-
"</div>\n",
294-
"\n",
295-
"In future versions of scikit-learn `DecisionBoundaryDisplay` will support a\n",
296-
"`class_of_interest` parameter that will allow in particular for a\n",
297-
"visualization of `predict_proba` in multi-class settings.\n",
298-
"\n",
299-
"We also plan to make it possible to visualize the `predict_proba` values for\n",
300-
"the class with the maximum predicted probability (without having to pass a\n",
301-
"given a fixed `class_of_interest` value)."
282+
"<p class=\"last\">You may notice that we do not use a diverging colormap (2 color gradients with\n",
283+
"white in the middle). Indeed, in a multiclass setting, 0.5 is not a\n",
284+
"meaningful value, hence using white as the center of the colormap is not\n",
285+
"appropriate. Instead, we use a sequential colormap, where the color intensity\n",
286+
"indicates the certainty of the classification. The darker the color, the more\n",
287+
"certain the model is that a given point in the feature space belongs to a\n",
288+
"given class.</p>\n",
289+
"</div>"
302290
]
303291
}
304292
],

python_scripts/trees_ex_01.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@
5959
#
6060
# ```{warning}
6161
# At this time, it is not possible to use `response_method="predict_proba"` for
62-
# multiclass problems. This is a planned feature for a future version of
63-
# scikit-learn. In the mean time, you can use `response_method="predict"`
64-
# instead.
62+
# multiclass problems on a single plot. This is a planned feature for a future
63+
# version of scikit-learn. In the mean time, you can use
64+
# `response_method="predict"` instead.
6565
# ```
6666

6767
# %%

0 commit comments

Comments
 (0)