@@ -572,10 +572,11 @@ def _async_unroll(self, unroll_length: int):
572572 step_time += unroll_result .step_time
573573 max_step_time = max (max_step_time , unroll_result .step_time )
574574
575- store_exp_time + = self ._process_unroll_step (
575+ store_exp_time_i , effective_unroll_steps = self ._process_unroll_step (
576576 policy_step , policy_step .output , time_step ,
577577 transformed_time_step , policy_state , experience_list ,
578578 original_reward_list )
579+ store_exp_time += store_exp_time_i
579580
580581 alf .summary .scalar ("time/unroll_env_step" ,
581582 env_step_time ,
@@ -602,7 +603,7 @@ def _async_unroll(self, unroll_length: int):
602603
603604 self ._current_transform_state = common .detach (trans_state )
604605
605- return experience
606+ return experience , effective_unroll_steps
606607
607608 def should_post_process_episode (self , rollout_info , step_type : StepType ):
608609 """A function that determines whether the ``post_process_episode`` function should
@@ -804,7 +805,7 @@ def _compute_train_info_and_loss_info_on_policy(self, unroll_length):
804805 with record_time ("time/unroll" ):
805806 with torch .cuda .amp .autocast (self ._config .enable_amp ,
806807 dtype = self ._config .amp_dtype ):
807- experience = self .unroll (self ._config .unroll_length )
808+ experience , _ = self .unroll (self ._config .unroll_length )
808809 self .summarize_metrics ()
809810
810811 train_info = experience .rollout_info
0 commit comments