Skip to content

Commit 64efcb3

Browse files
authored
polish(pu): delete unused enable_fast_timestep argument (#855)
* polish(pu): delete unused enable_fast_timestep argument * polish(pu): delete unused empty lines * polish(pu): delete unused empty lines * style(pu): polish comment's format * style(pu): polish comment's format
1 parent 3292384 commit 64efcb3

File tree

10 files changed

+86
-87
lines changed

10 files changed

+86
-87
lines changed

ding/model/template/collaq.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -411,27 +411,20 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
411411
agent_alone_state = agent_alone_state.reshape(T, -1, *agent_alone_state.shape[3:])
412412
agent_alone_padding_state = agent_alone_padding_state.reshape(T, -1, *agent_alone_padding_state.shape[3:])
413413

414-
colla_output = self._q_network(
415-
{
416-
'obs': agent_state,
417-
'prev_state': colla_prev_state,
418-
'enable_fast_timestep': True
419-
}
420-
)
414+
colla_output = self._q_network({
415+
'obs': agent_state,
416+
'prev_state': colla_prev_state,
417+
})
421418
colla_alone_output = self._q_network(
422419
{
423420
'obs': agent_alone_padding_state,
424421
'prev_state': colla_alone_prev_state,
425-
'enable_fast_timestep': True
426-
}
427-
)
428-
alone_output = self._q_alone_network(
429-
{
430-
'obs': agent_alone_state,
431-
'prev_state': alone_prev_state,
432-
'enable_fast_timestep': True
433422
}
434423
)
424+
alone_output = self._q_alone_network({
425+
'obs': agent_alone_state,
426+
'prev_state': alone_prev_state,
427+
})
435428

436429
agent_alone_q, alone_next_state = alone_output['logit'], alone_output['next_state']
437430
agent_colla_alone_q, colla_alone_next_state = colla_alone_output['logit'], colla_alone_output['next_state']

ding/model/template/coma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def forward(self, inputs: Dict) -> Dict:
7070
T, B, A = agent_state.shape[:3]
7171
agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
7272
prev_state = reduce(lambda x, y: x + y, prev_state)
73-
output = self.main({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True})
73+
output = self.main({'obs': agent_state, 'prev_state': prev_state})
7474
logit, next_state = output['logit'], output['next_state']
7575
next_state, _ = list_split(next_state, step=A)
7676
logit = logit.reshape(T, B, A, -1)

ding/model/template/q_learning.py

Lines changed: 65 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -855,18 +855,22 @@ def reshape(d):
855855
class DRQN(nn.Module):
856856
"""
857857
Overview:
858-
The neural network structure and computation graph of DRQN (DQN + RNN = DRQN) algorithm, which is the most \
859-
common DQN variant for sequential data and paratially observable environment. The DRQN is composed of three \
860-
parts: ``encoder``, ``head`` and ``rnn``. The ``encoder`` is used to extract the feature from various \
861-
observation, the ``rnn`` is used to process the sequential observation and other data, and the ``head`` is \
862-
used to compute the Q value of each action dimension.
858+
The DRQN (Deep Recurrent Q-Network) is a neural network model combining DQN with RNN to handle sequential
859+
data and partially observable environments. It consists of three main components: ``encoder``, ``rnn``,
860+
and ``head``.
861+
- **Encoder**: Extracts features from various observation inputs.
862+
- **RNN**: Processes sequential observations and other data.
863+
- **Head**: Computes Q-values for each action dimension.
864+
863865
Interfaces:
864866
``__init__``, ``forward``.
865867
866868
.. note::
867-
Current ``DRQN`` supports two types of encoder: ``FCEncoder`` and ``ConvEncoder``, two types of head: \
868-
``DiscreteHead`` and ``DuelingHead``, three types of rnn: ``normal (LSTM with LayerNorm)``, ``pytorch`` and \
869-
``gru``. You can customize your own encoder, rnn or head by inheriting this class.
869+
The current implementation supports:
870+
- Two encoder types: ``FCEncoder`` and ``ConvEncoder``.
871+
- Two head types: ``DiscreteHead`` and ``DuelingHead``.
872+
- Three RNN types: ``normal (LSTM with LayerNorm)``, ``pytorch`` (PyTorch's native LSTM), and ``gru``.
873+
You can extend the model by customizing your own encoder, RNN, or head by inheriting this class.
870874
"""
871875

872876
def __init__(
@@ -884,43 +888,48 @@ def __init__(
884888
) -> None:
885889
"""
886890
Overview:
887-
Initialize the DRQN Model according to the corresponding input arguments.
891+
Initialize the DRQN model with specified parameters.
888892
Arguments:
889-
- obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
890-
- action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
891-
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
892-
the last element must match ``head_hidden_size``.
893-
- dueling (:obj:`Optional[bool]`): Whether choose ``DuelingHead`` or ``DiscreteHead (default)``.
894-
- head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network, defaults to None, \
895-
then it will be set to the last element of ``encoder_hidden_size_list``.
896-
- head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output.
897-
- lstm_type (:obj:`Optional[str]`): The type of RNN module, now support ['normal', 'pytorch', 'gru'].
898-
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
899-
if ``None`` then default set it to ``nn.ReLU()``.
900-
- norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
901-
``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']
902-
- res_link (:obj:`bool`): Whether to enable the residual link, which is the skip connnection between \
903-
single frame data and the sequential data, defaults to False.
893+
- obs_shape (:obj:`Union[int, SequenceType]`): Shape of the observation space, e.g., 8 or [4, 84, 84].
894+
- action_shape (:obj:`Union[int, SequenceType]`): Shape of the action space, e.g., 6 or [2, 3, 3].
895+
- encoder_hidden_size_list (:obj:`SequenceType`): List of hidden sizes for the encoder. The last element \
896+
must match ``head_hidden_size``.
897+
- dueling (:obj:`Optional[bool]`): Use ``DuelingHead`` if True, otherwise use ``DiscreteHead``.
898+
- head_hidden_size (:obj:`Optional[int]`): Hidden size for the head network. Defaults to the last \
899+
element of ``encoder_hidden_size_list`` if None.
900+
- head_layer_num (:obj:`int`): Number of layers in the head network to compute Q-value outputs.
901+
- lstm_type (:obj:`Optional[str]`): Type of RNN module. Supported types are ``normal``, ``pytorch``, \
902+
and ``gru``.
903+
- activation (:obj:`Optional[nn.Module]`): Activation function used in the network. Defaults to \
904+
``nn.ReLU()``.
905+
- norm_type (:obj:`Optional[str]`): Normalization type for the networks. Supported types are: \
906+
['BN', 'IN', 'SyncBN', 'LN']. See ``ding.torch_utils.fc_block`` for more details.
907+
- res_link (:obj:`bool`): Enables residual connections between single-frame data and sequential data. \
908+
Defaults to False.
904909
"""
905910
super(DRQN, self).__init__()
906-
# For compatibility: 1, (1, ), [4, 32, 32]
911+
# Compatibility for obs_shape/action_shape: Handles scalar, tuple, or multi-dimensional inputs.
907912
obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
908913
if head_hidden_size is None:
909914
head_hidden_size = encoder_hidden_size_list[-1]
910-
# FC Encoder
915+
916+
# Encoder: Determines the encoder type based on the observation shape.
911917
if isinstance(obs_shape, int) or len(obs_shape) == 1:
918+
# FC Encoder
912919
self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
913-
# Conv Encoder
914920
elif len(obs_shape) == 3:
921+
# Conv Encoder
915922
self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
916923
else:
917924
raise RuntimeError(
918-
"not support obs_shape for pre-defined encoder: {}, please customize your own DRQN".format(obs_shape)
925+
f"Unsupported obs_shape for pre-defined encoder: {obs_shape}. Please customize your own DRQN."
919926
)
920-
# LSTM Type
927+
928+
# RNN: Initializes the RNN module based on the specified lstm_type.
921929
self.rnn = get_lstm(lstm_type, input_size=head_hidden_size, hidden_size=head_hidden_size)
922930
self.res_link = res_link
923-
# Head Type
931+
932+
# Head: Determines the head type (Dueling or Discrete) and its configuration.
924933
if dueling:
925934
head_cls = DuelingHead
926935
else:
@@ -943,31 +952,32 @@ def __init__(
943952
def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps: Optional[list] = None) -> Dict:
944953
"""
945954
Overview:
946-
DRQN forward computation graph, input observation tensor to predict q_value.
955+
Defines the forward pass of the DRQN model. Takes observation and previous RNN states as inputs \
956+
and predicts Q-values.
947957
Arguments:
948-
- inputs (:obj:`torch.Tensor`): The dict of input data, including observation and previous rnn state.
949-
- inference: (:obj:'bool'): Whether to enable inference forward mode, if True, we unroll the one timestep \
950-
transition, otherwise, we unroll the eentire sequence transitions.
951-
- saved_state_timesteps: (:obj:'Optional[list]'): When inference is False, we unroll the sequence \
952-
transitions, then we would use this list to indicate how to save and return hidden state.
958+
- inputs (:obj:`Dict`): Input data dictionary containing observation and previous RNN state.
959+
- inference (:obj:`bool`): If True, unrolls one timestep (used during evaluation). If False, unrolls \
960+
the entire sequence (used during training).
961+
- saved_state_timesteps (:obj:`Optional[list]`): When inference is False, specifies the timesteps \
962+
whose hidden states are saved and returned.
953963
ArgumentsKeys:
954-
- obs (:obj:`torch.Tensor`): The raw observation tensor.
955-
- prev_state (:obj:`list`): The previous rnn state tensor, whose structure depends on ``lstm_type``.
964+
- obs (:obj:`torch.Tensor`): Raw observation tensor.
965+
- prev_state (:obj:`list`): Previous RNN state tensor, structure depends on ``lstm_type``.
956966
Returns:
957967
- outputs (:obj:`Dict`): The output of DRQN's forward, including logit (q_value) and next state.
958968
ReturnsKeys:
959-
- logit (:obj:`torch.Tensor`): Discrete Q-value output of each possible action dimension.
960-
- next_state (:obj:`list`): The next rnn state tensor, whose structure depends on ``lstm_type``.
969+
- logit (:obj:`torch.Tensor`): Discrete Q-value output for each action dimension.
970+
- next_state (:obj:`list`): Next RNN state tensor.
961971
Shapes:
962-
- obs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``
963-
- logit (:obj:`torch.Tensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape``
972+
- obs (:obj:`torch.Tensor`): :math:`(B, N)` where B is batch size and N is ``obs_shape``.
973+
- logit (:obj:`torch.Tensor`): :math:`(B, M)` where B is batch size and M is ``action_shape``.
964974
Examples:
965-
>>> # Init input's Keys:
975+
>>> # Initialize input keys
966976
>>> prev_state = [[torch.randn(1, 1, 64) for __ in range(2)] for _ in range(4)] # B=4
967977
>>> obs = torch.randn(4,64)
968978
>>> model = DRQN(64, 64) # arguments: 'obs_shape' and 'action_shape'
969979
>>> outputs = model({'obs': inputs, 'prev_state': prev_state}, inference=True)
970-
>>> # Check outputs's Keys
980+
>>> # Validate output keys and shapes
971981
>>> assert isinstance(outputs, dict)
972982
>>> assert outputs['logit'].shape == (4, 64)
973983
>>> assert len(outputs['next_state']) == 4
@@ -976,9 +986,9 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps:
976986
"""
977987

978988
x, prev_state = inputs['obs'], inputs['prev_state']
979-
# for both inference and other cases, the network structure is encoder -> rnn network -> head
980-
# the difference is inference take the data with seq_len=1 (or T = 1)
981-
# NOTE(rjy): in most situations, set inference=True when evaluate and inference=False when training
989+
# Forward pass: Encoder -> RNN -> Head
990+
# in most situations, set inference=True when evaluate and inference=False when training
991+
# Inference mode: Processes one timestep (seq_len=1).
982992
if inference:
983993
x = self.encoder(x)
984994
if self.res_link:
@@ -992,27 +1002,28 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps:
9921002
x = self.head(x)
9931003
x['next_state'] = next_state
9941004
return x
1005+
# Training mode: Processes the entire sequence.
9951006
else:
9961007
# In order to better explain why rnn needs saved_state and which states need to be stored,
9971008
# let's take r2d2 as an example
9981009
# in r2d2,
9991010
# 1) data['burnin_nstep_obs'] = data['obs'][:bs + self._nstep]
10001011
# 2) data['main_obs'] = data['obs'][bs:-self._nstep]
10011012
# 3) data['target_obs'] = data['obs'][bs + self._nstep:]
1002-
# NOTE(rjy): (T, B, N) or (T, B, C, H, W)
1003-
assert len(x.shape) in [3, 5], x.shape
1013+
assert len(x.shape) in [3, 5], f"Expected shape (T, B, N) or (T, B, C, H, W), got {x.shape}"
10041014
x = parallel_wrapper(self.encoder)(x) # (T, B, N)
10051015
if self.res_link:
10061016
a = x
1007-
# NOTE(rjy) lstm_embedding stores all hidden_state
1017+
# lstm_embedding stores all hidden_state
10081018
lstm_embedding = []
10091019
# TODO(nyz) how to deal with hidden_size key-value
10101020
hidden_state_list = []
1021+
10111022
if saved_state_timesteps is not None:
10121023
saved_state = []
1013-
for t in range(x.shape[0]): # T timesteps
1014-
# NOTE(rjy) use x[t:t+1] but not x[t] can keep original dimension
1015-
output, prev_state = self.rnn(x[t:t + 1], prev_state) # output: (1,B, head_hidden_size)
1024+
for t in range(x.shape[0]): # Iterate over timesteps (T).
1025+
# use x[t:t+1] but not x[t] can keep the original dimension
1026+
output, prev_state = self.rnn(x[t:t + 1], prev_state) # RNN step output: (1, B, hidden_size)
10161027
if saved_state_timesteps is not None and t + 1 in saved_state_timesteps:
10171028
saved_state.append(prev_state)
10181029
lstm_embedding.append(output)
@@ -1023,7 +1034,7 @@ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps:
10231034
if self.res_link:
10241035
x = x + a
10251036
x = parallel_wrapper(self.head)(x) # (T, B, action_shape)
1026-
# NOTE(rjy): x['next_state'] is the hidden state of the last timestep inputted to lstm
1037+
# x['next_state'] is the hidden state of the last timestep inputted to lstm
10271038
# the last timestep state including the hidden state (h) and the cell state (c)
10281039
# shape: {list: B{dict: 2{Tensor:(1, 1, head_hidden_size}}}
10291040
x['next_state'] = prev_state

ding/model/template/qmix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
227227
), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0]))
228228
prev_state = reduce(lambda x, y: x + y, prev_state)
229229
agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
230-
output = self._q_network({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True})
230+
output = self._q_network({'obs': agent_state, 'prev_state': prev_state})
231231
agent_q, next_state = output['logit'], output['next_state']
232232
next_state, _ = list_split(next_state, step=A)
233233
agent_q = agent_q.reshape(T, B, A, -1)

ding/model/template/qtran.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
100100
), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0]))
101101
prev_state = reduce(lambda x, y: x + y, prev_state)
102102
agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
103-
output = self._q_network({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True})
103+
output = self._q_network({'obs': agent_state, 'prev_state': prev_state})
104104
agent_q, next_state = output['logit'], output['next_state']
105105
next_state, _ = list_split(next_state, step=A)
106106
agent_q = agent_q.reshape(T, B, A, -1)

ding/model/template/wqmix.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ def forward(self, data: dict, single_step: bool = True, q_star: bool = False) ->
177177
{
178178
'obs': agent_state,
179179
'prev_state': prev_state,
180-
'enable_fast_timestep': True
181180
}
182181
) # here is the forward pass of the agent networks of Q_star
183182
agent_q, next_state = output['logit'], output['next_state']
@@ -223,7 +222,6 @@ def forward(self, data: dict, single_step: bool = True, q_star: bool = False) ->
223222
{
224223
'obs': agent_state,
225224
'prev_state': prev_state,
226-
'enable_fast_timestep': True
227225
}
228226
) # here is the forward pass of the agent networks of Q
229227
agent_q, next_state = output['logit'], output['next_state']

ding/policy/ngu.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
290290
'action': data['burnin_nstep_action'],
291291
'reward': data['burnin_nstep_reward'],
292292
'beta': data['burnin_nstep_beta'],
293-
'enable_fast_timestep': True
294293
}
295294
tmp = self._learn_model.forward(
296295
inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep]
@@ -304,7 +303,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
304303
'action': data['main_action'],
305304
'reward': data['main_reward'],
306305
'beta': data['main_beta'],
307-
'enable_fast_timestep': True
308306
}
309307
self._learn_model.reset(data_id=None, state=tmp['saved_state'][0])
310308
q_value = self._learn_model.forward(inputs)['logit']
@@ -317,7 +315,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
317315
'action': data['target_action'],
318316
'reward': data['target_reward'],
319317
'beta': data['target_beta'],
320-
'enable_fast_timestep': True
321318
}
322319
with torch.no_grad():
323320
target_q_value = self._target_model.forward(next_inputs)['logit']

0 commit comments

Comments
 (0)