Skip to content

Commit 734dae8

Browse files
committed
Handle fractional unroll
1 parent 26ab09a commit 734dae8

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

alf/algorithms/rl_algorithm.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def __init__(self,
235235
self._current_time_step = None
236236
self._current_policy_state = None
237237
self._current_transform_state = None
238+
238239
if self._env is not None and not self.on_policy:
239240
replay_buffer_length = adjust_replay_buffer_length(
240241
config, self._num_earliest_frames_ignored)
@@ -598,7 +599,9 @@ def _async_unroll(self, unroll_length: int):
598599

599600
self._current_transform_state = common.detach(trans_state)
600601

601-
effective_unroll_iters = effective_unroll_steps // unroll_length
602+
# if the input unroll_length is 0 (e.g. fractional unroll), then this it treated as
603+
# an effective unroll iter
604+
effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length
602605
return experience, effective_unroll_iters
603606

604607
def post_process_experience(self, rollout_info, step_type: StepType,
@@ -637,11 +640,12 @@ def _process_unroll_step(self, policy_step, action, time_step,
637640
effective_unroll_steps = 1
638641
store_exp_time = 0
639642
if not self.on_policy:
640-
rollout_info = policy_step.info
641-
# 1) process
643+
# 1) post process
642644
post_processed_exp_list = self.post_process_experience(
643-
rollout_info, time_step.step_type, exp)
644-
effective_unroll_steps = len(post_processed_exp_list)
645+
policy_step.info, time_step.step_type, exp)
646+
effective_unroll_steps = sum(
647+
exp.step_type.shape[0]
648+
for exp in post_processed_exp_list) / exp.step_type.shape[0]
645649
# 2) observe
646650
t0 = time.time()
647651
for exp in post_processed_exp_list:
@@ -764,7 +768,9 @@ def _sync_unroll(self, unroll_length: int):
764768
self._current_policy_state = common.detach(policy_state)
765769
self._current_transform_state = common.detach(trans_state)
766770

767-
effective_unroll_iters = effective_unroll_steps // unroll_length
771+
# if the input unroll_length is 0 (e.g. fractional unroll), then this it treated as
772+
# an effective unroll iter
773+
effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length
768774
return experience, effective_unroll_iters
769775

770776
def train_iter(self):

0 commit comments

Comments
 (0)