@@ -855,18 +855,22 @@ def reshape(d):
855
855
class DRQN (nn .Module ):
856
856
"""
857
857
Overview:
858
- The neural network structure and computation graph of DRQN (DQN + RNN = DRQN) algorithm, which is the most \
859
- common DQN variant for sequential data and paratially observable environment. The DRQN is composed of three \
860
- parts: ``encoder``, ``head`` and ``rnn``. The ``encoder`` is used to extract the feature from various \
861
- observation, the ``rnn`` is used to process the sequential observation and other data, and the ``head`` is \
862
- used to compute the Q value of each action dimension.
858
+ The DRQN (Deep Recurrent Q-Network) is a neural network model combining DQN with RNN to handle sequential
859
+ data and partially observable environments. It consists of three main components: ``encoder``, ``rnn``,
860
+ and ``head``.
861
+ - **Encoder**: Extracts features from various observation inputs.
862
+ - **RNN**: Processes sequential observations and other data.
863
+ - **Head**: Computes Q-values for each action dimension.
864
+
863
865
Interfaces:
864
866
``__init__``, ``forward``.
865
867
866
868
.. note::
867
- Current ``DRQN`` supports two types of encoder: ``FCEncoder`` and ``ConvEncoder``, two types of head: \
868
- ``DiscreteHead`` and ``DuelingHead``, three types of rnn: ``normal (LSTM with LayerNorm)``, ``pytorch`` and \
869
- ``gru``. You can customize your own encoder, rnn or head by inheriting this class.
869
+ The current implementation supports:
870
+ - Two encoder types: ``FCEncoder`` and ``ConvEncoder``.
871
+ - Two head types: ``DiscreteHead`` and ``DuelingHead``.
872
+ - Three RNN types: ``normal (LSTM with LayerNorm)``, ``pytorch`` (PyTorch's native LSTM), and ``gru``.
873
+ You can extend the model by customizing your own encoder, RNN, or head by inheriting this class.
870
874
"""
871
875
872
876
def __init__ (
@@ -884,43 +888,48 @@ def __init__(
884
888
) -> None :
885
889
"""
886
890
Overview:
887
- Initialize the DRQN Model according to the corresponding input arguments .
891
+ Initialize the DRQN model with specified parameters .
888
892
Arguments:
889
- - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
890
- - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
891
- - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
892
- the last element must match ``head_hidden_size``.
893
- - dueling (:obj:`Optional[bool]`): Whether choose ``DuelingHead`` or ``DiscreteHead (default)``.
894
- - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network, defaults to None, \
895
- then it will be set to the last element of ``encoder_hidden_size_list``.
896
- - head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output.
897
- - lstm_type (:obj:`Optional[str]`): The type of RNN module, now support ['normal', 'pytorch', 'gru'].
898
- - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
899
- if ``None`` then default set it to ``nn.ReLU()``.
900
- - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
901
- ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']
902
- - res_link (:obj:`bool`): Whether to enable the residual link, which is the skip connnection between \
903
- single frame data and the sequential data, defaults to False.
893
+ - obs_shape (:obj:`Union[int, SequenceType]`): Shape of the observation space, e.g., 8 or [4, 84, 84].
894
+ - action_shape (:obj:`Union[int, SequenceType]`): Shape of the action space, e.g., 6 or [2, 3, 3].
895
+ - encoder_hidden_size_list (:obj:`SequenceType`): List of hidden sizes for the encoder. The last element \
896
+ must match ``head_hidden_size``.
897
+ - dueling (:obj:`Optional[bool]`): Use ``DuelingHead`` if True, otherwise use ``DiscreteHead``.
898
+ - head_hidden_size (:obj:`Optional[int]`): Hidden size for the head network. Defaults to the last \
899
+ element of ``encoder_hidden_size_list`` if None.
900
+ - head_layer_num (:obj:`int`): Number of layers in the head network to compute Q-value outputs.
901
+ - lstm_type (:obj:`Optional[str]`): Type of RNN module. Supported types are ``normal``, ``pytorch``, \
902
+ and ``gru``.
903
+ - activation (:obj:`Optional[nn.Module]`): Activation function used in the network. Defaults to \
904
+ ``nn.ReLU()``.
905
+ - norm_type (:obj:`Optional[str]`): Normalization type for the networks. Supported types are: \
906
+ ['BN', 'IN', 'SyncBN', 'LN']. See ``ding.torch_utils.fc_block`` for more details.
907
+ - res_link (:obj:`bool`): Enables residual connections between single-frame data and sequential data. \
908
+ Defaults to False.
904
909
"""
905
910
super (DRQN , self ).__init__ ()
906
- # For compatibility: 1, (1, ), [4, 32, 32]
911
+ # Compatibility for obs_shape/action_shape: Handles scalar, tuple, or multi-dimensional inputs.
907
912
obs_shape , action_shape = squeeze (obs_shape ), squeeze (action_shape )
908
913
if head_hidden_size is None :
909
914
head_hidden_size = encoder_hidden_size_list [- 1 ]
910
- # FC Encoder
915
+
916
+ # Encoder: Determines the encoder type based on the observation shape.
911
917
if isinstance (obs_shape , int ) or len (obs_shape ) == 1 :
918
+ # FC Encoder
912
919
self .encoder = FCEncoder (obs_shape , encoder_hidden_size_list , activation = activation , norm_type = norm_type )
913
- # Conv Encoder
914
920
elif len (obs_shape ) == 3 :
921
+ # Conv Encoder
915
922
self .encoder = ConvEncoder (obs_shape , encoder_hidden_size_list , activation = activation , norm_type = norm_type )
916
923
else :
917
924
raise RuntimeError (
918
- "not support obs_shape for pre-defined encoder: {}, please customize your own DRQN" . format ( obs_shape )
925
+ f"Unsupported obs_shape for pre-defined encoder: { obs_shape } . Please customize your own DRQN."
919
926
)
920
- # LSTM Type
927
+
928
+ # RNN: Initializes the RNN module based on the specified lstm_type.
921
929
self .rnn = get_lstm (lstm_type , input_size = head_hidden_size , hidden_size = head_hidden_size )
922
930
self .res_link = res_link
923
- # Head Type
931
+
932
+ # Head: Determines the head type (Dueling or Discrete) and its configuration.
924
933
if dueling :
925
934
head_cls = DuelingHead
926
935
else :
@@ -943,31 +952,32 @@ def __init__(
943
952
def forward (self , inputs : Dict , inference : bool = False , saved_state_timesteps : Optional [list ] = None ) -> Dict :
944
953
"""
945
954
Overview:
946
- DRQN forward computation graph, input observation tensor to predict q_value.
955
+ Defines the forward pass of the DRQN model. Takes observation and previous RNN states as inputs \
956
+ and predicts Q-values.
947
957
Arguments:
948
- - inputs (:obj:`torch.Tensor `): The dict of input data, including observation and previous rnn state.
949
- - inference: (:obj:' bool' ): Whether to enable inference forward mode, if True, we unroll the one timestep \
950
- transition, otherwise, we unroll the eentire sequence transitions .
951
- - saved_state_timesteps: (:obj:' Optional[list]' ): When inference is False, we unroll the sequence \
952
- transitions, then we would use this list to indicate how to save and return hidden state .
958
+ - inputs (:obj:`Dict `): Input data dictionary containing observation and previous RNN state.
959
+ - inference (:obj:` bool` ): If True, unrolls one timestep (used during evaluation). If False, unrolls \
960
+ the entire sequence (used during training) .
961
+ - saved_state_timesteps (:obj:` Optional[list]` ): When inference is False, specifies the timesteps \
962
+ whose hidden states are saved and returned .
953
963
ArgumentsKeys:
954
- - obs (:obj:`torch.Tensor`): The raw observation tensor.
955
- - prev_state (:obj:`list`): The previous rnn state tensor, whose structure depends on ``lstm_type``.
964
+ - obs (:obj:`torch.Tensor`): Raw observation tensor.
965
+ - prev_state (:obj:`list`): Previous RNN state tensor, structure depends on ``lstm_type``.
956
966
Returns:
957
967
- outputs (:obj:`Dict`): The output of DRQN's forward, including logit (q_value) and next state.
958
968
ReturnsKeys:
959
- - logit (:obj:`torch.Tensor`): Discrete Q-value output of each possible action dimension.
960
- - next_state (:obj:`list`): The next rnn state tensor, whose structure depends on ``lstm_type`` .
969
+ - logit (:obj:`torch.Tensor`): Discrete Q-value output for each action dimension.
970
+ - next_state (:obj:`list`): Next RNN state tensor.
961
971
Shapes:
962
- - obs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``
963
- - logit (:obj:`torch.Tensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape``
972
+ - obs (:obj:`torch.Tensor`): :math:`(B, N)` where B is batch size and N is ``obs_shape``.
973
+ - logit (:obj:`torch.Tensor`): :math:`(B, M)` where B is batch size and M is ``action_shape``.
964
974
Examples:
965
- >>> # Init input's Keys:
975
+ >>> # Initialize input keys
966
976
>>> prev_state = [[torch.randn(1, 1, 64) for __ in range(2)] for _ in range(4)] # B=4
967
977
>>> obs = torch.randn(4,64)
968
978
>>> model = DRQN(64, 64) # arguments: 'obs_shape' and 'action_shape'
969
979
>>> outputs = model({'obs': inputs, 'prev_state': prev_state}, inference=True)
970
- >>> # Check outputs's Keys
980
+ >>> # Validate output keys and shapes
971
981
>>> assert isinstance(outputs, dict)
972
982
>>> assert outputs['logit'].shape == (4, 64)
973
983
>>> assert len(outputs['next_state']) == 4
@@ -976,9 +986,9 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps:
976
986
"""
977
987
978
988
x , prev_state = inputs ['obs' ], inputs ['prev_state' ]
979
- # for both inference and other cases, the network structure is encoder -> rnn network -> head
980
- # the difference is inference take the data with seq_len=1 (or T = 1)
981
- # NOTE(rjy): in most situations, set inference=True when evaluate and inference=False when training
989
+ # Forward pass: Encoder -> RNN -> Head
990
+ # in most situations, set inference=True when evaluate and inference=False when training
991
+ # Inference mode: Processes one timestep (seq_len=1).
982
992
if inference :
983
993
x = self .encoder (x )
984
994
if self .res_link :
@@ -992,27 +1002,28 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps:
992
1002
x = self .head (x )
993
1003
x ['next_state' ] = next_state
994
1004
return x
1005
+ # Training mode: Processes the entire sequence.
995
1006
else :
996
1007
# In order to better explain why rnn needs saved_state and which states need to be stored,
997
1008
# let's take r2d2 as an example
998
1009
# in r2d2,
999
1010
# 1) data['burnin_nstep_obs'] = data['obs'][:bs + self._nstep]
1000
1011
# 2) data['main_obs'] = data['obs'][bs:-self._nstep]
1001
1012
# 3) data['target_obs'] = data['obs'][bs + self._nstep:]
1002
- # NOTE(rjy): (T, B, N) or (T, B, C, H, W)
1003
- assert len (x .shape ) in [3 , 5 ], x .shape
1013
+ assert len (x .shape ) in [3 , 5 ], f"Expected shape (T, B, N) or (T, B, C, H, W), got { x .shape } "
1004
1014
x = parallel_wrapper (self .encoder )(x ) # (T, B, N)
1005
1015
if self .res_link :
1006
1016
a = x
1007
- # NOTE(rjy) lstm_embedding stores all hidden_state
1017
+ # lstm_embedding stores all hidden_state
1008
1018
lstm_embedding = []
1009
1019
# TODO(nyz) how to deal with hidden_size key-value
1010
1020
hidden_state_list = []
1021
+
1011
1022
if saved_state_timesteps is not None :
1012
1023
saved_state = []
1013
- for t in range (x .shape [0 ]): # T timesteps
1014
- # NOTE(rjy) use x[t:t+1] but not x[t] can keep original dimension
1015
- output , prev_state = self .rnn (x [t :t + 1 ], prev_state ) # output: (1,B, head_hidden_size )
1024
+ for t in range (x .shape [0 ]): # Iterate over timesteps (T).
1025
+ # use x[t:t+1] but not x[t] can keep the original dimension
1026
+ output , prev_state = self .rnn (x [t :t + 1 ], prev_state ) # RNN step output: (1, B, hidden_size )
1016
1027
if saved_state_timesteps is not None and t + 1 in saved_state_timesteps :
1017
1028
saved_state .append (prev_state )
1018
1029
lstm_embedding .append (output )
@@ -1023,7 +1034,7 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps:
1023
1034
if self .res_link :
1024
1035
x = x + a
1025
1036
x = parallel_wrapper (self .head )(x ) # (T, B, action_shape)
1026
- # NOTE(rjy): x['next_state'] is the hidden state of the last timestep inputted to lstm
1037
+ # x['next_state'] is the hidden state of the last timestep inputted to lstm
1027
1038
# the last timestep state including the hidden state (h) and the cell state (c)
1028
1039
# shape: {list: B{dict: 2{Tensor:(1, 1, head_hidden_size}}}
1029
1040
x ['next_state' ] = prev_state
0 commit comments