Skip to content

Commit 1024d61

Browse files
committed
Update
[ghstack-poisoned]
2 parents 55934fb + 6e83807 commit 1024d61

File tree

3 files changed

+105
-13
lines changed

3 files changed

+105
-13
lines changed

test/test_cost.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -8409,7 +8409,6 @@ def test_ppo_composite_no_aggregate(
84098409
if isinstance(loss_fn, KLPENPPOLoss):
84108410
kl = loss.pop("kl_approx")
84118411
assert (kl != 0).any()
8412-
84138412
loss_critic = loss["loss_critic"]
84148413
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
84158414
loss_critic.backward(retain_graph=True)
@@ -8637,12 +8636,16 @@ def test_ppo_shared_seq(
86378636
)
86388637

86398638
loss = loss_fn(td).exclude("entropy")
8639+
if composite_action_dist:
8640+
loss = loss.exclude("composite_entropy")
86408641

86418642
sum(val for key, val in loss.items() if key.startswith("loss_")).backward()
86428643
grad = TensorDict(dict(model.named_parameters()), []).apply(
86438644
lambda x: x.grad.clone()
86448645
)
86458646
loss2 = loss_fn2(td).exclude("entropy")
8647+
if composite_action_dist:
8648+
loss2 = loss2.exclude("composite_entropy")
86468649

86478650
model.zero_grad()
86488651
sum(val for key, val in loss2.items() if key.startswith("loss_")).backward()

test/test_transforms.py

+79
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,85 @@ def test_single_trans_env_check(self, out_keys):
689689
)
690690
check_env_specs(env)
691691

692+
@pytest.mark.parametrize("cat_dim", [-1, -2, -3])
693+
@pytest.mark.parametrize("cat_N", [3, 10])
694+
@pytest.mark.parametrize("device", get_default_devices())
695+
def test_with_permute_no_env(self, cat_dim, cat_N, device):
696+
torch.manual_seed(cat_dim * cat_N)
697+
pixels = torch.randn(8, 5, 3, 10, 4, device=device)
698+
699+
a = TensorDict(
700+
{
701+
"pixels": pixels,
702+
},
703+
[
704+
pixels.shape[0],
705+
],
706+
device=device,
707+
)
708+
709+
t0 = Compose(
710+
CatFrames(N=cat_N, dim=cat_dim),
711+
)
712+
713+
def get_rand_perm(ndim):
714+
cat_dim_perm = cat_dim
715+
# Ensure that the permutation moves the cat_dim
716+
while cat_dim_perm == cat_dim:
717+
perm_pos = torch.randperm(ndim)
718+
perm = perm_pos - ndim
719+
cat_dim_perm = (perm == cat_dim).nonzero().item() - ndim
720+
perm_inv = perm_pos.argsort() - ndim
721+
return perm.tolist(), perm_inv.tolist(), cat_dim_perm
722+
723+
perm, perm_inv, cat_dim_perm = get_rand_perm(pixels.dim() - 1)
724+
725+
t1 = Compose(
726+
PermuteTransform(perm, in_keys=["pixels"]),
727+
CatFrames(N=cat_N, dim=cat_dim_perm),
728+
PermuteTransform(perm_inv, in_keys=["pixels"]),
729+
)
730+
731+
b = t0._call(a.clone())
732+
c = t1._call(a.clone())
733+
assert (b == c).all()
734+
735+
@pytest.mark.skipif(not _has_gym, reason="Test executed on gym")
736+
@pytest.mark.parametrize("cat_dim", [-1, -2])
737+
def test_with_permute_env(self, cat_dim):
738+
env0 = TransformedEnv(
739+
GymEnv("Pendulum-v1"),
740+
Compose(
741+
UnsqueezeTransform(-1, in_keys=["observation"]),
742+
CatFrames(N=4, dim=cat_dim, in_keys=["observation"]),
743+
),
744+
)
745+
746+
env1 = TransformedEnv(
747+
GymEnv("Pendulum-v1"),
748+
Compose(
749+
UnsqueezeTransform(-1, in_keys=["observation"]),
750+
PermuteTransform((-1, -2), in_keys=["observation"]),
751+
CatFrames(N=4, dim=-3 - cat_dim, in_keys=["observation"]),
752+
PermuteTransform((-1, -2), in_keys=["observation"]),
753+
),
754+
)
755+
756+
torch.manual_seed(0)
757+
env0.set_seed(0)
758+
td0 = env0.reset()
759+
760+
torch.manual_seed(0)
761+
env1.set_seed(0)
762+
td1 = env1.reset()
763+
764+
assert (td0 == td1).all()
765+
766+
td0 = env0.step(td0.update(env0.full_action_spec.rand()))
767+
td1 = env0.step(td0.update(env1.full_action_spec.rand()))
768+
769+
assert (td0 == td1).all()
770+
692771
def test_serial_trans_env_check(self):
693772
env = SerialEnv(
694773
2,

torchrl/objectives/ppo.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
494494
def reset(self) -> None:
495495
pass
496496

497-
def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
497+
def _get_entropy(self, dist: d.Distribution) -> torch.Tensor | TensorDict:
498498
try:
499499
entropy = dist.entropy()
500500
except NotImplementedError:
@@ -513,13 +513,11 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
513513
log_prob = log_prob.select(*self.tensor_keys.sample_log_prob)
514514

515515
entropy = -log_prob.mean(0)
516-
if is_tensor_collection(entropy):
517-
entropy = _sum_td_features(entropy)
518516
return entropy.unsqueeze(-1)
519517

520518
def _log_weight(
521519
self, tensordict: TensorDictBase
522-
) -> Tuple[torch.Tensor, d.Distribution]:
520+
) -> Tuple[torch.Tensor, d.Distribution, torch.Tensor]:
523521

524522
with self.actor_network_params.to_module(
525523
self.actor_network
@@ -681,10 +679,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
681679
log_weight = log_weight.view(advantage.shape)
682680
neg_loss = log_weight.exp() * advantage
683681
td_out = TensorDict({"loss_objective": -neg_loss}, batch_size=[])
682+
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
684683
if self.entropy_bonus:
685-
entropy = self.get_entropy_bonus(dist)
684+
entropy = self._get_entropy(dist)
685+
if is_tensor_collection(entropy):
686+
# Reports the entropy of each action head.
687+
td_out.set("composite_entropy", entropy.detach())
688+
entropy = _sum_td_features(entropy)
686689
td_out.set("entropy", entropy.detach().mean()) # for logging
687-
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
688690
td_out.set("loss_entropy", -self.entropy_coef * entropy)
689691
if self.critic_coef is not None:
690692
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
@@ -956,8 +958,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
956958
# ESS for logging
957959
with torch.no_grad():
958960
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according
959-
# to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion
960-
# of the weights.
961+
# to different, unrelated trajectories, which is not standard. Still, it can give an idea of the weights'
962+
# dispersion.
961963
lw = log_weight.squeeze()
962964
if not isinstance(lw, torch.Tensor):
963965
lw = _sum_td_features(lw)
@@ -976,11 +978,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
976978
gain = _sum_td_features(gain)
977979
td_out = TensorDict({"loss_objective": -gain}, batch_size=[])
978980
td_out.set("clip_fraction", clip_fraction)
981+
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
979982

980983
if self.entropy_bonus:
981-
entropy = self.get_entropy_bonus(dist)
984+
entropy = self._get_entropy(dist)
985+
if is_tensor_collection(entropy):
986+
# Reports the entropy of each action head.
987+
td_out.set("composite_entropy", entropy.detach())
988+
entropy = _sum_td_features(entropy)
982989
td_out.set("entropy", entropy.detach().mean()) # for logging
983-
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
984990
td_out.set("loss_entropy", -self.entropy_coef * entropy)
985991
if self.critic_coef is not None:
986992
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
@@ -1282,14 +1288,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
12821288
{
12831289
"loss_objective": -neg_loss,
12841290
"kl": kl.detach(),
1291+
"kl_approx": kl_approx.detach().mean(),
12851292
},
12861293
batch_size=[],
12871294
)
12881295

12891296
if self.entropy_bonus:
1290-
entropy = self.get_entropy_bonus(dist)
1297+
entropy = self._get_entropy(dist)
1298+
if is_tensor_collection(entropy):
1299+
# Reports the entropy of each action head.
1300+
td_out.set("composite_entropy", entropy.detach())
1301+
entropy = _sum_td_features(entropy)
12911302
td_out.set("entropy", entropy.detach().mean()) # for logging
1292-
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
12931303
td_out.set("loss_entropy", -self.entropy_coef * entropy)
12941304
if self.critic_coef is not None:
12951305
loss_critic, value_clip_fraction = self.loss_critic(tensordict_copy)

0 commit comments

Comments
 (0)