Skip to content

Commit 5f67723

Browse files
jpablomchYuan0320
andcommitted
Update MultiPruner Patch
Signed-off-by: J. Pablo Muñoz <[email protected]> Co-authored-by: Yuan0320 <[email protected]>
1 parent c713ed7 commit 5f67723

File tree

2 files changed

+40
-23
lines changed

2 files changed

+40
-23
lines changed

MultiPruner/patches/transformers-v4.45.0.patch

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ index d41bc99ee..f74ee777f 100644
2525
return self.key_cache[layer_idx], self.value_cache[layer_idx]
2626

2727
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
28-
index 3e3d78908..fb17c952d 100755
28+
index 3e3d78908..4915644b6 100755
2929
--- a/src/transformers/modeling_utils.py
3030
+++ b/src/transformers/modeling_utils.py
31-
@@ -4024,6 +4024,40 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
31+
@@ -4024,6 +4024,45 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
3232
gguf_path=gguf_path,
3333
)
3434

@@ -55,21 +55,26 @@ index 3e3d78908..fb17c952d 100755
5555
+ layer = get_layer_by_key(model, key)
5656
+ if ".self_attn" in key:
5757
+ layer.mask_attn = True
58+
+ layer.input_layernorm = None
59+
+ layer.self_attn = None
5860
+ logger.warning(
5961
+ f"Some weights of MHA module in {layer.__class__.__name__} were not initialized from the model checkpoint at"
6062
+ f" {pretrained_model_name_or_path} and the corresponding MHA module is pruned: {key}"
6163
+ )
6264
+ elif ".mlp" in key:
6365
+ layer.mask_mlp = True
66+
+ layer.post_attention_layernorm = None
67+
+ layer.mlp = None
6468
+ logger.warning(
6569
+ f"Some weights of MLP module in {layer.__class__.__name__} were not initialized from the model checkpoint at"
6670
+ f" {pretrained_model_name_or_path} and the corresponding MLP module is pruned: {key}"
6771
+ )
72+
+ torch.cuda.empty_cache()
6873
+
6974
# make sure token embedding weights are still tied if needed
7075
model.tie_weights()
7176

72-
@@ -4403,6 +4437,33 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
77+
@@ -4403,6 +4442,33 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
7378
}
7479
else:
7580
offload_index = None
@@ -103,15 +108,15 @@ index 3e3d78908..fb17c952d 100755
103108

104109
if state_dict is not None:
105110
# Whole checkpoint
106-
@@ -4414,6 +4475,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
111+
@@ -4414,6 +4480,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
107112
remove_prefix_from_model,
108113
ignore_mismatched_sizes,
109114
)
110115
+ module_reshape(state_dict)
111116

112117
# For GGUF models `state_dict` is never set to None as the state dict is always small
113118
if gguf_path:
114-
@@ -4485,6 +4547,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
119+
@@ -4485,6 +4552,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
115120
remove_prefix_from_model,
116121
ignore_mismatched_sizes,
117122
)
@@ -120,7 +125,7 @@ index 3e3d78908..fb17c952d 100755
120125
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
121126
for key, param in model_to_load.state_dict().items():
122127
diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
123-
index 73b6bcd8b..2128f4148 100644
128+
index 73b6bcd8b..5d68e6c85 100644
124129
--- a/src/transformers/models/llama/modeling_llama.py
125130
+++ b/src/transformers/models/llama/modeling_llama.py
126131
@@ -393,9 +393,9 @@ class LlamaAttention(nn.Module):
@@ -174,7 +179,14 @@ index 73b6bcd8b..2128f4148 100644
174179

175180
if position_embeddings is None:
176181
logger.warning_once(
177-
@@ -686,6 +686,8 @@ class LlamaDecoderLayer(nn.Module):
182+
@@ -680,12 +680,15 @@ class LlamaDecoderLayer(nn.Module):
183+
def __init__(self, config: LlamaConfig, layer_idx: int):
184+
super().__init__()
185+
self.hidden_size = config.hidden_size
186+
+ self.layer_idx = layer_idx
187+
188+
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
189+
178190
self.mlp = LlamaMLP(config)
179191
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
180192
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -183,7 +195,7 @@ index 73b6bcd8b..2128f4148 100644
183195

184196
def forward(
185197
self,
186-
@@ -721,29 +723,32 @@ class LlamaDecoderLayer(nn.Module):
198+
@@ -721,29 +724,32 @@ class LlamaDecoderLayer(nn.Module):
187199
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
188200
into the model
189201
"""
@@ -238,16 +250,16 @@ index 73b6bcd8b..2128f4148 100644
238250

239251
outputs = (hidden_states,)
240252

241-
@@ -751,6 +756,8 @@ class LlamaDecoderLayer(nn.Module):
253+
@@ -751,6 +757,8 @@ class LlamaDecoderLayer(nn.Module):
242254
outputs += (self_attn_weights,)
243255

244256
if use_cache:
245257
+ if self.mask_attn:
246-
+ past_key_value.update(None, None, self.self_attn.layer_idx)
258+
+ past_key_value.update(None, None, self.layer_idx)
247259
outputs += (present_key_value,)
248260

249261
return outputs
250-
@@ -1023,7 +1030,7 @@ class LlamaModel(LlamaPreTrainedModel):
262+
@@ -1023,7 +1031,7 @@ class LlamaModel(LlamaPreTrainedModel):
251263
all_hidden_states += (hidden_states,)
252264

253265
next_cache = next_decoder_cache if use_cache else None
@@ -257,7 +269,7 @@ index 73b6bcd8b..2128f4148 100644
257269

258270
if not return_dict:
259271
diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py
260-
index 10c0b6f38..763c5d813 100644
272+
index 10c0b6f38..aafb914d0 100644
261273
--- a/src/transformers/models/qwen2/modeling_qwen2.py
262274
+++ b/src/transformers/models/qwen2/modeling_qwen2.py
263275
@@ -345,9 +345,9 @@ class Qwen2Attention(nn.Module):
@@ -320,7 +332,15 @@ index 10c0b6f38..763c5d813 100644
320332

321333
attn_output = self.o_proj(attn_output)
322334

323-
@@ -660,6 +660,9 @@ class Qwen2DecoderLayer(nn.Module):
335+
@@ -648,6 +648,7 @@ class Qwen2DecoderLayer(nn.Module):
336+
def __init__(self, config: Qwen2Config, layer_idx: int):
337+
super().__init__()
338+
self.hidden_size = config.hidden_size
339+
+ self.layer_idx = layer_idx
340+
341+
if config.sliding_window and config._attn_implementation != "flash_attention_2":
342+
logger.warning_once(
343+
@@ -660,6 +661,9 @@ class Qwen2DecoderLayer(nn.Module):
324344
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
325345
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
326346

@@ -330,7 +350,7 @@ index 10c0b6f38..763c5d813 100644
330350
def forward(
331351
self,
332352
hidden_states: torch.Tensor,
333-
@@ -694,28 +697,31 @@ class Qwen2DecoderLayer(nn.Module):
353+
@@ -694,28 +698,31 @@ class Qwen2DecoderLayer(nn.Module):
334354
into the model
335355
"""
336356

@@ -381,16 +401,16 @@ index 10c0b6f38..763c5d813 100644
381401

382402
outputs = (hidden_states,)
383403

384-
@@ -723,6 +729,8 @@ class Qwen2DecoderLayer(nn.Module):
404+
@@ -723,6 +730,8 @@ class Qwen2DecoderLayer(nn.Module):
385405
outputs += (self_attn_weights,)
386406

387407
if use_cache:
388408
+ if self.mask_attn:
389-
+ past_key_value.update(None, None, self.self_attn.layer_idx)
409+
+ past_key_value.update(None, None, self.layer_idx)
390410
outputs += (present_key_value,)
391411

392412
return outputs
393-
@@ -999,7 +1007,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
413+
@@ -999,7 +1008,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
394414
all_hidden_states += (hidden_states,)
395415

396416
next_cache = next_decoder_cache if use_cache else None

MultiPruner/results/README.md

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,11 @@ python run_multipruner.py \
164164

165165
#### Baichuan2-7B-Base
166166

167-
To enable pruning of Query, Key, and Value, we have deconstructed the linear module `W_pack` (which combines QKV into a single linear layer) in [Baichuan2-7B-Base](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base).
168-
The deconstructed model: [IntelLabs/Baichuan2-7B-Base-split_qkv](https://huggingface.co/IntelLabs/Baichuan2-7B-Base-split_qkv).
167+
To enable pruning of Query, Key, and Value, we have deconstructed the linear module `W_pack` (which combines QKV into a single linear layer) in [Baichuan2-7B-Base](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base).
169168

170169
```bash
171170
python run_multipruner.py \
172-
--model_path IntelLabs/Baichuan2-7B-Base-split_qkv \
171+
--model_path <path to processed baichuan model> \
173172
--output_path <path to pruning results> \
174173
--weight_reorder \
175174
--do_prune \
@@ -186,11 +185,9 @@ python run_multipruner.py \
186185

187186
#### Baichuan2-13B-Base
188187

189-
Similar to Baichuan2-7B-Base, the deconstructed model: [IntelLabs/Baichuan2-13B-Base-split_qkv](https://huggingface.co/IntelLabs/Baichuan2-13B-Base-split_qkv).
190-
191188
```bash
192189
python run_multipruner.py \
193-
--model_path IntelLabs/Baichuan2-13B-Base-split_qkv \
190+
--model_path <path to processed baichuan model> \
194191
--output_path <path to pruning results> \
195192
--do_prune \
196193
--target_ratio <pruning ratio> \

0 commit comments

Comments
 (0)