From fd8dda35390d058c1745b9495634ea0ddadf71ad Mon Sep 17 00:00:00 2001 From: Atul Kumar Date: Tue, 23 Oct 2018 11:49:59 -0700 Subject: [PATCH] Update eval.py --- training_ptr_gen/eval.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/training_ptr_gen/eval.py b/training_ptr_gen/eval.py index fff42a19..7d16eb80 100644 --- a/training_ptr_gen/eval.py +++ b/training_ptr_gen/eval.py @@ -47,7 +47,7 @@ def eval_one_batch(self, batch): step_losses = [] for di in range(min(max_dec_len, config.max_dec_steps)): y_t_1 = dec_batch[:, di] # Teacher forcing - final_dist, s_t_1, c_t_1,attn_dist, p_gen, coverage = self.model.decoder(y_t_1, s_t_1, + final_dist, s_t_1, c_t_1,attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, s_t_1, encoder_outputs, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di) target = target_batch[:, di] @@ -56,6 +56,7 @@ def eval_one_batch(self, batch): if config.is_coverage: step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) step_loss = step_loss + config.cov_loss_wt * step_coverage_loss + coverage = next_coverage step_mask = dec_padding_mask[:, di] step_loss = step_loss * step_mask