@@ -494,7 +494,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
494494    def  reset (self ) ->  None :
495495        pass 
496496
497-     def  get_entropy_bonus (self , dist : d .Distribution ) ->  torch .Tensor :
497+     def  _get_entropy (self , dist : d .Distribution ) ->  torch .Tensor   |   TensorDict :
498498        try :
499499            entropy  =  dist .entropy ()
500500        except  NotImplementedError :
@@ -513,13 +513,11 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
513513                        log_prob  =  log_prob .select (* self .tensor_keys .sample_log_prob )
514514
515515            entropy  =  - log_prob .mean (0 )
516-         if  is_tensor_collection (entropy ):
517-             entropy  =  _sum_td_features (entropy )
518516        return  entropy .unsqueeze (- 1 )
519517
520518    def  _log_weight (
521519        self , tensordict : TensorDictBase 
522-     ) ->  Tuple [torch .Tensor , d .Distribution ]:
520+     ) ->  Tuple [torch .Tensor , d .Distribution ,  torch . Tensor ]:
523521
524522        with  self .actor_network_params .to_module (
525523            self .actor_network 
@@ -681,10 +679,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
681679            log_weight  =  log_weight .view (advantage .shape )
682680        neg_loss  =  log_weight .exp () *  advantage 
683681        td_out  =  TensorDict ({"loss_objective" : - neg_loss }, batch_size = [])
682+         td_out .set ("kl_approx" , kl_approx .detach ().mean ())  # for logging 
684683        if  self .entropy_bonus :
685-             entropy  =  self .get_entropy_bonus (dist )
684+             entropy  =  self ._get_entropy (dist )
685+             if  is_tensor_collection (entropy ):
686+                 # Reports the entropy of each action head. 
687+                 td_out .set ("composite_entropy" , entropy .detach ())
688+                 entropy  =  _sum_td_features (entropy )
686689            td_out .set ("entropy" , entropy .detach ().mean ())  # for logging 
687-             td_out .set ("kl_approx" , kl_approx .detach ().mean ())  # for logging 
688690            td_out .set ("loss_entropy" , - self .entropy_coef  *  entropy )
689691        if  self .critic_coef  is  not None :
690692            loss_critic , value_clip_fraction  =  self .loss_critic (tensordict )
@@ -956,8 +958,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
956958        # ESS for logging 
957959        with  torch .no_grad ():
958960            # In theory, ESS should be computed on particles sampled from the same source. Here we sample according 
959-             # to different, unrelated trajectories, which is not standard. Still it can give a  idea of the dispersion  
960-             # of the weights . 
961+             # to different, unrelated trajectories, which is not standard. Still,  it can give an  idea of the weights'  
962+             # dispersion . 
961963            lw  =  log_weight .squeeze ()
962964            if  not  isinstance (lw , torch .Tensor ):
963965                lw  =  _sum_td_features (lw )
@@ -976,11 +978,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
976978            gain  =  _sum_td_features (gain )
977979        td_out  =  TensorDict ({"loss_objective" : - gain }, batch_size = [])
978980        td_out .set ("clip_fraction" , clip_fraction )
981+         td_out .set ("kl_approx" , kl_approx .detach ().mean ())  # for logging 
979982
980983        if  self .entropy_bonus :
981-             entropy  =  self .get_entropy_bonus (dist )
984+             entropy  =  self ._get_entropy (dist )
985+             if  is_tensor_collection (entropy ):
986+                 # Reports the entropy of each action head. 
987+                 td_out .set ("composite_entropy" , entropy .detach ())
988+                 entropy  =  _sum_td_features (entropy )
982989            td_out .set ("entropy" , entropy .detach ().mean ())  # for logging 
983-             td_out .set ("kl_approx" , kl_approx .detach ().mean ())  # for logging 
984990            td_out .set ("loss_entropy" , - self .entropy_coef  *  entropy )
985991        if  self .critic_coef  is  not None :
986992            loss_critic , value_clip_fraction  =  self .loss_critic (tensordict )
@@ -1282,14 +1288,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
12821288            {
12831289                "loss_objective" : - neg_loss ,
12841290                "kl" : kl .detach (),
1291+                 "kl_approx" : kl_approx .detach ().mean (),
12851292            },
12861293            batch_size = [],
12871294        )
12881295
12891296        if  self .entropy_bonus :
1290-             entropy  =  self .get_entropy_bonus (dist )
1297+             entropy  =  self ._get_entropy (dist )
1298+             if  is_tensor_collection (entropy ):
1299+                 # Reports the entropy of each action head. 
1300+                 td_out .set ("composite_entropy" , entropy .detach ())
1301+                 entropy  =  _sum_td_features (entropy )
12911302            td_out .set ("entropy" , entropy .detach ().mean ())  # for logging 
1292-             td_out .set ("kl_approx" , kl_approx .detach ().mean ())  # for logging 
12931303            td_out .set ("loss_entropy" , - self .entropy_coef  *  entropy )
12941304        if  self .critic_coef  is  not None :
12951305            loss_critic , value_clip_fraction  =  self .loss_critic (tensordict_copy )
0 commit comments