Skip to content

Commit a917567

Browse files
committed
supports differnt layer setting
1 parent 39d177b commit a917567

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

python/sglang/srt/layers/attention/hip_attention/hip_config.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dataclasses import dataclass, field, InitVar
22
from typing import List, Optional, Union
3+
import warnings
34

45
from hip.models.hip_attention.gen3.attention_metadata import ScanStage
56

@@ -79,6 +80,10 @@ class HiPAttentionConfig:
7980
HiPAttentionPerLayerConfig(parsed_json={"second_stage_k": 4096, "sliding_window_size": 1024, "sink_token_size": 256}),
8081
HiPAttentionPerLayerConfig(),
8182
])
83+
prefill_layers: list[HiPAttentionPerLayerConfig] = field(default_factory=lambda: [
84+
HiPAttentionPerLayerConfig(parsed_json={"second_stage_k": 4096, "sliding_window_size": 1024, "sink_token_size": 256}),
85+
HiPAttentionPerLayerConfig(),
86+
])
8287

8388
# deprecated
8489
apply_v_dot: bool = False
@@ -129,9 +134,16 @@ def __post_init__(self, parsed_json: dict | None):
129134
HiPAttentionPerLayerConfig(parsed_json=layer)
130135
for layer in parsed_json['layers']
131136
]
137+
self.prefill_layers = self.layers
132138
parsed_json.pop('layers')
139+
if 'prefill_layers' in parsed_json:
140+
self.prefill_layers = [
141+
HiPAttentionPerLayerConfig(parsed_json=layer)
142+
for layer in parsed_json['prefill_layers']
143+
]
144+
parsed_json.pop('prefill_layers')
133145
if parsed_json:
134-
raise ValueError(f'Unknown keys in json: {parsed_json.keys()}')
146+
raise Exception(f'Unknown keys in json: {parsed_json.keys()}')
135147

136148
num_stages = len(self.layers[0].stages)
137149
for layer_config in self.layers:

python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -528,16 +528,22 @@ def forward_paged_hip(
528528

529529
online_update_cache: bool = False,
530530
) -> tuple[torch.Tensor, "HiPAttentionOutputMetadata"]:
531-
is_dense = layer.layer_id in self.hip_config.dense_layers
532-
533-
if len(self.hip_config.layers) == 2:
534-
layer_config = self.hip_config.layers[0 if is_dense else 1]
535-
else:
536-
layer_config = self.hip_config.layers[layer.layer_id]
537-
538531
N, num_heads, hidden_dims = query.shape
539532
dst_seq_len = N // batch_size
540533

534+
is_decode = dst_seq_len == 1
535+
is_dense = layer.layer_id in self.hip_config.dense_layers
536+
if not is_decode:
537+
if len(self.hip_config.prefill_layers) == 2:
538+
layer_config = self.hip_config.prefill_layers[0 if is_dense else 1]
539+
else:
540+
layer_config = self.hip_config.prefill_layers[layer.layer_id]
541+
else:
542+
if len(self.hip_config.layers) == 2:
543+
layer_config = self.hip_config.layers[0 if is_dense else 1]
544+
else:
545+
layer_config = self.hip_config.layers[layer.layer_id]
546+
541547
query = query.view(batch_size, dst_seq_len, num_heads, hidden_dims)
542548

543549
if k_cache is not None:

0 commit comments

Comments
 (0)