Skip to content

Commit 502b2ce

Browse files
committed
knn code examples to train the model and metrics/codes for evaluation in episode 4
1 parent 1488525 commit 502b2ce

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

content/04-supervised-ML-classification.rst

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,37 @@ It is noted that the choice of *k* (the number of neighbors) significantly affec
239239
:width: 512px
240240

241241

242+
Let’s create the KNN model. Here we choose 3 as the *k* value of the algorithm, which means that data needs 3 neighbors to be classified as one entity. Then we fit the train data using the ``fit()`` method.
242243

244+
.. code-block:: python
245+
246+
from sklearn.neighbors import KNeighborsClassifier
247+
248+
knn_clf = KNeighborsClassifier(n_neighbors=3)
249+
knn_clf.fit(X_train_scaled, y_train)
250+
251+
252+
After we fitting the training data, we use the trained model to predict species on the test set and evaluate its performance.
253+
254+
For classification tasks, metrics like accuracy, precision, recall, and the F1-score provide a comprehensive view of model performance.
255+
256+
- **accuracy** measures the proportion of correctly classified instances across all species (Adelie, Chinstrap, Gentoo), and it gives an overall measure of how often the model is correct, but it can be misleading for imbalanced datasets.
257+
- **precision** quantifies the proportion of correct positive predictions for each species, while **recall** assesses the proportion of actual positives correctly identified.
258+
- the **F1-score**, the harmonic mean of precision and recall, balances these metrics for each class, especially useful given the dataset’s imbalanced species distribution.
259+
260+
261+
.. code-block:: python
262+
263+
# predict on test data
264+
y_pred_knn = knn_clf.predict(X_test_scaled)
265+
266+
# evaluate model performance
267+
from sklearn.metrics import classification_report, accuracy_score
243268
269+
score_knn = accuracy_score(y_test, y_pred_knn)
244270
271+
print("Accuracy for KNN:", score_knn)
272+
print("\nClassification Report:\n", classification_report(y_test, y_pred_knn))
245273
246274
247275

0 commit comments

Comments
 (0)