|
87 | 87 | "<div class=\"admonition warning alert alert-danger\">\n", |
88 | 88 | "<p class=\"first admonition-title\" style=\"font-weight: bold;\">Warning</p>\n", |
89 | 89 | "<p class=\"last\">At this time, it is not possible to use <tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict_proba\"</span></tt> for\n", |
90 | | - "multiclass problems. This is a planned feature for a future version of\n", |
91 | | - "scikit-learn. In the mean time, you can use <tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict\"</span></tt>\n", |
92 | | - "instead.</p>\n", |
| 90 | + "multiclass problems on a single plot. This is a planned feature for a future\n", |
| 91 | + "version of scikit-learn. In the mean time, you can use\n", |
| 92 | + "<tt class=\"docutils literal\"><span class=\"pre\">response_method=\"predict\"</span></tt> instead.</p>\n", |
93 | 93 | "</div>" |
94 | 94 | ] |
95 | 95 | }, |
|
212 | 212 | "except that for a K-class problem you have K probability outputs for each\n", |
213 | 213 | "data point. Visualizing all these on a single plot can quickly become tricky\n", |
214 | 214 | "to interpret. It is then common to instead produce K separate plots, one for\n", |
215 | | - "each class, in a one-vs-rest (or one-vs-all) fashion.\n", |
| 215 | + "each class, in a one-vs-rest (or one-vs-all) fashion. This can be achieved by\n", |
| 216 | + "calling `DecisionBoundaryDisplay` several times, once for each class, and\n", |
| 217 | + "passing the `class_of_interest` parameter to the function.\n", |
216 | 218 | "\n", |
217 | | - "For example, in the plot below, the first plot on the left shows in yellow the\n", |
218 | | - "certainty on classifying a data point as belonging to the \"Adelie\" class. In\n", |
219 | | - "the same plot, the spectre from green to purple represents the certainty of\n", |
220 | | - "**not** belonging to the \"Adelie\" class. The same logic applies to the other\n", |
221 | | - "plots in the figure." |
| 219 | + "For example, in the plot below, the first plot on the left shows the\n", |
| 220 | + "certainty of classifying a data point as belonging to the \"Adelie\" class. The\n", |
| 221 | + "darker the color, the more certain the model is that a given point in the\n", |
| 222 | + "feature space belongs to a given class the predictions. The same logic\n", |
| 223 | + "applies to the other plots in the figure." |
222 | 224 | ] |
223 | 225 | }, |
224 | 226 | { |
|
231 | 233 | }, |
232 | 234 | "outputs": [], |
233 | 235 | "source": [ |
234 | | - "import numpy as np\n", |
235 | | - "\n", |
236 | | - "xx = np.linspace(30, 60, 100)\n", |
237 | | - "yy = np.linspace(10, 23, 100)\n", |
238 | | - "xx, yy = np.meshgrid(xx, yy)\n", |
239 | | - "Xfull = pd.DataFrame(\n", |
240 | | - " {\"Culmen Length (mm)\": xx.ravel(), \"Culmen Depth (mm)\": yy.ravel()}\n", |
241 | | - ")\n", |
242 | | - "\n", |
243 | | - "probas = tree.predict_proba(Xfull)\n", |
244 | | - "n_classes = len(np.unique(tree.classes_))\n", |
| 236 | + "# import numpy as np\n", |
| 237 | + "from matplotlib import cm\n", |
245 | 238 | "\n", |
246 | 239 | "_, axs = plt.subplots(ncols=3, nrows=1, sharey=True, figsize=(12, 5))\n", |
247 | | - "plt.suptitle(\"Predicted probabilities for decision tree model\", y=0.8)\n", |
| 240 | + "plt.suptitle(\"Predicted probabilities for decision tree model\", y=1.05)\n", |
| 241 | + "plt.subplots_adjust(bottom=0.45)\n", |
248 | 242 | "\n", |
249 | | - "for class_of_interest in range(n_classes):\n", |
250 | | - " axs[class_of_interest].set_title(\n", |
251 | | - " f\"Class {tree.classes_[class_of_interest]}\"\n", |
252 | | - " )\n", |
253 | | - " imshow_handle = axs[class_of_interest].imshow(\n", |
254 | | - " probas[:, class_of_interest].reshape((100, 100)),\n", |
255 | | - " extent=(30, 60, 10, 23),\n", |
256 | | - " vmin=0.0,\n", |
257 | | - " vmax=1.0,\n", |
258 | | - " origin=\"lower\",\n", |
259 | | - " cmap=\"viridis\",\n", |
| 243 | + "for idx, (class_of_interest, ax) in enumerate(zip(tree.classes_, axs)):\n", |
| 244 | + " ax.set_title(f\"Class {class_of_interest}\")\n", |
| 245 | + " DecisionBoundaryDisplay.from_estimator(\n", |
| 246 | + " tree,\n", |
| 247 | + " data_test,\n", |
| 248 | + " response_method=\"predict_proba\",\n", |
| 249 | + " class_of_interest=class_of_interest,\n", |
| 250 | + " ax=ax,\n", |
| 251 | + " vmin=0,\n", |
| 252 | + " vmax=1,\n", |
| 253 | + " cmap=\"Blues\",\n", |
260 | 254 | " )\n", |
261 | | - " axs[class_of_interest].set_xlabel(\"Culmen Length (mm)\")\n", |
262 | | - " if class_of_interest == 0:\n", |
263 | | - " axs[class_of_interest].set_ylabel(\"Culmen Depth (mm)\")\n", |
264 | | - " idx = target_test == tree.classes_[class_of_interest]\n", |
265 | | - " axs[class_of_interest].scatter(\n", |
266 | | - " data_test[\"Culmen Length (mm)\"].loc[idx],\n", |
267 | | - " data_test[\"Culmen Depth (mm)\"].loc[idx],\n", |
| 255 | + " ax.scatter(\n", |
| 256 | + " data_test[\"Culmen Length (mm)\"].loc[target_test == class_of_interest],\n", |
| 257 | + " data_test[\"Culmen Depth (mm)\"].loc[target_test == class_of_interest],\n", |
268 | 258 | " marker=\"o\",\n", |
269 | 259 | " c=\"w\",\n", |
270 | 260 | " edgecolor=\"k\",\n", |
271 | 261 | " )\n", |
| 262 | + " ax.set_xlabel(\"Culmen Length (mm)\")\n", |
| 263 | + " if idx == 0:\n", |
| 264 | + " ax.set_ylabel(\"Culmen Depth (mm)\")\n", |
272 | 265 | "\n", |
273 | | - "ax = plt.axes([0.15, 0.04, 0.7, 0.05])\n", |
274 | | - "plt.colorbar(imshow_handle, cax=ax, orientation=\"horizontal\")\n", |
275 | | - "_ = plt.title(\"Probability\")" |
| 266 | + "ax = plt.axes([0.15, 0.14, 0.7, 0.05])\n", |
| 267 | + "plt.colorbar(cm.ScalarMappable(cmap=\"Blues\"), cax=ax, orientation=\"horizontal\")\n", |
| 268 | + "_ = ax.set_title(\"Predicted class membership probability\")" |
276 | 269 | ] |
277 | 270 | }, |
278 | 271 | { |
|
283 | 276 | ] |
284 | 277 | }, |
285 | 278 | "source": [ |
| 279 | + "\n", |
286 | 280 | "<div class=\"admonition note alert alert-info\">\n", |
287 | 281 | "<p class=\"first admonition-title\" style=\"font-weight: bold;\">Note</p>\n", |
288 | | - "<p class=\"last\">You may have noticed that we are no longer using a diverging colormap. Indeed,\n", |
289 | | - "the chance level for a one-vs-rest binarization of the multi-class\n", |
290 | | - "classification problem is almost never at predicted probability of 0.5. So\n", |
291 | | - "using a colormap with a neutral white at 0.5 might give a false impression on\n", |
292 | | - "the certainty.</p>\n", |
293 | | - "</div>\n", |
294 | | - "\n", |
295 | | - "In future versions of scikit-learn `DecisionBoundaryDisplay` will support a\n", |
296 | | - "`class_of_interest` parameter that will allow in particular for a\n", |
297 | | - "visualization of `predict_proba` in multi-class settings.\n", |
298 | | - "\n", |
299 | | - "We also plan to make it possible to visualize the `predict_proba` values for\n", |
300 | | - "the class with the maximum predicted probability (without having to pass a\n", |
301 | | - "given a fixed `class_of_interest` value)." |
| 282 | + "<p class=\"last\">You may notice that we do not use a diverging colormap (2 color gradients with\n", |
| 283 | + "white in the middle). Indeed, in a multiclass setting, 0.5 is not a\n", |
| 284 | + "meaningful value, hence using white as the center of the colormap is not\n", |
| 285 | + "appropriate. Instead, we use a sequential colormap, where the color intensity\n", |
| 286 | + "indicates the certainty of the classification. The darker the color, the more\n", |
| 287 | + "certain the model is that a given point in the feature space belongs to a\n", |
| 288 | + "given class.</p>\n", |
| 289 | + "</div>" |
302 | 290 | ] |
303 | 291 | } |
304 | 292 | ], |
|
0 commit comments