From 43c2f957c839c9aa5dc7081d5b8ec81055c869b4 Mon Sep 17 00:00:00 2001 From: AinL Date: Fri, 24 Jan 2025 07:15:14 +0000 Subject: [PATCH 1/9] update --- .../attention/hip_attention/hip_config.py | 22 +++---- .../hip_attention/hip_radix_attention.py | 65 +++++++++++++++++-- .../sglang/srt/mem_cache/hip_memory_pool.py | 12 ++-- 3 files changed, 78 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_config.py b/python/sglang/srt/layers/attention/hip_attention/hip_config.py index a4bd28aa990..24ceb5aead0 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_config.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_config.py @@ -74,19 +74,15 @@ class HiPAttentionConfig: default_factory=lambda: [32, 16, 8] ) using_extend: bool = True - layers: list[HiPAttentionPerLayerConfig] = field( - default_factory=lambda: [ - HiPAttentionPerLayerConfig( - parsed_json={ - "second_stage_k": 4096, - "sliding_window_size": 1024, - "sink_token_size": 256, - } - ), - HiPAttentionPerLayerConfig(), - ] - ) - + layers: list[HiPAttentionPerLayerConfig] = field(default_factory=lambda: [ + HiPAttentionPerLayerConfig(parsed_json={"second_stage_k": 4096, "sliding_window_size": 1024, "sink_token_size": 256}), + HiPAttentionPerLayerConfig(), + ]) + prefill_layers: list[HiPAttentionPerLayerConfig] = field(default_factory=lambda: [ + HiPAttentionPerLayerConfig(parsed_json={"second_stage_k": 4096, "sliding_window_size": 1024, "sink_token_size": 256}), + HiPAttentionPerLayerConfig(), + ]) + # deprecated apply_v_dot: bool = False prefill_always_dense: bool = False diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py index f65160508bb..211e7e25af1 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py @@ -44,13 +44,40 @@ class WrapperDispatch(Enum): class HiPRadixAttentionBackend(AttentionBackend): - def __init__(self, model_runner: HiPModelRunner): super().__init__() + + # NOTE: this backend instance is only one time creation. self.hip_config: HiPAttentionConfig = model_runner.hip_attention_config self.max_context_len = model_runner.model_config.context_len + + # NOTE: this is quite temporary one. + self.q_buffers = [ + torch.zeros( + (1, 64, 32, 128), + device=torch.device(model_runner.device), + dtype=model_runner.dtype + ) + for layer_id in range(32) + ] + + def push_q_buffer(self, q: torch.Tensor, layer_id: int, batch_size: int): + assert batch_size == 1 + q = q.unsqueeze(0) + layer_q_buffer = self.q_buffers[layer_id] + q_buffer = torch.cat([layer_q_buffer, q[:, -layer_q_buffer.shape[1]:]], dim=1) + layer_q_buffer.copy_(q_buffer[:, -layer_q_buffer.shape[1]:]) + # if layer_id == 0: print(layer_q_buffer[0, :, 0, 0]) + + def get_q_buffer(self, layer_id: int, batch_size: int) -> torch.Tensor: + if self.q_buffers is not None: + assert batch_size == 1 + # if layer_id == 0: print(self.q_buffers[layer_id][0, :, 0, 0]) + return self.q_buffers[layer_id].flatten(0, 1) + else: + return None def init_forward_metadata(self, forward_batch: ForwardBatch): pass @@ -128,8 +155,9 @@ def forward_extend( layer.layer_id ) offload_cache = None - + q_reshaped = q.reshape(-1, layer.tp_q_head_num, layer.head_dim) + self.push_q_buffer(q_reshaped, layer_id=layer.layer_id, batch_size=forward_batch.batch_size) # Output tensor o = torch.empty_like(q_reshaped) @@ -348,10 +376,20 @@ def forward_decode( layer.layer_id ) offload_cache = None - + + self.push_q_buffer( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + layer_id=layer.layer_id, + batch_size=forward_batch.batch_size, + ) + q_for_masking = self.get_q_buffer( + layer.layer_id, forward_batch.batch_size + ) + if not require_validation: o, metadata = self.forward_paged_hip( query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + query_for_mask=q_for_masking, sm_scale=layer.scaling, batch_size=forward_batch.batch_size, k_cache=k_cache, @@ -384,6 +422,7 @@ def sse(a: torch.Tensor, b: torch.Tensor): o, metadata_new = self.forward_paged_hip( query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + query_for_mask=q_for_masking, sm_scale=layer.scaling, batch_size=forward_batch.batch_size, k_cache=None, @@ -416,6 +455,7 @@ def sse(a: torch.Tensor, b: torch.Tensor): o_valid, metadata_valid = self.forward_paged_hip( query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + query_for_mask=q_for_masking, sm_scale=layer.scaling, batch_size=forward_batch.batch_size, k_cache=k_cache, @@ -491,6 +531,7 @@ def sse(a: torch.Tensor, b: torch.Tensor): o_uvm, metadata_uvm = self.forward_paged_hip( query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + query_for_mask=q_for_masking, sm_scale=layer.scaling, batch_size=forward_batch.batch_size, k_cache=offload_cache.k_uvm.bank_gpu, @@ -518,6 +559,7 @@ def sse(a: torch.Tensor, b: torch.Tensor): o_retry, metadata_retry = self.forward_paged_hip( query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + query_for_mask=q_for_masking, sm_scale=layer.scaling, batch_size=forward_batch.batch_size, k_cache=None, @@ -599,6 +641,8 @@ def forward_paged_hip( is_dense: bool = False, k: Optional[torch.Tensor] = None, v: Optional[torch.Tensor] = None, + query_for_mask: Optional[torch.Tensor] = None, + online_update_cache: bool = False, is_decode: bool = False, ) -> tuple[torch.Tensor, "HiPAttentionOutputMetadata"]: @@ -619,6 +663,8 @@ def forward_paged_hip( layer_config = self.hip_config.layers[layer.layer_id] query = query.view(batch_size, dst_seq_len, num_heads, hidden_dims) + if query_for_mask is not None: + query_for_mask = query_for_mask.view(batch_size, -1, num_heads, hidden_dims) if k_cache is not None: N_PAGE, num_heads_kv, hidden_dims_kv = k_cache.shape @@ -653,6 +699,14 @@ def forward_paged_hip( require_cache_statistics = True elif os.getenv("HIP_DISABLE_COMPUTE_STATISTICS", "1") == "0": require_cache_statistics = True + + if query_for_mask is not None: + query_position_ids = positions.view(batch_size, dst_seq_len) + position_ids = torch.arange( + 0, query_for_mask.shape[1], device=query.device + )[None, :] - (query_for_mask.shape[1] - 1) + query_position_ids + else: + position_ids = positions.view(batch_size, dst_seq_len) args = HiPAttentionArgs( k_cache=( @@ -670,7 +724,8 @@ def forward_paged_hip( offload_cache=offload_cache, block_table=block_table, cache_seq_lens=seq_lens, - position_ids=positions.view(batch_size, dst_seq_len), + position_ids=position_ids, + block_size_k=32 if is_gemma else 64, # BLOCK_CHUNK sliding_window_size=layer_config.sliding_window_size, sink_token_size=layer_config.sink_token_size, @@ -697,6 +752,7 @@ def forward_paged_hip( online_update_cache=online_update_cache, require_cache_statistics=require_cache_statistics, disable_flashdecode=not is_decode, + q_mask=(query_for_mask * sm_scale).to(query.dtype) if query_for_mask is not None else None, ) context, metadata = dual_stage_quadratic_hip_attention( @@ -707,5 +763,6 @@ def forward_paged_hip( cached_metadata=cached_metadata, ) context = context.to(query.dtype) + context = context[:, -query.shape[1]:, :, :].contiguous() return context.view(N, num_heads, hidden_dims), metadata diff --git a/python/sglang/srt/mem_cache/hip_memory_pool.py b/python/sglang/srt/mem_cache/hip_memory_pool.py index 5ed18b67146..e2341037024 100644 --- a/python/sglang/srt/mem_cache/hip_memory_pool.py +++ b/python/sglang/srt/mem_cache/hip_memory_pool.py @@ -1,5 +1,8 @@ from __future__ import annotations +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Literal, Tuple, Union, Optional, Dict import logging import math from dataclasses import dataclass @@ -72,10 +75,11 @@ def __init__( layer_config = hip_config.layers[layer_idx] self.layer_configs[layer_idx] = layer_config - n_chunks = triton.cdiv( - layer_config.second_stage_k, layer_config.stages[-1].stage_chunk_size - ) - + if os.getenv('HIP_DEBUG_SNAP_KV', '0') == '1': + n_chunks = triton.cdiv(layer_config.second_stage_k + 2048, layer_config.stages[-1].stage_chunk_size) + else: + n_chunks = triton.cdiv(layer_config.second_stage_k, layer_config.stages[-1].stage_chunk_size) + num_q_blocks = 1 self.init_buffer( layer_idx, From 49d00658fb69999c4aed26adb622027c28197b9d Mon Sep 17 00:00:00 2001 From: AinL Date: Fri, 24 Jan 2025 10:48:09 +0000 Subject: [PATCH 2/9] fix bug --- .../sglang/srt/mem_cache/hip_memory_pool.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/mem_cache/hip_memory_pool.py b/python/sglang/srt/mem_cache/hip_memory_pool.py index e2341037024..29a7ffacec1 100644 --- a/python/sglang/srt/mem_cache/hip_memory_pool.py +++ b/python/sglang/srt/mem_cache/hip_memory_pool.py @@ -5,9 +5,6 @@ from typing import TYPE_CHECKING, List, Literal, Tuple, Union, Optional, Dict import logging import math -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union - import torch import triton from hip.models.hip_attention.gen3.attention_metadata import ( @@ -75,11 +72,17 @@ def __init__( layer_config = hip_config.layers[layer_idx] self.layer_configs[layer_idx] = layer_config - if os.getenv('HIP_DEBUG_SNAP_KV', '0') == '1': - n_chunks = triton.cdiv(layer_config.second_stage_k + 2048, layer_config.stages[-1].stage_chunk_size) + if os.getenv("HIP_DEBUG_SNAP_KV", "0") == "1": + n_chunks = triton.cdiv( + layer_config.second_stage_k + 2048, + layer_config.stages[-1].stage_chunk_size, + ) else: - n_chunks = triton.cdiv(layer_config.second_stage_k, layer_config.stages[-1].stage_chunk_size) - + n_chunks = triton.cdiv( + layer_config.second_stage_k, + layer_config.stages[-1].stage_chunk_size, + ) + num_q_blocks = 1 self.init_buffer( layer_idx, @@ -136,8 +139,10 @@ def __init__( - layer_config.sliding_window_size - layer_config.sink_token_size ) - chunk_count = ( - min(stage.stage_k, max_context_length) // stage.stage_chunk_size + chunk_count = max( + 0, + min(stage.stage_k, max_context_length) + // stage.stage_chunk_size, ) self.init_buffer( layer_idx, From a1aa68cdd4ca208c74e0f030f480bbb16f3bf130 Mon Sep 17 00:00:00 2001 From: AinL Date: Fri, 24 Jan 2025 07:15:14 +0000 Subject: [PATCH 3/9] update --- .../srt/layers/attention/hip_attention/hip_config.py | 8 ++++++-- python/sglang/srt/mem_cache/hip_memory_pool.py | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_config.py b/python/sglang/srt/layers/attention/hip_attention/hip_config.py index 24ceb5aead0..dddd52a8492 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_config.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_config.py @@ -75,8 +75,12 @@ class HiPAttentionConfig: ) using_extend: bool = True layers: list[HiPAttentionPerLayerConfig] = field(default_factory=lambda: [ - HiPAttentionPerLayerConfig(parsed_json={"second_stage_k": 4096, "sliding_window_size": 1024, "sink_token_size": 256}), - HiPAttentionPerLayerConfig(), + HiPAttentionPerLayerConfig(parsed_json={ + "second_stage_k": 4096, + "sliding_window_size": 32768, + "sink_token_size": 256, + }), + HiPAttentionPerLayerConfig(parsed_json={"second_stage_k": 1024, "sliding_window_size": 1024, "sink_token_size": 256}), ]) prefill_layers: list[HiPAttentionPerLayerConfig] = field(default_factory=lambda: [ HiPAttentionPerLayerConfig(parsed_json={"second_stage_k": 4096, "sliding_window_size": 1024, "sink_token_size": 256}), diff --git a/python/sglang/srt/mem_cache/hip_memory_pool.py b/python/sglang/srt/mem_cache/hip_memory_pool.py index 29a7ffacec1..946c39307f1 100644 --- a/python/sglang/srt/mem_cache/hip_memory_pool.py +++ b/python/sglang/srt/mem_cache/hip_memory_pool.py @@ -2,6 +2,7 @@ import os from dataclasses import dataclass +import os from typing import TYPE_CHECKING, List, Literal, Tuple, Union, Optional, Dict import logging import math From e861a54df51072750cb378270ada75878929761b Mon Sep 17 00:00:00 2001 From: AinL Date: Sat, 25 Jan 2025 11:18:02 +0000 Subject: [PATCH 4/9] feat(need fix): support query caching --- .../hip_attention/hip_radix_attention.py | 62 +++++++++++-------- 1 file changed, 35 insertions(+), 27 deletions(-) diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py index 211e7e25af1..452f968e5f1 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py @@ -46,35 +46,39 @@ class WrapperDispatch(Enum): class HiPRadixAttentionBackend(AttentionBackend): def __init__(self, model_runner: HiPModelRunner): super().__init__() - + # NOTE: this backend instance is only one time creation. self.hip_config: HiPAttentionConfig = model_runner.hip_attention_config self.max_context_len = model_runner.model_config.context_len - + # NOTE: this is quite temporary one. self.q_buffers = [ torch.zeros( - (1, 64, 32, 128), - device=torch.device(model_runner.device), - dtype=model_runner.dtype + ( + 1, + self.hip_config.block_sparse_block_size_q, + model_runner.model_config.num_attention_heads + // model_runner.tp_size, + model_runner.model_config.head_dim, + ), + device=torch.device(model_runner.device), + dtype=model_runner.dtype, ) - for layer_id in range(32) + for _ in range(model_runner.model_config.num_hidden_layers) ] - + def push_q_buffer(self, q: torch.Tensor, layer_id: int, batch_size: int): assert batch_size == 1 q = q.unsqueeze(0) layer_q_buffer = self.q_buffers[layer_id] - q_buffer = torch.cat([layer_q_buffer, q[:, -layer_q_buffer.shape[1]:]], dim=1) - layer_q_buffer.copy_(q_buffer[:, -layer_q_buffer.shape[1]:]) - # if layer_id == 0: print(layer_q_buffer[0, :, 0, 0]) - + q_buffer = torch.cat([layer_q_buffer, q[:, -layer_q_buffer.shape[1] :]], dim=1) + layer_q_buffer.copy_(q_buffer[:, -layer_q_buffer.shape[1] :]) + def get_q_buffer(self, layer_id: int, batch_size: int) -> torch.Tensor: if self.q_buffers is not None: assert batch_size == 1 - # if layer_id == 0: print(self.q_buffers[layer_id][0, :, 0, 0]) return self.q_buffers[layer_id].flatten(0, 1) else: return None @@ -155,9 +159,11 @@ def forward_extend( layer.layer_id ) offload_cache = None - + q_reshaped = q.reshape(-1, layer.tp_q_head_num, layer.head_dim) - self.push_q_buffer(q_reshaped, layer_id=layer.layer_id, batch_size=forward_batch.batch_size) + self.push_q_buffer( + q_reshaped, layer_id=layer.layer_id, batch_size=forward_batch.batch_size + ) # Output tensor o = torch.empty_like(q_reshaped) @@ -376,16 +382,14 @@ def forward_decode( layer.layer_id ) offload_cache = None - + self.push_q_buffer( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), layer_id=layer.layer_id, batch_size=forward_batch.batch_size, ) - q_for_masking = self.get_q_buffer( - layer.layer_id, forward_batch.batch_size - ) - + q_for_masking = self.get_q_buffer(layer.layer_id, forward_batch.batch_size) + if not require_validation: o, metadata = self.forward_paged_hip( query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), @@ -642,7 +646,6 @@ def forward_paged_hip( k: Optional[torch.Tensor] = None, v: Optional[torch.Tensor] = None, query_for_mask: Optional[torch.Tensor] = None, - online_update_cache: bool = False, is_decode: bool = False, ) -> tuple[torch.Tensor, "HiPAttentionOutputMetadata"]: @@ -699,12 +702,14 @@ def forward_paged_hip( require_cache_statistics = True elif os.getenv("HIP_DISABLE_COMPUTE_STATISTICS", "1") == "0": require_cache_statistics = True - + if query_for_mask is not None: query_position_ids = positions.view(batch_size, dst_seq_len) - position_ids = torch.arange( - 0, query_for_mask.shape[1], device=query.device - )[None, :] - (query_for_mask.shape[1] - 1) + query_position_ids + position_ids = ( + torch.arange(0, query_for_mask.shape[1], device=query.device)[None, :] + - (query_for_mask.shape[1] - 1) + + query_position_ids + ) else: position_ids = positions.view(batch_size, dst_seq_len) @@ -725,7 +730,6 @@ def forward_paged_hip( block_table=block_table, cache_seq_lens=seq_lens, position_ids=position_ids, - block_size_k=32 if is_gemma else 64, # BLOCK_CHUNK sliding_window_size=layer_config.sliding_window_size, sink_token_size=layer_config.sink_token_size, @@ -752,7 +756,11 @@ def forward_paged_hip( online_update_cache=online_update_cache, require_cache_statistics=require_cache_statistics, disable_flashdecode=not is_decode, - q_mask=(query_for_mask * sm_scale).to(query.dtype) if query_for_mask is not None else None, + q_mask=( + (query_for_mask * sm_scale).to(query.dtype) + if query_for_mask is not None + else None + ), ) context, metadata = dual_stage_quadratic_hip_attention( @@ -763,6 +771,6 @@ def forward_paged_hip( cached_metadata=cached_metadata, ) context = context.to(query.dtype) - context = context[:, -query.shape[1]:, :, :].contiguous() + context = context[:, -query.shape[1] :, :, :].contiguous() return context.view(N, num_heads, hidden_dims), metadata From 3266890ab6345132de6a36d3e3f202fb07496a50 Mon Sep 17 00:00:00 2001 From: AinL Date: Sat, 25 Jan 2025 15:48:47 +0000 Subject: [PATCH 5/9] roll back to default --- .../attention/hip_attention/hip_config.py | 38 ++++++++++++------- .../sglang/srt/mem_cache/hip_memory_pool.py | 1 - 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_config.py b/python/sglang/srt/layers/attention/hip_attention/hip_config.py index dddd52a8492..6f8f7740c91 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_config.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_config.py @@ -74,19 +74,31 @@ class HiPAttentionConfig: default_factory=lambda: [32, 16, 8] ) using_extend: bool = True - layers: list[HiPAttentionPerLayerConfig] = field(default_factory=lambda: [ - HiPAttentionPerLayerConfig(parsed_json={ - "second_stage_k": 4096, - "sliding_window_size": 32768, - "sink_token_size": 256, - }), - HiPAttentionPerLayerConfig(parsed_json={"second_stage_k": 1024, "sliding_window_size": 1024, "sink_token_size": 256}), - ]) - prefill_layers: list[HiPAttentionPerLayerConfig] = field(default_factory=lambda: [ - HiPAttentionPerLayerConfig(parsed_json={"second_stage_k": 4096, "sliding_window_size": 1024, "sink_token_size": 256}), - HiPAttentionPerLayerConfig(), - ]) - + layers: list[HiPAttentionPerLayerConfig] = field( + default_factory=lambda: [ + HiPAttentionPerLayerConfig( + parsed_json={ + "second_stage_k": 4096, + "sliding_window_size": 1024, + "sink_token_size": 256, + } + ), + HiPAttentionPerLayerConfig(), + ] + ) + prefill_layers: list[HiPAttentionPerLayerConfig] = field( + default_factory=lambda: [ + HiPAttentionPerLayerConfig( + parsed_json={ + "second_stage_k": 4096, + "sliding_window_size": 1024, + "sink_token_size": 256, + } + ), + HiPAttentionPerLayerConfig(), + ] + ) + # deprecated apply_v_dot: bool = False prefill_always_dense: bool = False diff --git a/python/sglang/srt/mem_cache/hip_memory_pool.py b/python/sglang/srt/mem_cache/hip_memory_pool.py index 946c39307f1..29a7ffacec1 100644 --- a/python/sglang/srt/mem_cache/hip_memory_pool.py +++ b/python/sglang/srt/mem_cache/hip_memory_pool.py @@ -2,7 +2,6 @@ import os from dataclasses import dataclass -import os from typing import TYPE_CHECKING, List, Literal, Tuple, Union, Optional, Dict import logging import math From 55d028625d2cc32d71b055f91189f22e5cbe22c0 Mon Sep 17 00:00:00 2001 From: AinL Date: Sat, 25 Jan 2025 16:38:56 +0000 Subject: [PATCH 6/9] RTRT by default (need revert) --- .../layers/attention/hip_attention/hip_radix_attention.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py index 452f968e5f1..9a73cb6ef1a 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py @@ -68,8 +68,11 @@ def __init__(self, model_runner: HiPModelRunner): ) for _ in range(model_runner.model_config.num_hidden_layers) ] + self.q_buffers = None def push_q_buffer(self, q: torch.Tensor, layer_id: int, batch_size: int): + if self.q_buffers is None: + return assert batch_size == 1 q = q.unsqueeze(0) layer_q_buffer = self.q_buffers[layer_id] @@ -747,7 +750,7 @@ def forward_paged_hip( ( "relative" if self.hip_config.apply_v_dot - else ("streaming" if is_dense else "relative") + else ("relative" if is_dense else "relative") ) if layer_config.scan_extend_backend is None else layer_config.scan_extend_backend From c8c03aa60caeadf916e0a1cb0196633063b6b9bb Mon Sep 17 00:00:00 2001 From: AinL Date: Sat, 25 Jan 2025 17:05:10 +0000 Subject: [PATCH 7/9] reverted to STRT --- .../srt/layers/attention/hip_attention/hip_radix_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py index 9a73cb6ef1a..7a4b403e8b4 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py @@ -750,7 +750,7 @@ def forward_paged_hip( ( "relative" if self.hip_config.apply_v_dot - else ("relative" if is_dense else "relative") + else ("streaming" if is_dense else "relative") ) if layer_config.scan_extend_backend is None else layer_config.scan_extend_backend From 04922a8c2a5c193d30ac9bd9f63fd9003ff1fb33 Mon Sep 17 00:00:00 2001 From: AinL Date: Mon, 27 Jan 2025 04:05:20 +0000 Subject: [PATCH 8/9] support qwen2; --- .../hip_attention/hip_radix_attention.py | 1 + python/sglang/srt/models/qwen2.py | 59 +++++++++++++++---- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py index 7a4b403e8b4..1ab0b1e2103 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py @@ -68,6 +68,7 @@ def __init__(self, model_runner: HiPModelRunner): ) for _ in range(model_runner.model_config.num_hidden_layers) ] + # NOTE: disable q caching self.q_buffers = None def push_q_buffer(self, q: torch.Tensor, layer_id: int, batch_size: int): diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 9ad9588c97d..d06b5b5be4b 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -46,7 +46,7 @@ default_weight_loader, kv_cache_scales_loader, ) -from sglang.srt.utils import make_layers +from sglang.srt.utils import make_layers_with_previous_layer Qwen2Config = None @@ -89,6 +89,7 @@ def forward(self, x): class Qwen2Attention(nn.Module): def __init__( self, + config: Qwen2Config, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -97,6 +98,7 @@ def __init__( rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 32768, quant_config: Optional[QuantizationConfig] = None, + previous_layer: Optional["Qwen2Attention"] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -136,19 +138,26 @@ def __init__( quant_config=quant_config, ) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - ) + if previous_layer is None: + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + else: + assert self.head_dim == previous_layer.head_dim + self.rotary_emb = previous_layer.rotary_emb self.attn = RadixAttention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + orig_context_len=getattr( + config, "orig_context_len", max_position_embeddings + ), rope=self.rotary_emb, ) @@ -160,7 +169,13 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) + + # FIXME(geon): find better way to detect if HIP is enabled + if (forward_batch.hip_metadata_cache_pool is None) or ( + not forward_batch.hip_metadata_cache_pool.hip_config.using_extend + ): + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output @@ -172,13 +187,21 @@ def __init__( config: Qwen2Config, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + previous_layer: Optional["Qwen2DecoderLayer"] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) max_position_embeddings = getattr(config, "max_position_embeddings", 32768) self.self_attn = Qwen2Attention( + config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, @@ -187,6 +210,9 @@ def __init__( rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, + previous_layer=( + previous_layer.self_attn if previous_layer is not None else None + ), ) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, @@ -239,13 +265,15 @@ def __init__( config.hidden_size, quant_config=quant_config, ) - self.layers = make_layers( + self.layers = make_layers_with_previous_layer( config.num_hidden_layers, - lambda idx, prefix: Qwen2DecoderLayer( - layer_id=idx, + lambda idx, prefix, previous_layer: Qwen2DecoderLayer( config=config, + layer_id=idx, quant_config=quant_config, + previous_layer=previous_layer, ), + prefix="model.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -264,7 +292,10 @@ def forward( else: hidden_states = input_embeds residual = None + + forward_batch.on_model_start() for i in range(len(self.layers)): + forward_batch.on_layer_start(i) layer = self.layers[i] hidden_states, residual = layer( positions, @@ -272,7 +303,11 @@ def forward( forward_batch, residual, ) + forward_batch.on_layer_end(i) + forward_batch.on_model_end() + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states # If this function is called, it should always initialize KV cache scale From e26012732d472c0789b104d595560f1b7a1cb251 Mon Sep 17 00:00:00 2001 From: AinL Date: Mon, 27 Jan 2025 04:39:09 +0000 Subject: [PATCH 9/9] fmt --- python/sglang/srt/mem_cache/hip_memory_pool.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/mem_cache/hip_memory_pool.py b/python/sglang/srt/mem_cache/hip_memory_pool.py index 29a7ffacec1..e5bf52fc795 100644 --- a/python/sglang/srt/mem_cache/hip_memory_pool.py +++ b/python/sglang/srt/mem_cache/hip_memory_pool.py @@ -1,10 +1,11 @@ from __future__ import annotations -import os -from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Literal, Tuple, Union, Optional, Dict import logging import math +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union + import torch import triton from hip.models.hip_attention.gen3.attention_metadata import (