Skip to content

Commit 45c8321

Browse files
committed
Update
1 parent a05e8da commit 45c8321

File tree

1 file changed

+59
-50
lines changed

1 file changed

+59
-50
lines changed

alf/algorithms/rl_algorithm.py

Lines changed: 59 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ def __init__(self,
147147
optimizer=None,
148148
checkpoint=None,
149149
is_eval: bool = False,
150-
episodic_annotation: bool = False,
151150
overwrite_policy_output=False,
152151
debug_summaries=False,
153152
name="RLAlgorithm"):
@@ -187,10 +186,6 @@ def __init__(self,
187186
during deployment. In this case, the algorithm do not need to
188187
create certain components such as value_network for ActorCriticAlgorithm,
189188
critic_networks for SacAlgorithm.
190-
episodic_annotation: episodic annotation is an operation that annotates the
191-
episode after it being collected, and then the annotated episode will be
192-
observed by the replay buffer. If True, annotate the episode before being
193-
observed by the replay buffer. Otherwise, episodic annotation is not applied.
194189
overwrite_policy_output (bool): if True, overwrite the policy output
195190
with next_step.prev_action. This option can be used in some
196191
cases such as data collection.
@@ -208,7 +203,6 @@ def __init__(self,
208203
debug_summaries=debug_summaries,
209204
name=name)
210205
self._is_eval = is_eval
211-
self._episodic_annotation = episodic_annotation
212206

213207
self._env = env
214208
self._observation_spec = observation_spec
@@ -241,7 +235,6 @@ def __init__(self,
241235
self._current_time_step = None
242236
self._current_policy_state = None
243237
self._current_transform_state = None
244-
self._cached_exp = [] # for lazy observation
245238
if self._env is not None and not self.on_policy:
246239
replay_buffer_length = adjust_replay_buffer_length(
247240
config, self._num_earliest_frames_ignored)
@@ -550,6 +543,7 @@ def _async_unroll(self, unroll_length: int):
550543
store_exp_time = 0.
551544
step_time = 0.
552545
max_step_time = 0.
546+
effective_unroll_steps = 0
553547
qsize = self._async_unroller.get_queue_size()
554548
unroll_results = self._async_unroller.gather_unroll_results(
555549
unroll_length, self._config.max_unroll_length)
@@ -572,11 +566,12 @@ def _async_unroll(self, unroll_length: int):
572566
step_time += unroll_result.step_time
573567
max_step_time = max(max_step_time, unroll_result.step_time)
574568

575-
store_exp_time_i, effective_unroll_steps = self._process_unroll_step(
569+
store_exp_time_i, effective_unroll_steps_i = self._process_unroll_step(
576570
policy_step, policy_step.output, time_step,
577571
transformed_time_step, policy_state, experience_list,
578572
original_reward_list)
579573
store_exp_time += store_exp_time_i
574+
effective_unroll_steps += effective_unroll_steps_i
580575

581576
alf.summary.scalar("time/unroll_env_step",
582577
env_step_time,
@@ -603,70 +598,80 @@ def _async_unroll(self, unroll_length: int):
603598

604599
self._current_transform_state = common.detach(trans_state)
605600

606-
return experience, effective_unroll_steps
607-
608-
def should_post_process_episode(self, rollout_info, step_type: StepType):
609-
"""A function that determines whether the ``post_process_episode`` function should
610-
be applied to the current list of experiences.
611-
Users can customize this function in the derived class.
612-
Bu default, it returns True all the time steps. When this is combined with
613-
``post_process_episode`` which simply return the input unmodified (as the default
614-
implementation in this class), it is a dummy version of eposodic annotation with
615-
logic equivalent to the case of episodic_annotation=False.
601+
effective_unroll_iters = effective_unroll_steps // unroll_length
602+
return experience, effective_unroll_iters
603+
604+
def should_post_process_experience(self, rollout_info,
605+
step_type: StepType):
606+
"""A function that determines whether the ``post_process_experience`` function should
607+
be called. Users can customize this pair of functions in the derived class to achieve
608+
different effects. For example:
609+
- per-step processing: ``should_post_process_experience``
610+
returns True for all the steps (by default), and ``post_process_experience``
611+
returns the current step of experience unmodified (by default) or a modified version
612+
according to their customized ``post_process_experience`` function.
613+
As another example, task filtering can be simply achieved by returning ``[]``
614+
in ``post_process_experience`` for that particular task.
615+
- per-episode processing: ``should_post_process_experience`` returns True on episode
616+
end and ``post_process_experience`` can return a list of cached and processed
617+
experiences. For example, this can be used for success episode labeling.
616618
"""
617619
return True
618620

619-
def post_process_episode(self, experiences: List[Experience]):
621+
def post_process_experience(self, experiences: Experience):
620622
"""A function for postprocessing a list of experience. It is called when
621-
``should_post_process_episode`` is True.
623+
``should_post_process_experience`` is True.
622624
By default, it returns the input unmodified.
623625
Users can customize this function in the derived class, to create a number of
624626
useful features such as 'hindsight relabeling' of a trajectory etc.
625627
626628
Args:
627-
experiences: a list of experience, containing the experience starting from the
628-
initial time when ``should_post_process_episode`` is False to the step where
629-
``should_post_process_episode`` is True.
629+
experiences: one step of experience.
630+
631+
Returns:
632+
A list of experiences. Users can customize this pair of functions in the
633+
derived class to achieve different effects. For example:
634+
- return a list that contains only the input experience (default behavior).
635+
- return a list that contains a number of experiences. This can be useful
636+
for episode processing such as success episode labeling.
630637
"""
631-
return experiences
638+
return [experiences]
632639

633640
def _process_unroll_step(self, policy_step, action, time_step,
634641
transformed_time_step, policy_state,
635642
experience_list, original_reward_list):
643+
"""A function for processing the unroll steps.
644+
By default, it returns the input unmodified.
645+
Users can customize this function in the derived class, to create a number of
646+
useful features such as 'hindsight relabeling' of a trajectory etc.
647+
648+
Args:
649+
experiences: a list of experience, containing the experience starting from the
650+
initial time when ``should_post_process_experience`` is False to the step where
651+
``should_post_process_experience`` is True.
652+
"""
653+
636654
self.observe_for_metrics(time_step.cpu())
637655
exp = make_experience(time_step.cpu(),
638656
alf.layers.to_float32(policy_step),
639657
alf.layers.to_float32(policy_state))
640-
effective_number_of_unroll_steps = 1
641-
if self._episodic_annotation:
642-
assert not self.on_policy, "only support episodic annotation for off policy training"
643-
store_exp_time = 0
644-
# if last step, annotate
658+
effective_unroll_steps = 1
659+
store_exp_time = 0
660+
if not self.on_policy:
645661
rollout_info = policy_step.info
646-
self._cached_exp.append(exp)
647-
if self.should_post_process_episode(rollout_info,
648-
time_step.step_type):
649-
662+
if self.should_post_process_experience(rollout_info,
663+
time_step.step_type):
650664
# 1) process
651-
annotated_exp_list = self.post_process_episode(
652-
self._cached_exp)
653-
effective_number_of_unroll_steps = len(annotated_exp_list)
665+
post_processed_exp_list = self.post_process_experience(exp)
666+
effective_unroll_steps = len(post_processed_exp_list)
654667
# 2) observe
655668
t0 = time.time()
656-
for exp in annotated_exp_list:
669+
for exp in post_processed_exp_list:
657670
self.observe_for_replay(exp)
658671
store_exp_time = time.time() - t0
659-
# clean up the exp cache
660-
self._cached_exp = []
661672
else:
662-
# effective unroll steps as 0 if not post_process_episode timepoint yet
663-
effective_number_of_unroll_steps = 0
664-
else:
665-
store_exp_time = 0
666-
if not self.on_policy:
667-
t0 = time.time()
668-
self.observe_for_replay(exp)
669-
store_exp_time = time.time() - t0
673+
# effective unroll steps as 0 if ``should_post_process_experience condition`` is False
674+
effective_unroll_steps = 0
670675

671676
exp_for_training = Experience(
672677
time_step=transformed_time_step,
@@ -676,7 +681,7 @@ def _process_unroll_step(self, policy_step, action, time_step,
676681

677682
experience_list.append(exp_for_training)
678683
original_reward_list.append(time_step.reward)
679-
return store_exp_time, effective_number_of_unroll_steps
684+
return store_exp_time, effective_unroll_steps
680685

681686
def reset_state(self):
682687
"""Reset the state of the algorithm.
@@ -700,6 +705,8 @@ def _sync_unroll(self, unroll_length: int):
700705
Returns:
701706
Experience: The stacked experience with shape :math:`[T, B, \ldots]`
702707
for each of its members.
708+
effective_unroll_iters: the effective number of unroll iterations.
709+
Each unroll iteration contains ``unroll_length`` unroll steps.
703710
"""
704711
if self._current_time_step is None:
705712
self._current_time_step = common.get_initial_time_step(self._env)
@@ -750,10 +757,11 @@ def _sync_unroll(self, unroll_length: int):
750757
if self._overwrite_policy_output:
751758
policy_step = policy_step._replace(
752759
output=next_time_step.prev_action)
753-
store_exp_time_i, effective_unroll_steps = self._process_unroll_step(
760+
store_exp_time_i, effective_unroll_steps_i = self._process_unroll_step(
754761
policy_step, action, time_step, transformed_time_step,
755762
policy_state, experience_list, original_reward_list)
756763
store_exp_time += store_exp_time_i
764+
effective_unroll_steps += effective_unroll_steps_i
757765

758766
time_step = next_time_step
759767
policy_state = policy_step.state
@@ -781,7 +789,8 @@ def _sync_unroll(self, unroll_length: int):
781789
self._current_policy_state = common.detach(policy_state)
782790
self._current_transform_state = common.detach(trans_state)
783791

784-
return experience, effective_unroll_steps
792+
effective_unroll_iters = effective_unroll_steps // unroll_length
793+
return experience, effective_unroll_iters
785794

786795
def train_iter(self):
787796
"""Perform one iteration of training.

0 commit comments

Comments
 (0)