Skip to content

Commit 8fc3ff2

Browse files
committed
Address comments
1 parent 94a50bf commit 8fc3ff2

File tree

1 file changed

+64
-16
lines changed

1 file changed

+64
-16
lines changed

alf/algorithms/rl_algorithm.py

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -604,13 +604,15 @@ def _async_unroll(self, unroll_length: int):
604604
effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length
605605
return experience, effective_unroll_iters
606606

607-
def post_process_experience(self, rollout_info, step_type: StepType,
608-
experiences: Experience) -> Tuple[List, int]:
609-
"""A function for postprocessing experience. By default, it returns the input
607+
def preprocess_unroll_experience(
608+
self, rollout_info, step_type: StepType,
609+
experiences: Experience) -> Tuple[List, float]:
610+
"""A function for processing the experience obtained from an unroll step before
611+
being saved into the replay buffer. By default, it returns the input
610612
experience unmodified. Users can customize this function in the derived
611613
class to achieve different effects. For example:
612614
- per-step processing: return the current step of experience unmodified (by default)
613-
or a modified version according to the customized ``post_process_experience``.
615+
or a modified version according to the customized ``preprocess_unroll_experience``.
614616
As another example, task filtering can be simply achieved by returning ``[]``
615617
for that particular task.
616618
- per-episode processing: this can be achieved by returning a list of processed
@@ -622,35 +624,79 @@ class to achieve different effects. For example:
622624
experiences: one step of experience.
623625
624626
Returns:
625-
- a list of experiences. Users can customize this functions in the
626-
derived class to achieve different effects. For example:
627+
- ``effective_experiences``: a list of experiences. Users can customize this
628+
functions in the derived class to achieve different effects. For example:
627629
* return a list that contains only the input experience (default behavior).
628630
* return a list that contains a number of experiences. This can be useful
629631
for episode processing such as success episode labeling.
630-
- an integer representing the effective number of unroll steps per env. The
631-
default value of 1, meaning the length of effective experience is 1
632-
after calling ``post_process_experience``, the same as the input length
633-
of experience.
632+
- ``effective_unroll_steps`` : a value representing the effective number of
633+
unroll steps per env. The default value of 1, meaning the length of
634+
effective experience is 1 after calling ``preprocess_unroll_experience``,
635+
the same as the input length of experience.
636+
The value of ``effective_unroll_steps`` can be set differently according
637+
to different scenarios, e.g.:
638+
(1) per-step saving without delay: saving each step of unroll experience
639+
into the replay buffer as we get it. Set ``effective_unroll_steps``
640+
as 1 so that each step will be counted as valid and there will be no
641+
impact on the train/unroll ratio.
642+
(2) all-step saving with delay: saving all the steps of unroll experience into
643+
the replay buffer with delay. This can happen in the case where we want to
644+
annotate an trajectory based on some quantities that are not immediately
645+
available in the current step (e.g. task success/failure). In this case,
646+
we can simply caching the experiences and set ``effective_experiences=[]``
647+
before obtaining the quantities required for annotation.
648+
After obtaining the quantities required for annotation, we can
649+
set ``effective_experiences`` as the cached and annotated experience.
650+
To maintain the original unroll/train iter ratio, we can set
651+
``effective_unroll_steps=1``, meaning each unroll step is regarded as
652+
effective in terms of the unroll/train iter ratio, even though the
653+
pace of saving the unroll steps into replay buffer has been altered.
654+
(3) selective saving: exclude some of the unroll experiences and only save
655+
the rest. This could be useful in the case where there are transitions
656+
that are irrelevant to the training (e.g. in the multi-task case, where
657+
we want to exclude data from certain subtasks).
658+
This can be achieved by setting ``effective_experiences=[]``for the
659+
steps to be excluded, while ``effective_experiences = [experiences]``
660+
otherwise. If we do not want to trigger a train iter for the unroll
661+
step that will be excluded, we can simply set ``effective_unroll_steps=0``.
662+
Otherwise, we can simply set ``effective_unroll_steps=1``.
663+
(4) parallel environments: in the case of parallel environments, the value
664+
of ``effective_unroll_steps`` can be set according to the modes described
665+
above and the status of each environment (e.g. ``effective_unroll_steps``
666+
can be set to an average value across environments). Note that this could
667+
resulf to a floating number.
634668
"""
635-
return [experiences], 1
669+
effective_experiences = [experiences]
670+
effective_unroll_steps = 1
671+
return effective_experiences, effective_unroll_steps
636672

637673
def _process_unroll_step(self, policy_step, action, time_step,
638674
transformed_time_step, policy_state,
639675
experience_list,
640-
original_reward_list) -> Tuple[int, int]:
676+
original_reward_list) -> Tuple[int, float]:
677+
"""
678+
679+
Returns:
680+
- ``store_exp_time``: the time spent on storing the experience
681+
- ``effective_unroll_steps``: a value representing the effective number
682+
of unroll steps per env. The default value of 1, meaning the length of
683+
effective experience is 1 after calling ``preprocess_unroll_experience``,
684+
the same as the input length of experience. For more details on it,
685+
please refer to the docstr of ``preprocess_unroll_experience``.
686+
"""
641687
self.observe_for_metrics(time_step.cpu())
642688
exp = make_experience(time_step.cpu(),
643689
alf.layers.to_float32(policy_step),
644690
alf.layers.to_float32(policy_state))
645691
effective_unroll_steps = 1
646692
store_exp_time = 0
647693
if not self.on_policy:
648-
# 1) post process
649-
post_processed_exp_list, effective_unroll_steps = self.post_process_experience(
694+
# 1) pre-process unroll experience
695+
pre_processed_exp_list, effective_unroll_steps = self.preprocess_unroll_experience(
650696
policy_step.info, time_step.step_type, exp)
651697
# 2) observe
652698
t0 = time.time()
653-
for exp in post_processed_exp_list:
699+
for exp in pre_processed_exp_list:
654700
self.observe_for_replay(exp)
655701
store_exp_time = time.time() - t0
656702

@@ -771,7 +817,9 @@ def _sync_unroll(self, unroll_length: int):
771817
self._current_transform_state = common.detach(trans_state)
772818

773819
# if the input unroll_length is 0 (e.g. fractional unroll), then this it treated as
774-
# an effective unroll iter
820+
# an effective unroll iter.
821+
# one ``effective_unroll_iter`` refers to the ``unroll_length`` times of calling
822+
# of ``rollout_step`` in the unroll phase.
775823
effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length
776824
return experience, effective_unroll_iters
777825

0 commit comments

Comments
 (0)