@@ -795,6 +795,28 @@ def input_spec(self) -> TensorSpec:
795
795
input_spec = self .__dict__ .get ("_input_spec" , None )
796
796
return input_spec
797
797
798
+ def rand_action (self , tensordict : Optional [TensorDictBase ] = None ) -> TensorDict :
799
+ if type (self .base_env ).rand_action is not EnvBase .rand_action :
800
+ # TODO: this will fail if the transform modifies the input.
801
+ # For instance, if an env overrides rand_action and we build a
802
+ # env = PendulumEnv().append_transform(ActionDiscretizer(num_intervals=4))
803
+ # env.rand_action will NOT have a discrete action!
804
+ # Getting a discrete action would require coding the inverse transform of an action within
805
+ # ActionDiscretizer (ie, float->int, not int->float).
806
+ # We can loosely check that the action_spec isn't altered - that doesn't mean the action is
807
+ # intact but it covers part of these alterations.
808
+ #
809
+ # The following check may be expensive to run and could be cached.
810
+ if self .full_action_spec != self .base_env .full_action_spec :
811
+ raise RuntimeError (
812
+ f"The rand_action method from the base env { self .base_env .__class__ .__name__ } "
813
+ "has been overwritten, but the transforms appended to the environment modify "
814
+ "the action. To call the base env rand_action method, we should then invert the "
815
+ "action transform, which is (in general) not doable."
816
+ )
817
+ return self .base_env .rand_action (tensordict )
818
+ return super ().rand_action (tensordict )
819
+
798
820
def _step (self , tensordict : TensorDictBase ) -> TensorDictBase :
799
821
# No need to clone here because inv does it already
800
822
# tensordict = tensordict.clone(False)
0 commit comments