-
Notifications
You must be signed in to change notification settings - Fork 58
Post Process Experience with Customizable Modes #1768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pytorch
Are you sure you want to change the base?
Changes from 1 commit
00efea8
e4cdb81
a05e8da
9cfe6a5
26ab09a
734dae8
94a50bf
8fc3ff2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -187,8 +187,10 @@ def __init__(self, | |
| during deployment. In this case, the algorithm do not need to | ||
| create certain components such as value_network for ActorCriticAlgorithm, | ||
| critic_networks for SacAlgorithm. | ||
| episodic_annotation: if True, annotate the episode before being observed by the | ||
| replay buffer. | ||
| episodic_annotation: episodic annotation is an operation that annotates the | ||
| episode after it being collected, and then the annotated episode will be | ||
| observed by the replay buffer. If True, annotate the episode before being | ||
| observed by the replay buffer. Otherwise, episodic annotation is not applied. | ||
| overwrite_policy_output (bool): if True, overwrite the policy output | ||
| with next_step.prev_action. This option can be used in some | ||
| cases such as data collection. | ||
|
|
@@ -244,9 +246,6 @@ def __init__(self, | |
| replay_buffer_length = adjust_replay_buffer_length( | ||
| config, self._num_earliest_frames_ignored) | ||
|
|
||
| if self._episodic_annotation: | ||
| assert self._env.batch_size == 1, "only support non-batched environment" | ||
|
|
||
| if config.whole_replay_buffer_training and config.clear_replay_buffer: | ||
| # For whole replay buffer training, we would like to be sure | ||
| # that the replay buffer have enough samples in it to perform | ||
|
|
@@ -608,21 +607,27 @@ def _async_unroll(self, unroll_length: int): | |
| def should_post_process_episode(self, rollout_info, step_type: StepType): | ||
| """A function that determines whether the ``post_process_episode`` function should | ||
| be applied to the current list of experiences. | ||
| Users can customize this function in the derived class. | ||
| Bu default, it returns True all the time steps. When this is combined with | ||
| ``post_process_episode`` which simply return the input unmodified (as the default | ||
| implementation in this class), it is a dummy version of eposodic annotation with | ||
| logic equivalent to the case of episodic_annotation=False. | ||
| """ | ||
| return False | ||
| return True | ||
|
|
||
| def post_process_episode(self, experiences: List[Experience]): | ||
| """A function for postprocessing a list of experience. It is called when | ||
| ``should_post_process_episode`` is True. | ||
| It can be used to create a number of useful features such as 'hindsight relabeling' | ||
| of a trajectory etc. | ||
| By default, it returns the input unmodified. | ||
| Users can customize this function in the derived class, to create a number of | ||
| useful features such as 'hindsight relabeling' of a trajectory etc. | ||
|
|
||
| Args: | ||
| experiences: a list of experience, containing the experience starting from the | ||
| initial time when ``should_post_process_episode`` is False to the step where | ||
| ``should_post_process_episode`` is True. | ||
| """ | ||
| return None | ||
| return experiences | ||
|
|
||
| def _process_unroll_step(self, policy_step, action, time_step, | ||
| transformed_time_step, policy_state, | ||
|
|
@@ -633,6 +638,7 @@ def _process_unroll_step(self, policy_step, action, time_step, | |
| alf.layers.to_float32(policy_state)) | ||
| effective_number_of_unroll_steps = 1 | ||
| if self._episodic_annotation: | ||
| assert not self.on_policy, "only support episodic annotation for off policy training" | ||
| store_exp_time = 0 | ||
| # if last step, annotate | ||
| rollout_info = policy_step.info | ||
|
|
@@ -645,11 +651,10 @@ def _process_unroll_step(self, policy_step, action, time_step, | |
| self._cached_exp) | ||
| effective_number_of_unroll_steps = len(annotated_exp_list) | ||
| # 2) observe | ||
| if not self.on_policy: | ||
| t0 = time.time() | ||
| for exp in annotated_exp_list: | ||
| self.observe_for_replay(exp) | ||
| store_exp_time = time.time() - t0 | ||
| t0 = time.time() | ||
| for exp in annotated_exp_list: | ||
| self.observe_for_replay(exp) | ||
| store_exp_time = time.time() - t0 | ||
| # clean up the exp cache | ||
| self._cached_exp = [] | ||
|
||
| else: | ||
|
|
@@ -903,8 +908,7 @@ def _train_iter_off_policy(self): | |
| self.train() | ||
| steps = 0 | ||
| for i in range(effective_unroll_steps): | ||
|
||
| steps += self.train_from_replay_buffer(effective_unroll_steps=1, | ||
| update_global_counter=True) | ||
| steps += self.train_from_replay_buffer(update_global_counter=True) | ||
| if unrolled: | ||
| with record_time("time/after_train_iter"): | ||
| self.after_train_iter(root_inputs, rollout_info) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe assert this in the
__init__function?