Skip to content

Commit a2e1fc4

Browse files
committed
Update
[ghstack-poisoned]
1 parent 6e7adf8 commit a2e1fc4

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

test/test_env.py

-1
Original file line numberDiff line numberDiff line change
@@ -4148,7 +4148,6 @@ def test_parallel_partial_step_and_maybe_reset(
41484148

41494149
td.set("action", penv.full_action_spec[penv.action_key].one())
41504150
td, tdreset = penv.step_and_maybe_reset(td)
4151-
print(td)
41524151
assert_allclose_td(td[0].get("next"), td[0], intersection=True)
41534152
assert (td[1].get("next") != 0).any()
41544153
assert_allclose_td(td[2].get("next"), td[2], intersection=True)

torchrl/envs/batched_envs.py

+8
Original file line numberDiff line numberDiff line change
@@ -1691,6 +1691,14 @@ def step_and_maybe_reset(
16911691
device=device,
16921692
filter_empty=True,
16931693
)
1694+
if tensordict.device != device:
1695+
tensordict = tensordict._fast_apply(
1696+
lambda x: x.to(device, non_blocking=self.non_blocking)
1697+
if x.device != device
1698+
else x,
1699+
device=device,
1700+
filter_empty=True,
1701+
)
16941702
self._sync_w2m()
16951703
else:
16961704
next_td = next_td.clone().clear_device_()

0 commit comments

Comments
 (0)