@@ -22,21 +22,19 @@ def train(self, cuda):
2222 batch_num = self .replay_buffer .get_batch_num ()
2323
2424 for i in range (batch_num ):
25- obs , feat , obs_next , feat_next , dones , rewards , actions , masks = self .replay_buffer .sample ()
25+ obs , feat , obs_next , feat_next , dones , rewards , acts , masks = self .replay_buffer .sample ()
2626
2727 obs = torch .FloatTensor (obs ).permute ([0 , 3 , 1 , 2 ]).cuda () if cuda else torch .FloatTensor (obs ).permute ([0 , 3 , 1 , 2 ])
2828 obs_next = torch .FloatTensor (obs_next ).permute ([0 , 3 , 1 , 2 ]).cuda () if cuda else torch .FloatTensor (obs_next ).permute ([0 , 3 , 1 , 2 ])
2929 feat = torch .FloatTensor (feat ).cuda () if cuda else torch .FloatTensor (feat )
3030 feat_next = torch .FloatTensor (feat_next ).cuda () if cuda else torch .FloatTensor (feat_next )
3131 acts = torch .LongTensor (acts ).cuda () if cuda else torch .LongTensor (acts )
32- act_prob = torch .FloatTensor (act_prob ).cuda () if cuda else torch .FloatTensor (act_prob )
33- act_prob_next = torch .FloatTensor (act_prob_next ).cuda () if cuda else torch .FloatTensor (act_prob_next )
3432 rewards = torch .FloatTensor (rewards ).cuda () if cuda else torch .FloatTensor (rewards )
3533 dones = torch .FloatTensor (dones ).cuda () if cuda else torch .FloatTensor (dones )
3634 masks = torch .FloatTensor (masks ).cuda () if cuda else torch .FloatTensor (masks )
3735
3836 target_q = self .calc_target_q (obs = obs_next , feature = feat_next , rewards = rewards , dones = dones )
39- loss , q = super ().train (obs = obs , feature = feat , target_q = target_q , acts = actions , masks = masks )
37+ loss , q = super ().train (obs = obs , feature = feat , target_q = target_q , acts = acts , mask = masks )
4038
4139 self .update ()
4240
0 commit comments