|
57 | 57 | #
|
58 | 58 | # ```{warning}
|
59 | 59 | # At this time, it is not possible to use `response_method="predict_proba"` for
|
60 |
| -# multiclass problems. This is a planned feature for a future version of |
61 |
| -# scikit-learn. In the mean time, you can use `response_method="predict"` |
62 |
| -# instead. |
| 60 | +# multiclass problems on a single plot. This is a planned feature for a future |
| 61 | +# version of scikit-learn. In the mean time, you can use |
| 62 | +# `response_method="predict"` instead. |
63 | 63 | # ```
|
64 | 64 |
|
65 | 65 | # %%
|
|
140 | 140 | # except that for a K-class problem you have K probability outputs for each
|
141 | 141 | # data point. Visualizing all these on a single plot can quickly become tricky
|
142 | 142 | # to interpret. It is then common to instead produce K separate plots, one for
|
143 |
| -# each class, in a one-vs-rest (or one-vs-all) fashion. |
| 143 | +# each class, in a one-vs-rest (or one-vs-all) fashion. This can be achieved by |
| 144 | +# calling `DecisionBoundaryDisplay` several times, once for each class, and |
| 145 | +# passing the `class_of_interest` parameter to the function. |
144 | 146 | #
|
145 |
| -# For example, in the plot below, the first plot on the left shows in yellow the |
146 |
| -# certainty on classifying a data point as belonging to the "Adelie" class. In |
147 |
| -# the same plot, the spectre from green to purple represents the certainty of |
148 |
| -# **not** belonging to the "Adelie" class. The same logic applies to the other |
149 |
| -# plots in the figure. |
| 147 | +# For example, in the plot below, the first plot on the left shows the |
| 148 | +# certainty of classifying a data point as belonging to the "Adelie" class. The |
| 149 | +# darker the color, the more certain the model is that a given point in the |
| 150 | +# feature space belongs to a given class the predictions. The same logic |
| 151 | +# applies to the other plots in the figure. |
150 | 152 |
|
151 | 153 | # %% tags=["solution"]
|
152 |
| -import numpy as np |
153 |
| - |
154 |
| -xx = np.linspace(30, 60, 100) |
155 |
| -yy = np.linspace(10, 23, 100) |
156 |
| -xx, yy = np.meshgrid(xx, yy) |
157 |
| -Xfull = pd.DataFrame( |
158 |
| - {"Culmen Length (mm)": xx.ravel(), "Culmen Depth (mm)": yy.ravel()} |
159 |
| -) |
160 |
| - |
161 |
| -probas = tree.predict_proba(Xfull) |
162 |
| -n_classes = len(np.unique(tree.classes_)) |
| 154 | +from matplotlib import cm |
163 | 155 |
|
164 | 156 | _, axs = plt.subplots(ncols=3, nrows=1, sharey=True, figsize=(12, 5))
|
165 |
| -plt.suptitle("Predicted probabilities for decision tree model", y=0.8) |
166 |
| - |
167 |
| -for class_of_interest in range(n_classes): |
168 |
| - axs[class_of_interest].set_title( |
169 |
| - f"Class {tree.classes_[class_of_interest]}" |
| 157 | +plt.suptitle("Predicted probabilities for decision tree model", y=1.05) |
| 158 | +plt.subplots_adjust(bottom=0.45) |
| 159 | + |
| 160 | +for idx, (class_of_interest, ax) in enumerate(zip(tree.classes_, axs)): |
| 161 | + ax.set_title(f"Class {class_of_interest}") |
| 162 | + DecisionBoundaryDisplay.from_estimator( |
| 163 | + tree, |
| 164 | + data_test, |
| 165 | + response_method="predict_proba", |
| 166 | + class_of_interest=class_of_interest, |
| 167 | + ax=ax, |
| 168 | + vmin=0, |
| 169 | + vmax=1, |
| 170 | + cmap="Blues", |
170 | 171 | )
|
171 |
| - imshow_handle = axs[class_of_interest].imshow( |
172 |
| - probas[:, class_of_interest].reshape((100, 100)), |
173 |
| - extent=(30, 60, 10, 23), |
174 |
| - vmin=0.0, |
175 |
| - vmax=1.0, |
176 |
| - origin="lower", |
177 |
| - cmap="viridis", |
178 |
| - ) |
179 |
| - axs[class_of_interest].set_xlabel("Culmen Length (mm)") |
180 |
| - if class_of_interest == 0: |
181 |
| - axs[class_of_interest].set_ylabel("Culmen Depth (mm)") |
182 |
| - idx = target_test == tree.classes_[class_of_interest] |
183 |
| - axs[class_of_interest].scatter( |
184 |
| - data_test["Culmen Length (mm)"].loc[idx], |
185 |
| - data_test["Culmen Depth (mm)"].loc[idx], |
| 172 | + ax.scatter( |
| 173 | + data_test["Culmen Length (mm)"].loc[target_test == class_of_interest], |
| 174 | + data_test["Culmen Depth (mm)"].loc[target_test == class_of_interest], |
186 | 175 | marker="o",
|
187 | 176 | c="w",
|
188 | 177 | edgecolor="k",
|
189 | 178 | )
|
| 179 | + ax.set_xlabel("Culmen Length (mm)") |
| 180 | + if idx == 0: |
| 181 | + ax.set_ylabel("Culmen Depth (mm)") |
190 | 182 |
|
191 |
| -ax = plt.axes([0.15, 0.04, 0.7, 0.05]) |
192 |
| -plt.colorbar(imshow_handle, cax=ax, orientation="horizontal") |
193 |
| -_ = plt.title("Probability") |
| 183 | +ax = plt.axes([0.15, 0.14, 0.7, 0.05]) |
| 184 | +plt.colorbar(cm.ScalarMappable(cmap="Blues"), cax=ax, orientation="horizontal") |
| 185 | +_ = ax.set_title("Predicted class membership probability") |
194 | 186 |
|
195 | 187 | # %% [markdown] tags=["solution"]
|
| 188 | +# |
196 | 189 | # ```{note}
|
197 |
| -# You may have noticed that we are no longer using a diverging colormap. Indeed, |
198 |
| -# the chance level for a one-vs-rest binarization of the multi-class |
199 |
| -# classification problem is almost never at predicted probability of 0.5. So |
200 |
| -# using a colormap with a neutral white at 0.5 might give a false impression on |
201 |
| -# the certainty. |
| 190 | +# You may notice that we do not use a diverging colormap (2 color gradients with |
| 191 | +# white in the middle). Indeed, in a multiclass setting, 0.5 is not a |
| 192 | +# meaningful value, hence using white as the center of the colormap is not |
| 193 | +# appropriate. Instead, we use a sequential colormap, where the color intensity |
| 194 | +# indicates the certainty of the classification. The darker the color, the more |
| 195 | +# certain the model is that a given point in the feature space belongs to a |
| 196 | +# given class. |
202 | 197 | # ```
|
203 |
| -# |
204 |
| -# In future versions of scikit-learn `DecisionBoundaryDisplay` will support a |
205 |
| -# `class_of_interest` parameter that will allow in particular for a |
206 |
| -# visualization of `predict_proba` in multi-class settings. |
207 |
| -# |
208 |
| -# We also plan to make it possible to visualize the `predict_proba` values for |
209 |
| -# the class with the maximum predicted probability (without having to pass a |
210 |
| -# given a fixed `class_of_interest` value). |
0 commit comments