Skip to content

Commit bb5f3df

Browse files
author
ArturoAmorQ
committed
Fix commit history
1 parent f413287 commit bb5f3df

File tree

1 file changed

+43
-56
lines changed

1 file changed

+43
-56
lines changed

python_scripts/trees_sol_01.py

+43-56
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@
5757
#
5858
# ```{warning}
5959
# At this time, it is not possible to use `response_method="predict_proba"` for
60-
# multiclass problems. This is a planned feature for a future version of
61-
# scikit-learn. In the mean time, you can use `response_method="predict"`
62-
# instead.
60+
# multiclass problems on a single plot. This is a planned feature for a future
61+
# version of scikit-learn. In the mean time, you can use
62+
# `response_method="predict"` instead.
6363
# ```
6464

6565
# %%
@@ -140,71 +140,58 @@
140140
# except that for a K-class problem you have K probability outputs for each
141141
# data point. Visualizing all these on a single plot can quickly become tricky
142142
# to interpret. It is then common to instead produce K separate plots, one for
143-
# each class, in a one-vs-rest (or one-vs-all) fashion.
143+
# each class, in a one-vs-rest (or one-vs-all) fashion. This can be achieved by
144+
# calling `DecisionBoundaryDisplay` several times, once for each class, and
145+
# passing the `class_of_interest` parameter to the function.
144146
#
145-
# For example, in the plot below, the first plot on the left shows in yellow the
146-
# certainty on classifying a data point as belonging to the "Adelie" class. In
147-
# the same plot, the spectre from green to purple represents the certainty of
148-
# **not** belonging to the "Adelie" class. The same logic applies to the other
149-
# plots in the figure.
147+
# For example, in the plot below, the first plot on the left shows the
148+
# certainty of classifying a data point as belonging to the "Adelie" class. The
149+
# darker the color, the more certain the model is that a given point in the
150+
# feature space belongs to a given class the predictions. The same logic
151+
# applies to the other plots in the figure.
150152

151153
# %% tags=["solution"]
152-
import numpy as np
153-
154-
xx = np.linspace(30, 60, 100)
155-
yy = np.linspace(10, 23, 100)
156-
xx, yy = np.meshgrid(xx, yy)
157-
Xfull = pd.DataFrame(
158-
{"Culmen Length (mm)": xx.ravel(), "Culmen Depth (mm)": yy.ravel()}
159-
)
160-
161-
probas = tree.predict_proba(Xfull)
162-
n_classes = len(np.unique(tree.classes_))
154+
from matplotlib import cm
163155

164156
_, axs = plt.subplots(ncols=3, nrows=1, sharey=True, figsize=(12, 5))
165-
plt.suptitle("Predicted probabilities for decision tree model", y=0.8)
166-
167-
for class_of_interest in range(n_classes):
168-
axs[class_of_interest].set_title(
169-
f"Class {tree.classes_[class_of_interest]}"
157+
plt.suptitle("Predicted probabilities for decision tree model", y=1.05)
158+
plt.subplots_adjust(bottom=0.45)
159+
160+
for idx, (class_of_interest, ax) in enumerate(zip(tree.classes_, axs)):
161+
ax.set_title(f"Class {class_of_interest}")
162+
DecisionBoundaryDisplay.from_estimator(
163+
tree,
164+
data_test,
165+
response_method="predict_proba",
166+
class_of_interest=class_of_interest,
167+
ax=ax,
168+
vmin=0,
169+
vmax=1,
170+
cmap="Blues",
170171
)
171-
imshow_handle = axs[class_of_interest].imshow(
172-
probas[:, class_of_interest].reshape((100, 100)),
173-
extent=(30, 60, 10, 23),
174-
vmin=0.0,
175-
vmax=1.0,
176-
origin="lower",
177-
cmap="viridis",
178-
)
179-
axs[class_of_interest].set_xlabel("Culmen Length (mm)")
180-
if class_of_interest == 0:
181-
axs[class_of_interest].set_ylabel("Culmen Depth (mm)")
182-
idx = target_test == tree.classes_[class_of_interest]
183-
axs[class_of_interest].scatter(
184-
data_test["Culmen Length (mm)"].loc[idx],
185-
data_test["Culmen Depth (mm)"].loc[idx],
172+
ax.scatter(
173+
data_test["Culmen Length (mm)"].loc[target_test == class_of_interest],
174+
data_test["Culmen Depth (mm)"].loc[target_test == class_of_interest],
186175
marker="o",
187176
c="w",
188177
edgecolor="k",
189178
)
179+
ax.set_xlabel("Culmen Length (mm)")
180+
if idx == 0:
181+
ax.set_ylabel("Culmen Depth (mm)")
190182

191-
ax = plt.axes([0.15, 0.04, 0.7, 0.05])
192-
plt.colorbar(imshow_handle, cax=ax, orientation="horizontal")
193-
_ = plt.title("Probability")
183+
ax = plt.axes([0.15, 0.14, 0.7, 0.05])
184+
plt.colorbar(cm.ScalarMappable(cmap="Blues"), cax=ax, orientation="horizontal")
185+
_ = ax.set_title("Predicted class membership probability")
194186

195187
# %% [markdown] tags=["solution"]
188+
#
196189
# ```{note}
197-
# You may have noticed that we are no longer using a diverging colormap. Indeed,
198-
# the chance level for a one-vs-rest binarization of the multi-class
199-
# classification problem is almost never at predicted probability of 0.5. So
200-
# using a colormap with a neutral white at 0.5 might give a false impression on
201-
# the certainty.
190+
# You may notice that we do not use a diverging colormap (2 color gradients with
191+
# white in the middle). Indeed, in a multiclass setting, 0.5 is not a
192+
# meaningful value, hence using white as the center of the colormap is not
193+
# appropriate. Instead, we use a sequential colormap, where the color intensity
194+
# indicates the certainty of the classification. The darker the color, the more
195+
# certain the model is that a given point in the feature space belongs to a
196+
# given class.
202197
# ```
203-
#
204-
# In future versions of scikit-learn `DecisionBoundaryDisplay` will support a
205-
# `class_of_interest` parameter that will allow in particular for a
206-
# visualization of `predict_proba` in multi-class settings.
207-
#
208-
# We also plan to make it possible to visualize the `predict_proba` values for
209-
# the class with the maximum predicted probability (without having to pass a
210-
# given a fixed `class_of_interest` value).

0 commit comments

Comments
 (0)