Skip to content
This repository was archived by the owner on Aug 18, 2021. It is now read-only.

Commit 31fdb61

Browse files
committed
beginning of batched seq2seq
1 parent 9d7ab1a commit 31fdb61

File tree

2 files changed

+1801
-0
lines changed

2 files changed

+1801
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
from torch.nn import functional
3+
from torch.autograd import Variable
4+
5+
def sequence_mask(sequence_length, max_len=None):
6+
if max_len is None:
7+
max_len = sequence_length.data.max()
8+
batch_size = sequence_length.size(0)
9+
seq_range = torch.range(0, max_len - 1).long()
10+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
11+
seq_range_expand = Variable(seq_range_expand)
12+
if sequence_length.is_cuda:
13+
seq_range_expand = seq_range_expand.cuda()
14+
seq_length_expand = (sequence_length.unsqueeze(1)
15+
.expand_as(seq_range_expand))
16+
return seq_range_expand < seq_length_expand
17+
18+
19+
def masked_cross_entropy(logits, target, length):
20+
length = Variable(torch.LongTensor(length)).cuda()
21+
22+
"""
23+
Args:
24+
logits: A Variable containing a FloatTensor of size
25+
(batch, max_len, num_classes) which contains the
26+
unnormalized probability for each class.
27+
target: A Variable containing a LongTensor of size
28+
(batch, max_len) which contains the index of the true
29+
class for each corresponding step.
30+
length: A Variable containing a LongTensor of size (batch,)
31+
which contains the length of each data in a batch.
32+
33+
Returns:
34+
loss: An average loss value masked by the length.
35+
"""
36+
37+
# logits_flat: (batch * max_len, num_classes)
38+
logits_flat = logits.view(-1, logits.size(-1))
39+
# log_probs_flat: (batch * max_len, num_classes)
40+
log_probs_flat = functional.log_softmax(logits_flat)
41+
# target_flat: (batch * max_len, 1)
42+
target_flat = target.view(-1, 1)
43+
# losses_flat: (batch * max_len, 1)
44+
losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
45+
# losses: (batch, max_len)
46+
losses = losses_flat.view(*target.size())
47+
# mask: (batch, max_len)
48+
mask = sequence_mask(sequence_length=length, max_len=target.size(1))
49+
losses = losses * mask.float()
50+
loss = losses.sum() / length.float().sum()
51+
return loss

0 commit comments

Comments
 (0)