|
45 | 45 | from text_generation_server.layers.layernorm import (
|
46 | 46 | FastRMSNorm,
|
47 | 47 | )
|
| 48 | +from text_generation_server.models.globals import ATTENTION |
48 | 49 | from text_generation_server.utils.weights import UnquantizedWeight
|
49 | 50 | from transformers.activations import ACT2FN
|
50 | 51 | from text_generation_server.layers.attention import (
|
@@ -248,7 +249,7 @@ def forward(
|
248 | 249 |
|
249 | 250 | # Prefill
|
250 | 251 | if cu_seqlen_prefill is not None:
|
251 |
| - if attention_mask is None: |
| 252 | + if attention_mask is None or ATTENTION == "flashinfer": |
252 | 253 | # flash attention
|
253 | 254 | attn_output = attention(
|
254 | 255 | query=query,
|
@@ -701,8 +702,16 @@ def __init__(self, prefix, config, weights):
|
701 | 702 | )
|
702 | 703 |
|
703 | 704 | def get_attention_mask(
|
704 |
| - self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask |
| 705 | + self, |
| 706 | + input_ids: torch.Tensor, |
| 707 | + cu_seqlen_prefill: Optional[torch.Tensor], |
| 708 | + dtype: torch.dtype, |
| 709 | + bool_mask: bool = False, |
705 | 710 | ):
|
| 711 | + image_token_mask = (input_ids == self.config.image_token_index).to( |
| 712 | + input_ids.device |
| 713 | + ) |
| 714 | + |
706 | 715 | device = input_ids.device
|
707 | 716 | min_dtype = torch.finfo(dtype).min
|
708 | 717 |
|
@@ -748,9 +757,10 @@ def get_attention_mask(
|
748 | 757 | )
|
749 | 758 | full_attention_mask[:, :, :, :sequence_length] = combined_mask
|
750 | 759 |
|
751 |
| - final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device) |
752 |
| - |
753 |
| - return final_attention_mask |
| 760 | + if bool_mask: |
| 761 | + return full_attention_mask |
| 762 | + else: |
| 763 | + return torch.where(full_attention_mask, 0, min_dtype).to(device) |
754 | 764 |
|
755 | 765 | def forward(
|
756 | 766 | self,
|
@@ -793,10 +803,8 @@ def forward(
|
793 | 803 | )
|
794 | 804 | attention_mask = self.get_attention_mask(
|
795 | 805 | input_ids,
|
796 |
| - max_s, |
797 | 806 | cu_seqlen_prefill,
|
798 | 807 | inputs_embeds.dtype,
|
799 |
| - image_token_mask, |
800 | 808 | )
|
801 | 809 | # Use flash attention for text-only input
|
802 | 810 | # else:
|
|
0 commit comments