Skip to content

Commit 64417e5

Browse files
Create MNIST Classification Model..py
1 parent 93229ff commit 64417e5

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

MNIST Classification Model..py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import torch as tch
2+
import torchvision.datasets as dt
3+
import torchvision.transforms as trans
4+
import torch.nn as nn
5+
import matplotlib.pyplot as plt
6+
from time import time
7+
8+
train = dt.MNIST(root="./datasets", train=True, transform=trans.ToTensor(), download=True)
9+
test = dt.MNIST(root="./datasets", train=False, transform=trans.ToTensor(), download=True)
10+
print("No. of Training examples: ",len(train))
11+
print("No. of Test examples: ",len(test))
12+
13+
train_batch = tch.utils.data.DataLoader(train, batch_size=30, shuffle=True)
14+
15+
16+
input = 784
17+
hidden = 490
18+
output = 10
19+
20+
model = nn.Sequential(nn.Linear(input, hidden),
21+
nn.LeakyReLU(),
22+
nn.Linear(hidden, output),
23+
nn.LogSoftmax(dim=1))
24+
25+
lossfn = nn.NLLLoss()
26+
images, labels = next(iter(train_batch))
27+
images = images.view(images.shape[0], -1)
28+
29+
logps = model(images)
30+
loss = lossfn(logps, labels)
31+
loss.backward()
32+
33+
optimize = tch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
34+
time_start = time()
35+
epochs = 18
36+
for num in range(epochs):
37+
run=0
38+
for images, labels in train_batch:
39+
images = images.view(images.shape[0], -1)
40+
optimize.zero_grad()
41+
output = model(images)
42+
loss = lossfn(output, labels)
43+
loss.backward()
44+
optimize.step()
45+
run += loss.item()
46+
else:
47+
print("Epoch Number : {} = Loss : {}".format(num, run/len(train_batch)))
48+
Elapsed=(time()-time_start)/60
49+
print("\nTraining Time (in minutes) : ",Elapsed)
50+
51+
correct=0
52+
all = 0
53+
for images,labels in test:
54+
img = images.view(1, 784)
55+
with tch.no_grad():
56+
logps = model(img)
57+
ps = tch.exp(logps)
58+
probab = list(ps.numpy()[0])
59+
prediction = probab.index(max(probab))
60+
truth = labels
61+
if(truth == prediction):
62+
correct += 1
63+
all += 1
64+
65+
print("Number Of Images Tested : ", all)
66+
print("Model Accuracy : ", (correct/all))
67+
68+
tch.save(model, './mnist_model.pt')

0 commit comments

Comments
 (0)