We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0e72326 commit d9d8b77Copy full SHA for d9d8b77
onmt/modules/multi_headed_attn.py
@@ -464,11 +464,11 @@ def forward(
464
or query.dtype != torch.float16
465
):
466
if self.max_relative_positions == -1: # Rotary Embeddings
467
- if seqlen > self.rope.size(0):
468
-
+ if seqlen + start_pos > self.rope.size(0):
+ # Resize rotary embeddings.
469
self.rope, _, _ = rotaryembeddings(
470
self.rotary_dim,
471
- maxseqlen=(seqlen + 2048),
+ maxseqlen=(seqlen + start_pos + 2048),
472
base=self.rotary_theta,
473
device=self.rope.device,
474
)
0 commit comments