Skip to content

Commit fb8f3c1

Browse files
authored
fix vispruner bugs and update holitom_merge (#431)
* fix vispruner bugs and update holitom_merge * set lmms_eval==0.3.0 temporarily
1 parent 73c131d commit fb8f3c1

File tree

10 files changed

+179
-99
lines changed

10 files changed

+179
-99
lines changed

llmc/compression/token_reduction/dycoke.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
try:
99
from llava.model.llava_arch import LlavaMetaForCausalLM
10-
except ModuleNotFoundError:
11-
logger.info('LlavaMetaForCausalLM not found, if need, please install llava first.')
10+
except ImportError:
11+
pass
1212
from transformers.cache_utils import Cache, DynamicCache
1313

1414
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY

llmc/compression/token_reduction/fastvid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from llava.model.multimodal_encoder.siglip_encoder import (
1414
SigLipVisionConfig, SigLipVisionModel)
1515
from llava.utils import rank0_print
16-
except ModuleNotFoundError:
17-
logger.info('LlavaMetaForCausalLM not found, if need, please install llava first.')
16+
except ImportError:
17+
pass
1818

1919
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
2020

llmc/compression/token_reduction/holitom.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from llava.utils import rank0_print
2121
from transformers.modeling_outputs import (BaseModelOutput,
2222
BaseModelOutputWithPooling)
23-
except ModuleNotFoundError:
24-
logger.info('LlavaMetaForCausalLM not found, if need, please install llava first.')
23+
except ImportError:
24+
pass
2525

2626
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
2727

llmc/compression/token_reduction/prunevid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
try:
1010
from llava.model.llava_arch import LlavaMetaForCausalLM
11-
except ModuleNotFoundError:
12-
logger.info('LlavaMetaForCausalLM not found, if need, please install llava first.')
11+
except ImportError:
12+
pass
1313

1414
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
1515

Lines changed: 162 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import functools
2-
from functools import wraps
32
from types import MethodType
43

54
import torch
6-
from loguru import logger
75

86
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
97

@@ -21,61 +19,21 @@ def __init__(self, config, model, blocks):
2119
def add_sparse_config(self):
2220

2321
self.pruning_loc = self.special_config['pruning_loc']
24-
self.special_config['image_token_length'] = self.model.pruning_config[
25-
'image_token_length'
26-
]
27-
2822
self.pruning_paras = self.special_config
2923

3024
def register_reduction_modules(self):
3125

32-
def input_hook_llava(fn, pruning_paras):
33-
@wraps(fn)
34-
def wrapper(self, *args, **kwargs):
35-
if len(args) == 0:
36-
return fn(*args, **kwargs)
37-
input_args = args[0]
38-
if hasattr(input_args[0], 'shape') and input_args[0].shape[0] == 1:
39-
return fn(*args, **kwargs)
40-
41-
input_ids = args[0]
42-
attention_mask = args[2]
43-
token_indices = input_ids[0][attention_mask[0]] == IMAGE_TOKEN_INDEX
44-
pruning_paras['image_token_start_index'] = torch.where(token_indices)[
45-
0
46-
][0].item()
47-
48-
outputs = fn(*args, **kwargs)
49-
return outputs
50-
51-
return wrapper
52-
53-
@prefill_wrapper
54-
def input_hook(module, input_args, pruning_paras):
55-
input_ids = input_args[0]
56-
image_token_idxs = (
57-
input_ids[0] == pruning_paras['vision_token_index']
58-
).nonzero(as_tuple=True)[0]
59-
pruning_paras['image_token_start_index'] = image_token_idxs[0].item()
60-
61-
return input_args
62-
6326
@prefill_wrapper
6427
def random_pruning_hook(module, args, kwargs, pruning_paras):
6528

66-
logger.info(' ========random_pruning_hook======== ')
67-
68-
rate = pruning_paras['rate']
69-
image_token_start_index = pruning_paras['image_token_start_index']
70-
image_token_length = pruning_paras['image_token_length']
29+
rate = pruning_paras['prune_ratio']
30+
image_token_start_index = pruning_paras['vision_token_start_index']
31+
image_token_length = pruning_paras['vision_token_length']
7132

7233
hidden_states = args[0]
7334
causal_mask = kwargs['attention_mask']
7435

75-
logger.info(f'before hidden_states : {hidden_states.shape}')
76-
7736
device = hidden_states.device
78-
7937
vision_indexes = torch.arange(
8038
image_token_start_index,
8139
image_token_start_index + image_token_length,
@@ -130,25 +88,169 @@ def random_pruning_hook(module, args, kwargs, pruning_paras):
13088
position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0)
13189
position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1)
13290

133-
logger.info(f'after hidden_states : {hidden_states.shape}')
13491
return (hidden_states,), kwargs
13592

136-
if self.model.__class__.__name__ == 'LlavaHf':
137-
self.model.embed_tokens.register_forward_pre_hook(
138-
functools.partial(input_hook, pruning_paras=self.pruning_paras)
93+
@prefill_wrapper
94+
def holitom_merge_hook(module, args, kwargs, pruning_paras):
95+
96+
rate = pruning_paras['prune_ratio']
97+
image_token_start_index = pruning_paras['vision_token_start_index']
98+
image_token_length = pruning_paras['vision_token_length']
99+
100+
hidden_states = args[0]
101+
causal_mask = kwargs['attention_mask']
102+
103+
device = hidden_states.device
104+
last_layer_attention = pruning_paras['attn_scores']
105+
# compute average attention over different head
106+
last_layer_attention_avg = torch.mean(
107+
last_layer_attention, dim=1
108+
)[0]
109+
# generate new attention mask based on the average attention,
110+
# sample the top ATTENTION_RANK tokens with highest attention
111+
last_layer_attention_avg_last_tok = (
112+
last_layer_attention_avg[-1]
113+
)
114+
# get the attention in image token
115+
last_layer_attention_avg_last_tok_image = \
116+
last_layer_attention_avg_last_tok[
117+
image_token_start_index:
118+
image_token_start_index + image_token_length
119+
]
120+
# get the indexes of the top ATTENTION_RANK tokens
121+
top_attention_rank_index = (
122+
last_layer_attention_avg_last_tok_image.topk(
123+
round(
124+
image_token_length * (1 - rate)
125+
)
126+
).indices
127+
+ image_token_start_index
139128
)
140-
elif self.model.__class__.__name__ == 'Llava':
141-
from llava.constants import IMAGE_TOKEN_INDEX
142129

143-
hook_fn = input_hook_llava(
144-
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
145-
self.pruning_paras,
130+
all_indices = torch.arange(
131+
image_token_length, device=device
146132
)
147-
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
148-
hook_fn, self.model.vlm_model
133+
non_topk_mask = ~torch.isin(
134+
all_indices,
135+
top_attention_rank_index
136+
- image_token_start_index,
137+
)
138+
non_topk_indices = (
139+
all_indices[non_topk_mask]
140+
+ image_token_start_index
141+
)
142+
non_topk_states = hidden_states[
143+
:, non_topk_indices, :
144+
] # [batch_size, len(non_topk), hidden_size]
145+
topk_states = hidden_states[
146+
:, top_attention_rank_index, :
147+
] # [batch_size, len(topk), hidden_size]
148+
non_topk_norm = torch.norm(
149+
non_topk_states, dim=-1, keepdim=True
150+
) # [batch_size, len(non_topk), 1]
151+
topk_norm = torch.norm(
152+
topk_states, dim=-1, keepdim=True
153+
) # [batch_size, len(topk), 1]
154+
dot_product = torch.bmm(
155+
non_topk_states, topk_states.transpose(1, 2)
156+
) # [batch_size, len(non_topk), len(topk)]
157+
sim_matrix = dot_product / (
158+
non_topk_norm * topk_norm.transpose(1, 2)
159+
)
160+
sim_max, sim_max_index = torch.max(sim_matrix, dim=-1)
161+
162+
batch_size = hidden_states.size(0)
163+
num_topk = len(top_attention_rank_index)
164+
num_non_topk = len(non_topk_indices)
165+
topk_counter = torch.ones((batch_size, num_topk, 1), device=hidden_states.device)
166+
167+
for b in range(batch_size):
168+
for i in range(num_non_topk):
169+
topk_rel_idx = sim_max_index[b, i].item() # 这是 topk 中的相对索引
170+
topk_abs_idx = top_attention_rank_index[topk_rel_idx] # 得到绝对索引
171+
non_topk_abs_idx = non_topk_indices[i]
172+
173+
# 累加non-topk到topk token上(就地)
174+
hidden_states[b, topk_abs_idx, :] += hidden_states[b, non_topk_abs_idx, :]
175+
# 增加计数
176+
topk_counter[b, topk_rel_idx] += 1
177+
178+
# 平均化所有topk token(包含自己和所有被合并的)
179+
for b in range(batch_size):
180+
for i in range(num_topk):
181+
topk_abs_idx = top_attention_rank_index[i]
182+
hidden_states[b, topk_abs_idx, :] /= topk_counter[b, i]
183+
184+
keep_indexs = torch.cat(
185+
(
186+
torch.arange(
187+
image_token_start_index,
188+
device=device,
189+
),
190+
top_attention_rank_index,
191+
torch.arange(
192+
image_token_start_index
193+
+ image_token_length,
194+
hidden_states.shape[1],
195+
device=device,
196+
),
197+
)
149198
)
150199

151-
self.blocks[self.pruning_loc].register_forward_pre_hook(
152-
functools.partial(random_pruning_hook, pruning_paras=self.pruning_paras),
153-
with_kwargs=True,
154-
)
200+
# sort index
201+
keep_indexs = keep_indexs.sort().values
202+
# filter hidden states &
203+
hidden_states = hidden_states[:, keep_indexs, :]
204+
# update position ids
205+
position_ids = keep_indexs.unsqueeze(0)
206+
# update attention mask
207+
if causal_mask is not None:
208+
causal_mask = causal_mask[:, :, :hidden_states.shape[1], :hidden_states.shape[1]]
209+
kwargs['attention_mask'].resize_as_(causal_mask).copy_(causal_mask.clone())
210+
kwargs['cache_position'].resize_as_(position_ids.squeeze(0)).copy_(
211+
position_ids.squeeze(0).clone())
212+
kwargs['position_ids'].resize_as_(position_ids).copy_(position_ids.clone())
213+
214+
position_embeddings = kwargs['position_embeddings']
215+
index_dim = 1 if position_embeddings[0].dim() == 3 else 2
216+
new_pe0 = position_embeddings[0].index_select(index_dim, keep_indexs).clone()
217+
new_pe1 = position_embeddings[1].index_select(index_dim, keep_indexs).clone()
218+
position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0)
219+
position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1)
220+
221+
return (hidden_states,), kwargs
222+
223+
def update_output_attentions_hook(module, args, kwargs):
224+
kwargs['output_attentions'] = True
225+
return args, kwargs
226+
227+
def store_attention_hook(m, x, layer_outputs, pruning_paras):
228+
layer_attention = layer_outputs[1]
229+
pruning_paras['attn_scores'] = layer_attention
230+
231+
if self.special_config['vision_token_length'] is None:
232+
if self.model.__class__.__name__ == 'Llava':
233+
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
234+
self.vtoken_length_for_llava_hook(
235+
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
236+
self.pruning_paras
237+
), self.model.vlm_model
238+
)
239+
240+
if self.special_config['metric'] == 'random':
241+
self.blocks[self.pruning_loc].register_forward_pre_hook(
242+
functools.partial(random_pruning_hook, pruning_paras=self.pruning_paras),
243+
with_kwargs=True
244+
)
245+
elif self.special_config['metric'] == 'holitom_merge':
246+
self.blocks[self.pruning_loc - 1].register_forward_pre_hook(
247+
update_output_attentions_hook,
248+
with_kwargs=True
249+
)
250+
self.blocks[self.pruning_loc - 1].register_forward_hook(
251+
functools.partial(store_attention_hook, pruning_paras=self.pruning_paras),
252+
)
253+
self.blocks[self.pruning_loc].register_forward_pre_hook(
254+
functools.partial(holitom_merge_hook, pruning_paras=self.pruning_paras),
255+
with_kwargs=True
256+
)

llmc/compression/token_reduction/tome.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from typing import Callable, Optional, Tuple
44

55
import torch
6-
import torch.nn.functional as F
7-
from loguru import logger
86
from transformers.models.clip.modeling_clip import CLIPEncoderLayer
97

108
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
@@ -20,8 +18,7 @@ def __init__(self, config, model, blocks):
2018
self.patch_layer()
2119

2220
def add_sparse_config(self):
23-
special_config = self.config.get('special', {})
24-
r_param = special_config.get('r', 0)
21+
r_param = self.special_config.get('r', 0)
2522
if isinstance(r_param, int) or isinstance(r_param, float):
2623
self.r = [max(int(r_param), 0)] * len(self.blocks)
2724
elif isinstance(r_param, (tuple, list)):
@@ -36,19 +33,17 @@ def add_sparse_config(self):
3633
else:
3734
raise ValueError('Invalid r format. Expected int or (start, step) tuple.')
3835

39-
self.pruning_paras = special_config
36+
self.pruning_paras = self.special_config
4037

4138
def patch_layer(self):
4239
for idx, block in enumerate(self.blocks):
4340
if self.r[idx] > 0:
4441
block.r = self.r[idx]
4542
if isinstance(block, CLIPEncoderLayer): # llava
46-
block.self_attn.original_forward = block.self_attn.forward
4743
block.self_attn.forward = types.MethodType(
4844
tome_CLIPSdpaAttention_forward,
4945
block.self_attn
5046
)
51-
block.original_forward = block.forward
5247
block.forward = types.MethodType(
5348
tome_CLIPEncoderLayer_forward,
5449
block

llmc/compression/token_reduction/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
try:
1111
from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX
12-
except Exception as e:
13-
logger.debug('LLaVA is not installed. Please install LLaVA to use this model.\nError: %s' % e)
12+
except ImportError:
13+
pass
1414
import random
1515

1616

llmc/compression/token_reduction/visionzip.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -585,19 +585,8 @@ def prune_qwenv25vl_hook(module, args, kwargs, pruning_paras):
585585
st_idx = torch.nonzero(img_mask, as_tuple=True)[0]
586586

587587
if st_idx.numel() > 0:
588-
discontinuities = torch.where(st_idx[1:] - st_idx[:-1] != 1)[0]
589-
if discontinuities.numel() > 0:
590-
raise ValueError('Visual tokens are not contiguous in input_ids!')
591-
segment_starts = [st_idx[0].item()] + [st_idx[i + 1].item() for i in discontinuities.tolist()] # noqa
592-
segment_ends = [st_idx[i].item() for i in discontinuities.tolist()] + [st_idx[-1].item()] # noqa
593-
offset = 0
594-
for first, last in zip(segment_starts, segment_ends):
595-
length = last - first + 1
596-
# [15 1502] [1505 3289]
597-
img_mask[first: last + 1] = ~select_mask[offset: offset + length]
598-
else:
599-
first, last = st_idx[0].item(), st_idx[-1].item()
600-
img_mask[first: last + 1] = ~select_mask
588+
first, last = st_idx[0].item(), st_idx[-1].item()
589+
img_mask[first: last + 1] = ~select_mask
601590
img_mask = ~img_mask
602591
contextual_input_idx = false_pos[target_indices] + first
603592

0 commit comments

Comments
 (0)