File tree 1 file changed +11
-0
lines changed
1 file changed +11
-0
lines changed Original file line number Diff line number Diff line change @@ -795,6 +795,17 @@ def input_spec(self) -> TensorSpec:
795
795
input_spec = self .__dict__ .get ("_input_spec" , None )
796
796
return input_spec
797
797
798
+ def rand_action (self , tensordict : Optional [TensorDictBase ] = None ) -> TensorDict :
799
+ if self .base_env .rand_action is not EnvBase .rand_action :
800
+ # TODO: this will fail if the transform modifies the input.
801
+ # For instance, if PendulumEnv overrides rand_action and we build a
802
+ # env = PendulumEnv().append_transform(ActionDiscretizer(num_intervals=4))
803
+ # env.rand_action will NOT have a discrete action!
804
+ # Getting a discrete action would require coding the inverse transform of an action within
805
+ # ActionDiscretizer (ie, float->int, not int->float).
806
+ return self .base_env .rand_action (tensordict )
807
+ return super ().rand_action (tensordict )
808
+
798
809
def _step (self , tensordict : TensorDictBase ) -> TensorDictBase :
799
810
# No need to clone here because inv does it already
800
811
# tensordict = tensordict.clone(False)
You can’t perform that action at this time.
0 commit comments