Skip to content

Commit

Permalink
Update eval.py
Browse files Browse the repository at this point in the history
  • Loading branch information
atulkum authored Oct 23, 2018
1 parent b147fc5 commit fd8dda3
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion training_ptr_gen/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def eval_one_batch(self, batch):
step_losses = []
for di in range(min(max_dec_len, config.max_dec_steps)):
y_t_1 = dec_batch[:, di] # Teacher forcing
final_dist, s_t_1, c_t_1,attn_dist, p_gen, coverage = self.model.decoder(y_t_1, s_t_1,
final_dist, s_t_1, c_t_1,attn_dist, p_gen, next_coverage = self.model.decoder(y_t_1, s_t_1,
encoder_outputs, enc_padding_mask, c_t_1,
extra_zeros, enc_batch_extend_vocab, coverage, di)
target = target_batch[:, di]
Expand All @@ -56,6 +56,7 @@ def eval_one_batch(self, batch):
if config.is_coverage:
step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
coverage = next_coverage

step_mask = dec_padding_mask[:, di]
step_loss = step_loss * step_mask
Expand Down

0 comments on commit fd8dda3

Please sign in to comment.