Skip to content

Commit d9d8b77

Browse files
authored
fix generation with large sequences when flash2 is False (#2564)
1 parent 0e72326 commit d9d8b77

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

onmt/modules/multi_headed_attn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,11 +464,11 @@ def forward(
464464
or query.dtype != torch.float16
465465
):
466466
if self.max_relative_positions == -1: # Rotary Embeddings
467-
if seqlen > self.rope.size(0):
468-
467+
if seqlen + start_pos > self.rope.size(0):
468+
# Resize rotary embeddings.
469469
self.rope, _, _ = rotaryembeddings(
470470
self.rotary_dim,
471-
maxseqlen=(seqlen + 2048),
471+
maxseqlen=(seqlen + start_pos + 2048),
472472
base=self.rotary_theta,
473473
device=self.rope.device,
474474
)

0 commit comments

Comments
 (0)