@@ -235,6 +235,7 @@ def __init__(self,
235235 self ._current_time_step = None
236236 self ._current_policy_state = None
237237 self ._current_transform_state = None
238+
238239 if self ._env is not None and not self .on_policy :
239240 replay_buffer_length = adjust_replay_buffer_length (
240241 config , self ._num_earliest_frames_ignored )
@@ -598,7 +599,9 @@ def _async_unroll(self, unroll_length: int):
598599
599600 self ._current_transform_state = common .detach (trans_state )
600601
601- effective_unroll_iters = effective_unroll_steps // unroll_length
602+ # if the input unroll_length is 0 (e.g. fractional unroll), then this it treated as
603+ # an effective unroll iter
604+ effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length
602605 return experience , effective_unroll_iters
603606
604607 def post_process_experience (self , rollout_info , step_type : StepType ,
@@ -637,11 +640,12 @@ def _process_unroll_step(self, policy_step, action, time_step,
637640 effective_unroll_steps = 1
638641 store_exp_time = 0
639642 if not self .on_policy :
640- rollout_info = policy_step .info
641- # 1) process
643+ # 1) post process
642644 post_processed_exp_list = self .post_process_experience (
643- rollout_info , time_step .step_type , exp )
644- effective_unroll_steps = len (post_processed_exp_list )
645+ policy_step .info , time_step .step_type , exp )
646+ effective_unroll_steps = sum (
647+ exp .step_type .shape [0 ]
648+ for exp in post_processed_exp_list ) / exp .step_type .shape [0 ]
645649 # 2) observe
646650 t0 = time .time ()
647651 for exp in post_processed_exp_list :
@@ -764,7 +768,9 @@ def _sync_unroll(self, unroll_length: int):
764768 self ._current_policy_state = common .detach (policy_state )
765769 self ._current_transform_state = common .detach (trans_state )
766770
767- effective_unroll_iters = effective_unroll_steps // unroll_length
771+ # if the input unroll_length is 0 (e.g. fractional unroll), then this it treated as
772+ # an effective unroll iter
773+ effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length
768774 return experience , effective_unroll_iters
769775
770776 def train_iter (self ):
0 commit comments