Skip to content

Commit 4866b70

Browse files
committed
description of confusion matrix and code examples in episode 4
1 parent 502b2ce commit 4866b70

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

content/04-supervised-ML-classification.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,34 @@ For classification tasks, metrics like accuracy, precision, recall, and the F1-s
272272
print("\nClassification Report:\n", classification_report(y_test, y_pred_knn))
273273
274274
275+
In classification tasks, a **confusion matrix** is a valuable tool for evaluating model performance by comparing predicted labels against true labels.
276+
For a multiclass classification task like the penguins dataset, the confusion matrix is an **N x N** matrix, where **N** is the number of target classes (here **N=3** for three penguins species). Each cell $(i, j)$ in the matrix indicates the number of instances where the true class was $i$ and the model predicted class $j$. Diagonal elements represent correct predictions, while off-diagonal elements indicate misclassifications. The confusion matrix provides an easy-to-understand overview of how often the predictions match the actual labels and where the model tends to make mistakes.
275277

278+
Since we will plot the confusion matrix multiple times, we write a function and call this function later whenever needed, which promotes clarity and avoids redundancy. This is especially helpful as we evaluate multiple classifiers such as KNN, Decision Trees, or SVM on the penguins dataset.
276279

280+
.. code-block:: python
281+
282+
from sklearn.metrics import confusion_matrix
277283
284+
def plot_confusion_matrix(conf_matrix, title, fig_name):
285+
plt.figure(figsize=(6, 5))
286+
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='OrRd',
287+
xticklabels=["Adelie", "Chinstrap", "Gentoo"],
288+
yticklabels=['Adelie', 'Chinstrap', 'Gentoo'], cbar=True)
289+
290+
plt.xlabel("Predicted Label")
291+
plt.ylabel("True Label")
292+
plt.title(title)
293+
plt.tight_layout()
294+
plt.savefig(fig_name)
278295
296+
We compute the confusion matrix from the trined model using the KNN algorithm, and visualize the matrix.
279297

298+
.. code-block:: python
280299
300+
cm_knn = confusion_matrix(y_test, y_pred_knn)
281301
302+
plot_confusion_matrix(cm_knn, "Confusion Matrix using KNN algorithm", "confusion-matrix-knn.png")
282303
283304
284305

0 commit comments

Comments
 (0)