@@ -25,10 +25,10 @@ index d41bc99ee..f74ee777f 100644
25
25
return self.key_cache[layer_idx], self.value_cache[layer_idx]
26
26
27
27
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
28
- index 3e3d78908..fb17c952d 100755
28
+ index 3e3d78908..4915644b6 100755
29
29
--- a/src/transformers/modeling_utils.py
30
30
+++ b/src/transformers/modeling_utils.py
31
- @@ -4024,6 +4024,40 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
31
+ @@ -4024,6 +4024,45 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
32
32
gguf_path=gguf_path,
33
33
)
34
34
@@ -55,21 +55,26 @@ index 3e3d78908..fb17c952d 100755
55
55
+ layer = get_layer_by_key(model, key)
56
56
+ if ".self_attn" in key:
57
57
+ layer.mask_attn = True
58
+ + layer.input_layernorm = None
59
+ + layer.self_attn = None
58
60
+ logger.warning(
59
61
+ f"Some weights of MHA module in {layer.__class__.__name__} were not initialized from the model checkpoint at"
60
62
+ f" {pretrained_model_name_or_path} and the corresponding MHA module is pruned: {key}"
61
63
+ )
62
64
+ elif ".mlp" in key:
63
65
+ layer.mask_mlp = True
66
+ + layer.post_attention_layernorm = None
67
+ + layer.mlp = None
64
68
+ logger.warning(
65
69
+ f"Some weights of MLP module in {layer.__class__.__name__} were not initialized from the model checkpoint at"
66
70
+ f" {pretrained_model_name_or_path} and the corresponding MLP module is pruned: {key}"
67
71
+ )
72
+ + torch.cuda.empty_cache()
68
73
+
69
74
# make sure token embedding weights are still tied if needed
70
75
model.tie_weights()
71
76
72
- @@ -4403,6 +4437 ,33 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
77
+ @@ -4403,6 +4442 ,33 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
73
78
}
74
79
else:
75
80
offload_index = None
@@ -103,15 +108,15 @@ index 3e3d78908..fb17c952d 100755
103
108
104
109
if state_dict is not None:
105
110
# Whole checkpoint
106
- @@ -4414,6 +4475 ,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
111
+ @@ -4414,6 +4480 ,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
107
112
remove_prefix_from_model,
108
113
ignore_mismatched_sizes,
109
114
)
110
115
+ module_reshape(state_dict)
111
116
112
117
# For GGUF models `state_dict` is never set to None as the state dict is always small
113
118
if gguf_path:
114
- @@ -4485,6 +4547 ,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
119
+ @@ -4485,6 +4552 ,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
115
120
remove_prefix_from_model,
116
121
ignore_mismatched_sizes,
117
122
)
@@ -120,7 +125,7 @@ index 3e3d78908..fb17c952d 100755
120
125
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
121
126
for key, param in model_to_load.state_dict().items():
122
127
diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
123
- index 73b6bcd8b..2128f4148 100644
128
+ index 73b6bcd8b..5d68e6c85 100644
124
129
--- a/src/transformers/models/llama/modeling_llama.py
125
130
+++ b/src/transformers/models/llama/modeling_llama.py
126
131
@@ -393,9 +393,9 @@ class LlamaAttention(nn.Module):
@@ -174,7 +179,14 @@ index 73b6bcd8b..2128f4148 100644
174
179
175
180
if position_embeddings is None:
176
181
logger.warning_once(
177
- @@ -686,6 +686,8 @@ class LlamaDecoderLayer(nn.Module):
182
+ @@ -680,12 +680,15 @@ class LlamaDecoderLayer(nn.Module):
183
+ def __init__(self, config: LlamaConfig, layer_idx: int):
184
+ super().__init__()
185
+ self.hidden_size = config.hidden_size
186
+ + self.layer_idx = layer_idx
187
+
188
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
189
+
178
190
self.mlp = LlamaMLP(config)
179
191
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
180
192
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -183,7 +195,7 @@ index 73b6bcd8b..2128f4148 100644
183
195
184
196
def forward(
185
197
self,
186
- @@ -721,29 +723 ,32 @@ class LlamaDecoderLayer(nn.Module):
198
+ @@ -721,29 +724 ,32 @@ class LlamaDecoderLayer(nn.Module):
187
199
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
188
200
into the model
189
201
"""
@@ -238,16 +250,16 @@ index 73b6bcd8b..2128f4148 100644
238
250
239
251
outputs = (hidden_states,)
240
252
241
- @@ -751,6 +756 ,8 @@ class LlamaDecoderLayer(nn.Module):
253
+ @@ -751,6 +757 ,8 @@ class LlamaDecoderLayer(nn.Module):
242
254
outputs += (self_attn_weights,)
243
255
244
256
if use_cache:
245
257
+ if self.mask_attn:
246
- + past_key_value.update(None, None, self.self_attn. layer_idx)
258
+ + past_key_value.update(None, None, self.layer_idx)
247
259
outputs += (present_key_value,)
248
260
249
261
return outputs
250
- @@ -1023,7 +1030 ,7 @@ class LlamaModel(LlamaPreTrainedModel):
262
+ @@ -1023,7 +1031 ,7 @@ class LlamaModel(LlamaPreTrainedModel):
251
263
all_hidden_states += (hidden_states,)
252
264
253
265
next_cache = next_decoder_cache if use_cache else None
@@ -257,7 +269,7 @@ index 73b6bcd8b..2128f4148 100644
257
269
258
270
if not return_dict:
259
271
diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py
260
- index 10c0b6f38..763c5d813 100644
272
+ index 10c0b6f38..aafb914d0 100644
261
273
--- a/src/transformers/models/qwen2/modeling_qwen2.py
262
274
+++ b/src/transformers/models/qwen2/modeling_qwen2.py
263
275
@@ -345,9 +345,9 @@ class Qwen2Attention(nn.Module):
@@ -320,7 +332,15 @@ index 10c0b6f38..763c5d813 100644
320
332
321
333
attn_output = self.o_proj(attn_output)
322
334
323
- @@ -660,6 +660,9 @@ class Qwen2DecoderLayer(nn.Module):
335
+ @@ -648,6 +648,7 @@ class Qwen2DecoderLayer(nn.Module):
336
+ def __init__(self, config: Qwen2Config, layer_idx: int):
337
+ super().__init__()
338
+ self.hidden_size = config.hidden_size
339
+ + self.layer_idx = layer_idx
340
+
341
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
342
+ logger.warning_once(
343
+ @@ -660,6 +661,9 @@ class Qwen2DecoderLayer(nn.Module):
324
344
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
325
345
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
326
346
@@ -330,7 +350,7 @@ index 10c0b6f38..763c5d813 100644
330
350
def forward(
331
351
self,
332
352
hidden_states: torch.Tensor,
333
- @@ -694,28 +697 ,31 @@ class Qwen2DecoderLayer(nn.Module):
353
+ @@ -694,28 +698 ,31 @@ class Qwen2DecoderLayer(nn.Module):
334
354
into the model
335
355
"""
336
356
@@ -381,16 +401,16 @@ index 10c0b6f38..763c5d813 100644
381
401
382
402
outputs = (hidden_states,)
383
403
384
- @@ -723,6 +729 ,8 @@ class Qwen2DecoderLayer(nn.Module):
404
+ @@ -723,6 +730 ,8 @@ class Qwen2DecoderLayer(nn.Module):
385
405
outputs += (self_attn_weights,)
386
406
387
407
if use_cache:
388
408
+ if self.mask_attn:
389
- + past_key_value.update(None, None, self.self_attn. layer_idx)
409
+ + past_key_value.update(None, None, self.layer_idx)
390
410
outputs += (present_key_value,)
391
411
392
412
return outputs
393
- @@ -999,7 +1007 ,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
413
+ @@ -999,7 +1008 ,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
394
414
all_hidden_states += (hidden_states,)
395
415
396
416
next_cache = next_decoder_cache if use_cache else None
0 commit comments