-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathWW6_train_notParallel.m
59 lines (47 loc) · 1.92 KB
/
WW6_train_notParallel.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
clc
clear
close all
addpath ('dcsimsep-master')
setup
ObservationInfo = rlNumericSpec([1 11]);
ObservationInfo.Name = 'Line State';
ObservationInfo.Description = 'line1, line2, line3, line4, line5, line6, line7, line8, line9, line10, line11';
ObservationInfo.LowerLimit=0;
ObservationInfo.UpperLimit=1;
ActionInfo = rlFiniteSetSpec([1 2 3 4 5 6 7 8 9 10 11]);
ActionInfo.Name = 'Attacker Action';
ActionInfo.Description = ['attack-line1, attack-line2, attack-line3, attack-line4, ' ...
'attack-line5, attack-line6, attack-line7, attack-line8, attack-line9, attack-line10, attack-line11'];
env = rlFunctionEnv(ObservationInfo, ActionInfo,'WW6_StepFunction','WW6_ResetFunction');
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
% obsInfo.Dimension % 1 11
% actInfo.Dimension % 1 1
%% Hard
dnn = [
featureInputLayer(obsInfo.Dimension(2),'Normalization','none','Name','state')
fullyConnectedLayer(24,'Name','CriticStateFC1')
reluLayer('Name','CriticRelu1')
fullyConnectedLayer(24, 'Name','CriticStateFC2')
reluLayer('Name','CriticCommonRelu')
fullyConnectedLayer(length(actInfo.Elements),'Name','output')];
figure
plot(layerGraph(dnn))
criticOpts = rlRepresentationOptions('LearnRate',0.001,'GradientThreshold',1);
critic = rlQValueRepresentation(dnn,obsInfo,actInfo,'Observation',{'state'},criticOpts);
agentOpts = rlDQNAgentOptions(...
'UseDoubleDQN',false, ...
'TargetSmoothFactor',1, ...
'TargetUpdateFrequency',4, ...
'ExperienceBufferLength',100000, ...
'DiscountFactor',0.99, ...
'MiniBatchSize',256);
agent = rlDQNAgent(critic,agentOpts);
trainOpts = rlTrainingOptions(...
'MaxEpisodes',1000, ...
'MaxStepsPerEpisode',10, ...
'Verbose',false, ...
'Plots','training-progress',...
'StopTrainingCriteria','AverageReward',...
'StopTrainingValue',4.5);
trainingStats = train(agent,env,trainOpts);