|
4 | 4 | "cell_type": "code", |
5 | 5 | "execution_count": null, |
6 | 6 | "metadata": { |
7 | | - "collapsed": true |
| 7 | + "id": "SrL47MttOLS3" |
8 | 8 | }, |
9 | 9 | "outputs": [], |
10 | 10 | "source": [ |
|
29 | 29 | "cell_type": "code", |
30 | 30 | "execution_count": null, |
31 | 31 | "metadata": { |
32 | | - "collapsed": true |
| 32 | + "id": "V3Jwg9JJOLTE" |
33 | 33 | }, |
34 | 34 | "outputs": [], |
35 | 35 | "source": [ |
|
65 | 65 | "cell_type": "code", |
66 | 66 | "execution_count": null, |
67 | 67 | "metadata": { |
68 | | - "collapsed": true |
| 68 | + "colab": { |
| 69 | + "base_uri": "https://localhost:8080/" |
| 70 | + }, |
| 71 | + "id": "LqB1PotzOLTF", |
| 72 | + "outputId": "2fef74c0-eba7-4dc5-a35a-b238432ee2c3" |
69 | 73 | }, |
70 | 74 | "outputs": [], |
71 | 75 | "source": [ |
|
83 | 87 | "bs = 32\n", |
84 | 88 | "\n", |
85 | 89 | "# Number of classes\n", |
86 | | - "num_classes = len(os.listdir(valid_directory))-1 #10#2#257\n", |
| 90 | + "num_classes = len(os.listdir(valid_directory)) #10#2#257\n", |
87 | 91 | "print(num_classes)\n", |
88 | 92 | "\n", |
89 | 93 | "# Load Data from folders\n", |
|
112 | 116 | "cell_type": "code", |
113 | 117 | "execution_count": null, |
114 | 118 | "metadata": { |
115 | | - "collapsed": true |
| 119 | + "colab": { |
| 120 | + "base_uri": "https://localhost:8080/" |
| 121 | + }, |
| 122 | + "id": "2mUcSMIMOLTG", |
| 123 | + "outputId": "88cc5e5d-2c37-41e6-ef9d-72ca9237efd8" |
116 | 124 | }, |
117 | 125 | "outputs": [], |
118 | 126 | "source": [ |
119 | | - "train_data_size, valid_data_size, test_data_size" |
| 127 | + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", |
| 128 | + "\n", |
| 129 | + "print(train_data_size, valid_data_size, test_data_size)" |
120 | 130 | ] |
121 | 131 | }, |
122 | 132 | { |
123 | 133 | "cell_type": "code", |
124 | 134 | "execution_count": null, |
125 | 135 | "metadata": { |
126 | | - "collapsed": true |
| 136 | + "id": "ViY8viSLOLTH" |
127 | 137 | }, |
128 | 138 | "outputs": [], |
129 | 139 | "source": [ |
130 | 140 | "# Load pretrained ResNet50 Model\n", |
131 | 141 | "resnet50 = models.resnet50(pretrained=True)\n", |
132 | | - "resnet50 = resnet50.to('cuda:0')\n" |
| 142 | + "resnet50 = resnet50.to(device)" |
133 | 143 | ] |
134 | 144 | }, |
135 | 145 | { |
136 | 146 | "cell_type": "code", |
137 | 147 | "execution_count": null, |
138 | 148 | "metadata": { |
139 | | - "collapsed": true |
| 149 | + "id": "3eVi7mJZOLTH" |
140 | 150 | }, |
141 | 151 | "outputs": [], |
142 | 152 | "source": [ |
|
149 | 159 | "cell_type": "code", |
150 | 160 | "execution_count": null, |
151 | 161 | "metadata": { |
152 | | - "collapsed": true |
| 162 | + "id": "EqMKMqdWOLTH" |
153 | 163 | }, |
154 | 164 | "outputs": [], |
155 | 165 | "source": [ |
|
164 | 174 | " nn.LogSoftmax(dim=1) # For using NLLLoss()\n", |
165 | 175 | ")\n", |
166 | 176 | "\n", |
| 177 | + "\n", |
| 178 | + "\n", |
167 | 179 | "# Convert model to be used on GPU\n", |
168 | | - "resnet50 = resnet50.to('cuda:0')\n" |
| 180 | + "resnet50 = resnet50.to(device)\n" |
169 | 181 | ] |
170 | 182 | }, |
171 | 183 | { |
172 | 184 | "cell_type": "code", |
173 | 185 | "execution_count": null, |
174 | 186 | "metadata": { |
175 | | - "collapsed": true |
| 187 | + "id": "3Lske-iPOLTI" |
176 | 188 | }, |
177 | 189 | "outputs": [], |
178 | 190 | "source": [ |
|
185 | 197 | "cell_type": "code", |
186 | 198 | "execution_count": null, |
187 | 199 | "metadata": { |
188 | | - "collapsed": true |
| 200 | + "id": "dyG4bCHHOLTI" |
189 | 201 | }, |
190 | 202 | "outputs": [], |
191 | 203 | "source": [ |
|
205 | 217 | " \n", |
206 | 218 | " start = time.time()\n", |
207 | 219 | " history = []\n", |
208 | | - " best_acc = 0.0\n", |
| 220 | + " best_loss = 100000.0\n", |
| 221 | + " best_epoch = None\n", |
209 | 222 | "\n", |
210 | 223 | " for epoch in range(epochs):\n", |
211 | 224 | " epoch_start = time.time()\n", |
|
256 | 269 | " \n", |
257 | 270 | " #print(\"Batch number: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}\".format(i, loss.item(), acc.item()))\n", |
258 | 271 | "\n", |
259 | | - " \n", |
| 272 | + " \n", |
260 | 273 | " # Validation - No gradient tracking needed\n", |
261 | 274 | " with torch.no_grad():\n", |
262 | 275 | "\n", |
|
288 | 301 | " valid_acc += acc.item() * inputs.size(0)\n", |
289 | 302 | "\n", |
290 | 303 | " #print(\"Validation Batch number: {:03d}, Validation: Loss: {:.4f}, Accuracy: {:.4f}\".format(j, loss.item(), acc.item()))\n", |
291 | | - " \n", |
| 304 | + " if valid_loss < best_loss:\n", |
| 305 | + " best_loss = valid_loss\n", |
| 306 | + " best_epoch = epoch\n", |
| 307 | + "\n", |
292 | 308 | " # Find average training loss and training accuracy\n", |
293 | 309 | " avg_train_loss = train_loss/train_data_size \n", |
294 | 310 | " avg_train_acc = train_acc/train_data_size\n", |
|
301 | 317 | " \n", |
302 | 318 | " epoch_end = time.time()\n", |
303 | 319 | " \n", |
304 | | - " print(\"Epoch : {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \\n\\t\\tValidation : Loss : {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s\".format(epoch, avg_train_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start))\n", |
| 320 | + " print(\"Epoch : {:03d}, Training: Loss - {:.4f}, Accuracy - {:.4f}%, \\n\\t\\tValidation : Loss - {:.4f}, Accuracy - {:.4f}%, Time: {:.4f}s\".format(epoch, avg_train_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start))\n", |
305 | 321 | " \n", |
306 | 322 | " # Save if the model has best accuracy till now\n", |
307 | 323 | " torch.save(model, dataset+'_model_'+str(epoch)+'.pt')\n", |
308 | 324 | " \n", |
309 | | - " return model, history\n", |
| 325 | + " return model, history, best_epoch\n", |
310 | 326 | " " |
311 | 327 | ] |
312 | 328 | }, |
313 | 329 | { |
314 | 330 | "cell_type": "code", |
315 | 331 | "execution_count": null, |
316 | 332 | "metadata": { |
317 | | - "collapsed": true |
| 333 | + "colab": { |
| 334 | + "base_uri": "https://localhost:8080/" |
| 335 | + }, |
| 336 | + "id": "GIVZTrxAOLTN", |
| 337 | + "outputId": "ebdcc4b0-a98e-4488-dddb-bd948e452cc0" |
318 | 338 | }, |
319 | 339 | "outputs": [], |
320 | 340 | "source": [ |
321 | | - "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", |
322 | | - "\n", |
323 | 341 | "# Print the model to be trained\n", |
324 | 342 | "#summary(resnet50, input_size=(3, 224, 224), batch_size=bs, device='cuda')\n", |
325 | 343 | "\n", |
326 | 344 | "# Train the model for 25 epochs\n", |
327 | 345 | "num_epochs = 30\n", |
328 | | - "trained_model, history = train_and_validate(resnet50, loss_func, optimizer, num_epochs)\n", |
| 346 | + "trained_model, history, best_epoch = train_and_validate(resnet50, loss_func, optimizer, num_epochs)\n", |
329 | 347 | "\n", |
330 | 348 | "torch.save(history, dataset+'_history.pt')" |
331 | 349 | ] |
|
334 | 352 | "cell_type": "code", |
335 | 353 | "execution_count": null, |
336 | 354 | "metadata": { |
337 | | - "collapsed": true |
| 355 | + "colab": { |
| 356 | + "base_uri": "https://localhost:8080/", |
| 357 | + "height": 283 |
| 358 | + }, |
| 359 | + "id": "pGcRnfbiOLTS", |
| 360 | + "outputId": "1326a7fe-4daf-45ea-9f69-7219e23d277b" |
338 | 361 | }, |
339 | 362 | "outputs": [], |
340 | 363 | "source": [ |
|
352 | 375 | "cell_type": "code", |
353 | 376 | "execution_count": null, |
354 | 377 | "metadata": { |
355 | | - "collapsed": true |
| 378 | + "colab": { |
| 379 | + "base_uri": "https://localhost:8080/", |
| 380 | + "height": 283 |
| 381 | + }, |
| 382 | + "id": "1e8oWgPWOLTS", |
| 383 | + "outputId": "c74fcaa4-1b41-436f-de02-8be4ee5c1a11" |
356 | 384 | }, |
357 | 385 | "outputs": [], |
358 | 386 | "source": [ |
|
369 | 397 | "cell_type": "code", |
370 | 398 | "execution_count": null, |
371 | 399 | "metadata": { |
372 | | - "collapsed": true |
| 400 | + "id": "FGHpMyvHOLTS" |
373 | 401 | }, |
374 | 402 | "outputs": [], |
375 | 403 | "source": [ |
|
429 | 457 | "cell_type": "code", |
430 | 458 | "execution_count": null, |
431 | 459 | "metadata": { |
432 | | - "collapsed": true |
| 460 | + "id": "QBHeFS7QOLTT" |
433 | 461 | }, |
434 | 462 | "outputs": [], |
435 | 463 | "source": [ |
|
444 | 472 | " \n", |
445 | 473 | " transform = image_transforms['test']\n", |
446 | 474 | "\n", |
| 475 | + "\n", |
447 | 476 | " test_image = Image.open(test_image_name)\n", |
448 | 477 | " plt.imshow(test_image)\n", |
449 | 478 | " \n", |
450 | 479 | " test_image_tensor = transform(test_image)\n", |
451 | | - "\n", |
452 | 480 | " if torch.cuda.is_available():\n", |
453 | 481 | " test_image_tensor = test_image_tensor.view(1, 3, 224, 224).cuda()\n", |
454 | 482 | " else:\n", |
|
459 | 487 | " # Model outputs log probabilities\n", |
460 | 488 | " out = model(test_image_tensor)\n", |
461 | 489 | " ps = torch.exp(out)\n", |
| 490 | + "\n", |
462 | 491 | " topk, topclass = ps.topk(3, dim=1)\n", |
| 492 | + " cls = idx_to_class[topclass.cpu().numpy()[0][0]]\n", |
| 493 | + " score = topk.cpu().numpy()[0][0]\n", |
| 494 | + "\n", |
463 | 495 | " for i in range(3):\n", |
464 | 496 | " print(\"Predcition\", i+1, \":\", idx_to_class[topclass.cpu().numpy()[0][i]], \", Score: \", topk.cpu().numpy()[0][i])\n", |
465 | 497 | "\n", |
|
470 | 502 | "cell_type": "code", |
471 | 503 | "execution_count": null, |
472 | 504 | "metadata": { |
473 | | - "collapsed": true |
| 505 | + "colab": { |
| 506 | + "base_uri": "https://localhost:8080/", |
| 507 | + "height": 507 |
| 508 | + }, |
| 509 | + "id": "vK8eHndnOLTU", |
| 510 | + "outputId": "581a2e73-b90f-446e-fc73-d4820adc4b68" |
474 | 511 | }, |
475 | 512 | "outputs": [], |
476 | 513 | "source": [ |
477 | 514 | "# Test a particular model on a test image\n", |
478 | | - "\n", |
| 515 | + "! wget https://cdn.pixabay.com/photo/2018/10/01/12/28/skunk-3716043_1280.jpg -O skunk.jpg\n", |
479 | 516 | "dataset = 'caltech_10'\n", |
480 | | - "model = torch.load('caltech_10_model_8.pt')\n", |
481 | | - "predict(model, 'pixabay-test-animals/triceratops-954293_640.jpg')\n", |
| 517 | + "model = torch.load(\"{}_model_{}.pt\".format(dataset, best_epoch))\n", |
| 518 | + "predict(model, 'skunk.jpg')\n", |
482 | 519 | "\n", |
483 | 520 | "# Load Data from folders\n", |
484 | 521 | "#computeTestSetAccuracy(model, loss_func)\n", |
|
488 | 525 | } |
489 | 526 | ], |
490 | 527 | "metadata": { |
| 528 | + "accelerator": "GPU", |
| 529 | + "colab": { |
| 530 | + "collapsed_sections": [], |
| 531 | + "name": "image_classification_using_transfer_learning_in_pytorch.ipynb", |
| 532 | + "provenance": [] |
| 533 | + }, |
491 | 534 | "kernelspec": { |
492 | 535 | "display_name": "Python 3", |
493 | 536 | "language": "python", |
|
503 | 546 | "name": "python", |
504 | 547 | "nbconvert_exporter": "python", |
505 | 548 | "pygments_lexer": "ipython3", |
506 | | - "version": "3.6.7" |
| 549 | + "version": "3.6.9" |
507 | 550 | } |
508 | 551 | }, |
509 | 552 | "nbformat": 4, |
510 | | - "nbformat_minor": 2 |
| 553 | + "nbformat_minor": 1 |
511 | 554 | } |
0 commit comments