@@ -61,6 +61,7 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
61
61
float total_entropy = 0 ;
62
62
float kl_divergence = 0 ;
63
63
float kl_early_stopped = -1 ;
64
+ float clip_fraction = 0 ;
64
65
int num_updates = 0 ;
65
66
66
67
// Epoch loop
@@ -107,11 +108,19 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
107
108
mini_batch.action_log_probs );
108
109
109
110
// PPO loss formula
110
- auto surr_1 = ratio * mini_batch.advantages ;
111
+ auto surr_1 = ratio * mini_batch.advantages . mean () ;
111
112
auto surr_2 = (torch::clamp (ratio,
112
113
1.0 - clip_param,
113
114
1.0 + clip_param) *
114
- mini_batch.advantages );
115
+ mini_batch.advantages )
116
+ .mean ();
117
+ clip_fraction += (ratio - 1.0 )
118
+ .abs ()
119
+ .gt (clip_param)
120
+ .to (torch::kFloat )
121
+ .mean ()
122
+ .item ()
123
+ .toFloat ();
115
124
auto action_loss = -torch::min (surr_1, surr_2).mean ();
116
125
117
126
// Value loss
@@ -148,11 +157,13 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
148
157
total_value_loss /= num_updates;
149
158
total_action_loss /= num_updates;
150
159
total_entropy /= num_updates;
160
+ clip_fraction /= num_updates;
151
161
152
162
if (kl_early_stopped > -1 )
153
163
{
154
164
return {{" Value loss" , total_value_loss},
155
165
{" Action loss" , total_action_loss},
166
+ {" Clip fraction" , clip_fraction},
156
167
{" Entropy" , total_entropy},
157
168
{" KL divergence" , kl_divergence},
158
169
{" KL divergence early stop update" , kl_early_stopped}};
@@ -161,6 +172,7 @@ std::vector<UpdateDatum> PPO::update(RolloutStorage &rollouts, float decay_level
161
172
{
162
173
return {{" Value loss" , total_value_loss},
163
174
{" Action loss" , total_action_loss},
175
+ {" Clip fraction" , clip_fraction},
164
176
{" Entropy" , total_entropy},
165
177
{" KL divergence" , kl_divergence}};
166
178
}
0 commit comments