File tree 2 files changed +9
-1
lines changed
2 files changed +9
-1
lines changed Original file line number Diff line number Diff line change @@ -617,6 +617,8 @@ def _step(
617
617
618
618
619
619
class ContinuousActionVecMockEnv (_MockEnv ):
620
+ adapt_dtype : bool = True
621
+
620
622
@classmethod
621
623
def __new__ (
622
624
cls ,
@@ -715,7 +717,12 @@ def _step(
715
717
done = done .any (- 1 )
716
718
done = reward = done .unsqueeze (- 1 )
717
719
tensordict .set (
718
- "reward" , reward .to (self .reward_spec .dtype ).expand (self .reward_spec .shape )
720
+ "reward" ,
721
+ reward .to (
722
+ self .reward_spec .dtype
723
+ if self .adapt_dtype
724
+ else torch .get_default_dtype ()
725
+ ).expand (self .reward_spec .shape ),
719
726
)
720
727
tensordict .set ("done" , done )
721
728
tensordict .set ("terminated" , done )
Original file line number Diff line number Diff line change 235
235
class TestEnvBase :
236
236
def test_run_type_checks (self ):
237
237
env = ContinuousActionVecMockEnv ()
238
+ env .adapt_dtype = False
238
239
env ._run_type_checks = False
239
240
check_env_specs (env )
240
241
env ._run_type_checks = True
You can’t perform that action at this time.
0 commit comments