-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathelm_kernel_train.m
96 lines (78 loc) · 2.95 KB
/
elm_kernel_train.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
function model = elm_kernel_train(TrainingData,C,Kernel_type, Kernel_para)
%%核ELM只需要调惩罚系数C,核类型,和Kernel_para就是公式里w和b,因为相比于OS-ELM,elm采用的是随机生成;
train_data=TrainingData;
dataLabel=train_data(:,1);
data=train_data(:,2:end);
clear train_data; % Release raw training data array
NumberofTrainingData=size(data,1);
%%%%%%%%%%%% Preprocessing the data of classification
label=unique(dataLabel);
number_class=numel(label);
NumberofOutputNeurons=number_class;
model.label=label;
model.X=data;
%%%%%%%%%% Processing the targets of training
temp=zeros(NumberofTrainingData,NumberofOutputNeurons );
for i = 1:NumberofTrainingData
for j = 1:number_class
if label(j) == dataLabel(i)
break;
end
end
temp(i,j)=1;
end
dataLabelMP=temp*2-1;%转化成多节点形式
%%%%%%%%%% Processing the targets of testing
%%%%%%%%%%% Training Phase %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
tic;
n = size(dataLabel,1);
Omega_train = kernel_matrix(data,Kernel_type, Kernel_para);
model.beta=((Omega_train+speye(n)/C)\(dataLabelMP));
model.TrainingTime=toc;
model.Kernel_type=Kernel_type;
model.Kernel_para=Kernel_para;
end
%%%%%%%%%%%%%%%%%% Kernel Matrix %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function omega = kernel_matrix(Xtrain,kernel_type, kernel_pars,Xt)
nb_data = size(Xtrain,1);
if strcmp(kernel_type,'RBF_kernel'),
if nargin<4,
XXh = sum(Xtrain.^2,2)*ones(1,nb_data);
omega = XXh+XXh'-2*(Xtrain*Xtrain');
omega = exp(-omega./kernel_pars(1));
else
XXh1 = sum(Xtrain.^2,2)*ones(1,size(Xt,1));
XXh2 = sum(Xt.^2,2)*ones(1,nb_data);
omega = XXh1+XXh2' - 2*Xtrain*Xt';
omega = exp(-omega./kernel_pars(1));
end
elseif strcmp(kernel_type,'lin_kernel')
if nargin<4,
omega = Xtrain*Xtrain';
else
omega = Xtrain*Xt';
end
elseif strcmp(kernel_type,'poly_kernel')
if nargin<4,
omega = (Xtrain*Xtrain'+kernel_pars(1)).^kernel_pars(2);
else
omega = (Xtrain*Xt'+kernel_pars(1)).^kernel_pars(2);
end
elseif strcmp(kernel_type,'wav_kernel')
if nargin<4,
XXh = sum(Xtrain.^2,2)*ones(1,nb_data);
omega = XXh+XXh'-2*(Xtrain*Xtrain');
XXh1 = sum(Xtrain,2)*ones(1,nb_data);
omega1 = XXh1-XXh1';
omega = cos(kernel_pars(3)*omega1./kernel_pars(2)).*exp(-omega./kernel_pars(1));
else
XXh1 = sum(Xtrain.^2,2)*ones(1,size(Xt,1));
XXh2 = sum(Xt.^2,2)*ones(1,nb_data);
omega = XXh1+XXh2' - 2*(Xtrain*Xt');
XXh11 = sum(Xtrain,2)*ones(1,size(Xt,1));
XXh22 = sum(Xt,2)*ones(1,nb_data);
omega1 = XXh11-XXh22';
omega = cos(kernel_pars(3)*omega1./kernel_pars(2)).*exp(-omega./kernel_pars(1));
end
end
end