|
4 | 4 | "cell_type": "code",
|
5 | 5 | "execution_count": null,
|
6 | 6 | "metadata": {
|
7 |
| - "collapsed": true |
| 7 | + "id": "SrL47MttOLS3" |
8 | 8 | },
|
9 | 9 | "outputs": [],
|
10 | 10 | "source": [
|
|
29 | 29 | "cell_type": "code",
|
30 | 30 | "execution_count": null,
|
31 | 31 | "metadata": {
|
32 |
| - "collapsed": true |
| 32 | + "id": "V3Jwg9JJOLTE" |
33 | 33 | },
|
34 | 34 | "outputs": [],
|
35 | 35 | "source": [
|
|
65 | 65 | "cell_type": "code",
|
66 | 66 | "execution_count": null,
|
67 | 67 | "metadata": {
|
68 |
| - "collapsed": true |
| 68 | + "colab": { |
| 69 | + "base_uri": "https://localhost:8080/" |
| 70 | + }, |
| 71 | + "id": "LqB1PotzOLTF", |
| 72 | + "outputId": "2fef74c0-eba7-4dc5-a35a-b238432ee2c3" |
69 | 73 | },
|
70 | 74 | "outputs": [],
|
71 | 75 | "source": [
|
|
83 | 87 | "bs = 32\n",
|
84 | 88 | "\n",
|
85 | 89 | "# Number of classes\n",
|
86 |
| - "num_classes = len(os.listdir(valid_directory))-1 #10#2#257\n", |
| 90 | + "num_classes = len(os.listdir(valid_directory)) #10#2#257\n", |
87 | 91 | "print(num_classes)\n",
|
88 | 92 | "\n",
|
89 | 93 | "# Load Data from folders\n",
|
|
112 | 116 | "cell_type": "code",
|
113 | 117 | "execution_count": null,
|
114 | 118 | "metadata": {
|
115 |
| - "collapsed": true |
| 119 | + "colab": { |
| 120 | + "base_uri": "https://localhost:8080/" |
| 121 | + }, |
| 122 | + "id": "2mUcSMIMOLTG", |
| 123 | + "outputId": "88cc5e5d-2c37-41e6-ef9d-72ca9237efd8" |
116 | 124 | },
|
117 | 125 | "outputs": [],
|
118 | 126 | "source": [
|
119 |
| - "train_data_size, valid_data_size, test_data_size" |
| 127 | + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", |
| 128 | + "\n", |
| 129 | + "print(train_data_size, valid_data_size, test_data_size)" |
120 | 130 | ]
|
121 | 131 | },
|
122 | 132 | {
|
123 | 133 | "cell_type": "code",
|
124 | 134 | "execution_count": null,
|
125 | 135 | "metadata": {
|
126 |
| - "collapsed": true |
| 136 | + "id": "ViY8viSLOLTH" |
127 | 137 | },
|
128 | 138 | "outputs": [],
|
129 | 139 | "source": [
|
130 | 140 | "# Load pretrained ResNet50 Model\n",
|
131 | 141 | "resnet50 = models.resnet50(pretrained=True)\n",
|
132 |
| - "resnet50 = resnet50.to('cuda:0')\n" |
| 142 | + "resnet50 = resnet50.to(device)" |
133 | 143 | ]
|
134 | 144 | },
|
135 | 145 | {
|
136 | 146 | "cell_type": "code",
|
137 | 147 | "execution_count": null,
|
138 | 148 | "metadata": {
|
139 |
| - "collapsed": true |
| 149 | + "id": "3eVi7mJZOLTH" |
140 | 150 | },
|
141 | 151 | "outputs": [],
|
142 | 152 | "source": [
|
|
149 | 159 | "cell_type": "code",
|
150 | 160 | "execution_count": null,
|
151 | 161 | "metadata": {
|
152 |
| - "collapsed": true |
| 162 | + "id": "EqMKMqdWOLTH" |
153 | 163 | },
|
154 | 164 | "outputs": [],
|
155 | 165 | "source": [
|
|
164 | 174 | " nn.LogSoftmax(dim=1) # For using NLLLoss()\n",
|
165 | 175 | ")\n",
|
166 | 176 | "\n",
|
| 177 | + "\n", |
| 178 | + "\n", |
167 | 179 | "# Convert model to be used on GPU\n",
|
168 |
| - "resnet50 = resnet50.to('cuda:0')\n" |
| 180 | + "resnet50 = resnet50.to(device)\n" |
169 | 181 | ]
|
170 | 182 | },
|
171 | 183 | {
|
172 | 184 | "cell_type": "code",
|
173 | 185 | "execution_count": null,
|
174 | 186 | "metadata": {
|
175 |
| - "collapsed": true |
| 187 | + "id": "3Lske-iPOLTI" |
176 | 188 | },
|
177 | 189 | "outputs": [],
|
178 | 190 | "source": [
|
|
185 | 197 | "cell_type": "code",
|
186 | 198 | "execution_count": null,
|
187 | 199 | "metadata": {
|
188 |
| - "collapsed": true |
| 200 | + "id": "dyG4bCHHOLTI" |
189 | 201 | },
|
190 | 202 | "outputs": [],
|
191 | 203 | "source": [
|
|
205 | 217 | " \n",
|
206 | 218 | " start = time.time()\n",
|
207 | 219 | " history = []\n",
|
208 |
| - " best_acc = 0.0\n", |
| 220 | + " best_loss = 100000.0\n", |
| 221 | + " best_epoch = None\n", |
209 | 222 | "\n",
|
210 | 223 | " for epoch in range(epochs):\n",
|
211 | 224 | " epoch_start = time.time()\n",
|
|
256 | 269 | " \n",
|
257 | 270 | " #print(\"Batch number: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}\".format(i, loss.item(), acc.item()))\n",
|
258 | 271 | "\n",
|
259 |
| - " \n", |
| 272 | + " \n", |
260 | 273 | " # Validation - No gradient tracking needed\n",
|
261 | 274 | " with torch.no_grad():\n",
|
262 | 275 | "\n",
|
|
288 | 301 | " valid_acc += acc.item() * inputs.size(0)\n",
|
289 | 302 | "\n",
|
290 | 303 | " #print(\"Validation Batch number: {:03d}, Validation: Loss: {:.4f}, Accuracy: {:.4f}\".format(j, loss.item(), acc.item()))\n",
|
291 |
| - " \n", |
| 304 | + " if valid_loss < best_loss:\n", |
| 305 | + " best_loss = valid_loss\n", |
| 306 | + " best_epoch = epoch\n", |
| 307 | + "\n", |
292 | 308 | " # Find average training loss and training accuracy\n",
|
293 | 309 | " avg_train_loss = train_loss/train_data_size \n",
|
294 | 310 | " avg_train_acc = train_acc/train_data_size\n",
|
|
301 | 317 | " \n",
|
302 | 318 | " epoch_end = time.time()\n",
|
303 | 319 | " \n",
|
304 |
| - " print(\"Epoch : {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \\n\\t\\tValidation : Loss : {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s\".format(epoch, avg_train_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start))\n", |
| 320 | + " print(\"Epoch : {:03d}, Training: Loss - {:.4f}, Accuracy - {:.4f}%, \\n\\t\\tValidation : Loss - {:.4f}, Accuracy - {:.4f}%, Time: {:.4f}s\".format(epoch, avg_train_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start))\n", |
305 | 321 | " \n",
|
306 | 322 | " # Save if the model has best accuracy till now\n",
|
307 | 323 | " torch.save(model, dataset+'_model_'+str(epoch)+'.pt')\n",
|
308 | 324 | " \n",
|
309 |
| - " return model, history\n", |
| 325 | + " return model, history, best_epoch\n", |
310 | 326 | " "
|
311 | 327 | ]
|
312 | 328 | },
|
313 | 329 | {
|
314 | 330 | "cell_type": "code",
|
315 | 331 | "execution_count": null,
|
316 | 332 | "metadata": {
|
317 |
| - "collapsed": true |
| 333 | + "colab": { |
| 334 | + "base_uri": "https://localhost:8080/" |
| 335 | + }, |
| 336 | + "id": "GIVZTrxAOLTN", |
| 337 | + "outputId": "ebdcc4b0-a98e-4488-dddb-bd948e452cc0" |
318 | 338 | },
|
319 | 339 | "outputs": [],
|
320 | 340 | "source": [
|
321 |
| - "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", |
322 |
| - "\n", |
323 | 341 | "# Print the model to be trained\n",
|
324 | 342 | "#summary(resnet50, input_size=(3, 224, 224), batch_size=bs, device='cuda')\n",
|
325 | 343 | "\n",
|
326 | 344 | "# Train the model for 25 epochs\n",
|
327 | 345 | "num_epochs = 30\n",
|
328 |
| - "trained_model, history = train_and_validate(resnet50, loss_func, optimizer, num_epochs)\n", |
| 346 | + "trained_model, history, best_epoch = train_and_validate(resnet50, loss_func, optimizer, num_epochs)\n", |
329 | 347 | "\n",
|
330 | 348 | "torch.save(history, dataset+'_history.pt')"
|
331 | 349 | ]
|
|
334 | 352 | "cell_type": "code",
|
335 | 353 | "execution_count": null,
|
336 | 354 | "metadata": {
|
337 |
| - "collapsed": true |
| 355 | + "colab": { |
| 356 | + "base_uri": "https://localhost:8080/", |
| 357 | + "height": 283 |
| 358 | + }, |
| 359 | + "id": "pGcRnfbiOLTS", |
| 360 | + "outputId": "1326a7fe-4daf-45ea-9f69-7219e23d277b" |
338 | 361 | },
|
339 | 362 | "outputs": [],
|
340 | 363 | "source": [
|
|
352 | 375 | "cell_type": "code",
|
353 | 376 | "execution_count": null,
|
354 | 377 | "metadata": {
|
355 |
| - "collapsed": true |
| 378 | + "colab": { |
| 379 | + "base_uri": "https://localhost:8080/", |
| 380 | + "height": 283 |
| 381 | + }, |
| 382 | + "id": "1e8oWgPWOLTS", |
| 383 | + "outputId": "c74fcaa4-1b41-436f-de02-8be4ee5c1a11" |
356 | 384 | },
|
357 | 385 | "outputs": [],
|
358 | 386 | "source": [
|
|
369 | 397 | "cell_type": "code",
|
370 | 398 | "execution_count": null,
|
371 | 399 | "metadata": {
|
372 |
| - "collapsed": true |
| 400 | + "id": "FGHpMyvHOLTS" |
373 | 401 | },
|
374 | 402 | "outputs": [],
|
375 | 403 | "source": [
|
|
429 | 457 | "cell_type": "code",
|
430 | 458 | "execution_count": null,
|
431 | 459 | "metadata": {
|
432 |
| - "collapsed": true |
| 460 | + "id": "QBHeFS7QOLTT" |
433 | 461 | },
|
434 | 462 | "outputs": [],
|
435 | 463 | "source": [
|
|
444 | 472 | " \n",
|
445 | 473 | " transform = image_transforms['test']\n",
|
446 | 474 | "\n",
|
| 475 | + "\n", |
447 | 476 | " test_image = Image.open(test_image_name)\n",
|
448 | 477 | " plt.imshow(test_image)\n",
|
449 | 478 | " \n",
|
450 | 479 | " test_image_tensor = transform(test_image)\n",
|
451 |
| - "\n", |
452 | 480 | " if torch.cuda.is_available():\n",
|
453 | 481 | " test_image_tensor = test_image_tensor.view(1, 3, 224, 224).cuda()\n",
|
454 | 482 | " else:\n",
|
|
459 | 487 | " # Model outputs log probabilities\n",
|
460 | 488 | " out = model(test_image_tensor)\n",
|
461 | 489 | " ps = torch.exp(out)\n",
|
| 490 | + "\n", |
462 | 491 | " topk, topclass = ps.topk(3, dim=1)\n",
|
| 492 | + " cls = idx_to_class[topclass.cpu().numpy()[0][0]]\n", |
| 493 | + " score = topk.cpu().numpy()[0][0]\n", |
| 494 | + "\n", |
463 | 495 | " for i in range(3):\n",
|
464 | 496 | " print(\"Predcition\", i+1, \":\", idx_to_class[topclass.cpu().numpy()[0][i]], \", Score: \", topk.cpu().numpy()[0][i])\n",
|
465 | 497 | "\n",
|
|
470 | 502 | "cell_type": "code",
|
471 | 503 | "execution_count": null,
|
472 | 504 | "metadata": {
|
473 |
| - "collapsed": true |
| 505 | + "colab": { |
| 506 | + "base_uri": "https://localhost:8080/", |
| 507 | + "height": 507 |
| 508 | + }, |
| 509 | + "id": "vK8eHndnOLTU", |
| 510 | + "outputId": "581a2e73-b90f-446e-fc73-d4820adc4b68" |
474 | 511 | },
|
475 | 512 | "outputs": [],
|
476 | 513 | "source": [
|
477 | 514 | "# Test a particular model on a test image\n",
|
478 |
| - "\n", |
| 515 | + "! wget https://cdn.pixabay.com/photo/2018/10/01/12/28/skunk-3716043_1280.jpg -O skunk.jpg\n", |
479 | 516 | "dataset = 'caltech_10'\n",
|
480 |
| - "model = torch.load('caltech_10_model_8.pt')\n", |
481 |
| - "predict(model, 'pixabay-test-animals/triceratops-954293_640.jpg')\n", |
| 517 | + "model = torch.load(\"{}_model_{}.pt\".format(dataset, best_epoch))\n", |
| 518 | + "predict(model, 'skunk.jpg')\n", |
482 | 519 | "\n",
|
483 | 520 | "# Load Data from folders\n",
|
484 | 521 | "#computeTestSetAccuracy(model, loss_func)\n",
|
|
488 | 525 | }
|
489 | 526 | ],
|
490 | 527 | "metadata": {
|
| 528 | + "accelerator": "GPU", |
| 529 | + "colab": { |
| 530 | + "collapsed_sections": [], |
| 531 | + "name": "image_classification_using_transfer_learning_in_pytorch.ipynb", |
| 532 | + "provenance": [] |
| 533 | + }, |
491 | 534 | "kernelspec": {
|
492 | 535 | "display_name": "Python 3",
|
493 | 536 | "language": "python",
|
|
503 | 546 | "name": "python",
|
504 | 547 | "nbconvert_exporter": "python",
|
505 | 548 | "pygments_lexer": "ipython3",
|
506 |
| - "version": "3.6.7" |
| 549 | + "version": "3.6.9" |
507 | 550 | }
|
508 | 551 | },
|
509 | 552 | "nbformat": 4,
|
510 |
| - "nbformat_minor": 2 |
| 553 | + "nbformat_minor": 1 |
511 | 554 | }
|
0 commit comments