Skip to content

Commit d4f8846

Browse files
committed
[BugFix] Fix collector with no buffers and devices
ghstack-source-id: 5367df9 Pull Request resolved: #2809
1 parent 3acf491 commit d4f8846

File tree

2 files changed

+65
-49
lines changed

2 files changed

+65
-49
lines changed

test/test_collector.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -3093,16 +3093,23 @@ def test_dynamic_sync_collector(self):
30933093
assert isinstance(data, LazyStackedTensorDict)
30943094
assert data.names[-1] == "time"
30953095

3096-
def test_dynamic_multisync_collector(self):
3096+
@pytest.mark.parametrize("policy_device", [None, *get_default_devices()])
3097+
def test_dynamic_multisync_collector(self, policy_device):
30973098
env = EnvWithDynamicSpec
3098-
policy = RandomPolicy(env().action_spec)
3099+
spec = env().action_spec
3100+
if policy_device is not None:
3101+
spec = spec.to(policy_device)
3102+
policy = RandomPolicy(spec)
30993103
collector = MultiSyncDataCollector(
31003104
[env],
31013105
policy,
31023106
frames_per_batch=20,
31033107
total_frames=100,
31043108
use_buffers=False,
31053109
cat_results="stack",
3110+
policy_device=policy_device,
3111+
env_device="cpu",
3112+
storing_device="cpu",
31063113
)
31073114
for data in collector:
31083115
assert isinstance(data, LazyStackedTensorDict)

torchrl/collectors/collectors.py

+56-47
Original file line numberDiff line numberDiff line change
@@ -768,8 +768,7 @@ def __init__(
768768
self.set_truncated = set_truncated
769769

770770
self._make_shuttle()
771-
if self._use_buffers:
772-
self._make_final_rollout()
771+
self._maybe_make_final_rollout(make_rollout=self._use_buffers)
773772
self._set_truncated_keys()
774773

775774
if split_trajs is None:
@@ -806,28 +805,30 @@ def _make_shuttle(self):
806805
traj_ids,
807806
)
808807

809-
def _make_final_rollout(self):
810-
with torch.no_grad():
811-
self._final_rollout = self.env.fake_tensordict()
812-
813-
# If storing device is not None, we use this to cast the storage.
814-
# If it is None and the env and policy are on the same device,
815-
# the storing device is already the same as those, so we don't need
816-
# to consider this use case.
817-
# In all other cases, we can't really put a device on the storage,
818-
# since at least one data source has a device that is not clear.
819-
if self.storing_device:
820-
self._final_rollout = self._final_rollout.to(
821-
self.storing_device, non_blocking=True
822-
)
823-
else:
824-
# erase all devices
825-
self._final_rollout.clear_device_()
808+
def _maybe_make_final_rollout(self, make_rollout: bool):
809+
if make_rollout:
810+
with torch.no_grad():
811+
self._final_rollout = self.env.fake_tensordict()
812+
813+
# If storing device is not None, we use this to cast the storage.
814+
# If it is None and the env and policy are on the same device,
815+
# the storing device is already the same as those, so we don't need
816+
# to consider this use case.
817+
# In all other cases, we can't really put a device on the storage,
818+
# since at least one data source has a device that is not clear.
819+
if self.storing_device:
820+
self._final_rollout = self._final_rollout.to(
821+
self.storing_device, non_blocking=True
822+
)
823+
else:
824+
# erase all devices
825+
self._final_rollout.clear_device_()
826826

827827
# If the policy has a valid spec, we use it
828828
self._policy_output_keys = set()
829829
if (
830-
hasattr(self.policy, "spec")
830+
make_rollout
831+
and hasattr(self.policy, "spec")
831832
and self.policy.spec is not None
832833
and all(v is not None for v in self.policy.spec.values(True, True))
833834
):
@@ -846,14 +847,20 @@ def _make_final_rollout(self):
846847
if key in self._final_rollout.keys(True):
847848
continue
848849
self._final_rollout.set(key, spec.zero())
849-
850+
elif (
851+
not make_rollout
852+
and hasattr(self.policy, "out_keys")
853+
and self.policy.out_keys
854+
):
855+
self._policy_output_keys = list(self.policy.out_keys)
850856
else:
851-
# otherwise, we perform a small number of steps with the policy to
852-
# determine the relevant keys with which to pre-populate _final_rollout.
853-
# This is the safest thing to do if the spec has None fields or if there is
854-
# no spec at all.
855-
# See #505 for additional context.
856-
self._final_rollout.update(self._shuttle.copy())
857+
if make_rollout:
858+
# otherwise, we perform a small number of steps with the policy to
859+
# determine the relevant keys with which to pre-populate _final_rollout.
860+
# This is the safest thing to do if the spec has None fields or if there is
861+
# no spec at all.
862+
# See #505 for additional context.
863+
self._final_rollout.update(self._shuttle.copy())
857864
with torch.no_grad():
858865
policy_input = self._shuttle.copy()
859866
if self.policy_device:
@@ -911,33 +918,35 @@ def filter_policy(name, value_output, value_input, value_input_clone):
911918
set(filtered_policy_output.keys(True, True))
912919
)
913920
)
914-
self._final_rollout.update(
915-
policy_output.select(*self._policy_output_keys)
916-
)
921+
if make_rollout:
922+
self._final_rollout.update(
923+
policy_output.select(*self._policy_output_keys)
924+
)
917925
del filtered_policy_output, policy_output, policy_input
918926

919927
_env_output_keys = []
920928
for spec in ["full_observation_spec", "full_done_spec", "full_reward_spec"]:
921929
_env_output_keys += list(self.env.output_spec[spec].keys(True, True))
922930
self._env_output_keys = _env_output_keys
923-
self._final_rollout = (
924-
self._final_rollout.unsqueeze(-1)
925-
.expand(*self.env.batch_size, self.frames_per_batch)
926-
.clone()
927-
.zero_()
928-
)
931+
if make_rollout:
932+
self._final_rollout = (
933+
self._final_rollout.unsqueeze(-1)
934+
.expand(*self.env.batch_size, self.frames_per_batch)
935+
.clone()
936+
.zero_()
937+
)
929938

930-
# in addition to outputs of the policy, we add traj_ids to
931-
# _final_rollout which will be collected during rollout
932-
self._final_rollout.set(
933-
("collector", "traj_ids"),
934-
torch.zeros(
935-
*self._final_rollout.batch_size,
936-
dtype=torch.int64,
937-
device=self.storing_device,
938-
),
939-
)
940-
self._final_rollout.refine_names(..., "time")
939+
# in addition to outputs of the policy, we add traj_ids to
940+
# _final_rollout which will be collected during rollout
941+
self._final_rollout.set(
942+
("collector", "traj_ids"),
943+
torch.zeros(
944+
*self._final_rollout.batch_size,
945+
dtype=torch.int64,
946+
device=self.storing_device,
947+
),
948+
)
949+
self._final_rollout.refine_names(..., "time")
941950

942951
def _set_truncated_keys(self):
943952
self._truncated_keys = []

0 commit comments

Comments
 (0)