diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_config.py b/python/sglang/srt/layers/attention/hip_attention/hip_config.py index a4bd28aa990..6f8f7740c91 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_config.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_config.py @@ -86,6 +86,18 @@ class HiPAttentionConfig: HiPAttentionPerLayerConfig(), ] ) + prefill_layers: list[HiPAttentionPerLayerConfig] = field( + default_factory=lambda: [ + HiPAttentionPerLayerConfig( + parsed_json={ + "second_stage_k": 4096, + "sliding_window_size": 1024, + "sink_token_size": 256, + } + ), + HiPAttentionPerLayerConfig(), + ] + ) # deprecated apply_v_dot: bool = False diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py index f65160508bb..1ab0b1e2103 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py @@ -44,14 +44,49 @@ class WrapperDispatch(Enum): class HiPRadixAttentionBackend(AttentionBackend): - def __init__(self, model_runner: HiPModelRunner): super().__init__() + # NOTE: this backend instance is only one time creation. + self.hip_config: HiPAttentionConfig = model_runner.hip_attention_config self.max_context_len = model_runner.model_config.context_len + # NOTE: this is quite temporary one. + self.q_buffers = [ + torch.zeros( + ( + 1, + self.hip_config.block_sparse_block_size_q, + model_runner.model_config.num_attention_heads + // model_runner.tp_size, + model_runner.model_config.head_dim, + ), + device=torch.device(model_runner.device), + dtype=model_runner.dtype, + ) + for _ in range(model_runner.model_config.num_hidden_layers) + ] + # NOTE: disable q caching + self.q_buffers = None + + def push_q_buffer(self, q: torch.Tensor, layer_id: int, batch_size: int): + if self.q_buffers is None: + return + assert batch_size == 1 + q = q.unsqueeze(0) + layer_q_buffer = self.q_buffers[layer_id] + q_buffer = torch.cat([layer_q_buffer, q[:, -layer_q_buffer.shape[1] :]], dim=1) + layer_q_buffer.copy_(q_buffer[:, -layer_q_buffer.shape[1] :]) + + def get_q_buffer(self, layer_id: int, batch_size: int) -> torch.Tensor: + if self.q_buffers is not None: + assert batch_size == 1 + return self.q_buffers[layer_id].flatten(0, 1) + else: + return None + def init_forward_metadata(self, forward_batch: ForwardBatch): pass @@ -130,6 +165,9 @@ def forward_extend( offload_cache = None q_reshaped = q.reshape(-1, layer.tp_q_head_num, layer.head_dim) + self.push_q_buffer( + q_reshaped, layer_id=layer.layer_id, batch_size=forward_batch.batch_size + ) # Output tensor o = torch.empty_like(q_reshaped) @@ -349,9 +387,17 @@ def forward_decode( ) offload_cache = None + self.push_q_buffer( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + layer_id=layer.layer_id, + batch_size=forward_batch.batch_size, + ) + q_for_masking = self.get_q_buffer(layer.layer_id, forward_batch.batch_size) + if not require_validation: o, metadata = self.forward_paged_hip( query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + query_for_mask=q_for_masking, sm_scale=layer.scaling, batch_size=forward_batch.batch_size, k_cache=k_cache, @@ -384,6 +430,7 @@ def sse(a: torch.Tensor, b: torch.Tensor): o, metadata_new = self.forward_paged_hip( query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + query_for_mask=q_for_masking, sm_scale=layer.scaling, batch_size=forward_batch.batch_size, k_cache=None, @@ -416,6 +463,7 @@ def sse(a: torch.Tensor, b: torch.Tensor): o_valid, metadata_valid = self.forward_paged_hip( query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + query_for_mask=q_for_masking, sm_scale=layer.scaling, batch_size=forward_batch.batch_size, k_cache=k_cache, @@ -491,6 +539,7 @@ def sse(a: torch.Tensor, b: torch.Tensor): o_uvm, metadata_uvm = self.forward_paged_hip( query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + query_for_mask=q_for_masking, sm_scale=layer.scaling, batch_size=forward_batch.batch_size, k_cache=offload_cache.k_uvm.bank_gpu, @@ -518,6 +567,7 @@ def sse(a: torch.Tensor, b: torch.Tensor): o_retry, metadata_retry = self.forward_paged_hip( query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + query_for_mask=q_for_masking, sm_scale=layer.scaling, batch_size=forward_batch.batch_size, k_cache=None, @@ -599,6 +649,7 @@ def forward_paged_hip( is_dense: bool = False, k: Optional[torch.Tensor] = None, v: Optional[torch.Tensor] = None, + query_for_mask: Optional[torch.Tensor] = None, online_update_cache: bool = False, is_decode: bool = False, ) -> tuple[torch.Tensor, "HiPAttentionOutputMetadata"]: @@ -619,6 +670,8 @@ def forward_paged_hip( layer_config = self.hip_config.layers[layer.layer_id] query = query.view(batch_size, dst_seq_len, num_heads, hidden_dims) + if query_for_mask is not None: + query_for_mask = query_for_mask.view(batch_size, -1, num_heads, hidden_dims) if k_cache is not None: N_PAGE, num_heads_kv, hidden_dims_kv = k_cache.shape @@ -654,6 +707,16 @@ def forward_paged_hip( elif os.getenv("HIP_DISABLE_COMPUTE_STATISTICS", "1") == "0": require_cache_statistics = True + if query_for_mask is not None: + query_position_ids = positions.view(batch_size, dst_seq_len) + position_ids = ( + torch.arange(0, query_for_mask.shape[1], device=query.device)[None, :] + - (query_for_mask.shape[1] - 1) + + query_position_ids + ) + else: + position_ids = positions.view(batch_size, dst_seq_len) + args = HiPAttentionArgs( k_cache=( k_cache.view(torch.uint8) @@ -670,7 +733,7 @@ def forward_paged_hip( offload_cache=offload_cache, block_table=block_table, cache_seq_lens=seq_lens, - position_ids=positions.view(batch_size, dst_seq_len), + position_ids=position_ids, block_size_k=32 if is_gemma else 64, # BLOCK_CHUNK sliding_window_size=layer_config.sliding_window_size, sink_token_size=layer_config.sink_token_size, @@ -697,6 +760,11 @@ def forward_paged_hip( online_update_cache=online_update_cache, require_cache_statistics=require_cache_statistics, disable_flashdecode=not is_decode, + q_mask=( + (query_for_mask * sm_scale).to(query.dtype) + if query_for_mask is not None + else None + ), ) context, metadata = dual_stage_quadratic_hip_attention( @@ -707,5 +775,6 @@ def forward_paged_hip( cached_metadata=cached_metadata, ) context = context.to(query.dtype) + context = context[:, -query.shape[1] :, :, :].contiguous() return context.view(N, num_heads, hidden_dims), metadata diff --git a/python/sglang/srt/mem_cache/hip_memory_pool.py b/python/sglang/srt/mem_cache/hip_memory_pool.py index 5ed18b67146..e5bf52fc795 100644 --- a/python/sglang/srt/mem_cache/hip_memory_pool.py +++ b/python/sglang/srt/mem_cache/hip_memory_pool.py @@ -2,6 +2,7 @@ import logging import math +import os from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union @@ -72,9 +73,16 @@ def __init__( layer_config = hip_config.layers[layer_idx] self.layer_configs[layer_idx] = layer_config - n_chunks = triton.cdiv( - layer_config.second_stage_k, layer_config.stages[-1].stage_chunk_size - ) + if os.getenv("HIP_DEBUG_SNAP_KV", "0") == "1": + n_chunks = triton.cdiv( + layer_config.second_stage_k + 2048, + layer_config.stages[-1].stage_chunk_size, + ) + else: + n_chunks = triton.cdiv( + layer_config.second_stage_k, + layer_config.stages[-1].stage_chunk_size, + ) num_q_blocks = 1 self.init_buffer( @@ -132,8 +140,10 @@ def __init__( - layer_config.sliding_window_size - layer_config.sink_token_size ) - chunk_count = ( - min(stage.stage_k, max_context_length) // stage.stage_chunk_size + chunk_count = max( + 0, + min(stage.stage_k, max_context_length) + // stage.stage_chunk_size, ) self.init_buffer( layer_idx, diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 9ad9588c97d..d06b5b5be4b 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -46,7 +46,7 @@ default_weight_loader, kv_cache_scales_loader, ) -from sglang.srt.utils import make_layers +from sglang.srt.utils import make_layers_with_previous_layer Qwen2Config = None @@ -89,6 +89,7 @@ def forward(self, x): class Qwen2Attention(nn.Module): def __init__( self, + config: Qwen2Config, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -97,6 +98,7 @@ def __init__( rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 32768, quant_config: Optional[QuantizationConfig] = None, + previous_layer: Optional["Qwen2Attention"] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -136,19 +138,26 @@ def __init__( quant_config=quant_config, ) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - ) + if previous_layer is None: + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + else: + assert self.head_dim == previous_layer.head_dim + self.rotary_emb = previous_layer.rotary_emb self.attn = RadixAttention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + orig_context_len=getattr( + config, "orig_context_len", max_position_embeddings + ), rope=self.rotary_emb, ) @@ -160,7 +169,13 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) + + # FIXME(geon): find better way to detect if HIP is enabled + if (forward_batch.hip_metadata_cache_pool is None) or ( + not forward_batch.hip_metadata_cache_pool.hip_config.using_extend + ): + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -172,13 +187,21 @@ def __init__( config: Qwen2Config, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + previous_layer: Optional["Qwen2DecoderLayer"] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) max_position_embeddings = getattr(config, "max_position_embeddings", 32768) self.self_attn = Qwen2Attention( + config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, @@ -187,6 +210,9 @@ def __init__( rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, + previous_layer=( + previous_layer.self_attn if previous_layer is not None else None + ), ) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, @@ -239,13 +265,15 @@ def __init__( config.hidden_size, quant_config=quant_config, ) - self.layers = make_layers( + self.layers = make_layers_with_previous_layer( config.num_hidden_layers, - lambda idx, prefix: Qwen2DecoderLayer( - layer_id=idx, + lambda idx, prefix, previous_layer: Qwen2DecoderLayer( config=config, + layer_id=idx, quant_config=quant_config, + previous_layer=previous_layer, ), + prefix="model.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -264,7 +292,10 @@ def forward( else: hidden_states = input_embeds residual = None + + forward_batch.on_model_start() for i in range(len(self.layers)): + forward_batch.on_layer_start(i) layer = self.layers[i] hidden_states, residual = layer( positions, @@ -272,7 +303,11 @@ def forward( forward_batch, residual, ) + forward_batch.on_layer_end(i) + forward_batch.on_model_end() + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states # If this function is called, it should always initialize KV cache scale