Skip to content

Commit 0005b7e

Browse files
authored
Merge pull request #3 from DeepAuto-AI/feat/temp-query-caching
Feat/temp query caching
2 parents cb2d74e + e260127 commit 0005b7e

File tree

4 files changed

+145
-19
lines changed

4 files changed

+145
-19
lines changed

python/sglang/srt/layers/attention/hip_attention/hip_config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,18 @@ class HiPAttentionConfig:
8686
HiPAttentionPerLayerConfig(),
8787
]
8888
)
89+
prefill_layers: list[HiPAttentionPerLayerConfig] = field(
90+
default_factory=lambda: [
91+
HiPAttentionPerLayerConfig(
92+
parsed_json={
93+
"second_stage_k": 4096,
94+
"sliding_window_size": 1024,
95+
"sink_token_size": 256,
96+
}
97+
),
98+
HiPAttentionPerLayerConfig(),
99+
]
100+
)
89101

90102
# deprecated
91103
apply_v_dot: bool = False

python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,49 @@ class WrapperDispatch(Enum):
4444

4545

4646
class HiPRadixAttentionBackend(AttentionBackend):
47-
4847
def __init__(self, model_runner: HiPModelRunner):
4948
super().__init__()
5049

50+
# NOTE: this backend instance is only one time creation.
51+
5152
self.hip_config: HiPAttentionConfig = model_runner.hip_attention_config
5253

5354
self.max_context_len = model_runner.model_config.context_len
5455

56+
# NOTE: this is quite temporary one.
57+
self.q_buffers = [
58+
torch.zeros(
59+
(
60+
1,
61+
self.hip_config.block_sparse_block_size_q,
62+
model_runner.model_config.num_attention_heads
63+
// model_runner.tp_size,
64+
model_runner.model_config.head_dim,
65+
),
66+
device=torch.device(model_runner.device),
67+
dtype=model_runner.dtype,
68+
)
69+
for _ in range(model_runner.model_config.num_hidden_layers)
70+
]
71+
# NOTE: disable q caching
72+
self.q_buffers = None
73+
74+
def push_q_buffer(self, q: torch.Tensor, layer_id: int, batch_size: int):
75+
if self.q_buffers is None:
76+
return
77+
assert batch_size == 1
78+
q = q.unsqueeze(0)
79+
layer_q_buffer = self.q_buffers[layer_id]
80+
q_buffer = torch.cat([layer_q_buffer, q[:, -layer_q_buffer.shape[1] :]], dim=1)
81+
layer_q_buffer.copy_(q_buffer[:, -layer_q_buffer.shape[1] :])
82+
83+
def get_q_buffer(self, layer_id: int, batch_size: int) -> torch.Tensor:
84+
if self.q_buffers is not None:
85+
assert batch_size == 1
86+
return self.q_buffers[layer_id].flatten(0, 1)
87+
else:
88+
return None
89+
5590
def init_forward_metadata(self, forward_batch: ForwardBatch):
5691
pass
5792

@@ -130,6 +165,9 @@ def forward_extend(
130165
offload_cache = None
131166

132167
q_reshaped = q.reshape(-1, layer.tp_q_head_num, layer.head_dim)
168+
self.push_q_buffer(
169+
q_reshaped, layer_id=layer.layer_id, batch_size=forward_batch.batch_size
170+
)
133171

134172
# Output tensor
135173
o = torch.empty_like(q_reshaped)
@@ -349,9 +387,17 @@ def forward_decode(
349387
)
350388
offload_cache = None
351389

390+
self.push_q_buffer(
391+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
392+
layer_id=layer.layer_id,
393+
batch_size=forward_batch.batch_size,
394+
)
395+
q_for_masking = self.get_q_buffer(layer.layer_id, forward_batch.batch_size)
396+
352397
if not require_validation:
353398
o, metadata = self.forward_paged_hip(
354399
query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
400+
query_for_mask=q_for_masking,
355401
sm_scale=layer.scaling,
356402
batch_size=forward_batch.batch_size,
357403
k_cache=k_cache,
@@ -384,6 +430,7 @@ def sse(a: torch.Tensor, b: torch.Tensor):
384430

385431
o, metadata_new = self.forward_paged_hip(
386432
query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
433+
query_for_mask=q_for_masking,
387434
sm_scale=layer.scaling,
388435
batch_size=forward_batch.batch_size,
389436
k_cache=None,
@@ -416,6 +463,7 @@ def sse(a: torch.Tensor, b: torch.Tensor):
416463

417464
o_valid, metadata_valid = self.forward_paged_hip(
418465
query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
466+
query_for_mask=q_for_masking,
419467
sm_scale=layer.scaling,
420468
batch_size=forward_batch.batch_size,
421469
k_cache=k_cache,
@@ -491,6 +539,7 @@ def sse(a: torch.Tensor, b: torch.Tensor):
491539

492540
o_uvm, metadata_uvm = self.forward_paged_hip(
493541
query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
542+
query_for_mask=q_for_masking,
494543
sm_scale=layer.scaling,
495544
batch_size=forward_batch.batch_size,
496545
k_cache=offload_cache.k_uvm.bank_gpu,
@@ -518,6 +567,7 @@ def sse(a: torch.Tensor, b: torch.Tensor):
518567

519568
o_retry, metadata_retry = self.forward_paged_hip(
520569
query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
570+
query_for_mask=q_for_masking,
521571
sm_scale=layer.scaling,
522572
batch_size=forward_batch.batch_size,
523573
k_cache=None,
@@ -599,6 +649,7 @@ def forward_paged_hip(
599649
is_dense: bool = False,
600650
k: Optional[torch.Tensor] = None,
601651
v: Optional[torch.Tensor] = None,
652+
query_for_mask: Optional[torch.Tensor] = None,
602653
online_update_cache: bool = False,
603654
is_decode: bool = False,
604655
) -> tuple[torch.Tensor, "HiPAttentionOutputMetadata"]:
@@ -619,6 +670,8 @@ def forward_paged_hip(
619670
layer_config = self.hip_config.layers[layer.layer_id]
620671

621672
query = query.view(batch_size, dst_seq_len, num_heads, hidden_dims)
673+
if query_for_mask is not None:
674+
query_for_mask = query_for_mask.view(batch_size, -1, num_heads, hidden_dims)
622675

623676
if k_cache is not None:
624677
N_PAGE, num_heads_kv, hidden_dims_kv = k_cache.shape
@@ -654,6 +707,16 @@ def forward_paged_hip(
654707
elif os.getenv("HIP_DISABLE_COMPUTE_STATISTICS", "1") == "0":
655708
require_cache_statistics = True
656709

710+
if query_for_mask is not None:
711+
query_position_ids = positions.view(batch_size, dst_seq_len)
712+
position_ids = (
713+
torch.arange(0, query_for_mask.shape[1], device=query.device)[None, :]
714+
- (query_for_mask.shape[1] - 1)
715+
+ query_position_ids
716+
)
717+
else:
718+
position_ids = positions.view(batch_size, dst_seq_len)
719+
657720
args = HiPAttentionArgs(
658721
k_cache=(
659722
k_cache.view(torch.uint8)
@@ -670,7 +733,7 @@ def forward_paged_hip(
670733
offload_cache=offload_cache,
671734
block_table=block_table,
672735
cache_seq_lens=seq_lens,
673-
position_ids=positions.view(batch_size, dst_seq_len),
736+
position_ids=position_ids,
674737
block_size_k=32 if is_gemma else 64, # BLOCK_CHUNK
675738
sliding_window_size=layer_config.sliding_window_size,
676739
sink_token_size=layer_config.sink_token_size,
@@ -697,6 +760,11 @@ def forward_paged_hip(
697760
online_update_cache=online_update_cache,
698761
require_cache_statistics=require_cache_statistics,
699762
disable_flashdecode=not is_decode,
763+
q_mask=(
764+
(query_for_mask * sm_scale).to(query.dtype)
765+
if query_for_mask is not None
766+
else None
767+
),
700768
)
701769

702770
context, metadata = dual_stage_quadratic_hip_attention(
@@ -707,5 +775,6 @@ def forward_paged_hip(
707775
cached_metadata=cached_metadata,
708776
)
709777
context = context.to(query.dtype)
778+
context = context[:, -query.shape[1] :, :, :].contiguous()
710779

711780
return context.view(N, num_heads, hidden_dims), metadata

python/sglang/srt/mem_cache/hip_memory_pool.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import logging
44
import math
5+
import os
56
from dataclasses import dataclass
67
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
78

@@ -72,9 +73,16 @@ def __init__(
7273
layer_config = hip_config.layers[layer_idx]
7374
self.layer_configs[layer_idx] = layer_config
7475

75-
n_chunks = triton.cdiv(
76-
layer_config.second_stage_k, layer_config.stages[-1].stage_chunk_size
77-
)
76+
if os.getenv("HIP_DEBUG_SNAP_KV", "0") == "1":
77+
n_chunks = triton.cdiv(
78+
layer_config.second_stage_k + 2048,
79+
layer_config.stages[-1].stage_chunk_size,
80+
)
81+
else:
82+
n_chunks = triton.cdiv(
83+
layer_config.second_stage_k,
84+
layer_config.stages[-1].stage_chunk_size,
85+
)
7886

7987
num_q_blocks = 1
8088
self.init_buffer(
@@ -132,8 +140,10 @@ def __init__(
132140
- layer_config.sliding_window_size
133141
- layer_config.sink_token_size
134142
)
135-
chunk_count = (
136-
min(stage.stage_k, max_context_length) // stage.stage_chunk_size
143+
chunk_count = max(
144+
0,
145+
min(stage.stage_k, max_context_length)
146+
// stage.stage_chunk_size,
137147
)
138148
self.init_buffer(
139149
layer_idx,

python/sglang/srt/models/qwen2.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
default_weight_loader,
4747
kv_cache_scales_loader,
4848
)
49-
from sglang.srt.utils import make_layers
49+
from sglang.srt.utils import make_layers_with_previous_layer
5050

5151
Qwen2Config = None
5252

@@ -89,6 +89,7 @@ def forward(self, x):
8989
class Qwen2Attention(nn.Module):
9090
def __init__(
9191
self,
92+
config: Qwen2Config,
9293
hidden_size: int,
9394
num_heads: int,
9495
num_kv_heads: int,
@@ -97,6 +98,7 @@ def __init__(
9798
rope_scaling: Optional[Dict[str, Any]] = None,
9899
max_position_embeddings: int = 32768,
99100
quant_config: Optional[QuantizationConfig] = None,
101+
previous_layer: Optional["Qwen2Attention"] = None,
100102
) -> None:
101103
super().__init__()
102104
self.hidden_size = hidden_size
@@ -136,19 +138,26 @@ def __init__(
136138
quant_config=quant_config,
137139
)
138140

139-
self.rotary_emb = get_rope(
140-
self.head_dim,
141-
rotary_dim=self.head_dim,
142-
max_position=max_position_embeddings,
143-
base=rope_theta,
144-
rope_scaling=rope_scaling,
145-
)
141+
if previous_layer is None:
142+
self.rotary_emb = get_rope(
143+
self.head_dim,
144+
rotary_dim=self.head_dim,
145+
max_position=max_position_embeddings,
146+
base=rope_theta,
147+
rope_scaling=rope_scaling,
148+
)
149+
else:
150+
assert self.head_dim == previous_layer.head_dim
151+
self.rotary_emb = previous_layer.rotary_emb
146152
self.attn = RadixAttention(
147153
self.num_heads,
148154
self.head_dim,
149155
self.scaling,
150156
num_kv_heads=self.num_kv_heads,
151157
layer_id=layer_id,
158+
orig_context_len=getattr(
159+
config, "orig_context_len", max_position_embeddings
160+
),
152161
rope=self.rotary_emb,
153162
)
154163

@@ -160,7 +169,13 @@ def forward(
160169
) -> torch.Tensor:
161170
qkv, _ = self.qkv_proj(hidden_states)
162171
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
163-
q, k = self.rotary_emb(positions, q, k)
172+
173+
# FIXME(geon): find better way to detect if HIP is enabled
174+
if (forward_batch.hip_metadata_cache_pool is None) or (
175+
not forward_batch.hip_metadata_cache_pool.hip_config.using_extend
176+
):
177+
q, k = self.rotary_emb(positions, q, k)
178+
164179
attn_output = self.attn(q, k, v, forward_batch)
165180
output, _ = self.o_proj(attn_output)
166181
return output
@@ -172,13 +187,21 @@ def __init__(
172187
config: Qwen2Config,
173188
layer_id: int = 0,
174189
quant_config: Optional[QuantizationConfig] = None,
190+
previous_layer: Optional["Qwen2DecoderLayer"] = None,
175191
) -> None:
176192
super().__init__()
177193
self.hidden_size = config.hidden_size
178194
rope_theta = getattr(config, "rope_theta", 1000000)
179195
rope_scaling = getattr(config, "rope_scaling", None)
196+
if rope_scaling is not None and getattr(
197+
config, "original_max_position_embeddings", None
198+
):
199+
rope_scaling["original_max_position_embeddings"] = (
200+
config.original_max_position_embeddings
201+
)
180202
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
181203
self.self_attn = Qwen2Attention(
204+
config=config,
182205
hidden_size=self.hidden_size,
183206
num_heads=config.num_attention_heads,
184207
num_kv_heads=config.num_key_value_heads,
@@ -187,6 +210,9 @@ def __init__(
187210
rope_scaling=rope_scaling,
188211
max_position_embeddings=max_position_embeddings,
189212
quant_config=quant_config,
213+
previous_layer=(
214+
previous_layer.self_attn if previous_layer is not None else None
215+
),
190216
)
191217
self.mlp = Qwen2MLP(
192218
hidden_size=self.hidden_size,
@@ -239,13 +265,15 @@ def __init__(
239265
config.hidden_size,
240266
quant_config=quant_config,
241267
)
242-
self.layers = make_layers(
268+
self.layers = make_layers_with_previous_layer(
243269
config.num_hidden_layers,
244-
lambda idx, prefix: Qwen2DecoderLayer(
245-
layer_id=idx,
270+
lambda idx, prefix, previous_layer: Qwen2DecoderLayer(
246271
config=config,
272+
layer_id=idx,
247273
quant_config=quant_config,
274+
previous_layer=previous_layer,
248275
),
276+
prefix="model.layers",
249277
)
250278
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
251279

@@ -264,15 +292,22 @@ def forward(
264292
else:
265293
hidden_states = input_embeds
266294
residual = None
295+
296+
forward_batch.on_model_start()
267297
for i in range(len(self.layers)):
298+
forward_batch.on_layer_start(i)
268299
layer = self.layers[i]
269300
hidden_states, residual = layer(
270301
positions,
271302
hidden_states,
272303
forward_batch,
273304
residual,
274305
)
306+
forward_batch.on_layer_end(i)
307+
forward_batch.on_model_end()
308+
275309
hidden_states, _ = self.norm(hidden_states, residual)
310+
276311
return hidden_states
277312

278313
# If this function is called, it should always initialize KV cache scale

0 commit comments

Comments
 (0)