|
87 | 87 | "<div class=\"admonition warning alert alert-danger\">\n",
|
88 | 88 | "<p class=\"first admonition-title\" style=\"font-weight: bold;\">Warning</p>\n",
|
89 | 89 | "<p class=\"last\">At this time, it is not possible to use <tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict_proba\"</span></tt> for\n",
|
90 |
| - "multiclass problems. This is a planned feature for a future version of\n", |
91 |
| - "scikit-learn. In the mean time, you can use <tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict\"</span></tt>\n", |
92 |
| - "instead.</p>\n", |
| 90 | + "multiclass problems on a single plot. This is a planned feature for a future\n", |
| 91 | + "version of scikit-learn. In the mean time, you can use\n", |
| 92 | + "<tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict\"</span></tt> instead.</p>\n", |
93 | 93 | "</div>"
|
94 | 94 | ]
|
95 | 95 | },
|
|
212 | 212 | "except that for a K-class problem you have K probability outputs for each\n",
|
213 | 213 | "data point. Visualizing all these on a single plot can quickly become tricky\n",
|
214 | 214 | "to interpret. It is then common to instead produce K separate plots, one for\n",
|
215 |
| - "each class, in a one-vs-rest (or one-vs-all) fashion.\n", |
| 215 | + "each class, in a one-vs-rest (or one-vs-all) fashion. This can be achieved by\n", |
| 216 | + "calling `DecisionBoundaryDisplay` several times, once for each class, and\n", |
| 217 | + "passing the `class_of_interest` parameter to the function.\n", |
216 | 218 | "\n",
|
217 |
| - "For example, in the plot below, the first plot on the left shows in yellow the\n", |
218 |
| - "certainty on classifying a data point as belonging to the \"Adelie\" class. In\n", |
219 |
| - "the same plot, the spectre from green to purple represents the certainty of\n", |
220 |
| - "**not** belonging to the \"Adelie\" class. The same logic applies to the other\n", |
221 |
| - "plots in the figure." |
| 219 | + "For example, in the plot below, the first plot on the left shows the\n", |
| 220 | + "certainty of classifying a data point as belonging to the \"Adelie\" class. The\n", |
| 221 | + "darker the color, the more certain the model is that a given point in the\n", |
| 222 | + "feature space belongs to a given class the predictions. The same logic\n", |
| 223 | + "applies to the other plots in the figure." |
222 | 224 | ]
|
223 | 225 | },
|
224 | 226 | {
|
|
231 | 233 | },
|
232 | 234 | "outputs": [],
|
233 | 235 | "source": [
|
234 |
| - "import numpy as np\n", |
235 |
| - "\n", |
236 |
| - "xx = np.linspace(30, 60, 100)\n", |
237 |
| - "yy = np.linspace(10, 23, 100)\n", |
238 |
| - "xx, yy = np.meshgrid(xx, yy)\n", |
239 |
| - "Xfull = pd.DataFrame(\n", |
240 |
| - " {\"Culmen Length (mm)\": xx.ravel(), \"Culmen Depth (mm)\": yy.ravel()}\n", |
241 |
| - ")\n", |
242 |
| - "\n", |
243 |
| - "probas = tree.predict_proba(Xfull)\n", |
244 |
| - "n_classes = len(np.unique(tree.classes_))\n", |
| 236 | + "# import numpy as np\n", |
| 237 | + "from matplotlib import cm\n", |
245 | 238 | "\n",
|
246 | 239 | "_, axs = plt.subplots(ncols=3, nrows=1, sharey=True, figsize=(12, 5))\n",
|
247 |
| - "plt.suptitle(\"Predicted probabilities for decision tree model\", y=0.8)\n", |
| 240 | + "plt.suptitle(\"Predicted probabilities for decision tree model\", y=1.05)\n", |
| 241 | + "plt.subplots_adjust(bottom=0.45)\n", |
248 | 242 | "\n",
|
249 |
| - "for class_of_interest in range(n_classes):\n", |
250 |
| - " axs[class_of_interest].set_title(\n", |
251 |
| - " f\"Class {tree.classes_[class_of_interest]}\"\n", |
252 |
| - " )\n", |
253 |
| - " imshow_handle = axs[class_of_interest].imshow(\n", |
254 |
| - " probas[:, class_of_interest].reshape((100, 100)),\n", |
255 |
| - " extent=(30, 60, 10, 23),\n", |
256 |
| - " vmin=0.0,\n", |
257 |
| - " vmax=1.0,\n", |
258 |
| - " origin=\"lower\",\n", |
259 |
| - " cmap=\"viridis\",\n", |
| 243 | + "for idx, (class_of_interest, ax) in enumerate(zip(tree.classes_, axs)):\n", |
| 244 | + " ax.set_title(f\"Class {class_of_interest}\")\n", |
| 245 | + " DecisionBoundaryDisplay.from_estimator(\n", |
| 246 | + " tree,\n", |
| 247 | + " data_test,\n", |
| 248 | + " response_method=\"predict_proba\",\n", |
| 249 | + " class_of_interest=class_of_interest,\n", |
| 250 | + " ax=ax,\n", |
| 251 | + " vmin=0,\n", |
| 252 | + " vmax=1,\n", |
| 253 | + " cmap=\"Blues\",\n", |
260 | 254 | " )\n",
|
261 |
| - " axs[class_of_interest].set_xlabel(\"Culmen Length (mm)\")\n", |
262 |
| - " if class_of_interest == 0:\n", |
263 |
| - " axs[class_of_interest].set_ylabel(\"Culmen Depth (mm)\")\n", |
264 |
| - " idx = target_test == tree.classes_[class_of_interest]\n", |
265 |
| - " axs[class_of_interest].scatter(\n", |
266 |
| - " data_test[\"Culmen Length (mm)\"].loc[idx],\n", |
267 |
| - " data_test[\"Culmen Depth (mm)\"].loc[idx],\n", |
| 255 | + " ax.scatter(\n", |
| 256 | + " data_test[\"Culmen Length (mm)\"].loc[target_test == class_of_interest],\n", |
| 257 | + " data_test[\"Culmen Depth (mm)\"].loc[target_test == class_of_interest],\n", |
268 | 258 | " marker=\"o\",\n",
|
269 | 259 | " c=\"w\",\n",
|
270 | 260 | " edgecolor=\"k\",\n",
|
271 | 261 | " )\n",
|
| 262 | + " ax.set_xlabel(\"Culmen Length (mm)\")\n", |
| 263 | + " if idx == 0:\n", |
| 264 | + " ax.set_ylabel(\"Culmen Depth (mm)\")\n", |
272 | 265 | "\n",
|
273 |
| - "ax = plt.axes([0.15, 0.04, 0.7, 0.05])\n", |
274 |
| - "plt.colorbar(imshow_handle, cax=ax, orientation=\"horizontal\")\n", |
275 |
| - "_ = plt.title(\"Probability\")" |
| 266 | + "ax = plt.axes([0.15, 0.14, 0.7, 0.05])\n", |
| 267 | + "plt.colorbar(cm.ScalarMappable(cmap=\"Blues\"), cax=ax, orientation=\"horizontal\")\n", |
| 268 | + "_ = ax.set_title(\"Predicted class membership probability\")" |
276 | 269 | ]
|
277 | 270 | },
|
278 | 271 | {
|
|
283 | 276 | ]
|
284 | 277 | },
|
285 | 278 | "source": [
|
| 279 | + "\n", |
286 | 280 | "<div class=\"admonition note alert alert-info\">\n",
|
287 | 281 | "<p class=\"first admonition-title\" style=\"font-weight: bold;\">Note</p>\n",
|
288 |
| - "<p class=\"last\">You may have noticed that we are no longer using a diverging colormap. Indeed,\n", |
289 |
| - "the chance level for a one-vs-rest binarization of the multi-class\n", |
290 |
| - "classification problem is almost never at predicted probability of 0.5. So\n", |
291 |
| - "using a colormap with a neutral white at 0.5 might give a false impression on\n", |
292 |
| - "the certainty.</p>\n", |
293 |
| - "</div>\n", |
294 |
| - "\n", |
295 |
| - "In future versions of scikit-learn `DecisionBoundaryDisplay` will support a\n", |
296 |
| - "`class_of_interest` parameter that will allow in particular for a\n", |
297 |
| - "visualization of `predict_proba` in multi-class settings.\n", |
298 |
| - "\n", |
299 |
| - "We also plan to make it possible to visualize the `predict_proba` values for\n", |
300 |
| - "the class with the maximum predicted probability (without having to pass a\n", |
301 |
| - "given a fixed `class_of_interest` value)." |
| 282 | + "<p class=\"last\">You may notice that we do not use a diverging colormap (2 color gradients with\n", |
| 283 | + "white in the middle). Indeed, in a multiclass setting, 0.5 is not a\n", |
| 284 | + "meaningful value, hence using white as the center of the colormap is not\n", |
| 285 | + "appropriate. Instead, we use a sequential colormap, where the color intensity\n", |
| 286 | + "indicates the certainty of the classification. The darker the color, the more\n", |
| 287 | + "certain the model is that a given point in the feature space belongs to a\n", |
| 288 | + "given class.</p>\n", |
| 289 | + "</div>" |
302 | 290 | ]
|
303 | 291 | }
|
304 | 292 | ],
|
|
0 commit comments