Skip to content

Commit

Permalink
further code cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
l-k-11235 committed May 28, 2024
1 parent 506b355 commit 4fa9721
Showing 1 changed file with 10 additions and 22 deletions.
32 changes: 10 additions & 22 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,6 @@ def forward(
"""
# 1) Project key, value, and query.
# as a reminder at training layer_cache[0] remains False
key_pad_mask = self.layer_cache[1].get("key_pad_mask", None)
if self.layer_cache[0]:
# Retrieve keys and values from the KV cache (decoding mode only).
if self.attn_type == "self":
Expand Down Expand Up @@ -483,12 +482,16 @@ def forward(
if sliding_window > 0 and key.size(2) > sliding_window:
key = key[:, :, 1:, :]
value = value[:, :, 1:, :]
if key_pad_mask is not None and step == 0:
x = key_pad_mask
x = x.expand(-1, self.head_count // self.parallel_gpu, -1)
x = x.unsqueeze(3)
x = x.expand(-1, -1, -1, value.size(3))
value = value.masked_fill(x, 0)

if step == 0:
key_pad_mask = self.layer_cache[1].get("key_pad_mask", None)
if key_pad_mask is not None:
x = key_pad_mask.expand(
-1, self.head_count // self.parallel_gpu, -1
)
x = x.unsqueeze(3)
x = x.expand(-1, -1, -1, value.size(3))
value = value.masked_fill(x, 0)

self.layer_cache[1]["keys"] = key
self.layer_cache[1]["values"] = value
Expand Down Expand Up @@ -571,19 +574,6 @@ def forward(
self.layer_cache[1]["keys"] = key
self.layer_cache[1]["values"] = value

if key_pad_mask is not None:
# Increase the cached key pad mask by concatenation.
# For decoding only.
if step > 0:
y = torch.zeros(
(key_pad_mask.size(0), key_pad_mask.size(1), 1),
dtype=torch.bool,
device=key_pad_mask.device,
)
self.layer_cache[1]["key_pad_mask"] = torch.cat(
(key_pad_mask, y), 2
)
key_pad_mask = self.layer_cache[1]["key_pad_mask"]
else:
# Retrieve keys and values from linear layers (training mode).
key = self.maybe_ckpt(self.linear_keys, key)
Expand Down Expand Up @@ -712,8 +702,6 @@ def forward(
scores = self.alibi(scores)

scores = scores.float()
if key_pad_mask is not None and mask is None:
mask = key_pad_mask.unsqueeze(1)

if mask is not None:
# not 100% necessary but expand to nb of heads
Expand Down

0 comments on commit 4fa9721

Please sign in to comment.