Skip to content

Commit a451ad7

Browse files
author
Omegastick
committed
Report clip fraction from PPO
1 parent 219b5dd commit a451ad7

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

src/algorithms/ppo.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
6161
float total_entropy = 0;
6262
float kl_divergence = 0;
6363
float kl_early_stopped = -1;
64+
float clip_fraction = 0;
6465
int num_updates = 0;
6566

6667
// Epoch loop
@@ -107,11 +108,19 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
107108
mini_batch.action_log_probs);
108109

109110
// PPO loss formula
110-
auto surr_1 = ratio * mini_batch.advantages;
111+
auto surr_1 = ratio * mini_batch.advantages.mean();
111112
auto surr_2 = (torch::clamp(ratio,
112113
1.0 - clip_param,
113114
1.0 + clip_param) *
114-
mini_batch.advantages);
115+
mini_batch.advantages)
116+
.mean();
117+
clip_fraction += (ratio - 1.0)
118+
.abs()
119+
.gt(clip_param)
120+
.to(torch::kFloat)
121+
.mean()
122+
.item()
123+
.toFloat();
115124
auto action_loss = -torch::min(surr_1, surr_2).mean();
116125

117126
// Value loss
@@ -148,11 +157,13 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
148157
total_value_loss /= num_updates;
149158
total_action_loss /= num_updates;
150159
total_entropy /= num_updates;
160+
clip_fraction /= num_updates;
151161

152162
if (kl_early_stopped > -1)
153163
{
154164
return {{"Value loss", total_value_loss},
155165
{"Action loss", total_action_loss},
166+
{"Clip fraction", clip_fraction},
156167
{"Entropy", total_entropy},
157168
{"KL divergence", kl_divergence},
158169
{"KL divergence early stop update", kl_early_stopped}};
@@ -161,6 +172,7 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
161172
{
162173
return {{"Value loss", total_value_loss},
163174
{"Action loss", total_action_loss},
175+
{"Clip fraction", clip_fraction},
164176
{"Entropy", total_entropy},
165177
{"KL divergence", kl_divergence}};
166178
}

0 commit comments

Comments
 (0)