Skip to content

Commit ea41d81

Browse files
UbuntuUbuntu
Ubuntu
authored and
Ubuntu
committed
maybe bugfix
1 parent 36f551a commit ea41d81

8 files changed

+18
-15
lines changed

Diff for: configs/attn_lstm_vocab_1k.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: attn_lstm_vocab_1k
22
train:
3-
batch_size: 512
3+
batch_size: 256
44
LOAD_EPOCH:
55
epochs: 3
66
num_workers: 6

Diff for: configs/attn_lstm_vocab_50k.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
name: attn_lstm_vocab_50k
22
train:
3-
batch_size: 128
3+
batch_size: 64
44
LOAD_EPOCH:
5-
epochs: 6
5+
epochs: 1
66
num_workers: 6
77
eval_period: 1
88
checkpoint_period: 1

Diff for: configs/pointer_vocab_50k.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
name: pointer_vocab_50k
22
train:
3-
batch_size: 128
3+
batch_size: 64
44
LOAD_EPOCH:
5-
epochs: 6
5+
epochs: 1
66
num_workers: 6
77
eval_period: 1
88
checkpoint_period: 1

Diff for: configs/simple_lstm_vocab_50k.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
name: simple_lstm_vocab_50k
22
train:
3-
batch_size: 256
3+
batch_size: 128
44
LOAD_EPOCH:
5-
epochs: 5
5+
epochs: 1
66
num_workers: 6
77
eval_period: 1
88
checkpoint_period: 1

Diff for: model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def forward(
336336
# cond = (t_tensor[:, iter] < self.vocab_sizeT + self.attn_size).long()
337337
# masked_target = cond * t_tensor[:, iter] + (1 - cond) * self.eof_T_id
338338
target = t_tensor[:, iter]
339-
target[target >= output.shape[1]] = self.eof_T_id # ignored index
339+
target[target >= output.shape[1]] = self.unk_id
340340
token_losses[:, iter] = self.criterion(output, t_tensor[:, iter].clone().detach())
341341

342342
loss = token_losses.sum() #/ batch_size

Diff for: run.sh

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
python3 train.py --config=configs/pointer_vocab_10k.yml
2-
python3 train.py --config=configs/pointer_vocab_50k.yml
3-
python3 train.py --config=configs/attn_lstm_vocab_1k.yml
4-
python3 train.py --config=configs/attn_lstm_vocab_50k.yml
5-
python3 train.py --config=configs/simple_lstm_vocab_1k.yml
6-
python3 train.py --config=configs/simple_lstm_vocab_50k.yml
1+
rm -r logs/pointer_vocab_50k
2+
# python3 train.py --config=configs/pointer_vocab_10k.yml
3+
python train.py --config=configs/pointer_vocab_50k.yml
4+
python train.py --config=configs/attn_lstm_vocab_1k.yml
5+
python train.py --config=configs/attn_lstm_vocab_50k.yml
6+
#python3 train.py --config=configs/simple_lstm_vocab_1k.yml
7+
python train.py --config=configs/simple_lstm_vocab_50k.yml

Diff for: train.py

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def train(config):
4747
)
4848

4949
ignored_index = data_train.vocab_sizeT - 1
50+
unk_index = data_train.vocab_sizeT - 2
5051

5152
model = MixtureAttention(
5253
hidden_size = config.model.hidden_size,

Diff for: utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ def forward(self, output, target):
3232

3333
return F.kl_div(output, model_prob, reduction='sum')
3434

35-
def accuracy(out, target, ignored_index):
35+
def accuracy(out, target, ignored_index, unk_index):
3636
out_ = out[target != ignored_index]
3737
target_ = target[target != ignored_index]
38+
out_ = out_[out_ == unk_index] = -1
3839
return accuracy_score(out_, target_)
3940

4041
class DotDict(dict):

0 commit comments

Comments
 (0)