Skip to content

Commit ced1252

Browse files
committed
MNIST classification using perceptron
1 parent dd8e67d commit ced1252

File tree

1 file changed

+122
-0
lines changed

1 file changed

+122
-0
lines changed

Machine_Learning/perceptron_mnist.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Author: Abhinav Dhere (abhitechnical41[at]gmail.com)
2+
# Originally written as { Problem 1, Assignment 1, SM in AI (CSE 471) - 2017 IIIT Hyderabad }
3+
# Perceptron based classification for MNIST database ; Options available - single update ; batchwise update; with or without margin
4+
# Coded from scratch, dependency is only Numpy.
5+
# Expects data in CSV files.
6+
7+
import numpy as np
8+
import sys
9+
import math
10+
import time
11+
12+
def getConfMatrix(pLabels,labels):
13+
'''
14+
Obtain confusion matrix for the classification performed on test data.
15+
'''
16+
TN=0;FN=0;FP=0;TP=0;
17+
for id in pLabels.keys():
18+
# print(id)
19+
# print str(pLabels[id])+' '+str(labels[id])
20+
if pLabels[id]==0 and labels[id]==0:
21+
TN+=1
22+
elif pLabels[id]==0 and labels[id]==1:
23+
FN+=1
24+
elif pLabels[id]==1 and labels[id]==0:
25+
FP+=1
26+
elif pLabels[id]==1 and labels[id]==1:
27+
TP+=1
28+
confMat = np.array([[TN, FP],[FN,TP]])
29+
return confMat
30+
31+
def getStats(labelsP,labels):
32+
c = getConfMatrix(labelsP,labels)
33+
acc = (c[1,1]+c[0,0])/float(c[0,0]+c[0,1]+c[1,0]+c[1,1])
34+
recall = c[1,1]/float(c[1,1]+c[1,0])
35+
print("Accuracy: "+str(acc*100))
36+
print("Recall: "+str(recall*100))
37+
38+
def readFile(filename,datType):
39+
'''
40+
Read csv file specified by filename into two separate numpy arrays, one for features and one for header i.e. names of features.
41+
'''
42+
if datType==0:
43+
colNos = range(1,785)
44+
data = np.genfromtxt(filename,dtype=int,delimiter=',',autostrip=True,usecols=colNos)
45+
labels = np.genfromtxt(filename,dtype=int,delimiter=',',autostrip=True,usecols=[0])
46+
return data,labels
47+
elif datType==1:
48+
data = np.genfromtxt(filename,dtype=int,delimiter=',',autostrip=True)
49+
return data
50+
51+
def predict(w,x,i,margin):
52+
if (np.dot(np.transpose(w),x[i])>=margin):
53+
pred = 1
54+
elif (np.dot(np.transpose(w),x[i])<margin):
55+
pred = 0
56+
return pred
57+
58+
def augment(data):
59+
aug = np.ones((data.shape[0],1))
60+
x = np.concatenate((aug,data),axis=1)
61+
return x
62+
63+
def train(data_train_file,method,margin):
64+
[data,labels] = readFile(data_train_file,0)
65+
x = augment(data)
66+
w = np.random.rand(x.shape[1])
67+
eta = 1
68+
#w = w + err*x[0,:]
69+
if method=='single':
70+
for i in range(x.shape[0]):
71+
predVal = predict(w,x,i,margin)
72+
err = labels[i]-predVal
73+
if err!=0:
74+
w = w + (err*eta)*x[i,:]
75+
elif method=='batch':
76+
z = [label if label==1 else -1 for label in labels]
77+
w = w + sum([z[num]*x[num] for num in range(x.shape[0])])
78+
lenValue = x.shape[0]
79+
oldDefaulters = range(x.shape[0])
80+
while(lenValue>1):
81+
defaulters = []
82+
for i in oldDefaulters:
83+
if (np.dot(np.transpose(w),z[i]*x[i])<=0):
84+
defaulters.append(i)
85+
x_sum = sum([z[j]*x[j] for j in defaulters])
86+
w = w + eta*x_sum
87+
lenValue = len(defaulters)
88+
oldDefaulters = defaulters
89+
return w
90+
91+
def test(w,data_test_file,margin):
92+
data_test = readFile(data_test_file,1)
93+
x = augment(data_test)
94+
predVal = {}
95+
for i in range(data_test.shape[0]):
96+
predVal[i] = predict(w,x,i,margin)
97+
return predVal
98+
99+
100+
def classify(trainFile,testFile,method,margin):
101+
w = train(data_train_file,method,margin)
102+
labels_pred = test(w,data_test_file,margin)
103+
#getStats(labels_pred,labels)
104+
for id in labels_pred.keys():
105+
print(labels_pred[id])
106+
107+
if __name__ == "__main__":
108+
# start_time = time.time()
109+
data_train_file = sys.argv[1]
110+
data_test_file = sys.argv[2]
111+
112+
# Single sample perceptron ==>
113+
# Without margin
114+
classify(data_train_file,data_test_file,'single',0)
115+
# With margin
116+
classify(data_train_file,data_test_file,'single',6)
117+
# Batch perceptron ==>
118+
#Without margin
119+
classify(data_train_file,data_test_file,'batch',0)
120+
#With margin
121+
classify(data_train_file,data_test_file,'batch',6)
122+
# print("--- %s seconds ---" % (time.time() - start_time))

0 commit comments

Comments
 (0)