@@ -26,7 +26,7 @@ def eval(self, x1, x2):
26
26
27
27
class SVM :
28
28
# Support Vector Machine Classifier
29
- def __init__ (self , kernel , C ):
29
+ def __init__ (self , C , kernel = LinKernel () ):
30
30
self .k = kernel
31
31
self .C = C
32
32
self .optimizer = SMO (kernel , C )
@@ -38,7 +38,7 @@ def train(self, X, y):
38
38
39
39
# use the SMO module to compute alphas and b
40
40
self .alphas = self .optimizer .compute_alphas (X ,y )
41
- self .b = self .optimizer .threshold
41
+ self .b = self .optimizer .b
42
42
43
43
def _eval (self , x ):
44
44
# evaluate the SVM on a single example
@@ -51,7 +51,10 @@ def _eval(self, x):
51
51
52
52
def eval (self , X ):
53
53
# evaluate a matrix of example vectors
54
- return np .vectorize (self ._eval )(X )
54
+ result = np .zeros (len (X ))
55
+ for i in xrange (len (X )):
56
+ result [i ] = self ._eval (X [i ])
57
+ return result
55
58
56
59
def classify (self , X ):
57
60
# classify a matrix of example vectors
@@ -63,3 +66,36 @@ def test(self, X, y):
63
66
guess = self .classify (X )
64
67
error [guess != y ] = 1
65
68
return np .float (np .sum (error )) / len (X )
69
+
70
+ def findC (self , X , y , count = 50 , kfolds = 5 ):
71
+ # find a good estimate of C with kfold cross validation
72
+
73
+ yt = y .reshape (len (y ), 1 )
74
+ data = np .hstack ([yt ,X ])
75
+ np .random .shuffle (data )
76
+ partitions = np .array_split (data , kfolds )
77
+ candidates = np .logspace (0 , 5 , num = count , base = np .e )
78
+
79
+ minErr = np .inf
80
+ C = candidates [0 ]
81
+ for c in candidates :
82
+ errors = np .zeros (kfolds )
83
+ for i in range (kfolds ):
84
+ test = partitions [i ]
85
+ train = np .vstack ([partitions [x ] for x in range (kfolds ) if x != i ])
86
+ testy = test [:,0 ]
87
+ testx = test [:,1 :]
88
+ trainy = train [:,0 ]
89
+ trainx = train [:,1 :]
90
+ temp = SVM (c , self .k )
91
+ temp .train (trainx , trainy )
92
+ errors [i ] = temp .test (testx , testy )
93
+ err = np .mean (errors )
94
+ print c , err , minErr
95
+ if err < minErr :
96
+ C = c
97
+ minErr = err
98
+
99
+ print candidates
100
+ print "Final value: " , C
101
+ return C
0 commit comments