|
27 | 27 | from alf.algorithms.config import TrainerConfig |
28 | 28 | from alf.algorithms.off_policy_algorithm import OffPolicyAlgorithm |
29 | 29 | from alf.algorithms.one_step_loss import OneStepTDLoss |
30 | | -from alf.data_structures import TimeStep, LossInfo, namedtuple, \ |
31 | | - BasicRLInfo |
| 30 | +from alf.data_structures import TimeStep, LossInfo, namedtuple |
32 | 31 | from alf.data_structures import AlgStep, StepType |
33 | 32 | from alf.nest import nest |
34 | 33 | import alf.nest.utils as nest_utils |
@@ -845,9 +844,8 @@ def _select_q_value(self, action, q_values): |
845 | 844 | return q_values.gather(2, action).squeeze(2) |
846 | 845 |
|
847 | 846 | def _critic_train_step(self, observation, target_observation, |
848 | | - state: SacCriticState, |
849 | | - rollout_info: SacInfo | BasicRLInfo, action, |
850 | | - action_distribution): |
| 847 | + state: SacCriticState, rollout_info: SacInfo, |
| 848 | + action, action_distribution): |
851 | 849 |
|
852 | 850 | critics, critics_state = self._compute_critics( |
853 | 851 | self._critic_networks, |
@@ -899,7 +897,7 @@ def _alpha_train_step(self, log_pi): |
899 | 897 | return sum(nest.flatten(alpha_loss)) |
900 | 898 |
|
901 | 899 | def train_step(self, inputs: TimeStep, state: SacState, |
902 | | - rollout_info: SacInfo | BasicRLInfo): |
| 900 | + rollout_info: SacInfo): |
903 | 901 | assert not self._is_eval |
904 | 902 | self._training_started = True |
905 | 903 | if self._target_repr_alg is not None: |
|
0 commit comments