@@ -604,13 +604,15 @@ def _async_unroll(self, unroll_length: int):
604604 effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length
605605 return experience , effective_unroll_iters
606606
607- def post_process_experience (self , rollout_info , step_type : StepType ,
608- experiences : Experience ) -> Tuple [List , int ]:
609- """A function for postprocessing experience. By default, it returns the input
607+ def preprocess_unroll_experience (
608+ self , rollout_info , step_type : StepType ,
609+ experiences : Experience ) -> Tuple [List , float ]:
610+ """A function for processing the experience obtained from an unroll step before
611+ being saved into the replay buffer. By default, it returns the input
610612 experience unmodified. Users can customize this function in the derived
611613 class to achieve different effects. For example:
612614 - per-step processing: return the current step of experience unmodified (by default)
613- or a modified version according to the customized ``post_process_experience ``.
615+ or a modified version according to the customized ``preprocess_unroll_experience ``.
614616 As another example, task filtering can be simply achieved by returning ``[]``
615617 for that particular task.
616618 - per-episode processing: this can be achieved by returning a list of processed
@@ -622,35 +624,79 @@ class to achieve different effects. For example:
622624 experiences: one step of experience.
623625
624626 Returns:
625- - a list of experiences. Users can customize this functions in the
626- derived class to achieve different effects. For example:
627+ - ``effective_experiences``: a list of experiences. Users can customize this
628+ functions in the derived class to achieve different effects. For example:
627629 * return a list that contains only the input experience (default behavior).
628630 * return a list that contains a number of experiences. This can be useful
629631 for episode processing such as success episode labeling.
630- - an integer representing the effective number of unroll steps per env. The
631- default value of 1, meaning the length of effective experience is 1
632- after calling ``post_process_experience``, the same as the input length
633- of experience.
632+ - ``effective_unroll_steps`` : a value representing the effective number of
633+ unroll steps per env. The default value of 1, meaning the length of
634+ effective experience is 1 after calling ``preprocess_unroll_experience``,
635+ the same as the input length of experience.
636+ The value of ``effective_unroll_steps`` can be set differently according
637+ to different scenarios, e.g.:
638+ (1) per-step saving without delay: saving each step of unroll experience
639+ into the replay buffer as we get it. Set ``effective_unroll_steps``
640+ as 1 so that each step will be counted as valid and there will be no
641+ impact on the train/unroll ratio.
642+ (2) all-step saving with delay: saving all the steps of unroll experience into
643+ the replay buffer with delay. This can happen in the case where we want to
644+ annotate an trajectory based on some quantities that are not immediately
645+ available in the current step (e.g. task success/failure). In this case,
646+ we can simply caching the experiences and set ``effective_experiences=[]``
647+ before obtaining the quantities required for annotation.
648+ After obtaining the quantities required for annotation, we can
649+ set ``effective_experiences`` as the cached and annotated experience.
650+ To maintain the original unroll/train iter ratio, we can set
651+ ``effective_unroll_steps=1``, meaning each unroll step is regarded as
652+ effective in terms of the unroll/train iter ratio, even though the
653+ pace of saving the unroll steps into replay buffer has been altered.
654+ (3) selective saving: exclude some of the unroll experiences and only save
655+ the rest. This could be useful in the case where there are transitions
656+ that are irrelevant to the training (e.g. in the multi-task case, where
657+ we want to exclude data from certain subtasks).
658+ This can be achieved by setting ``effective_experiences=[]``for the
659+ steps to be excluded, while ``effective_experiences = [experiences]``
660+ otherwise. If we do not want to trigger a train iter for the unroll
661+ step that will be excluded, we can simply set ``effective_unroll_steps=0``.
662+ Otherwise, we can simply set ``effective_unroll_steps=1``.
663+ (4) parallel environments: in the case of parallel environments, the value
664+ of ``effective_unroll_steps`` can be set according to the modes described
665+ above and the status of each environment (e.g. ``effective_unroll_steps``
666+ can be set to an average value across environments). Note that this could
667+ resulf to a floating number.
634668 """
635- return [experiences ], 1
669+ effective_experiences = [experiences ]
670+ effective_unroll_steps = 1
671+ return effective_experiences , effective_unroll_steps
636672
637673 def _process_unroll_step (self , policy_step , action , time_step ,
638674 transformed_time_step , policy_state ,
639675 experience_list ,
640- original_reward_list ) -> Tuple [int , int ]:
676+ original_reward_list ) -> Tuple [int , float ]:
677+ """
678+
679+ Returns:
680+ - ``store_exp_time``: the time spent on storing the experience
681+ - ``effective_unroll_steps``: a value representing the effective number
682+ of unroll steps per env. The default value of 1, meaning the length of
683+ effective experience is 1 after calling ``preprocess_unroll_experience``,
684+ the same as the input length of experience. For more details on it,
685+ please refer to the docstr of ``preprocess_unroll_experience``.
686+ """
641687 self .observe_for_metrics (time_step .cpu ())
642688 exp = make_experience (time_step .cpu (),
643689 alf .layers .to_float32 (policy_step ),
644690 alf .layers .to_float32 (policy_state ))
645691 effective_unroll_steps = 1
646692 store_exp_time = 0
647693 if not self .on_policy :
648- # 1) post process
649- post_processed_exp_list , effective_unroll_steps = self .post_process_experience (
694+ # 1) pre- process unroll experience
695+ pre_processed_exp_list , effective_unroll_steps = self .preprocess_unroll_experience (
650696 policy_step .info , time_step .step_type , exp )
651697 # 2) observe
652698 t0 = time .time ()
653- for exp in post_processed_exp_list :
699+ for exp in pre_processed_exp_list :
654700 self .observe_for_replay (exp )
655701 store_exp_time = time .time () - t0
656702
@@ -771,7 +817,9 @@ def _sync_unroll(self, unroll_length: int):
771817 self ._current_transform_state = common .detach (trans_state )
772818
773819 # if the input unroll_length is 0 (e.g. fractional unroll), then this it treated as
774- # an effective unroll iter
820+ # an effective unroll iter.
821+ # one ``effective_unroll_iter`` refers to the ``unroll_length`` times of calling
822+ # of ``rollout_step`` in the unroll phase.
775823 effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length
776824 return experience , effective_unroll_iters
777825
0 commit comments