Skip to content

Commit b53d626

Browse files
committed
small modifications
1 parent 7750407 commit b53d626

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

svm.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, sigma):
2222
self.gamma = 1 / (2.0 * sigma**2)
2323

2424
def eval(self, x1, x2):
25-
return np.exp(-gamma * lin.norm(x1 - x2)**2)
25+
return np.exp(-self.gamma * lin.norm(x1 - x2)**2)
2626

2727
class SVM:
2828
# Support Vector Machine Classifier
@@ -67,6 +67,13 @@ def test(self, X, y):
6767
error[guess != y] = 1
6868
return np.float(np.sum(error)) / len(X)
6969

70+
def countSupVectors(self):
71+
count = 0
72+
for a in self.alphas:
73+
if a == 0:
74+
count += 1
75+
return count
76+
7077
def findC(self, X, y, count=50, kfolds=5):
7178
# find a good estimate of C with kfold cross validation
7279

@@ -75,11 +82,14 @@ def findC(self, X, y, count=50, kfolds=5):
7582
np.random.shuffle(data)
7683
partitions = np.array_split(data, kfolds)
7784
candidates = np.logspace(0, 5, num=count, base=np.e)
85+
err = np.zeros(count)
86+
svec = err.copy()
7887

7988
minErr = np.inf
8089
C = candidates[0]
81-
for c in candidates:
90+
for index, c in enumerate(candidates):
8291
errors = np.zeros(kfolds)
92+
supvec_count = errors.copy()
8393
for i in range(kfolds):
8494
test = partitions[i]
8595
train = np.vstack([partitions[x] for x in range(kfolds) if x != i])
@@ -90,12 +100,16 @@ def findC(self, X, y, count=50, kfolds=5):
90100
temp = SVM(c, self.k)
91101
temp.train(trainx, trainy)
92102
errors[i] = temp.test(testx, testy)
93-
err = np.mean(errors)
94-
print c, err, minErr
95-
if err < minErr:
103+
supvec_count[i] = temp.countSupVectors()
104+
err[index] = np.mean(errors)
105+
svec[index] = np.mean(supvec_count)
106+
print c, err
107+
if err[index] < minErr:
96108
C = c
97-
minErr = err
109+
minErr = err[index]
98110

99-
print candidates
111+
print "C, err, #svec"
112+
for i in range(count):
113+
print candidates[i], err[i], svec[i]
100114
print "Final value: ", C
101115
return C

svmplot.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,17 @@ def readSatData(fname='tr'):
4848
X = data[:,1:]
4949
return X,y
5050

51-
s = svm.SVM(10)
51+
r = svm.RBFKernel(2)
52+
s = svm.SVM(10,kernel=r)
5253
#X,y = make2dData(200)
5354
#c.train(X,y)
5455
#print c.alphas
55-
#print "final error", c.test(X,y)
5656
#print c.findC(X,y,count=5)
5757
X,y = readSatData()
58-
c = s.findC(X,y,count=10,kfolds=3)
59-
s = svm.SVM(c)
58+
#c = s.findC(X,y,count=10,kfolds=3)
59+
c = 8.3
60+
s = svm.SVM(c,kernel=svm.RBFKernel(2))
6061
s.train(X,y)
6162
Xt,yt = readSatData('t')
6263
print "final error", s.test(Xt,yt)
64+
#print "final error", s.test(X,y)

0 commit comments

Comments
 (0)