-
Notifications
You must be signed in to change notification settings - Fork 58
Post Process Experience with Customizable Modes #1768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pytorch
Are you sure you want to change the base?
Changes from 2 commits
00efea8
e4cdb81
a05e8da
9cfe6a5
26ab09a
734dae8
94a50bf
8fc3ff2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ | |
| import os | ||
| import time | ||
| import torch | ||
| from typing import Callable, Optional | ||
| from typing import Callable, List, Optional | ||
| from absl import logging | ||
|
|
||
| import alf | ||
|
|
@@ -147,6 +147,7 @@ def __init__(self, | |
| optimizer=None, | ||
| checkpoint=None, | ||
| is_eval: bool = False, | ||
| episodic_annotation: bool = False, | ||
| overwrite_policy_output=False, | ||
| debug_summaries=False, | ||
| name="RLAlgorithm"): | ||
|
|
@@ -186,6 +187,10 @@ def __init__(self, | |
| during deployment. In this case, the algorithm do not need to | ||
| create certain components such as value_network for ActorCriticAlgorithm, | ||
| critic_networks for SacAlgorithm. | ||
| episodic_annotation: episodic annotation is an operation that annotates the | ||
| episode after it being collected, and then the annotated episode will be | ||
| observed by the replay buffer. If True, annotate the episode before being | ||
| observed by the replay buffer. Otherwise, episodic annotation is not applied. | ||
| overwrite_policy_output (bool): if True, overwrite the policy output | ||
| with next_step.prev_action. This option can be used in some | ||
| cases such as data collection. | ||
|
|
@@ -203,6 +208,7 @@ def __init__(self, | |
| debug_summaries=debug_summaries, | ||
| name=name) | ||
| self._is_eval = is_eval | ||
| self._episodic_annotation = episodic_annotation | ||
|
|
||
| self._env = env | ||
| self._observation_spec = observation_spec | ||
|
|
@@ -235,7 +241,7 @@ def __init__(self, | |
| self._current_time_step = None | ||
| self._current_policy_state = None | ||
| self._current_transform_state = None | ||
|
|
||
| self._cached_exp = [] # for lazy observation | ||
| if self._env is not None and not self.on_policy: | ||
| replay_buffer_length = adjust_replay_buffer_length( | ||
| config, self._num_earliest_frames_ignored) | ||
|
|
@@ -598,19 +604,68 @@ def _async_unroll(self, unroll_length: int): | |
|
|
||
| return experience | ||
|
|
||
| def should_post_process_episode(self, rollout_info, step_type: StepType): | ||
| """A function that determines whether the ``post_process_episode`` function should | ||
| be applied to the current list of experiences. | ||
| Users can customize this function in the derived class. | ||
| Bu default, it returns True all the time steps. When this is combined with | ||
| ``post_process_episode`` which simply return the input unmodified (as the default | ||
| implementation in this class), it is a dummy version of eposodic annotation with | ||
| logic equivalent to the case of episodic_annotation=False. | ||
| """ | ||
| return True | ||
|
|
||
| def post_process_episode(self, experiences: List[Experience]): | ||
| """A function for postprocessing a list of experience. It is called when | ||
| ``should_post_process_episode`` is True. | ||
| By default, it returns the input unmodified. | ||
| Users can customize this function in the derived class, to create a number of | ||
| useful features such as 'hindsight relabeling' of a trajectory etc. | ||
|
|
||
| Args: | ||
| experiences: a list of experience, containing the experience starting from the | ||
| initial time when ``should_post_process_episode`` is False to the step where | ||
| ``should_post_process_episode`` is True. | ||
| """ | ||
| return experiences | ||
|
|
||
| def _process_unroll_step(self, policy_step, action, time_step, | ||
| transformed_time_step, policy_state, | ||
| experience_list, original_reward_list): | ||
| self.observe_for_metrics(time_step.cpu()) | ||
| exp = make_experience(time_step.cpu(), | ||
| alf.layers.to_float32(policy_step), | ||
| alf.layers.to_float32(policy_state)) | ||
|
|
||
| store_exp_time = 0 | ||
| if not self.on_policy: | ||
| t0 = time.time() | ||
| self.observe_for_replay(exp) | ||
| store_exp_time = time.time() - t0 | ||
| effective_number_of_unroll_steps = 1 | ||
| if self._episodic_annotation: | ||
| assert not self.on_policy, "only support episodic annotation for off policy training" | ||
|
||
| store_exp_time = 0 | ||
| # if last step, annotate | ||
| rollout_info = policy_step.info | ||
| self._cached_exp.append(exp) | ||
| if self.should_post_process_episode(rollout_info, | ||
| time_step.step_type): | ||
|
|
||
| # 1) process | ||
| annotated_exp_list = self.post_process_episode( | ||
| self._cached_exp) | ||
| effective_number_of_unroll_steps = len(annotated_exp_list) | ||
| # 2) observe | ||
| t0 = time.time() | ||
| for exp in annotated_exp_list: | ||
| self.observe_for_replay(exp) | ||
| store_exp_time = time.time() - t0 | ||
| # clean up the exp cache | ||
| self._cached_exp = [] | ||
|
||
| else: | ||
| # effective unroll steps as 0 if not post_process_episode timepoint yet | ||
| effective_number_of_unroll_steps = 0 | ||
| else: | ||
| store_exp_time = 0 | ||
| if not self.on_policy: | ||
| t0 = time.time() | ||
| self.observe_for_replay(exp) | ||
| store_exp_time = time.time() - t0 | ||
|
|
||
| exp_for_training = Experience( | ||
| time_step=transformed_time_step, | ||
|
|
@@ -620,7 +675,7 @@ def _process_unroll_step(self, policy_step, action, time_step, | |
|
|
||
| experience_list.append(exp_for_training) | ||
| original_reward_list.append(time_step.reward) | ||
| return store_exp_time | ||
| return store_exp_time, effective_number_of_unroll_steps | ||
|
|
||
| def reset_state(self): | ||
| """Reset the state of the algorithm. | ||
|
|
@@ -665,6 +720,7 @@ def _sync_unroll(self, unroll_length: int): | |
| policy_step_time = 0. | ||
| env_step_time = 0. | ||
| store_exp_time = 0. | ||
| effective_unroll_steps = 0 | ||
| for _ in range(unroll_length): | ||
| policy_state = common.reset_state_if_necessary( | ||
| policy_state, initial_state, time_step.is_first()) | ||
|
|
@@ -693,9 +749,10 @@ def _sync_unroll(self, unroll_length: int): | |
| if self._overwrite_policy_output: | ||
| policy_step = policy_step._replace( | ||
| output=next_time_step.prev_action) | ||
| store_exp_time += self._process_unroll_step( | ||
| store_exp_time_i, effective_unroll_steps = self._process_unroll_step( | ||
| policy_step, action, time_step, transformed_time_step, | ||
| policy_state, experience_list, original_reward_list) | ||
| store_exp_time += store_exp_time_i | ||
|
|
||
| time_step = next_time_step | ||
| policy_state = policy_step.state | ||
|
|
@@ -723,7 +780,7 @@ def _sync_unroll(self, unroll_length: int): | |
| self._current_policy_state = common.detach(policy_state) | ||
| self._current_transform_state = common.detach(trans_state) | ||
|
|
||
| return experience | ||
| return experience, effective_unroll_steps | ||
|
|
||
| def train_iter(self): | ||
| """Perform one iteration of training. | ||
|
|
@@ -804,6 +861,7 @@ def _unroll_iter_off_policy(self): | |
| unrolled = False | ||
| root_inputs = None | ||
| rollout_info = None | ||
| effective_unroll_steps = 0 | ||
| if (alf.summary.get_global_counter() | ||
| >= self._rl_train_after_update_steps | ||
| and (unroll_length > 0 or config.unroll_length == 0) and | ||
|
|
@@ -822,19 +880,21 @@ def _unroll_iter_off_policy(self): | |
| # need to remember whether summary has been written between | ||
| # two unrolls. | ||
| with self._ensure_rollout_summary: | ||
| experience = self.unroll(unroll_length) | ||
| experience, effective_unroll_steps = self.unroll( | ||
| unroll_length) | ||
| if experience: | ||
| self.summarize_rollout(experience) | ||
| self.summarize_metrics() | ||
| rollout_info = experience.rollout_info | ||
| if config.use_root_inputs_for_after_train_iter: | ||
| root_inputs = experience.time_step | ||
| del experience | ||
| return unrolled, root_inputs, rollout_info | ||
| return unrolled, root_inputs, rollout_info, effective_unroll_steps | ||
|
|
||
| def _train_iter_off_policy(self): | ||
| """User may override this for their own training procedure.""" | ||
| unrolled, root_inputs, rollout_info = self._unroll_iter_off_policy() | ||
| unrolled, root_inputs, rollout_info, effective_unroll_steps = self._unroll_iter_off_policy( | ||
| ) | ||
|
|
||
| # replay buffer may not have been created for two different reasons: | ||
| # 1. in online RL training (``has_offline`` is False), unroll is not | ||
|
|
@@ -846,11 +906,12 @@ def _train_iter_off_policy(self): | |
| return 0 | ||
|
|
||
| self.train() | ||
| steps = self.train_from_replay_buffer(update_global_counter=True) | ||
|
|
||
| if unrolled: | ||
| with record_time("time/after_train_iter"): | ||
| self.after_train_iter(root_inputs, rollout_info) | ||
| steps = 0 | ||
| for i in range(effective_unroll_steps): | ||
|
||
| steps += self.train_from_replay_buffer(update_global_counter=True) | ||
| if unrolled: | ||
| with record_time("time/after_train_iter"): | ||
| self.after_train_iter(root_inputs, rollout_info) | ||
|
|
||
| # For now, we only return the steps of the primary algorithm's training | ||
| return steps | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need to make this change?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not necessary anymore. removed