Skip to content

Commit 825af47

Browse files
committed
Update
[ghstack-poisoned]
2 parents 9666e5b + 5447009 commit 825af47

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

test/mocking_classes.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,8 @@ def _step(
617617

618618

619619
class ContinuousActionVecMockEnv(_MockEnv):
620+
adapt_dtype: bool = True
621+
620622
@classmethod
621623
def __new__(
622624
cls,
@@ -715,7 +717,12 @@ def _step(
715717
done = done.any(-1)
716718
done = reward = done.unsqueeze(-1)
717719
tensordict.set(
718-
"reward", reward.to(self.reward_spec.dtype).expand(self.reward_spec.shape)
720+
"reward",
721+
reward.to(
722+
self.reward_spec.dtype
723+
if self.adapt_dtype
724+
else torch.get_default_dtype()
725+
).expand(self.reward_spec.shape),
719726
)
720727
tensordict.set("done", done)
721728
tensordict.set("terminated", done)

test/test_env.py

+1
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@
235235
class TestEnvBase:
236236
def test_run_type_checks(self):
237237
env = ContinuousActionVecMockEnv()
238+
env.adapt_dtype = False
238239
env._run_type_checks = False
239240
check_env_specs(env)
240241
env._run_type_checks = True

0 commit comments

Comments
 (0)