Skip to content

Commit d62c941

Browse files
authored
Gaudi: clean cuda/rocm code in hpu backend, enable flat_hpu (#3113)
* clean cuda/rocm code in hpu backend, enable flat_hpu Signed-off-by: Wang, Yi A <[email protected]> * fix TP in pageattn Signed-off-by: Wang, Yi A <[email protected]> * adjust block table in hpu to improve performance Signed-off-by: Wang, Yi A <[email protected]> * enable all the model. not testet yet Signed-off-by: Wang, Yi A <[email protected]> * use tensor cache in hpu graph to avoid replay issue Signed-off-by: Wang, Yi A <[email protected]> * add moe support, fix qwen/mistral/mixtral crash Signed-off-by: Wang, Yi A <[email protected]> * fix phimoe issue Signed-off-by: Wang, Yi A <[email protected]> * gpt_bigcode could also go pageattn Signed-off-by: Wang, Yi A <[email protected]> * enable dbrx remove some unused code Signed-off-by: Wang, Yi A <[email protected]> * multi-modality initial PR Signed-off-by: Wang, Yi A <[email protected]> * adjust warmup and enable vlm Signed-off-by: Wang, Yi A <[email protected]> * fix incorrect output in qwen2 idefics if hpu graph is used Signed-off-by: Wang, Yi A <[email protected]> * remove unused quantization code and enable awq/gptq int4 Signed-off-by: Wang, Yi A <[email protected]> * fix gptq issue Signed-off-by: Wang, Yi A <[email protected]> * enable fp8 Signed-off-by: Wang, Yi A <[email protected]> * warmup prefill remove model where pageattn is not used, set block table to None since it's not used Signed-off-by: Wang, Yi A <[email protected]> * add warmup_decode Signed-off-by: Wang, Yi A <[email protected]> * warmup decode Signed-off-by: Wang, Yi A <[email protected]> * remove block_tables and prefill_cache_indices which will lead to dynamic shape Signed-off-by: Wang, Yi A <[email protected]> * fix comment Signed-off-by: Wang, Yi A <[email protected]> * missing gptj change... Signed-off-by: Wang, Yi A <[email protected]> * fix some issue Signed-off-by: Wang, Yi A <[email protected]> * remove torch.where to fix incorrect output in hpu graph model Signed-off-by: Wang, Yi A <[email protected]> * match the latest vllm_extension ops Signed-off-by: Wang, Yi A <[email protected]> --------- Signed-off-by: Wang, Yi A <[email protected]>
1 parent 9a8d046 commit d62c941

File tree

91 files changed

+8849
-11858
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+8849
-11858
lines changed

Dockerfile_gaudi

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ RUN cd server && \
9595
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
9696
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
9797
pip install . --no-cache-dir
98-
98+
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git
9999
# Install benchmarker
100100
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
101101
# Install router

backends/gaudi/server/text_generation_server/cli.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,9 @@
1616

1717

1818
class Quantization(str, Enum):
19-
bitsandbytes = "bitsandbytes"
20-
bitsandbytes_nf4 = "bitsandbytes-nf4"
21-
bitsandbytes_fp4 = "bitsandbytes-fp4"
2219
gptq = "gptq"
2320
awq = "awq"
24-
eetq = "eetq"
25-
exl2 = "exl2"
2621
fp8 = "fp8"
27-
marlin = "marlin"
2822

2923

3024
class Dtype(str, Enum):
@@ -105,14 +99,17 @@ def serve(
10599
"bitsandbytes",
106100
"bitsandbytes-nf4",
107101
"bitsandbytes-fp4",
102+
"gptq",
103+
"awq",
104+
"fp8",
108105
}:
109106
raise RuntimeError(
110107
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
111108
)
112109

113110
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
114111

115-
if sharded:
112+
if sharded and os.getenv("ATTENTION", "default") not in {"paged"}:
116113
tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
117114
num_shard = int(os.getenv("WORLD_SIZE", "1"))
118115
logger.info("CLI SHARDED = {}".format(num_shard))
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,28 @@
1-
from text_generation_server.utils.import_utils import SYSTEM
2-
import os
1+
from .common import (
2+
Seqlen,
3+
HPUPagedAttentionMetadata,
4+
trim_attn_metadata,
5+
trim_seqlen_metadata,
6+
)
37

4-
from .common import Seqlen
8+
from .hpu import (
9+
SUPPORTS_WINDOWING,
10+
attention,
11+
paged_attention,
12+
)
513

6-
if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "false":
7-
raise ImportError("`USE_FLASH_ATTENTION` is false.")
8-
if SYSTEM == "cuda":
9-
from .cuda import (
10-
attention,
11-
paged_attention,
12-
reshape_and_cache,
13-
SUPPORTS_WINDOWING,
14-
PREFILL_IN_KV_CACHE,
15-
)
16-
elif SYSTEM == "rocm":
17-
from .rocm import (
18-
attention,
19-
paged_attention,
20-
reshape_and_cache,
21-
PREFILL_IN_KV_CACHE,
22-
SUPPORTS_WINDOWING,
23-
)
24-
elif SYSTEM == "ipex":
25-
from .ipex import (
26-
attention,
27-
paged_attention,
28-
reshape_and_cache,
29-
PREFILL_IN_KV_CACHE,
30-
SUPPORTS_WINDOWING,
31-
)
32-
else:
33-
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
3414

15+
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
16+
from .kv_cache import KVCache, get_kv_scales
3517

3618
__all__ = [
3719
"attention",
20+
"get_kv_scales",
3821
"paged_attention",
39-
"reshape_and_cache",
40-
"PREFILL_IN_KV_CACHE",
4122
"SUPPORTS_WINDOWING",
23+
"KVCache",
4224
"Seqlen",
25+
"HPUPagedAttentionMetadata",
26+
"trim_seqlen_metadata",
27+
"trim_attn_metadata",
4328
]
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,147 @@
11
from dataclasses import dataclass
2-
from text_generation_server.utils.import_utils import SYSTEM
3-
from text_generation_server.models.globals import ATTENTION
42
import torch
5-
from typing import Optional
6-
7-
8-
if ATTENTION in {"flashinfer", "flashdecoding"}:
9-
10-
@dataclass
11-
class Seqlen:
12-
input_lengths: torch.Tensor
13-
prefix_lengths: torch.Tensor
14-
cu_seqlen_q: Optional[torch.Tensor]
15-
cu_seqlen_k: Optional[torch.Tensor]
16-
max_q: int
17-
max_k: int
18-
19-
def __init__(
20-
self,
21-
input_lengths,
22-
prefix_lengths,
23-
cu_seqlen_q=None,
24-
max_q=None,
25-
max_k=None,
26-
):
27-
self.input_lengths = input_lengths
28-
self.prefix_lengths = prefix_lengths
29-
device = self.input_lengths.device
30-
shape = self.input_lengths.shape
31-
if cu_seqlen_q is None:
32-
cu_seqlen_q = torch.arange(
33-
shape[0] + 1,
34-
device=device,
35-
dtype=torch.int32,
36-
)
37-
max_q = 1
38-
else:
39-
assert max_q is not None
40-
assert max_k is not None
41-
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
42-
43-
# cuda graphs don't like this and this is necessary to clamp within mistral
44-
# Although FA2 might not want the clamping
45-
# cu_seqlen_k[0] = 0
46-
total = self.input_lengths + self.prefix_lengths
47-
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
48-
49-
self.cu_seqlen_q = cu_seqlen_q
50-
self.cu_seqlen_k = cu_seqlen_k
51-
self.max_q = max_q
52-
self.max_k = max_k
53-
54-
def clamp(self, max):
55-
# Flash decoding doesn't need to clamp
56-
return self
57-
58-
else:
59-
60-
@dataclass
61-
class Seqlen:
62-
input_lengths: torch.Tensor
63-
prefix_lengths: torch.Tensor
64-
cu_seqlen_q: torch.Tensor
65-
max_q: int
66-
max_k: int
67-
68-
def clamp(self, max):
69-
if SYSTEM == "rocm":
70-
return self
71-
raise NotImplementedError("Not implemented seqlen for paged")
72-
return Seqlen(torch.clamp(self.input_lengths, max=max))
3+
from typing import Optional, List, Dict
4+
import collections
5+
6+
_TYPE_CACHE = {}
7+
8+
9+
@dataclass
10+
class HPUPagedAttentionMetadata:
11+
"""Metadata for PagedAttention."""
12+
13+
block_list: Optional[torch.Tensor]
14+
block_mapping: Optional[torch.Tensor]
15+
block_usage: Optional[torch.Tensor]
16+
block_scales: Optional[torch.Tensor]
17+
block_groups: Optional[torch.Tensor]
18+
attn_bias: Optional[torch.Tensor]
19+
20+
21+
def subtuple(
22+
obj: object,
23+
typename: str,
24+
to_copy: List[str],
25+
to_override: Optional[Dict[str, object]] = None,
26+
):
27+
if obj is None:
28+
return None
29+
if to_override is None:
30+
to_override = {}
31+
fields = set(to_copy) | set(to_override.keys())
32+
if isinstance(obj, dict):
33+
values = {key: obj[key] for key in fields if key in obj}
34+
else:
35+
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
36+
if typename not in _TYPE_CACHE:
37+
_TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields))
38+
return _TYPE_CACHE[typename](**values)
39+
40+
41+
def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
42+
# NOTE(kzawora): To anyone working on this in the future:
43+
# Trimming metadata is required when using HPUGraphs.
44+
# Attention metadata is going to be hashed by PT bridge, and
45+
# appropriate HPUGraphs will be matched based on all inputs' hash.
46+
47+
# Before you put more keys in here, make sure you know their
48+
# value type and make sure you know how it's going to be hashed.
49+
# You can find that information in input_hash function
50+
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
51+
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
52+
53+
# If you use primitive types here - they will get hashed based
54+
# on their value. You *will* get lots of excessive graph captures
55+
# (and an OOM eventually) if you decide to put something like
56+
# seq_len int here.
57+
# If you absolutely need a scalar, put it in a tensor. Tensors
58+
# get hashed using their metadata, not their values:
59+
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
60+
# input_hash(123) != input_hash(321)
61+
# input_hash("abc") != input_hash("cba")
62+
attention_metadata = subtuple(
63+
metadata,
64+
"TrimmedAttentionMetadata",
65+
[
66+
"block_list",
67+
"block_mapping",
68+
"block_usage",
69+
"block_scales",
70+
"block_groups",
71+
"attn_bias",
72+
],
73+
)
74+
return attention_metadata
75+
76+
77+
@dataclass
78+
class Seqlen:
79+
input_lengths: torch.Tensor
80+
cache_lengths: torch.Tensor
81+
cu_seqlen_q: Optional[torch.Tensor]
82+
cu_seqlen_k: Optional[torch.Tensor]
83+
84+
def __init__(
85+
self,
86+
input_lengths,
87+
cache_lengths,
88+
cu_seqlen_q=None,
89+
):
90+
self.input_lengths = input_lengths
91+
self.cache_lengths = cache_lengths
92+
device = self.input_lengths.device
93+
shape = self.input_lengths.shape
94+
if cu_seqlen_q is None:
95+
cu_seqlen_q = torch.arange(
96+
shape[0] + 1,
97+
device=device,
98+
dtype=torch.int32,
99+
)
100+
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
101+
102+
# cuda graphs don't like this and this is necessary to clamp within mistral
103+
# Although FA2 might not want the clamping
104+
# cu_seqlen_k[0] = 0
105+
total = self.input_lengths + self.cache_lengths
106+
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
107+
108+
self.cu_seqlen_q = cu_seqlen_q
109+
self.cu_seqlen_k = cu_seqlen_k
110+
111+
def clamp(self, max):
112+
# Flash decoding doesn't need to clamp
113+
return self
114+
115+
116+
def trim_seqlen_metadata(metadata: Seqlen) -> object:
117+
# NOTE(kzawora): To anyone working on this in the future:
118+
# Trimming metadata is required when using HPUGraphs.
119+
# Attention metadata is going to be hashed by PT bridge, and
120+
# appropriate HPUGraphs will be matched based on all inputs' hash.
121+
122+
# Before you put more keys in here, make sure you know their
123+
# value type and make sure you know how it's going to be hashed.
124+
# You can find that information in input_hash function
125+
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
126+
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
127+
128+
# If you use primitive types here - they will get hashed based
129+
# on their value. You *will* get lots of excessive graph captures
130+
# (and an OOM eventually) if you decide to put something like
131+
# seq_len int here.
132+
# If you absolutely need a scalar, put it in a tensor. Tensors
133+
# get hashed using their metadata, not their values:
134+
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
135+
# input_hash(123) != input_hash(321)
136+
# input_hash("abc") != input_hash("cba")
137+
attention_metadata = subtuple(
138+
metadata,
139+
"TrimmedSeqlen",
140+
[
141+
"input_lengths",
142+
"cache_lengths",
143+
"cu_seqlen_q",
144+
"cu_seqlen_k",
145+
],
146+
)
147+
return attention_metadata

0 commit comments

Comments
 (0)