@@ -439,7 +439,6 @@ def forward(
439
439
"""
440
440
# 1) Project key, value, and query.
441
441
# as a reminder at training layer_cache[0] remains False
442
- key_pad_mask = self .layer_cache [1 ].get ("key_pad_mask" , None )
443
442
if self .layer_cache [0 ]:
444
443
# Retrieve keys and values from the KV cache (decoding mode only).
445
444
if self .attn_type == "self" :
@@ -483,12 +482,16 @@ def forward(
483
482
if sliding_window > 0 and key .size (2 ) > sliding_window :
484
483
key = key [:, :, 1 :, :]
485
484
value = value [:, :, 1 :, :]
486
- if key_pad_mask is not None and step == 0 :
487
- x = key_pad_mask
488
- x = x .expand (- 1 , self .head_count // self .parallel_gpu , - 1 )
489
- x = x .unsqueeze (3 )
490
- x = x .expand (- 1 , - 1 , - 1 , value .size (3 ))
491
- value = value .masked_fill (x , 0 )
485
+
486
+ if step == 0 :
487
+ key_pad_mask = self .layer_cache [1 ].get ("key_pad_mask" , None )
488
+ if key_pad_mask is not None :
489
+ x = key_pad_mask .expand (
490
+ - 1 , self .head_count // self .parallel_gpu , - 1
491
+ )
492
+ x = x .unsqueeze (3 )
493
+ x = x .expand (- 1 , - 1 , - 1 , value .size (3 ))
494
+ value = value .masked_fill (x , 0 )
492
495
493
496
self .layer_cache [1 ]["keys" ] = key
494
497
self .layer_cache [1 ]["values" ] = value
@@ -571,19 +574,6 @@ def forward(
571
574
self .layer_cache [1 ]["keys" ] = key
572
575
self .layer_cache [1 ]["values" ] = value
573
576
574
- if key_pad_mask is not None :
575
- # Increase the cached key pad mask by concatenation.
576
- # For decoding only.
577
- if step > 0 :
578
- y = torch .zeros (
579
- (key_pad_mask .size (0 ), key_pad_mask .size (1 ), 1 ),
580
- dtype = torch .bool ,
581
- device = key_pad_mask .device ,
582
- )
583
- self .layer_cache [1 ]["key_pad_mask" ] = torch .cat (
584
- (key_pad_mask , y ), 2
585
- )
586
- key_pad_mask = self .layer_cache [1 ]["key_pad_mask" ]
587
577
else :
588
578
# Retrieve keys and values from linear layers (training mode).
589
579
key = self .maybe_ckpt (self .linear_keys , key )
@@ -712,8 +702,6 @@ def forward(
712
702
scores = self .alibi (scores )
713
703
714
704
scores = scores .float ()
715
- if key_pad_mask is not None and mask is None :
716
- mask = key_pad_mask .unsqueeze (1 )
717
705
718
706
if mask is not None :
719
707
# not 100% necessary but expand to nb of heads
0 commit comments