@@ -439,7 +439,6 @@ def forward(
439439 """
440440 # 1) Project key, value, and query.
441441 # as a reminder at training layer_cache[0] remains False
442- key_pad_mask = self .layer_cache [1 ].get ("key_pad_mask" , None )
443442 if self .layer_cache [0 ]:
444443 # Retrieve keys and values from the KV cache (decoding mode only).
445444 if self .attn_type == "self" :
@@ -483,12 +482,16 @@ def forward(
483482 if sliding_window > 0 and key .size (2 ) > sliding_window :
484483 key = key [:, :, 1 :, :]
485484 value = value [:, :, 1 :, :]
486- if key_pad_mask is not None and step == 0 :
487- x = key_pad_mask
488- x = x .expand (- 1 , self .head_count // self .parallel_gpu , - 1 )
489- x = x .unsqueeze (3 )
490- x = x .expand (- 1 , - 1 , - 1 , value .size (3 ))
491- value = value .masked_fill (x , 0 )
485+
486+ if step == 0 :
487+ key_pad_mask = self .layer_cache [1 ].get ("key_pad_mask" , None )
488+ if key_pad_mask is not None :
489+ x = key_pad_mask .expand (
490+ - 1 , self .head_count // self .parallel_gpu , - 1
491+ )
492+ x = x .unsqueeze (3 )
493+ x = x .expand (- 1 , - 1 , - 1 , value .size (3 ))
494+ value = value .masked_fill (x , 0 )
492495
493496 self .layer_cache [1 ]["keys" ] = key
494497 self .layer_cache [1 ]["values" ] = value
@@ -571,19 +574,6 @@ def forward(
571574 self .layer_cache [1 ]["keys" ] = key
572575 self .layer_cache [1 ]["values" ] = value
573576
574- if key_pad_mask is not None :
575- # Increase the cached key pad mask by concatenation.
576- # For decoding only.
577- if step > 0 :
578- y = torch .zeros (
579- (key_pad_mask .size (0 ), key_pad_mask .size (1 ), 1 ),
580- dtype = torch .bool ,
581- device = key_pad_mask .device ,
582- )
583- self .layer_cache [1 ]["key_pad_mask" ] = torch .cat (
584- (key_pad_mask , y ), 2
585- )
586- key_pad_mask = self .layer_cache [1 ]["key_pad_mask" ]
587577 else :
588578 # Retrieve keys and values from linear layers (training mode).
589579 key = self .maybe_ckpt (self .linear_keys , key )
@@ -712,8 +702,6 @@ def forward(
712702 scores = self .alibi (scores )
713703
714704 scores = scores .float ()
715- if key_pad_mask is not None and mask is None :
716- mask = key_pad_mask .unsqueeze (1 )
717705
718706 if mask is not None :
719707 # not 100% necessary but expand to nb of heads
0 commit comments