1717sparse_token_list_192 = []
1818sparse_token_list_128 = []
1919sparse_token_list_64 = []
20+ sparse_token_list_960 = []
2021sparse_token_list_640 = []
2122sparse_token_list_320 = []
2223sparse_token_list_160 = []
@@ -329,8 +330,9 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_paras, laye
329330
330331 if attention_mask is not None :
331332 attention_mask = attention_mask [:, :, keep_indexs , keep_indexs ]
332- new_pe0 = position_embeddings [0 ][:, keep_indexs , :].clone ()
333- new_pe1 = position_embeddings [1 ][:, keep_indexs , :].clone ()
333+ index_dim = 1 if position_embeddings [0 ].dim () == 3 else 2
334+ new_pe0 = position_embeddings [0 ].index_select (index_dim , keep_indexs ).clone ()
335+ new_pe1 = position_embeddings [1 ].index_select (index_dim , keep_indexs ).clone ()
334336 position_embeddings = (new_pe0 , new_pe1 )
335337
336338 pruning_paras ['v_token_num' ] = v_token_num
@@ -352,6 +354,75 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):
352354
353355 return args , kwargs
354356
357+ @prefill_wrapper
358+ def vtoken_length_hook (module , args , pruning_paras ):
359+ input_ids = args [0 ]
360+ token_indices = torch .where (
361+ input_ids [0 ] == pruning_paras ['vision_token_index' ]
362+ )[0 ]
363+ pruning_paras ['vision_token_length' ] = token_indices .shape [0 ]
364+ pruning_paras ['pre_prompt_length_list' ] = [token_indices [0 ].item ()]
365+
366+ def get_attn_logits_for_qwen25vl (
367+ module ,
368+ args , kwargs , layer_outs ,
369+ pruning_paras , layer_idx
370+ ):
371+ if kwargs ['position_ids' ].shape [- 1 ] == 1 :
372+ return layer_outs
373+
374+ from transformers .models .qwen2_5_vl .modeling_qwen2_5_vl import (
375+ apply_multimodal_rotary_pos_emb , repeat_kv )
376+
377+ hidden_states = kwargs ['hidden_states' ]
378+ position_embeddings = kwargs ['position_embeddings' ]
379+ past_key_value = layer_outs [2 ]
380+ attention_mask = kwargs ['attention_mask' ]
381+
382+ t_token_idx = pruning_paras ['t_token_idx' ]
383+ v_token_start = pruning_paras ['v_token_start' ]
384+ v_token_num = pruning_paras ['v_token_num' ]
385+
386+ bsz , q_len , _ = hidden_states .size ()
387+
388+ query_states = module .q_proj (hidden_states )
389+ key_states = module .k_proj (hidden_states )
390+ value_states = module .v_proj (hidden_states )
391+
392+ query_states = query_states .view (bsz , q_len , - 1 , module .head_dim ).transpose (1 , 2 )
393+ key_states = key_states .view (bsz , q_len , - 1 , module .head_dim ).transpose (1 , 2 )
394+ value_states = value_states .view (bsz , q_len , - 1 , module .head_dim ).transpose (1 , 2 )
395+
396+ cos , sin = position_embeddings
397+ query_states , key_states = apply_multimodal_rotary_pos_emb (
398+ query_states , key_states , cos , sin , module .rope_scaling ['mrope_section' ]
399+ )
400+
401+ if past_key_value is not None :
402+ key_states = past_key_value .key_cache [layer_idx ]
403+ value_states = past_key_value .value_cache [layer_idx ]
404+
405+ key_states = repeat_kv (key_states , module .num_key_value_groups )
406+ value_states = repeat_kv (value_states , module .num_key_value_groups )
407+
408+ t_token_idx = t_token_idx [1 ] + v_token_start + v_token_num
409+ L , S = query_states .size (- 2 ), key_states .size (- 2 )
410+ scale_factor = 1 / math .sqrt (query_states .size (- 1 ))
411+ attn_bias = torch .zeros (L , S , dtype = query_states .dtype )
412+ if module .is_causal :
413+ assert attention_mask is None
414+ temp_mask = torch .ones (L , S , dtype = torch .bool ).tril (diagonal = 0 )
415+ attn_bias .masked_fill_ (temp_mask .logical_not (), float ('-inf' ))
416+ attn_bias .to (query_states .dtype )
417+
418+ attn_logits = query_states @ key_states .transpose (2 , 3 ) * scale_factor
419+ attn_logits += attn_bias .to (query_states .device )
420+ attn_logits = torch .softmax (attn_logits , dim = - 1 )
421+
422+ pruning_paras ['attn_logits' ] = attn_logits
423+
424+ return layer_outs
425+
355426 if self .model .__class__ .__name__ == 'LlavaHf' :
356427 self .model .embed_tokens .register_forward_pre_hook (
357428 functools .partial (input_hook , pruning_paras = self .pruning_paras )
@@ -364,11 +435,17 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):
364435 llava_next = self .special_config ['vision_token_length' ] is None
365436 ), self .model .vlm_model
366437 )
438+ elif self .model .__class__ .__name__ == 'Qwen2_5VL' :
439+ self .model .embed_tokens .register_forward_pre_hook (
440+ functools .partial (vtoken_length_hook , pruning_paras = self .pruning_paras )
441+ )
367442
368443 if self .model .__class__ .__name__ == 'LlavaHf' :
369444 llama_model = self .model .model
370445 elif self .model .__class__ .__name__ == 'Llava' :
371446 llama_model = self .model .model .model
447+ elif self .model .__class__ .__name__ == 'Qwen2_5VL' :
448+ llama_model = self .model .language_model
372449 llama_model .register_forward_pre_hook (
373450 functools .partial (register_module_paras , pruning_paras = self .pruning_paras ),
374451 with_kwargs = True
@@ -405,6 +482,23 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):
405482 ),
406483 with_kwargs = True
407484 )
485+ elif self .model .__class__ .__name__ == 'Qwen2_5VL' :
486+ self .blocks [block_idx ].register_forward_pre_hook (
487+ functools .partial (
488+ update_kwargs_hook ,
489+ pruning_paras = self .pruning_paras ,
490+ layer_idx = block_idx ,
491+ ),
492+ with_kwargs = True
493+ )
494+ self .blocks [block_idx ].self_attn .register_forward_hook (
495+ functools .partial (
496+ get_attn_logits_for_qwen25vl ,
497+ pruning_paras = self .pruning_paras ,
498+ layer_idx = block_idx ,
499+ ),
500+ with_kwargs = True
501+ )
408502 self .blocks [block_idx ].register_forward_hook (
409503 functools .partial (
410504 decoder_attn_hook ,
@@ -425,7 +519,7 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):
425519
426520def update_list ():
427521 global sparse_token_list_192 , sparse_token_list_128 , sparse_token_list_64
428- global sparse_token_list_640 , sparse_token_list_320 , sparse_token_list_160
522+ global sparse_token_list_960 , sparse_token_list_640 , sparse_token_list_320 , sparse_token_list_160 # noqa
429523 global prune_flag , merge_flag , sparse_token_dict
430524
431525 if layer_dict == {2 : 0 , 6 : 1 , 15 : 2 }: # 2*576 4*300 10*200 16*110
@@ -437,13 +531,15 @@ def update_list():
437531 sparse_token_list_192 = [180 ]
438532 sparse_token_list_128 = [114 ]
439533 sparse_token_list_64 = [48 ]
534+ sparse_token_list_960 = [0.3125 ]
440535 sparse_token_list_640 = [0.1979 ]
441536 sparse_token_list_320 = [0.0833 ]
442537 sparse_token_list_160 = [0.0261 ]
443538 elif prune_flag :
444539 sparse_token_list_192 = [192 ]
445540 sparse_token_list_128 = [128 ]
446541 sparse_token_list_64 = [64 ]
542+ sparse_token_list_960 = [0.3333 ]
447543 sparse_token_list_640 = [0.2222 ]
448544 sparse_token_list_320 = [0.1111 ]
449545 sparse_token_list_160 = [0.0555 ]
@@ -460,6 +556,7 @@ def update_list():
460556 192 : sparse_token_list_192 ,
461557 128 : sparse_token_list_128 ,
462558 64 : sparse_token_list_64 ,
559+ 960 : sparse_token_list_960 ,
463560 640 : sparse_token_list_640 ,
464561 320 : sparse_token_list_320 ,
465562 160 : sparse_token_list_160
0 commit comments