Skip to content

Commit 4fa9721

Browse files
committed
further code cleaning
1 parent 506b355 commit 4fa9721

File tree

1 file changed

+10
-22
lines changed

1 file changed

+10
-22
lines changed

onmt/modules/multi_headed_attn.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,6 @@ def forward(
439439
"""
440440
# 1) Project key, value, and query.
441441
# as a reminder at training layer_cache[0] remains False
442-
key_pad_mask = self.layer_cache[1].get("key_pad_mask", None)
443442
if self.layer_cache[0]:
444443
# Retrieve keys and values from the KV cache (decoding mode only).
445444
if self.attn_type == "self":
@@ -483,12 +482,16 @@ def forward(
483482
if sliding_window > 0 and key.size(2) > sliding_window:
484483
key = key[:, :, 1:, :]
485484
value = value[:, :, 1:, :]
486-
if key_pad_mask is not None and step == 0:
487-
x = key_pad_mask
488-
x = x.expand(-1, self.head_count // self.parallel_gpu, -1)
489-
x = x.unsqueeze(3)
490-
x = x.expand(-1, -1, -1, value.size(3))
491-
value = value.masked_fill(x, 0)
485+
486+
if step == 0:
487+
key_pad_mask = self.layer_cache[1].get("key_pad_mask", None)
488+
if key_pad_mask is not None:
489+
x = key_pad_mask.expand(
490+
-1, self.head_count // self.parallel_gpu, -1
491+
)
492+
x = x.unsqueeze(3)
493+
x = x.expand(-1, -1, -1, value.size(3))
494+
value = value.masked_fill(x, 0)
492495

493496
self.layer_cache[1]["keys"] = key
494497
self.layer_cache[1]["values"] = value
@@ -571,19 +574,6 @@ def forward(
571574
self.layer_cache[1]["keys"] = key
572575
self.layer_cache[1]["values"] = value
573576

574-
if key_pad_mask is not None:
575-
# Increase the cached key pad mask by concatenation.
576-
# For decoding only.
577-
if step > 0:
578-
y = torch.zeros(
579-
(key_pad_mask.size(0), key_pad_mask.size(1), 1),
580-
dtype=torch.bool,
581-
device=key_pad_mask.device,
582-
)
583-
self.layer_cache[1]["key_pad_mask"] = torch.cat(
584-
(key_pad_mask, y), 2
585-
)
586-
key_pad_mask = self.layer_cache[1]["key_pad_mask"]
587577
else:
588578
# Retrieve keys and values from linear layers (training mode).
589579
key = self.maybe_ckpt(self.linear_keys, key)
@@ -712,8 +702,6 @@ def forward(
712702
scores = self.alibi(scores)
713703

714704
scores = scores.float()
715-
if key_pad_mask is not None and mask is None:
716-
mask = key_pad_mask.unsqueeze(1)
717705

718706
if mask is not None:
719707
# not 100% necessary but expand to nb of heads

0 commit comments

Comments
 (0)