@@ -768,8 +768,7 @@ def __init__(
768
768
self .set_truncated = set_truncated
769
769
770
770
self ._make_shuttle ()
771
- if self ._use_buffers :
772
- self ._make_final_rollout ()
771
+ self ._maybe_make_final_rollout (make_rollout = self ._use_buffers )
773
772
self ._set_truncated_keys ()
774
773
775
774
if split_trajs is None :
@@ -806,28 +805,30 @@ def _make_shuttle(self):
806
805
traj_ids ,
807
806
)
808
807
809
- def _make_final_rollout (self ):
810
- with torch .no_grad ():
811
- self ._final_rollout = self .env .fake_tensordict ()
812
-
813
- # If storing device is not None, we use this to cast the storage.
814
- # If it is None and the env and policy are on the same device,
815
- # the storing device is already the same as those, so we don't need
816
- # to consider this use case.
817
- # In all other cases, we can't really put a device on the storage,
818
- # since at least one data source has a device that is not clear.
819
- if self .storing_device :
820
- self ._final_rollout = self ._final_rollout .to (
821
- self .storing_device , non_blocking = True
822
- )
823
- else :
824
- # erase all devices
825
- self ._final_rollout .clear_device_ ()
808
+ def _maybe_make_final_rollout (self , make_rollout : bool ):
809
+ if make_rollout :
810
+ with torch .no_grad ():
811
+ self ._final_rollout = self .env .fake_tensordict ()
812
+
813
+ # If storing device is not None, we use this to cast the storage.
814
+ # If it is None and the env and policy are on the same device,
815
+ # the storing device is already the same as those, so we don't need
816
+ # to consider this use case.
817
+ # In all other cases, we can't really put a device on the storage,
818
+ # since at least one data source has a device that is not clear.
819
+ if self .storing_device :
820
+ self ._final_rollout = self ._final_rollout .to (
821
+ self .storing_device , non_blocking = True
822
+ )
823
+ else :
824
+ # erase all devices
825
+ self ._final_rollout .clear_device_ ()
826
826
827
827
# If the policy has a valid spec, we use it
828
828
self ._policy_output_keys = set ()
829
829
if (
830
- hasattr (self .policy , "spec" )
830
+ make_rollout
831
+ and hasattr (self .policy , "spec" )
831
832
and self .policy .spec is not None
832
833
and all (v is not None for v in self .policy .spec .values (True , True ))
833
834
):
@@ -846,14 +847,20 @@ def _make_final_rollout(self):
846
847
if key in self ._final_rollout .keys (True ):
847
848
continue
848
849
self ._final_rollout .set (key , spec .zero ())
849
-
850
+ elif (
851
+ not make_rollout
852
+ and hasattr (self .policy , "out_keys" )
853
+ and self .policy .out_keys
854
+ ):
855
+ self ._policy_output_keys = list (self .policy .out_keys )
850
856
else :
851
- # otherwise, we perform a small number of steps with the policy to
852
- # determine the relevant keys with which to pre-populate _final_rollout.
853
- # This is the safest thing to do if the spec has None fields or if there is
854
- # no spec at all.
855
- # See #505 for additional context.
856
- self ._final_rollout .update (self ._shuttle .copy ())
857
+ if make_rollout :
858
+ # otherwise, we perform a small number of steps with the policy to
859
+ # determine the relevant keys with which to pre-populate _final_rollout.
860
+ # This is the safest thing to do if the spec has None fields or if there is
861
+ # no spec at all.
862
+ # See #505 for additional context.
863
+ self ._final_rollout .update (self ._shuttle .copy ())
857
864
with torch .no_grad ():
858
865
policy_input = self ._shuttle .copy ()
859
866
if self .policy_device :
@@ -911,33 +918,35 @@ def filter_policy(name, value_output, value_input, value_input_clone):
911
918
set (filtered_policy_output .keys (True , True ))
912
919
)
913
920
)
914
- self ._final_rollout .update (
915
- policy_output .select (* self ._policy_output_keys )
916
- )
921
+ if make_rollout :
922
+ self ._final_rollout .update (
923
+ policy_output .select (* self ._policy_output_keys )
924
+ )
917
925
del filtered_policy_output , policy_output , policy_input
918
926
919
927
_env_output_keys = []
920
928
for spec in ["full_observation_spec" , "full_done_spec" , "full_reward_spec" ]:
921
929
_env_output_keys += list (self .env .output_spec [spec ].keys (True , True ))
922
930
self ._env_output_keys = _env_output_keys
923
- self ._final_rollout = (
924
- self ._final_rollout .unsqueeze (- 1 )
925
- .expand (* self .env .batch_size , self .frames_per_batch )
926
- .clone ()
927
- .zero_ ()
928
- )
931
+ if make_rollout :
932
+ self ._final_rollout = (
933
+ self ._final_rollout .unsqueeze (- 1 )
934
+ .expand (* self .env .batch_size , self .frames_per_batch )
935
+ .clone ()
936
+ .zero_ ()
937
+ )
929
938
930
- # in addition to outputs of the policy, we add traj_ids to
931
- # _final_rollout which will be collected during rollout
932
- self ._final_rollout .set (
933
- ("collector" , "traj_ids" ),
934
- torch .zeros (
935
- * self ._final_rollout .batch_size ,
936
- dtype = torch .int64 ,
937
- device = self .storing_device ,
938
- ),
939
- )
940
- self ._final_rollout .refine_names (..., "time" )
939
+ # in addition to outputs of the policy, we add traj_ids to
940
+ # _final_rollout which will be collected during rollout
941
+ self ._final_rollout .set (
942
+ ("collector" , "traj_ids" ),
943
+ torch .zeros (
944
+ * self ._final_rollout .batch_size ,
945
+ dtype = torch .int64 ,
946
+ device = self .storing_device ,
947
+ ),
948
+ )
949
+ self ._final_rollout .refine_names (..., "time" )
941
950
942
951
def _set_truncated_keys (self ):
943
952
self ._truncated_keys = []
0 commit comments