Skip to content

fix radix attention #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export CHUNK_PREFILL=16384;
# Any RoPE based attention models are supported in theoritically.
# However currently we are supports `llama.py` models. (Llama Family)
export MODEL="hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4";
# You can set upper limit of maximum extended context window.
# You can set upper limit of maximum extended context window.
# Training-free and unlimited.
export EXTENDED_CONTEXT_LEN=196608;
# You can change this flag into 1, if you want test online cache update. (exprimental)
Expand All @@ -42,13 +42,13 @@ python -m sglang.launch_server \
--context-length $EXTENDED_CONTEXT_LEN \
--max-total-tokens $EXTENDED_CONTEXT_LEN \
--enable-hip-attention \
# You can turn off this flag to disable offloading.
# You can turn off this flag to disable offloading.
# Offloading may have difference in decoding result.
--enable-hip-offload \
# For on-gpu offloading cache in masking kernel,
# For on-gpu offloading cache in masking kernel,
# allocate size of cache in num of tokens. This is shared by whole batch.
--hip-max-mask-cache-token-size 32000 \
# For on-gpu offloading cache in block sparse attention kernel,
# For on-gpu offloading cache in block sparse attention kernel,
# allocate size of cache in num of tokens. This is shared by whole batch.
--hip-max-sa-cache-token-size 10000;
```
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/hf_transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def get_config(
"max_position_embeddings",
]


def get_context_length(config):
"""Get the context length of a model from a huggingface model configs."""
text_config = config
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .hip_cuda_graph_runner import HiPCudaGraphRunner

# from .hip_offload_kv_pool_mha import MHATokenToHiPOffloadKVPool
from .hip_radix_attention import HiPRadixAttentionBackend
164 changes: 86 additions & 78 deletions python/sglang/srt/layers/attention/hip_attention/hip_config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from dataclasses import dataclass, field, InitVar
from typing import List, Optional, Union
import warnings
from dataclasses import InitVar, dataclass, field
from typing import List, Optional, Union

from hip.models.hip_attention.gen3.attention_metadata import ScanStage


_DEFAULT_STAGES = [
ScanStage(
stage_block_size_q=64,
Expand Down Expand Up @@ -35,7 +34,7 @@ class HiPAttentionPerLayerConfig:
second_stage_k: int = 2048
sliding_window_size: int = 1024
sink_token_size: int = 256
sa_extend_backend: str = 'streaming'
sa_extend_backend: str = "streaming"
scan_extend_backend: Optional[str] = None
stages: list[ScanStage] = field(default_factory=lambda: _DEFAULT_STAGES)

Expand All @@ -44,47 +43,50 @@ class HiPAttentionPerLayerConfig:
def __post_init__(self, parsed_json: dict | None):
super().__init__()
if parsed_json is not None:
if 'second_stage_k' in parsed_json:
self.second_stage_k = parsed_json['second_stage_k']
parsed_json.pop('second_stage_k')
if 'sliding_window_size' in parsed_json:
self.sliding_window_size = parsed_json['sliding_window_size']
parsed_json.pop('sliding_window_size')
if 'sink_token_size' in parsed_json:
self.sink_token_size = parsed_json['sink_token_size']
parsed_json.pop('sink_token_size')
if 'sa_extend_backend' in parsed_json:
self.sa_extend_backend = parsed_json['sa_extend_backend']
parsed_json.pop('sa_extend_backend')
if 'scan_extend_backend' in parsed_json:
self.scan_extend_backend = parsed_json['scan_extend_backend']
parsed_json.pop('scan_extend_backend')
if 'stages' in parsed_json:
self.stages = [
ScanStage(**stage)
for stage in parsed_json['stages']
]
parsed_json.pop('stages')
if "second_stage_k" in parsed_json:
self.second_stage_k = parsed_json["second_stage_k"]
parsed_json.pop("second_stage_k")
if "sliding_window_size" in parsed_json:
self.sliding_window_size = parsed_json["sliding_window_size"]
parsed_json.pop("sliding_window_size")
if "sink_token_size" in parsed_json:
self.sink_token_size = parsed_json["sink_token_size"]
parsed_json.pop("sink_token_size")
if "sa_extend_backend" in parsed_json:
self.sa_extend_backend = parsed_json["sa_extend_backend"]
parsed_json.pop("sa_extend_backend")
if "scan_extend_backend" in parsed_json:
self.scan_extend_backend = parsed_json["scan_extend_backend"]
parsed_json.pop("scan_extend_backend")
if "stages" in parsed_json:
self.stages = [ScanStage(**stage) for stage in parsed_json["stages"]]
parsed_json.pop("stages")
if parsed_json:
raise ValueError(f'Unknown keys in json: {parsed_json.keys()}')
raise ValueError(f"Unknown keys in json: {parsed_json.keys()}")


@dataclass
class HiPAttentionConfig:
dense_layers: list[int] = field(default_factory=lambda: [0, 1, 2])
block_sparse_block_size_q: int = 64
metadata_cache_max_batch_size: int = 32
mask_refresh_interval: Union[int, List[int]] = field(default_factory=lambda: [32, 16, 8])
mask_refresh_interval: Union[int, List[int]] = field(
default_factory=lambda: [32, 16, 8]
)
using_extend: bool = True
layers: list[HiPAttentionPerLayerConfig] = field(default_factory=lambda: [
HiPAttentionPerLayerConfig(parsed_json={"second_stage_k": 4096, "sliding_window_size": 1024, "sink_token_size": 256}),
HiPAttentionPerLayerConfig(),
])
prefill_layers: list[HiPAttentionPerLayerConfig] = field(default_factory=lambda: [
HiPAttentionPerLayerConfig(parsed_json={"second_stage_k": 4096, "sliding_window_size": 1024, "sink_token_size": 256}),
HiPAttentionPerLayerConfig(),
])

layers: list[HiPAttentionPerLayerConfig] = field(
default_factory=lambda: [
HiPAttentionPerLayerConfig(
parsed_json={
"second_stage_k": 4096,
"sliding_window_size": 1024,
"sink_token_size": 256,
}
),
HiPAttentionPerLayerConfig(),
]
)

# deprecated
apply_v_dot: bool = False
prefill_always_dense: bool = False
Expand All @@ -96,58 +98,64 @@ class HiPAttentionConfig:

def __post_init__(self, parsed_json: dict | None):
super().__init__()

if parsed_json is not None:
if 'apply_v_dot' in parsed_json:
self.apply_v_dot = parsed_json['apply_v_dot']
parsed_json.pop('apply_v_dot')
if 'dense_layers' in parsed_json:
self.dense_layers = parsed_json['dense_layers']
parsed_json.pop('dense_layers')
if 'prefill_always_dense' in parsed_json:
self.prefill_always_dense = parsed_json['prefill_always_dense']
parsed_json.pop('prefill_always_dense')
if 'decode_always_dense' in parsed_json:
self.decode_always_dense = parsed_json['decode_always_dense']
parsed_json.pop('decode_always_dense')
if 'force_dense' in parsed_json:
self.force_dense = parsed_json['force_dense']
parsed_json.pop('force_dense')
if 'prefill_dense_threshold' in parsed_json:
self.prefill_dense_threshold = parsed_json['prefill_dense_threshold']
parsed_json.pop('prefill_dense_threshold')
if 'block_sparse_block_size_q' in parsed_json:
self.block_sparse_block_size_q = parsed_json['block_sparse_block_size_q']
parsed_json.pop('block_sparse_block_size_q')
if 'metadata_cache_max_batch_size' in parsed_json:
self.metadata_cache_max_batch_size = parsed_json['metadata_cache_max_batch_size']
parsed_json.pop('metadata_cache_max_batch_size')
if 'mask_refresh_interval' in parsed_json:
assert isinstance(parsed_json['mask_refresh_interval'], (int, list))
self.mask_refresh_interval = parsed_json['mask_refresh_interval']
parsed_json.pop('mask_refresh_interval')
if 'using_extend' in parsed_json:
self.using_extend = parsed_json['using_extend']
parsed_json.pop('using_extend')
if 'layers' in parsed_json:
if "apply_v_dot" in parsed_json:
self.apply_v_dot = parsed_json["apply_v_dot"]
parsed_json.pop("apply_v_dot")
if "dense_layers" in parsed_json:
self.dense_layers = parsed_json["dense_layers"]
parsed_json.pop("dense_layers")
if "prefill_always_dense" in parsed_json:
self.prefill_always_dense = parsed_json["prefill_always_dense"]
parsed_json.pop("prefill_always_dense")
if "decode_always_dense" in parsed_json:
self.decode_always_dense = parsed_json["decode_always_dense"]
parsed_json.pop("decode_always_dense")
if "force_dense" in parsed_json:
self.force_dense = parsed_json["force_dense"]
parsed_json.pop("force_dense")
if "prefill_dense_threshold" in parsed_json:
self.prefill_dense_threshold = parsed_json["prefill_dense_threshold"]
parsed_json.pop("prefill_dense_threshold")
if "block_sparse_block_size_q" in parsed_json:
self.block_sparse_block_size_q = parsed_json[
"block_sparse_block_size_q"
]
parsed_json.pop("block_sparse_block_size_q")
if "metadata_cache_max_batch_size" in parsed_json:
self.metadata_cache_max_batch_size = parsed_json[
"metadata_cache_max_batch_size"
]
parsed_json.pop("metadata_cache_max_batch_size")
if "mask_refresh_interval" in parsed_json:
assert isinstance(parsed_json["mask_refresh_interval"], (int, list))
self.mask_refresh_interval = parsed_json["mask_refresh_interval"]
parsed_json.pop("mask_refresh_interval")
if "using_extend" in parsed_json:
self.using_extend = parsed_json["using_extend"]
parsed_json.pop("using_extend")
if "layers" in parsed_json:
self.layers = [
HiPAttentionPerLayerConfig(parsed_json=layer)
for layer in parsed_json['layers']
for layer in parsed_json["layers"]
]
self.prefill_layers = self.layers
parsed_json.pop('layers')
if 'prefill_layers' in parsed_json:
parsed_json.pop("layers")
if "prefill_layers" in parsed_json:
self.prefill_layers = [
HiPAttentionPerLayerConfig(parsed_json=layer)
for layer in parsed_json['prefill_layers']
for layer in parsed_json["prefill_layers"]
]
parsed_json.pop('prefill_layers')
parsed_json.pop("prefill_layers")
if parsed_json:
raise Exception(f'Unknown keys in json: {parsed_json.keys()}')
raise ValueError(f"Unknown keys in json: {parsed_json.keys()}")

num_stages = len(self.layers[0].stages)
for layer_config in self.layers:
assert num_stages == len(layer_config.stages)

if isinstance(self.mask_refresh_interval, int):
self.mask_refresh_interval = [self.mask_refresh_interval, ] * num_stages
self.mask_refresh_interval = [
self.mask_refresh_interval,
] * num_stages
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,18 @@

from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import graph_capture
from sglang.srt.layers.torchao_utils import save_gemlite_cache

from sglang.srt.layers.logits_processor import (
LogitsMetadata,
LogitsProcessor,
LogitsProcessorOutput,
)
from sglang.srt.layers.torchao_utils import save_gemlite_cache
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner, patch_model
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner, patch_model

if TYPE_CHECKING:
from sglang.srt.model_executor.hip_model_runner import HiPModelRunner
Expand All @@ -41,13 +40,18 @@ def can_run(self, forward_batch: ForwardBatch):
forward_batch.global_num_tokens
)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
(min_num_tokens == max_num_tokens and (max_num_tokens, use_cached_mask, num_stage_cached) in self.graphs)
(
min_num_tokens == max_num_tokens
and (max_num_tokens, use_cached_mask, num_stage_cached)
in self.graphs
)
if self.disable_padding
else max_num_tokens <= self.max_bs
)
else:
is_bs_supported = (
(forward_batch.batch_size, use_cached_mask, num_stage_cached) in self.graphs
(forward_batch.batch_size, use_cached_mask, num_stage_cached)
in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)
Expand All @@ -70,7 +74,7 @@ def capture(self):
cache_configs = [(True, None)]
for i_stage in range(num_stages):
cache_configs.append((False, i_stage))

self.stream = graph_capture_context.stream
capture_bs = (
tqdm.tqdm(self.capture_bs)
Expand All @@ -89,8 +93,7 @@ def capture(self):
graph,
output_buffers,
) = self.capture_one_batch_size(
bs, forward,
use_cached_mask, num_cached_stages
bs, forward, use_cached_mask, num_cached_stages
)
graph_handle = (bs, use_cached_mask, num_cached_stages)
self.graphs[graph_handle] = graph
Expand All @@ -99,16 +102,16 @@ def capture(self):
save_gemlite_cache()

def capture_one_batch_size(
self,
bs: int,
forward: Callable,
self,
bs: int,
forward: Callable,
hip_use_cached_mask: bool = False,
hip_num_cached_stages: int = 0,
):
graph = torch.cuda.CUDAGraph()
stream = self.stream
num_tokens = bs * self.num_tokens_per_bs

# Common inputs
input_ids = self.input_ids[:num_tokens]
req_pool_indices = self.req_pool_indices[:bs]
Expand Down Expand Up @@ -218,15 +221,15 @@ def replay(self, forward_batch: ForwardBatch):
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions)

if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)

if hasattr(forward_batch.spec_info, "hidden_states"):
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states

# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs,
Expand All @@ -239,7 +242,11 @@ def replay(self, forward_batch: ForwardBatch):
)

# Replay
key = (bs, forward_batch.hip_use_cached_mask, forward_batch.hip_metadata_cached_stage)
key = (
bs,
forward_batch.hip_use_cached_mask,
forward_batch.hip_metadata_cached_stage,
)
self.graphs[key].replay()
next_token_logits, hidden_states = self.output_buffers[key]

Expand Down
Loading
Loading