Skip to content

Commit 319bb68

Browse files
louisfauryLouis Faury
and
Louis Faury
authored
[Feature] Log each entropy for composite distributions in PPO (#2707)
Co-authored-by: Louis Faury <[email protected]>
1 parent d4e4019 commit 319bb68

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

test/test_cost.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -8409,7 +8409,6 @@ def test_ppo_composite_no_aggregate(
84098409
if isinstance(loss_fn, KLPENPPOLoss):
84108410
kl = loss.pop("kl_approx")
84118411
assert (kl != 0).any()
8412-
84138412
loss_critic = loss["loss_critic"]
84148413
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
84158414
loss_critic.backward(retain_graph=True)
@@ -8637,12 +8636,16 @@ def test_ppo_shared_seq(
86378636
)
86388637

86398638
loss = loss_fn(td).exclude("entropy")
8639+
if composite_action_dist:
8640+
loss = loss.exclude("composite_entropy")
86408641

86418642
sum(val for key, val in loss.items() if key.startswith("loss_")).backward()
86428643
grad = TensorDict(dict(model.named_parameters()), []).apply(
86438644
lambda x: x.grad.clone()
86448645
)
86458646
loss2 = loss_fn2(td).exclude("entropy")
8647+
if composite_action_dist:
8648+
loss2 = loss2.exclude("composite_entropy")
86468649

86478650
model.zero_grad()
86488651
sum(val for key, val in loss2.items() if key.startswith("loss_")).backward()

torchrl/objectives/ppo.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
494494
def reset(self) -> None:
495495
pass
496496

497-
def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
497+
def _get_entropy(self, dist: d.Distribution) -> torch.Tensor | TensorDict:
498498
try:
499499
entropy = dist.entropy()
500500
except NotImplementedError:
@@ -513,13 +513,11 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
513513
log_prob = log_prob.select(*self.tensor_keys.sample_log_prob)
514514

515515
entropy = -log_prob.mean(0)
516-
if is_tensor_collection(entropy):
517-
entropy = _sum_td_features(entropy)
518516
return entropy.unsqueeze(-1)
519517

520518
def _log_weight(
521519
self, tensordict: TensorDictBase
522-
) -> Tuple[torch.Tensor, d.Distribution]:
520+
) -> Tuple[torch.Tensor, d.Distribution, torch.Tensor]:
523521

524522
with self.actor_network_params.to_module(
525523
self.actor_network
@@ -681,10 +679,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
681679
log_weight = log_weight.view(advantage.shape)
682680
neg_loss = log_weight.exp() * advantage
683681
td_out = TensorDict({"loss_objective": -neg_loss}, batch_size=[])
682+
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
684683
if self.entropy_bonus:
685-
entropy = self.get_entropy_bonus(dist)
684+
entropy = self._get_entropy(dist)
685+
if is_tensor_collection(entropy):
686+
# Reports the entropy of each action head.
687+
td_out.set("composite_entropy", entropy.detach())
688+
entropy = _sum_td_features(entropy)
686689
td_out.set("entropy", entropy.detach().mean()) # for logging
687-
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
688690
td_out.set("loss_entropy", -self.entropy_coef * entropy)
689691
if self.critic_coef is not None:
690692
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
@@ -956,8 +958,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
956958
# ESS for logging
957959
with torch.no_grad():
958960
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according
959-
# to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion
960-
# of the weights.
961+
# to different, unrelated trajectories, which is not standard. Still, it can give an idea of the weights'
962+
# dispersion.
961963
lw = log_weight.squeeze()
962964
if not isinstance(lw, torch.Tensor):
963965
lw = _sum_td_features(lw)
@@ -976,11 +978,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
976978
gain = _sum_td_features(gain)
977979
td_out = TensorDict({"loss_objective": -gain}, batch_size=[])
978980
td_out.set("clip_fraction", clip_fraction)
981+
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
979982

980983
if self.entropy_bonus:
981-
entropy = self.get_entropy_bonus(dist)
984+
entropy = self._get_entropy(dist)
985+
if is_tensor_collection(entropy):
986+
# Reports the entropy of each action head.
987+
td_out.set("composite_entropy", entropy.detach())
988+
entropy = _sum_td_features(entropy)
982989
td_out.set("entropy", entropy.detach().mean()) # for logging
983-
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
984990
td_out.set("loss_entropy", -self.entropy_coef * entropy)
985991
if self.critic_coef is not None:
986992
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
@@ -1282,14 +1288,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
12821288
{
12831289
"loss_objective": -neg_loss,
12841290
"kl": kl.detach(),
1291+
"kl_approx": kl_approx.detach().mean(),
12851292
},
12861293
batch_size=[],
12871294
)
12881295

12891296
if self.entropy_bonus:
1290-
entropy = self.get_entropy_bonus(dist)
1297+
entropy = self._get_entropy(dist)
1298+
if is_tensor_collection(entropy):
1299+
# Reports the entropy of each action head.
1300+
td_out.set("composite_entropy", entropy.detach())
1301+
entropy = _sum_td_features(entropy)
12911302
td_out.set("entropy", entropy.detach().mean()) # for logging
1292-
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
12931303
td_out.set("loss_entropy", -self.entropy_coef * entropy)
12941304
if self.critic_coef is not None:
12951305
loss_critic, value_clip_fraction = self.loss_critic(tensordict_copy)

0 commit comments

Comments
 (0)