Skip to content

Commit 26ab09a

Browse files
committed
Address more comments
1 parent 9cfe6a5 commit 26ab09a

File tree

2 files changed

+24
-51
lines changed

2 files changed

+24
-51
lines changed

alf/algorithms/algorithm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1469,7 +1469,6 @@ def train_from_replay_buffer(self, update_global_counter=False):
14691469
``True``, it will affect the counter only if
14701470
``config.update_counter_every_mini_batch=True``.
14711471
"""
1472-
14731472
config: TrainerConfig = self._config
14741473

14751474
# returns 0 if haven't started training yet, when ``_replay_buffer`` is
@@ -1494,22 +1493,21 @@ def _replay():
14941493
# ``_replay_buffer`` for training.
14951494
# TODO: If this function can be called asynchronously, and using
14961495
# prioritized replay, then make sure replay and train below is atomic.
1497-
effective_num_updates_per_train_iter = config.num_updates_per_train_iter
14981496
with record_time("time/replay"):
14991497
mini_batch_size = config.mini_batch_size
15001498
if mini_batch_size is None:
15011499
mini_batch_size = self._replay_buffer.num_environments
15021500
if config.whole_replay_buffer_training:
15031501
experience, batch_info = self._replay_buffer.gather_all(
15041502
ignore_earliest_frames=True)
1505-
num_updates = effective_num_updates_per_train_iter
1503+
num_updates = config.num_updates_per_train_iter
15061504
else:
15071505
assert config.mini_batch_length is not None, (
15081506
"No mini_batch_length is specified for off-policy training"
15091507
)
15101508
experience, batch_info = self._replay_buffer.get_batch(
15111509
batch_size=(mini_batch_size *
1512-
effective_num_updates_per_train_iter),
1510+
config.num_updates_per_train_iter),
15131511
batch_length=config.mini_batch_length)
15141512
num_updates = 1
15151513
return experience, batch_info, num_updates, mini_batch_size

alf/algorithms/rl_algorithm.py

Lines changed: 22 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os
2020
import time
2121
import torch
22-
from typing import Callable, List, Optional
22+
from typing import Callable, Optional
2323
from absl import logging
2424

2525
import alf
@@ -601,35 +601,25 @@ def _async_unroll(self, unroll_length: int):
601601
effective_unroll_iters = effective_unroll_steps // unroll_length
602602
return experience, effective_unroll_iters
603603

604-
def should_post_process_experience(self, rollout_info,
605-
step_type: StepType):
606-
"""A function that determines whether the ``post_process_experience`` function should
607-
be called. Users can customize this pair of functions in the derived class to achieve
608-
different effects. For example:
609-
- per-step processing: ``should_post_process_experience``
610-
returns True for all the steps (by default), and ``post_process_experience``
611-
returns the current step of experience unmodified (by default) or a modified version
612-
according to their customized ``post_process_experience`` function.
604+
def post_process_experience(self, rollout_info, step_type: StepType,
605+
experiences: Experience):
606+
"""A function for postprocessing experience. By default, it returns the input
607+
experience unmodified. Users can customize this function in the derived
608+
class to achieve different effects. For example:
609+
- per-step processing: return the current step of experience unmodified (by default)
610+
or a modified version according to the customized ``post_process_experience``.
613611
As another example, task filtering can be simply achieved by returning ``[]``
614-
in ``post_process_experience`` for that particular task.
615-
- per-episode processing: ``should_post_process_experience`` returns True on episode
616-
end and ``post_process_experience`` can return a list of cached and processed
612+
for that particular task.
613+
- per-episode processing: this can be achieved by returning a list of processed
617614
experiences. For example, this can be used for success episode labeling.
618-
"""
619-
return True
620-
621-
def post_process_experience(self, experiences: Experience):
622-
"""A function for postprocessing a list of experience. It is called when
623-
``should_post_process_experience`` is True.
624-
By default, it returns the input unmodified.
625-
Users can customize this function in the derived class, to create a number of
626-
useful features such as 'hindsight relabeling' of a trajectory etc.
627615
628616
Args:
617+
rollout_info: the rollout info.
618+
step_type: the step type of the current experience.
629619
experiences: one step of experience.
630620
631621
Returns:
632-
A list of experiences. Users can customize this pair of functions in the
622+
A list of experiences. Users can customize this functions in the
633623
derived class to achieve different effects. For example:
634624
- return a list that contains only the input experience (default behavior).
635625
- return a list that contains a number of experiences. This can be useful
@@ -640,17 +630,6 @@ def post_process_experience(self, experiences: Experience):
640630
def _process_unroll_step(self, policy_step, action, time_step,
641631
transformed_time_step, policy_state,
642632
experience_list, original_reward_list):
643-
"""A function for processing the unroll steps.
644-
By default, it returns the input unmodified.
645-
Users can customize this function in the derived class, to create a number of
646-
useful features such as 'hindsight relabeling' of a trajectory etc.
647-
648-
Args:
649-
experiences: a list of experience, containing the experience starting from the
650-
initial time when ``should_post_process_experience`` is False to the step where
651-
``should_post_process_experience`` is True.
652-
"""
653-
654633
self.observe_for_metrics(time_step.cpu())
655634
exp = make_experience(time_step.cpu(),
656635
alf.layers.to_float32(policy_step),
@@ -659,19 +638,15 @@ def _process_unroll_step(self, policy_step, action, time_step,
659638
store_exp_time = 0
660639
if not self.on_policy:
661640
rollout_info = policy_step.info
662-
if self.should_post_process_experience(rollout_info,
663-
time_step.step_type):
664-
# 1) process
665-
post_processed_exp_list = self.post_process_experience(exp)
666-
effective_unroll_steps = len(post_processed_exp_list)
667-
# 2) observe
668-
t0 = time.time()
669-
for exp in post_processed_exp_list:
670-
self.observe_for_replay(exp)
671-
store_exp_time = time.time() - t0
672-
else:
673-
# effective unroll steps as 0 if ``should_post_process_experience condition`` is False
674-
effective_unroll_steps = 0
641+
# 1) process
642+
post_processed_exp_list = self.post_process_experience(
643+
rollout_info, time_step.step_type, exp)
644+
effective_unroll_steps = len(post_processed_exp_list)
645+
# 2) observe
646+
t0 = time.time()
647+
for exp in post_processed_exp_list:
648+
self.observe_for_replay(exp)
649+
store_exp_time = time.time() - t0
675650

676651
exp_for_training = Experience(
677652
time_step=transformed_time_step,

0 commit comments

Comments
 (0)