Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/temp query caching #3

Merged
merged 9 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/sglang/srt/layers/attention/hip_attention/hip_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ class HiPAttentionConfig:
HiPAttentionPerLayerConfig(),
]
)
prefill_layers: list[HiPAttentionPerLayerConfig] = field(
default_factory=lambda: [
HiPAttentionPerLayerConfig(
parsed_json={
"second_stage_k": 4096,
"sliding_window_size": 1024,
"sink_token_size": 256,
}
),
HiPAttentionPerLayerConfig(),
]
)

# deprecated
apply_v_dot: bool = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,49 @@ class WrapperDispatch(Enum):


class HiPRadixAttentionBackend(AttentionBackend):

def __init__(self, model_runner: HiPModelRunner):
super().__init__()

# NOTE: this backend instance is only one time creation.

self.hip_config: HiPAttentionConfig = model_runner.hip_attention_config

self.max_context_len = model_runner.model_config.context_len

# NOTE: this is quite temporary one.
self.q_buffers = [
torch.zeros(
(
1,
self.hip_config.block_sparse_block_size_q,
model_runner.model_config.num_attention_heads
// model_runner.tp_size,
model_runner.model_config.head_dim,
),
device=torch.device(model_runner.device),
dtype=model_runner.dtype,
)
for _ in range(model_runner.model_config.num_hidden_layers)
]
# NOTE: disable q caching
self.q_buffers = None

def push_q_buffer(self, q: torch.Tensor, layer_id: int, batch_size: int):
if self.q_buffers is None:
return
assert batch_size == 1
q = q.unsqueeze(0)
layer_q_buffer = self.q_buffers[layer_id]
q_buffer = torch.cat([layer_q_buffer, q[:, -layer_q_buffer.shape[1] :]], dim=1)
layer_q_buffer.copy_(q_buffer[:, -layer_q_buffer.shape[1] :])

def get_q_buffer(self, layer_id: int, batch_size: int) -> torch.Tensor:
if self.q_buffers is not None:
assert batch_size == 1
return self.q_buffers[layer_id].flatten(0, 1)
else:
return None

def init_forward_metadata(self, forward_batch: ForwardBatch):
pass

Expand Down Expand Up @@ -130,6 +165,9 @@ def forward_extend(
offload_cache = None

q_reshaped = q.reshape(-1, layer.tp_q_head_num, layer.head_dim)
self.push_q_buffer(
q_reshaped, layer_id=layer.layer_id, batch_size=forward_batch.batch_size
)

# Output tensor
o = torch.empty_like(q_reshaped)
Expand Down Expand Up @@ -349,9 +387,17 @@ def forward_decode(
)
offload_cache = None

self.push_q_buffer(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
layer_id=layer.layer_id,
batch_size=forward_batch.batch_size,
)
q_for_masking = self.get_q_buffer(layer.layer_id, forward_batch.batch_size)

if not require_validation:
o, metadata = self.forward_paged_hip(
query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
query_for_mask=q_for_masking,
sm_scale=layer.scaling,
batch_size=forward_batch.batch_size,
k_cache=k_cache,
Expand Down Expand Up @@ -384,6 +430,7 @@ def sse(a: torch.Tensor, b: torch.Tensor):

o, metadata_new = self.forward_paged_hip(
query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
query_for_mask=q_for_masking,
sm_scale=layer.scaling,
batch_size=forward_batch.batch_size,
k_cache=None,
Expand Down Expand Up @@ -416,6 +463,7 @@ def sse(a: torch.Tensor, b: torch.Tensor):

o_valid, metadata_valid = self.forward_paged_hip(
query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
query_for_mask=q_for_masking,
sm_scale=layer.scaling,
batch_size=forward_batch.batch_size,
k_cache=k_cache,
Expand Down Expand Up @@ -491,6 +539,7 @@ def sse(a: torch.Tensor, b: torch.Tensor):

o_uvm, metadata_uvm = self.forward_paged_hip(
query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
query_for_mask=q_for_masking,
sm_scale=layer.scaling,
batch_size=forward_batch.batch_size,
k_cache=offload_cache.k_uvm.bank_gpu,
Expand Down Expand Up @@ -518,6 +567,7 @@ def sse(a: torch.Tensor, b: torch.Tensor):

o_retry, metadata_retry = self.forward_paged_hip(
query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
query_for_mask=q_for_masking,
sm_scale=layer.scaling,
batch_size=forward_batch.batch_size,
k_cache=None,
Expand Down Expand Up @@ -599,6 +649,7 @@ def forward_paged_hip(
is_dense: bool = False,
k: Optional[torch.Tensor] = None,
v: Optional[torch.Tensor] = None,
query_for_mask: Optional[torch.Tensor] = None,
online_update_cache: bool = False,
is_decode: bool = False,
) -> tuple[torch.Tensor, "HiPAttentionOutputMetadata"]:
Expand All @@ -619,6 +670,8 @@ def forward_paged_hip(
layer_config = self.hip_config.layers[layer.layer_id]

query = query.view(batch_size, dst_seq_len, num_heads, hidden_dims)
if query_for_mask is not None:
query_for_mask = query_for_mask.view(batch_size, -1, num_heads, hidden_dims)

if k_cache is not None:
N_PAGE, num_heads_kv, hidden_dims_kv = k_cache.shape
Expand Down Expand Up @@ -654,6 +707,16 @@ def forward_paged_hip(
elif os.getenv("HIP_DISABLE_COMPUTE_STATISTICS", "1") == "0":
require_cache_statistics = True

if query_for_mask is not None:
query_position_ids = positions.view(batch_size, dst_seq_len)
position_ids = (
torch.arange(0, query_for_mask.shape[1], device=query.device)[None, :]
- (query_for_mask.shape[1] - 1)
+ query_position_ids
)
else:
position_ids = positions.view(batch_size, dst_seq_len)

args = HiPAttentionArgs(
k_cache=(
k_cache.view(torch.uint8)
Expand All @@ -670,7 +733,7 @@ def forward_paged_hip(
offload_cache=offload_cache,
block_table=block_table,
cache_seq_lens=seq_lens,
position_ids=positions.view(batch_size, dst_seq_len),
position_ids=position_ids,
block_size_k=32 if is_gemma else 64, # BLOCK_CHUNK
sliding_window_size=layer_config.sliding_window_size,
sink_token_size=layer_config.sink_token_size,
Expand All @@ -697,6 +760,11 @@ def forward_paged_hip(
online_update_cache=online_update_cache,
require_cache_statistics=require_cache_statistics,
disable_flashdecode=not is_decode,
q_mask=(
(query_for_mask * sm_scale).to(query.dtype)
if query_for_mask is not None
else None
),
)

context, metadata = dual_stage_quadratic_hip_attention(
Expand All @@ -707,5 +775,6 @@ def forward_paged_hip(
cached_metadata=cached_metadata,
)
context = context.to(query.dtype)
context = context[:, -query.shape[1] :, :, :].contiguous()

return context.view(N, num_heads, hidden_dims), metadata
20 changes: 15 additions & 5 deletions python/sglang/srt/mem_cache/hip_memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import math
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union

Expand Down Expand Up @@ -72,9 +73,16 @@ def __init__(
layer_config = hip_config.layers[layer_idx]
self.layer_configs[layer_idx] = layer_config

n_chunks = triton.cdiv(
layer_config.second_stage_k, layer_config.stages[-1].stage_chunk_size
)
if os.getenv("HIP_DEBUG_SNAP_KV", "0") == "1":
n_chunks = triton.cdiv(
layer_config.second_stage_k + 2048,
layer_config.stages[-1].stage_chunk_size,
)
else:
n_chunks = triton.cdiv(
layer_config.second_stage_k,
layer_config.stages[-1].stage_chunk_size,
)

num_q_blocks = 1
self.init_buffer(
Expand Down Expand Up @@ -132,8 +140,10 @@ def __init__(
- layer_config.sliding_window_size
- layer_config.sink_token_size
)
chunk_count = (
min(stage.stage_k, max_context_length) // stage.stage_chunk_size
chunk_count = max(
0,
min(stage.stage_k, max_context_length)
// stage.stage_chunk_size,
)
self.init_buffer(
layer_idx,
Expand Down
59 changes: 47 additions & 12 deletions python/sglang/srt/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
default_weight_loader,
kv_cache_scales_loader,
)
from sglang.srt.utils import make_layers
from sglang.srt.utils import make_layers_with_previous_layer

Qwen2Config = None

Expand Down Expand Up @@ -89,6 +89,7 @@ def forward(self, x):
class Qwen2Attention(nn.Module):
def __init__(
self,
config: Qwen2Config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
Expand All @@ -97,6 +98,7 @@ def __init__(
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 32768,
quant_config: Optional[QuantizationConfig] = None,
previous_layer: Optional["Qwen2Attention"] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
Expand Down Expand Up @@ -136,19 +138,26 @@ def __init__(
quant_config=quant_config,
)

self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
if previous_layer is None:
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
else:
assert self.head_dim == previous_layer.head_dim
self.rotary_emb = previous_layer.rotary_emb
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
orig_context_len=getattr(
config, "orig_context_len", max_position_embeddings
),
rope=self.rotary_emb,
)

Expand All @@ -160,7 +169,13 @@ def forward(
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)

# FIXME(geon): find better way to detect if HIP is enabled
if (forward_batch.hip_metadata_cache_pool is None) or (
not forward_batch.hip_metadata_cache_pool.hip_config.using_extend
):
q, k = self.rotary_emb(positions, q, k)

attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
Expand All @@ -172,13 +187,21 @@ def __init__(
config: Qwen2Config,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
previous_layer: Optional["Qwen2DecoderLayer"] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings
)
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
self.self_attn = Qwen2Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
Expand All @@ -187,6 +210,9 @@ def __init__(
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
previous_layer=(
previous_layer.self_attn if previous_layer is not None else None
),
)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
Expand Down Expand Up @@ -239,13 +265,15 @@ def __init__(
config.hidden_size,
quant_config=quant_config,
)
self.layers = make_layers(
self.layers = make_layers_with_previous_layer(
config.num_hidden_layers,
lambda idx, prefix: Qwen2DecoderLayer(
layer_id=idx,
lambda idx, prefix, previous_layer: Qwen2DecoderLayer(
config=config,
layer_id=idx,
quant_config=quant_config,
previous_layer=previous_layer,
),
prefix="model.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

Expand All @@ -264,15 +292,22 @@ def forward(
else:
hidden_states = input_embeds
residual = None

forward_batch.on_model_start()
for i in range(len(self.layers)):
forward_batch.on_layer_start(i)
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
forward_batch,
residual,
)
forward_batch.on_layer_end(i)
forward_batch.on_model_end()

hidden_states, _ = self.norm(hidden_states, residual)

return hidden_states

# If this function is called, it should always initialize KV cache scale
Expand Down
Loading