diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 18a81bf8de3f8..d7d5a2948d496 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -159,11 +159,11 @@ def __init__( logger().warning("Could not import HPU FusedSDPA kernel. " "vLLM will use native implementation.") - suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + supported_head_sizes = HPUPagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Supported head sizes are: {supported_head_sizes}.") self.attn_type = attn_type if (self.attn_type != AttentionType.DECODER diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 8cd5d4e355048..613f7c27f3667 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -872,7 +872,27 @@ def _attention_with_mask( attn_metadata: AttentionMetadata, ) -> torch.Tensor: # Skip writing kv-cache for the initial profiling run. - if len(kv_cache.shape) > 1: + if is_hpu and kv_cache is not None and isinstance(kv_cache, tuple): + assert self.attn.backend == _Backend.HPU_ATTN + # During cross-attention decode, key & value will be None, + # we don't need to cache them. + if (k is not None) and (v is not None): + from vllm_hpu_extension.utils import VLLMKVCache + + from vllm.attention.ops.hpu_paged_attn import HPUPagedAttention + key_cache, value_cache = HPUPagedAttention.split_kv_cache( + kv_cache, self.num_local_key_value_heads, self.head_dim) + cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) + cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) + block_indices = attn_metadata.cross_block_indices + block_offsets = attn_metadata.cross_block_offsets + k_cache = VLLMKVCache() + v_cache = VLLMKVCache() + key_cache = k_cache(cached_k, key_cache, block_indices, + block_offsets) + value_cache = v_cache(cached_v, value_cache, block_indices, + block_offsets) + elif len(kv_cache.shape) > 1: i = torch.ones(1, dtype=torch.float32) if self.attn.backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1): @@ -929,14 +949,33 @@ def _attention_with_mask( kv_len, self.head_dim).contiguous() attention_mask = attention_mask.view(1, 1, q_len, kv_len) - output = F.scaled_dot_product_attention(q, - k, - v, - attn_mask=attention_mask, - is_causal=False) - output = output.permute(2, 0, 1, 3).reshape( - q_len, self.num_local_heads * self.head_dim) - return output + if current_platform.is_hpu(): + from habana_frameworks.torch.hpex.kernels import FusedSDPA + from vllm_hpu_extension.utils import ModuleFusedSDPA + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + # use fp32 as softmax_mode for better accuracy + output = fsdpa_op(q, + k, + v, + attention_mask, + dropout_p=0.0, + is_causal=False, + scale=None, + softmax_mode="fast", + recompute_mode=None, + valid_sequence_lengths=None) + output = output.permute(2, 0, 1, 3).reshape( + q_len, self.num_local_heads * self.head_dim) + return output + else: + output = F.scaled_dot_product_attention(q, + k, + v, + attn_mask=attention_mask, + is_causal=False) + output = output.permute(2, 0, 1, 3).reshape( + q_len, self.num_local_heads * self.head_dim) + return output class MllamaCrossAttentionDecoderLayer(torch.nn.Module): @@ -1001,9 +1040,13 @@ def forward( # TODO: Change input_tokens tensor at the beginning of model execution # to 2D tensor to align with public vllm input_tokens shape. But this # will face the graph building failure issue, still need to investigate. - if len(hidden_states.shape) == 3: - full_text_row_masked_out_mask = full_text_row_masked_out_mask.view( - hidden_states.size(0), -1, 1) + assert len(residual.shape) == 3 + if len(hidden_states.shape) == 2: + hidden_states = hidden_states.view(residual.size(0), + residual.size(1), + residual.size(2)) + full_text_row_masked_out_mask = full_text_row_masked_out_mask.view( + hidden_states.size(0), -1, 1) hidden_states = full_text_row_masked_out_mask * hidden_states hidden_states = residual + self.cross_attn_attn_gate.tanh( ) * hidden_states @@ -1320,7 +1363,12 @@ def get_cross_attention_mask( num_tokens_per_tile: int, dtype: torch.dtype, ) -> Tuple[torch.Tensor, torch.Tensor]: - token_ids = input_ids.tolist() + token_ids = [] + if is_hpu: + # input_ids is not flatten yet for hpu + token_ids = input_ids.flatten().tolist() + else: + token_ids = input_ids.tolist() start = 0 batch_token_ids = [] for seq_len in attn_metadata.seq_lens: diff --git a/vllm/worker/hpu_enc_dec_model_runner.py b/vllm/worker/hpu_enc_dec_model_runner.py index 1f9d7b8023ff5..191bf79f57881 100644 --- a/vllm/worker/hpu_enc_dec_model_runner.py +++ b/vllm/worker/hpu_enc_dec_model_runner.py @@ -9,6 +9,7 @@ import habana_frameworks.torch as htorch import torch +from PIL import Image from vllm_hpu_extension.ops import batch2block, block2batch from vllm.attention import AttentionMetadata @@ -21,7 +22,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceData, SequenceGroupMetadata, SequenceOutput) -from vllm.utils import is_fake_hpu +from vllm.utils import is_fake_hpu, is_list_of from vllm.worker.hpu_model_runner import (HpuModelAdapter, HPUModelRunnerBase, ModelInputForHPUWithSamplingMetadata, setup_profiler, subtuple) @@ -358,9 +359,13 @@ def _prepare_encoder_model_input_tensors( return attn_metadata + @torch.inference_mode() def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) - kv_caches = [None] * num_layers + kv_caches = [ + torch.tensor([], dtype=torch.bfloat16, device=self.device) + for _ in range(num_layers) + ] max_batch_size = self.max_num_prefill_seqs _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() max_seq_len = min(self.max_num_batched_tokens // max_batch_size, @@ -446,24 +451,15 @@ def create_dummy_seq_group_metadata(self, sampling_params = SamplingParams(temperature=temperature) num_blocks = math.ceil(seq_len / self.block_size) cross_block_table: Optional[List[int]] = None - seq_len = max(seq_len, 1) - mm_counts = self.mm_registry.get_mm_limits_per_prompt( - self.model_config) - num_images = mm_counts["image"] max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( - self.model_config) * num_images - decoder_dummy_data \ - = self.input_registry.dummy_data_for_profiling( - self.model_config, - seq_len, - self.mm_registry, - is_encoder_data=False) + self.model_config) encoder_dummy_data \ = self.input_registry.dummy_data_for_profiling( - self.model_config, - max_mm_tokens, - self.mm_registry, - is_encoder_data=True) + self.model_config, + max_mm_tokens, + self.mm_registry, + is_encoder_data=True) + seq_len = max(seq_len, 1) if is_prompt: input_len = seq_len output_len = 0 @@ -477,11 +473,15 @@ def create_dummy_seq_group_metadata(self, num_cross_blocks = min(self.bucketing_ctx.num_hpu_blocks, max_mm_tokens) // self.block_size cross_block_table = [_PAD_BLOCK_ID] * num_cross_blocks - prompt_token_ids = [0] * input_len output_token_ids = [1] * output_len - prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821 - seq_data = SequenceData(prompt_token_ids_array) + decoder_dummy_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry, + is_encoder_data=False) + seq_data = decoder_dummy_data.seq_data seq_data.output_token_ids = output_token_ids + return SequenceGroupMetadata( request_id=str(group_id), is_prompt=is_prompt, @@ -489,9 +489,9 @@ def create_dummy_seq_group_metadata(self, sampling_params=sampling_params, block_tables=block_tables, encoder_seq_data=encoder_dummy_data.seq_data, - multi_modal_data=decoder_dummy_data.multi_modal_data, + multi_modal_data=decoder_dummy_data.multi_modal_data or encoder_dummy_data.multi_modal_data, multi_modal_placeholders=decoder_dummy_data. - multi_modal_placeholders, + multi_modal_placeholders or encoder_dummy_data.multi_modal_placeholders, cross_block_table=cross_block_table) def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: