File tree Expand file tree Collapse file tree 1 file changed +11
-0
lines changed Expand file tree Collapse file tree 1 file changed +11
-0
lines changed Original file line number Diff line number Diff line change @@ -795,6 +795,17 @@ def input_spec(self) -> TensorSpec:
795795 input_spec = self .__dict__ .get ("_input_spec" , None )
796796 return input_spec
797797
798+ def rand_action (self , tensordict : Optional [TensorDictBase ] = None ) -> TensorDict :
799+ if self .base_env .rand_action is not EnvBase .rand_action :
800+ # TODO: this will fail if the transform modifies the input.
801+ # For instance, if PendulumEnv overrides rand_action and we build a
802+ # env = PendulumEnv().append_transform(ActionDiscretizer(num_intervals=4))
803+ # env.rand_action will NOT have a discrete action!
804+ # Getting a discrete action would require coding the inverse transform of an action within
805+ # ActionDiscretizer (ie, float->int, not int->float).
806+ return self .base_env .rand_action (tensordict )
807+ return super ().rand_action (tensordict )
808+
798809 def _step (self , tensordict : TensorDictBase ) -> TensorDictBase :
799810 # No need to clone here because inv does it already
800811 # tensordict = tensordict.clone(False)
You can’t perform that action at this time.
0 commit comments