|
167 | 167 | "metadata": {},
|
168 | 168 | "outputs": [],
|
169 | 169 | "source": [
|
170 |
| - "from typing import List, Tuple, Dict\n", |
| 170 | + "from typing import List, Tuple, Any, Dict\n", |
171 | 171 | "\n",
|
172 | 172 | "# import some useful functions here, see https://pytorch.org/docs/stable/torch.html\n",
|
173 | 173 | "# where `tensor` is used for constructing tensors,\n",
|
174 | 174 | "# and using a lower-precision float32 is advised for performance\n",
|
175 |
| - "\n", |
176 | 175 | "# Task 4: add imports here\n",
|
177 |
| - "# from torch import tensor, Tensor, float32\n", |
| 176 | + "# from torch import tensor, float32\n", |
178 | 177 | "\n",
|
179 | 178 | "from torch.utils.data import Dataset\n",
|
180 | 179 | "\n",
|
181 | 180 | "from palmerpenguins import load_penguins\n",
|
182 | 181 | "\n",
|
183 | 182 | "\n",
|
184 | 183 | "class PenguinDataset(Dataset):\n",
|
185 |
| - " \"\"\"Simplified Penguin dataset for classification tasks.\n", |
| 184 | + " \"\"\"Penguin dataset class.\n", |
186 | 185 | "\n",
|
187 | 186 | " Parameters\n",
|
188 | 187 | " ----------\n",
|
189 | 188 | " input_keys : List[str]\n",
|
190 |
| - " Column names to use as input features.\n", |
| 189 | + " The column titles to use in the input feature vectors.\n", |
191 | 190 | " target_key : str\n",
|
192 |
| - " Categorical target column (e.g., \"species\").\n", |
| 191 | + " The column titles to use in the target feature vectors.\n", |
193 | 192 | " train : bool\n",
|
194 | 193 | " If ``True``, this object will serve as the training set, and if\n",
|
195 | 194 | " ``False``, the validation set.\n",
|
|
198 | 197 | " -----\n",
|
199 | 198 | " The validation split contains 10 male and 10 female penguins of each\n",
|
200 | 199 | " species.\n",
|
| 200 | + "\n", |
201 | 201 | " \"\"\"\n",
|
202 | 202 | "\n",
|
203 | 203 | " def __init__(\n",
|
|
206 | 206 | " target_key: str,\n",
|
207 | 207 | " train: bool,\n",
|
208 | 208 | " ):\n",
|
209 |
| - " \"\"\"Build `PenguinDataset` for classification.\"\"\"\n", |
| 209 | + " \"\"\"Build ``PenguinDataset``.\"\"\"\n", |
210 | 210 | " self.input_keys = input_keys\n",
|
211 | 211 | " self.target_key = target_key\n",
|
212 | 212 | "\n",
|
213 |
| - " # Load and clean full dataset\n", |
214 | 213 | " data = load_penguins()\n",
|
215 |
| - " data = data.dropna().sort_values(by=sorted(data.columns)).reset_index(drop=True)\n", |
216 |
| - " data[\"sex\"] = (data[\"sex\"] == \"male\").astype(float)\n", |
| 214 | + " data = (\n", |
| 215 | + " data.loc[~data.isna().any(axis=1)]\n", |
| 216 | + " .sort_values(by=sorted(data.keys()))\n", |
| 217 | + " .reset_index(drop=True)\n", |
| 218 | + " )\n", |
| 219 | + " # Transform the sex field into a float, with male represented by 1.0, female by 0.0\n", |
| 220 | + " data.sex = (data.sex == \"male\").astype(float)\n", |
217 | 221 | " self.full_df = data\n",
|
218 | 222 | "\n",
|
219 |
| - " # Create stratified validation split\n", |
220 |
| - " valid_df = data.groupby([\"species\", \"sex\"]).sample(n=10, random_state=123)\n", |
221 |
| - " train_df = data[~data.index.isin(valid_df.index)]\n", |
| 223 | + " valid_df = self.full_df.groupby(by=[\"species\", \"sex\"]).sample(\n", |
| 224 | + " n=10,\n", |
| 225 | + " random_state=123,\n", |
| 226 | + " )\n", |
| 227 | + " # The training items are simply the items *not* in the valid split\n", |
| 228 | + " train_df = self.full_df.loc[~self.full_df.index.isin(valid_df.index)]\n", |
222 | 229 | "\n",
|
223 |
| - " # Choose split\n", |
224 |
| - " split_df = train_df if train else valid_df\n", |
| 230 | + " self.split = {\"train\": train_df, \"valid\": valid_df}[\n", |
| 231 | + " \"train\" if train is True else \"valid\"\n", |
| 232 | + " ]\n", |
225 | 233 | "\n",
|
226 | 234 | " # Build label map from the full dataset\n",
|
227 |
| - " unique_labels = sorted(self.full_df[target_key].unique())\n", |
228 |
| - " self.label_map = {label: idx for idx, label in enumerate(unique_labels)}\n", |
| 235 | + " unique_labels = sorted(self.full_df[self.target_key].unique())\n", |
| 236 | + " self.label_map: Dict[str, int] = {\n", |
| 237 | + " label: idx for idx, label in enumerate(unique_labels)\n", |
| 238 | + " }\n", |
229 | 239 | "\n",
|
230 |
| - " # Precompute tensors from split only\n", |
231 |
| - " self.features = tensor(split_df[input_keys].values, dtype=float32)\n", |
232 |
| - " self.targets = tensor(\n", |
233 |
| - " split_df[target_key].map(self.label_map).values, dtype=long\n", |
234 |
| - " )\n", |
| 240 | + " def __len__(self) -> int:\n", |
| 241 | + " \"\"\"Return the length of requested split.\n", |
235 | 242 | "\n",
|
236 |
| - " def get_label_map(self) -> Dict:\n", |
237 |
| - " \"\"\"Return the label-to-index mapping.\"\"\"\n", |
238 |
| - " return self.label_map\n", |
| 243 | + " Returns\n", |
| 244 | + " -------\n", |
| 245 | + " int\n", |
| 246 | + " The number of items in the dataset.\n", |
239 | 247 | "\n",
|
240 |
| - " def __len__(self) -> int:\n", |
241 |
| - " # Task 4 - Exercise #1: Return length of features\n", |
242 |
| - " return ...\n", |
| 248 | + " \"\"\"\n", |
| 249 | + " return len(self.split)\n", |
| 250 | + "\n", |
| 251 | + " def __getitem__(self, idx: int) -> Tuple[Any, Any]:\n", |
| 252 | + " \"\"\"Return an input-target pair.\n", |
243 | 253 | "\n",
|
244 |
| - " def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]:\n", |
245 |
| - " # Task 4 - Exercise #2: Return example\n", |
246 |
| - " return ..." |
| 254 | + " Parameters\n", |
| 255 | + " ----------\n", |
| 256 | + " idx : int\n", |
| 257 | + " Index of the input-target pair to return.\n", |
| 258 | + "\n", |
| 259 | + " Returns\n", |
| 260 | + " -------\n", |
| 261 | + " in_feats : Any\n", |
| 262 | + " Inputs.\n", |
| 263 | + " target : Any\n", |
| 264 | + " Targets.\n", |
| 265 | + "\n", |
| 266 | + " \"\"\"\n", |
| 267 | + " # get the row index (idx) from the dataframe and\n", |
| 268 | + " # select relevant column features (provided as input_keys)\n", |
| 269 | + " feats = tuple(self.split.iloc[idx][self.input_keys])\n", |
| 270 | + "\n", |
| 271 | + " # this gives a 'species' i.e. one of ('Gentoo',), ('Chinstrap',), or ('Adelie',)\n", |
| 272 | + " tgt = self.split.iloc[idx][self.target_key]\n", |
| 273 | + "\n", |
| 274 | + " # Task 4 -- Part (a): Convert the tuple features to PyTorch Tensors\n", |
| 275 | + "\n", |
| 276 | + " # Task 4 -- Part (b): Convert the target (a Python integer) to a 0-D tensor (scalar tensor).\n", |
| 277 | + "\n", |
| 278 | + "\n", |
| 279 | + " return feats, tgt" |
247 | 280 | ]
|
248 | 281 | },
|
249 | 282 | {
|
|
300 | 333 | "cell_type": "markdown",
|
301 | 334 | "metadata": {},
|
302 | 335 | "source": [
|
303 |
| - "### Task 3 -- Part (a) and (b): Applying transforms to the data\n", |
| 336 | + "### Task 4 -- Part (a) and (b): Convert Dataset outputs to PyTorch Tensors\n", |
304 | 337 | "\n",
|
305 | 338 | "Modify the `PenguinDataset` class above so that the tuples of numbers are converted to PyTorch `torch.Tensor` s and the string targets are converted to indices.\n",
|
306 | 339 | "\n",
|
307 | 340 | "- Begin by importing relevant PyTorch functions.\n",
|
308 |
| - "- Complete `__len__()` and `__getitem__()` functions above.\n", |
| 341 | + "- Complete the `__getitem__()` function above.\n", |
309 | 342 | "\n",
|
310 | 343 | "Then create a training and validation set.\n",
|
311 | 344 | "\n",
|
|
314 | 347 | " \n",
|
315 | 348 | "For the validation set, we choose ten males and ten females of each species. This means the validation set is less likely to be biased by sex and species, and is potentially a more reliable measure of performance. You should always be _very_ careful when choosing metrics and splitting data.\n",
|
316 | 349 | "\n",
|
317 |
| - "- Is this transformation approach general? No, but it's a good start. " |
| 350 | + "- Is this transformation approach general? No, but it's a good start. \n", |
| 351 | + " - Switch between validation/train time transformations?" |
318 | 352 | ]
|
319 | 353 | },
|
320 | 354 | {
|
|
364 | 398 | "source": [
|
365 | 399 | "from torchvision.transforms import Compose\n",
|
366 | 400 | "\n",
|
| 401 | + "# from ml_workshop import PenguinDataset\n", |
| 402 | + "\n", |
367 | 403 | "# import some useful functions here, see https://pytorch.org/docs/stable/torch.html\n",
|
368 | 404 | "# where `tensor` is used for constructing tensors,\n",
|
369 | 405 | "# and using a lower-precision float32 is advised for performance\n",
|
|
0 commit comments