@@ -147,7 +147,6 @@ def __init__(self,
147147 optimizer = None ,
148148 checkpoint = None ,
149149 is_eval : bool = False ,
150- episodic_annotation : bool = False ,
151150 overwrite_policy_output = False ,
152151 debug_summaries = False ,
153152 name = "RLAlgorithm" ):
@@ -187,10 +186,6 @@ def __init__(self,
187186 during deployment. In this case, the algorithm do not need to
188187 create certain components such as value_network for ActorCriticAlgorithm,
189188 critic_networks for SacAlgorithm.
190- episodic_annotation: episodic annotation is an operation that annotates the
191- episode after it being collected, and then the annotated episode will be
192- observed by the replay buffer. If True, annotate the episode before being
193- observed by the replay buffer. Otherwise, episodic annotation is not applied.
194189 overwrite_policy_output (bool): if True, overwrite the policy output
195190 with next_step.prev_action. This option can be used in some
196191 cases such as data collection.
@@ -208,7 +203,6 @@ def __init__(self,
208203 debug_summaries = debug_summaries ,
209204 name = name )
210205 self ._is_eval = is_eval
211- self ._episodic_annotation = episodic_annotation
212206
213207 self ._env = env
214208 self ._observation_spec = observation_spec
@@ -241,7 +235,6 @@ def __init__(self,
241235 self ._current_time_step = None
242236 self ._current_policy_state = None
243237 self ._current_transform_state = None
244- self ._cached_exp = [] # for lazy observation
245238 if self ._env is not None and not self .on_policy :
246239 replay_buffer_length = adjust_replay_buffer_length (
247240 config , self ._num_earliest_frames_ignored )
@@ -550,6 +543,7 @@ def _async_unroll(self, unroll_length: int):
550543 store_exp_time = 0.
551544 step_time = 0.
552545 max_step_time = 0.
546+ effective_unroll_steps = 0
553547 qsize = self ._async_unroller .get_queue_size ()
554548 unroll_results = self ._async_unroller .gather_unroll_results (
555549 unroll_length , self ._config .max_unroll_length )
@@ -572,11 +566,12 @@ def _async_unroll(self, unroll_length: int):
572566 step_time += unroll_result .step_time
573567 max_step_time = max (max_step_time , unroll_result .step_time )
574568
575- store_exp_time_i , effective_unroll_steps = self ._process_unroll_step (
569+ store_exp_time_i , effective_unroll_steps_i = self ._process_unroll_step (
576570 policy_step , policy_step .output , time_step ,
577571 transformed_time_step , policy_state , experience_list ,
578572 original_reward_list )
579573 store_exp_time += store_exp_time_i
574+ effective_unroll_steps += effective_unroll_steps_i
580575
581576 alf .summary .scalar ("time/unroll_env_step" ,
582577 env_step_time ,
@@ -603,70 +598,80 @@ def _async_unroll(self, unroll_length: int):
603598
604599 self ._current_transform_state = common .detach (trans_state )
605600
606- return experience , effective_unroll_steps
607-
608- def should_post_process_episode (self , rollout_info , step_type : StepType ):
609- """A function that determines whether the ``post_process_episode`` function should
610- be applied to the current list of experiences.
611- Users can customize this function in the derived class.
612- Bu default, it returns True all the time steps. When this is combined with
613- ``post_process_episode`` which simply return the input unmodified (as the default
614- implementation in this class), it is a dummy version of eposodic annotation with
615- logic equivalent to the case of episodic_annotation=False.
601+ effective_unroll_iters = effective_unroll_steps // unroll_length
602+ return experience , effective_unroll_iters
603+
604+ def should_post_process_experience (self , rollout_info ,
605+ step_type : StepType ):
606+ """A function that determines whether the ``post_process_experience`` function should
607+ be called. Users can customize this pair of functions in the derived class to achieve
608+ different effects. For example:
609+ - per-step processing: ``should_post_process_experience``
610+ returns True for all the steps (by default), and ``post_process_experience``
611+ returns the current step of experience unmodified (by default) or a modified version
612+ according to their customized ``post_process_experience`` function.
613+ As another example, task filtering can be simply achieved by returning ``[]``
614+ in ``post_process_experience`` for that particular task.
615+ - per-episode processing: ``should_post_process_experience`` returns True on episode
616+ end and ``post_process_experience`` can return a list of cached and processed
617+ experiences. For example, this can be used for success episode labeling.
616618 """
617619 return True
618620
619- def post_process_episode (self , experiences : List [ Experience ] ):
621+ def post_process_experience (self , experiences : Experience ):
620622 """A function for postprocessing a list of experience. It is called when
621- ``should_post_process_episode `` is True.
623+ ``should_post_process_experience `` is True.
622624 By default, it returns the input unmodified.
623625 Users can customize this function in the derived class, to create a number of
624626 useful features such as 'hindsight relabeling' of a trajectory etc.
625627
626628 Args:
627- experiences: a list of experience, containing the experience starting from the
628- initial time when ``should_post_process_episode`` is False to the step where
629- ``should_post_process_episode`` is True.
629+ experiences: one step of experience.
630+
631+ Returns:
632+ A list of experiences. Users can customize this pair of functions in the
633+ derived class to achieve different effects. For example:
634+ - return a list that contains only the input experience (default behavior).
635+ - return a list that contains a number of experiences. This can be useful
636+ for episode processing such as success episode labeling.
630637 """
631- return experiences
638+ return [ experiences ]
632639
633640 def _process_unroll_step (self , policy_step , action , time_step ,
634641 transformed_time_step , policy_state ,
635642 experience_list , original_reward_list ):
643+ """A function for processing the unroll steps.
644+ By default, it returns the input unmodified.
645+ Users can customize this function in the derived class, to create a number of
646+ useful features such as 'hindsight relabeling' of a trajectory etc.
647+
648+ Args:
649+ experiences: a list of experience, containing the experience starting from the
650+ initial time when ``should_post_process_experience`` is False to the step where
651+ ``should_post_process_experience`` is True.
652+ """
653+
636654 self .observe_for_metrics (time_step .cpu ())
637655 exp = make_experience (time_step .cpu (),
638656 alf .layers .to_float32 (policy_step ),
639657 alf .layers .to_float32 (policy_state ))
640- effective_number_of_unroll_steps = 1
641- if self ._episodic_annotation :
642- assert not self .on_policy , "only support episodic annotation for off policy training"
643- store_exp_time = 0
644- # if last step, annotate
658+ effective_unroll_steps = 1
659+ store_exp_time = 0
660+ if not self .on_policy :
645661 rollout_info = policy_step .info
646- self ._cached_exp .append (exp )
647- if self .should_post_process_episode (rollout_info ,
648- time_step .step_type ):
649-
662+ if self .should_post_process_experience (rollout_info ,
663+ time_step .step_type ):
650664 # 1) process
651- annotated_exp_list = self .post_process_episode (
652- self ._cached_exp )
653- effective_number_of_unroll_steps = len (annotated_exp_list )
665+ post_processed_exp_list = self .post_process_experience (exp )
666+ effective_unroll_steps = len (post_processed_exp_list )
654667 # 2) observe
655668 t0 = time .time ()
656- for exp in annotated_exp_list :
669+ for exp in post_processed_exp_list :
657670 self .observe_for_replay (exp )
658671 store_exp_time = time .time () - t0
659- # clean up the exp cache
660- self ._cached_exp = []
661672 else :
662- # effective unroll steps as 0 if not post_process_episode timepoint yet
663- effective_number_of_unroll_steps = 0
664- else :
665- store_exp_time = 0
666- if not self .on_policy :
667- t0 = time .time ()
668- self .observe_for_replay (exp )
669- store_exp_time = time .time () - t0
673+ # effective unroll steps as 0 if ``should_post_process_experience condition`` is False
674+ effective_unroll_steps = 0
670675
671676 exp_for_training = Experience (
672677 time_step = transformed_time_step ,
@@ -676,7 +681,7 @@ def _process_unroll_step(self, policy_step, action, time_step,
676681
677682 experience_list .append (exp_for_training )
678683 original_reward_list .append (time_step .reward )
679- return store_exp_time , effective_number_of_unroll_steps
684+ return store_exp_time , effective_unroll_steps
680685
681686 def reset_state (self ):
682687 """Reset the state of the algorithm.
@@ -700,6 +705,8 @@ def _sync_unroll(self, unroll_length: int):
700705 Returns:
701706 Experience: The stacked experience with shape :math:`[T, B, \ldots]`
702707 for each of its members.
708+ effective_unroll_iters: the effective number of unroll iterations.
709+ Each unroll iteration contains ``unroll_length`` unroll steps.
703710 """
704711 if self ._current_time_step is None :
705712 self ._current_time_step = common .get_initial_time_step (self ._env )
@@ -750,10 +757,11 @@ def _sync_unroll(self, unroll_length: int):
750757 if self ._overwrite_policy_output :
751758 policy_step = policy_step ._replace (
752759 output = next_time_step .prev_action )
753- store_exp_time_i , effective_unroll_steps = self ._process_unroll_step (
760+ store_exp_time_i , effective_unroll_steps_i = self ._process_unroll_step (
754761 policy_step , action , time_step , transformed_time_step ,
755762 policy_state , experience_list , original_reward_list )
756763 store_exp_time += store_exp_time_i
764+ effective_unroll_steps += effective_unroll_steps_i
757765
758766 time_step = next_time_step
759767 policy_state = policy_step .state
@@ -781,7 +789,8 @@ def _sync_unroll(self, unroll_length: int):
781789 self ._current_policy_state = common .detach (policy_state )
782790 self ._current_transform_state = common .detach (trans_state )
783791
784- return experience , effective_unroll_steps
792+ effective_unroll_iters = effective_unroll_steps // unroll_length
793+ return experience , effective_unroll_iters
785794
786795 def train_iter (self ):
787796 """Perform one iteration of training.
0 commit comments