Skip to content

Commit 08bf4c4

Browse files
committed
fix some bugs
1 parent 4a43798 commit 08bf4c4

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

algo/q_learning.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,19 @@ def train(self, cuda):
2222
batch_num = self.replay_buffer.get_batch_num()
2323

2424
for i in range(batch_num):
25-
obs, feat, obs_next, feat_next, dones, rewards, actions, masks = self.replay_buffer.sample()
25+
obs, feat, obs_next, feat_next, dones, rewards, acts, masks = self.replay_buffer.sample()
2626

2727
obs = torch.FloatTensor(obs).permute([0, 3, 1, 2]).cuda() if cuda else torch.FloatTensor(obs).permute([0, 3, 1, 2])
2828
obs_next = torch.FloatTensor(obs_next).permute([0, 3, 1, 2]).cuda() if cuda else torch.FloatTensor(obs_next).permute([0, 3, 1, 2])
2929
feat = torch.FloatTensor(feat).cuda() if cuda else torch.FloatTensor(feat)
3030
feat_next = torch.FloatTensor(feat_next).cuda() if cuda else torch.FloatTensor(feat_next)
3131
acts = torch.LongTensor(acts).cuda() if cuda else torch.LongTensor(acts)
32-
act_prob = torch.FloatTensor(act_prob).cuda() if cuda else torch.FloatTensor(act_prob)
33-
act_prob_next = torch.FloatTensor(act_prob_next).cuda() if cuda else torch.FloatTensor(act_prob_next)
3432
rewards = torch.FloatTensor(rewards).cuda() if cuda else torch.FloatTensor(rewards)
3533
dones = torch.FloatTensor(dones).cuda() if cuda else torch.FloatTensor(dones)
3634
masks = torch.FloatTensor(masks).cuda() if cuda else torch.FloatTensor(masks)
3735

3836
target_q = self.calc_target_q(obs=obs_next, feature=feat_next, rewards=rewards, dones=dones)
39-
loss, q = super().train(obs=obs, feature=feat, target_q=target_q, acts=actions, masks=masks)
37+
loss, q = super().train(obs=obs, feature=feat, target_q=target_q, acts=acts, mask=masks)
4038

4139
self.update()
4240

train_battle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def linear_decay(epoch, x, y):
3636

3737
if __name__ == '__main__':
3838
parser = argparse.ArgumentParser()
39-
parser.add_argument('--algo', type=str, choices={'ac', 'mfac', 'mfq', 'il'}, help='choose an algorithm from the preset', required=True)
39+
parser.add_argument('--algo', type=str, choices={'ac', 'mfac', 'mfq', 'iql'}, help='choose an algorithm from the preset', required=True)
4040
parser.add_argument('--save_every', type=int, default=20, help='decide the self-play update interval')
4141
parser.add_argument('--update_every', type=int, default=5, help='decide the udpate interval for q-learning, optional')
4242
parser.add_argument('--n_round', type=int, default=2000, help='set the trainning round')

0 commit comments

Comments
 (0)