Skip to content

[Hardware][Intel-Gaudi] enable text embedding for Intel-Gaudi backend #17920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def __init__(
f"Supported head sizes are: {supported_head_sizes}.")

self.attn_type = attn_type
if self.attn_type != AttentionType.DECODER:
if attn_type != AttentionType.DECODER and \
self.attn_type != AttentionType.ENCODER_ONLY:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
Expand Down
49 changes: 36 additions & 13 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def get_prompt_lens(
pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
pooling_metadata, hidden_states.device
).prompt_lens, PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_offsets

def extract_states(
self,
Expand Down Expand Up @@ -109,10 +111,14 @@ def extract_states(
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)

first_token_flat_indices = torch.zeros_like(prompt_lens)
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
prompt_lens, prompt_offsets = self.get_prompt_lens(
hidden_states, pooling_metadata)
if prompt_offsets is not None:
first_token_flat_indices = prompt_offsets
else:
first_token_flat_indices = torch.zeros_like(prompt_lens)
first_token_flat_indices[1:] += torch.cumsum(prompt_lens,
dim=0)[:-1]
return hidden_states[first_token_flat_indices]


Expand All @@ -123,9 +129,15 @@ def extract_states(
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)

last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
prompt_lens, prompt_offsets = self.get_prompt_lens(
hidden_states, pooling_metadata)
if prompt_offsets is not None:
last_token_flat_indices = (torch.sum(torch.cat(
(prompt_lens.unsqueeze(0), prompt_offsets.unsqueeze(0)), 0),
dim=0,
keepdim=True) - 1).squeeze(0)
else:
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
return hidden_states[last_token_flat_indices]


Expand All @@ -136,7 +148,8 @@ def extract_states(
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
prompt_lens, prompt_offsets = self.get_prompt_lens(
hidden_states, pooling_metadata)

offset = 0
pooled_data = list[torch.Tensor]()
Expand All @@ -154,14 +167,23 @@ def extract_states(
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
prompt_lens, prompt_offsets = self.get_prompt_lens(
hidden_states, pooling_metadata)

cumsum = torch.cumsum(hidden_states, dim=0)
start_indices = torch.cat([
torch.tensor([0], device=hidden_states.device),
torch.cumsum(prompt_lens[:-1], dim=0)
])
end_indices = torch.cumsum(prompt_lens, dim=0)
if prompt_offsets is not None:
end_indices = prompt_offsets + prompt_lens
start_indices = prompt_offsets
else:
start_indices = torch.cat([
torch.tensor([0], device=hidden_states.device),
torch.cumsum(prompt_lens[:-1], dim=0)
])
end_indices = torch.cumsum(prompt_lens, dim=0)
return (cumsum[end_indices - 1] - cumsum[start_indices] +
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)

Expand All @@ -186,7 +208,8 @@ def extract_states(
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
prompt_lens, prompt_offsets = self.get_prompt_lens(
hidden_states, pooling_metadata)

returned_token_ids = self.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
Expand Down Expand Up @@ -221,7 +244,7 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata):

dimensions_list = [
pooling_param.dimensions
pooling_param.dimensions if pooling_param is not None else None
for _, pooling_param in pooling_metadata.seq_groups
]
if any(d is not None for d in dimensions_list):
Expand Down
24 changes: 21 additions & 3 deletions vllm/model_executor/pooling_metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from typing import Any

from typing import Any, Optional

import torch

Expand All @@ -19,30 +20,37 @@ class PoolingMetadata:
seq_groups: List of (seq_ids, pooling_params).
seq_data: A mapping of sequence ID to additional sequence data.
prompt_lens: List of the lengths of each prompt.
prompt_offsets: List of prompt start offsets for each prompt
when flat out with padding
"""

def __init__(
self,
seq_groups: list[tuple[list[int], PoolingParams]],
seq_data: dict[int, Any], # Specific data related to sequences
prompt_lens: list[int],
prompt_offsets: Optional[list[int]] = None,

) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.prompt_offsets = prompt_offsets

def __repr__(self) -> str:
return ("PoolingMetadata("
f"seq_groups={self.seq_groups}, "
f"seq_data={self.seq_data}, "
f"prompt_lens={self.prompt_lens})")
f"prompt_lens={self.prompt_lens}, "
f"prompt_offsets={self.prompt_offsets})")


@dataclass
class PoolingTensors:
"""Tensors for pooling."""

prompt_lens: torch.Tensor
prompt_offsets: torch.Tensor

@classmethod
def from_pooling_metadata(
Expand All @@ -67,5 +75,15 @@ def from_pooling_metadata(
pin_memory=pin_memory,
)

if pooling_metadata.prompt_offsets is not None:
prompt_offsets_t = torch.tensor(
pooling_metadata.prompt_offsets,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
).to(device=device, non_blocking=True)
else:
prompt_offsets_t = None
return cls(prompt_lens=prompt_lens_t.to(device=device,
non_blocking=True), )
non_blocking=True),
prompt_offsets=prompt_offsets_t)
Loading
Loading