@@ -44,14 +44,49 @@ class WrapperDispatch(Enum):
4444
4545
4646class HiPRadixAttentionBackend (AttentionBackend ):
47-
4847 def __init__ (self , model_runner : HiPModelRunner ):
4948 super ().__init__ ()
5049
50+ # NOTE: this backend instance is only one time creation.
51+
5152 self .hip_config : HiPAttentionConfig = model_runner .hip_attention_config
5253
5354 self .max_context_len = model_runner .model_config .context_len
5455
56+ # NOTE: this is quite temporary one.
57+ self .q_buffers = [
58+ torch .zeros (
59+ (
60+ 1 ,
61+ self .hip_config .block_sparse_block_size_q ,
62+ model_runner .model_config .num_attention_heads
63+ // model_runner .tp_size ,
64+ model_runner .model_config .head_dim ,
65+ ),
66+ device = torch .device (model_runner .device ),
67+ dtype = model_runner .dtype ,
68+ )
69+ for _ in range (model_runner .model_config .num_hidden_layers )
70+ ]
71+ # NOTE: disable q caching
72+ self .q_buffers = None
73+
74+ def push_q_buffer (self , q : torch .Tensor , layer_id : int , batch_size : int ):
75+ if self .q_buffers is None :
76+ return
77+ assert batch_size == 1
78+ q = q .unsqueeze (0 )
79+ layer_q_buffer = self .q_buffers [layer_id ]
80+ q_buffer = torch .cat ([layer_q_buffer , q [:, - layer_q_buffer .shape [1 ] :]], dim = 1 )
81+ layer_q_buffer .copy_ (q_buffer [:, - layer_q_buffer .shape [1 ] :])
82+
83+ def get_q_buffer (self , layer_id : int , batch_size : int ) -> torch .Tensor :
84+ if self .q_buffers is not None :
85+ assert batch_size == 1
86+ return self .q_buffers [layer_id ].flatten (0 , 1 )
87+ else :
88+ return None
89+
5590 def init_forward_metadata (self , forward_batch : ForwardBatch ):
5691 pass
5792
@@ -130,6 +165,9 @@ def forward_extend(
130165 offload_cache = None
131166
132167 q_reshaped = q .reshape (- 1 , layer .tp_q_head_num , layer .head_dim )
168+ self .push_q_buffer (
169+ q_reshaped , layer_id = layer .layer_id , batch_size = forward_batch .batch_size
170+ )
133171
134172 # Output tensor
135173 o = torch .empty_like (q_reshaped )
@@ -349,9 +387,17 @@ def forward_decode(
349387 )
350388 offload_cache = None
351389
390+ self .push_q_buffer (
391+ q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim ),
392+ layer_id = layer .layer_id ,
393+ batch_size = forward_batch .batch_size ,
394+ )
395+ q_for_masking = self .get_q_buffer (layer .layer_id , forward_batch .batch_size )
396+
352397 if not require_validation :
353398 o , metadata = self .forward_paged_hip (
354399 query = q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim ),
400+ query_for_mask = q_for_masking ,
355401 sm_scale = layer .scaling ,
356402 batch_size = forward_batch .batch_size ,
357403 k_cache = k_cache ,
@@ -384,6 +430,7 @@ def sse(a: torch.Tensor, b: torch.Tensor):
384430
385431 o , metadata_new = self .forward_paged_hip (
386432 query = q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim ),
433+ query_for_mask = q_for_masking ,
387434 sm_scale = layer .scaling ,
388435 batch_size = forward_batch .batch_size ,
389436 k_cache = None ,
@@ -416,6 +463,7 @@ def sse(a: torch.Tensor, b: torch.Tensor):
416463
417464 o_valid , metadata_valid = self .forward_paged_hip (
418465 query = q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim ),
466+ query_for_mask = q_for_masking ,
419467 sm_scale = layer .scaling ,
420468 batch_size = forward_batch .batch_size ,
421469 k_cache = k_cache ,
@@ -491,6 +539,7 @@ def sse(a: torch.Tensor, b: torch.Tensor):
491539
492540 o_uvm , metadata_uvm = self .forward_paged_hip (
493541 query = q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim ),
542+ query_for_mask = q_for_masking ,
494543 sm_scale = layer .scaling ,
495544 batch_size = forward_batch .batch_size ,
496545 k_cache = offload_cache .k_uvm .bank_gpu ,
@@ -518,6 +567,7 @@ def sse(a: torch.Tensor, b: torch.Tensor):
518567
519568 o_retry , metadata_retry = self .forward_paged_hip (
520569 query = q .contiguous ().view (- 1 , layer .tp_q_head_num , layer .head_dim ),
570+ query_for_mask = q_for_masking ,
521571 sm_scale = layer .scaling ,
522572 batch_size = forward_batch .batch_size ,
523573 k_cache = None ,
@@ -599,6 +649,7 @@ def forward_paged_hip(
599649 is_dense : bool = False ,
600650 k : Optional [torch .Tensor ] = None ,
601651 v : Optional [torch .Tensor ] = None ,
652+ query_for_mask : Optional [torch .Tensor ] = None ,
602653 online_update_cache : bool = False ,
603654 is_decode : bool = False ,
604655 ) -> tuple [torch .Tensor , "HiPAttentionOutputMetadata" ]:
@@ -619,6 +670,8 @@ def forward_paged_hip(
619670 layer_config = self .hip_config .layers [layer .layer_id ]
620671
621672 query = query .view (batch_size , dst_seq_len , num_heads , hidden_dims )
673+ if query_for_mask is not None :
674+ query_for_mask = query_for_mask .view (batch_size , - 1 , num_heads , hidden_dims )
622675
623676 if k_cache is not None :
624677 N_PAGE , num_heads_kv , hidden_dims_kv = k_cache .shape
@@ -654,6 +707,16 @@ def forward_paged_hip(
654707 elif os .getenv ("HIP_DISABLE_COMPUTE_STATISTICS" , "1" ) == "0" :
655708 require_cache_statistics = True
656709
710+ if query_for_mask is not None :
711+ query_position_ids = positions .view (batch_size , dst_seq_len )
712+ position_ids = (
713+ torch .arange (0 , query_for_mask .shape [1 ], device = query .device )[None , :]
714+ - (query_for_mask .shape [1 ] - 1 )
715+ + query_position_ids
716+ )
717+ else :
718+ position_ids = positions .view (batch_size , dst_seq_len )
719+
657720 args = HiPAttentionArgs (
658721 k_cache = (
659722 k_cache .view (torch .uint8 )
@@ -670,7 +733,7 @@ def forward_paged_hip(
670733 offload_cache = offload_cache ,
671734 block_table = block_table ,
672735 cache_seq_lens = seq_lens ,
673- position_ids = positions . view ( batch_size , dst_seq_len ) ,
736+ position_ids = position_ids ,
674737 block_size_k = 32 if is_gemma else 64 , # BLOCK_CHUNK
675738 sliding_window_size = layer_config .sliding_window_size ,
676739 sink_token_size = layer_config .sink_token_size ,
@@ -697,6 +760,11 @@ def forward_paged_hip(
697760 online_update_cache = online_update_cache ,
698761 require_cache_statistics = require_cache_statistics ,
699762 disable_flashdecode = not is_decode ,
763+ q_mask = (
764+ (query_for_mask * sm_scale ).to (query .dtype )
765+ if query_for_mask is not None
766+ else None
767+ ),
700768 )
701769
702770 context , metadata = dual_stage_quadratic_hip_attention (
@@ -707,5 +775,6 @@ def forward_paged_hip(
707775 cached_metadata = cached_metadata ,
708776 )
709777 context = context .to (query .dtype )
778+ context = context [:, - query .shape [1 ] :, :, :].contiguous ()
710779
711780 return context .view (N , num_heads , hidden_dims ), metadata
0 commit comments