Skip to content

Commit 78c8908

Browse files
authored
left padding for LM inference (#2525)
* left padding for LM inference
1 parent efd316d commit 78c8908

File tree

5 files changed

+65
-33
lines changed

5 files changed

+65
-33
lines changed

onmt/decoders/transformer.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ def _forward(self, *args, **kwargs):
186186

187187
def _compute_dec_mask(self, tgt_pad_mask, future):
188188
tgt_len = tgt_pad_mask.size(-1)
189-
if not future: # apply future_mask, result mask in (B, T, T)
189+
if not future:
190+
# Add triangular future_mask and pad_mask, result mask in (B, T, T).
190191
future_mask = torch.ones(
191192
[tgt_len, tgt_len],
192193
device=tgt_pad_mask.device,
@@ -197,9 +198,14 @@ def _compute_dec_mask(self, tgt_pad_mask, future):
197198
future_mask = future_mask.triu_(-self.sliding_window)
198199
future_mask = future_mask.bool()
199200
future_mask = ~future_mask.view(1, tgt_len, tgt_len)
200-
201+
# Patch for scaled dot product attention.
202+
patch_mask = ~torch.all(
203+
tgt_pad_mask + future_mask, dim=2, keepdim=True
204+
).expand_as(tgt_pad_mask + future_mask)
201205
dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
202-
else: # only mask padding, result mask in (B, 1, T)
206+
dec_mask = torch.logical_and(dec_mask, patch_mask)
207+
else:
208+
# Only mask padding, result mask in (B, 1, T).
203209
dec_mask = tgt_pad_mask
204210
return dec_mask
205211

@@ -717,7 +723,9 @@ def _forward(
717723
dec_mask = None
718724

719725
if layer_in.size(1) > 1:
720-
# masking is necessary when sequence length is greater than one
726+
# Masking is necessary when sequence length is greater than one
727+
# The decoding has not started yet,
728+
# we compute the scores on the source tokens in one shot.
721729
dec_mask = self._compute_dec_mask(tgt_pad_mask, future)
722730
dec_mask = dec_mask.unsqueeze(1)
723731
dec_mask = dec_mask.expand(-1, -1, dec_mask.size(3), -1)
@@ -859,8 +867,11 @@ def detach_state(self):
859867
def forward(self, tgt, enc_out=None, step=None, **kwargs):
860868
"""Decode, possibly stepwise."""
861869
if step == 0:
870+
# decoding mode.
871+
# Initialize KV cache.
862872
self._init_cache(tgt)
863873
elif step is None:
874+
# training mode.
864875
for layer in self.transformer_layers:
865876
layer.self_attn.layer_cache = (
866877
False,

onmt/inputters/dynamic_iterator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Module that contain iterator used for dynamic data."""
22
import torch
33
from itertools import cycle
4-
from onmt.constants import CorpusTask
4+
from onmt.constants import CorpusTask, ModelTask
55
from onmt.inputters.text_corpus import get_corpora, build_corpora_iters
66
from onmt.inputters.text_utils import (
77
text_sort_key,
@@ -164,6 +164,10 @@ def __init__(
164164
self.skip_empty_level = skip_empty_level
165165
self.random_shuffler = RandomShuffler()
166166
self.bucket_idx = 0
167+
if task != CorpusTask.TRAIN and vocabs["data_task"] == ModelTask.LANGUAGE_MODEL:
168+
self.left_pad = True
169+
else:
170+
self.left_pad = False
167171

168172
@classmethod
169173
def from_opt(
@@ -354,7 +358,9 @@ def __iter__(self):
354358
# within the batch
355359
if self.task == CorpusTask.TRAIN:
356360
minibatch.sort(key=lambda x: self.sort_key(x[0]), reverse=True)
357-
tensor_batch = tensorify(self.vocabs, minibatch, self.device)
361+
tensor_batch = tensorify(
362+
self.vocabs, minibatch, self.device, self.left_pad
363+
)
358364
yield (tensor_batch, bucket_idx)
359365

360366

onmt/inputters/text_utils.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def parse_align_idx(align_pharaoh):
168168
return flatten_align_idx
169169

170170

171-
def tensorify(vocabs, minibatch, device):
171+
def tensorify(vocabs, minibatch, device, left_pad=False):
172172
"""
173173
This function transforms a batch of example in tensors
174174
Each example looks like
@@ -193,21 +193,37 @@ def tensorify(vocabs, minibatch, device):
193193
}
194194
"""
195195
tensor_batch = {}
196-
tbatchsrc = [
197-
torch.tensor(ex["src"]["src_ids"], dtype=torch.long, device=device)
198-
for ex, indice in minibatch
199-
]
196+
if left_pad:
197+
tbatchsrc = [
198+
torch.tensor(ex["src"]["src_ids"], dtype=torch.long, device=device).flip(
199+
dims=[0]
200+
)
201+
for ex, indice in minibatch
202+
]
203+
else:
204+
tbatchsrc = [
205+
torch.tensor(ex["src"]["src_ids"], dtype=torch.long, device=device)
206+
for ex, indice in minibatch
207+
]
200208
padidx = vocabs["src"][DefaultTokens.PAD]
201209
tbatchsrc = pad_sequence(tbatchsrc, batch_first=True, padding_value=padidx)
202210
if "feats" in minibatch[0][0]["src"]:
203211
tbatchfs = [tbatchsrc]
204212
for feat_id in range(len(minibatch[0][0]["src"]["feats"])):
205-
tbatchfeat = [
206-
torch.tensor(
207-
ex["src"]["feats"][feat_id], dtype=torch.long, device=device
208-
)
209-
for ex, indice in minibatch
210-
]
213+
if left_pad:
214+
tbatchfeat = [
215+
torch.tensor(
216+
ex["src"]["feats"][feat_id], dtype=torch.long, device=device
217+
).flip(dims=[0])
218+
for ex, indice in minibatch
219+
]
220+
else:
221+
tbatchfeat = [
222+
torch.tensor(
223+
ex["src"]["feats"][feat_id], dtype=torch.long, device=device
224+
)
225+
for ex, indice in minibatch
226+
]
211227
padidx = vocabs["src_feats"][feat_id][DefaultTokens.PAD]
212228
tbatchfeat = pad_sequence(
213229
tbatchfeat, batch_first=True, padding_value=padidx
@@ -218,7 +234,10 @@ def tensorify(vocabs, minibatch, device):
218234
# Need to add features in last dimensions
219235
tbatchsrc = tbatchsrc[:, :, None]
220236

221-
tensor_batch["src"] = tbatchsrc
237+
if left_pad:
238+
tensor_batch["src"] = tbatchsrc.flip(dims=[1])
239+
else:
240+
tensor_batch["src"] = tbatchsrc
222241

223242
tensor_batch["srclen"] = torch.tensor(
224243
[len(ex["src"]["src_ids"]) for ex, indice in minibatch],

onmt/modules/multi_headed_attn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def forward(
405405
# 1) Project key, value, and query.
406406
# as a reminder at training layer_cache[0] remains False
407407
if self.layer_cache[0]:
408+
# Retrieve keys and values from the KV cache (decoding mode only).
408409
if self.attn_type == "self":
409410
query, key, value = (
410411
self.linear_query(query),
@@ -451,6 +452,7 @@ def forward(
451452
self.layer_cache[1]["keys"] = key
452453
self.layer_cache[1]["values"] = value
453454
else:
455+
# Retrieve keys and values from linear layers (training mode).
454456
key = self.maybe_ckpt(self.linear_keys, key)
455457
value = self.maybe_ckpt(self.linear_values, value)
456458
query = self.maybe_ckpt(self.linear_query, query)
@@ -491,12 +493,12 @@ def forward(
491493
self.flash2
492494
and l > 256 # https://github.com/Dao-AILab/flash-attention/issues/591
493495
)
494-
495496
if (
496497
self.max_relative_positions in [-1, 0]
497498
and not return_attn
498499
and query.device != torch.device("cpu")
499500
):
501+
# Apply flash2 attention.
500502
causal = self.is_decoder and self.attn_type == "self" and mask is not None
501503
if self.is_decoder and self.attn_type == "self" and flash2:
502504
if causal:
@@ -514,6 +516,7 @@ def forward(
514516
window_size=window_size,
515517
).transpose(1, 2)
516518
else:
519+
# Apply scaled dot product attention.
517520
with torch.backends.cuda.sdp_kernel(
518521
enable_flash=False, enable_math=True, enable_mem_efficient=True
519522
):

onmt/translate/translator.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,6 @@ def _decode_and_generate(
658658
step=step,
659659
return_attn=self.global_scorer.has_cov_pen or return_attn,
660660
)
661-
662661
# Generator forward.
663662
if not self.copy_attn:
664663
if "std" in dec_attn:
@@ -988,16 +987,6 @@ def _align_forward(self, batch, predictions):
988987

989988
def translate_batch(self, batch, attn_debug):
990989
"""Translate a batch of sentences."""
991-
batch_size = len(batch["srclen"])
992-
if batch_size != 1:
993-
warning_msg = (
994-
"GeneratorLM does not support batch_size != 1"
995-
" nicely. You can remove this limitation here."
996-
" With batch_size > 1 the end of each input is"
997-
" repeated until the input is finished. Then"
998-
" generation will start."
999-
)
1000-
self._log(warning_msg)
1001990
with torch.no_grad():
1002991
if self.sample_from_topk != 0 or self.sample_from_topp != 0:
1003992
decode_strategy = GreedySearchLM(
@@ -1061,7 +1050,7 @@ def tile_to_beam_size_after_initial_step(self, fn_map_state, log_probs):
10611050
log_probs = log_probs[:, -1, :]
10621051
return log_probs
10631052

1064-
def _translate_batch_with_strategy(self, batch, decode_strategy):
1053+
def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True):
10651054
"""Translate a batch of sentences step by step using cache.
10661055
10671056
Args:
@@ -1081,7 +1070,12 @@ def _translate_batch_with_strategy(self, batch, decode_strategy):
10811070
src = batch["src"]
10821071
src_len = batch["srclen"]
10831072

1084-
src, src_len, target_prefix = self.split_src_to_prevent_padding(src, src_len)
1073+
if left_pad:
1074+
target_prefix = None
1075+
else:
1076+
src, src_len, target_prefix = self.split_src_to_prevent_padding(
1077+
src, src_len
1078+
)
10851079

10861080
# (2) init decoder
10871081
self.model.decoder.init_state(src, None, None)
@@ -1109,7 +1103,6 @@ def _translate_batch_with_strategy(self, batch, decode_strategy):
11091103
decoder_input = (
11101104
src if step == 0 else decode_strategy.current_predictions.view(-1, 1, 1)
11111105
)
1112-
11131106
log_probs, attn = self._decode_and_generate(
11141107
decoder_input,
11151108
None,

0 commit comments

Comments
 (0)