Skip to content

Commit 44ec09b

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 16d647b commit 44ec09b

File tree

2 files changed

+80
-39
lines changed

2 files changed

+80
-39
lines changed

test/test_transforms.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13456,11 +13456,14 @@ class TestConditionalSkip(TransformBase):
1345613456
@pytest.mark.parametrize("bwad", [False, True])
1345713457
def test_single_trans_env_check(self, bwad):
1345813458
env = CountingEnv()
13459-
base_env = TransformedEnv(env, StepCounter(step_count_key="other_count"))
13460-
env = TransformedEnv(base_env, StepCounter(), auto_unwrap=False)
13461-
env = env.append_transform(
13462-
ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1)
13459+
base_env = TransformedEnv(
13460+
env,
13461+
Compose(
13462+
StepCounter(step_count_key="other_count"),
13463+
ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1),
13464+
),
1346313465
)
13466+
env = TransformedEnv(base_env, StepCounter(), auto_unwrap=False)
1346413467
env.set_seed(0)
1346513468
env.check_env_specs()
1346613469
policy = lambda td: td.set("action", torch.ones((1,)))
@@ -13472,13 +13475,16 @@ def test_single_trans_env_check(self, bwad):
1347213475
def test_serial_trans_env_check(self, bwad):
1347313476
def make_env(i):
1347413477
env = CountingEnv()
13475-
base_env = TransformedEnv(env, StepCounter(step_count_key="other_count"))
13476-
return TransformedEnv(
13477-
base_env,
13478+
base_env = TransformedEnv(
13479+
env,
1347813480
Compose(
13479-
StepCounter(),
13481+
StepCounter(step_count_key="other_count"),
1348013482
ConditionalSkip(cond=lambda td, i=i: (td["step_count"] % 2 == i)),
1348113483
),
13484+
)
13485+
return TransformedEnv(
13486+
base_env,
13487+
StepCounter(),
1348213488
auto_unwrap=False,
1348313489
)
1348413490

@@ -13494,13 +13500,16 @@ def make_env(i):
1349413500
def test_parallel_trans_env_check(self, bwad):
1349513501
def make_env(i):
1349613502
env = CountingEnv()
13497-
base_env = TransformedEnv(env, StepCounter(step_count_key="other_count"))
13498-
return TransformedEnv(
13499-
base_env,
13503+
base_env = TransformedEnv(
13504+
env,
1350013505
Compose(
13501-
StepCounter(),
13506+
StepCounter(step_count_key="other_count"),
1350213507
ConditionalSkip(cond=lambda td, i=i: (td["step_count"] % 2 == i)),
1350313508
),
13509+
)
13510+
return TransformedEnv(
13511+
base_env,
13512+
StepCounter(),
1350413513
auto_unwrap=False,
1350513514
)
1350613515

@@ -13533,7 +13542,8 @@ def cond(td):
1353313542
sc = td["step_count"] + torch.tensor([[0], [1]])
1353413543
return sc.squeeze() % 2 == 0
1353513544

13536-
env = TransformedEnv(base_env, Compose(StepCounter(), ConditionalSkip(cond)))
13545+
env = TransformedEnv(base_env, ConditionalSkip(cond))
13546+
env = TransformedEnv(env, StepCounter(), auto_unwrap=False)
1353713547
env.check_env_specs()
1353813548
policy = lambda td: td.set("action", torch.ones((2, 1)))
1353913549
r = env.rollout(10, policy, break_when_any_done=bwad)
@@ -13558,9 +13568,8 @@ def cond(td):
1355813568
sc = td["step_count"] + torch.tensor([[0], [1]])
1355913569
return sc.squeeze() % 2 == 0
1356013570

13561-
env = TransformedEnv(
13562-
base_env, Compose(StepCounter(), ConditionalSkip(cond))
13563-
)
13571+
env = TransformedEnv(base_env, ConditionalSkip(cond))
13572+
env = TransformedEnv(env, StepCounter(), auto_unwrap=False)
1356413573
env.check_env_specs()
1356513574
policy = lambda td: td.set("action", torch.ones((2, 1)))
1356613575
r = env.rollout(10, policy, break_when_any_done=bwad)

torchrl/envs/transforms/transforms.py

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,7 @@ def parent(self) -> Optional[EnvBase]:
635635
)
636636
parent, _ = container._rebuild_up_to(self)
637637
elif isinstance(container, TransformedEnv):
638-
parent = TransformedEnv(container.base_env)
638+
parent = TransformedEnv(container.base_env, auto_unwrap=False)
639639
else:
640640
raise ValueError(f"container is of type {type(container)}")
641641
self.__dict__["_parent"] = parent
@@ -958,22 +958,22 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
958958

959959
if next_tensordict is None:
960960
next_tensordict = self.base_env._step(tensordict_in)
961+
if next_preset is not None:
962+
# tensordict could already have a "next" key
963+
# this could be done more efficiently by not excluding but just passing
964+
# the necessary keys
965+
next_tensordict.update(
966+
next_preset.exclude(*next_tensordict.keys(True, True))
967+
)
968+
self.base_env._complete_done(self.base_env.full_done_spec, next_tensordict)
969+
# we want the input entries to remain unchanged
970+
next_tensordict = self.transform._step(tensordict, next_tensordict)
961971

962972
if partial_steps is not None and tensordict_batch_size != self.batch_size:
963973
result = next_tensordict.new_zeros(tensordict_batch_size)
964974
result[partial_steps] = next_tensordict
965975
next_tensordict = result
966976

967-
if next_preset is not None:
968-
# tensordict could already have a "next" key
969-
# this could be done more efficiently by not excluding but just passing
970-
# the necessary keys
971-
next_tensordict.update(
972-
next_preset.exclude(*next_tensordict.keys(True, True))
973-
)
974-
self.base_env._complete_done(self.base_env.full_done_spec, next_tensordict)
975-
# we want the input entries to remain unchanged
976-
next_tensordict = self.transform._step(tensordict, next_tensordict)
977977
return next_tensordict
978978

979979
def set_seed(
@@ -9079,6 +9079,7 @@ class _CallableTransform(Transform):
90799079
# A wrapper around a custom callable to make it possible to transform any data type
90809080
def __init__(self, func):
90819081
super().__init__()
9082+
raise RuntimeError(isinstance(func, Transform), func)
90829083
self.func = func
90839084

90849085
def forward(self, *args, **kwargs):
@@ -10266,21 +10267,40 @@ class ConditionalSkip(Transform):
1026610267
value in `"_step"` is ``True``. Otherwise, it is trusted that the environment will account for the
1026710268
`"_step"` signal accordingly.
1026810269
10270+
.. note:: The skip will affect transforms that modify the environment output too, i.e., any transform
10271+
that is to be exectued on the tensordict returned by :meth:`~torchrl.envs.EnvBase.step` will be
10272+
skipped if the condition is met. To palliate this effect if it is not desirable, one can wrap
10273+
the transformed env in another transformed env, since the skip only affects the first-degree parent
10274+
of the ``ConditionalSkip`` transform. See example below.
10275+
1026910276
Args:
1027010277
cond (Callable[[TensorDictBase], bool | torch.Tensor]): a callable for the tensordict input
1027110278
that checks whether the next env step must be skipped (`True` = skipped, `False` = execute
1027210279
env.step).
1027310280
1027410281
Examples:
10275-
>>> from torchrl.envs.transforms.transforms import ConditionalSkip, StepCounter, TransformedEnv
10276-
>>> from torchrl.envs import GymEnv
1027710282
>>> import torch
1027810283
>>>
10284+
>>> from torchrl.envs import GymEnv
10285+
>>> from torchrl.envs.transforms.transforms import ConditionalSkip, StepCounter, TransformedEnv, Compose
10286+
>>>
1027910287
>>> torch.manual_seed(0)
1028010288
>>>
10281-
>>> base_env = TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(step_count_key="other_count"))
10282-
>>> env = TransformedEnv(base_env, StepCounter(), auto_unwrap=False)
10283-
>>> env = env.append_transform(ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1))
10289+
>>> base_env = TransformedEnv(
10290+
... GymEnv("Pendulum-v1"),
10291+
... StepCounter(step_count_key="inner_count"),
10292+
... )
10293+
>>> middle_env = TransformedEnv(
10294+
... base_env,
10295+
... Compose(
10296+
... StepCounter(step_count_key="middle_count"),
10297+
... ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1),
10298+
... ),
10299+
... auto_unwrap=False) # makes sure that transformed envs are properly wrapped
10300+
>>> env = TransformedEnv(
10301+
... middle_env,
10302+
... StepCounter(step_count_key="step_count"),
10303+
... auto_unwrap=False)
1028410304
>>> env.set_seed(0)
1028510305
>>>
1028610306
>>> r = env.rollout(10)
@@ -10295,18 +10315,18 @@ class ConditionalSkip(Transform):
1029510315
[-0.9984, 0.0561, -1.7933],
1029610316
[-0.9984, 0.0561, -1.7933],
1029710317
[-0.9895, 0.1445, -1.7779]])
10298-
>>> print(r["step_count"])
10318+
>>> print(r["inner_count"])
1029910319
tensor([[0],
1030010320
[1],
10321+
[1],
10322+
[2],
1030110323
[2],
1030210324
[3],
10325+
[3],
1030310326
[4],
10304-
[5],
10305-
[6],
10306-
[7],
10307-
[8],
10308-
[9]])
10309-
>>> print(r["other_count"])
10327+
[4],
10328+
[5]])
10329+
>>> print(r["middle_count"])
1031010330
tensor([[0],
1031110331
[1],
1031210332
[1],
@@ -10317,6 +10337,18 @@ class ConditionalSkip(Transform):
1031710337
[4],
1031810338
[4],
1031910339
[5]])
10340+
>>> print(r["step_count"])
10341+
tensor([[0],
10342+
[1],
10343+
[2],
10344+
[3],
10345+
[4],
10346+
[5],
10347+
[6],
10348+
[7],
10349+
[8],
10350+
[9]])
10351+
1032010352
1032110353
"""
1032210354

0 commit comments

Comments
 (0)