1919import os
2020import time
2121import torch
22- from typing import Callable , List , Optional
22+ from typing import Callable , Optional
2323from absl import logging
2424
2525import alf
@@ -601,35 +601,25 @@ def _async_unroll(self, unroll_length: int):
601601 effective_unroll_iters = effective_unroll_steps // unroll_length
602602 return experience , effective_unroll_iters
603603
604- def should_post_process_experience (self , rollout_info ,
605- step_type : StepType ):
606- """A function that determines whether the ``post_process_experience`` function should
607- be called. Users can customize this pair of functions in the derived class to achieve
608- different effects. For example:
609- - per-step processing: ``should_post_process_experience``
610- returns True for all the steps (by default), and ``post_process_experience``
611- returns the current step of experience unmodified (by default) or a modified version
612- according to their customized ``post_process_experience`` function.
604+ def post_process_experience (self , rollout_info , step_type : StepType ,
605+ experiences : Experience ):
606+ """A function for postprocessing experience. By default, it returns the input
607+ experience unmodified. Users can customize this function in the derived
608+ class to achieve different effects. For example:
609+ - per-step processing: return the current step of experience unmodified (by default)
610+ or a modified version according to the customized ``post_process_experience``.
613611 As another example, task filtering can be simply achieved by returning ``[]``
614- in ``post_process_experience`` for that particular task.
615- - per-episode processing: ``should_post_process_experience`` returns True on episode
616- end and ``post_process_experience`` can return a list of cached and processed
612+ for that particular task.
613+ - per-episode processing: this can be achieved by returning a list of processed
617614 experiences. For example, this can be used for success episode labeling.
618- """
619- return True
620-
621- def post_process_experience (self , experiences : Experience ):
622- """A function for postprocessing a list of experience. It is called when
623- ``should_post_process_experience`` is True.
624- By default, it returns the input unmodified.
625- Users can customize this function in the derived class, to create a number of
626- useful features such as 'hindsight relabeling' of a trajectory etc.
627615
628616 Args:
617+ rollout_info: the rollout info.
618+ step_type: the step type of the current experience.
629619 experiences: one step of experience.
630620
631621 Returns:
632- A list of experiences. Users can customize this pair of functions in the
622+ A list of experiences. Users can customize this functions in the
633623 derived class to achieve different effects. For example:
634624 - return a list that contains only the input experience (default behavior).
635625 - return a list that contains a number of experiences. This can be useful
@@ -640,17 +630,6 @@ def post_process_experience(self, experiences: Experience):
640630 def _process_unroll_step (self , policy_step , action , time_step ,
641631 transformed_time_step , policy_state ,
642632 experience_list , original_reward_list ):
643- """A function for processing the unroll steps.
644- By default, it returns the input unmodified.
645- Users can customize this function in the derived class, to create a number of
646- useful features such as 'hindsight relabeling' of a trajectory etc.
647-
648- Args:
649- experiences: a list of experience, containing the experience starting from the
650- initial time when ``should_post_process_experience`` is False to the step where
651- ``should_post_process_experience`` is True.
652- """
653-
654633 self .observe_for_metrics (time_step .cpu ())
655634 exp = make_experience (time_step .cpu (),
656635 alf .layers .to_float32 (policy_step ),
@@ -659,19 +638,15 @@ def _process_unroll_step(self, policy_step, action, time_step,
659638 store_exp_time = 0
660639 if not self .on_policy :
661640 rollout_info = policy_step .info
662- if self .should_post_process_experience (rollout_info ,
663- time_step .step_type ):
664- # 1) process
665- post_processed_exp_list = self .post_process_experience (exp )
666- effective_unroll_steps = len (post_processed_exp_list )
667- # 2) observe
668- t0 = time .time ()
669- for exp in post_processed_exp_list :
670- self .observe_for_replay (exp )
671- store_exp_time = time .time () - t0
672- else :
673- # effective unroll steps as 0 if ``should_post_process_experience condition`` is False
674- effective_unroll_steps = 0
641+ # 1) process
642+ post_processed_exp_list = self .post_process_experience (
643+ rollout_info , time_step .step_type , exp )
644+ effective_unroll_steps = len (post_processed_exp_list )
645+ # 2) observe
646+ t0 = time .time ()
647+ for exp in post_processed_exp_list :
648+ self .observe_for_replay (exp )
649+ store_exp_time = time .time () - t0
675650
676651 exp_for_training = Experience (
677652 time_step = transformed_time_step ,
0 commit comments