Skip to content

Commit 9e54e2a

Browse files
committed
finished svm
1 parent 63a96be commit 9e54e2a

File tree

1 file changed

+39
-3
lines changed

1 file changed

+39
-3
lines changed

svm.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def eval(self, x1, x2):
2626

2727
class SVM:
2828
# Support Vector Machine Classifier
29-
def __init__(self, kernel, C):
29+
def __init__(self, C, kernel=LinKernel()):
3030
self.k = kernel
3131
self.C = C
3232
self.optimizer = SMO(kernel, C)
@@ -38,7 +38,7 @@ def train(self, X, y):
3838

3939
# use the SMO module to compute alphas and b
4040
self.alphas = self.optimizer.compute_alphas(X,y)
41-
self.b = self.optimizer.threshold
41+
self.b = self.optimizer.b
4242

4343
def _eval(self, x):
4444
# evaluate the SVM on a single example
@@ -51,7 +51,10 @@ def _eval(self, x):
5151

5252
def eval(self, X):
5353
# evaluate a matrix of example vectors
54-
return np.vectorize(self._eval)(X)
54+
result = np.zeros(len(X))
55+
for i in xrange(len(X)):
56+
result[i] = self._eval(X[i])
57+
return result
5558

5659
def classify(self, X):
5760
# classify a matrix of example vectors
@@ -63,3 +66,36 @@ def test(self, X, y):
6366
guess = self.classify(X)
6467
error[guess != y] = 1
6568
return np.float(np.sum(error)) / len(X)
69+
70+
def findC(self, X, y, count=50, kfolds=5):
71+
# find a good estimate of C with kfold cross validation
72+
73+
yt = y.reshape(len(y), 1)
74+
data = np.hstack([yt,X])
75+
np.random.shuffle(data)
76+
partitions = np.array_split(data, kfolds)
77+
candidates = np.logspace(0, 5, num=count, base=np.e)
78+
79+
minErr = np.inf
80+
C = candidates[0]
81+
for c in candidates:
82+
errors = np.zeros(kfolds)
83+
for i in range(kfolds):
84+
test = partitions[i]
85+
train = np.vstack([partitions[x] for x in range(kfolds) if x != i])
86+
testy = test[:,0]
87+
testx = test[:,1:]
88+
trainy = train[:,0]
89+
trainx = train[:,1:]
90+
temp = SVM(c, self.k)
91+
temp.train(trainx, trainy)
92+
errors[i] = temp.test(testx, testy)
93+
err = np.mean(errors)
94+
print c, err, minErr
95+
if err < minErr:
96+
C = c
97+
minErr = err
98+
99+
print candidates
100+
print "Final value: ", C
101+
return C

0 commit comments

Comments
 (0)