@@ -494,7 +494,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
494
494
def reset (self ) -> None :
495
495
pass
496
496
497
- def get_entropy_bonus (self , dist : d .Distribution ) -> torch .Tensor :
497
+ def _get_entropy (self , dist : d .Distribution ) -> torch .Tensor | TensorDict :
498
498
try :
499
499
entropy = dist .entropy ()
500
500
except NotImplementedError :
@@ -513,13 +513,11 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
513
513
log_prob = log_prob .select (* self .tensor_keys .sample_log_prob )
514
514
515
515
entropy = - log_prob .mean (0 )
516
- if is_tensor_collection (entropy ):
517
- entropy = _sum_td_features (entropy )
518
516
return entropy .unsqueeze (- 1 )
519
517
520
518
def _log_weight (
521
519
self , tensordict : TensorDictBase
522
- ) -> Tuple [torch .Tensor , d .Distribution ]:
520
+ ) -> Tuple [torch .Tensor , d .Distribution , torch . Tensor ]:
523
521
524
522
with self .actor_network_params .to_module (
525
523
self .actor_network
@@ -681,10 +679,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
681
679
log_weight = log_weight .view (advantage .shape )
682
680
neg_loss = log_weight .exp () * advantage
683
681
td_out = TensorDict ({"loss_objective" : - neg_loss }, batch_size = [])
682
+ td_out .set ("kl_approx" , kl_approx .detach ().mean ()) # for logging
684
683
if self .entropy_bonus :
685
- entropy = self .get_entropy_bonus (dist )
684
+ entropy = self ._get_entropy (dist )
685
+ if is_tensor_collection (entropy ):
686
+ # Reports the entropy of each action head.
687
+ td_out .set ("composite_entropy" , entropy .detach ())
688
+ entropy = _sum_td_features (entropy )
686
689
td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
687
- td_out .set ("kl_approx" , kl_approx .detach ().mean ()) # for logging
688
690
td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
689
691
if self .critic_coef is not None :
690
692
loss_critic , value_clip_fraction = self .loss_critic (tensordict )
@@ -956,8 +958,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
956
958
# ESS for logging
957
959
with torch .no_grad ():
958
960
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according
959
- # to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion
960
- # of the weights .
961
+ # to different, unrelated trajectories, which is not standard. Still, it can give an idea of the weights'
962
+ # dispersion .
961
963
lw = log_weight .squeeze ()
962
964
if not isinstance (lw , torch .Tensor ):
963
965
lw = _sum_td_features (lw )
@@ -976,11 +978,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
976
978
gain = _sum_td_features (gain )
977
979
td_out = TensorDict ({"loss_objective" : - gain }, batch_size = [])
978
980
td_out .set ("clip_fraction" , clip_fraction )
981
+ td_out .set ("kl_approx" , kl_approx .detach ().mean ()) # for logging
979
982
980
983
if self .entropy_bonus :
981
- entropy = self .get_entropy_bonus (dist )
984
+ entropy = self ._get_entropy (dist )
985
+ if is_tensor_collection (entropy ):
986
+ # Reports the entropy of each action head.
987
+ td_out .set ("composite_entropy" , entropy .detach ())
988
+ entropy = _sum_td_features (entropy )
982
989
td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
983
- td_out .set ("kl_approx" , kl_approx .detach ().mean ()) # for logging
984
990
td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
985
991
if self .critic_coef is not None :
986
992
loss_critic , value_clip_fraction = self .loss_critic (tensordict )
@@ -1282,14 +1288,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
1282
1288
{
1283
1289
"loss_objective" : - neg_loss ,
1284
1290
"kl" : kl .detach (),
1291
+ "kl_approx" : kl_approx .detach ().mean (),
1285
1292
},
1286
1293
batch_size = [],
1287
1294
)
1288
1295
1289
1296
if self .entropy_bonus :
1290
- entropy = self .get_entropy_bonus (dist )
1297
+ entropy = self ._get_entropy (dist )
1298
+ if is_tensor_collection (entropy ):
1299
+ # Reports the entropy of each action head.
1300
+ td_out .set ("composite_entropy" , entropy .detach ())
1301
+ entropy = _sum_td_features (entropy )
1291
1302
td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
1292
- td_out .set ("kl_approx" , kl_approx .detach ().mean ()) # for logging
1293
1303
td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
1294
1304
if self .critic_coef is not None :
1295
1305
loss_critic , value_clip_fraction = self .loss_critic (tensordict_copy )
0 commit comments