Skip to content

Commit

Permalink
Fix long contexts in LoRA (#624)
Browse files Browse the repository at this point in the history
#566 breaks long-contexts +
LoRA flow.

This assumes caching sin-cos buffer for first decoder layer is
sufficient to handle all cases, which is not the applicable for
long-context + LoRA.

This PR ignores `_prepare_cos_sin` call prior to HpuModelAdapter forward
in long-context + LoRA flow.
  • Loading branch information
SanjuCSudhakaran authored Jan 2, 2025
1 parent 9555fef commit 2443ba9
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
1 change: 1 addition & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):

@pytest.fixture
def dist_init():
import habana_frameworks.torch.hpu # noqa: F401
temp_file = tempfile.mkstemp()[1]
backend_type = "hccl" if current_platform.is_hpu() else "nccl"
init_distributed_environment(
Expand Down
24 changes: 19 additions & 5 deletions vllm/lora/punica_wrapper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from vllm.platforms import current_platform

if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
Expand Down Expand Up @@ -86,10 +88,14 @@ def convert_mapping(
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None

if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
if current_platform.is_hpu():
long_lora_offsets_list: List[int] = []
else:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
Expand All @@ -102,10 +108,18 @@ def convert_mapping(
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset
if current_platform.is_hpu():
long_lora_offsets_list.append(lora_offset)
else:
assert long_lora_offsets is not None
long_lora_offsets[i] = lora_offset

if long_lora_context and current_platform.is_hpu():
long_lora_offsets = torch.tensor(long_lora_offsets_list,
device=device,
dtype=torch.long)

indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices,
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,10 @@ def forward_hpu(
) -> Tuple[torch.Tensor, torch.Tensor]:
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
if self.sin is None:

# Prepare cos-sin caches for long-context + LoRA with offsets for every
# forward, since the offset information wasn't available previously
if hasattr(self, "scaling_factors") or self.sin is None:
self.prepare_cos_sin(positions, offsets)
num_tokens = positions.shape[0] * positions.shape[1]
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
Expand Down

0 comments on commit 2443ba9

Please sign in to comment.