Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions alf/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1469,6 +1469,7 @@ def train_from_replay_buffer(self, update_global_counter=False):
``True``, it will affect the counter only if
``config.update_counter_every_mini_batch=True``.
"""

config: TrainerConfig = self._config

# returns 0 if haven't started training yet, when ``_replay_buffer`` is
Expand All @@ -1493,21 +1494,22 @@ def _replay():
# ``_replay_buffer`` for training.
# TODO: If this function can be called asynchronously, and using
# prioritized replay, then make sure replay and train below is atomic.
effective_num_updates_per_train_iter = config.num_updates_per_train_iter
with record_time("time/replay"):
mini_batch_size = config.mini_batch_size
if mini_batch_size is None:
mini_batch_size = self._replay_buffer.num_environments
if config.whole_replay_buffer_training:
experience, batch_info = self._replay_buffer.gather_all(
ignore_earliest_frames=True)
num_updates = config.num_updates_per_train_iter
num_updates = effective_num_updates_per_train_iter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need to make this change?

Copy link
Contributor Author

@Haichao-Zhang Haichao-Zhang May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessary anymore. removed

else:
assert config.mini_batch_length is not None, (
"No mini_batch_length is specified for off-policy training"
)
experience, batch_info = self._replay_buffer.get_batch(
batch_size=(mini_batch_size *
config.num_updates_per_train_iter),
effective_num_updates_per_train_iter),
batch_length=config.mini_batch_length)
num_updates = 1
return experience, batch_info, num_updates, mini_batch_size
Expand Down
99 changes: 80 additions & 19 deletions alf/algorithms/rl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
import time
import torch
from typing import Callable, Optional
from typing import Callable, List, Optional
from absl import logging

import alf
Expand Down Expand Up @@ -147,6 +147,7 @@ def __init__(self,
optimizer=None,
checkpoint=None,
is_eval: bool = False,
episodic_annotation: bool = False,
overwrite_policy_output=False,
debug_summaries=False,
name="RLAlgorithm"):
Expand Down Expand Up @@ -186,6 +187,10 @@ def __init__(self,
during deployment. In this case, the algorithm do not need to
create certain components such as value_network for ActorCriticAlgorithm,
critic_networks for SacAlgorithm.
episodic_annotation: episodic annotation is an operation that annotates the
episode after it being collected, and then the annotated episode will be
observed by the replay buffer. If True, annotate the episode before being
observed by the replay buffer. Otherwise, episodic annotation is not applied.
overwrite_policy_output (bool): if True, overwrite the policy output
with next_step.prev_action. This option can be used in some
cases such as data collection.
Expand All @@ -203,6 +208,7 @@ def __init__(self,
debug_summaries=debug_summaries,
name=name)
self._is_eval = is_eval
self._episodic_annotation = episodic_annotation

self._env = env
self._observation_spec = observation_spec
Expand Down Expand Up @@ -235,7 +241,7 @@ def __init__(self,
self._current_time_step = None
self._current_policy_state = None
self._current_transform_state = None

self._cached_exp = [] # for lazy observation
if self._env is not None and not self.on_policy:
replay_buffer_length = adjust_replay_buffer_length(
config, self._num_earliest_frames_ignored)
Expand Down Expand Up @@ -598,19 +604,68 @@ def _async_unroll(self, unroll_length: int):

return experience

def should_post_process_episode(self, rollout_info, step_type: StepType):
"""A function that determines whether the ``post_process_episode`` function should
be applied to the current list of experiences.
Users can customize this function in the derived class.
Bu default, it returns True all the time steps. When this is combined with
``post_process_episode`` which simply return the input unmodified (as the default
implementation in this class), it is a dummy version of eposodic annotation with
logic equivalent to the case of episodic_annotation=False.
"""
return True

def post_process_episode(self, experiences: List[Experience]):
"""A function for postprocessing a list of experience. It is called when
``should_post_process_episode`` is True.
By default, it returns the input unmodified.
Users can customize this function in the derived class, to create a number of
useful features such as 'hindsight relabeling' of a trajectory etc.

Args:
experiences: a list of experience, containing the experience starting from the
initial time when ``should_post_process_episode`` is False to the step where
``should_post_process_episode`` is True.
"""
return experiences

def _process_unroll_step(self, policy_step, action, time_step,
transformed_time_step, policy_state,
experience_list, original_reward_list):
self.observe_for_metrics(time_step.cpu())
exp = make_experience(time_step.cpu(),
alf.layers.to_float32(policy_step),
alf.layers.to_float32(policy_state))

store_exp_time = 0
if not self.on_policy:
t0 = time.time()
self.observe_for_replay(exp)
store_exp_time = time.time() - t0
effective_number_of_unroll_steps = 1
if self._episodic_annotation:
assert not self.on_policy, "only support episodic annotation for off policy training"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe assert this in the __init__ function?

store_exp_time = 0
# if last step, annotate
rollout_info = policy_step.info
self._cached_exp.append(exp)
if self.should_post_process_episode(rollout_info,
time_step.step_type):

# 1) process
annotated_exp_list = self.post_process_episode(
self._cached_exp)
effective_number_of_unroll_steps = len(annotated_exp_list)
# 2) observe
t0 = time.time()
for exp in annotated_exp_list:
self.observe_for_replay(exp)
store_exp_time = time.time() - t0
# clean up the exp cache
self._cached_exp = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to assume that all envs end on the same step? What if some envs are LAST, some are MID? cached_exp will be cleared even for those with MID steps?

Even when doing this for an env with batch_size 1, this annotation mode will delay experience from being stored into the replay buffer.

Ok to submit the change as is, but may need to do two things:

  1. rename the feature to something like store_experience_on_episode_end, and document its behavior clearly in the docstr.
    experience relabel should be done when reading data out of replay buffer as in hindsight relabel.

  2. assert that batch_size is 1 when enabled.

Also, delaying train_step because of delayed experience storage can have unexpected side effects, e.g. if episodes are 100 steps long, and unroll once per train iter, then summary will only happen every 100 train iters. It will also shift the distribution of the data training sees due to the delay.

Overall I think doing this episode level relabeling at the DataTransformer stage, after reading from replay_buffer is perhaps a better way, and a cleaner way as well (less scattered code). That would require the replay buffer to keep track of episode begin and end, which I think it already does.

Copy link
Contributor Author

@Haichao-Zhang Haichao-Zhang May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to assume that all envs end on the same step? What if some envs are LAST, some are MID? cached_exp will be cleared even for those with MID steps?

There is no such assumption. It is totally up to the users to inject their own assumptions.
By default, the behavior is the same as before.
Sorry that the function names are a bit mis-leading and their role has been extended to handle per-step case as well. Changed the function names and added more comments.

Even when doing this for an env with batch_size 1, this annotation mode will delay experience from being stored into the replay buffer.

No it won't. By default, the behavior is the same as before.

Ok to submit the change as is, but may need to do two things:

  1. rename the feature to something like store_experience_on_episode_end, and document its behavior clearly in the docstr.
    The suggested name is not appropriate.

experience relabel should be done when reading data out of replay buffer as in hindsight relabel.

Different use cases. This is an alternative interface that can support more than pure relabeling (e.g. excluding data), which is not directly supported by the replay buffer hindsight relabel.

  1. assert that batch_size is 1 when enabled.
    There is no such assumption in the current PR. It is up to the user.

Also, delaying train_step because of delayed experience storage can have unexpected side effects, e.g. if episodes are 100 steps long, and unroll once per train iter, then summary will only happen every 100 train iters. It will also shift the distribution of the data training sees due to the delay.
There is no delay.

Overall I think doing this episode level relabeling at the DataTransformer stage, after reading from replay_buffer is perhaps a better way, and a cleaner way as well (less scattered code). That would require the replay buffer to keep track of episode begin and end, which I think it already does.
As explained, it is more than pure relabeling.

else:
# effective unroll steps as 0 if not post_process_episode timepoint yet
effective_number_of_unroll_steps = 0
else:
store_exp_time = 0
if not self.on_policy:
t0 = time.time()
self.observe_for_replay(exp)
store_exp_time = time.time() - t0

exp_for_training = Experience(
time_step=transformed_time_step,
Expand All @@ -620,7 +675,7 @@ def _process_unroll_step(self, policy_step, action, time_step,

experience_list.append(exp_for_training)
original_reward_list.append(time_step.reward)
return store_exp_time
return store_exp_time, effective_number_of_unroll_steps

def reset_state(self):
"""Reset the state of the algorithm.
Expand Down Expand Up @@ -665,6 +720,7 @@ def _sync_unroll(self, unroll_length: int):
policy_step_time = 0.
env_step_time = 0.
store_exp_time = 0.
effective_unroll_steps = 0
for _ in range(unroll_length):
policy_state = common.reset_state_if_necessary(
policy_state, initial_state, time_step.is_first())
Expand Down Expand Up @@ -693,9 +749,10 @@ def _sync_unroll(self, unroll_length: int):
if self._overwrite_policy_output:
policy_step = policy_step._replace(
output=next_time_step.prev_action)
store_exp_time += self._process_unroll_step(
store_exp_time_i, effective_unroll_steps = self._process_unroll_step(
policy_step, action, time_step, transformed_time_step,
policy_state, experience_list, original_reward_list)
store_exp_time += store_exp_time_i

time_step = next_time_step
policy_state = policy_step.state
Expand Down Expand Up @@ -723,7 +780,7 @@ def _sync_unroll(self, unroll_length: int):
self._current_policy_state = common.detach(policy_state)
self._current_transform_state = common.detach(trans_state)

return experience
return experience, effective_unroll_steps

def train_iter(self):
"""Perform one iteration of training.
Expand Down Expand Up @@ -804,6 +861,7 @@ def _unroll_iter_off_policy(self):
unrolled = False
root_inputs = None
rollout_info = None
effective_unroll_steps = 0
if (alf.summary.get_global_counter()
>= self._rl_train_after_update_steps
and (unroll_length > 0 or config.unroll_length == 0) and
Expand All @@ -822,19 +880,21 @@ def _unroll_iter_off_policy(self):
# need to remember whether summary has been written between
# two unrolls.
with self._ensure_rollout_summary:
experience = self.unroll(unroll_length)
experience, effective_unroll_steps = self.unroll(
unroll_length)
if experience:
self.summarize_rollout(experience)
self.summarize_metrics()
rollout_info = experience.rollout_info
if config.use_root_inputs_for_after_train_iter:
root_inputs = experience.time_step
del experience
return unrolled, root_inputs, rollout_info
return unrolled, root_inputs, rollout_info, effective_unroll_steps

def _train_iter_off_policy(self):
"""User may override this for their own training procedure."""
unrolled, root_inputs, rollout_info = self._unroll_iter_off_policy()
unrolled, root_inputs, rollout_info, effective_unroll_steps = self._unroll_iter_off_policy(
)

# replay buffer may not have been created for two different reasons:
# 1. in online RL training (``has_offline`` is False), unroll is not
Expand All @@ -846,11 +906,12 @@ def _train_iter_off_policy(self):
return 0

self.train()
steps = self.train_from_replay_buffer(update_global_counter=True)

if unrolled:
with record_time("time/after_train_iter"):
self.after_train_iter(root_inputs, rollout_info)
steps = 0
for i in range(effective_unroll_steps):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unroll_steps is the wrong name? It should be called unroll_iterations to indicate training iterations, not env steps?

also rename effective_number_of_unroll_steps to be effective_unroll_iters to be consistent. (i.e. remove "number_of_")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments. Changed.

steps += self.train_from_replay_buffer(update_global_counter=True)
if unrolled:
with record_time("time/after_train_iter"):
self.after_train_iter(root_inputs, rollout_info)

# For now, we only return the steps of the primary algorithm's training
return steps
Expand Down
Loading