Skip to content

Commit 9db5652

Browse files
committed
address comments
1 parent 2c7e7cf commit 9db5652

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

alf/algorithms/algorithm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,6 +1369,11 @@ def train_step_offline(self, inputs, state, rollout_info, pre_train=False):
13691369
"""
13701370
try:
13711371
if isinstance(rollout_info, BasicRolloutInfo):
1372+
logging.log_first_n(
1373+
logging.WARNING,
1374+
"Detected offline buffer training without Agent wrapper. "
1375+
"For best compatibility, it is advised to use the Agent wrapper.",
1376+
n=1)
13721377
rollout_info = rollout_info.rl
13731378
return self.train_step(inputs, state, rollout_info)
13741379
except:

alf/algorithms/sac_algorithm.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
from alf.algorithms.config import TrainerConfig
2828
from alf.algorithms.off_policy_algorithm import OffPolicyAlgorithm
2929
from alf.algorithms.one_step_loss import OneStepTDLoss
30-
from alf.data_structures import TimeStep, LossInfo, namedtuple, \
31-
BasicRLInfo
30+
from alf.data_structures import TimeStep, LossInfo, namedtuple
3231
from alf.data_structures import AlgStep, StepType
3332
from alf.nest import nest
3433
import alf.nest.utils as nest_utils
@@ -845,9 +844,8 @@ def _select_q_value(self, action, q_values):
845844
return q_values.gather(2, action).squeeze(2)
846845

847846
def _critic_train_step(self, observation, target_observation,
848-
state: SacCriticState,
849-
rollout_info: SacInfo | BasicRLInfo, action,
850-
action_distribution):
847+
state: SacCriticState, rollout_info: SacInfo,
848+
action, action_distribution):
851849

852850
critics, critics_state = self._compute_critics(
853851
self._critic_networks,
@@ -899,7 +897,7 @@ def _alpha_train_step(self, log_pi):
899897
return sum(nest.flatten(alpha_loss))
900898

901899
def train_step(self, inputs: TimeStep, state: SacState,
902-
rollout_info: SacInfo | BasicRLInfo):
900+
rollout_info: SacInfo):
903901
assert not self._is_eval
904902
self._training_started = True
905903
if self._target_repr_alg is not None:

0 commit comments

Comments
 (0)