File tree 2 files changed +8
-1
lines changed
2 files changed +8
-1
lines changed Original file line number Diff line number Diff line change @@ -4148,7 +4148,6 @@ def test_parallel_partial_step_and_maybe_reset(
4148
4148
4149
4149
td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4150
4150
td , tdreset = penv .step_and_maybe_reset (td )
4151
- print (td )
4152
4151
assert_allclose_td (td [0 ].get ("next" ), td [0 ], intersection = True )
4153
4152
assert (td [1 ].get ("next" ) != 0 ).any ()
4154
4153
assert_allclose_td (td [2 ].get ("next" ), td [2 ], intersection = True )
Original file line number Diff line number Diff line change @@ -1691,6 +1691,14 @@ def step_and_maybe_reset(
1691
1691
device = device ,
1692
1692
filter_empty = True ,
1693
1693
)
1694
+ if tensordict .device != device :
1695
+ tensordict = tensordict ._fast_apply (
1696
+ lambda x : x .to (device , non_blocking = self .non_blocking )
1697
+ if x .device != device
1698
+ else x ,
1699
+ device = device ,
1700
+ filter_empty = True ,
1701
+ )
1694
1702
self ._sync_w2m ()
1695
1703
else :
1696
1704
next_td = next_td .clone ().clear_device_ ()
You can’t perform that action at this time.
0 commit comments