Skip to content

Commit 94a50bf

Browse files
committed
Let user set effective_unroll_steps
1 parent 734dae8 commit 94a50bf

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

alf/algorithms/rl_algorithm.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os
2020
import time
2121
import torch
22-
from typing import Callable, Optional
22+
from typing import Callable, List, Optional, Tuple
2323
from absl import logging
2424

2525
import alf
@@ -605,7 +605,7 @@ def _async_unroll(self, unroll_length: int):
605605
return experience, effective_unroll_iters
606606

607607
def post_process_experience(self, rollout_info, step_type: StepType,
608-
experiences: Experience):
608+
experiences: Experience) -> Tuple[List, int]:
609609
"""A function for postprocessing experience. By default, it returns the input
610610
experience unmodified. Users can customize this function in the derived
611611
class to achieve different effects. For example:
@@ -622,17 +622,22 @@ class to achieve different effects. For example:
622622
experiences: one step of experience.
623623
624624
Returns:
625-
A list of experiences. Users can customize this functions in the
626-
derived class to achieve different effects. For example:
627-
- return a list that contains only the input experience (default behavior).
628-
- return a list that contains a number of experiences. This can be useful
629-
for episode processing such as success episode labeling.
625+
- a list of experiences. Users can customize this functions in the
626+
derived class to achieve different effects. For example:
627+
* return a list that contains only the input experience (default behavior).
628+
* return a list that contains a number of experiences. This can be useful
629+
for episode processing such as success episode labeling.
630+
- an integer representing the effective number of unroll steps per env. The
631+
default value of 1, meaning the length of effective experience is 1
632+
after calling ``post_process_experience``, the same as the input length
633+
of experience.
630634
"""
631-
return [experiences]
635+
return [experiences], 1
632636

633637
def _process_unroll_step(self, policy_step, action, time_step,
634638
transformed_time_step, policy_state,
635-
experience_list, original_reward_list):
639+
experience_list,
640+
original_reward_list) -> Tuple[int, int]:
636641
self.observe_for_metrics(time_step.cpu())
637642
exp = make_experience(time_step.cpu(),
638643
alf.layers.to_float32(policy_step),
@@ -641,11 +646,8 @@ def _process_unroll_step(self, policy_step, action, time_step,
641646
store_exp_time = 0
642647
if not self.on_policy:
643648
# 1) post process
644-
post_processed_exp_list = self.post_process_experience(
649+
post_processed_exp_list, effective_unroll_steps = self.post_process_experience(
645650
policy_step.info, time_step.step_type, exp)
646-
effective_unroll_steps = sum(
647-
exp.step_type.shape[0]
648-
for exp in post_processed_exp_list) / exp.step_type.shape[0]
649651
# 2) observe
650652
t0 = time.time()
651653
for exp in post_processed_exp_list:

0 commit comments

Comments
 (0)