|
| 1 | +import torch as tch |
| 2 | +import torchvision.datasets as dt |
| 3 | +import torchvision.transforms as trans |
| 4 | +import torch.nn as nn |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +from time import time |
| 7 | + |
| 8 | +train = dt.MNIST(root="./datasets", train=True, transform=trans.ToTensor(), download=True) |
| 9 | +test = dt.MNIST(root="./datasets", train=False, transform=trans.ToTensor(), download=True) |
| 10 | +print("No. of Training examples: ",len(train)) |
| 11 | +print("No. of Test examples: ",len(test)) |
| 12 | + |
| 13 | +train_batch = tch.utils.data.DataLoader(train, batch_size=30, shuffle=True) |
| 14 | + |
| 15 | + |
| 16 | +input = 784 |
| 17 | +hidden = 490 |
| 18 | +output = 10 |
| 19 | + |
| 20 | +model = nn.Sequential(nn.Linear(input, hidden), |
| 21 | + nn.LeakyReLU(), |
| 22 | + nn.Linear(hidden, output), |
| 23 | + nn.LogSoftmax(dim=1)) |
| 24 | + |
| 25 | +lossfn = nn.NLLLoss() |
| 26 | +images, labels = next(iter(train_batch)) |
| 27 | +images = images.view(images.shape[0], -1) |
| 28 | + |
| 29 | +logps = model(images) |
| 30 | +loss = lossfn(logps, labels) |
| 31 | +loss.backward() |
| 32 | + |
| 33 | +optimize = tch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9) |
| 34 | +time_start = time() |
| 35 | +epochs = 18 |
| 36 | +for num in range(epochs): |
| 37 | + run=0 |
| 38 | + for images, labels in train_batch: |
| 39 | + images = images.view(images.shape[0], -1) |
| 40 | + optimize.zero_grad() |
| 41 | + output = model(images) |
| 42 | + loss = lossfn(output, labels) |
| 43 | + loss.backward() |
| 44 | + optimize.step() |
| 45 | + run += loss.item() |
| 46 | + else: |
| 47 | + print("Epoch Number : {} = Loss : {}".format(num, run/len(train_batch))) |
| 48 | +Elapsed=(time()-time_start)/60 |
| 49 | +print("\nTraining Time (in minutes) : ",Elapsed) |
| 50 | + |
| 51 | +correct=0 |
| 52 | +all = 0 |
| 53 | +for images,labels in test: |
| 54 | + img = images.view(1, 784) |
| 55 | + with tch.no_grad(): |
| 56 | + logps = model(img) |
| 57 | + ps = tch.exp(logps) |
| 58 | + probab = list(ps.numpy()[0]) |
| 59 | + prediction = probab.index(max(probab)) |
| 60 | + truth = labels |
| 61 | + if(truth == prediction): |
| 62 | + correct += 1 |
| 63 | + all += 1 |
| 64 | + |
| 65 | +print("Number Of Images Tested : ", all) |
| 66 | +print("Model Accuracy : ", (correct/all)) |
| 67 | + |
| 68 | +tch.save(model, './mnist_model.pt') |
0 commit comments