Skip to content

Commit 7750407

Browse files
committed
added loading satimage files
1 parent 9e54e2a commit 7750407

File tree

1 file changed

+44
-3
lines changed

1 file changed

+44
-3
lines changed

svmplot.py

+44-3
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,55 @@
88
# numeric stuff
99
import numpy as np
1010

11+
# plotting data
12+
from matplotlib import pyplot as pl
13+
1114
# generate test data
1215
from sklearn.datasets.samples_generator import make_blobs
1316

1417
# actual svm module
1518
import svm
1619

17-
def make_2d_data(n_samples):
18-
# generate two clusters of
19-
return make_blobs(n_samples=n_samples, centers=2)
20+
def make2dData(n_samples):
21+
# generate a 2d 2-class dataset
22+
X,y = make_blobs(n_samples=n_samples, centers=2)
23+
y[y==0] = -1
24+
return X,y
2025

26+
def readSatData(fname='tr'):
27+
data = []
28+
with open('data/satimage.scale.' + fname, 'r') as f:
29+
for line in f.readlines():
30+
s = line.split()
31+
row = np.zeros(37)
32+
# set label
33+
if s[0] == '6':
34+
row[0] = 1
35+
else:
36+
row[0] = -1
37+
hand = 1
38+
for i in range(1,37):
39+
k,v = s[hand].split(':')
40+
if int(k) == i:
41+
hand += 1
42+
row[i] = float(v)
43+
else:
44+
row[i] = 0
45+
data.append(row)
46+
data = np.vstack(data)
47+
y = data[:,0]
48+
X = data[:,1:]
49+
return X,y
2150

51+
s = svm.SVM(10)
52+
#X,y = make2dData(200)
53+
#c.train(X,y)
54+
#print c.alphas
55+
#print "final error", c.test(X,y)
56+
#print c.findC(X,y,count=5)
57+
X,y = readSatData()
58+
c = s.findC(X,y,count=10,kfolds=3)
59+
s = svm.SVM(c)
60+
s.train(X,y)
61+
Xt,yt = readSatData('t')
62+
print "final error", s.test(Xt,yt)

0 commit comments

Comments
 (0)