Skip to content

Commit 1832c60

Browse files
author
Omegastick
committed
Fix typos in PPO
1 parent f1f4bd5 commit 1832c60

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

example/gym_client.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ const int batch_size = 40;
2121
const float clip_param = 0.2;
2222
const float discount_factor = 0.99;
2323
const float entropy_coef = 1e-3;
24-
const float gae = 0.95;
25-
const float kl_target = 0.05;
26-
const float learning_rate = 7e-4;
27-
const int log_interval = 1;
24+
const float gae = 0.9;
25+
const float kl_target = 0.5;
26+
const float learning_rate = 1e-3;
27+
const int log_interval = 10;
2828
const int max_frames = 10e+7;
2929
const int num_epoch = 3;
3030
const int num_mini_batch = 20;
3131
const int reward_average_window_size = 10;
32-
const float reward_clip_value = 10; // Post scaling
32+
const float reward_clip_value = 100; // Post scaling
3333
const bool use_gae = true;
3434
const bool use_lr_decay = false;
3535
const float value_loss_coef = 0.5;

src/algorithms/ppo.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
5454
value_preds.narrow(0, 0, value_preds.size(0) - 1));
5555

5656
// Normalize advantages
57-
advantages = (advantages - advantages.mean() / (advantages.std() + 1e-5));
57+
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5);
5858

5959
float total_value_loss = 0;
6060
float total_action_loss = 0;
@@ -108,12 +108,11 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
108108
mini_batch.action_log_probs);
109109

110110
// PPO loss formula
111-
auto surr_1 = ratio * mini_batch.advantages.mean();
111+
auto surr_1 = ratio * mini_batch.advantages;
112112
auto surr_2 = (torch::clamp(ratio,
113113
1.0 - clip_param,
114114
1.0 + clip_param) *
115-
mini_batch.advantages)
116-
.mean();
115+
mini_batch.advantages);
117116
clip_fraction += (ratio - 1.0)
118117
.abs()
119118
.gt(clip_param)

0 commit comments

Comments
 (0)