11import functools
2- from functools import wraps
32from types import MethodType
43
54import torch
6- from loguru import logger
75
86from llmc .utils .registry_factory import TOKEN_REDUCTION_REGISTRY
97
@@ -21,61 +19,21 @@ def __init__(self, config, model, blocks):
2119 def add_sparse_config (self ):
2220
2321 self .pruning_loc = self .special_config ['pruning_loc' ]
24- self .special_config ['image_token_length' ] = self .model .pruning_config [
25- 'image_token_length'
26- ]
27-
2822 self .pruning_paras = self .special_config
2923
3024 def register_reduction_modules (self ):
3125
32- def input_hook_llava (fn , pruning_paras ):
33- @wraps (fn )
34- def wrapper (self , * args , ** kwargs ):
35- if len (args ) == 0 :
36- return fn (* args , ** kwargs )
37- input_args = args [0 ]
38- if hasattr (input_args [0 ], 'shape' ) and input_args [0 ].shape [0 ] == 1 :
39- return fn (* args , ** kwargs )
40-
41- input_ids = args [0 ]
42- attention_mask = args [2 ]
43- token_indices = input_ids [0 ][attention_mask [0 ]] == IMAGE_TOKEN_INDEX
44- pruning_paras ['image_token_start_index' ] = torch .where (token_indices )[
45- 0
46- ][0 ].item ()
47-
48- outputs = fn (* args , ** kwargs )
49- return outputs
50-
51- return wrapper
52-
53- @prefill_wrapper
54- def input_hook (module , input_args , pruning_paras ):
55- input_ids = input_args [0 ]
56- image_token_idxs = (
57- input_ids [0 ] == pruning_paras ['vision_token_index' ]
58- ).nonzero (as_tuple = True )[0 ]
59- pruning_paras ['image_token_start_index' ] = image_token_idxs [0 ].item ()
60-
61- return input_args
62-
6326 @prefill_wrapper
6427 def random_pruning_hook (module , args , kwargs , pruning_paras ):
6528
66- logger .info (' ========random_pruning_hook======== ' )
67-
68- rate = pruning_paras ['rate' ]
69- image_token_start_index = pruning_paras ['image_token_start_index' ]
70- image_token_length = pruning_paras ['image_token_length' ]
29+ rate = pruning_paras ['prune_ratio' ]
30+ image_token_start_index = pruning_paras ['vision_token_start_index' ]
31+ image_token_length = pruning_paras ['vision_token_length' ]
7132
7233 hidden_states = args [0 ]
7334 causal_mask = kwargs ['attention_mask' ]
7435
75- logger .info (f'before hidden_states : { hidden_states .shape } ' )
76-
7736 device = hidden_states .device
78-
7937 vision_indexes = torch .arange (
8038 image_token_start_index ,
8139 image_token_start_index + image_token_length ,
@@ -130,25 +88,169 @@ def random_pruning_hook(module, args, kwargs, pruning_paras):
13088 position_embeddings [0 ].resize_as_ (new_pe0 ).copy_ (new_pe0 )
13189 position_embeddings [1 ].resize_as_ (new_pe0 ).copy_ (new_pe1 )
13290
133- logger .info (f'after hidden_states : { hidden_states .shape } ' )
13491 return (hidden_states ,), kwargs
13592
136- if self .model .__class__ .__name__ == 'LlavaHf' :
137- self .model .embed_tokens .register_forward_pre_hook (
138- functools .partial (input_hook , pruning_paras = self .pruning_paras )
93+ @prefill_wrapper
94+ def holitom_merge_hook (module , args , kwargs , pruning_paras ):
95+
96+ rate = pruning_paras ['prune_ratio' ]
97+ image_token_start_index = pruning_paras ['vision_token_start_index' ]
98+ image_token_length = pruning_paras ['vision_token_length' ]
99+
100+ hidden_states = args [0 ]
101+ causal_mask = kwargs ['attention_mask' ]
102+
103+ device = hidden_states .device
104+ last_layer_attention = pruning_paras ['attn_scores' ]
105+ # compute average attention over different head
106+ last_layer_attention_avg = torch .mean (
107+ last_layer_attention , dim = 1
108+ )[0 ]
109+ # generate new attention mask based on the average attention,
110+ # sample the top ATTENTION_RANK tokens with highest attention
111+ last_layer_attention_avg_last_tok = (
112+ last_layer_attention_avg [- 1 ]
113+ )
114+ # get the attention in image token
115+ last_layer_attention_avg_last_tok_image = \
116+ last_layer_attention_avg_last_tok [
117+ image_token_start_index :
118+ image_token_start_index + image_token_length
119+ ]
120+ # get the indexes of the top ATTENTION_RANK tokens
121+ top_attention_rank_index = (
122+ last_layer_attention_avg_last_tok_image .topk (
123+ round (
124+ image_token_length * (1 - rate )
125+ )
126+ ).indices
127+ + image_token_start_index
139128 )
140- elif self .model .__class__ .__name__ == 'Llava' :
141- from llava .constants import IMAGE_TOKEN_INDEX
142129
143- hook_fn = input_hook_llava (
144- self .model .vlm_model .prepare_inputs_labels_for_multimodal ,
145- self .pruning_paras ,
130+ all_indices = torch .arange (
131+ image_token_length , device = device
146132 )
147- self .model .vlm_model .prepare_inputs_labels_for_multimodal = MethodType (
148- hook_fn , self .model .vlm_model
133+ non_topk_mask = ~ torch .isin (
134+ all_indices ,
135+ top_attention_rank_index
136+ - image_token_start_index ,
137+ )
138+ non_topk_indices = (
139+ all_indices [non_topk_mask ]
140+ + image_token_start_index
141+ )
142+ non_topk_states = hidden_states [
143+ :, non_topk_indices , :
144+ ] # [batch_size, len(non_topk), hidden_size]
145+ topk_states = hidden_states [
146+ :, top_attention_rank_index , :
147+ ] # [batch_size, len(topk), hidden_size]
148+ non_topk_norm = torch .norm (
149+ non_topk_states , dim = - 1 , keepdim = True
150+ ) # [batch_size, len(non_topk), 1]
151+ topk_norm = torch .norm (
152+ topk_states , dim = - 1 , keepdim = True
153+ ) # [batch_size, len(topk), 1]
154+ dot_product = torch .bmm (
155+ non_topk_states , topk_states .transpose (1 , 2 )
156+ ) # [batch_size, len(non_topk), len(topk)]
157+ sim_matrix = dot_product / (
158+ non_topk_norm * topk_norm .transpose (1 , 2 )
159+ )
160+ sim_max , sim_max_index = torch .max (sim_matrix , dim = - 1 )
161+
162+ batch_size = hidden_states .size (0 )
163+ num_topk = len (top_attention_rank_index )
164+ num_non_topk = len (non_topk_indices )
165+ topk_counter = torch .ones ((batch_size , num_topk , 1 ), device = hidden_states .device )
166+
167+ for b in range (batch_size ):
168+ for i in range (num_non_topk ):
169+ topk_rel_idx = sim_max_index [b , i ].item () # 这是 topk 中的相对索引
170+ topk_abs_idx = top_attention_rank_index [topk_rel_idx ] # 得到绝对索引
171+ non_topk_abs_idx = non_topk_indices [i ]
172+
173+ # 累加non-topk到topk token上(就地)
174+ hidden_states [b , topk_abs_idx , :] += hidden_states [b , non_topk_abs_idx , :]
175+ # 增加计数
176+ topk_counter [b , topk_rel_idx ] += 1
177+
178+ # 平均化所有topk token(包含自己和所有被合并的)
179+ for b in range (batch_size ):
180+ for i in range (num_topk ):
181+ topk_abs_idx = top_attention_rank_index [i ]
182+ hidden_states [b , topk_abs_idx , :] /= topk_counter [b , i ]
183+
184+ keep_indexs = torch .cat (
185+ (
186+ torch .arange (
187+ image_token_start_index ,
188+ device = device ,
189+ ),
190+ top_attention_rank_index ,
191+ torch .arange (
192+ image_token_start_index
193+ + image_token_length ,
194+ hidden_states .shape [1 ],
195+ device = device ,
196+ ),
197+ )
149198 )
150199
151- self .blocks [self .pruning_loc ].register_forward_pre_hook (
152- functools .partial (random_pruning_hook , pruning_paras = self .pruning_paras ),
153- with_kwargs = True ,
154- )
200+ # sort index
201+ keep_indexs = keep_indexs .sort ().values
202+ # filter hidden states &
203+ hidden_states = hidden_states [:, keep_indexs , :]
204+ # update position ids
205+ position_ids = keep_indexs .unsqueeze (0 )
206+ # update attention mask
207+ if causal_mask is not None :
208+ causal_mask = causal_mask [:, :, :hidden_states .shape [1 ], :hidden_states .shape [1 ]]
209+ kwargs ['attention_mask' ].resize_as_ (causal_mask ).copy_ (causal_mask .clone ())
210+ kwargs ['cache_position' ].resize_as_ (position_ids .squeeze (0 )).copy_ (
211+ position_ids .squeeze (0 ).clone ())
212+ kwargs ['position_ids' ].resize_as_ (position_ids ).copy_ (position_ids .clone ())
213+
214+ position_embeddings = kwargs ['position_embeddings' ]
215+ index_dim = 1 if position_embeddings [0 ].dim () == 3 else 2
216+ new_pe0 = position_embeddings [0 ].index_select (index_dim , keep_indexs ).clone ()
217+ new_pe1 = position_embeddings [1 ].index_select (index_dim , keep_indexs ).clone ()
218+ position_embeddings [0 ].resize_as_ (new_pe0 ).copy_ (new_pe0 )
219+ position_embeddings [1 ].resize_as_ (new_pe0 ).copy_ (new_pe1 )
220+
221+ return (hidden_states ,), kwargs
222+
223+ def update_output_attentions_hook (module , args , kwargs ):
224+ kwargs ['output_attentions' ] = True
225+ return args , kwargs
226+
227+ def store_attention_hook (m , x , layer_outputs , pruning_paras ):
228+ layer_attention = layer_outputs [1 ]
229+ pruning_paras ['attn_scores' ] = layer_attention
230+
231+ if self .special_config ['vision_token_length' ] is None :
232+ if self .model .__class__ .__name__ == 'Llava' :
233+ self .model .vlm_model .prepare_inputs_labels_for_multimodal = MethodType (
234+ self .vtoken_length_for_llava_hook (
235+ self .model .vlm_model .prepare_inputs_labels_for_multimodal ,
236+ self .pruning_paras
237+ ), self .model .vlm_model
238+ )
239+
240+ if self .special_config ['metric' ] == 'random' :
241+ self .blocks [self .pruning_loc ].register_forward_pre_hook (
242+ functools .partial (random_pruning_hook , pruning_paras = self .pruning_paras ),
243+ with_kwargs = True
244+ )
245+ elif self .special_config ['metric' ] == 'holitom_merge' :
246+ self .blocks [self .pruning_loc - 1 ].register_forward_pre_hook (
247+ update_output_attentions_hook ,
248+ with_kwargs = True
249+ )
250+ self .blocks [self .pruning_loc - 1 ].register_forward_hook (
251+ functools .partial (store_attention_hook , pruning_paras = self .pruning_paras ),
252+ )
253+ self .blocks [self .pruning_loc ].register_forward_pre_hook (
254+ functools .partial (holitom_merge_hook , pruning_paras = self .pruning_paras ),
255+ with_kwargs = True
256+ )
0 commit comments