Skip to content

Commit 6ef9cd6

Browse files
authored
Add files via upload
1 parent f2ca48c commit 6ef9cd6

File tree

2 files changed

+415
-0
lines changed

2 files changed

+415
-0
lines changed

Demo_PSO.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import numpy as np
2+
import pandas as pd
3+
from sklearn.neighbors import KNeighborsClassifier
4+
from sklearn.model_selection import train_test_split
5+
from FS.pso import jfs
6+
import matplotlib.pyplot as plt
7+
8+
9+
# load data
10+
data = pd.read_csv('ionosphere.csv')
11+
data = data.values
12+
feat = np.asarray(data[:, 0:-1])
13+
label = np.asarray(data[:, -1])
14+
15+
# split data into train & validation using k-fold cross-validation
16+
xtrain, xtest, ytrain, ytest = train_test_split(feat, label, test_size=0.3, stratify=label)
17+
fold = {'xt':xtrain, 'yt':ytrain, 'xv':xtest, 'yv':ytest}
18+
19+
# feature selection
20+
k = 5
21+
opts = {'k':k, 'fold':fold, 'N':10, 'T':100, 'w':0.9, 'c1':2, 'c2':2}
22+
23+
fmdl = jfs(feat, label, opts)
24+
sf = fmdl['sf']
25+
26+
# model with selected features
27+
num_train = np.size(xtrain, 0)
28+
num_valid = np.size(xtest, 0)
29+
x_train = xtrain[:, sf]
30+
y_train = ytrain.reshape(num_train) # Solve bug
31+
x_valid = xtest[:, sf]
32+
y_valid = ytest.reshape(num_valid) # Solve bug
33+
34+
mdl = KNeighborsClassifier(n_neighbors = k)
35+
mdl.fit(x_train, y_train)
36+
37+
# validation accuracy
38+
pred = mdl.predict(x_valid)
39+
correct = 0
40+
for i in range(num_valid):
41+
if pred[i] == y_valid[i]:
42+
correct += 1
43+
44+
accuracy = correct / num_valid
45+
print("Accuracy:", 100 * accuracy)
46+
47+
# number of selected features
48+
num_feat = fmdl['nf']
49+
print("Feature Size:", num_feat)
50+
51+
# plot convergence
52+
curve = fmdl['c']
53+
curve = curve.reshape(np.size(curve,1))
54+
x = np.arange(0, opts['T'], 1.0) + 1.0
55+
56+
fig, ax = plt.subplots()
57+
ax.plot(x, curve, 'o-')
58+
ax.set_xlabel('Number of Iterations')
59+
ax.set_ylabel('Fitness')
60+
ax.set_title('PSO')
61+
ax.grid()
62+
plt.show()
63+

0 commit comments

Comments
 (0)