-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtriageSVM.m
102 lines (88 loc) · 4.01 KB
/
triageSVM.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
% function [resultsTable,labelCV] = triageSVM(inputX,Y)
inputX = XY_pc(:,2:end);
Y = XY_pc.(1);
%Inistiates results tables; comment out if already created
type = {};
sensitivity = [];
specificity = [];
numObservations = [];
numFeatures = [];
errorRate = [];
posPred = [];
negPred = [];
prevalence = [];
genError = [];
%Dummify X
dumX = table;
for c = 1:size(inputX,2)
dumX.(c) = dummyvar(nominal(inputX.(c)));
end
%% Prep Data
%Divide into train vs test data
trainPrecentage = .7;
trainCutoffIndex = fix(size(Y,1)*trainPrecentage);
trainX = dumX(1:trainCutoffIndex,:);
trainY = Y(1:trainCutoffIndex);
testX = dumX((trainCutoffIndex+1):end,:);
testY = Y((trainCutoffIndex+1):end);
%Turn dumXs into matDumXs
trainXMat = cell2mat(table2cell(trainX));
testXMat = cell2mat(table2cell(testX));
XMat = cell2mat(table2cell(dumX));
%% Feature Selection
c = cvpartition(Y,'k',5);
opts = statset('display','iter');
inmodel = sequentialfs(@my_fun,XMat,Y,'cv',c,'options',opts);
%% Train & Test SVM
%Train Model on TrainData, Predict with test input data & Analyze Performance
SVMModel = fitcsvm(trainXMat,trainY,'Standardize',true,'KernelFunction','linear','ClassNames',{'SELF','DR'});
% 'Crossval','on','KFold',10,'PredictorNames',{'Breathing Problem','BreathingRate','ChestPain','Cold/flu-severity','CoughOnset','CoughSpasms','CoughType','Cough-severity','CurrentState','Inhaled/Ingest','Meds/Respiratory','OtherSymptoms','PMH/RespiratoryDis.','Precipfactors?','Sorethroat-severity','Sputum','Temperature','TreatmentTried','WorstTime'}
[labelTest,scoreTest] = predict(SVMModel,testXMat);
[labelTrain,scoreTrain] = predict(SVMModel,trainXMat);
CP_Test = classperf(testY, labelTest,'Positive',{'DR'}, 'Negative', {'SELF'});
CP_Train = classperf(trainY, labelTrain,'Positive',{'DR'}, 'Negative', {'SELF'});
%Train Model on AllData with Cross Validation & Analyze Performance
SVMModelCV = fitcsvm(XMat,Y,'Crossval','on','KFold',10,'Standardize',true,'KernelFunction','linear','ClassNames',{'SELF','DR'});
[labelCV,scoreCV] = kfoldPredict(SVMModelCV);
CP_CV = classperf(Y, labelCV,'Positive',{'DR'}, 'Negative', {'SELF'});
%Error Rate
error = CP_CV.ErrorRate;
%F Score
R = CP_CV.Sensitivity;
P = CP_CV.PositivePredictiveValue;
fScore = 2*R*P/(R+P);
%Generalization Error
genError_CV = kfoldLoss(SVMModelCV);
%% Write to Results Table
type(end+1,1) = {'CP_Train'};
numFeatures(end+1,1) = size(inputX,2);
numObservations(end+1,1) = CP_Train.NumberOfObservations;
sensitivity(end+1,1) = CP_Train.Sensitivity;
specificity(end+1,1) = CP_Train.Specificity;
errorRate(end+1,1) = CP_Train.ErrorRate;
posPred(end+1,1) = CP_Train.PositivePredictiveValue;
negPred(end+1,1) = CP_Train.NegativePredictiveValue;
prevalence(end+1,1) = CP_Train.Prevalence;
genError(end+1,1) = 0;
type(end+1,1) = {'CP_Test'};
numFeatures(end+1,1) = size(inputX,2);
numObservations(end+1,1) = CP_Test.NumberOfObservations;
sensitivity(end+1,1) = CP_Test.Sensitivity;
specificity(end+1,1) = CP_Test.Specificity;
errorRate(end+1,1) = CP_Test.ErrorRate;
posPred(end+1,1) = CP_Test.PositivePredictiveValue;
negPred(end+1,1) = CP_Test.NegativePredictiveValue;
prevalence(end+1,1) = CP_Test.Prevalence;
genError(end+1,1) = 0;
type(end+1,1) = {'CP_CV'};
numFeatures(end+1,1) = size(inputX,2);
numObservations(end+1,1) = CP_CV.NumberOfObservations;
sensitivity(end+1,1) = CP_CV.Sensitivity;
specificity(end+1,1) = CP_CV.Specificity;
errorRate(end+1,1) = CP_CV.ErrorRate;
posPred(end+1,1) = CP_CV.PositivePredictiveValue;
negPred(end+1,1) = CP_CV.NegativePredictiveValue;
prevalence(end+1,1) = CP_CV.Prevalence;
genError(end+1,1) = genError_CV;
resultsTable = table(type,numFeatures,numObservations,sensitivity,specificity,errorRate,posPred,negPred,prevalence,genError);
end