Skip to content

Commit 0399db8

Browse files
authored
Merge pull request #2 from DeepAuto-AI/feat/fix-short-ctx
fix radix attention
2 parents a917567 + 00e3aa7 commit 0399db8

25 files changed

+789
-534
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ export CHUNK_PREFILL=16384;
2424
# Any RoPE based attention models are supported in theoritically.
2525
# However currently we are supports `llama.py` models. (Llama Family)
2626
export MODEL="hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4";
27-
# You can set upper limit of maximum extended context window.
27+
# You can set upper limit of maximum extended context window.
2828
# Training-free and unlimited.
2929
export EXTENDED_CONTEXT_LEN=196608;
3030
# You can change this flag into 1, if you want test online cache update. (exprimental)
@@ -42,13 +42,13 @@ python -m sglang.launch_server \
4242
--context-length $EXTENDED_CONTEXT_LEN \
4343
--max-total-tokens $EXTENDED_CONTEXT_LEN \
4444
--enable-hip-attention \
45-
# You can turn off this flag to disable offloading.
45+
# You can turn off this flag to disable offloading.
4646
# Offloading may have difference in decoding result.
4747
--enable-hip-offload \
48-
# For on-gpu offloading cache in masking kernel,
48+
# For on-gpu offloading cache in masking kernel,
4949
# allocate size of cache in num of tokens. This is shared by whole batch.
5050
--hip-max-mask-cache-token-size 32000 \
51-
# For on-gpu offloading cache in block sparse attention kernel,
51+
# For on-gpu offloading cache in block sparse attention kernel,
5252
# allocate size of cache in num of tokens. This is shared by whole batch.
5353
--hip-max-sa-cache-token-size 10000;
5454
```

python/sglang/srt/hf_transformers_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def get_config(
9797
"max_position_embeddings",
9898
]
9999

100+
100101
def get_context_length(config):
101102
"""Get the context length of a model from a huggingface model configs."""
102103
text_config = config
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .hip_cuda_graph_runner import HiPCudaGraphRunner
2+
23
# from .hip_offload_kv_pool_mha import MHATokenToHiPOffloadKVPool
34
from .hip_radix_attention import HiPRadixAttentionBackend
Lines changed: 86 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from dataclasses import dataclass, field, InitVar
2-
from typing import List, Optional, Union
31
import warnings
2+
from dataclasses import InitVar, dataclass, field
3+
from typing import List, Optional, Union
44

55
from hip.models.hip_attention.gen3.attention_metadata import ScanStage
66

7-
87
_DEFAULT_STAGES = [
98
ScanStage(
109
stage_block_size_q=64,
@@ -35,7 +34,7 @@ class HiPAttentionPerLayerConfig:
3534
second_stage_k: int = 2048
3635
sliding_window_size: int = 1024
3736
sink_token_size: int = 256
38-
sa_extend_backend: str = 'streaming'
37+
sa_extend_backend: str = "streaming"
3938
scan_extend_backend: Optional[str] = None
4039
stages: list[ScanStage] = field(default_factory=lambda: _DEFAULT_STAGES)
4140

@@ -44,47 +43,50 @@ class HiPAttentionPerLayerConfig:
4443
def __post_init__(self, parsed_json: dict | None):
4544
super().__init__()
4645
if parsed_json is not None:
47-
if 'second_stage_k' in parsed_json:
48-
self.second_stage_k = parsed_json['second_stage_k']
49-
parsed_json.pop('second_stage_k')
50-
if 'sliding_window_size' in parsed_json:
51-
self.sliding_window_size = parsed_json['sliding_window_size']
52-
parsed_json.pop('sliding_window_size')
53-
if 'sink_token_size' in parsed_json:
54-
self.sink_token_size = parsed_json['sink_token_size']
55-
parsed_json.pop('sink_token_size')
56-
if 'sa_extend_backend' in parsed_json:
57-
self.sa_extend_backend = parsed_json['sa_extend_backend']
58-
parsed_json.pop('sa_extend_backend')
59-
if 'scan_extend_backend' in parsed_json:
60-
self.scan_extend_backend = parsed_json['scan_extend_backend']
61-
parsed_json.pop('scan_extend_backend')
62-
if 'stages' in parsed_json:
63-
self.stages = [
64-
ScanStage(**stage)
65-
for stage in parsed_json['stages']
66-
]
67-
parsed_json.pop('stages')
46+
if "second_stage_k" in parsed_json:
47+
self.second_stage_k = parsed_json["second_stage_k"]
48+
parsed_json.pop("second_stage_k")
49+
if "sliding_window_size" in parsed_json:
50+
self.sliding_window_size = parsed_json["sliding_window_size"]
51+
parsed_json.pop("sliding_window_size")
52+
if "sink_token_size" in parsed_json:
53+
self.sink_token_size = parsed_json["sink_token_size"]
54+
parsed_json.pop("sink_token_size")
55+
if "sa_extend_backend" in parsed_json:
56+
self.sa_extend_backend = parsed_json["sa_extend_backend"]
57+
parsed_json.pop("sa_extend_backend")
58+
if "scan_extend_backend" in parsed_json:
59+
self.scan_extend_backend = parsed_json["scan_extend_backend"]
60+
parsed_json.pop("scan_extend_backend")
61+
if "stages" in parsed_json:
62+
self.stages = [ScanStage(**stage) for stage in parsed_json["stages"]]
63+
parsed_json.pop("stages")
6864
if parsed_json:
69-
raise ValueError(f'Unknown keys in json: {parsed_json.keys()}')
65+
raise ValueError(f"Unknown keys in json: {parsed_json.keys()}")
7066

7167

7268
@dataclass
7369
class HiPAttentionConfig:
7470
dense_layers: list[int] = field(default_factory=lambda: [0, 1, 2])
7571
block_sparse_block_size_q: int = 64
7672
metadata_cache_max_batch_size: int = 32
77-
mask_refresh_interval: Union[int, List[int]] = field(default_factory=lambda: [32, 16, 8])
73+
mask_refresh_interval: Union[int, List[int]] = field(
74+
default_factory=lambda: [32, 16, 8]
75+
)
7876
using_extend: bool = True
79-
layers: list[HiPAttentionPerLayerConfig] = field(default_factory=lambda: [
80-
HiPAttentionPerLayerConfig(parsed_json={"second_stage_k": 4096, "sliding_window_size": 1024, "sink_token_size": 256}),
81-
HiPAttentionPerLayerConfig(),
82-
])
83-
prefill_layers: list[HiPAttentionPerLayerConfig] = field(default_factory=lambda: [
84-
HiPAttentionPerLayerConfig(parsed_json={"second_stage_k": 4096, "sliding_window_size": 1024, "sink_token_size": 256}),
85-
HiPAttentionPerLayerConfig(),
86-
])
87-
77+
layers: list[HiPAttentionPerLayerConfig] = field(
78+
default_factory=lambda: [
79+
HiPAttentionPerLayerConfig(
80+
parsed_json={
81+
"second_stage_k": 4096,
82+
"sliding_window_size": 1024,
83+
"sink_token_size": 256,
84+
}
85+
),
86+
HiPAttentionPerLayerConfig(),
87+
]
88+
)
89+
8890
# deprecated
8991
apply_v_dot: bool = False
9092
prefill_always_dense: bool = False
@@ -96,58 +98,64 @@ class HiPAttentionConfig:
9698

9799
def __post_init__(self, parsed_json: dict | None):
98100
super().__init__()
99-
101+
100102
if parsed_json is not None:
101-
if 'apply_v_dot' in parsed_json:
102-
self.apply_v_dot = parsed_json['apply_v_dot']
103-
parsed_json.pop('apply_v_dot')
104-
if 'dense_layers' in parsed_json:
105-
self.dense_layers = parsed_json['dense_layers']
106-
parsed_json.pop('dense_layers')
107-
if 'prefill_always_dense' in parsed_json:
108-
self.prefill_always_dense = parsed_json['prefill_always_dense']
109-
parsed_json.pop('prefill_always_dense')
110-
if 'decode_always_dense' in parsed_json:
111-
self.decode_always_dense = parsed_json['decode_always_dense']
112-
parsed_json.pop('decode_always_dense')
113-
if 'force_dense' in parsed_json:
114-
self.force_dense = parsed_json['force_dense']
115-
parsed_json.pop('force_dense')
116-
if 'prefill_dense_threshold' in parsed_json:
117-
self.prefill_dense_threshold = parsed_json['prefill_dense_threshold']
118-
parsed_json.pop('prefill_dense_threshold')
119-
if 'block_sparse_block_size_q' in parsed_json:
120-
self.block_sparse_block_size_q = parsed_json['block_sparse_block_size_q']
121-
parsed_json.pop('block_sparse_block_size_q')
122-
if 'metadata_cache_max_batch_size' in parsed_json:
123-
self.metadata_cache_max_batch_size = parsed_json['metadata_cache_max_batch_size']
124-
parsed_json.pop('metadata_cache_max_batch_size')
125-
if 'mask_refresh_interval' in parsed_json:
126-
assert isinstance(parsed_json['mask_refresh_interval'], (int, list))
127-
self.mask_refresh_interval = parsed_json['mask_refresh_interval']
128-
parsed_json.pop('mask_refresh_interval')
129-
if 'using_extend' in parsed_json:
130-
self.using_extend = parsed_json['using_extend']
131-
parsed_json.pop('using_extend')
132-
if 'layers' in parsed_json:
103+
if "apply_v_dot" in parsed_json:
104+
self.apply_v_dot = parsed_json["apply_v_dot"]
105+
parsed_json.pop("apply_v_dot")
106+
if "dense_layers" in parsed_json:
107+
self.dense_layers = parsed_json["dense_layers"]
108+
parsed_json.pop("dense_layers")
109+
if "prefill_always_dense" in parsed_json:
110+
self.prefill_always_dense = parsed_json["prefill_always_dense"]
111+
parsed_json.pop("prefill_always_dense")
112+
if "decode_always_dense" in parsed_json:
113+
self.decode_always_dense = parsed_json["decode_always_dense"]
114+
parsed_json.pop("decode_always_dense")
115+
if "force_dense" in parsed_json:
116+
self.force_dense = parsed_json["force_dense"]
117+
parsed_json.pop("force_dense")
118+
if "prefill_dense_threshold" in parsed_json:
119+
self.prefill_dense_threshold = parsed_json["prefill_dense_threshold"]
120+
parsed_json.pop("prefill_dense_threshold")
121+
if "block_sparse_block_size_q" in parsed_json:
122+
self.block_sparse_block_size_q = parsed_json[
123+
"block_sparse_block_size_q"
124+
]
125+
parsed_json.pop("block_sparse_block_size_q")
126+
if "metadata_cache_max_batch_size" in parsed_json:
127+
self.metadata_cache_max_batch_size = parsed_json[
128+
"metadata_cache_max_batch_size"
129+
]
130+
parsed_json.pop("metadata_cache_max_batch_size")
131+
if "mask_refresh_interval" in parsed_json:
132+
assert isinstance(parsed_json["mask_refresh_interval"], (int, list))
133+
self.mask_refresh_interval = parsed_json["mask_refresh_interval"]
134+
parsed_json.pop("mask_refresh_interval")
135+
if "using_extend" in parsed_json:
136+
self.using_extend = parsed_json["using_extend"]
137+
parsed_json.pop("using_extend")
138+
if "layers" in parsed_json:
133139
self.layers = [
134140
HiPAttentionPerLayerConfig(parsed_json=layer)
135-
for layer in parsed_json['layers']
141+
for layer in parsed_json["layers"]
136142
]
137143
self.prefill_layers = self.layers
138-
parsed_json.pop('layers')
139-
if 'prefill_layers' in parsed_json:
144+
parsed_json.pop("layers")
145+
if "prefill_layers" in parsed_json:
140146
self.prefill_layers = [
141147
HiPAttentionPerLayerConfig(parsed_json=layer)
142-
for layer in parsed_json['prefill_layers']
148+
for layer in parsed_json["prefill_layers"]
143149
]
144-
parsed_json.pop('prefill_layers')
150+
parsed_json.pop("prefill_layers")
145151
if parsed_json:
146-
raise Exception(f'Unknown keys in json: {parsed_json.keys()}')
147-
152+
raise ValueError(f"Unknown keys in json: {parsed_json.keys()}")
153+
148154
num_stages = len(self.layers[0].stages)
149155
for layer_config in self.layers:
150156
assert num_stages == len(layer_config.stages)
151-
157+
152158
if isinstance(self.mask_refresh_interval, int):
153-
self.mask_refresh_interval = [self.mask_refresh_interval, ] * num_stages
159+
self.mask_refresh_interval = [
160+
self.mask_refresh_interval,
161+
] * num_stages

python/sglang/srt/layers/attention/hip_attention/hip_cuda_graph_runner.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,18 @@
88

99
from sglang.srt.distributed import get_tensor_model_parallel_rank
1010
from sglang.srt.distributed.parallel_state import graph_capture
11-
from sglang.srt.layers.torchao_utils import save_gemlite_cache
12-
1311
from sglang.srt.layers.logits_processor import (
1412
LogitsMetadata,
1513
LogitsProcessor,
1614
LogitsProcessorOutput,
1715
)
16+
from sglang.srt.layers.torchao_utils import save_gemlite_cache
17+
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner, patch_model
1818
from sglang.srt.model_executor.forward_batch_info import (
1919
CaptureHiddenMode,
2020
ForwardBatch,
2121
ForwardMode,
2222
)
23-
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner, patch_model
2423

2524
if TYPE_CHECKING:
2625
from sglang.srt.model_executor.hip_model_runner import HiPModelRunner
@@ -41,13 +40,18 @@ def can_run(self, forward_batch: ForwardBatch):
4140
forward_batch.global_num_tokens
4241
)
4342
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
44-
(min_num_tokens == max_num_tokens and (max_num_tokens, use_cached_mask, num_stage_cached) in self.graphs)
43+
(
44+
min_num_tokens == max_num_tokens
45+
and (max_num_tokens, use_cached_mask, num_stage_cached)
46+
in self.graphs
47+
)
4548
if self.disable_padding
4649
else max_num_tokens <= self.max_bs
4750
)
4851
else:
4952
is_bs_supported = (
50-
(forward_batch.batch_size, use_cached_mask, num_stage_cached) in self.graphs
53+
(forward_batch.batch_size, use_cached_mask, num_stage_cached)
54+
in self.graphs
5155
if self.disable_padding
5256
else forward_batch.batch_size <= self.max_bs
5357
)
@@ -70,7 +74,7 @@ def capture(self):
7074
cache_configs = [(True, None)]
7175
for i_stage in range(num_stages):
7276
cache_configs.append((False, i_stage))
73-
77+
7478
self.stream = graph_capture_context.stream
7579
capture_bs = (
7680
tqdm.tqdm(self.capture_bs)
@@ -89,8 +93,7 @@ def capture(self):
8993
graph,
9094
output_buffers,
9195
) = self.capture_one_batch_size(
92-
bs, forward,
93-
use_cached_mask, num_cached_stages
96+
bs, forward, use_cached_mask, num_cached_stages
9497
)
9598
graph_handle = (bs, use_cached_mask, num_cached_stages)
9699
self.graphs[graph_handle] = graph
@@ -99,16 +102,16 @@ def capture(self):
99102
save_gemlite_cache()
100103

101104
def capture_one_batch_size(
102-
self,
103-
bs: int,
104-
forward: Callable,
105+
self,
106+
bs: int,
107+
forward: Callable,
105108
hip_use_cached_mask: bool = False,
106109
hip_num_cached_stages: int = 0,
107110
):
108111
graph = torch.cuda.CUDAGraph()
109112
stream = self.stream
110113
num_tokens = bs * self.num_tokens_per_bs
111-
114+
112115
# Common inputs
113116
input_ids = self.input_ids[:num_tokens]
114117
req_pool_indices = self.req_pool_indices[:bs]
@@ -218,15 +221,15 @@ def replay(self, forward_batch: ForwardBatch):
218221
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
219222
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
220223
self.positions[:raw_num_token].copy_(forward_batch.positions)
221-
224+
222225
if self.is_encoder_decoder:
223226
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
224227
if forward_batch.mrope_positions is not None:
225228
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
226229

227230
if hasattr(forward_batch.spec_info, "hidden_states"):
228231
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
229-
232+
230233
# Attention backend
231234
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
232235
bs,
@@ -239,7 +242,11 @@ def replay(self, forward_batch: ForwardBatch):
239242
)
240243

241244
# Replay
242-
key = (bs, forward_batch.hip_use_cached_mask, forward_batch.hip_metadata_cached_stage)
245+
key = (
246+
bs,
247+
forward_batch.hip_use_cached_mask,
248+
forward_batch.hip_metadata_cached_stage,
249+
)
243250
self.graphs[key].replay()
244251
next_token_logits, hidden_states = self.output_buffers[key]
245252

0 commit comments

Comments
 (0)