From d89a970db38e411ebbe864427bca5d94ac134388 Mon Sep 17 00:00:00 2001 From: Geon Park Date: Mon, 27 Jan 2025 20:35:10 +0900 Subject: [PATCH 01/16] cleanup a bit --- .gitignore | 7 ++++- python/sglang/bench_one_batch.py | 1 + python/sglang/srt/configs/model_config.py | 30 +++++++++---------- python/sglang/srt/layers/radix_attention.py | 6 ++-- python/sglang/srt/managers/scheduler.py | 16 +--------- .../sglang/srt/managers/tokenizer_manager.py | 11 +------ python/sglang/srt/managers/tp_worker.py | 1 + .../srt/managers/tp_worker_overlap_thread.py | 24 ++------------- .../srt/model_executor/cuda_graph_runner.py | 1 - .../srt/model_executor/forward_batch_info.py | 5 ---- .../sglang/srt/model_executor/model_runner.py | 6 ++-- python/sglang/srt/models/llama.py | 8 ++--- 12 files changed, 35 insertions(+), 81 deletions(-) diff --git a/.gitignore b/.gitignore index 5df31ae882..1c3119c7a1 100644 --- a/.gitignore +++ b/.gitignore @@ -223,6 +223,11 @@ compile_commands.json *.iml -.vscode/ +# VSCode +.vscode + +1 + +# Profiling data *.nsys-rep *.ncu-rep diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index bc7a9c7a1a..b35a9ff063 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -131,6 +131,7 @@ def load_model(server_args, port_args, tp_rank): is_embedding=server_args.is_embedding, dtype=server_args.dtype, quantization=server_args.quantization, + is_context_extended=server_args.enable_hip_attention, ) model_runner = ModelRunner( model_config=model_config, diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index a8daaeaa36..3d3c49d5b2 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -43,6 +43,7 @@ def __init__( is_embedding: Optional[bool] = None, dtype: str = "auto", quantization: Optional[str] = None, + is_context_extended: Optional[bool] = None, ) -> None: self.model_path = model_path self.revision = revision @@ -70,21 +71,20 @@ def __init__( derived_context_len = get_context_length(self.hf_text_config) if context_length is not None: if context_length > derived_context_len: - # FIXME: ignore this env flag only when HiP + context extension activated - logger.warning( - f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). " - f"This may lead to incorrect model outputs or CUDA errors." - ) - self.context_len = context_length - # if get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"): - # else: - # raise ValueError( - # f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). " - # f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. " - # f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1" - # ) - else: - self.context_len = context_length + if is_context_extended: + pass + elif get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"): + logger.warning( + f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). " + f"This may lead to incorrect model outputs or CUDA errors." + ) + else: + raise ValueError( + f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). " + f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. " + f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1" + ) + self.context_len = context_length else: self.context_len = derived_context_len diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 9409442059..bdeb518cba 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -16,11 +16,11 @@ from typing import TYPE_CHECKING, Optional -import torch from torch import nn -from sglang.srt.layers.rotary_embedding import RotaryEmbedding, get_rope -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +if TYPE_CHECKING: + from sglang.srt.layers.rotary_embedding import RotaryEmbedding + from sglang.srt.model_executor.forward_batch_info import ForwardBatch class RadixAttention(nn.Module): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index fde224e9a5..d8ae7dfb97 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -196,6 +196,7 @@ def __init__( is_embedding=server_args.is_embedding, dtype=server_args.dtype, quantization=server_args.quantization, + is_context_extended=server_args.enable_hip_attention, ) self.is_generation = self.model_config.is_generation @@ -1049,21 +1050,6 @@ def run_batch( model_worker_batch = batch.get_model_worker_batch() logits_output, next_token_ids = ( self.tp_worker.forward_batch_generation(model_worker_batch) - # model_worker_batch = batch.get_model_worker_batch() - # if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: - # # FIXME(geon): handle hip refresh_interval here - # logits_output, next_token_ids = self.tp_worker.forward_batch_generation( - # model_worker_batch - # ) - # elif batch.forward_mode.is_idle(): - # model_worker_batch = batch.get_model_worker_batch() - # self.tp_worker.forward_batch_idle(model_worker_batch) - # return - # else: - # logits_output = None - # if self.skip_tokenizer_init: - # next_token_ids = torch.full( - # (batch.batch_size(),), self.tokenizer.eos_token_id ) else: ( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 66be456784..aaa8e1241f 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -140,6 +140,7 @@ def __init__( is_embedding=server_args.is_embedding, dtype=server_args.dtype, quantization=server_args.quantization, + is_context_extended=server_args.enable_hip_attention, ) self.is_generation = self.model_config.is_generation @@ -315,16 +316,6 @@ async def _tokenize_one_request( ) input_embeds = obj.input_embeds input_ids = obj.input_ids - elif obj.input_ids is None: - input_ids = self.tokenizer.encode(input_text) - - # HACK: Remove duplicate bos tokens - while ( - (len(input_ids) > 1) - and (input_ids[0] == self.tokenizer.bos_token_id) - and (input_ids[1] == self.tokenizer.bos_token_id) - ): - input_ids.pop(0) elif obj.input_ids is not None: input_ids = obj.input_ids else: diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 891cacd34b..3806e7b3c1 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -65,6 +65,7 @@ def __init__( is_embedding=server_args.is_embedding, dtype=server_args.dtype, quantization=server_args.quantization, + is_context_extended=server_args.enable_hip_attention, ) ModelRunnerClass = ModelRunner if server_args.enable_hip_attention: diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 1973f8306b..a93b8bb6ad 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -124,6 +124,7 @@ def forward_thread_func_(self): batch_pt = 0 batch_lists = [None] * 2 + # For keeping track of HiP attention mask refresh decode_index = 0 while True: @@ -134,28 +135,6 @@ def forward_thread_func_(self): model_worker_batch: ModelWorkerBatch if model_worker_batch.forward_mode.is_decode(): if self.hip_mask_refresh_interval is not None: - # NOTE: for debug - # if decode_index % self.hip_mask_refresh_interval == 0: - # model_worker_batch.hip_use_cached_mask = False - # # logger.info(f"Refreshing attention mask for decode index {decode_index}.") - # else: - # model_worker_batch.hip_use_cached_mask = True - # # logger.info(f"Using cached attention mask for decode index {decode_index}.") - - # NOTE: for debug - # if decode_index % 8 == 0: # first stage refresh interval - # model_worker_batch.hip_use_cached_mask = False - # model_worker_batch.hip_metadata_cached_stages = 0 # NOTE: no cached stages - # elif decode_index % 4 == 0: # second stage refresh interval - # model_worker_batch.hip_use_cached_mask = False - # model_worker_batch.hip_metadata_cached_stages = 1 - # elif decode_index % 2 == 0: # third stage refresh interval - # model_worker_batch.hip_use_cached_mask = False - # model_worker_batch.hip_metadata_cached_stages = 2 - # else: - # model_worker_batch.hip_use_cached_mask = True - # model_worker_batch.hip_metadata_cached_stages = None # NOTE: use cache every stage - require_refresh = False for i_stage, refresh_inteval in enumerate( self.hip_mask_refresh_interval @@ -172,6 +151,7 @@ def forward_thread_func_(self): None # NOTE: use cache every stage ) decode_index += 1 + elif model_worker_batch.forward_mode.is_extend(): decode_index = 0 diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index b7e7ea2cf4..169b643436 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -16,7 +16,6 @@ from __future__ import annotations import bisect -import os from contextlib import contextmanager from typing import TYPE_CHECKING, Callable diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 89c3261ec7..aff37773bb 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -345,11 +345,6 @@ def init_new( if model_runner.model_is_mrope: ret.compute_mrope_positions(model_runner, batch) - # Init attention information - ret.req_to_token_pool = model_runner.req_to_token_pool - ret.token_to_kv_pool = model_runner.token_to_kv_pool - ret.attn_backend = model_runner.attn_backend - # Init HiP attention information if hasattr(model_runner, "hip_metadata_cache_pool"): ret.hip_metadata_cache_pool = model_runner.hip_metadata_cache_pool diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 205cab30aa..7e1f119034 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -624,10 +624,9 @@ def init_memory_pool( logging.warning( f"max_total_tokens={max_total_tokens} is larger than the profiled value " f"{self.max_total_num_tokens}. " - # f"Use the given value instead." + f"Use the profiled value instead." ) - # self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens) - self.max_total_num_tokens = max_total_tokens + self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens) if self.max_total_num_tokens <= 0: raise RuntimeError( @@ -689,7 +688,6 @@ def init_memory_pool( device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, ) - logger.info( f"Memory pool end. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 956e6550f0..e8cfd77956 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -38,12 +38,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.rotary_embedding import RotaryEmbedding, get_rope +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, @@ -170,9 +170,7 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, - orig_context_len=getattr( - config, "orig_context_len", max_position_embeddings - ), + orig_context_len=getattr(config, "orig_context_len", max_position_embeddings), rope=self.rotary_emb, ) From e35638ea0e223eeb581805641eddef73dac22455 Mon Sep 17 00:00:00 2001 From: Geon Park Date: Mon, 27 Jan 2025 21:30:29 +0900 Subject: [PATCH 02/16] fix minor bugs --- .../sglang/srt/layers/attention/hip_attention/hip_config.py | 4 +++- python/sglang/srt/layers/radix_attention.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_config.py b/python/sglang/srt/layers/attention/hip_attention/hip_config.py index a4bd28aa99..3deb6a469c 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_config.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_config.py @@ -86,6 +86,7 @@ class HiPAttentionConfig: HiPAttentionPerLayerConfig(), ] ) + prefill_layers: Optional[list[HiPAttentionPerLayerConfig]] = None # deprecated apply_v_dot: bool = False @@ -140,8 +141,9 @@ def __post_init__(self, parsed_json: dict | None): HiPAttentionPerLayerConfig(parsed_json=layer) for layer in parsed_json["layers"] ] - self.prefill_layers = self.layers parsed_json.pop("layers") + if self.prefill_layers is None: + self.prefill_layers = self.layers if "prefill_layers" in parsed_json: self.prefill_layers = [ HiPAttentionPerLayerConfig(parsed_json=layer) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index bdeb518cba..d065d14864 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -18,8 +18,9 @@ from torch import nn +from sglang.srt.layers.rotary_embedding import RotaryEmbedding + if TYPE_CHECKING: - from sglang.srt.layers.rotary_embedding import RotaryEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch From f4fee38fa03f4e144e898f143964f51af268b11b Mon Sep 17 00:00:00 2001 From: Geon Park Date: Mon, 27 Jan 2025 22:45:16 +0900 Subject: [PATCH 03/16] merge hip_cuda_graph_runner --- .../attention/hip_attention/__init__.py | 3 - .../hip_attention/hip_cuda_graph_runner.py | 260 ------------------ .../srt/model_executor/cuda_graph_runner.py | 77 ++++-- .../sglang/srt/model_executor/model_runner.py | 6 +- 4 files changed, 54 insertions(+), 292 deletions(-) delete mode 100644 python/sglang/srt/layers/attention/hip_attention/hip_cuda_graph_runner.py diff --git a/python/sglang/srt/layers/attention/hip_attention/__init__.py b/python/sglang/srt/layers/attention/hip_attention/__init__.py index f3c998575d..99c783c073 100644 --- a/python/sglang/srt/layers/attention/hip_attention/__init__.py +++ b/python/sglang/srt/layers/attention/hip_attention/__init__.py @@ -1,4 +1 @@ -from .hip_cuda_graph_runner import HiPCudaGraphRunner - -# from .hip_offload_kv_pool_mha import MHATokenToHiPOffloadKVPool from .hip_radix_attention import HiPRadixAttentionBackend diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_cuda_graph_runner.py b/python/sglang/srt/layers/attention/hip_attention/hip_cuda_graph_runner.py deleted file mode 100644 index a89a5d760b..0000000000 --- a/python/sglang/srt/layers/attention/hip_attention/hip_cuda_graph_runner.py +++ /dev/null @@ -1,260 +0,0 @@ -from __future__ import annotations - -import bisect -from typing import TYPE_CHECKING, Callable - -import torch -import tqdm - -from sglang.srt.distributed import get_tensor_model_parallel_rank -from sglang.srt.distributed.parallel_state import graph_capture -from sglang.srt.layers.logits_processor import ( - LogitsMetadata, - LogitsProcessor, - LogitsProcessorOutput, -) -from sglang.srt.layers.torchao_utils import save_gemlite_cache -from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner, patch_model -from sglang.srt.model_executor.forward_batch_info import ( - CaptureHiddenMode, - ForwardBatch, - ForwardMode, -) - -if TYPE_CHECKING: - from sglang.srt.model_executor.hip_model_runner import HiPModelRunner - - -class HiPCudaGraphRunner(CudaGraphRunner): - model_runner: "HiPModelRunner" - - def __init__(self, model_runner: "HiPModelRunner"): - super().__init__(model_runner) - - def can_run(self, forward_batch: ForwardBatch): - use_cached_mask = forward_batch.hip_use_cached_mask - num_stage_cached = forward_batch.hip_metadata_cached_stage - - if self.enable_dp_attention: - min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max( - forward_batch.global_num_tokens - ) - is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( - ( - min_num_tokens == max_num_tokens - and (max_num_tokens, use_cached_mask, num_stage_cached) - in self.graphs - ) - if self.disable_padding - else max_num_tokens <= self.max_bs - ) - else: - is_bs_supported = ( - (forward_batch.batch_size, use_cached_mask, num_stage_cached) - in self.graphs - if self.disable_padding - else forward_batch.batch_size <= self.max_bs - ) - - # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0) - # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph - # because the full_text_row_masked_out_mask tensor will always be ones - is_encoder_lens_supported = ( - torch.all(forward_batch.encoder_lens > 0) - if self.is_encoder_decoder - else True - ) - return is_bs_supported and is_encoder_lens_supported - - def capture(self): - with graph_capture() as graph_capture_context: - num_stages = len(self.model_runner.hip_attention_config.layers[0].stages) - for layer_config in self.model_runner.hip_attention_config.layers: - assert num_stages == len(layer_config.stages) - cache_configs = [(True, None)] - for i_stage in range(num_stages): - cache_configs.append((False, i_stage)) - - self.stream = graph_capture_context.stream - capture_bs = ( - tqdm.tqdm(self.capture_bs) - if get_tensor_model_parallel_rank() == 0 - else self.capture_bs - ) - for bs in capture_bs: - for use_cached_mask, num_cached_stages in cache_configs: - with patch_model( - self.model_runner.model, - bs in self.compile_bs, - bs, - self.model_runner.tp_group, - ) as forward: - ( - graph, - output_buffers, - ) = self.capture_one_batch_size( - bs, forward, use_cached_mask, num_cached_stages - ) - graph_handle = (bs, use_cached_mask, num_cached_stages) - self.graphs[graph_handle] = graph - self.output_buffers[graph_handle] = output_buffers - # Save gemlite cache after each capture - save_gemlite_cache() - - def capture_one_batch_size( - self, - bs: int, - forward: Callable, - hip_use_cached_mask: bool = False, - hip_num_cached_stages: int = 0, - ): - graph = torch.cuda.CUDAGraph() - stream = self.stream - num_tokens = bs * self.num_tokens_per_bs - - # Common inputs - input_ids = self.input_ids[:num_tokens] - req_pool_indices = self.req_pool_indices[:bs] - seq_lens = self.seq_lens[:bs] - out_cache_loc = self.out_cache_loc[:num_tokens] - positions = self.positions[:num_tokens] - if self.is_encoder_decoder: - encoder_lens = self.encoder_lens[:bs] - else: - encoder_lens = None - mrope_positions = self.mrope_positions[:, :bs] - - if self.enable_dp_attention: - global_num_tokens = [bs] * self.tp_size - gathered_buffer = self.gathered_buffer[: bs * self.tp_size] - else: - global_num_tokens = None - gathered_buffer = None - - spec_info = self.get_spec_info(num_tokens, positions) - - forward_batch = ForwardBatch( - forward_mode=ForwardMode.DECODE, - batch_size=bs, - input_ids=input_ids, - req_pool_indices=req_pool_indices, - seq_lens=seq_lens, - req_to_token_pool=self.model_runner.req_to_token_pool, - token_to_kv_pool=self.model_runner.token_to_kv_pool, - attn_backend=self.model_runner.attn_backend, - hip_metadata_cache_pool=self.model_runner.hip_metadata_cache_pool, - hip_use_cached_mask=hip_use_cached_mask, - hip_metadata_cached_stage=hip_num_cached_stages, - out_cache_loc=out_cache_loc, - seq_lens_sum=seq_lens.sum(), - encoder_lens=encoder_lens, - return_logprob=False, - top_logprobs_nums=[0] * bs, - positions=positions, - global_num_tokens=global_num_tokens, - gathered_buffer=gathered_buffer, - mrope_positions=mrope_positions, - spec_algorithm=self.model_runner.spec_algorithm, - spec_info=spec_info, - capture_hidden_mode=( - spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL - ), - ) - - # Attention backend - self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( - bs, - num_tokens, - req_pool_indices, - seq_lens, - encoder_lens, - forward_batch.forward_mode, - forward_batch.spec_info, - ) - - # Run and capture - def run_once(): - logits_output = forward(input_ids, forward_batch.positions, forward_batch) - return logits_output.next_token_logits, logits_output.hidden_states - - for _ in range(2): - torch.cuda.synchronize() - self.model_runner.tp_group.barrier() - - run_once() - - torch.cuda.synchronize() - self.model_runner.tp_group.barrier() - - torch.cuda.synchronize() - self.model_runner.tp_group.barrier() - - with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream): - out = run_once() - - torch.cuda.synchronize() - self.model_runner.tp_group.barrier() - - self.graph_memory_pool = graph.pool() - return graph, out - - def replay(self, forward_batch: ForwardBatch): - assert forward_batch.out_cache_loc is not None - raw_bs = forward_batch.batch_size - raw_num_token = raw_bs * self.num_tokens_per_bs - - # Pad - if self.enable_dp_attention: - index = bisect.bisect_left( - self.capture_bs, max(forward_batch.global_num_tokens) - ) - else: - index = bisect.bisect_left(self.capture_bs, raw_bs) - bs = self.capture_bs[index] - if bs != raw_bs: - self.seq_lens.fill_(1) - self.out_cache_loc.zero_() - - # Common inputs - self.input_ids[:raw_bs].copy_(forward_batch.input_ids) - self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) - self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) - self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc) - self.positions[:raw_num_token].copy_(forward_batch.positions) - - if self.is_encoder_decoder: - self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) - if forward_batch.mrope_positions is not None: - self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) - - if hasattr(forward_batch.spec_info, "hidden_states"): - self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states - - # Attention backend - self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( - bs, - self.req_pool_indices, - self.seq_lens, - forward_batch.seq_lens_sum + (bs - raw_bs), - self.encoder_lens, - forward_batch.forward_mode, - forward_batch.spec_info, - ) - - # Replay - key = ( - bs, - forward_batch.hip_use_cached_mask, - forward_batch.hip_metadata_cached_stage, - ) - self.graphs[key].replay() - next_token_logits, hidden_states = self.output_buffers[key] - - # Extract logprobs - logits_output = LogitsProcessorOutput( - next_token_logits=next_token_logits[:raw_num_token], - hidden_states=( - hidden_states[:raw_num_token] if hidden_states is not None else None - ), - ) - return logits_output diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 169b643436..096396daee 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -118,6 +118,9 @@ def __init__(self, model_runner: "ModelRunner"): self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention + self.enable_hip_attention = self.model_runner.server_args.enable_hip_attention + if self.enable_hip_attention: + self.hip_config = self.model_runner.hip_attention_config self.tp_size = self.model_runner.tp_size self.dp_size = self.model_runner.server_args.dp_size @@ -250,16 +253,20 @@ def can_run(self, forward_batch: ForwardBatch): forward_batch.global_num_tokens ) is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( - (min_num_tokens == max_num_tokens and max_num_tokens in self.graphs) + (min_num_tokens == max_num_tokens and (max_num_tokens,) in self.graphs) if self.disable_padding else max_num_tokens <= self.max_bs ) else: - is_bs_supported = ( - forward_batch.batch_size in self.graphs - if self.disable_padding - else forward_batch.batch_size <= self.max_bs - ) + if self.disable_padding: + index = bisect.bisect_left(self.capture_bs, forward_batch.batch_size) + if index < len(self.capture_bs): + found_bs = self.capture_bs[index] + is_bs_supported = found_bs == forward_batch.batch_size + else: + is_bs_supported = False + else: + is_bs_supported = forward_batch.batch_size <= self.max_bs # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0) # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph @@ -280,23 +287,35 @@ def capture(self): else self.capture_bs ) for bs in capture_range: - with patch_model( - self.model_runner.model, - bs in self.compile_bs, - bs, - self.model_runner.tp_group, - ) as forward: - ( - graph, - output_buffers, - ) = self.capture_one_batch_size(bs, forward) - self.graphs[bs] = graph - self.output_buffers[bs] = output_buffers - - # Save gemlite cache after each capture - save_gemlite_cache() + for capture_config in self.capture_configs(): + with patch_model( + self.model_runner.model, + bs in self.compile_bs, + bs, + self.model_runner.tp_group, + ) as forward: + ( + graph, + output_buffers, + ) = self.capture_one_batch_size(bs, forward, capture_config) + graph_handle = (bs, *capture_config) + self.graphs[graph_handle] = graph + self.output_buffers[graph_handle] = output_buffers + + # Save gemlite cache after each capture + save_gemlite_cache() + + def capture_configs(self): + if self.enable_hip_attention: + num_stages = len(self.hip_config.layers[0].stages) + cache_configs = [(True, None)] # (use_cached_mask, num_stage_cached) + for i_stage in range(num_stages): + cache_configs.append((False, i_stage)) + return cache_configs + else: + return [()] - def capture_one_batch_size(self, bs: int, forward: Callable): + def capture_one_batch_size(self, bs: int, forward: Callable, capture_config: tuple): graph = torch.cuda.CUDAGraph() stream = self.stream num_tokens = bs * self.num_tokens_per_bs @@ -322,6 +341,10 @@ def capture_one_batch_size(self, bs: int, forward: Callable): spec_info = self.get_spec_info(num_tokens, positions) + hip_use_cached_mask = hip_num_cached_stages = None + if self.enable_hip_attention: + hip_use_cached_mask, hip_num_cached_stages = capture_config + forward_batch = ForwardBatch( forward_mode=self.capture_forward_mode, batch_size=bs, @@ -331,6 +354,9 @@ def capture_one_batch_size(self, bs: int, forward: Callable): req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, attn_backend=self.model_runner.attn_backend, + hip_metadata_cache_pool=self.model_runner.hip_metadata_cache_pool, + hip_use_cached_mask=hip_use_cached_mask, + hip_metadata_cached_stage=hip_num_cached_stages, out_cache_loc=out_cache_loc, seq_lens_sum=seq_lens.sum(), encoder_lens=encoder_lens, @@ -428,8 +454,11 @@ def replay(self, forward_batch: ForwardBatch): ) # Replay - self.graphs[bs].replay() - next_token_logits, hidden_states = self.output_buffers[bs] + graph_handle = (bs,) + if self.enable_hip_attention: + graph_handle = (bs, forward_batch.hip_use_cached_mask, forward_batch.hip_metadata_cached_stage) + self.graphs[graph_handle].replay() + next_token_logits, hidden_states = self.output_buffers[graph_handle] logits_output = LogitsProcessorOutput( next_token_logits=next_token_logits[:raw_num_token], diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7e1f119034..a26a40c9de 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -745,7 +745,6 @@ def init_double_sparsity_channel_config(self, selected_channel): def init_cuda_graphs(self): """Capture cuda graphs.""" - from sglang.srt.layers.attention.hip_attention import HiPCudaGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner self.cuda_graph_runner = None @@ -758,11 +757,8 @@ def init_cuda_graphs(self): return tic = time.time() - CudaGraphRunnerClass = CudaGraphRunner - if self.server_args.enable_hip_attention: - CudaGraphRunnerClass = HiPCudaGraphRunner logger.info("Capture cuda graph begin. This can take up to several minutes.") - self.cuda_graph_runner = CudaGraphRunnerClass(self) + self.cuda_graph_runner = CudaGraphRunner(self) logger.info( f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s, " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" From aac3379f889eaed2f967d480801f51ab74035f32 Mon Sep 17 00:00:00 2001 From: Geon Park Date: Tue, 28 Jan 2025 00:10:26 +0900 Subject: [PATCH 04/16] merge hip_model_runner into model_runner --- .../hip_attention/hip_radix_attention.py | 5 +- python/sglang/srt/managers/tp_worker.py | 6 +- .../srt/model_executor/hip_model_runner.py | 86 ------------------- .../sglang/srt/model_executor/model_runner.py | 33 ++++++- 4 files changed, 35 insertions(+), 95 deletions(-) delete mode 100644 python/sglang/srt/model_executor/hip_model_runner.py diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py index f65160508b..b1260eefd5 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py @@ -17,11 +17,10 @@ from sglang.srt.layers.attention import AttentionBackend from sglang.srt.mem_cache.hip_offload_kv_pool_mha import MHATokenToHiPOffloadKVPool -from sglang.srt.model_executor.forward_batch_info import ForwardBatch if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention - from sglang.srt.model_executor.hip_model_runner import HiPModelRunner + from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.layers.attention.hip_attention.hip_config import HiPAttentionConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.spec_info import SpecInfo @@ -45,7 +44,7 @@ class WrapperDispatch(Enum): class HiPRadixAttentionBackend(AttentionBackend): - def __init__(self, model_runner: HiPModelRunner): + def __init__(self, model_runner: ModelRunner): super().__init__() self.hip_config: HiPAttentionConfig = model_runner.hip_attention_config diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 3806e7b3c1..45160d66b0 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -28,7 +28,6 @@ ) from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_executor.hip_model_runner import HiPModelRunner from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed @@ -67,10 +66,7 @@ def __init__( quantization=server_args.quantization, is_context_extended=server_args.enable_hip_attention, ) - ModelRunnerClass = ModelRunner - if server_args.enable_hip_attention: - ModelRunnerClass = HiPModelRunner - self.model_runner = ModelRunnerClass( + self.model_runner = ModelRunner( model_config=self.model_config, mem_fraction_static=server_args.mem_fraction_static, gpu_id=gpu_id, diff --git a/python/sglang/srt/model_executor/hip_model_runner.py b/python/sglang/srt/model_executor/hip_model_runner.py deleted file mode 100644 index e77972481f..0000000000 --- a/python/sglang/srt/model_executor/hip_model_runner.py +++ /dev/null @@ -1,86 +0,0 @@ -import json -import logging -from typing import Optional - -import torch - -from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.layers.attention.hip_attention import HiPRadixAttentionBackend -from sglang.srt.layers.attention.hip_attention.hip_config import HiPAttentionConfig -from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.mem_cache.hip_memory_pool import HiPMetadataCachePool -from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_executor.model_runner import ModelRunner -from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import get_available_gpu_memory - -logger = logging.getLogger(__name__) - - -class HiPModelRunner(ModelRunner): - hip_attention_config: HiPAttentionConfig - - def __init__( - self, - model_config: ModelConfig, - mem_fraction_static: float, - gpu_id: int, - tp_rank: int, - tp_size: int, - nccl_port: int, - server_args: ServerArgs, - is_draft_worker: bool = False, - ): - if server_args.enable_hip_attention: - logger.info("HIP attention is turned on.") - server_args.attention_backend = "hip_attention" - self.init_hip_attention_config(server_args.hip_attention_config) - - super().__init__( - model_config=model_config, - mem_fraction_static=mem_fraction_static, - gpu_id=gpu_id, - tp_rank=tp_rank, - tp_size=tp_size, - nccl_port=nccl_port, - server_args=server_args, - is_draft_worker=is_draft_worker, - ) - - def init_attention_backend(self): - if self.server_args.enable_hip_attention: - self.attn_backend = HiPRadixAttentionBackend(self) - else: - super().init_attention_backend() - - def init_hip_attention_config(self, hip_attention_config): - if hip_attention_config is None: - hip_attention_config = {} - elif hip_attention_config.startswith("{"): - hip_attention_config = json.loads(hip_attention_config) - else: - with open(hip_attention_config, "r") as f: - hip_attention_config = json.load(f) - self.hip_attention_config = HiPAttentionConfig(parsed_json=hip_attention_config) - - def init_memory_pool( - self, - total_gpu_memory: int, - max_num_reqs: Optional[int] = None, - max_total_tokens: Optional[int] = None, - ): - super().init_memory_pool(total_gpu_memory, max_num_reqs, max_total_tokens) - - if self.server_args.enable_hip_attention: - self.hip_metadata_cache_pool = HiPMetadataCachePool( - query_head_num=self.model_config.num_attention_heads - // self.server_args.tp_size, - layer_num=self.model_config.num_hidden_layers, - context_length=self.model_config.context_len, - device=self.device, - hip_config=self.hip_attention_config, - ) - logger.info( - f"Memory + HiP pool end. " - f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" - ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a26a40c9de..e2d2c69f62 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -36,6 +36,8 @@ from sglang.srt.hf_transformers_utils import get_context_length, update_context_length from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend +from sglang.srt.layers.attention.hip_attention import HiPRadixAttentionBackend +from sglang.srt.layers.attention.hip_attention.hip_config import HiPAttentionConfig from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.dp_attention import ( @@ -48,6 +50,7 @@ from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.mem_cache.hip_memory_pool import HiPMetadataCachePool from sglang.srt.mem_cache.hip_offload_kv_pool_mha import MHATokenToHiPOffloadKVPool from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, @@ -129,6 +132,11 @@ def __init__( self.server_args.ds_heavy_channel_type ) + elif server_args.enable_hip_attention: + logger.info("HIP attention is turned on.") + server_args.attention_backend = "hip_attention" + self.init_hip_attention_config(server_args.hip_attention_config) + if self.is_multimodal: self.mem_fraction_static *= 0.95 logger.info( @@ -688,6 +696,17 @@ def init_memory_pool( device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, ) + + self.hip_metadata_cache_pool = None + if self.server_args.enable_hip_attention: + self.hip_metadata_cache_pool = HiPMetadataCachePool( + query_head_num=self.model_config.num_attention_heads // self.server_args.tp_size, + layer_num=self.model_config.num_hidden_layers, + context_length=self.model_config.context_len, + device=self.device, + hip_config=self.hip_attention_config, + ) + logger.info( f"Memory pool end. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" @@ -704,7 +723,9 @@ def init_cublas(self): def init_attention_backend(self): """Init attention kernel backend.""" - if self.server_args.attention_backend == "flashinfer": + if self.server_args.enable_hip_attention: + self.attn_backend = HiPRadixAttentionBackend(self) + elif self.server_args.attention_backend == "flashinfer": self.attn_backend = FlashInferAttnBackend(self) elif self.server_args.attention_backend == "triton": assert self.sliding_window_size is None, ( @@ -743,6 +764,16 @@ def init_double_sparsity_channel_config(self, selected_channel): .cuda() ) + def init_hip_attention_config(self, hip_attention_config): + if hip_attention_config is None: + hip_attention_config = {} + elif hip_attention_config.startswith("{"): + hip_attention_config = json.loads(hip_attention_config) + else: + with open(hip_attention_config, "r") as f: + hip_attention_config = json.load(f) + self.hip_attention_config = HiPAttentionConfig(parsed_json=hip_attention_config) + def init_cuda_graphs(self): """Capture cuda graphs.""" from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner From 746c737c3a70d9e957623f929cb10d0cadd85ed1 Mon Sep 17 00:00:00 2001 From: Geon Park Date: Tue, 28 Jan 2025 00:58:06 +0900 Subject: [PATCH 05/16] move hip_config and hip_memory_pool to hip-attention repo --- .../attention/hip_attention/hip_config.py | 163 -------- .../hip_attention/hip_radix_attention.py | 25 +- .../sglang/srt/mem_cache/hip_memory_pool.py | 371 ------------------ .../srt/mem_cache/hip_offload_kv_pool_mha.py | 11 +- .../srt/model_executor/forward_batch_info.py | 3 +- .../sglang/srt/model_executor/model_runner.py | 9 +- 6 files changed, 15 insertions(+), 567 deletions(-) delete mode 100644 python/sglang/srt/layers/attention/hip_attention/hip_config.py delete mode 100644 python/sglang/srt/mem_cache/hip_memory_pool.py diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_config.py b/python/sglang/srt/layers/attention/hip_attention/hip_config.py deleted file mode 100644 index 3deb6a469c..0000000000 --- a/python/sglang/srt/layers/attention/hip_attention/hip_config.py +++ /dev/null @@ -1,163 +0,0 @@ -import warnings -from dataclasses import InitVar, dataclass, field -from typing import List, Optional, Union - -from hip.models.hip_attention.gen3.attention_metadata import ScanStage - -_DEFAULT_STAGES = [ - ScanStage( - stage_block_size_q=64, - stage_block_stride_q=4, - stage_chunk_size=256, - stage_k=None, - stage_stride=1, - ), - ScanStage( - stage_block_size_q=64, - stage_block_stride_q=4, - stage_chunk_size=32, - stage_k=32768, - stage_stride=1, - ), - ScanStage( - stage_block_size_q=64, - stage_block_stride_q=1, - stage_chunk_size=8, - stage_k=8192, - stage_stride=1, - ), -] - - -@dataclass -class HiPAttentionPerLayerConfig: - second_stage_k: int = 2048 - sliding_window_size: int = 1024 - sink_token_size: int = 256 - sa_extend_backend: str = "streaming" - scan_extend_backend: Optional[str] = None - stages: list[ScanStage] = field(default_factory=lambda: _DEFAULT_STAGES) - - parsed_json: InitVar[dict | None] = None - - def __post_init__(self, parsed_json: dict | None): - super().__init__() - if parsed_json is not None: - if "second_stage_k" in parsed_json: - self.second_stage_k = parsed_json["second_stage_k"] - parsed_json.pop("second_stage_k") - if "sliding_window_size" in parsed_json: - self.sliding_window_size = parsed_json["sliding_window_size"] - parsed_json.pop("sliding_window_size") - if "sink_token_size" in parsed_json: - self.sink_token_size = parsed_json["sink_token_size"] - parsed_json.pop("sink_token_size") - if "sa_extend_backend" in parsed_json: - self.sa_extend_backend = parsed_json["sa_extend_backend"] - parsed_json.pop("sa_extend_backend") - if "scan_extend_backend" in parsed_json: - self.scan_extend_backend = parsed_json["scan_extend_backend"] - parsed_json.pop("scan_extend_backend") - if "stages" in parsed_json: - self.stages = [ScanStage(**stage) for stage in parsed_json["stages"]] - parsed_json.pop("stages") - if parsed_json: - raise ValueError(f"Unknown keys in json: {parsed_json.keys()}") - - -@dataclass -class HiPAttentionConfig: - dense_layers: list[int] = field(default_factory=lambda: [0, 1, 2]) - block_sparse_block_size_q: int = 64 - metadata_cache_max_batch_size: int = 32 - mask_refresh_interval: Union[int, List[int]] = field( - default_factory=lambda: [32, 16, 8] - ) - using_extend: bool = True - layers: list[HiPAttentionPerLayerConfig] = field( - default_factory=lambda: [ - HiPAttentionPerLayerConfig( - parsed_json={ - "second_stage_k": 4096, - "sliding_window_size": 1024, - "sink_token_size": 256, - } - ), - HiPAttentionPerLayerConfig(), - ] - ) - prefill_layers: Optional[list[HiPAttentionPerLayerConfig]] = None - - # deprecated - apply_v_dot: bool = False - prefill_always_dense: bool = False - decode_always_dense: bool = False - force_dense: bool = False - prefill_dense_threshold: int = 8192 - - parsed_json: InitVar[dict | None] = None - - def __post_init__(self, parsed_json: dict | None): - super().__init__() - - if parsed_json is not None: - if "apply_v_dot" in parsed_json: - self.apply_v_dot = parsed_json["apply_v_dot"] - parsed_json.pop("apply_v_dot") - if "dense_layers" in parsed_json: - self.dense_layers = parsed_json["dense_layers"] - parsed_json.pop("dense_layers") - if "prefill_always_dense" in parsed_json: - self.prefill_always_dense = parsed_json["prefill_always_dense"] - parsed_json.pop("prefill_always_dense") - if "decode_always_dense" in parsed_json: - self.decode_always_dense = parsed_json["decode_always_dense"] - parsed_json.pop("decode_always_dense") - if "force_dense" in parsed_json: - self.force_dense = parsed_json["force_dense"] - parsed_json.pop("force_dense") - if "prefill_dense_threshold" in parsed_json: - self.prefill_dense_threshold = parsed_json["prefill_dense_threshold"] - parsed_json.pop("prefill_dense_threshold") - if "block_sparse_block_size_q" in parsed_json: - self.block_sparse_block_size_q = parsed_json[ - "block_sparse_block_size_q" - ] - parsed_json.pop("block_sparse_block_size_q") - if "metadata_cache_max_batch_size" in parsed_json: - self.metadata_cache_max_batch_size = parsed_json[ - "metadata_cache_max_batch_size" - ] - parsed_json.pop("metadata_cache_max_batch_size") - if "mask_refresh_interval" in parsed_json: - assert isinstance(parsed_json["mask_refresh_interval"], (int, list)) - self.mask_refresh_interval = parsed_json["mask_refresh_interval"] - parsed_json.pop("mask_refresh_interval") - if "using_extend" in parsed_json: - self.using_extend = parsed_json["using_extend"] - parsed_json.pop("using_extend") - if "layers" in parsed_json: - self.layers = [ - HiPAttentionPerLayerConfig(parsed_json=layer) - for layer in parsed_json["layers"] - ] - parsed_json.pop("layers") - if self.prefill_layers is None: - self.prefill_layers = self.layers - if "prefill_layers" in parsed_json: - self.prefill_layers = [ - HiPAttentionPerLayerConfig(parsed_json=layer) - for layer in parsed_json["prefill_layers"] - ] - parsed_json.pop("prefill_layers") - if parsed_json: - raise ValueError(f"Unknown keys in json: {parsed_json.keys()}") - - num_stages = len(self.layers[0].stages) - for layer_config in self.layers: - assert num_stages == len(layer_config.stages) - - if isinstance(self.mask_refresh_interval, int): - self.mask_refresh_interval = [ - self.mask_refresh_interval, - ] * num_stages diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py index b1260eefd5..4dd895a1ed 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py @@ -21,19 +21,9 @@ if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.layers.attention.hip_attention.hip_config import HiPAttentionConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.spec_info import SpecInfo -from hip.models.hip_attention.gen3.attention_extend import ( - dual_stage_quadratic_hip_attention, -) -from hip.models.hip_attention.gen3.attention_metadata import ( - HiPAttentionArgs, - HiPAttentionOutputMetadata, -) -from hip.models.hip_attention.gen3.uvm_gpu_cache import HiPOffloadCache - logger = logging.getLogger(__name__) @@ -47,7 +37,7 @@ class HiPRadixAttentionBackend(AttentionBackend): def __init__(self, model_runner: ModelRunner): super().__init__() - self.hip_config: HiPAttentionConfig = model_runner.hip_attention_config + self.hip_config: "HiPAttentionConfig" = model_runner.hip_attention_config self.max_context_len = model_runner.model_config.context_len @@ -99,8 +89,6 @@ def forward_extend( else forward_batch.encoder_out_cache_loc ) - # logger.info(f'HiP attention is used in prompting (layer {layer.layer_id})!', stacklevel=0) - is_offload_cache = isinstance( forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool ) @@ -293,8 +281,6 @@ def forward_decode( else forward_batch.encoder_out_cache_loc ) - # logger.info(f'HiP attention is used in decoding (layer {layer.layer_id})!', stacklevel=0) - is_offload_cache = isinstance( forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool ) @@ -588,13 +574,13 @@ def forward_paged_hip( batch_size: int, k_cache: Optional[torch.Tensor], v_cache: Optional[torch.Tensor], - offload_cache: Optional[HiPOffloadCache], + offload_cache: Optional["HiPOffloadCache"], positions: torch.Tensor, seq_lens: torch.Tensor, req_to_tokens: torch.Tensor, req_pool_indices: torch.Tensor, layer: RadixAttention, - cached_metadata: Optional[HiPAttentionOutputMetadata] = None, + cached_metadata: Optional["HiPAttentionOutputMetadata"] = None, is_dense: bool = False, k: Optional[torch.Tensor] = None, v: Optional[torch.Tensor] = None, @@ -653,6 +639,11 @@ def forward_paged_hip( elif os.getenv("HIP_DISABLE_COMPUTE_STATISTICS", "1") == "0": require_cache_statistics = True + from hip.models.hip_attention.gen3.attention_extend import ( + dual_stage_quadratic_hip_attention, + ) + from hip.models.hip_attention.gen3.attention_metadata import HiPAttentionArgs + args = HiPAttentionArgs( k_cache=( k_cache.view(torch.uint8) diff --git a/python/sglang/srt/mem_cache/hip_memory_pool.py b/python/sglang/srt/mem_cache/hip_memory_pool.py deleted file mode 100644 index 5ed18b6714..0000000000 --- a/python/sglang/srt/mem_cache/hip_memory_pool.py +++ /dev/null @@ -1,371 +0,0 @@ -from __future__ import annotations - -import logging -import math -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union - -import torch -import triton -from hip.models.hip_attention.gen3.attention_metadata import ( - HiPAttentionCacheAccessStatistics, - HiPAttentionOutputMetadata, - HiPAttentionStageInputCache, -) - -if TYPE_CHECKING: - from sglang.srt.layers.attention.hip_attention.hip_config import HiPAttentionConfig - -logger = logging.getLogger(__name__) - - -@dataclass -class CachedBuffer: - buffer: torch.Tensor - batch_format: Literal["BH", "B,1,H"] - dtype: torch.dtype - - def get(self, batch_size: int, head_size: int) -> torch.Tensor: - if self.batch_format == "BH": - return self.buffer[: batch_size * head_size].to(self.dtype, copy=True) - elif self.batch_format == "B,1,H": - return self.buffer[:batch_size, :, :head_size].to(self.dtype, copy=True) - else: - raise Exception() - - def set(self, value: torch.Tensor): - if self.batch_format == "BH": - self.buffer[: value.shape[0]].copy_(value.to(self.buffer.dtype)) - elif self.batch_format == "B,1,H": - self.buffer[: value.shape[0], :, : value.shape[2]].copy_( - value.to(self.buffer.dtype) - ) - else: - raise Exception() - - -class HiPMetadataCachePool: - cache: List[Dict[str, CachedBuffer]] - - def __init__( - self, - query_head_num: int, - layer_num: int, - context_length: int, - device: str, - hip_config: HiPAttentionConfig, - ): - self.hip_config = hip_config - self.layer_num = layer_num - self.cache = [{} for _ in range(layer_num)] - self.head_num = query_head_num - self.max_batch_size = hip_config.metadata_cache_max_batch_size - self.device = device - self.allocated_gpu_bytes = 0 - self.layer_configs = {} - - for layer_idx in range(layer_num): - require_dense = layer_idx in hip_config.dense_layers - if len(hip_config.layers) == 2: - layer_config = hip_config.layers[0 if require_dense else 1] - else: - layer_config = hip_config.layers[layer_idx] - self.layer_configs[layer_idx] = layer_config - - n_chunks = triton.cdiv( - layer_config.second_stage_k, layer_config.stages[-1].stage_chunk_size - ) - - num_q_blocks = 1 - self.init_buffer( - layer_idx, - "indices", - ( - num_q_blocks, - n_chunks, - ), - torch.int64, - store_dtype=torch.uint32, - ) - self.init_buffer(layer_idx, "ks", (num_q_blocks,), torch.int64) - self.init_buffer( - layer_idx, - "ks_count", - ( - num_q_blocks, - 1, - ), - torch.int64, - ) - self.init_buffer( - layer_idx, - "ks_start_end", - ( - num_q_blocks, - 2, - ), - torch.int64, - ) - - self.init_buffer( - layer_idx, "mask_access_count", (num_q_blocks,), torch.int64 - ) - self.init_buffer( - layer_idx, "mask_unique_access_count", (num_q_blocks,), torch.int64 - ) - self.init_buffer( - layer_idx, "mask_cache_miss_count", (num_q_blocks,), torch.int64 - ) - - self.init_buffer(layer_idx, "sa_access_count", (num_q_blocks,), torch.int64) - self.init_buffer( - layer_idx, "sa_unique_access_count", (num_q_blocks,), torch.int64 - ) - self.init_buffer( - layer_idx, "sa_cache_miss_count", (num_q_blocks,), torch.int64 - ) - - for i_stage, stage in enumerate(layer_config.stages): - if i_stage > 0: - max_context_length = ( - context_length - - layer_config.sliding_window_size - - layer_config.sink_token_size - ) - chunk_count = ( - min(stage.stage_k, max_context_length) // stage.stage_chunk_size - ) - self.init_buffer( - layer_idx, - f"stage_{i_stage}_indices_left", - [ - chunk_count, - ], - torch.int64, - "B,1,H", - torch.uint32, - ) - self.init_buffer( - layer_idx, - f"stage_{i_stage}_indices_right", - [ - chunk_count, - ], - torch.int64, - "B,1,H", - torch.uint32, - ) - self.init_buffer( - layer_idx, - f"stage_{i_stage}_out_scores", - [ - chunk_count, - ], - torch.float32, - "B,1,H", - torch.bfloat16, - ) - - self.allocated_gpu_bytes = self.compute_allocated_bytes() - logger.info( - f"Allocated HiP metadata cache pool size: {self.allocated_gpu_bytes / 1024 / 1024:.2f} MB" - ) - - def compute_allocated_bytes(self): - t = 0 - for layer_buffer in self.cache: - for v in layer_buffer.values(): - t += v.buffer.numel() * v.buffer.element_size() - return t - - def init_buffer( - self, - layer_idx: int, - name: str, - shape: List[int], - dtype: torch.dtype, - batch_format: Literal["BH", "B,1,H"] = "BH", - store_dtype: Optional[torch.dtype] = None, - ): - layer_buffer = self.cache[layer_idx] - if batch_format == "BH": - layer_buffer[name] = CachedBuffer( - buffer=torch.zeros( - (self.max_batch_size * self.head_num, *shape), - device=self.device, - dtype=dtype if store_dtype is None else store_dtype, - ), - batch_format=batch_format, - dtype=dtype, - ) - elif batch_format == "B,1,H": - layer_buffer[name] = CachedBuffer( - buffer=torch.zeros( - (self.max_batch_size, 1, self.head_num, *shape), - device=self.device, - dtype=dtype if store_dtype is None else store_dtype, - ), - batch_format=batch_format, - dtype=dtype, - ) - else: - raise Exception() - - def get_buffer(self, layer_idx: int, name: str, batch_size: int): - if not layer_idx in range(len(self.cache)): - raise Exception(f"{layer_idx} is not in range({len(self.cache)})") - if not name in self.cache[layer_idx]: - raise Exception(f"{name} is not in {self.cache[layer_idx].keys()}") - return self.cache[layer_idx][name].get(batch_size, self.head_num) - - def set_buffer(self, layer_idx: int, name: str, value: torch.Tensor): - if not layer_idx in range(len(self.cache)): - raise Exception(f"{layer_idx} is not in range({len(self.cache)})") - if not name in self.cache[layer_idx]: - raise Exception(f"{name} is not in {self.cache[layer_idx].keys()}") - self.cache[layer_idx][name].set(value) - - def get_hip_metadata_cache( - self, - layer_id: int, - size: int, - batch_size: int, - cached_stages: Optional[int], - ) -> Optional[HiPAttentionOutputMetadata]: - assert size == batch_size - - if (cached_stages is None) or ( - cached_stages == len(self.layer_configs[layer_id].stages) - ): - return HiPAttentionOutputMetadata( - indices=self.get_buffer(layer_id, "indices", batch_size), - ks=self.get_buffer(layer_id, "ks", batch_size), - ks_count=self.get_buffer(layer_id, "ks_count", batch_size), - ks_start_end=self.get_buffer(layer_id, "ks_start_end", batch_size), - mask_cache_statistics=None, - sa_cache_statistics=None, - stage_caches=None, - ) - elif cached_stages == 0: - # NOTE: reset the cache, let hip attention compute everything from scratch - return - else: - stage_caches = [] - for i_stage in range(cached_stages + 1): - if i_stage == 0: - stage_caches.append( - HiPAttentionStageInputCache( - indices_left=None, - indices_right=None, - out_scores=None, - ) - ) - else: - stage_caches.append( - HiPAttentionStageInputCache( - indices_left=self.get_buffer( - layer_id, f"stage_{i_stage}_indices_left", batch_size - ), - indices_right=self.get_buffer( - layer_id, f"stage_{i_stage}_indices_right", batch_size - ), - out_scores=self.get_buffer( - layer_id, f"stage_{i_stage}_out_scores", batch_size - ), - ) - ) - return HiPAttentionOutputMetadata( - indices=None, - ks=None, - ks_count=None, - ks_start_end=None, - mask_cache_statistics=None, - sa_cache_statistics=None, - stage_caches=stage_caches, - ) - - def set_hip_metadata_cache( - self, - layer_id: int, - size: int, - batch_size: int, - metadata: HiPAttentionOutputMetadata, - ): - assert size == batch_size - - self.set_buffer(layer_id, "indices", metadata.indices) - self.set_buffer(layer_id, "ks", metadata.ks) - self.set_buffer(layer_id, "ks_count", metadata.ks_count) - self.set_buffer(layer_id, "ks_start_end", metadata.ks_start_end) - - def update_cache_stats(stats: HiPAttentionCacheAccessStatistics, prefix: str): - if stats is None: - access_count = torch.zeros((1,), dtype=torch.int64, device=self.device) - unique_access_count = torch.zeros( - (1,), dtype=torch.int64, device=self.device - ) - cache_miss_count = torch.zeros( - (1,), dtype=torch.int64, device=self.device - ) - else: - computed_statistics = stats.compute_statistics() - access_count = computed_statistics["access_count"] - unique_access_count = computed_statistics["unique_access_count"] - cache_miss_count = computed_statistics["cache_miss_count"] - - if access_count is not None: - self.set_buffer( - layer_id, - f"{prefix}_access_count", - access_count.view(1, 1).expand(self.max_batch_size, 1), - ) - self.set_buffer( - layer_id, - f"{prefix}_unique_access_count", - unique_access_count.view(1, 1).expand(self.max_batch_size, 1), - ) - self.set_buffer( - layer_id, - f"{prefix}_cache_miss_count", - cache_miss_count.view(1, 1).expand(self.max_batch_size, 1), - ) - - update_cache_stats(metadata.sa_cache_statistics, "sa") - update_cache_stats(metadata.mask_cache_statistics, "mask") - - if metadata.stage_caches is not None: - for i_stage, cache in enumerate(metadata.stage_caches): - if i_stage > 0: - self.set_buffer( - layer_id, f"stage_{i_stage}_indices_left", cache.indices_left - ) - self.set_buffer( - layer_id, f"stage_{i_stage}_indices_right", cache.indices_right - ) - self.set_buffer( - layer_id, f"stage_{i_stage}_out_scores", cache.out_scores - ) - - def compute_cache_statistics(self, batch_size: int): - def compute(prefix): - total_access = 0 - total_miss = 0 - for idx_layer in range(self.layer_num): - access_count = self.get_buffer( - idx_layer, f"{prefix}_access_count", batch_size - ) - miss_count = self.get_buffer( - idx_layer, f"{prefix}_cache_miss_count", batch_size - ) - total_access += access_count.sum() - total_miss += miss_count.sum() - return { - f"{prefix}_access": total_access, - f"{prefix}_miss": total_miss, - f"{prefix}_hit_ratio": 1 - (total_miss / total_access), - } - - result = {} - result.update(compute("sa")) - result.update(compute("mask")) - return result diff --git a/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py b/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py index 2950f4fd84..5af37dfadb 100644 --- a/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py +++ b/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py @@ -2,18 +2,15 @@ import os import threading import time -from typing import Dict, Optional, Set, Tuple, Union +from typing import Dict, Set, Tuple import torch from hip.models.hip_attention.gen3.uvm_gpu_cache import ( - GPUCache, HiPOffloadCache, - UVMCache, format_size_bytes, ) from torch import Tensor -from sglang.srt.layers.attention.hip_attention.hip_config import HiPAttentionConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, MHATokenToKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -33,7 +30,7 @@ def __init__( head_dim: int, layer_num: int, device: torch.device, - hip_config: HiPAttentionConfig, + hip_config: "HiPAttentionConfig", ): assert isinstance(device, torch.device) assert device.index is not None @@ -140,9 +137,7 @@ def get_value_buffer(self, layer_id: int): def get_kv_buffer(self, layer_id: int) -> HiPOffloadCache: # Use this function for decode, pass this to `k` if self.require_validation: - return self.layer_buffer[layer_id], *self.validation_cache.get_kv_buffer( - layer_id - ) + return self.layer_buffer[layer_id], *self.validation_cache.get_kv_buffer(layer_id) return self.layer_buffer[layer_id] def prefetch_prefix_kv_buffer( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index aff37773bb..1b7e90d021 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -43,7 +43,6 @@ if TYPE_CHECKING: from sglang.srt.layers.attention import AttentionBackend from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch - from sglang.srt.mem_cache.hip_memory_pool import HiPMetadataCachePool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo @@ -186,7 +185,7 @@ class ForwardBatch: attn_backend: AttentionBackend = None # For HiP attention - hip_metadata_cache_pool: Optional[HiPMetadataCachePool] = None + hip_metadata_cache_pool: Optional["HiPMetadataCachePool"] = None hip_use_cached_mask: Optional[bool] = None hip_metadata_cached_stage: Optional[int] = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e2d2c69f62..af7e440ffb 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -37,7 +37,6 @@ from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend from sglang.srt.layers.attention.hip_attention import HiPRadixAttentionBackend -from sglang.srt.layers.attention.hip_attention.hip_config import HiPAttentionConfig from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.dp_attention import ( @@ -50,7 +49,6 @@ from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.mem_cache.hip_memory_pool import HiPMetadataCachePool from sglang.srt.mem_cache.hip_offload_kv_pool_mha import MHATokenToHiPOffloadKVPool from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, @@ -699,6 +697,7 @@ def init_memory_pool( self.hip_metadata_cache_pool = None if self.server_args.enable_hip_attention: + from hip.models.hip_attention.gen3.hip_memory_pool import HiPMetadataCachePool self.hip_metadata_cache_pool = HiPMetadataCachePool( query_head_num=self.model_config.num_attention_heads // self.server_args.tp_size, layer_num=self.model_config.num_hidden_layers, @@ -765,6 +764,7 @@ def init_double_sparsity_channel_config(self, selected_channel): ) def init_hip_attention_config(self, hip_attention_config): + from hip.models.hip_attention.gen3.hip_config import HiPAttentionConfig if hip_attention_config is None: hip_attention_config = {} elif hip_attention_config.startswith("{"): @@ -790,10 +790,7 @@ def init_cuda_graphs(self): tic = time.time() logger.info("Capture cuda graph begin. This can take up to several minutes.") self.cuda_graph_runner = CudaGraphRunner(self) - logger.info( - f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s, " - f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" - ) + logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s") def apply_torch_tp(self): logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") From 93d27d5741b2a61fffcc23ff8dfb3a197a6e7861 Mon Sep 17 00:00:00 2001 From: Geon Park Date: Tue, 28 Jan 2025 01:35:01 +0900 Subject: [PATCH 06/16] remove redundant use_cached_mask --- .../hip_attention/hip_radix_attention.py | 4 +--- python/sglang/srt/layers/radix_attention.py | 8 ++------ python/sglang/srt/managers/schedule_batch.py | 1 - .../srt/managers/tp_worker_overlap_thread.py | 15 +++++---------- .../srt/model_executor/cuda_graph_runner.py | 11 +++++------ .../srt/model_executor/forward_batch_info.py | 2 -- 6 files changed, 13 insertions(+), 28 deletions(-) diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py index 4dd895a1ed..493687fc86 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py @@ -286,9 +286,7 @@ def forward_decode( ) metadata = None - if forward_batch.hip_use_cached_mask or ( - forward_batch.hip_metadata_cached_stage is not None - ): + if forward_batch.hip_metadata_cached_stage is not None: metadata = forward_batch.hip_metadata_cache_pool.get_hip_metadata_cache( layer.layer_id, q.shape[0], diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index d065d14864..772d0fa928 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -12,16 +12,12 @@ # limitations under the License. # ============================================================================== """Radix attention.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional +from typing import Optional from torch import nn from sglang.srt.layers.rotary_embedding import RotaryEmbedding - -if TYPE_CHECKING: - from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch class RadixAttention(nn.Module): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index bd84970d73..17423cdbc4 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1274,7 +1274,6 @@ class ModelWorkerBatch: capture_hidden_mode: CaptureHiddenMode = None # Use cached mask for HiP Attention - hip_use_cached_mask: Optional[bool] = None hip_metadata_cached_stages: Optional[int] = None diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index a93b8bb6ad..3c81fdea85 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -72,12 +72,10 @@ def __init__( (self.max_running_requests * 5,), dtype=torch.int32, device=self.device ) - # Init hip mask refresh interval - self.hip_mask_refresh_interval = None + # Init hip attention config + self.hip_attention_config = None if server_args.enable_hip_attention: - self.hip_mask_refresh_interval = ( - self.worker.model_runner.hip_attention_config.mask_refresh_interval - ) + self.hip_attention_config = self.worker.model_runner.hip_attention_config # Launch threads self.input_queue = Queue() @@ -132,21 +130,18 @@ def forward_thread_func_(self): if not model_worker_batch: break - model_worker_batch: ModelWorkerBatch if model_worker_batch.forward_mode.is_decode(): - if self.hip_mask_refresh_interval is not None: + if self.hip_attention_config.mask_refresh_interval is not None: require_refresh = False for i_stage, refresh_inteval in enumerate( - self.hip_mask_refresh_interval + self.hip_attention_config.mask_refresh_interval ): if (decode_index % refresh_inteval == 0) and ( not require_refresh ): - model_worker_batch.hip_use_cached_mask = False model_worker_batch.hip_metadata_cached_stages = i_stage require_refresh = True if not require_refresh: - model_worker_batch.hip_use_cached_mask = True model_worker_batch.hip_metadata_cached_stages = ( None # NOTE: use cache every stage ) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 096396daee..d9c5fb5e1f 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -308,9 +308,9 @@ def capture(self): def capture_configs(self): if self.enable_hip_attention: num_stages = len(self.hip_config.layers[0].stages) - cache_configs = [(True, None)] # (use_cached_mask, num_stage_cached) + cache_configs = [(None,)] # (num_stage_cached,) for i_stage in range(num_stages): - cache_configs.append((False, i_stage)) + cache_configs.append((i_stage,)) return cache_configs else: return [()] @@ -341,9 +341,9 @@ def capture_one_batch_size(self, bs: int, forward: Callable, capture_config: tup spec_info = self.get_spec_info(num_tokens, positions) - hip_use_cached_mask = hip_num_cached_stages = None + hip_num_cached_stages = None if self.enable_hip_attention: - hip_use_cached_mask, hip_num_cached_stages = capture_config + hip_num_cached_stages, = capture_config forward_batch = ForwardBatch( forward_mode=self.capture_forward_mode, @@ -355,7 +355,6 @@ def capture_one_batch_size(self, bs: int, forward: Callable, capture_config: tup token_to_kv_pool=self.model_runner.token_to_kv_pool, attn_backend=self.model_runner.attn_backend, hip_metadata_cache_pool=self.model_runner.hip_metadata_cache_pool, - hip_use_cached_mask=hip_use_cached_mask, hip_metadata_cached_stage=hip_num_cached_stages, out_cache_loc=out_cache_loc, seq_lens_sum=seq_lens.sum(), @@ -456,7 +455,7 @@ def replay(self, forward_batch: ForwardBatch): # Replay graph_handle = (bs,) if self.enable_hip_attention: - graph_handle = (bs, forward_batch.hip_use_cached_mask, forward_batch.hip_metadata_cached_stage) + graph_handle = (bs, forward_batch.hip_metadata_cached_stage) self.graphs[graph_handle].replay() next_token_logits, hidden_states = self.output_buffers[graph_handle] diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 1b7e90d021..61aacafd0c 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -186,7 +186,6 @@ class ForwardBatch: # For HiP attention hip_metadata_cache_pool: Optional["HiPMetadataCachePool"] = None - hip_use_cached_mask: Optional[bool] = None hip_metadata_cached_stage: Optional[int] = None # For DP attention @@ -347,7 +346,6 @@ def init_new( # Init HiP attention information if hasattr(model_runner, "hip_metadata_cache_pool"): ret.hip_metadata_cache_pool = model_runner.hip_metadata_cache_pool - ret.hip_use_cached_mask = batch.hip_use_cached_mask ret.hip_metadata_cached_stage = batch.hip_metadata_cached_stages # Init lora information From 906016d2f8e827ab0415eadc746481d450e891bd Mon Sep 17 00:00:00 2001 From: Geon Park Date: Tue, 28 Jan 2025 01:56:11 +0900 Subject: [PATCH 07/16] move mask refresh mechanism to hip-attention --- .../srt/managers/tp_worker_overlap_thread.py | 33 +++++++------------ .../srt/model_executor/forward_batch_info.py | 2 +- python/sglang/srt/models/llama.py | 7 ++-- 3 files changed, 17 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 3c81fdea85..9eb1c57ed0 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -122,33 +122,24 @@ def forward_thread_func_(self): batch_pt = 0 batch_lists = [None] * 2 - # For keeping track of HiP attention mask refresh - decode_index = 0 + hip_mask_refresh_state = None + if self.hip_attention_config is not None: + from hip.models.hip_attention.gen3.mask_refresh_interval import HiPMaskRefreshState + + # For keeping track of HiP attention mask refresh cycles + hip_mask_refresh_state = HiPMaskRefreshState() while True: model_worker_batch, future_token_ids_ct = self.input_queue.get() if not model_worker_batch: break - if model_worker_batch.forward_mode.is_decode(): - if self.hip_attention_config.mask_refresh_interval is not None: - require_refresh = False - for i_stage, refresh_inteval in enumerate( - self.hip_attention_config.mask_refresh_interval - ): - if (decode_index % refresh_inteval == 0) and ( - not require_refresh - ): - model_worker_batch.hip_metadata_cached_stages = i_stage - require_refresh = True - if not require_refresh: - model_worker_batch.hip_metadata_cached_stages = ( - None # NOTE: use cache every stage - ) - decode_index += 1 - - elif model_worker_batch.forward_mode.is_extend(): - decode_index = 0 + if hip_mask_refresh_state is not None: + model_worker_batch.hip_metadata_cached_stages = hip_mask_refresh_state.update( + model_worker_batch.forward_mode.is_decode(), + model_worker_batch.forward_mode.is_extend(), + self.hip_attention_config, + ) # Keep a reference of model_worker_batch by storing it into a list. # Otherwise, the tensor members of model_worker_batch will be released diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 61aacafd0c..c0bf582e7e 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -344,7 +344,7 @@ def init_new( ret.compute_mrope_positions(model_runner, batch) # Init HiP attention information - if hasattr(model_runner, "hip_metadata_cache_pool"): + if model_runner.hip_metadata_cache_pool is not None: ret.hip_metadata_cache_pool = model_runner.hip_metadata_cache_pool ret.hip_metadata_cached_stage = batch.hip_metadata_cached_stages diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index e8cfd77956..f5b561927f 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -183,9 +183,10 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - # FIXME(geon): find better way to detect if HIP is enabled - if (forward_batch.hip_metadata_cache_pool is None) or ( - not forward_batch.hip_metadata_cache_pool.hip_config.using_extend + # RoPE is applied inside the attention kernel in HiP Attention + if not ( + forward_batch.hip_metadata_cache_pool is not None + and forward_batch.hip_metadata_cache_pool.hip_config.using_extend ): q, k = self.rotary_emb(positions, q, k) From d3e533ea159627930479c60f3eb5909549eef82c Mon Sep 17 00:00:00 2001 From: Geon Park Date: Tue, 28 Jan 2025 03:16:03 +0900 Subject: [PATCH 08/16] move forward_paged_hip to hip-attention repo --- .../hip_attention/hip_radix_attention.py | 406 ++++++------------ 1 file changed, 128 insertions(+), 278 deletions(-) diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py index 493687fc86..5f80aa0e2b 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py @@ -1,7 +1,5 @@ from __future__ import annotations -import os - """ Support different attention backends. Now there are two backends: FlashInfer and Triton. @@ -27,20 +25,19 @@ logger = logging.getLogger(__name__) -class WrapperDispatch(Enum): - SLIDING_WINDOW = auto() - CROSS_ATTENTION = auto() - - class HiPRadixAttentionBackend(AttentionBackend): def __init__(self, model_runner: ModelRunner): super().__init__() self.hip_config: "HiPAttentionConfig" = model_runner.hip_attention_config + self.is_offload_enabled = model_runner.server_args.enable_hip_offload self.max_context_len = model_runner.model_config.context_len + from hip.models.hip_attention.gen3 import forward_paged_hip + self.forward_paged_hip = forward_paged_hip + def init_forward_metadata(self, forward_batch: ForwardBatch): pass @@ -89,14 +86,16 @@ def forward_extend( else forward_batch.encoder_out_cache_loc ) - is_offload_cache = isinstance( - forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool - ) + if not self.is_offload_enabled: + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + offload_cache = None - if is_offload_cache: - assert isinstance( - forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool - ) + else: # Offloading enabled + assert isinstance(forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool) if k is not None: assert v is not None if save_kv_cache: @@ -106,15 +105,6 @@ def forward_extend( k_cache = v_cache = None # offload_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) offload_cache = None - else: - if k is not None: - assert v is not None - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) - k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( - layer.layer_id - ) - offload_cache = None q_reshaped = q.reshape(-1, layer.tp_q_head_num, layer.head_dim) @@ -123,26 +113,26 @@ def forward_extend( start_len = 0 decoding_reqs = [] - decoding_reqs_poistions = [] + decoding_reqs_positions = [] for idx_batch, seq_len in enumerate(forward_batch.extend_seq_lens_cpu): if seq_len == 0: # Skip empty sequences decoding_reqs.append(idx_batch) - decoding_reqs_poistions.append(start_len) + decoding_reqs_positions.append(start_len) + else: - if is_offload_cache: - assert isinstance( - forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool - ) - require_validation = ( - forward_batch.token_to_kv_pool.require_validation - ) + if not self.is_offload_enabled: + k_chunk = v_chunk = None + + else: # Offloading enabled + assert isinstance(forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool) + require_validation = forward_batch.token_to_kv_pool.require_validation if require_validation: k_chunk, v_chunk, k_pages, v_pages = ( forward_batch.token_to_kv_pool.get_fetched_prefix_kv_buffer( layer_id=layer.layer_id, batch_id=idx_batch, - cache_k=k[start_len : start_len + seq_len].unsqueeze(0), - cache_v=v[start_len : start_len + seq_len].unsqueeze(0), + cache_k=k[start_len: start_len + seq_len].unsqueeze(0), + cache_v=v[start_len: start_len + seq_len].unsqueeze(0), ) ) else: @@ -150,15 +140,38 @@ def forward_extend( forward_batch.token_to_kv_pool.get_fetched_prefix_kv_buffer( layer_id=layer.layer_id, batch_id=idx_batch, - cache_k=k[start_len : start_len + seq_len].unsqueeze(0), - cache_v=v[start_len : start_len + seq_len].unsqueeze(0), + cache_k=k[start_len: start_len + seq_len].unsqueeze(0), + cache_v=v[start_len: start_len + seq_len].unsqueeze(0), ) ) offload_cache = k_cache = v_cache = None - else: - k_chunk = v_chunk = None - if is_offload_cache: + if not self.is_offload_enabled: + o_req, _ = self.forward_paged_hip( + query=q_reshaped[start_len:start_len + seq_len], + sm_scale=layer.scaling, + batch_size=1, + k_cache=k_cache, + v_cache=v_cache, + offload_cache=offload_cache, + positions=forward_batch.positions[start_len:start_len + seq_len], + seq_lens=forward_batch.seq_lens[idx_batch:idx_batch + 1], + req_to_tokens=forward_batch.req_to_token_pool.req_to_token, + req_pool_indices=forward_batch.req_pool_indices[idx_batch:idx_batch + 1], + rope_cos=layer.rope_cos, + rope_sin=layer.rope_sin, + layer_id=layer.layer_id, + logit_cap=layer.logit_cap, + orig_context_len=layer.orig_context_len, + max_context_len=self.max_context_len, + hip_config=self.hip_config, + k=k_chunk, + v=v_chunk, + ) + + o[start_len:start_len + seq_len] = o_req + + else: # Offloading enabled # BUG: this padding is neccesary to match non offload scenario. why? pad_size = self.max_context_len if k_chunk.shape[1] != pad_size: @@ -190,78 +203,56 @@ def forward_extend( v_chunk = v_chunk_padded o_req, _ = self.forward_paged_hip( - query=q_reshaped[start_len : start_len + seq_len], + query=q_reshaped[start_len:start_len + seq_len], sm_scale=layer.scaling, batch_size=1, k_cache=k_cache, v_cache=v_cache, offload_cache=offload_cache, - positions=forward_batch.positions[ - start_len : start_len + seq_len - ], - seq_lens=forward_batch.seq_lens[idx_batch : idx_batch + 1], + positions=forward_batch.positions[start_len:start_len + seq_len], + seq_lens=forward_batch.seq_lens[idx_batch:idx_batch + 1], req_to_tokens=forward_batch.req_to_token_pool.req_to_token, - req_pool_indices=forward_batch.req_pool_indices[ - idx_batch : idx_batch + 1 - ], - layer=layer, - is_dense=layer.layer_id in self.hip_config.dense_layers, + req_pool_indices=forward_batch.req_pool_indices[idx_batch:idx_batch + 1], + rope_cos=layer.rope_cos, + rope_sin=layer.rope_sin, + layer_id=layer.layer_id, + logit_cap=layer.logit_cap, + orig_context_len=layer.orig_context_len, + max_context_len=self.max_context_len, + hip_config=self.hip_config, k=k_chunk, v=v_chunk, online_update_cache=forward_batch.token_to_kv_pool.online_update_cache, - is_decode=False, ) if require_validation: o_req_valid, _ = self.forward_paged_hip( - query=q_reshaped[start_len : start_len + seq_len], + query=q_reshaped[start_len: start_len + seq_len], sm_scale=layer.scaling, batch_size=1, k_cache=k_pages, v_cache=v_pages, offload_cache=None, - positions=forward_batch.positions[ - start_len : start_len + seq_len - ], - seq_lens=forward_batch.seq_lens[idx_batch : idx_batch + 1], + positions=forward_batch.positions[start_len:start_len + seq_len], + seq_lens=forward_batch.seq_lens[idx_batch:idx_batch + 1], req_to_tokens=forward_batch.req_to_token_pool.req_to_token, - req_pool_indices=forward_batch.req_pool_indices[ - idx_batch : idx_batch + 1 - ], - layer=layer, - is_dense=layer.layer_id in self.hip_config.dense_layers, - is_decode=False, + req_pool_indices=forward_batch.req_pool_indices[idx_batch:idx_batch + 1], + rope_cos=layer.rope_cos, + rope_sin=layer.rope_sin, + layer_id=layer.layer_id, + logit_cap=layer.logit_cap, + orig_context_len=layer.orig_context_len, + max_context_len=self.max_context_len, + hip_config=self.hip_config, ) o_err = ((o_req - o_req_valid) ** 2).sum() assert o_err < 1e-6, o_err - o[start_len : start_len + seq_len] = o_req - else: - o_req, _ = self.forward_paged_hip( - query=q_reshaped[start_len : start_len + seq_len], - sm_scale=layer.scaling, - batch_size=1, - k_cache=k_cache, - v_cache=v_cache, - offload_cache=offload_cache, - positions=forward_batch.positions[ - start_len : start_len + seq_len - ], - seq_lens=forward_batch.seq_lens[idx_batch : idx_batch + 1], - req_to_tokens=forward_batch.req_to_token_pool.req_to_token, - req_pool_indices=forward_batch.req_pool_indices[ - idx_batch : idx_batch + 1 - ], - layer=layer, - is_dense=layer.layer_id in self.hip_config.dense_layers, - k=k_chunk, - v=v_chunk, - is_decode=False, - ) + o[start_len:start_len + seq_len] = o_req - o[start_len : start_len + seq_len] = o_req start_len += seq_len + assert len(decoding_reqs) == 0 return o.view(-1, layer.tp_q_head_num * layer.head_dim) @@ -281,10 +272,6 @@ def forward_decode( else forward_batch.encoder_out_cache_loc ) - is_offload_cache = isinstance( - forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool - ) - metadata = None if forward_batch.hip_metadata_cached_stage is not None: metadata = forward_batch.hip_metadata_cache_pool.get_hip_metadata_cache( @@ -295,10 +282,15 @@ def forward_decode( ) require_validation = False - if is_offload_cache: - assert isinstance( - forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool - ) + if not self.is_offload_enabled: + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + offload_cache = None + else: + assert isinstance(forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool) require_validation = forward_batch.token_to_kv_pool.require_validation if k is not None: @@ -315,22 +307,11 @@ def forward_decode( if not require_validation: k_cache = v_cache = None - offload_cache = forward_batch.token_to_kv_pool.get_kv_buffer( - layer.layer_id - ) + offload_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) else: offload_cache, k_cache, v_cache = ( forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) ) - else: - if k is not None: - assert v is not None - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) - k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( - layer.layer_id - ) - offload_cache = None if not require_validation: o, metadata = self.forward_paged_hip( @@ -344,17 +325,18 @@ def forward_decode( seq_lens=forward_batch.seq_lens, req_to_tokens=forward_batch.req_to_token_pool.req_to_token, req_pool_indices=forward_batch.req_pool_indices, - layer=layer, + rope_cos=layer.rope_cos, + rope_sin=layer.rope_sin, + layer_id=layer.layer_id, + logit_cap=layer.logit_cap, + orig_context_len=layer.orig_context_len, + max_context_len=self.max_context_len, + hip_config=self.hip_config, cached_metadata=metadata, - is_dense=layer.layer_id in self.hip_config.dense_layers, online_update_cache=( forward_batch.token_to_kv_pool.online_update_cache - if isinstance( - forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool - ) - else None + if self.is_offload_enabled else None ), - is_decode=True, ) else: @@ -384,17 +366,18 @@ def sse(a: torch.Tensor, b: torch.Tensor): seq_lens=forward_batch.seq_lens, req_to_tokens=forward_batch.req_to_token_pool.req_to_token, req_pool_indices=forward_batch.req_pool_indices, - layer=layer, + rope_cos=layer.rope_cos, + rope_sin=layer.rope_sin, + layer_id=layer.layer_id, + logit_cap=layer.logit_cap, + orig_context_len=layer.orig_context_len, + max_context_len=self.max_context_len, + hip_config=self.hip_config, cached_metadata=metadata, - is_dense=layer.layer_id in self.hip_config.dense_layers, online_update_cache=( forward_batch.token_to_kv_pool.online_update_cache - if isinstance( - forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool - ) - else None + if self.is_offload_enabled else None ), - is_decode=True, ) o_valid, metadata_valid = self.forward_paged_hip( @@ -408,17 +391,18 @@ def sse(a: torch.Tensor, b: torch.Tensor): seq_lens=forward_batch.seq_lens, req_to_tokens=forward_batch.req_to_token_pool.req_to_token, req_pool_indices=forward_batch.req_pool_indices, - layer=layer, + rope_cos=layer.rope_cos, + rope_sin=layer.rope_sin, + layer_id=layer.layer_id, + logit_cap=layer.logit_cap, + orig_context_len=layer.orig_context_len, + max_context_len=self.max_context_len, + hip_config=self.hip_config, cached_metadata=metadata, - is_dense=layer.layer_id in self.hip_config.dense_layers, online_update_cache=( forward_batch.token_to_kv_pool.online_update_cache - if isinstance( - forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool - ) - else None + if self.is_offload_enabled else None ), - is_decode=True, ) err_thresh = 1e-7 @@ -466,10 +450,7 @@ def sse(a: torch.Tensor, b: torch.Tensor): ) = stage2_right_err = stage2_score_err = None online_update = ( forward_batch.token_to_kv_pool.online_update_cache - if isinstance( - forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool - ) - else None + if self.is_offload_enabled else None ) o_uvm, metadata_uvm = self.forward_paged_hip( @@ -483,17 +464,18 @@ def sse(a: torch.Tensor, b: torch.Tensor): seq_lens=forward_batch.seq_lens, req_to_tokens=forward_batch.req_to_token_pool.req_to_token, req_pool_indices=forward_batch.req_pool_indices, - layer=layer, + rope_cos=layer.rope_cos, + rope_sin=layer.rope_sin, + layer_id=layer.layer_id, + logit_cap=layer.logit_cap, + orig_context_len=layer.orig_context_len, + max_context_len=self.max_context_len, + hip_config=self.hip_config, cached_metadata=metadata, - is_dense=layer.layer_id in self.hip_config.dense_layers, online_update_cache=( forward_batch.token_to_kv_pool.online_update_cache - if isinstance( - forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool - ) - else None + if self.is_offload_enabled else None ), - is_decode=True, ) offload_cache.sa_kv_cache.flush() @@ -510,17 +492,18 @@ def sse(a: torch.Tensor, b: torch.Tensor): seq_lens=forward_batch.seq_lens, req_to_tokens=forward_batch.req_to_token_pool.req_to_token, req_pool_indices=forward_batch.req_pool_indices, - layer=layer, + rope_cos=layer.rope_cos, + rope_sin=layer.rope_sin, + layer_id=layer.layer_id, + logit_cap=layer.logit_cap, + orig_context_len=layer.orig_context_len, + max_context_len=self.max_context_len, + hip_config=self.hip_config, cached_metadata=metadata, - is_dense=layer.layer_id in self.hip_config.dense_layers, online_update_cache=( forward_batch.token_to_kv_pool.online_update_cache - if isinstance( - forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool - ) - else None + if self.is_offload_enabled else None ), - is_decode=True, ) err_uvm = sse(o, o_uvm) err_retry = sse(o_valid, o_retry) @@ -560,140 +543,7 @@ def sse(a: torch.Tensor, b: torch.Tensor): metadata=metadata, ) - if is_offload_cache: + if self.is_offload_enabled: offload_cache.handle_cache_miss(metadata) return o.view(-1, layer.tp_q_head_num * layer.head_dim) - - def forward_paged_hip( - self, - query: torch.Tensor, - sm_scale: float, - batch_size: int, - k_cache: Optional[torch.Tensor], - v_cache: Optional[torch.Tensor], - offload_cache: Optional["HiPOffloadCache"], - positions: torch.Tensor, - seq_lens: torch.Tensor, - req_to_tokens: torch.Tensor, - req_pool_indices: torch.Tensor, - layer: RadixAttention, - cached_metadata: Optional["HiPAttentionOutputMetadata"] = None, - is_dense: bool = False, - k: Optional[torch.Tensor] = None, - v: Optional[torch.Tensor] = None, - online_update_cache: bool = False, - is_decode: bool = False, - ) -> tuple[torch.Tensor, "HiPAttentionOutputMetadata"]: - N, num_heads, hidden_dims = query.shape - dst_seq_len = N // batch_size - - is_decode = dst_seq_len == 1 - is_dense = layer.layer_id in self.hip_config.dense_layers - if not is_decode: - if len(self.hip_config.prefill_layers) == 2: - layer_config = self.hip_config.prefill_layers[0 if is_dense else 1] - else: - layer_config = self.hip_config.prefill_layers[layer.layer_id] - else: - if len(self.hip_config.layers) == 2: - layer_config = self.hip_config.layers[0 if is_dense else 1] - else: - layer_config = self.hip_config.layers[layer.layer_id] - - query = query.view(batch_size, dst_seq_len, num_heads, hidden_dims) - - if k_cache is not None: - N_PAGE, num_heads_kv, hidden_dims_kv = k_cache.shape - assert v_cache.shape == k_cache.shape - assert hidden_dims_kv == hidden_dims - - k_cache = k_cache.view(N_PAGE, 1, num_heads_kv, hidden_dims) - v_cache = v_cache.view(N_PAGE, 1, num_heads_kv, hidden_dims) - - # FIXME: this operation is linear during decoding - block_table = req_to_tokens.index_select(dim=0, index=req_pool_indices) - - BLOCK_TABLE_BSZ, MODEL_SEQ_LEN = block_table.shape - assert batch_size == BLOCK_TABLE_BSZ - - # NOTE(heejun): the whole point to need to find gemma is large size of hidden size - # FIXME: find better way to detect Gemma - if k_cache is not None: - hidden_size = k_cache.shape[-1] - elif k is not None: - hidden_size = k.shape[-1] - elif offload_cache is not None: - hidden_size = offload_cache.k_uvm.bank_cpu.shape[-1] - else: - raise Exception() - is_gemma = hidden_size > 128 - - require_cache_statistics = False - if cached_metadata is None: - require_cache_statistics = True - elif cached_metadata.indices is None: - require_cache_statistics = True - elif os.getenv("HIP_DISABLE_COMPUTE_STATISTICS", "1") == "0": - require_cache_statistics = True - - from hip.models.hip_attention.gen3.attention_extend import ( - dual_stage_quadratic_hip_attention, - ) - from hip.models.hip_attention.gen3.attention_metadata import HiPAttentionArgs - - args = HiPAttentionArgs( - k_cache=( - k_cache.view(torch.uint8) - if isinstance(k_cache, torch.Tensor) - and k_cache.dtype == torch.float8_e5m2 - else k_cache - ), - v_cache=( - v_cache.view(torch.uint8) - if isinstance(k_cache, torch.Tensor) - and v_cache.dtype == torch.float8_e5m2 - else v_cache - ), - offload_cache=offload_cache, - block_table=block_table, - cache_seq_lens=seq_lens, - position_ids=positions.view(batch_size, dst_seq_len), - block_size_k=32 if is_gemma else 64, # BLOCK_CHUNK - sliding_window_size=layer_config.sliding_window_size, - sink_token_size=layer_config.sink_token_size, - using_extend=self.hip_config.using_extend, - need_apply_rope=self.hip_config.using_extend, - rope_cos=layer.rope_cos, - rope_sin=layer.rope_sin, - logit_softcap=layer.logit_cap if layer.logit_cap != 0.0 else None, - second_stage_k=layer_config.second_stage_k, - stages=layer_config.stages, - model_context_length=layer.orig_context_len, - extend_context_length=self.max_context_len, - block_sparse_block_size_q=self.hip_config.block_sparse_block_size_q, - scan_extend_backend=( - ( - "relative" - if self.hip_config.apply_v_dot - else ("streaming" if is_dense else "relative") - ) - if layer_config.scan_extend_backend is None - else layer_config.scan_extend_backend - ), - sa_extend_backend=layer_config.sa_extend_backend, - online_update_cache=online_update_cache, - require_cache_statistics=require_cache_statistics, - disable_flashdecode=not is_decode, - ) - - context, metadata = dual_stage_quadratic_hip_attention( - (query * sm_scale).to(query.dtype), - k, - v, - args=args, - cached_metadata=cached_metadata, - ) - context = context.to(query.dtype) - - return context.view(N, num_heads, hidden_dims), metadata From 5ceba035e41b0f1f35652f539121fd6642181fbf Mon Sep 17 00:00:00 2001 From: Geon Park Date: Tue, 28 Jan 2025 03:49:36 +0900 Subject: [PATCH 09/16] add enable_memory_saver arg --- python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py b/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py index 5af37dfadb..e780771b16 100644 --- a/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py +++ b/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py @@ -112,6 +112,7 @@ def __init__( head_dim=head_dim, layer_num=layer_num, device=self.device, + enable_memory_saver=False, ) else: self.validation_cache = None From d8d683b90ec83abd2494e6d583aa60ddcdf3198c Mon Sep 17 00:00:00 2001 From: Geon Park Date: Wed, 29 Jan 2025 01:03:41 +0900 Subject: [PATCH 10/16] move validation logic to hip-attention repo --- .../hip_attention/hip_radix_attention.py | 436 +++--------------- .../srt/mem_cache/hip_offload_kv_pool_mha.py | 19 +- 2 files changed, 79 insertions(+), 376 deletions(-) diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py index 5f80aa0e2b..329b79ebee 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py @@ -30,14 +30,15 @@ class HiPRadixAttentionBackend(AttentionBackend): def __init__(self, model_runner: ModelRunner): super().__init__() - self.hip_config: "HiPAttentionConfig" = model_runner.hip_attention_config + from hip.models.hip_attention.gen3 import forward_paged_hip, HiPAttentionConfig + + self.forward_paged_hip = forward_paged_hip + + self.hip_config: HiPAttentionConfig = model_runner.hip_attention_config self.is_offload_enabled = model_runner.server_args.enable_hip_offload self.max_context_len = model_runner.model_config.context_len - from hip.models.hip_attention.gen3 import forward_paged_hip - self.forward_paged_hip = forward_paged_hip - def init_forward_metadata(self, forward_batch: ForwardBatch): pass @@ -103,7 +104,6 @@ def forward_extend( layer, cache_loc, k, v, async_copy=True, push_to_gpu_cache=False ) k_cache = v_cache = None - # offload_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) offload_cache = None q_reshaped = q.reshape(-1, layer.tp_q_head_num, layer.head_dim) @@ -122,134 +122,48 @@ def forward_extend( else: if not self.is_offload_enabled: k_chunk = v_chunk = None + offloading_metadata = None else: # Offloading enabled - assert isinstance(forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool) - require_validation = forward_batch.token_to_kv_pool.require_validation - if require_validation: - k_chunk, v_chunk, k_pages, v_pages = ( - forward_batch.token_to_kv_pool.get_fetched_prefix_kv_buffer( - layer_id=layer.layer_id, - batch_id=idx_batch, - cache_k=k[start_len: start_len + seq_len].unsqueeze(0), - cache_v=v[start_len: start_len + seq_len].unsqueeze(0), - ) - ) - else: - k_chunk, v_chunk = ( - forward_batch.token_to_kv_pool.get_fetched_prefix_kv_buffer( - layer_id=layer.layer_id, - batch_id=idx_batch, - cache_k=k[start_len: start_len + seq_len].unsqueeze(0), - cache_v=v[start_len: start_len + seq_len].unsqueeze(0), - ) - ) - offload_cache = k_cache = v_cache = None - - if not self.is_offload_enabled: - o_req, _ = self.forward_paged_hip( - query=q_reshaped[start_len:start_len + seq_len], - sm_scale=layer.scaling, - batch_size=1, - k_cache=k_cache, - v_cache=v_cache, - offload_cache=offload_cache, - positions=forward_batch.positions[start_len:start_len + seq_len], - seq_lens=forward_batch.seq_lens[idx_batch:idx_batch + 1], - req_to_tokens=forward_batch.req_to_token_pool.req_to_token, - req_pool_indices=forward_batch.req_pool_indices[idx_batch:idx_batch + 1], - rope_cos=layer.rope_cos, - rope_sin=layer.rope_sin, - layer_id=layer.layer_id, - logit_cap=layer.logit_cap, - orig_context_len=layer.orig_context_len, - max_context_len=self.max_context_len, - hip_config=self.hip_config, - k=k_chunk, - v=v_chunk, - ) - - o[start_len:start_len + seq_len] = o_req - - else: # Offloading enabled - # BUG: this padding is neccesary to match non offload scenario. why? - pad_size = self.max_context_len - if k_chunk.shape[1] != pad_size: - k_chunk_padded = torch.zeros( - ( - k_chunk.shape[0], - pad_size, - k_chunk.shape[2], - k_chunk.shape[3], - ), - dtype=k_chunk.dtype, - device=k_chunk.device, - ) - k_chunk_padded[:, : k_chunk.shape[1]] = k_chunk - del k_chunk - v_chunk_padded = torch.zeros( - ( - v_chunk.shape[0], - pad_size, - v_chunk.shape[2], - v_chunk.shape[3], - ), - dtype=v_chunk.dtype, - device=v_chunk.device, - ) - v_chunk_padded[:, : v_chunk.shape[1]] = v_chunk - del v_chunk - k_chunk = k_chunk_padded - v_chunk = v_chunk_padded - - o_req, _ = self.forward_paged_hip( - query=q_reshaped[start_len:start_len + seq_len], - sm_scale=layer.scaling, - batch_size=1, - k_cache=k_cache, - v_cache=v_cache, - offload_cache=offload_cache, - positions=forward_batch.positions[start_len:start_len + seq_len], - seq_lens=forward_batch.seq_lens[idx_batch:idx_batch + 1], - req_to_tokens=forward_batch.req_to_token_pool.req_to_token, - req_pool_indices=forward_batch.req_pool_indices[idx_batch:idx_batch + 1], - rope_cos=layer.rope_cos, - rope_sin=layer.rope_sin, - layer_id=layer.layer_id, - logit_cap=layer.logit_cap, - orig_context_len=layer.orig_context_len, - max_context_len=self.max_context_len, - hip_config=self.hip_config, - k=k_chunk, - v=v_chunk, - online_update_cache=forward_batch.token_to_kv_pool.online_update_cache, - ) - - if require_validation: - o_req_valid, _ = self.forward_paged_hip( - query=q_reshaped[start_len: start_len + seq_len], - sm_scale=layer.scaling, - batch_size=1, - k_cache=k_pages, - v_cache=v_pages, - offload_cache=None, - positions=forward_batch.positions[start_len:start_len + seq_len], - seq_lens=forward_batch.seq_lens[idx_batch:idx_batch + 1], - req_to_tokens=forward_batch.req_to_token_pool.req_to_token, - req_pool_indices=forward_batch.req_pool_indices[idx_batch:idx_batch + 1], - rope_cos=layer.rope_cos, - rope_sin=layer.rope_sin, + k_chunk, v_chunk, offloading_metadata = ( + forward_batch.token_to_kv_pool.get_fetched_prefix_kv_buffer( layer_id=layer.layer_id, - logit_cap=layer.logit_cap, - orig_context_len=layer.orig_context_len, - max_context_len=self.max_context_len, - hip_config=self.hip_config, + batch_id=idx_batch, + cache_k=k[start_len: start_len + seq_len].unsqueeze(0), + cache_v=v[start_len: start_len + seq_len].unsqueeze(0), ) + ) + offload_cache = k_cache = v_cache = None - o_err = ((o_req - o_req_valid) ** 2).sum() - assert o_err < 1e-6, o_err + o_req, _ = self.forward_paged_hip( + query=q_reshaped[start_len:start_len + seq_len], + sm_scale=layer.scaling, + batch_size=1, + k_cache=k_cache, + v_cache=v_cache, + offload_cache=offload_cache, + positions=forward_batch.positions[start_len:start_len + seq_len], + seq_lens=forward_batch.seq_lens[idx_batch:idx_batch + 1], + req_to_tokens=forward_batch.req_to_token_pool.req_to_token, + req_pool_indices=forward_batch.req_pool_indices[idx_batch:idx_batch + 1], + rope_cos=layer.rope_cos, + rope_sin=layer.rope_sin, + layer_id=layer.layer_id, + logit_cap=layer.logit_cap, + orig_context_len=layer.orig_context_len, + max_context_len=self.max_context_len, + is_prefill=True, + hip_config=self.hip_config, + k=k_chunk, + v=v_chunk, + online_update_cache=( + forward_batch.token_to_kv_pool.online_update_cache + if self.is_offload_enabled else None + ), + offloading_metadata=offloading_metadata, + ) - o[start_len:start_len + seq_len] = o_req + o[start_len:start_len + seq_len] = o_req start_len += seq_len @@ -281,7 +195,6 @@ def forward_decode( forward_batch.hip_metadata_cached_stage, ) - require_validation = False if not self.is_offload_enabled: if k is not None: assert v is not None @@ -289,252 +202,45 @@ def forward_decode( forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) offload_cache = None - else: + + else: # Offloading enabled assert isinstance(forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool) - require_validation = forward_batch.token_to_kv_pool.require_validation if k is not None: assert v is not None if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( - layer, - cache_loc, - k, - v, - async_copy=False, - push_to_gpu_cache=True, - ) - - if not require_validation: - k_cache = v_cache = None - offload_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - else: - offload_cache, k_cache, v_cache = ( - forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - ) - - if not require_validation: - o, metadata = self.forward_paged_hip( - query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - sm_scale=layer.scaling, - batch_size=forward_batch.batch_size, - k_cache=k_cache, - v_cache=v_cache, - offload_cache=offload_cache, - positions=forward_batch.positions, - seq_lens=forward_batch.seq_lens, - req_to_tokens=forward_batch.req_to_token_pool.req_to_token, - req_pool_indices=forward_batch.req_pool_indices, - rope_cos=layer.rope_cos, - rope_sin=layer.rope_sin, - layer_id=layer.layer_id, - logit_cap=layer.logit_cap, - orig_context_len=layer.orig_context_len, - max_context_len=self.max_context_len, - hip_config=self.hip_config, - cached_metadata=metadata, - online_update_cache=( - forward_batch.token_to_kv_pool.online_update_cache - if self.is_offload_enabled else None - ), - ) - else: - - def sse(a: torch.Tensor, b: torch.Tensor): - assert a.dtype == b.dtype - return ((a - b) ** 2).sum().item() - - err_k = sse(offload_cache.k_uvm.bank_gpu, k_cache) - err_v = sse(offload_cache.v_uvm.bank_gpu, v_cache) - - o, metadata_new = self.forward_paged_hip( - query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - sm_scale=layer.scaling, - batch_size=forward_batch.batch_size, - k_cache=None, - v_cache=None, - offload_cache=offload_cache, - # NOTE: to test uvm only - # k_cache=offload_cache.k_uvm.bank_gpu, - # v_cache=offload_cache.v_uvm.bank_gpu, - # offload_cache=None, - # NOTE: to test on gpu only - # k_cache=k_cache, - # v_cache=v_cache, - # offload_cache=None, - positions=forward_batch.positions, - seq_lens=forward_batch.seq_lens, - req_to_tokens=forward_batch.req_to_token_pool.req_to_token, - req_pool_indices=forward_batch.req_pool_indices, - rope_cos=layer.rope_cos, - rope_sin=layer.rope_sin, - layer_id=layer.layer_id, - logit_cap=layer.logit_cap, - orig_context_len=layer.orig_context_len, - max_context_len=self.max_context_len, - hip_config=self.hip_config, - cached_metadata=metadata, - online_update_cache=( - forward_batch.token_to_kv_pool.online_update_cache - if self.is_offload_enabled else None - ), - ) - - o_valid, metadata_valid = self.forward_paged_hip( - query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - sm_scale=layer.scaling, - batch_size=forward_batch.batch_size, - k_cache=k_cache, - v_cache=v_cache, - offload_cache=None, - positions=forward_batch.positions, - seq_lens=forward_batch.seq_lens, - req_to_tokens=forward_batch.req_to_token_pool.req_to_token, - req_pool_indices=forward_batch.req_pool_indices, - rope_cos=layer.rope_cos, - rope_sin=layer.rope_sin, - layer_id=layer.layer_id, - logit_cap=layer.logit_cap, - orig_context_len=layer.orig_context_len, - max_context_len=self.max_context_len, - hip_config=self.hip_config, - cached_metadata=metadata, - online_update_cache=( - forward_batch.token_to_kv_pool.online_update_cache - if self.is_offload_enabled else None - ), - ) - - err_thresh = 1e-7 - - o_sse = sse(o, o_valid) - err_retry = -1 - err_uvm = None - if o_sse >= err_thresh: - indices_err = sse(metadata_new.indices, metadata_valid.indices) - ks_err = sse(metadata_new.ks, metadata_valid.ks) - ks_count_err = sse(metadata_new.ks_count, metadata_valid.ks_count) - ks_start_end_err = sse( - metadata_new.ks_start_end, metadata_valid.ks_start_end - ) - if (metadata_valid.stage_caches is not None) and ( - len(metadata_valid.stage_caches) > 0 - ): - stage1_left_err = sse( - metadata_new.stage_caches[1].indices_left, - metadata_valid.stage_caches[1].indices_left, - ) - stage1_right_err = sse( - metadata_new.stage_caches[1].indices_right, - metadata_valid.stage_caches[1].indices_right, - ) - stage1_score_err = sse( - metadata_new.stage_caches[1].out_scores, - metadata_valid.stage_caches[1].out_scores, - ) - stage2_left_err = sse( - metadata_new.stage_caches[2].indices_left, - metadata_valid.stage_caches[2].indices_left, + layer, cache_loc, k, v, async_copy=False, push_to_gpu_cache=True ) - stage2_right_err = sse( - metadata_new.stage_caches[2].indices_right, - metadata_valid.stage_caches[2].indices_right, - ) - stage2_score_err = sse( - metadata_new.stage_caches[2].out_scores, - metadata_valid.stage_caches[2].out_scores, - ) - else: - stage1_left_err = stage1_right_err = stage1_score_err = ( - stage2_left_err - ) = stage2_right_err = stage2_score_err = None - online_update = ( - forward_batch.token_to_kv_pool.online_update_cache - if self.is_offload_enabled else None - ) - o_uvm, metadata_uvm = self.forward_paged_hip( - query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - sm_scale=layer.scaling, - batch_size=forward_batch.batch_size, - k_cache=offload_cache.k_uvm.bank_gpu, - v_cache=offload_cache.v_uvm.bank_gpu, - offload_cache=None, - positions=forward_batch.positions, - seq_lens=forward_batch.seq_lens, - req_to_tokens=forward_batch.req_to_token_pool.req_to_token, - req_pool_indices=forward_batch.req_pool_indices, - rope_cos=layer.rope_cos, - rope_sin=layer.rope_sin, - layer_id=layer.layer_id, - logit_cap=layer.logit_cap, - orig_context_len=layer.orig_context_len, - max_context_len=self.max_context_len, - hip_config=self.hip_config, - cached_metadata=metadata, - online_update_cache=( - forward_batch.token_to_kv_pool.online_update_cache - if self.is_offload_enabled else None - ), - ) - - offload_cache.sa_kv_cache.flush() - offload_cache.mask_k_cache.flush() - - o_retry, metadata_retry = self.forward_paged_hip( - query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - sm_scale=layer.scaling, - batch_size=forward_batch.batch_size, - k_cache=None, - v_cache=None, - offload_cache=offload_cache, - positions=forward_batch.positions, - seq_lens=forward_batch.seq_lens, - req_to_tokens=forward_batch.req_to_token_pool.req_to_token, - req_pool_indices=forward_batch.req_pool_indices, - rope_cos=layer.rope_cos, - rope_sin=layer.rope_sin, - layer_id=layer.layer_id, - logit_cap=layer.logit_cap, - orig_context_len=layer.orig_context_len, - max_context_len=self.max_context_len, - hip_config=self.hip_config, - cached_metadata=metadata, - online_update_cache=( - forward_batch.token_to_kv_pool.online_update_cache - if self.is_offload_enabled else None - ), - ) - err_uvm = sse(o, o_uvm) - err_retry = sse(o_valid, o_retry) - - print(o) - print(o_valid) - print(metadata_new.indices) - print(metadata_valid.indices) - - assert ( - o_sse < err_thresh - ), f""" -sse={o_sse} -err_k (uvm_k <=> valid_k) = {err_k} -err_v (uvm_v <=> valid_v) ={err_v} -err_retry (o_valid <=> o_retry) = {err_retry} -err_uvm (o_first <=> o_uvm_retry) = {err_uvm} -indices_err={indices_err} -ks_err={ks_err} -ks_count_err={ks_count_err} -ks_start_end_err={ks_start_end_err} -stage1_left_err={stage1_left_err} -stage1_right_err={stage1_right_err} -stage1_score_err={stage1_score_err} -stage2_left_err={stage2_left_err} -stage2_right_err={stage2_right_err} -stage2_score_err={stage2_score_err} -online_update={online_update} -""" + k_cache = v_cache = None + offload_cache, offloading_metadata = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - metadata = metadata_new + o, metadata = self.forward_paged_hip( + query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + sm_scale=layer.scaling, + batch_size=forward_batch.batch_size, + k_cache=k_cache, + v_cache=v_cache, + offload_cache=offload_cache, + positions=forward_batch.positions, + seq_lens=forward_batch.seq_lens, + req_to_tokens=forward_batch.req_to_token_pool.req_to_token, + req_pool_indices=forward_batch.req_pool_indices, + rope_cos=layer.rope_cos, + rope_sin=layer.rope_sin, + layer_id=layer.layer_id, + logit_cap=layer.logit_cap, + orig_context_len=layer.orig_context_len, + max_context_len=self.max_context_len, + is_prefill=False, + hip_config=self.hip_config, + cached_metadata=metadata, + online_update_cache=( + forward_batch.token_to_kv_pool.online_update_cache + if self.is_offload_enabled else None + ), + ) forward_batch.hip_metadata_cache_pool.set_hip_metadata_cache( layer_id=layer.layer_id, diff --git a/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py b/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py index e780771b16..74cf67f809 100644 --- a/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py +++ b/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py @@ -2,7 +2,7 @@ import os import threading import time -from typing import Dict, Set, Tuple +from typing import Dict, Set, Tuple, Any import torch from hip.models.hip_attention.gen3.uvm_gpu_cache import ( @@ -135,11 +135,11 @@ def get_key_buffer(self, layer_id: int): def get_value_buffer(self, layer_id: int): raise NotImplementedError() - def get_kv_buffer(self, layer_id: int) -> HiPOffloadCache: + def get_kv_buffer(self, layer_id: int) -> Tuple[HiPOffloadCache, Any]: # Use this function for decode, pass this to `k` if self.require_validation: - return self.layer_buffer[layer_id], *self.validation_cache.get_kv_buffer(layer_id) - return self.layer_buffer[layer_id] + return self.layer_buffer[layer_id], self.validation_cache.get_kv_buffer(layer_id) + return self.layer_buffer[layer_id], None def prefetch_prefix_kv_buffer( self, layer_id: int, batch_id: int, table: Tensor, prefix_seq_len: int @@ -147,10 +147,7 @@ def prefetch_prefix_kv_buffer( # you must call before get fetched prefix assert table.ndim == 1 - if self.require_validation: - hip_offload_cache, _, _ = self.get_kv_buffer(layer_id) - else: - hip_offload_cache = self.get_kv_buffer(layer_id) + hip_offload_cache, _ = self.get_kv_buffer(layer_id) handle_id = (layer_id, batch_id) assert handle_id not in self.prefetch_threads, handle_id @@ -205,7 +202,7 @@ def get_fetched_prefix_kv_buffer( # you need to pass KV for extend cache_k: Tensor, cache_v: Tensor, - ) -> Tuple[Tensor, Tensor]: + ) -> Tuple[Tensor, Tensor, Any]: # return cache_k, cache_v # Use this function for prefill @@ -279,9 +276,9 @@ def get_fetched_prefix_kv_buffer( assert k_err < 1e-5, k_err assert v_err < 1e-5, v_err - return k, v, k_valid, v_valid + return k, v, (k_valid, v_valid) else: - return k, v + return k, v, None def set_kv_buffer( self, From 6bd523ccd89b8363c884ba9e85bb286556089b2c Mon Sep 17 00:00:00 2001 From: Geon Park Date: Wed, 29 Jan 2025 04:52:56 +0900 Subject: [PATCH 11/16] move offloading logic to hip-attention and cleanup imports --- .../hip_attention/hip_radix_attention.py | 10 +- .../srt/managers/tp_worker_overlap_thread.py | 2 +- .../srt/mem_cache/hip_offload_kv_pool_mha.py | 413 +++--------------- .../srt/model_executor/forward_batch_info.py | 4 +- .../sglang/srt/model_executor/model_runner.py | 4 +- 5 files changed, 64 insertions(+), 369 deletions(-) diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py index 329b79ebee..eb2f64bd67 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py @@ -8,7 +8,6 @@ """ import logging -from enum import Enum, auto from typing import TYPE_CHECKING, Optional import torch @@ -21,6 +20,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.spec_info import SpecInfo + from hip.models.hip_attention.gen3 import HiPAttentionConfig logger = logging.getLogger(__name__) @@ -30,8 +30,7 @@ class HiPRadixAttentionBackend(AttentionBackend): def __init__(self, model_runner: ModelRunner): super().__init__() - from hip.models.hip_attention.gen3 import forward_paged_hip, HiPAttentionConfig - + from hip.models.hip_attention.gen3 import forward_paged_hip self.forward_paged_hip = forward_paged_hip self.hip_config: HiPAttentionConfig = model_runner.hip_attention_config @@ -157,7 +156,7 @@ def forward_extend( k=k_chunk, v=v_chunk, online_update_cache=( - forward_batch.token_to_kv_pool.online_update_cache + forward_batch.token_to_kv_pool.is_online_cache_update_enabled() if self.is_offload_enabled else None ), offloading_metadata=offloading_metadata, @@ -205,7 +204,6 @@ def forward_decode( else: # Offloading enabled assert isinstance(forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool) - if k is not None: assert v is not None if save_kv_cache: @@ -237,7 +235,7 @@ def forward_decode( hip_config=self.hip_config, cached_metadata=metadata, online_update_cache=( - forward_batch.token_to_kv_pool.online_update_cache + forward_batch.token_to_kv_pool.is_online_cache_update_enabled() if self.is_offload_enabled else None ), ) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 9eb1c57ed0..69ee4c8ef7 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -124,7 +124,7 @@ def forward_thread_func_(self): hip_mask_refresh_state = None if self.hip_attention_config is not None: - from hip.models.hip_attention.gen3.mask_refresh_interval import HiPMaskRefreshState + from hip.models.hip_attention.gen3 import HiPMaskRefreshState # For keeping track of HiP attention mask refresh cycles hip_mask_refresh_state = HiPMaskRefreshState() diff --git a/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py b/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py index 74cf67f809..4626d15249 100644 --- a/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py +++ b/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py @@ -1,20 +1,18 @@ +from __future__ import annotations import logging -import os -import threading -import time -from typing import Dict, Set, Tuple, Any +from typing import Tuple, Any, TYPE_CHECKING import torch -from hip.models.hip_attention.gen3.uvm_gpu_cache import ( - HiPOffloadCache, - format_size_bytes, -) from torch import Tensor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, MHATokenToKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch +if TYPE_CHECKING: + from hip.models.hip_attention.gen3 import HiPOffloadCache + from hip.models.hip_attention.gen3 import HiPAttentionConfig + logger = logging.getLogger(__name__) @@ -30,104 +28,25 @@ def __init__( head_dim: int, layer_num: int, device: torch.device, - hip_config: "HiPAttentionConfig", + hip_config: HiPAttentionConfig, ): + super().__init__(max_token_size, dtype, device) assert isinstance(device, torch.device) assert device.index is not None - super().__init__(max_token_size, dtype, device) - - # TODO: derive token sizes from size - self.head_num = head_num - self.head_dim = head_dim - self.layer_num = layer_num - self.max_mask_cache_token_size = max_mask_cache_token_size * head_num - self.max_sa_cache_token_size = max_sa_cache_token_size * head_num - - self.online_update_cache = os.getenv("DEBUG_ONLINE", "0") == "1" - self.layer_buffer = [] - for layer_id in range(layer_num): - self.layer_buffer.append( - HiPOffloadCache( - layer_id=layer_id, - max_token_size=max_token_size + 1, - max_mask_cache_token_size=min( - max_token_size * head_num, self.max_mask_cache_token_size - ), - max_sa_cache_token_size=min( - max_token_size * head_num, self.max_sa_cache_token_size - ), - head_num=head_num, - head_dim=head_dim, - dtype=dtype, - device=device, - online_cache_update=self.online_update_cache, - ) - if layer_id not in hip_config.dense_layers - else HiPOffloadCache( - layer_id=layer_id, - max_token_size=max_token_size + 1, - max_mask_cache_token_size=min( - max_token_size * head_num, self.max_mask_cache_token_size * 2 - ), - max_sa_cache_token_size=min( - max_token_size * head_num, self.max_sa_cache_token_size * 2 - ), - head_num=head_num, - head_dim=head_dim, - dtype=dtype, - device=device, - online_cache_update=self.online_update_cache, - ) - ) - - uvm_allocated_bytes, gpu_allocated_bytes = self.calc_allocated_bytes() - logger.info( - f"[{layer_id + 1}/{layer_num}] " - f"CPU (UVM): {format_size_bytes(uvm_allocated_bytes)} and " - f"GPU: {format_size_bytes(gpu_allocated_bytes)} are allocated. " - f"({self.dtype} on {self.device})" - ) - - # (layer_id, batch_id) -> (K, V, seq_len) - self.prefetch_threads: Dict[Tuple[int, int], threading.Thread] = {} - self.prefetched_kv: Dict[Tuple[int, int], Tuple[Tensor, Tensor, int]] = {} - self.async_set_threads: Set[threading.Thread] = set() - - self.enable_async = os.getenv("HIP_DISABLE_AYSNC", "0") == "0" - - # uvm_allocated_bytes, gpu_allocated_bytes = self.calc_allocated_bytes() - # logger.info( - # f'Allocated total CPU (UVM) bytes: {format_size_bytes(uvm_allocated_bytes)}, ' - # f'Allocated total GPU bytes: {format_size_bytes(gpu_allocated_bytes)}, ' - # f'{self.dtype} on {self.device}' - # ) - - self.require_validation = os.getenv("HIP_OFFLOAD_CACHE_VALIDATION", "0") == "1" - if self.require_validation: - self.validation_cache = MHATokenToKVPool( - max_token_size, - dtype=dtype, - head_num=head_num, - head_dim=head_dim, - layer_num=layer_num, - device=self.device, - enable_memory_saver=False, - ) - else: - self.validation_cache = None - - def calc_allocated_bytes(self): - uvm_allocated_bytes = 0 - gpu_allocated_bytes = 0 - for cache in self.layer_buffer: - uvm_allocated_bytes += cache.k_uvm.allocated_cpu_bytes - gpu_allocated_bytes += cache.k_uvm.allocated_gpu_bytes - uvm_allocated_bytes += cache.v_uvm.allocated_cpu_bytes - gpu_allocated_bytes += cache.v_uvm.allocated_gpu_bytes - gpu_allocated_bytes += cache.mask_k_cache.allocated_gpu_bytes - gpu_allocated_bytes += cache.sa_kv_cache.allocated_gpu_bytes - return uvm_allocated_bytes, gpu_allocated_bytes + from hip.models.hip_attention.gen3 import HiPModelOffloadCache + + self.offload_cache = HiPModelOffloadCache( + max_token_size=max_token_size, + max_mask_cache_token_size=max_mask_cache_token_size, + max_sa_cache_token_size=max_sa_cache_token_size, + dtype=dtype, + head_num=head_num, + head_dim=head_dim, + layer_num=layer_num, + device=device, + hip_config=hip_config, + ) def get_key_buffer(self, layer_id: int): raise NotImplementedError() @@ -136,64 +55,7 @@ def get_value_buffer(self, layer_id: int): raise NotImplementedError() def get_kv_buffer(self, layer_id: int) -> Tuple[HiPOffloadCache, Any]: - # Use this function for decode, pass this to `k` - if self.require_validation: - return self.layer_buffer[layer_id], self.validation_cache.get_kv_buffer(layer_id) - return self.layer_buffer[layer_id], None - - def prefetch_prefix_kv_buffer( - self, layer_id: int, batch_id: int, table: Tensor, prefix_seq_len: int - ) -> threading.Thread: - # you must call before get fetched prefix - assert table.ndim == 1 - - hip_offload_cache, _ = self.get_kv_buffer(layer_id) - - handle_id = (layer_id, batch_id) - assert handle_id not in self.prefetch_threads, handle_id - assert handle_id not in self.prefetched_kv, handle_id - - if self.enable_async: - start_event = torch.cuda.Event() - table = table.to(torch.int64).to("cpu") - start_event.record() - - # torch.cuda.synchronize() - def thread_main(): - try: - # BUG(heejun): i think this line is quite suspicious hmm - start_event.synchronize() - stream = torch.cuda.Stream(device=self.device, priority=0) - - with torch.cuda.stream(stream): - k, v = hip_offload_cache.prefetch_prefix_kv_buffer( - table=table, - device=self.device, - ) - assert k.device == self.device - assert v.device == self.device - - stream.synchronize() - self.prefetched_kv[handle_id] = (k, v, prefix_seq_len, table) - except Exception as ex: - print(f"{handle_id} thread dead") - raise Exception("thread dead") from ex - finally: - self.prefetch_threads.pop(handle_id) - - t = threading.Thread(target=thread_main, daemon=True) - self.prefetch_threads[handle_id] = t - t.start() - else: - k, v = hip_offload_cache.prefetch_prefix_kv_buffer( - table=table.to(torch.int64), - device=self.device, - ) - assert k.device == self.device - assert v.device == self.device - torch.cuda.synchronize() - self.prefetched_kv[handle_id] = (k, v, prefix_seq_len, table) - return + return self.offload_cache.get_kv_buffer(layer_id) def get_fetched_prefix_kv_buffer( self, @@ -203,82 +65,9 @@ def get_fetched_prefix_kv_buffer( cache_k: Tensor, cache_v: Tensor, ) -> Tuple[Tensor, Tensor, Any]: - # return cache_k, cache_v - - # Use this function for prefill - handle_id = (layer_id, batch_id) - prefetch_thread = self.prefetch_threads.get(handle_id, None) - if prefetch_thread is not None: - while handle_id not in self.prefetched_kv: - time.sleep(0.0001) - # print('start join', flush=True) - # while True: - # try: - # prefetch_thread.join(timeout=1.0) - # print('joined') - # break - # except TimeoutError: - # print('timeout', layer_id, batch_id) - # except RuntimeError: - # print('runtime error wtf') - # raise RuntimeError('deadlock') - - assert handle_id in self.prefetched_kv, "did prefetch successed?" - k, v, prefix_seq_len, table = self.prefetched_kv.pop(handle_id) - - assert isinstance(k, Tensor) - assert isinstance(v, Tensor) - assert isinstance(prefix_seq_len, int) - assert k.shape == v.shape - assert k.ndim == 4, f"{k.shape}" - assert k.shape[0] == 1 - assert k.shape[1] >= prefix_seq_len - assert k.shape[2] == self.head_num - assert k.shape[3] == self.head_dim - assert k.dtype == v.dtype - assert k.dtype == self.dtype - assert cache_k.ndim == 4 - assert cache_k.shape[0] == 1 - assert cache_k.shape[2] == self.head_num - assert cache_k.shape[3] == self.head_dim - assert k.shape[1] == prefix_seq_len + cache_k.shape[1] - assert k.dtype in [ - torch.float8_e5m2, - torch.float16, - torch.bfloat16, - torch.float32, - ] - - if cache_k.dtype != self.dtype: - cache_k = cache_k.to(self.dtype) - cache_v = cache_v.to(self.dtype) - # if self.dtype not in [torch.float8_e5m2]: - # assert cache_k.dtype == self.dtype - # else: - # if cache_k.dtype != self.dtype: - # cache_k = cache_k.to(self.dtype) - # cache_v = cache_v.to(self.dtype) - - k[:, prefix_seq_len:, :, :] = cache_k - v[:, prefix_seq_len:, :, :] = cache_v - - if self.require_validation: - k_valid, v_valid = self.validation_cache.get_kv_buffer(layer_id) - - assert k.dtype == k_valid.dtype - - k_valid_packed = k_valid[table].unsqueeze(0) - v_valid_packed = v_valid[table].unsqueeze(0) - - k_err = ((k_valid_packed - k) ** 2).sum() - v_err = ((v_valid_packed - v) ** 2).sum() - - assert k_err < 1e-5, k_err - assert v_err < 1e-5, v_err - - return k, v, (k_valid, v_valid) - else: - return k, v, None + return self.offload_cache.get_fetched_prefix_kv_buffer( + layer_id, batch_id, cache_k, cache_v + ) def set_kv_buffer( self, @@ -289,143 +78,49 @@ def set_kv_buffer( async_copy: bool = False, push_to_gpu_cache: bool = False, ): - if self.require_validation: - self.validation_cache.set_kv_buffer( - layer, - table, - cache_k, - cache_v, - ) - - if not self.enable_async: - async_copy = False - - layer_id = layer.layer_id - # pass async_copy=True when only prefill (eager mode) - assert (not async_copy) or ( - async_copy and (not torch.cuda.is_current_stream_capturing()) + self.offload_cache.set_kv_buffer( + layer.layer_id, table, cache_k, cache_v, async_copy, push_to_gpu_cache ) - if cache_k.dtype != self.dtype: - cache_k = cache_k.to(self.dtype) - cache_v = cache_v.to(self.dtype) - - if async_copy: - start_event = torch.cuda.Event() - start_event.record() - - def thread_main(): - try: - start_event.synchronize() - stream = torch.cuda.Stream(device=self.device) - - with torch.cuda.stream(stream): - table_gpu = table.to(torch.int64) - table_cpu = table.to("cpu", non_blocking=False) - cache_k_cpu = cache_k.to("cpu", non_blocking=False) - cache_v_cpu = cache_v.to("cpu", non_blocking=False) - self.layer_buffer[layer_id].set_kv_buffer( - table=table_cpu, - table_gpu=table_gpu, - cache_k=cache_k_cpu, - cache_v=cache_v_cpu, - ) - stream.synchronize() - finally: - self.async_set_threads.remove(t) - - t = threading.Thread(target=thread_main, daemon=True) - self.async_set_threads.add(t) - t.start() - else: - self.layer_buffer[layer_id].set_kv_buffer( - table=table, - table_gpu=table, - cache_k=cache_k, - cache_v=cache_v, - ) - - def synchronize(self): - torch.cuda.synchronize(device=self.device) - t = time.time() - # you must call this function when finish prefill, before decode - while (len(self.prefetch_threads) > 0) or (len(self.async_set_threads) > 0): - time.sleep(0.001) - assert len(self.prefetch_threads) == 0 - assert len(self.async_set_threads) == 0 - assert len(self.prefetched_kv) == 0 - elapsed = time.time() - t - logger.debug(f"Final layer sync took {elapsed * 1024:.4f} ms") - - def prefetch_layer(self, forward_batch: ForwardBatch, layer_id: int): - assert isinstance(forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool) - assert forward_batch.token_to_kv_pool == self - - for ibatch in range(forward_batch.batch_size): - req_to_tokens = forward_batch.req_to_token_pool.req_to_token - req_pool_indices = forward_batch.req_pool_indices[ibatch : ibatch + 1] - block_table = req_to_tokens.index_select(dim=0, index=req_pool_indices)[ - 0, - : forward_batch.extend_prefix_lens_cpu[ibatch] - + forward_batch.extend_seq_lens_cpu[ibatch], - ] - # print(block_table, block_table.shape) - self.prefetch_prefix_kv_buffer( - layer_id=layer_id, - batch_id=ibatch, - table=block_table, - prefix_seq_len=forward_batch.extend_prefix_lens_cpu[ibatch], - ) - - def wait_prefetch_layer(self, forward_batch: ForwardBatch, layer_id: int): - assert isinstance(forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool) - assert forward_batch.token_to_kv_pool == self - - for ibatch in range(forward_batch.batch_size): - while (layer_id, ibatch) not in self.prefetched_kv: - time.sleep(0.0001) - def on_model_start(self, forward_batch: ForwardBatch): - require_prefetch = forward_batch.forward_mode.is_extend() assert forward_batch.token_to_kv_pool == self - if require_prefetch: - # FIXME: find better way to detect this. - is_first_chunk = forward_batch.extend_prefix_lens_cpu[0] == 0 - # FIXME: find better way to detect this. - is_inter_chunk = forward_batch.extend_seq_lens_cpu[0] in map( - lambda x: 2**x, range(0, 20) - ) - # BUG(heejun): at the last chunk of prefill, prefetch layer sometimes failes... so disable async - if not ( - forward_batch.batch_size == 1 and (is_first_chunk or is_inter_chunk) - ): - self.onetime_disable = self.enable_async - self.enable_async = False - else: - self.onetime_disable = False - self.prefetch_layer(forward_batch, 0) - # self.wait_prefetch_layer(forward_batch, 0) + self.offload_cache.on_model_start( + forward_batch.forward_mode.is_extend(), + forward_batch.batch_size, + forward_batch.req_to_token_pool.req_to_token, + forward_batch.req_pool_indices, + forward_batch.extend_prefix_lens_cpu, + forward_batch.extend_seq_lens_cpu, + ) def on_model_end(self, forward_batch: ForwardBatch): - require_prefetch = forward_batch.forward_mode.is_extend() assert forward_batch.token_to_kv_pool == self - if require_prefetch: - self.synchronize() - self.enable_async = self.enable_async or self.onetime_disable - self.onetime_disable = False + self.offload_cache.on_model_end( + forward_batch.forward_mode.is_extend(), + ) def on_layer_start(self, forward_batch: ForwardBatch, layer_id: int): - require_prefetch = forward_batch.forward_mode.is_extend() assert forward_batch.token_to_kv_pool == self - if require_prefetch and (layer_id < (self.layer_num - 1)): - self.prefetch_layer(forward_batch, layer_id + 1) + self.offload_cache.on_layer_start( + layer_id, + forward_batch.forward_mode.is_extend(), + forward_batch.batch_size, + forward_batch.req_to_token_pool.req_to_token, + forward_batch.req_pool_indices, + forward_batch.extend_prefix_lens_cpu, + forward_batch.extend_seq_lens_cpu, + ) def on_layer_end(self, forward_batch: ForwardBatch, layer_id: int): - require_prefetch = forward_batch.forward_mode.is_extend() assert forward_batch.token_to_kv_pool == self - if require_prefetch and (layer_id < (self.layer_num - 1)): - torch.cuda.current_stream(self.device).synchronize() + self.offload_cache.on_layer_end( + layer_id, + forward_batch.forward_mode.is_extend(), + ) + + def is_online_cache_update_enabled(self): + return self.offload_cache.is_online_cache_update_enabled() diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index c0bf582e7e..baffd666dc 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -48,6 +48,8 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm + from hip.models.hip_attention.gen3 import HiPMetadataCachePool + class ForwardMode(IntEnum): # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. @@ -185,7 +187,7 @@ class ForwardBatch: attn_backend: AttentionBackend = None # For HiP attention - hip_metadata_cache_pool: Optional["HiPMetadataCachePool"] = None + hip_metadata_cache_pool: Optional[HiPMetadataCachePool] = None hip_metadata_cached_stage: Optional[int] = None # For DP attention diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d909d4b520..b765d591af 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -697,7 +697,7 @@ def init_memory_pool( self.hip_metadata_cache_pool = None if self.server_args.enable_hip_attention: - from hip.models.hip_attention.gen3.hip_memory_pool import HiPMetadataCachePool + from hip.models.hip_attention.gen3 import HiPMetadataCachePool self.hip_metadata_cache_pool = HiPMetadataCachePool( query_head_num=self.model_config.num_attention_heads // self.server_args.tp_size, layer_num=self.model_config.num_hidden_layers, @@ -764,7 +764,7 @@ def init_double_sparsity_channel_config(self, selected_channel): ) def init_hip_attention_config(self, hip_attention_config): - from hip.models.hip_attention.gen3.hip_config import HiPAttentionConfig + from hip.models.hip_attention.gen3 import HiPAttentionConfig if hip_attention_config is None: hip_attention_config = {} elif hip_attention_config.startswith("{"): From cbd4705fdbbfc0322e33cdd5139d183e7e98f84d Mon Sep 17 00:00:00 2001 From: Geon Park Date: Wed, 29 Jan 2025 04:53:46 +0900 Subject: [PATCH 12/16] cleanup --- python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py b/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py index 4626d15249..08ab504625 100644 --- a/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py +++ b/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py @@ -6,7 +6,7 @@ from torch import Tensor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, MHATokenToKVPool +from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch if TYPE_CHECKING: From 09fa13978fa306f2abc06faffaceb82d62869460 Mon Sep 17 00:00:00 2001 From: Geon Park Date: Wed, 29 Jan 2025 05:51:05 +0900 Subject: [PATCH 13/16] run pre_commit --- .../hip_attention/hip_radix_attention.py | 48 ++++++++++++------- .../srt/managers/tp_worker_overlap_thread.py | 10 ++-- .../srt/mem_cache/hip_offload_kv_pool_mha.py | 6 +-- .../srt/model_executor/cuda_graph_runner.py | 2 +- .../srt/model_executor/forward_batch_info.py | 4 +- .../sglang/srt/model_executor/model_runner.py | 5 +- python/sglang/srt/models/llama.py | 4 +- 7 files changed, 51 insertions(+), 28 deletions(-) diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py index eb2f64bd67..9686a2cca2 100644 --- a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py +++ b/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py @@ -16,11 +16,12 @@ from sglang.srt.mem_cache.hip_offload_kv_pool_mha import MHATokenToHiPOffloadKVPool if TYPE_CHECKING: + from hip.models.hip_attention.gen3 import HiPAttentionConfig + from sglang.srt.layers.radix_attention import RadixAttention - from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode + from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.speculative.spec_info import SpecInfo - from hip.models.hip_attention.gen3 import HiPAttentionConfig logger = logging.getLogger(__name__) @@ -31,6 +32,7 @@ def __init__(self, model_runner: ModelRunner): super().__init__() from hip.models.hip_attention.gen3 import forward_paged_hip + self.forward_paged_hip = forward_paged_hip self.hip_config: HiPAttentionConfig = model_runner.hip_attention_config @@ -91,11 +93,15 @@ def forward_extend( assert v is not None if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) - k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) offload_cache = None else: # Offloading enabled - assert isinstance(forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool) + assert isinstance( + forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool + ) if k is not None: assert v is not None if save_kv_cache: @@ -128,23 +134,25 @@ def forward_extend( forward_batch.token_to_kv_pool.get_fetched_prefix_kv_buffer( layer_id=layer.layer_id, batch_id=idx_batch, - cache_k=k[start_len: start_len + seq_len].unsqueeze(0), - cache_v=v[start_len: start_len + seq_len].unsqueeze(0), + cache_k=k[start_len : start_len + seq_len].unsqueeze(0), + cache_v=v[start_len : start_len + seq_len].unsqueeze(0), ) ) offload_cache = k_cache = v_cache = None o_req, _ = self.forward_paged_hip( - query=q_reshaped[start_len:start_len + seq_len], + query=q_reshaped[start_len : start_len + seq_len], sm_scale=layer.scaling, batch_size=1, k_cache=k_cache, v_cache=v_cache, offload_cache=offload_cache, - positions=forward_batch.positions[start_len:start_len + seq_len], - seq_lens=forward_batch.seq_lens[idx_batch:idx_batch + 1], + positions=forward_batch.positions[start_len : start_len + seq_len], + seq_lens=forward_batch.seq_lens[idx_batch : idx_batch + 1], req_to_tokens=forward_batch.req_to_token_pool.req_to_token, - req_pool_indices=forward_batch.req_pool_indices[idx_batch:idx_batch + 1], + req_pool_indices=forward_batch.req_pool_indices[ + idx_batch : idx_batch + 1 + ], rope_cos=layer.rope_cos, rope_sin=layer.rope_sin, layer_id=layer.layer_id, @@ -157,12 +165,13 @@ def forward_extend( v=v_chunk, online_update_cache=( forward_batch.token_to_kv_pool.is_online_cache_update_enabled() - if self.is_offload_enabled else None + if self.is_offload_enabled + else None ), offloading_metadata=offloading_metadata, ) - o[start_len:start_len + seq_len] = o_req + o[start_len : start_len + seq_len] = o_req start_len += seq_len @@ -199,11 +208,15 @@ def forward_decode( assert v is not None if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) - k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) offload_cache = None else: # Offloading enabled - assert isinstance(forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool) + assert isinstance( + forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool + ) if k is not None: assert v is not None if save_kv_cache: @@ -212,7 +225,9 @@ def forward_decode( ) k_cache = v_cache = None - offload_cache, offloading_metadata = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + offload_cache, offloading_metadata = ( + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + ) o, metadata = self.forward_paged_hip( query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), @@ -236,7 +251,8 @@ def forward_decode( cached_metadata=metadata, online_update_cache=( forward_batch.token_to_kv_pool.is_online_cache_update_enabled() - if self.is_offload_enabled else None + if self.is_offload_enabled + else None ), ) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 69ee4c8ef7..fe9ad102c1 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -135,10 +135,12 @@ def forward_thread_func_(self): break if hip_mask_refresh_state is not None: - model_worker_batch.hip_metadata_cached_stages = hip_mask_refresh_state.update( - model_worker_batch.forward_mode.is_decode(), - model_worker_batch.forward_mode.is_extend(), - self.hip_attention_config, + model_worker_batch.hip_metadata_cached_stages = ( + hip_mask_refresh_state.update( + model_worker_batch.forward_mode.is_decode(), + model_worker_batch.forward_mode.is_extend(), + self.hip_attention_config, + ) ) # Keep a reference of model_worker_batch by storing it into a list. diff --git a/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py b/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py index 08ab504625..06f2af1476 100644 --- a/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py +++ b/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py @@ -1,6 +1,7 @@ from __future__ import annotations + import logging -from typing import Tuple, Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Tuple import torch from torch import Tensor @@ -10,8 +11,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch if TYPE_CHECKING: - from hip.models.hip_attention.gen3 import HiPOffloadCache - from hip.models.hip_attention.gen3 import HiPAttentionConfig + from hip.models.hip_attention.gen3 import HiPAttentionConfig, HiPOffloadCache logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index c3fe39a554..d87b5fb37d 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -343,7 +343,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable, capture_config: tup hip_num_cached_stages = None if self.enable_hip_attention: - hip_num_cached_stages, = capture_config + (hip_num_cached_stages,) = capture_config forward_batch = ForwardBatch( forward_mode=self.capture_forward_mode, diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index baffd666dc..b1adfba4b8 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -41,6 +41,8 @@ from sglang.srt.utils import get_compiler_backend if TYPE_CHECKING: + from hip.models.hip_attention.gen3 import HiPMetadataCachePool + from sglang.srt.layers.attention import AttentionBackend from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool @@ -48,8 +50,6 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm - from hip.models.hip_attention.gen3 import HiPMetadataCachePool - class ForwardMode(IntEnum): # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b765d591af..41a0390101 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -698,8 +698,10 @@ def init_memory_pool( self.hip_metadata_cache_pool = None if self.server_args.enable_hip_attention: from hip.models.hip_attention.gen3 import HiPMetadataCachePool + self.hip_metadata_cache_pool = HiPMetadataCachePool( - query_head_num=self.model_config.num_attention_heads // self.server_args.tp_size, + query_head_num=self.model_config.num_attention_heads + // self.server_args.tp_size, layer_num=self.model_config.num_hidden_layers, context_length=self.model_config.context_len, device=self.device, @@ -765,6 +767,7 @@ def init_double_sparsity_channel_config(self, selected_channel): def init_hip_attention_config(self, hip_attention_config): from hip.models.hip_attention.gen3 import HiPAttentionConfig + if hip_attention_config is None: hip_attention_config = {} elif hip_attention_config.startswith("{"): diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index f5b561927f..a2619516db 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -170,7 +170,9 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, - orig_context_len=getattr(config, "orig_context_len", max_position_embeddings), + orig_context_len=getattr( + config, "orig_context_len", max_position_embeddings + ), rope=self.rotary_emb, ) From 0a77144589f8f585811a5cb6a3a44a1bbf60538e Mon Sep 17 00:00:00 2001 From: Geon Park Date: Wed, 29 Jan 2025 06:51:19 +0900 Subject: [PATCH 14/16] move hip_radix_attention to upper dir --- python/sglang/srt/layers/attention/hip_attention/__init__.py | 1 - .../layers/attention/{hip_attention => }/hip_radix_attention.py | 0 2 files changed, 1 deletion(-) delete mode 100644 python/sglang/srt/layers/attention/hip_attention/__init__.py rename python/sglang/srt/layers/attention/{hip_attention => }/hip_radix_attention.py (100%) diff --git a/python/sglang/srt/layers/attention/hip_attention/__init__.py b/python/sglang/srt/layers/attention/hip_attention/__init__.py deleted file mode 100644 index 99c783c073..0000000000 --- a/python/sglang/srt/layers/attention/hip_attention/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .hip_radix_attention import HiPRadixAttentionBackend diff --git a/python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py b/python/sglang/srt/layers/attention/hip_radix_attention.py similarity index 100% rename from python/sglang/srt/layers/attention/hip_attention/hip_radix_attention.py rename to python/sglang/srt/layers/attention/hip_radix_attention.py From a9592ca246bcbbfb0f3c068876ba4ff5cb806cbc Mon Sep 17 00:00:00 2001 From: Geon Park Date: Wed, 29 Jan 2025 06:53:10 +0900 Subject: [PATCH 15/16] fix imports --- python/sglang/srt/model_executor/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 41a0390101..214e7a0e01 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -36,7 +36,7 @@ from sglang.srt.hf_transformers_utils import get_context_length, update_context_length from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend -from sglang.srt.layers.attention.hip_attention import HiPRadixAttentionBackend +from sglang.srt.layers.attention.hip_radix_attention import HiPRadixAttentionBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.dp_attention import ( From 63fee4f9c2a43429c11d8f31c9e9d61c63dbd7ff Mon Sep 17 00:00:00 2001 From: Geon Park Date: Wed, 29 Jan 2025 06:58:39 +0900 Subject: [PATCH 16/16] fix copy-pasted block comment --- python/sglang/srt/layers/attention/hip_radix_attention.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/attention/hip_radix_attention.py b/python/sglang/srt/layers/attention/hip_radix_attention.py index 9686a2cca2..85cf7f45d3 100644 --- a/python/sglang/srt/layers/attention/hip_radix_attention.py +++ b/python/sglang/srt/layers/attention/hip_radix_attention.py @@ -1,10 +1,8 @@ from __future__ import annotations """ -Support different attention backends. -Now there are two backends: FlashInfer and Triton. -FlashInfer is faster and Triton is easier to customize. -Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode. +HiP Attention Backend for SGLang +https://arxiv.org/pdf/2406.09827 """ import logging