13
13
class ContinuousQVAC (nn .Module ):
14
14
"""
15
15
Overview:
16
- The neural network and computation graph of algorithms related to Actor-Critic that have both Q-value and V-value critic, such as \
17
- IQL. This model now supports continuous and hybrid action space. The ContinuousQVAC is composed of \
18
- four parts: ``actor_encoder``, ``critic_encoder``, ``actor_head`` and ``critic_head``. Encoders are used to \
19
- extract the feature from various observation . Heads are used to predict corresponding Q- value and V-value or action logit. \
16
+ The neural network and computation graph of algorithms related to Actor-Critic that have both Q-value and \
17
+ V-value critic, such as IQL. This model now supports continuous and hybrid action space. The ContinuousQVAC is \
18
+ composed of four parts: ``actor_encoder``, ``critic_encoder``, ``actor_head`` and ``critic_head``. Encoders \
19
+ are used to extract the feature . Heads are used to predict corresponding value or action logit.
20
20
In high-dimensional observation space like 2D image, we often use a shared encoder for both ``actor_encoder`` \
21
21
and ``critic_encoder``. In low-dimensional observation space like 1D vector, we often use different encoders.
22
22
Interfaces:
@@ -34,7 +34,7 @@ def __init__(
34
34
actor_head_layer_num : int = 1 ,
35
35
critic_head_hidden_size : int = 64 ,
36
36
critic_head_layer_num : int = 1 ,
37
- activation : Optional [nn .Module ] = nn .SiLU (), #nn.ReLU(),
37
+ activation : Optional [nn .Module ] = nn .SiLU (),
38
38
norm_type : Optional [str ] = None ,
39
39
encoder_hidden_size_list : Optional [SequenceType ] = None ,
40
40
share_encoder : Optional [bool ] = False ,
@@ -319,7 +319,7 @@ def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten
319
319
- logit (:obj:`torch.Tensor`): Discrete action logit, only in hybrid action_space.
320
320
- action_args (:obj:`torch.Tensor`): Continuous action arguments, only in hybrid action_space.
321
321
Returns:
322
- - outputs (:obj:`Dict[str, torch.Tensor]`): The output dict of QVAC's forward computation graph for critic, \
322
+ - outputs (:obj:`Dict[str, torch.Tensor]`): The output of QVAC's forward computation graph for critic, \
323
323
including ``q_value``.
324
324
ReturnKeys:
325
325
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
0 commit comments