Skip to content

Commit c909896

Browse files
authored
prune for qwen2.5vl(divprune,holitom,sparsevlm,vispruner) (#441)
1 parent 06bad3e commit c909896

File tree

4 files changed

+353
-28
lines changed

4 files changed

+353
-28
lines changed

llmc/compression/token_reduction/divprune.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
from functools import wraps
23
from types import MethodType
34

@@ -6,6 +7,7 @@
67
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
78

89
from .token_reduction_module import TokenReductionModule
10+
from .utils import prefill_wrapper
911

1012

1113
def pairwise_cosine_similarity(matrix):
@@ -84,6 +86,41 @@ def divprune_post_hook(*args, pruning_paras=None):
8486
return tuple(args)
8587

8688

89+
def prune_qwenv25vl_hook(module, args, kwargs, pruning_paras):
90+
if kwargs['position_ids'].shape[-1] == 1:
91+
return args, kwargs
92+
inputs_embeds = kwargs['inputs_embeds']
93+
attention_mask = kwargs['attention_mask']
94+
rate = pruning_paras['reduction_ratio']
95+
SYS_TOKEN_LEN = pruning_paras['vision_token_start_index']
96+
img_feature_len = pruning_paras['vision_token_length']
97+
device = inputs_embeds.device
98+
99+
visual_tokens = inputs_embeds[0][SYS_TOKEN_LEN: SYS_TOKEN_LEN + img_feature_len]
100+
selected_visual_tokens, cosine_matrix = divprune(
101+
visual_tokens, img_feature_len, None, threshold_ratio=1 - rate
102+
)
103+
selected_visual_tokens += SYS_TOKEN_LEN
104+
keep_indexs = torch.cat(
105+
(
106+
torch.arange(SYS_TOKEN_LEN, device=device),
107+
selected_visual_tokens,
108+
torch.arange(
109+
SYS_TOKEN_LEN + img_feature_len, inputs_embeds.shape[1], device=device
110+
),
111+
)
112+
)
113+
keep_indexs = keep_indexs.sort().values
114+
115+
kwargs['inputs_embeds'] = inputs_embeds[:, keep_indexs, :]
116+
kwargs['position_ids'] = kwargs['position_ids'][:, :, keep_indexs]
117+
if attention_mask is not None:
118+
kwargs['attention_mask'] = attention_mask[:, keep_indexs]
119+
kwargs['cache_position'] = keep_indexs
120+
121+
return args, kwargs
122+
123+
87124
@TOKEN_REDUCTION_REGISTRY.register('DivPrune')
88125
class DivPrune(TokenReductionModule):
89126
def __init__(self, config, model, blocks):
@@ -114,6 +151,14 @@ def wrapper(self, *args, **kwargs):
114151
return divprune_post_hook(*outs, pruning_paras=pruning_paras)
115152
return wrapper
116153

154+
@prefill_wrapper
155+
def vtoken_length_hook(module, args, pruning_paras):
156+
input_ids = args[0]
157+
token_indices = torch.where(
158+
input_ids[0] == pruning_paras['vision_token_index']
159+
)[0]
160+
pruning_paras['vision_token_length'] = token_indices.shape[0]
161+
117162
if self.model.__class__.__name__ == 'Llava':
118163

119164
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
@@ -123,3 +168,15 @@ def wrapper(self, *args, **kwargs):
123168
llava_next=self.special_config['vision_token_length'] is None
124169
), self.model.vlm_model
125170
)
171+
elif self.model.__class__.__name__ == 'Qwen2_5VL':
172+
173+
self.model.embed_tokens.register_forward_pre_hook(
174+
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
175+
)
176+
self.model.language_model.register_forward_pre_hook(
177+
functools.partial(
178+
prune_qwenv25vl_hook,
179+
pruning_paras=self.pruning_paras,
180+
),
181+
with_kwargs=True
182+
)

llmc/compression/token_reduction/random.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,14 @@ def store_attention_hook(m, x, layer_outputs, pruning_paras):
228228
layer_attention = layer_outputs[1]
229229
pruning_paras['attn_scores'] = layer_attention
230230

231+
@prefill_wrapper
232+
def vtoken_length_hook(module, args, pruning_paras):
233+
input_ids = args[0]
234+
token_indices = torch.where(
235+
input_ids[0] == pruning_paras['vision_token_index']
236+
)[0]
237+
pruning_paras['vision_token_length'] = token_indices.shape[0]
238+
231239
if self.special_config['vision_token_length'] is None:
232240
if self.model.__class__.__name__ == 'Llava':
233241
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
@@ -236,6 +244,10 @@ def store_attention_hook(m, x, layer_outputs, pruning_paras):
236244
self.pruning_paras
237245
), self.model.vlm_model
238246
)
247+
elif self.model.__class__.__name__ == 'Qwen2_5VL':
248+
self.model.embed_tokens.register_forward_pre_hook(
249+
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
250+
)
239251

240252
if self.special_config['metric'] == 'random':
241253
self.blocks[self.pruning_loc].register_forward_pre_hook(

llmc/compression/token_reduction/sparsevlm.py

Lines changed: 100 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
sparse_token_list_192 = []
1818
sparse_token_list_128 = []
1919
sparse_token_list_64 = []
20+
sparse_token_list_960 = []
2021
sparse_token_list_640 = []
2122
sparse_token_list_320 = []
2223
sparse_token_list_160 = []
@@ -329,8 +330,9 @@ def decoder_attn_hook(module, inputs, kwargs, layer_outputs, pruning_paras, laye
329330

330331
if attention_mask is not None:
331332
attention_mask = attention_mask[:, :, keep_indexs, keep_indexs]
332-
new_pe0 = position_embeddings[0][:, keep_indexs, :].clone()
333-
new_pe1 = position_embeddings[1][:, keep_indexs, :].clone()
333+
index_dim = 1 if position_embeddings[0].dim() == 3 else 2
334+
new_pe0 = position_embeddings[0].index_select(index_dim, keep_indexs).clone()
335+
new_pe1 = position_embeddings[1].index_select(index_dim, keep_indexs).clone()
334336
position_embeddings = (new_pe0, new_pe1)
335337

336338
pruning_paras['v_token_num'] = v_token_num
@@ -352,6 +354,75 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):
352354

353355
return args, kwargs
354356

357+
@prefill_wrapper
358+
def vtoken_length_hook(module, args, pruning_paras):
359+
input_ids = args[0]
360+
token_indices = torch.where(
361+
input_ids[0] == pruning_paras['vision_token_index']
362+
)[0]
363+
pruning_paras['vision_token_length'] = token_indices.shape[0]
364+
pruning_paras['pre_prompt_length_list'] = [token_indices[0].item()]
365+
366+
def get_attn_logits_for_qwen25vl(
367+
module,
368+
args, kwargs, layer_outs,
369+
pruning_paras, layer_idx
370+
):
371+
if kwargs['position_ids'].shape[-1] == 1:
372+
return layer_outs
373+
374+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
375+
apply_multimodal_rotary_pos_emb, repeat_kv)
376+
377+
hidden_states = kwargs['hidden_states']
378+
position_embeddings = kwargs['position_embeddings']
379+
past_key_value = layer_outs[2]
380+
attention_mask = kwargs['attention_mask']
381+
382+
t_token_idx = pruning_paras['t_token_idx']
383+
v_token_start = pruning_paras['v_token_start']
384+
v_token_num = pruning_paras['v_token_num']
385+
386+
bsz, q_len, _ = hidden_states.size()
387+
388+
query_states = module.q_proj(hidden_states)
389+
key_states = module.k_proj(hidden_states)
390+
value_states = module.v_proj(hidden_states)
391+
392+
query_states = query_states.view(bsz, q_len, -1, module.head_dim).transpose(1, 2)
393+
key_states = key_states.view(bsz, q_len, -1, module.head_dim).transpose(1, 2)
394+
value_states = value_states.view(bsz, q_len, -1, module.head_dim).transpose(1, 2)
395+
396+
cos, sin = position_embeddings
397+
query_states, key_states = apply_multimodal_rotary_pos_emb(
398+
query_states, key_states, cos, sin, module.rope_scaling['mrope_section']
399+
)
400+
401+
if past_key_value is not None:
402+
key_states = past_key_value.key_cache[layer_idx]
403+
value_states = past_key_value.value_cache[layer_idx]
404+
405+
key_states = repeat_kv(key_states, module.num_key_value_groups)
406+
value_states = repeat_kv(value_states, module.num_key_value_groups)
407+
408+
t_token_idx = t_token_idx[1] + v_token_start + v_token_num
409+
L, S = query_states.size(-2), key_states.size(-2)
410+
scale_factor = 1 / math.sqrt(query_states.size(-1))
411+
attn_bias = torch.zeros(L, S, dtype=query_states.dtype)
412+
if module.is_causal:
413+
assert attention_mask is None
414+
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
415+
attn_bias.masked_fill_(temp_mask.logical_not(), float('-inf'))
416+
attn_bias.to(query_states.dtype)
417+
418+
attn_logits = query_states @ key_states.transpose(2, 3) * scale_factor
419+
attn_logits += attn_bias.to(query_states.device)
420+
attn_logits = torch.softmax(attn_logits, dim=-1)
421+
422+
pruning_paras['attn_logits'] = attn_logits
423+
424+
return layer_outs
425+
355426
if self.model.__class__.__name__ == 'LlavaHf':
356427
self.model.embed_tokens.register_forward_pre_hook(
357428
functools.partial(input_hook, pruning_paras=self.pruning_paras)
@@ -364,11 +435,17 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):
364435
llava_next=self.special_config['vision_token_length'] is None
365436
), self.model.vlm_model
366437
)
438+
elif self.model.__class__.__name__ == 'Qwen2_5VL':
439+
self.model.embed_tokens.register_forward_pre_hook(
440+
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
441+
)
367442

368443
if self.model.__class__.__name__ == 'LlavaHf':
369444
llama_model = self.model.model
370445
elif self.model.__class__.__name__ == 'Llava':
371446
llama_model = self.model.model.model
447+
elif self.model.__class__.__name__ == 'Qwen2_5VL':
448+
llama_model = self.model.language_model
372449
llama_model.register_forward_pre_hook(
373450
functools.partial(register_module_paras, pruning_paras=self.pruning_paras),
374451
with_kwargs=True
@@ -405,6 +482,23 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):
405482
),
406483
with_kwargs=True
407484
)
485+
elif self.model.__class__.__name__ == 'Qwen2_5VL':
486+
self.blocks[block_idx].register_forward_pre_hook(
487+
functools.partial(
488+
update_kwargs_hook,
489+
pruning_paras=self.pruning_paras,
490+
layer_idx=block_idx,
491+
),
492+
with_kwargs=True
493+
)
494+
self.blocks[block_idx].self_attn.register_forward_hook(
495+
functools.partial(
496+
get_attn_logits_for_qwen25vl,
497+
pruning_paras=self.pruning_paras,
498+
layer_idx=block_idx,
499+
),
500+
with_kwargs=True
501+
)
408502
self.blocks[block_idx].register_forward_hook(
409503
functools.partial(
410504
decoder_attn_hook,
@@ -425,7 +519,7 @@ def read_parameter_hook(module, args, kwargs, pruning_paras):
425519

426520
def update_list():
427521
global sparse_token_list_192, sparse_token_list_128, sparse_token_list_64
428-
global sparse_token_list_640, sparse_token_list_320, sparse_token_list_160
522+
global sparse_token_list_960, sparse_token_list_640, sparse_token_list_320, sparse_token_list_160 # noqa
429523
global prune_flag, merge_flag, sparse_token_dict
430524

431525
if layer_dict == {2: 0, 6: 1, 15: 2}: # 2*576 4*300 10*200 16*110
@@ -437,13 +531,15 @@ def update_list():
437531
sparse_token_list_192 = [180]
438532
sparse_token_list_128 = [114]
439533
sparse_token_list_64 = [48]
534+
sparse_token_list_960 = [0.3125]
440535
sparse_token_list_640 = [0.1979]
441536
sparse_token_list_320 = [0.0833]
442537
sparse_token_list_160 = [0.0261]
443538
elif prune_flag:
444539
sparse_token_list_192 = [192]
445540
sparse_token_list_128 = [128]
446541
sparse_token_list_64 = [64]
542+
sparse_token_list_960 = [0.3333]
447543
sparse_token_list_640 = [0.2222]
448544
sparse_token_list_320 = [0.1111]
449545
sparse_token_list_160 = [0.0555]
@@ -460,6 +556,7 @@ def update_list():
460556
192: sparse_token_list_192,
461557
128: sparse_token_list_128,
462558
64: sparse_token_list_64,
559+
960: sparse_token_list_960,
463560
640: sparse_token_list_640,
464561
320: sparse_token_list_320,
465562
160: sparse_token_list_160

0 commit comments

Comments
 (0)