Skip to content

Commit 5a791e5

Browse files
authored
Optimize the performance of FlashBert Path for HPU (#575)
Signed-off-by: Liu, Kaixuan <[email protected]>
1 parent 2118688 commit 5a791e5

File tree

1 file changed

+7
-5
lines changed
  • backends/python/server/text_embeddings_server/models

1 file changed

+7
-5
lines changed

backends/python/server/text_embeddings_server/models/flash_bert.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -323,19 +323,21 @@ def batch_type(self) -> Union[FlashBatch, PaddedBatch]:
323323
def embed(self, batch: Union[FlashBatch, PaddedBatch]) -> List[Embedding]:
324324
if isinstance(batch, PaddedBatch):
325325
input_lens = batch.attention_mask.cumsum(-1)[:, -1].to(torch.int32)
326-
max_input_lens = input_lens.max().item()
326+
max_input_lens = 0 # This value will not be used
327327
cu_seqlens = torch.cat(
328328
(input_lens.new_tensor([0]), input_lens.cumsum(-1).int())
329329
)
330330
mask = batch.attention_mask.bool()
331-
batch_size = input_lens.size(0)
331+
bsz, tgt_len = mask.size()
332+
min_val = torch.finfo(self.dtype).min
332333
attn_mask = torch.full(
333-
[batch_size, 1, 1, mask.shape[-1]],
334-
fill_value=torch.finfo(self.dtype).min,
334+
[bsz, 1, tgt_len, tgt_len],
335+
fill_value=min_val,
335336
device=self.device,
336337
dtype=self.dtype,
337338
)
338-
attn_mask.masked_fill_(mask[:, None, None, :], 0)
339+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, tgt_len)
340+
attn_mask = attn_mask.masked_fill(expanded_mask, 0.0)
339341
elif isinstance(batch, FlashBatch):
340342
cu_seqlens = batch.cu_seqlens
341343
mask = None

0 commit comments

Comments
 (0)