Skip to content

Commit 2c19fcc

Browse files
committed
[BugFix] patch rand_action in TransformedEnv to read the base_env method
ghstack-source-id: 04e2e85e2675cf34c349ebadb8fa85a5aff2e532 Pull Request resolved: #2699
1 parent ec370c6 commit 2c19fcc

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

torchrl/envs/transforms/transforms.py

+22
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,28 @@ def input_spec(self) -> TensorSpec:
795795
input_spec = self.__dict__.get("_input_spec", None)
796796
return input_spec
797797

798+
def rand_action(self, tensordict: Optional[TensorDictBase] = None) -> TensorDict:
799+
if type(self.base_env).rand_action is not EnvBase.rand_action:
800+
# TODO: this will fail if the transform modifies the input.
801+
# For instance, if an env overrides rand_action and we build a
802+
# env = PendulumEnv().append_transform(ActionDiscretizer(num_intervals=4))
803+
# env.rand_action will NOT have a discrete action!
804+
# Getting a discrete action would require coding the inverse transform of an action within
805+
# ActionDiscretizer (ie, float->int, not int->float).
806+
# We can loosely check that the action_spec isn't altered - that doesn't mean the action is
807+
# intact but it covers part of these alterations.
808+
#
809+
# The following check may be expensive to run and could be cached.
810+
if self.full_action_spec != self.base_env.full_action_spec:
811+
raise RuntimeError(
812+
f"The rand_action method from the base env {self.base_env.__class__.__name__} "
813+
"has been overwritten, but the transforms appended to the environment modify "
814+
"the action. To call the base env rand_action method, we should then invert the "
815+
"action transform, which is (in general) not doable."
816+
)
817+
return self.base_env.rand_action(tensordict)
818+
return super().rand_action(tensordict)
819+
798820
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
799821
# No need to clone here because inv does it already
800822
# tensordict = tensordict.clone(False)

0 commit comments

Comments
 (0)