@@ -323,19 +323,21 @@ def batch_type(self) -> Union[FlashBatch, PaddedBatch]:
323
323
def embed (self , batch : Union [FlashBatch , PaddedBatch ]) -> List [Embedding ]:
324
324
if isinstance (batch , PaddedBatch ):
325
325
input_lens = batch .attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
326
- max_input_lens = input_lens . max (). item ()
326
+ max_input_lens = 0 # This value will not be used
327
327
cu_seqlens = torch .cat (
328
328
(input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ())
329
329
)
330
330
mask = batch .attention_mask .bool ()
331
- batch_size = input_lens .size (0 )
331
+ bsz , tgt_len = mask .size ()
332
+ min_val = torch .finfo (self .dtype ).min
332
333
attn_mask = torch .full (
333
- [batch_size , 1 , 1 , mask . shape [ - 1 ] ],
334
- fill_value = torch . finfo ( self . dtype ). min ,
334
+ [bsz , 1 , tgt_len , tgt_len ],
335
+ fill_value = min_val ,
335
336
device = self .device ,
336
337
dtype = self .dtype ,
337
338
)
338
- attn_mask .masked_fill_ (mask [:, None , None , :], 0 )
339
+ expanded_mask = mask [:, None , None , :].expand (bsz , 1 , tgt_len , tgt_len )
340
+ attn_mask = attn_mask .masked_fill (expanded_mask , 0.0 )
339
341
elif isinstance (batch , FlashBatch ):
340
342
cu_seqlens = batch .cu_seqlens
341
343
mask = None
0 commit comments