Skip to content

Commit de99fb7

Browse files
committed
Fix code in training
1 parent 556b1b5 commit de99fb7

File tree

1 file changed

+76
-33
lines changed

1 file changed

+76
-33
lines changed

Image-Classification-in-PyTorch/image_classification_using_transfer_learning_in_pytorch.ipynb

+76-33
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"cell_type": "code",
55
"execution_count": null,
66
"metadata": {
7-
"collapsed": true
7+
"id": "SrL47MttOLS3"
88
},
99
"outputs": [],
1010
"source": [
@@ -29,7 +29,7 @@
2929
"cell_type": "code",
3030
"execution_count": null,
3131
"metadata": {
32-
"collapsed": true
32+
"id": "V3Jwg9JJOLTE"
3333
},
3434
"outputs": [],
3535
"source": [
@@ -65,7 +65,11 @@
6565
"cell_type": "code",
6666
"execution_count": null,
6767
"metadata": {
68-
"collapsed": true
68+
"colab": {
69+
"base_uri": "https://localhost:8080/"
70+
},
71+
"id": "LqB1PotzOLTF",
72+
"outputId": "2fef74c0-eba7-4dc5-a35a-b238432ee2c3"
6973
},
7074
"outputs": [],
7175
"source": [
@@ -83,7 +87,7 @@
8387
"bs = 32\n",
8488
"\n",
8589
"# Number of classes\n",
86-
"num_classes = len(os.listdir(valid_directory))-1 #10#2#257\n",
90+
"num_classes = len(os.listdir(valid_directory)) #10#2#257\n",
8791
"print(num_classes)\n",
8892
"\n",
8993
"# Load Data from folders\n",
@@ -112,31 +116,37 @@
112116
"cell_type": "code",
113117
"execution_count": null,
114118
"metadata": {
115-
"collapsed": true
119+
"colab": {
120+
"base_uri": "https://localhost:8080/"
121+
},
122+
"id": "2mUcSMIMOLTG",
123+
"outputId": "88cc5e5d-2c37-41e6-ef9d-72ca9237efd8"
116124
},
117125
"outputs": [],
118126
"source": [
119-
"train_data_size, valid_data_size, test_data_size"
127+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
128+
"\n",
129+
"print(train_data_size, valid_data_size, test_data_size)"
120130
]
121131
},
122132
{
123133
"cell_type": "code",
124134
"execution_count": null,
125135
"metadata": {
126-
"collapsed": true
136+
"id": "ViY8viSLOLTH"
127137
},
128138
"outputs": [],
129139
"source": [
130140
"# Load pretrained ResNet50 Model\n",
131141
"resnet50 = models.resnet50(pretrained=True)\n",
132-
"resnet50 = resnet50.to('cuda:0')\n"
142+
"resnet50 = resnet50.to(device)"
133143
]
134144
},
135145
{
136146
"cell_type": "code",
137147
"execution_count": null,
138148
"metadata": {
139-
"collapsed": true
149+
"id": "3eVi7mJZOLTH"
140150
},
141151
"outputs": [],
142152
"source": [
@@ -149,7 +159,7 @@
149159
"cell_type": "code",
150160
"execution_count": null,
151161
"metadata": {
152-
"collapsed": true
162+
"id": "EqMKMqdWOLTH"
153163
},
154164
"outputs": [],
155165
"source": [
@@ -164,15 +174,17 @@
164174
" nn.LogSoftmax(dim=1) # For using NLLLoss()\n",
165175
")\n",
166176
"\n",
177+
"\n",
178+
"\n",
167179
"# Convert model to be used on GPU\n",
168-
"resnet50 = resnet50.to('cuda:0')\n"
180+
"resnet50 = resnet50.to(device)\n"
169181
]
170182
},
171183
{
172184
"cell_type": "code",
173185
"execution_count": null,
174186
"metadata": {
175-
"collapsed": true
187+
"id": "3Lske-iPOLTI"
176188
},
177189
"outputs": [],
178190
"source": [
@@ -185,7 +197,7 @@
185197
"cell_type": "code",
186198
"execution_count": null,
187199
"metadata": {
188-
"collapsed": true
200+
"id": "dyG4bCHHOLTI"
189201
},
190202
"outputs": [],
191203
"source": [
@@ -205,7 +217,8 @@
205217
" \n",
206218
" start = time.time()\n",
207219
" history = []\n",
208-
" best_acc = 0.0\n",
220+
" best_loss = 100000.0\n",
221+
" best_epoch = None\n",
209222
"\n",
210223
" for epoch in range(epochs):\n",
211224
" epoch_start = time.time()\n",
@@ -256,7 +269,7 @@
256269
" \n",
257270
" #print(\"Batch number: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}\".format(i, loss.item(), acc.item()))\n",
258271
"\n",
259-
" \n",
272+
" \n",
260273
" # Validation - No gradient tracking needed\n",
261274
" with torch.no_grad():\n",
262275
"\n",
@@ -288,7 +301,10 @@
288301
" valid_acc += acc.item() * inputs.size(0)\n",
289302
"\n",
290303
" #print(\"Validation Batch number: {:03d}, Validation: Loss: {:.4f}, Accuracy: {:.4f}\".format(j, loss.item(), acc.item()))\n",
291-
" \n",
304+
" if valid_loss < best_loss:\n",
305+
" best_loss = valid_loss\n",
306+
" best_epoch = epoch\n",
307+
"\n",
292308
" # Find average training loss and training accuracy\n",
293309
" avg_train_loss = train_loss/train_data_size \n",
294310
" avg_train_acc = train_acc/train_data_size\n",
@@ -301,31 +317,33 @@
301317
" \n",
302318
" epoch_end = time.time()\n",
303319
" \n",
304-
" print(\"Epoch : {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \\n\\t\\tValidation : Loss : {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s\".format(epoch, avg_train_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start))\n",
320+
" print(\"Epoch : {:03d}, Training: Loss - {:.4f}, Accuracy - {:.4f}%, \\n\\t\\tValidation : Loss - {:.4f}, Accuracy - {:.4f}%, Time: {:.4f}s\".format(epoch, avg_train_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start))\n",
305321
" \n",
306322
" # Save if the model has best accuracy till now\n",
307323
" torch.save(model, dataset+'_model_'+str(epoch)+'.pt')\n",
308324
" \n",
309-
" return model, history\n",
325+
" return model, history, best_epoch\n",
310326
" "
311327
]
312328
},
313329
{
314330
"cell_type": "code",
315331
"execution_count": null,
316332
"metadata": {
317-
"collapsed": true
333+
"colab": {
334+
"base_uri": "https://localhost:8080/"
335+
},
336+
"id": "GIVZTrxAOLTN",
337+
"outputId": "ebdcc4b0-a98e-4488-dddb-bd948e452cc0"
318338
},
319339
"outputs": [],
320340
"source": [
321-
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
322-
"\n",
323341
"# Print the model to be trained\n",
324342
"#summary(resnet50, input_size=(3, 224, 224), batch_size=bs, device='cuda')\n",
325343
"\n",
326344
"# Train the model for 25 epochs\n",
327345
"num_epochs = 30\n",
328-
"trained_model, history = train_and_validate(resnet50, loss_func, optimizer, num_epochs)\n",
346+
"trained_model, history, best_epoch = train_and_validate(resnet50, loss_func, optimizer, num_epochs)\n",
329347
"\n",
330348
"torch.save(history, dataset+'_history.pt')"
331349
]
@@ -334,7 +352,12 @@
334352
"cell_type": "code",
335353
"execution_count": null,
336354
"metadata": {
337-
"collapsed": true
355+
"colab": {
356+
"base_uri": "https://localhost:8080/",
357+
"height": 283
358+
},
359+
"id": "pGcRnfbiOLTS",
360+
"outputId": "1326a7fe-4daf-45ea-9f69-7219e23d277b"
338361
},
339362
"outputs": [],
340363
"source": [
@@ -352,7 +375,12 @@
352375
"cell_type": "code",
353376
"execution_count": null,
354377
"metadata": {
355-
"collapsed": true
378+
"colab": {
379+
"base_uri": "https://localhost:8080/",
380+
"height": 283
381+
},
382+
"id": "1e8oWgPWOLTS",
383+
"outputId": "c74fcaa4-1b41-436f-de02-8be4ee5c1a11"
356384
},
357385
"outputs": [],
358386
"source": [
@@ -369,7 +397,7 @@
369397
"cell_type": "code",
370398
"execution_count": null,
371399
"metadata": {
372-
"collapsed": true
400+
"id": "FGHpMyvHOLTS"
373401
},
374402
"outputs": [],
375403
"source": [
@@ -429,7 +457,7 @@
429457
"cell_type": "code",
430458
"execution_count": null,
431459
"metadata": {
432-
"collapsed": true
460+
"id": "QBHeFS7QOLTT"
433461
},
434462
"outputs": [],
435463
"source": [
@@ -444,11 +472,11 @@
444472
" \n",
445473
" transform = image_transforms['test']\n",
446474
"\n",
475+
"\n",
447476
" test_image = Image.open(test_image_name)\n",
448477
" plt.imshow(test_image)\n",
449478
" \n",
450479
" test_image_tensor = transform(test_image)\n",
451-
"\n",
452480
" if torch.cuda.is_available():\n",
453481
" test_image_tensor = test_image_tensor.view(1, 3, 224, 224).cuda()\n",
454482
" else:\n",
@@ -459,7 +487,11 @@
459487
" # Model outputs log probabilities\n",
460488
" out = model(test_image_tensor)\n",
461489
" ps = torch.exp(out)\n",
490+
"\n",
462491
" topk, topclass = ps.topk(3, dim=1)\n",
492+
" cls = idx_to_class[topclass.cpu().numpy()[0][0]]\n",
493+
" score = topk.cpu().numpy()[0][0]\n",
494+
"\n",
463495
" for i in range(3):\n",
464496
" print(\"Predcition\", i+1, \":\", idx_to_class[topclass.cpu().numpy()[0][i]], \", Score: \", topk.cpu().numpy()[0][i])\n",
465497
"\n",
@@ -470,15 +502,20 @@
470502
"cell_type": "code",
471503
"execution_count": null,
472504
"metadata": {
473-
"collapsed": true
505+
"colab": {
506+
"base_uri": "https://localhost:8080/",
507+
"height": 507
508+
},
509+
"id": "vK8eHndnOLTU",
510+
"outputId": "581a2e73-b90f-446e-fc73-d4820adc4b68"
474511
},
475512
"outputs": [],
476513
"source": [
477514
"# Test a particular model on a test image\n",
478-
"\n",
515+
"! wget https://cdn.pixabay.com/photo/2018/10/01/12/28/skunk-3716043_1280.jpg -O skunk.jpg\n",
479516
"dataset = 'caltech_10'\n",
480-
"model = torch.load('caltech_10_model_8.pt')\n",
481-
"predict(model, 'pixabay-test-animals/triceratops-954293_640.jpg')\n",
517+
"model = torch.load(\"{}_model_{}.pt\".format(dataset, best_epoch))\n",
518+
"predict(model, 'skunk.jpg')\n",
482519
"\n",
483520
"# Load Data from folders\n",
484521
"#computeTestSetAccuracy(model, loss_func)\n",
@@ -488,6 +525,12 @@
488525
}
489526
],
490527
"metadata": {
528+
"accelerator": "GPU",
529+
"colab": {
530+
"collapsed_sections": [],
531+
"name": "image_classification_using_transfer_learning_in_pytorch.ipynb",
532+
"provenance": []
533+
},
491534
"kernelspec": {
492535
"display_name": "Python 3",
493536
"language": "python",
@@ -503,9 +546,9 @@
503546
"name": "python",
504547
"nbconvert_exporter": "python",
505548
"pygments_lexer": "ipython3",
506-
"version": "3.6.7"
549+
"version": "3.6.9"
507550
}
508551
},
509552
"nbformat": 4,
510-
"nbformat_minor": 2
553+
"nbformat_minor": 1
511554
}

0 commit comments

Comments
 (0)