Skip to content

Commit 454a2f6

Browse files
committed
fix context calculation at decode stesp
1 parent 3232666 commit 454a2f6

File tree

4 files changed

+16
-5
lines changed

4 files changed

+16
-5
lines changed

training_ptr_gen/decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def beam_search(self, batch):
160160

161161
final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(y_t_1, s_t_1,
162162
encoder_outputs, enc_padding_mask, c_t_1,
163-
extra_zeros, enc_batch_extend_vocab, coverage_t_1)
163+
extra_zeros, enc_batch_extend_vocab, coverage_t_1, steps)
164164

165165
topk_log_probs, topk_ids = torch.topk(final_dist, config.beam_size * 2)
166166

training_ptr_gen/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def eval_one_batch(self, batch):
4949
y_t_1 = dec_batch[:, di] # Teacher forcing
5050
final_dist, s_t_1, c_t_1,attn_dist, p_gen, coverage = self.model.decoder(y_t_1, s_t_1,
5151
encoder_outputs, enc_padding_mask, c_t_1,
52-
extra_zeros, enc_batch_extend_vocab, coverage)
52+
extra_zeros, enc_batch_extend_vocab, coverage, di)
5353
target = target_batch[:, di]
5454
gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
5555
step_loss = -torch.log(gold_probs + config.eps)

training_ptr_gen/model.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,15 @@ def __init__(self):
145145
init_linear_wt(self.out2)
146146

147147
def forward(self, y_t_1, s_t_1, encoder_outputs, enc_padding_mask,
148-
c_t_1, extra_zeros, enc_batch_extend_vocab, coverage):
148+
c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, step):
149+
150+
if not self.training and step == 0:
151+
h_decoder, c_decoder = s_t_1
152+
s_t_hat = torch.cat((h_decoder.view(-1, config.hidden_dim),
153+
c_decoder.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim
154+
c_t, _, coverage_next = self.attention_network(s_t_hat, encoder_outputs,
155+
enc_padding_mask, coverage)
156+
coverage = coverage_next
149157

150158
y_t_1_embd = self.embedding(y_t_1)
151159
x = self.x_context(torch.cat((c_t_1, y_t_1_embd), 1))
@@ -154,9 +162,12 @@ def forward(self, y_t_1, s_t_1, encoder_outputs, enc_padding_mask,
154162
h_decoder, c_decoder = s_t
155163
s_t_hat = torch.cat((h_decoder.view(-1, config.hidden_dim),
156164
c_decoder.view(-1, config.hidden_dim)), 1) # B x 2*hidden_dim
157-
c_t, attn_dist, coverage = self.attention_network(s_t_hat, encoder_outputs,
165+
c_t, attn_dist, coverage_next = self.attention_network(s_t_hat, encoder_outputs,
158166
enc_padding_mask, coverage)
159167

168+
if self.training or step > 0:
169+
coverage = coverage_next
170+
160171
p_gen = None
161172
if config.pointer_gen:
162173
p_gen_input = torch.cat((c_t, s_t_hat, x), 1) # B x (2*2*hidden_dim + emb_dim)

training_ptr_gen/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def train_one_batch(self, batch):
9191
final_dist, s_t_1, c_t_1, attn_dist, p_gen, coverage = self.model.decoder(y_t_1, s_t_1,
9292
encoder_outputs, enc_padding_mask, c_t_1,
9393
extra_zeros, enc_batch_extend_vocab,
94-
coverage)
94+
coverage, di)
9595
target = target_batch[:, di]
9696
gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
9797
step_loss = -torch.log(gold_probs + config.eps)

0 commit comments

Comments
 (0)