Skip to content

Commit cd17f07

Browse files
committed
Update 01 exercise with solution
1 parent 7906dc7 commit cd17f07

File tree

3 files changed

+724
-119
lines changed

3 files changed

+724
-119
lines changed

exercises/01_penguin_classification.ipynb

Lines changed: 214 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,90 @@
3939
"from palmerpenguins import load_penguins"
4040
]
4141
},
42+
{
43+
"cell_type": "markdown",
44+
"metadata": {},
45+
"source": [
46+
"<div style=\"text-align: center;\">\n",
47+
" <img src=\"https://raw.githubusercontent.com/allisonhorst/palmerpenguins/c19a904462482430170bfe2c718775ddb7dbb885/man/figures/culmen_depth.png\" width=\"500\" />\n",
48+
"</div>"
49+
]
50+
},
51+
{
52+
"cell_type": "markdown",
53+
"metadata": {},
54+
"source": [
55+
"### Task 1 -- Part (b): Use seaborn to plot the distribution of the penguin species in the dataset."
56+
]
57+
},
58+
{
59+
"cell_type": "code",
60+
"execution_count": null,
61+
"metadata": {},
62+
"outputs": [],
63+
"source": [
64+
"# import seaborn as sns\n",
65+
"# sns.pairplot(data.drop(\"year\", axis=1), hue='species')"
66+
]
67+
},
68+
{
69+
"cell_type": "markdown",
70+
"metadata": {
71+
"vscode": {
72+
"languageId": "markdown"
73+
}
74+
},
75+
"source": [
76+
"### Task 1 -- Part (c): Apply umap to visualise the data"
77+
]
78+
},
79+
{
80+
"cell_type": "code",
81+
"execution_count": null,
82+
"metadata": {},
83+
"outputs": [],
84+
"source": [
85+
"import umap\n",
86+
"import matplotlib.pyplot as plt\n",
87+
"import seaborn as sns\n",
88+
"from sklearn.preprocessing import StandardScaler\n",
89+
"\n",
90+
"# Drop rows with missing values\n",
91+
"data = data.dropna() \n",
92+
"\n",
93+
"# Extract features\n",
94+
"penguin_data = data[\n",
95+
" [\n",
96+
" \"bill_length_mm\",\n",
97+
" \"bill_depth_mm\",\n",
98+
" \"flipper_length_mm\",\n",
99+
" \"body_mass_g\",\n",
100+
" ]\n",
101+
"].values \n",
102+
"scaled_penguin_data = StandardScaler().fit_transform(penguin_data)\n",
103+
"\n",
104+
"# Fit and transform\n",
105+
"reducer = umap.UMAP(random_state=42)\n",
106+
"embedding = reducer.fit_transform(scaled_penguin_data)\n",
107+
"\n",
108+
"colors = sns.color_palette()\n",
109+
"\n",
110+
"for i, (species, group) in enumerate(data.groupby(\"species\")):\n",
111+
" plt.scatter(\n",
112+
" embedding[data.species == species, 0],\n",
113+
" embedding[data.species == species, 1],\n",
114+
" label=species,\n",
115+
" color=colors[i],\n",
116+
" )\n",
117+
"\n",
118+
"plt.gca().set_aspect(\"equal\", \"datalim\")\n",
119+
"plt.title(\"UMAP projection of the Penguin dataset\", fontsize=24)\n",
120+
"plt.xlabel(\"UMAP 1\", fontsize=18)\n",
121+
"plt.ylabel(\"UMAP 2\", fontsize=18)\n",
122+
"plt.legend(loc=\"upper right\", fontsize=10, title=\"Species\")\n",
123+
"plt.show()"
124+
]
125+
},
42126
{
43127
"cell_type": "markdown",
44128
"metadata": {},
@@ -181,7 +265,7 @@
181265
"source": [
182266
"data_set = PenguinDataset(\n",
183267
" input_keys=[\"bill_length_mm\", \"body_mass_g\"],\n",
184-
" target_keys=...,\n",
268+
" target_key=...,\n",
185269
" train=True,\n",
186270
")\n",
187271
"\n",
@@ -203,10 +287,7 @@
203287
" <li>We must represent these data as <code>torch.Tensor</code>s. This is the fundamental data abstraction used by PyTorch; they are the PyTorch equivalent to Numpy arrays, while also providing support for GPU acceleration. See <a href=\"https://pytorch.org/tutorials/beginner/introyt/tensors_deeper_tutorial.html\">pytorch tensors documentation</a>.</li>\n",
204288
" <li>The targets are tuples of strings i.e. ('Gentoo', )\n",
205289
" <ul>\n",
206-
" <li>One idea is to represent as ordinal values i.e. [1] or [2] or [3]. But this implies that the class encoded by value 1 is closer to 2 than 1 is to 3. This is not desirable for categorical data. One-hot encoding avoids this by representing each species independently.<br>\n",
207-
" \"A\" — [1, 0, 0]<br>\n",
208-
" \"B\" — [0, 1, 0]<br>\n",
209-
" \"C\" — [0, 0, 1]</li>\n",
290+
" <li>One idea is to represent as categorical indices i.e. [1] or [2] or [3]. Will this work? \n",
210291
" </ul>\n",
211292
" </li>\n",
212293
" </ul>\n",
@@ -219,9 +300,9 @@
219300
"cell_type": "markdown",
220301
"metadata": {},
221302
"source": [
222-
"### Task 4 -- Part (a) and (b): Applying transforms to the data\n",
303+
"### Task 3 -- Part (a) and (b): Applying transforms to the data\n",
223304
"\n",
224-
"Modify the `PenguinDataset` class above so that the tuples of numbers are converted to PyTorch `torch.Tensor` s and the string targets are converted to one-hot vectors.\n",
305+
"Modify the `PenguinDataset` class above so that the tuples of numbers are converted to PyTorch `torch.Tensor` s and the string targets are converted to indices.\n",
225306
"\n",
226307
"- Begin by importing relevant PyTorch functions.\n",
227308
"- Complete `__len__()` and `__getitem__()` functions above.\n",
@@ -242,8 +323,8 @@
242323
"metadata": {},
243324
"outputs": [],
244325
"source": [
245-
"# Apply the transforms we need to PenguinDataset class to convert input\n",
246-
"# data and target class to tensors. See Task 4 ``TODOs`` in PenguinDataset class.\n",
326+
"# Complete __len__() and __getitem__() functions\n",
327+
"# See Task 4 ``TODOs`` in PenguinDataset class.\n",
247328
"\n",
248329
"# Create train_set\n",
249330
"\n",
@@ -298,7 +379,7 @@
298379
"source": [
299380
"### Task 5: Creating ``DataLoaders``—and why\n",
300381
"\n",
301-
"Once we have created a ``Dataset`` object, we wrap it in a ``DataLoader``.\n",
382+
"Once we have created a ``Dataset`` object, we wrap it in a ``DataLoader``. This comes with a number of useful features:\n",
302383
"#### Mini-batches\n",
303384
"The ``DataLoader`` object allows us to put our inputs and targets in **mini-batches**, which makes for more efficient training.\n",
304385
"- Note: rather than supplying one input-target pair to the model at a time, we supply \"mini-batches\" of these data at once (typically a small power of 2, like 16 or 32).\n",
@@ -664,7 +745,7 @@
664745
" # run forward model and compute proxy probabilities over dimension 1 (columns of tensor).\n",
665746
"\n",
666747
" # compute loss\n",
667-
" # e.g. pred = [0.2, 0.7, 0.1] and target = [0, 1, 0]\n",
748+
" # e.g. pred : Tensor([3]) and target : int\n",
668749
"\n",
669750
" # compute gradients\n",
670751
"\n",
@@ -742,7 +823,58 @@
742823
"source": [
743824
"### Task 11: Visualise some results\n",
744825
"\n",
745-
"Let's do this part together—though feel free to make a start on your own if you have completed the previous exercises."
826+
"Let's do this part together—though feel free to make a start on your own if you have completed the previous exercises.\n",
827+
"\n",
828+
"<details>\n",
829+
"<summary>Visualising results</summary>\n",
830+
"\n",
831+
"```python\n",
832+
"\n",
833+
"import matplotlib.pyplot as plt\n",
834+
"\n",
835+
"quantities = [\"loss\", \"accuracy\"]\n",
836+
"splits = [\"train\", \"valid\"]\n",
837+
"\n",
838+
"epochs_range = np.arange(1, epochs + 1)\n",
839+
"\n",
840+
"fig, axes = plt.subplots(1, 2, figsize=(8, 4))\n",
841+
"\n",
842+
"for i, quant in enumerate(quantities):\n",
843+
" ax = axes[i]\n",
844+
" for split in splits:\n",
845+
" values = metrics[f\"{quant}_{split}\"]\n",
846+
" ax.plot(epochs_range, values, marker='o', markersize=2, label=split.capitalize())\n",
847+
" ax.set_title(quant.capitalize())\n",
848+
" ax.set_xlabel(\"Epoch\")\n",
849+
" ax.set_ylabel(quant.capitalize())\n",
850+
" ax.set_xlim(1, epochs)\n",
851+
" ax.set_ylim(0.0, 1.0)\n",
852+
" ax.legend()\n",
853+
"\n",
854+
"fig.tight_layout()\n",
855+
"plt.show()\n",
856+
"\n",
857+
"\n",
858+
"```\n"
859+
]
860+
},
861+
{
862+
"cell_type": "code",
863+
"execution_count": 2,
864+
"metadata": {},
865+
"outputs": [],
866+
"source": [
867+
"import matplotlib.pyplot as plt\n",
868+
"\n",
869+
"quantities = [\"loss\", \"accuracy\"]\n",
870+
"splits = [\"train\", \"valid\"]\n"
871+
]
872+
},
873+
{
874+
"cell_type": "markdown",
875+
"metadata": {},
876+
"source": [
877+
"### Task 12 -- Part (a): Confusion matrix"
746878
]
747879
},
748880
{
@@ -751,34 +883,83 @@
751883
"metadata": {},
752884
"outputs": [],
753885
"source": [
886+
"from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay\n",
754887
"import matplotlib.pyplot as plt\n",
755-
"from numpy import linspace\n",
888+
"import numpy as np\n",
756889
"\n",
890+
"class_names = sorted(data.species.unique())\n",
757891
"\n",
758-
"quantities = [\"loss\", \"accuracy\"]\n",
759-
"splits = [\"train\", \"valid\"]\n",
892+
"all_preds = []\n",
893+
"all_labels = []\n",
760894
"\n",
761-
"fig, axes = plt.subplots(1, 2, figsize=(8, 4))\n",
895+
"model.eval()\n",
896+
"with no_grad():\n",
897+
" for batch, label in valid_loader:\n",
898+
" preds = model(batch).softmax(dim=1)\n",
899+
" all_preds.append(preds.argmax(dim=1).numpy())\n",
900+
" all_labels.append(label.numpy())\n",
901+
"\n",
902+
"# concatenate all predictions and labels\n",
903+
"all_preds = np.concatenate(all_preds)\n",
904+
"all_labels = np.concatenate(all_labels)\n",
905+
"\n",
906+
"cm = confusion_matrix(all_labels, all_preds, labels=[0, 1, 2])\n",
907+
"cm_normalized = cm.astype(\"float\") / (cm.sum(axis=1)[:, np.newaxis] + 1e-8)\n",
908+
"disp = ConfusionMatrixDisplay(\n",
909+
" confusion_matrix=cm_normalized, display_labels=class_names\n",
910+
")\n",
911+
"\n",
912+
"# plotting\n",
913+
"fig, ax = plt.subplots(figsize=(6, 5))\n",
914+
"disp.plot(ax=ax, cmap=\"Blues\", colorbar=True, values_format=\".2f\")\n",
915+
"disp.ax_.set_xlabel(\"Predicted Label\")\n",
916+
"disp.ax_.set_ylabel(\"True Label\")\n",
917+
"plt.xticks(rotation=45)\n",
918+
"plt.grid(False) # cleaner plot\n",
919+
"plt.title(\"Normalized Confusion Matrix\")\n",
920+
"plt.tight_layout()\n",
921+
"plt.show()"
922+
]
923+
},
924+
{
925+
"cell_type": "markdown",
926+
"metadata": {},
927+
"source": [
928+
"### Task 12 -- Part (b): Classification report"
929+
]
930+
},
931+
{
932+
"cell_type": "code",
933+
"execution_count": null,
934+
"metadata": {},
935+
"outputs": [],
936+
"source": [
937+
"from sklearn.metrics import classification_report\n",
938+
"import pandas as pd\n",
939+
"\n",
940+
"\n",
941+
"# class_names = ['Adelie', 'Chinstrap', 'Gentoo']\n",
942+
"report = classification_report(\n",
943+
" y_true=all_labels,\n",
944+
" y_pred=all_preds,\n",
945+
" target_names=class_names,\n",
946+
" output_dict=True # <- so we can plot it\n",
947+
")\n",
762948
"\n",
763-
"for axis, quant in zip(axes.ravel(), quantities):\n",
764-
" for split in splits:\n",
765-
" key = f\"{quant}_{split}\"\n",
766-
" axis.plot(\n",
767-
" linspace(1, epochs, epochs),\n",
768-
" metrics[key],\n",
769-
" \"-o\",\n",
770-
" ms=1.5,\n",
771-
" label=split.capitalize(),\n",
772-
" )\n",
773-
" axis.set_ylabel(quant.capitalize(), fontsize=15)\n",
774949
"\n",
775-
"for axis in axes.ravel():\n",
776-
" axis.legend(fontsize=15)\n",
777-
" axis.set_ylim(bottom=0.0, top=1.0)\n",
778-
" axis.set_xlim(left=1, right=epochs)\n",
779-
" axis.set_xlabel(\"Epoch\", fontsize=15)\n",
950+
"# Convert the report dict to DataFrame for plotting\n",
951+
"report_df = pd.DataFrame(report).transpose()\n",
952+
"report_df = report_df.loc[class_names, ['precision', 'recall', 'f1-score']]\n",
780953
"\n",
781-
"fig.tight_layout()"
954+
"# Plot\n",
955+
"report_df.plot(kind='bar', figsize=(8, 5))\n",
956+
"plt.title(\"Per-Class Precision, Recall, and F1 Score\")\n",
957+
"plt.ylim(0.0, 1.05)\n",
958+
"plt.ylabel(\"Score\")\n",
959+
"plt.xticks(rotation=0)\n",
960+
"plt.grid(axis='y')\n",
961+
"plt.tight_layout()\n",
962+
"plt.show()"
782963
]
783964
},
784965
{

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ dependencies = [
3030
"torch_tools @ git+https://github.com/jdenholm/TorchTools.git",
3131
"matplotlib",
3232
"numpy<2.0.0",
33+
"umap-learn",
34+
"seaborn"
3335
]
3436

3537
[project.optional-dependencies]

0 commit comments

Comments
 (0)