Skip to content

Commit b147fc5

Browse files
authored
Update train.py
1 parent 5e51169 commit b147fc5

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

training_ptr_gen/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def train_one_batch(self, batch):
8888
step_losses = []
8989
for di in range(min(max_dec_len, config.max_dec_steps)):
9090
y_t_1 = dec_batch[:, di] # Teacher forcing
91-
final_dist, s_t_1, c_t_1, attn_dist, p_gen, coverage = self.model.decoder(y_t_1, s_t_1,
91+
final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, s_t_1,
9292
encoder_outputs, enc_padding_mask, c_t_1,
9393
extra_zeros, enc_batch_extend_vocab,
9494
coverage, di)
@@ -98,6 +98,8 @@ def train_one_batch(self, batch):
9898
if config.is_coverage:
9999
step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
100100
step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
101+
coverage = next_coverage
102+
101103
step_mask = dec_padding_mask[:, di]
102104
step_loss = step_loss * step_mask
103105
step_losses.append(step_loss)

0 commit comments

Comments
 (0)