@@ -187,8 +187,10 @@ def __init__(self,
187187 during deployment. In this case, the algorithm do not need to
188188 create certain components such as value_network for ActorCriticAlgorithm,
189189 critic_networks for SacAlgorithm.
190- episodic_annotation: if True, annotate the episode before being observed by the
191- replay buffer.
190+ episodic_annotation: episodic annotation is an operation that annotates the
191+ episode after it being collected, and then the annotated episode will be
192+ observed by the replay buffer. If True, annotate the episode before being
193+ observed by the replay buffer. Otherwise, episodic annotation is not applied.
192194 overwrite_policy_output (bool): if True, overwrite the policy output
193195 with next_step.prev_action. This option can be used in some
194196 cases such as data collection.
@@ -244,9 +246,6 @@ def __init__(self,
244246 replay_buffer_length = adjust_replay_buffer_length (
245247 config , self ._num_earliest_frames_ignored )
246248
247- if self ._episodic_annotation :
248- assert self ._env .batch_size == 1 , "only support non-batched environment"
249-
250249 if config .whole_replay_buffer_training and config .clear_replay_buffer :
251250 # For whole replay buffer training, we would like to be sure
252251 # that the replay buffer have enough samples in it to perform
@@ -608,21 +607,27 @@ def _async_unroll(self, unroll_length: int):
608607 def should_post_process_episode (self , rollout_info , step_type : StepType ):
609608 """A function that determines whether the ``post_process_episode`` function should
610609 be applied to the current list of experiences.
610+ Users can customize this function in the derived class.
611+ Bu default, it returns True all the time steps. When this is combined with
612+ ``post_process_episode`` which simply return the input unmodified (as the default
613+ implementation in this class), it is a dummy version of eposodic annotation with
614+ logic equivalent to the case of episodic_annotation=False.
611615 """
612- return False
616+ return True
613617
614618 def post_process_episode (self , experiences : List [Experience ]):
615619 """A function for postprocessing a list of experience. It is called when
616620 ``should_post_process_episode`` is True.
617- It can be used to create a number of useful features such as 'hindsight relabeling'
618- of a trajectory etc.
621+ By default, it returns the input unmodified.
622+ Users can customize this function in the derived class, to create a number of
623+ useful features such as 'hindsight relabeling' of a trajectory etc.
619624
620625 Args:
621626 experiences: a list of experience, containing the experience starting from the
622627 initial time when ``should_post_process_episode`` is False to the step where
623628 ``should_post_process_episode`` is True.
624629 """
625- return None
630+ return experiences
626631
627632 def _process_unroll_step (self , policy_step , action , time_step ,
628633 transformed_time_step , policy_state ,
@@ -633,6 +638,7 @@ def _process_unroll_step(self, policy_step, action, time_step,
633638 alf .layers .to_float32 (policy_state ))
634639 effective_number_of_unroll_steps = 1
635640 if self ._episodic_annotation :
641+ assert not self .on_policy , "only support episodic annotation for off policy training"
636642 store_exp_time = 0
637643 # if last step, annotate
638644 rollout_info = policy_step .info
@@ -645,11 +651,10 @@ def _process_unroll_step(self, policy_step, action, time_step,
645651 self ._cached_exp )
646652 effective_number_of_unroll_steps = len (annotated_exp_list )
647653 # 2) observe
648- if not self .on_policy :
649- t0 = time .time ()
650- for exp in annotated_exp_list :
651- self .observe_for_replay (exp )
652- store_exp_time = time .time () - t0
654+ t0 = time .time ()
655+ for exp in annotated_exp_list :
656+ self .observe_for_replay (exp )
657+ store_exp_time = time .time () - t0
653658 # clean up the exp cache
654659 self ._cached_exp = []
655660 else :
@@ -903,8 +908,7 @@ def _train_iter_off_policy(self):
903908 self .train ()
904909 steps = 0
905910 for i in range (effective_unroll_steps ):
906- steps += self .train_from_replay_buffer (effective_unroll_steps = 1 ,
907- update_global_counter = True )
911+ steps += self .train_from_replay_buffer (update_global_counter = True )
908912 if unrolled :
909913 with record_time ("time/after_train_iter" ):
910914 self .after_train_iter (root_inputs , rollout_info )
0 commit comments