diff --git a/alf/algorithms/rl_algorithm.py b/alf/algorithms/rl_algorithm.py index 773647233..81ae4a27b 100644 --- a/alf/algorithms/rl_algorithm.py +++ b/alf/algorithms/rl_algorithm.py @@ -19,7 +19,7 @@ import os import time import torch -from typing import Callable, Optional +from typing import Callable, List, Optional, Tuple from absl import logging import alf @@ -544,6 +544,7 @@ def _async_unroll(self, unroll_length: int): store_exp_time = 0. step_time = 0. max_step_time = 0. + effective_unroll_steps = 0 qsize = self._async_unroller.get_queue_size() unroll_results = self._async_unroller.gather_unroll_results( unroll_length, self._config.max_unroll_length) @@ -566,10 +567,12 @@ def _async_unroll(self, unroll_length: int): step_time += unroll_result.step_time max_step_time = max(max_step_time, unroll_result.step_time) - store_exp_time += self._process_unroll_step( + store_exp_time_i, effective_unroll_steps_i = self._process_unroll_step( policy_step, policy_step.output, time_step, transformed_time_step, policy_state, experience_list, original_reward_list) + store_exp_time += store_exp_time_i + effective_unroll_steps += effective_unroll_steps_i alf.summary.scalar("time/unroll_env_step", env_step_time, @@ -596,20 +599,105 @@ def _async_unroll(self, unroll_length: int): self._current_transform_state = common.detach(trans_state) - return experience + # if the input unroll_length is 0 (e.g. fractional unroll), then this it treated as + # an effective unroll iter + effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length + return experience, effective_unroll_iters + + def preprocess_unroll_experience( + self, rollout_info, step_type: StepType, + experiences: Experience) -> Tuple[List, float]: + """A function for processing the experience obtained from an unroll step before + being saved into the replay buffer. By default, it returns the input + experience unmodified. Users can customize this function in the derived + class to achieve different effects. For example: + - per-step processing: return the current step of experience unmodified (by default) + or a modified version according to the customized ``preprocess_unroll_experience``. + As another example, task filtering can be simply achieved by returning ``[]`` + for that particular task. + - per-episode processing: this can be achieved by returning a list of processed + experiences. For example, this can be used for success episode labeling. + + Args: + rollout_info: the rollout info. + step_type: the step type of the current experience. + experiences: one step of experience. + + Returns: + - ``effective_experiences``: a list of experiences. Users can customize this + functions in the derived class to achieve different effects. For example: + * return a list that contains only the input experience (default behavior). + * return a list that contains a number of experiences. This can be useful + for episode processing such as success episode labeling. + - ``effective_unroll_steps`` : a value representing the effective number of + unroll steps per env. The default value of 1, meaning the length of + effective experience is 1 after calling ``preprocess_unroll_experience``, + the same as the input length of experience. + The value of ``effective_unroll_steps`` can be set differently according + to different scenarios, e.g.: + (1) per-step saving without delay: saving each step of unroll experience + into the replay buffer as we get it. Set ``effective_unroll_steps`` + as 1 so that each step will be counted as valid and there will be no + impact on the train/unroll ratio. + (2) all-step saving with delay: saving all the steps of unroll experience into + the replay buffer with delay. This can happen in the case where we want to + annotate an trajectory based on some quantities that are not immediately + available in the current step (e.g. task success/failure). In this case, + we can simply caching the experiences and set ``effective_experiences=[]`` + before obtaining the quantities required for annotation. + After obtaining the quantities required for annotation, we can + set ``effective_experiences`` as the cached and annotated experience. + To maintain the original unroll/train iter ratio, we can set + ``effective_unroll_steps=1``, meaning each unroll step is regarded as + effective in terms of the unroll/train iter ratio, even though the + pace of saving the unroll steps into replay buffer has been altered. + (3) selective saving: exclude some of the unroll experiences and only save + the rest. This could be useful in the case where there are transitions + that are irrelevant to the training (e.g. in the multi-task case, where + we want to exclude data from certain subtasks). + This can be achieved by setting ``effective_experiences=[]``for the + steps to be excluded, while ``effective_experiences = [experiences]`` + otherwise. If we do not want to trigger a train iter for the unroll + step that will be excluded, we can simply set ``effective_unroll_steps=0``. + Otherwise, we can simply set ``effective_unroll_steps=1``. + (4) parallel environments: in the case of parallel environments, the value + of ``effective_unroll_steps`` can be set according to the modes described + above and the status of each environment (e.g. ``effective_unroll_steps`` + can be set to an average value across environments). Note that this could + resulf to a floating number. + """ + effective_experiences = [experiences] + effective_unroll_steps = 1 + return effective_experiences, effective_unroll_steps def _process_unroll_step(self, policy_step, action, time_step, transformed_time_step, policy_state, - experience_list, original_reward_list): + experience_list, + original_reward_list) -> Tuple[int, float]: + """ + + Returns: + - ``store_exp_time``: the time spent on storing the experience + - ``effective_unroll_steps``: a value representing the effective number + of unroll steps per env. The default value of 1, meaning the length of + effective experience is 1 after calling ``preprocess_unroll_experience``, + the same as the input length of experience. For more details on it, + please refer to the docstr of ``preprocess_unroll_experience``. + """ self.observe_for_metrics(time_step.cpu()) exp = make_experience(time_step.cpu(), alf.layers.to_float32(policy_step), alf.layers.to_float32(policy_state)) - + effective_unroll_steps = 1 store_exp_time = 0 if not self.on_policy: + # 1) pre-process unroll experience + pre_processed_exp_list, effective_unroll_steps = self.preprocess_unroll_experience( + policy_step.info, time_step.step_type, exp) + # 2) observe t0 = time.time() - self.observe_for_replay(exp) + for exp in pre_processed_exp_list: + self.observe_for_replay(exp) store_exp_time = time.time() - t0 exp_for_training = Experience( @@ -620,7 +708,7 @@ def _process_unroll_step(self, policy_step, action, time_step, experience_list.append(exp_for_training) original_reward_list.append(time_step.reward) - return store_exp_time + return store_exp_time, effective_unroll_steps def reset_state(self): """Reset the state of the algorithm. @@ -644,6 +732,8 @@ def _sync_unroll(self, unroll_length: int): Returns: Experience: The stacked experience with shape :math:`[T, B, \ldots]` for each of its members. + effective_unroll_iters: the effective number of unroll iterations. + Each unroll iteration contains ``unroll_length`` unroll steps. """ if self._current_time_step is None: self._current_time_step = common.get_initial_time_step(self._env) @@ -665,6 +755,7 @@ def _sync_unroll(self, unroll_length: int): policy_step_time = 0. env_step_time = 0. store_exp_time = 0. + effective_unroll_steps = 0 for _ in range(unroll_length): policy_state = common.reset_state_if_necessary( policy_state, initial_state, time_step.is_first()) @@ -693,9 +784,11 @@ def _sync_unroll(self, unroll_length: int): if self._overwrite_policy_output: policy_step = policy_step._replace( output=next_time_step.prev_action) - store_exp_time += self._process_unroll_step( + store_exp_time_i, effective_unroll_steps_i = self._process_unroll_step( policy_step, action, time_step, transformed_time_step, policy_state, experience_list, original_reward_list) + store_exp_time += store_exp_time_i + effective_unroll_steps += effective_unroll_steps_i time_step = next_time_step policy_state = policy_step.state @@ -723,7 +816,12 @@ def _sync_unroll(self, unroll_length: int): self._current_policy_state = common.detach(policy_state) self._current_transform_state = common.detach(trans_state) - return experience + # if the input unroll_length is 0 (e.g. fractional unroll), then this it treated as + # an effective unroll iter. + # one ``effective_unroll_iter`` refers to the ``unroll_length`` times of calling + # of ``rollout_step`` in the unroll phase. + effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length + return experience, effective_unroll_iters def train_iter(self): """Perform one iteration of training. @@ -747,7 +845,7 @@ def _compute_train_info_and_loss_info_on_policy(self, unroll_length): with record_time("time/unroll"): with torch.cuda.amp.autocast(self._config.enable_amp, dtype=self._config.amp_dtype): - experience = self.unroll(self._config.unroll_length) + experience, _ = self.unroll(self._config.unroll_length) self.summarize_metrics() train_info = experience.rollout_info @@ -788,6 +886,9 @@ def _unroll_iter_off_policy(self): unroll length, it may not have been called. - root_inputs: root-level time step returned by the unroll - rollout_info: rollout info returned by the unroll + - effective_unroll_iters: the effective number of unroll iterations. + ``train_from_replay_buffer`` will be run ``effective_unroll_iters`` times + during ``_train_iter_off_policy``. """ config: TrainerConfig = self._config @@ -804,6 +905,7 @@ def _unroll_iter_off_policy(self): unrolled = False root_inputs = None rollout_info = None + effective_unroll_iters = 0 if (alf.summary.get_global_counter() >= self._rl_train_after_update_steps and (unroll_length > 0 or config.unroll_length == 0) and @@ -822,7 +924,8 @@ def _unroll_iter_off_policy(self): # need to remember whether summary has been written between # two unrolls. with self._ensure_rollout_summary: - experience = self.unroll(unroll_length) + experience, effective_unroll_iters = self.unroll( + unroll_length) if experience: self.summarize_rollout(experience) self.summarize_metrics() @@ -830,11 +933,12 @@ def _unroll_iter_off_policy(self): if config.use_root_inputs_for_after_train_iter: root_inputs = experience.time_step del experience - return unrolled, root_inputs, rollout_info + return unrolled, root_inputs, rollout_info, effective_unroll_iters def _train_iter_off_policy(self): """User may override this for their own training procedure.""" - unrolled, root_inputs, rollout_info = self._unroll_iter_off_policy() + unrolled, root_inputs, rollout_info, effective_unroll_iters = self._unroll_iter_off_policy( + ) # replay buffer may not have been created for two different reasons: # 1. in online RL training (``has_offline`` is False), unroll is not @@ -846,11 +950,12 @@ def _train_iter_off_policy(self): return 0 self.train() - steps = self.train_from_replay_buffer(update_global_counter=True) - - if unrolled: - with record_time("time/after_train_iter"): - self.after_train_iter(root_inputs, rollout_info) + steps = 0 + for i in range(effective_unroll_iters): + steps += self.train_from_replay_buffer(update_global_counter=True) + if unrolled: + with record_time("time/after_train_iter"): + self.after_train_iter(root_inputs, rollout_info) # For now, we only return the steps of the primary algorithm's training return steps