Skip to content

Commit 0ac544f

Browse files
committed
Created using Colaboratory
1 parent 7cc81f0 commit 0ac544f

File tree

1 file changed

+9
-33
lines changed

1 file changed

+9
-33
lines changed

colabs/Training_an_Image_Classification_Model_in_PyTorch.ipynb

+9-33
Original file line numberDiff line numberDiff line change
@@ -68,20 +68,6 @@
6868
"execution_count": null,
6969
"outputs": []
7070
},
71-
{
72-
"cell_type": "code",
73-
"metadata": {
74-
"id": "SOkA83IsRWYo"
75-
},
76-
"source": [
77-
"# IMPORTANT - Please restart your Colab runtime after installing Hub!\n",
78-
"# This is a Colab-specific issue that prevents PIL from working properly.\n",
79-
"import os\n",
80-
"os.kill(os.getpid(), 9)"
81-
],
82-
"execution_count": null,
83-
"outputs": []
84-
},
8571
{
8672
"cell_type": "markdown",
8773
"metadata": {
@@ -136,7 +122,7 @@
136122
"id": "jPSz9kml03Aa"
137123
},
138124
"source": [
139-
"print(ds_train.labels.info.class_names[str(ds_train.labels[0].numpy()[0])])"
125+
"print(ds_train.labels.info.class_names[ds_train.labels[0].numpy()[0]])"
140126
],
141127
"execution_count": null,
142128
"outputs": []
@@ -147,7 +133,7 @@
147133
"id": "Np5fIbViHlCu"
148134
},
149135
"source": [
150-
"The next step is to define a transformation function that will process the data and convert it into a format that can be passed into a deep learning model. The syntax for the transformation function is that the input parameter is a sample from a Hub dataset in dictionary syntax, and the return value is a dictionary containing the data that the training loop uses to train the model. In this particular example, `torchvision.transforms` is used as a part of the transformation pipeline that performs operations such as normalization and image augmentation (rotation)."
136+
"The next step is to define a transformation function that will process the data and convert it into a format that can be passed into a deep learning model. In this particular example, `torchvision.transforms` is used as a part of the transformation pipeline that performs operations such as normalization and image augmentation (rotation)."
151137
]
152138
},
153139
{
@@ -156,9 +142,6 @@
156142
"id": "WqdWgumwQ1d6"
157143
},
158144
"source": [
159-
"def transform(sample_in):\n",
160-
" return {'images': tform(sample_in['images']), 'labels': sample_in['labels']}\n",
161-
"\n",
162145
"tform = transforms.Compose([\n",
163146
" transforms.ToPILImage(), # Must convert to PIL image for subsequent operations to run\n",
164147
" transforms.RandomRotation(20), # Image augmentation\n",
@@ -169,22 +152,15 @@
169152
"execution_count": null,
170153
"outputs": []
171154
},
172-
{
173-
"cell_type": "markdown",
174-
"metadata": {
175-
"id": "ToNQ3WwfIJZf"
176-
},
177-
"source": [
178-
"**Note:** Don't worry if the above syntax is a bit confusing 😵! We're currently improving it."
179-
]
180-
},
181155
{
182156
"cell_type": "markdown",
183157
"metadata": {
184158
"id": "DGmWr44PIQMk"
185159
},
186160
"source": [
187-
"You are now ready to create a pytorch dataloader that connects the Hub dataset to the PyTorch model. This can be done using the provided method `ds.pytorch()` , which automatically applies the user-defined transformation function, takes care of random shuffling (if desired), and converts hub data to PyTorch tensors. The `num_workers` parameter can be used to parallelize data preprocessing, which is critical for ensuring that preprocessing does not bottleneck the overall training workflow."
161+
"You can now create a pytorch dataloader that connects the Hub dataset to the PyTorch model using the provided method `ds.pytorch()`. This method automatically applies the transformation function, takes care of random shuffling (if desired), and converts hub data to PyTorch tensors. The `num_workers` parameter can be used to parallelize data preprocessing, which is critical for ensuring that preprocessing does not bottleneck the overall training workflow.\n",
162+
"\n",
163+
"The `transform` input is a dictionary where the `key` is the tensor name and the `value` is the transformation function that should be applied to that tensor. If a specific tensor's data does not need to be returned, it should be omitted from the keys. If a tensor's data does not need to be modified during preprocessing, the transformation function is set as `None`."
188164
]
189165
},
190166
{
@@ -195,8 +171,8 @@
195171
"source": [
196172
"batch_size = 32\n",
197173
"\n",
198-
"train_loader = ds_train.pytorch(num_workers = 2, shuffle = True, transform = transform, batch_size = batch_size)\n",
199-
"test_loader = ds_test.pytorch(num_workers = 2, transform = transform, batch_size = batch_size)"
174+
"train_loader = ds_train.pytorch(num_workers = 0, shuffle = True, transform = {'images': tform, 'labels': None}, batch_size = batch_size)\n",
175+
"test_loader = ds_test.pytorch(num_workers = 0, transform = {'images': tform, 'labels': None}, batch_size = batch_size)"
200176
],
201177
"execution_count": null,
202178
"outputs": []
@@ -349,7 +325,7 @@
349325
" _, predicted = torch.max(outputs.data, 1)\n",
350326
" total += labels.size(0)\n",
351327
" correct += (predicted == labels).sum().item()\n",
352-
" accuracy = 100 * correct / total\n",
328+
" accuracy = 100 * correct / total\n",
353329
" \n",
354330
" print('Finished Testing')\n",
355331
" print('Testing accuracy: %.1f %%' %(accuracy))"
@@ -387,4 +363,4 @@
387363
]
388364
}
389365
]
390-
}
366+
}

0 commit comments

Comments
 (0)