1919import os
2020import time
2121import torch
22- from typing import Callable , Optional
22+ from typing import Callable , List , Optional , Tuple
2323from absl import logging
2424
2525import alf
@@ -605,7 +605,7 @@ def _async_unroll(self, unroll_length: int):
605605 return experience , effective_unroll_iters
606606
607607 def post_process_experience (self , rollout_info , step_type : StepType ,
608- experiences : Experience ):
608+ experiences : Experience ) -> Tuple [ List , int ] :
609609 """A function for postprocessing experience. By default, it returns the input
610610 experience unmodified. Users can customize this function in the derived
611611 class to achieve different effects. For example:
@@ -622,17 +622,22 @@ class to achieve different effects. For example:
622622 experiences: one step of experience.
623623
624624 Returns:
625- A list of experiences. Users can customize this functions in the
626- derived class to achieve different effects. For example:
627- - return a list that contains only the input experience (default behavior).
628- - return a list that contains a number of experiences. This can be useful
629- for episode processing such as success episode labeling.
625+ - a list of experiences. Users can customize this functions in the
626+ derived class to achieve different effects. For example:
627+ * return a list that contains only the input experience (default behavior).
628+ * return a list that contains a number of experiences. This can be useful
629+ for episode processing such as success episode labeling.
630+ - an integer representing the effective number of unroll steps per env. The
631+ default value of 1, meaning the length of effective experience is 1
632+ after calling ``post_process_experience``, the same as the input length
633+ of experience.
630634 """
631- return [experiences ]
635+ return [experiences ], 1
632636
633637 def _process_unroll_step (self , policy_step , action , time_step ,
634638 transformed_time_step , policy_state ,
635- experience_list , original_reward_list ):
639+ experience_list ,
640+ original_reward_list ) -> Tuple [int , int ]:
636641 self .observe_for_metrics (time_step .cpu ())
637642 exp = make_experience (time_step .cpu (),
638643 alf .layers .to_float32 (policy_step ),
@@ -641,11 +646,8 @@ def _process_unroll_step(self, policy_step, action, time_step,
641646 store_exp_time = 0
642647 if not self .on_policy :
643648 # 1) post process
644- post_processed_exp_list = self .post_process_experience (
649+ post_processed_exp_list , effective_unroll_steps = self .post_process_experience (
645650 policy_step .info , time_step .step_type , exp )
646- effective_unroll_steps = sum (
647- exp .step_type .shape [0 ]
648- for exp in post_processed_exp_list ) / exp .step_type .shape [0 ]
649651 # 2) observe
650652 t0 = time .time ()
651653 for exp in post_processed_exp_list :
0 commit comments