Skip to content

Commit a05e8da

Browse files
committed
Update async unroll
1 parent e4cdb81 commit a05e8da

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

alf/algorithms/rl_algorithm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -572,10 +572,11 @@ def _async_unroll(self, unroll_length: int):
572572
step_time += unroll_result.step_time
573573
max_step_time = max(max_step_time, unroll_result.step_time)
574574

575-
store_exp_time += self._process_unroll_step(
575+
store_exp_time_i, effective_unroll_steps = self._process_unroll_step(
576576
policy_step, policy_step.output, time_step,
577577
transformed_time_step, policy_state, experience_list,
578578
original_reward_list)
579+
store_exp_time += store_exp_time_i
579580

580581
alf.summary.scalar("time/unroll_env_step",
581582
env_step_time,
@@ -602,7 +603,7 @@ def _async_unroll(self, unroll_length: int):
602603

603604
self._current_transform_state = common.detach(trans_state)
604605

605-
return experience
606+
return experience, effective_unroll_steps
606607

607608
def should_post_process_episode(self, rollout_info, step_type: StepType):
608609
"""A function that determines whether the ``post_process_episode`` function should
@@ -804,7 +805,7 @@ def _compute_train_info_and_loss_info_on_policy(self, unroll_length):
804805
with record_time("time/unroll"):
805806
with torch.cuda.amp.autocast(self._config.enable_amp,
806807
dtype=self._config.amp_dtype):
807-
experience = self.unroll(self._config.unroll_length)
808+
experience, _ = self.unroll(self._config.unroll_length)
808809
self.summarize_metrics()
809810

810811
train_info = experience.rollout_info

0 commit comments

Comments
 (0)