@@ -635,7 +635,7 @@ def parent(self) -> Optional[EnvBase]:
635635 )
636636 parent , _ = container ._rebuild_up_to (self )
637637 elif isinstance (container , TransformedEnv ):
638- parent = TransformedEnv (container .base_env )
638+ parent = TransformedEnv (container .base_env , auto_unwrap = False )
639639 else :
640640 raise ValueError (f"container is of type { type (container )} " )
641641 self .__dict__ ["_parent" ] = parent
@@ -958,22 +958,22 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
958958
959959 if next_tensordict is None :
960960 next_tensordict = self .base_env ._step (tensordict_in )
961+ if next_preset is not None :
962+ # tensordict could already have a "next" key
963+ # this could be done more efficiently by not excluding but just passing
964+ # the necessary keys
965+ next_tensordict .update (
966+ next_preset .exclude (* next_tensordict .keys (True , True ))
967+ )
968+ self .base_env ._complete_done (self .base_env .full_done_spec , next_tensordict )
969+ # we want the input entries to remain unchanged
970+ next_tensordict = self .transform ._step (tensordict , next_tensordict )
961971
962972 if partial_steps is not None and tensordict_batch_size != self .batch_size :
963973 result = next_tensordict .new_zeros (tensordict_batch_size )
964974 result [partial_steps ] = next_tensordict
965975 next_tensordict = result
966976
967- if next_preset is not None :
968- # tensordict could already have a "next" key
969- # this could be done more efficiently by not excluding but just passing
970- # the necessary keys
971- next_tensordict .update (
972- next_preset .exclude (* next_tensordict .keys (True , True ))
973- )
974- self .base_env ._complete_done (self .base_env .full_done_spec , next_tensordict )
975- # we want the input entries to remain unchanged
976- next_tensordict = self .transform ._step (tensordict , next_tensordict )
977977 return next_tensordict
978978
979979 def set_seed (
@@ -9079,6 +9079,7 @@ class _CallableTransform(Transform):
90799079 # A wrapper around a custom callable to make it possible to transform any data type
90809080 def __init__ (self , func ):
90819081 super ().__init__ ()
9082+ raise RuntimeError (isinstance (func , Transform ), func )
90829083 self .func = func
90839084
90849085 def forward (self , * args , ** kwargs ):
@@ -10266,21 +10267,40 @@ class ConditionalSkip(Transform):
1026610267 value in `"_step"` is ``True``. Otherwise, it is trusted that the environment will account for the
1026710268 `"_step"` signal accordingly.
1026810269
10270+ .. note:: The skip will affect transforms that modify the environment output too, i.e., any transform
10271+ that is to be exectued on the tensordict returned by :meth:`~torchrl.envs.EnvBase.step` will be
10272+ skipped if the condition is met. To palliate this effect if it is not desirable, one can wrap
10273+ the transformed env in another transformed env, since the skip only affects the first-degree parent
10274+ of the ``ConditionalSkip`` transform. See example below.
10275+
1026910276 Args:
1027010277 cond (Callable[[TensorDictBase], bool | torch.Tensor]): a callable for the tensordict input
1027110278 that checks whether the next env step must be skipped (`True` = skipped, `False` = execute
1027210279 env.step).
1027310280
1027410281 Examples:
10275- >>> from torchrl.envs.transforms.transforms import ConditionalSkip, StepCounter, TransformedEnv
10276- >>> from torchrl.envs import GymEnv
1027710282 >>> import torch
1027810283 >>>
10284+ >>> from torchrl.envs import GymEnv
10285+ >>> from torchrl.envs.transforms.transforms import ConditionalSkip, StepCounter, TransformedEnv, Compose
10286+ >>>
1027910287 >>> torch.manual_seed(0)
1028010288 >>>
10281- >>> base_env = TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(step_count_key="other_count"))
10282- >>> env = TransformedEnv(base_env, StepCounter(), auto_unwrap=False)
10283- >>> env = env.append_transform(ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1))
10289+ >>> base_env = TransformedEnv(
10290+ ... GymEnv("Pendulum-v1"),
10291+ ... StepCounter(step_count_key="inner_count"),
10292+ ... )
10293+ >>> middle_env = TransformedEnv(
10294+ ... base_env,
10295+ ... Compose(
10296+ ... StepCounter(step_count_key="middle_count"),
10297+ ... ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1),
10298+ ... ),
10299+ ... auto_unwrap=False) # makes sure that transformed envs are properly wrapped
10300+ >>> env = TransformedEnv(
10301+ ... middle_env,
10302+ ... StepCounter(step_count_key="step_count"),
10303+ ... auto_unwrap=False)
1028410304 >>> env.set_seed(0)
1028510305 >>>
1028610306 >>> r = env.rollout(10)
@@ -10295,18 +10315,18 @@ class ConditionalSkip(Transform):
1029510315 [-0.9984, 0.0561, -1.7933],
1029610316 [-0.9984, 0.0561, -1.7933],
1029710317 [-0.9895, 0.1445, -1.7779]])
10298- >>> print(r["step_count "])
10318+ >>> print(r["inner_count "])
1029910319 tensor([[0],
1030010320 [1],
10321+ [1],
10322+ [2],
1030110323 [2],
1030210324 [3],
10325+ [3],
1030310326 [4],
10304- [5],
10305- [6],
10306- [7],
10307- [8],
10308- [9]])
10309- >>> print(r["other_count"])
10327+ [4],
10328+ [5]])
10329+ >>> print(r["middle_count"])
1031010330 tensor([[0],
1031110331 [1],
1031210332 [1],
@@ -10317,6 +10337,18 @@ class ConditionalSkip(Transform):
1031710337 [4],
1031810338 [4],
1031910339 [5]])
10340+ >>> print(r["step_count"])
10341+ tensor([[0],
10342+ [1],
10343+ [2],
10344+ [3],
10345+ [4],
10346+ [5],
10347+ [6],
10348+ [7],
10349+ [8],
10350+ [9]])
10351+
1032010352
1032110353 """
1032210354
0 commit comments