Skip to content

Commit e4cdb81

Browse files
committed
Address comments
1 parent 00efea8 commit e4cdb81

File tree

2 files changed

+22
-21
lines changed

2 files changed

+22
-21
lines changed

alf/algorithms/algorithm.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,9 +1426,7 @@ def train_from_unroll(self, experience, train_info):
14261426
return shape[0] * shape[1]
14271427

14281428
@common.mark_replay
1429-
def train_from_replay_buffer(self,
1430-
effective_unroll_steps,
1431-
update_global_counter=False):
1429+
def train_from_replay_buffer(self, update_global_counter=False):
14321430
"""This function can be called by any algorithm that has its own
14331431
replay buffer configured. There are several parameters specified in
14341432
``self._config`` that will affect how the training is performed:
@@ -1482,8 +1480,7 @@ def train_from_replay_buffer(self,
14821480
# training is not started yet, ``_replay_buffer`` will be None since it
14831481
# is only lazily created later when online RL training started.
14841482
if (self._replay_buffer and self._replay_buffer.total_size
1485-
< config.initial_collect_steps) or (effective_unroll_steps
1486-
== 0):
1483+
< config.initial_collect_steps):
14871484
assert (
14881485
self._replay_buffer.num_environments *
14891486
self._replay_buffer.max_length >= config.initial_collect_steps

alf/algorithms/rl_algorithm.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,10 @@ def __init__(self,
187187
during deployment. In this case, the algorithm do not need to
188188
create certain components such as value_network for ActorCriticAlgorithm,
189189
critic_networks for SacAlgorithm.
190-
episodic_annotation: if True, annotate the episode before being observed by the
191-
replay buffer.
190+
episodic_annotation: episodic annotation is an operation that annotates the
191+
episode after it being collected, and then the annotated episode will be
192+
observed by the replay buffer. If True, annotate the episode before being
193+
observed by the replay buffer. Otherwise, episodic annotation is not applied.
192194
overwrite_policy_output (bool): if True, overwrite the policy output
193195
with next_step.prev_action. This option can be used in some
194196
cases such as data collection.
@@ -244,9 +246,6 @@ def __init__(self,
244246
replay_buffer_length = adjust_replay_buffer_length(
245247
config, self._num_earliest_frames_ignored)
246248

247-
if self._episodic_annotation:
248-
assert self._env.batch_size == 1, "only support non-batched environment"
249-
250249
if config.whole_replay_buffer_training and config.clear_replay_buffer:
251250
# For whole replay buffer training, we would like to be sure
252251
# that the replay buffer have enough samples in it to perform
@@ -608,21 +607,27 @@ def _async_unroll(self, unroll_length: int):
608607
def should_post_process_episode(self, rollout_info, step_type: StepType):
609608
"""A function that determines whether the ``post_process_episode`` function should
610609
be applied to the current list of experiences.
610+
Users can customize this function in the derived class.
611+
Bu default, it returns True all the time steps. When this is combined with
612+
``post_process_episode`` which simply return the input unmodified (as the default
613+
implementation in this class), it is a dummy version of eposodic annotation with
614+
logic equivalent to the case of episodic_annotation=False.
611615
"""
612-
return False
616+
return True
613617

614618
def post_process_episode(self, experiences: List[Experience]):
615619
"""A function for postprocessing a list of experience. It is called when
616620
``should_post_process_episode`` is True.
617-
It can be used to create a number of useful features such as 'hindsight relabeling'
618-
of a trajectory etc.
621+
By default, it returns the input unmodified.
622+
Users can customize this function in the derived class, to create a number of
623+
useful features such as 'hindsight relabeling' of a trajectory etc.
619624
620625
Args:
621626
experiences: a list of experience, containing the experience starting from the
622627
initial time when ``should_post_process_episode`` is False to the step where
623628
``should_post_process_episode`` is True.
624629
"""
625-
return None
630+
return experiences
626631

627632
def _process_unroll_step(self, policy_step, action, time_step,
628633
transformed_time_step, policy_state,
@@ -633,6 +638,7 @@ def _process_unroll_step(self, policy_step, action, time_step,
633638
alf.layers.to_float32(policy_state))
634639
effective_number_of_unroll_steps = 1
635640
if self._episodic_annotation:
641+
assert not self.on_policy, "only support episodic annotation for off policy training"
636642
store_exp_time = 0
637643
# if last step, annotate
638644
rollout_info = policy_step.info
@@ -645,11 +651,10 @@ def _process_unroll_step(self, policy_step, action, time_step,
645651
self._cached_exp)
646652
effective_number_of_unroll_steps = len(annotated_exp_list)
647653
# 2) observe
648-
if not self.on_policy:
649-
t0 = time.time()
650-
for exp in annotated_exp_list:
651-
self.observe_for_replay(exp)
652-
store_exp_time = time.time() - t0
654+
t0 = time.time()
655+
for exp in annotated_exp_list:
656+
self.observe_for_replay(exp)
657+
store_exp_time = time.time() - t0
653658
# clean up the exp cache
654659
self._cached_exp = []
655660
else:
@@ -903,8 +908,7 @@ def _train_iter_off_policy(self):
903908
self.train()
904909
steps = 0
905910
for i in range(effective_unroll_steps):
906-
steps += self.train_from_replay_buffer(effective_unroll_steps=1,
907-
update_global_counter=True)
911+
steps += self.train_from_replay_buffer(update_global_counter=True)
908912
if unrolled:
909913
with record_time("time/after_train_iter"):
910914
self.after_train_iter(root_inputs, rollout_info)

0 commit comments

Comments
 (0)