23
23
from typing import List , Optional , Tuple , Union
24
24
25
25
import mindspore
26
- from mindnlp .core import nn , ops , get_default_dtype
26
+ from mindnlp .core import nn , ops , get_default_dtype , no_grad
27
27
from mindnlp .core .nn import BCEWithLogitsLoss , CrossEntropyLoss , MSELoss
28
28
from mindnlp .core .nn import functional as F
29
29
36
36
SequenceClassifierOutputWithPast ,
37
37
TokenClassifierOutput ,
38
38
)
39
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
39
40
from ...modeling_utils import PreTrainedModel
40
41
from ....utils import logging
41
42
from ....configs import SUPPORT_VIEW , use_pyboost , ON_ORANGE_PI
@@ -132,39 +133,86 @@ def extra_repr(self):
132
133
133
134
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
134
135
class Qwen2RotaryEmbedding (nn .Module ):
135
- def __init__ (self , dim , max_position_embeddings = 2048 , base = 10000 ):
136
+ def __init__ (
137
+ self ,
138
+ dim = None ,
139
+ max_position_embeddings = 2048 ,
140
+ base = 10000 ,
141
+ scaling_factor = 1.0 ,
142
+ rope_type = "default" ,
143
+ config : Optional [Qwen2Config ] = None ,
144
+ ):
136
145
super ().__init__ ()
146
+ # TODO (joao): remove the `if` below, only used for BC
147
+ self .rope_kwargs = {}
148
+ if config is None :
149
+ logger .warning_once (
150
+ "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
151
+ "`config` argument. All other arguments will be removed"
152
+ )
153
+ self .rope_kwargs = {
154
+ "rope_type" : rope_type ,
155
+ "factor" : scaling_factor ,
156
+ "dim" : dim ,
157
+ "base" : base ,
158
+ "max_position_embeddings" : max_position_embeddings ,
159
+ }
160
+ self .rope_type = rope_type
161
+ self .max_seq_len_cached = max_position_embeddings
162
+ self .original_max_seq_len = max_position_embeddings
163
+ else :
164
+ # BC: "rope_type" was originally "type"
165
+ if config .rope_scaling is not None :
166
+ self .rope_type = config .rope_scaling .get ("rope_type" , config .rope_scaling .get ("type" ))
167
+ else :
168
+ self .rope_type = "default"
169
+ self .max_seq_len_cached = config .max_position_embeddings
170
+ self .original_max_seq_len = config .max_position_embeddings
137
171
138
- self .dim = dim
139
- self .max_position_embeddings = max_position_embeddings
140
- self .base = base
141
- inv_freq = 1.0 / (self .base ** (ops .arange (0 , self .dim , 2 , dtype = mindspore .int64 ).float () / self .dim ))
142
- self .register_buffer ("inv_freq" , inv_freq , persistent = False )
143
-
144
- # Build here to make `torch.jit.trace` work.
145
- self ._set_cos_sin_cache (
146
- seq_len = max_position_embeddings , dtype = get_default_dtype ()
147
- )
172
+ self .config = config
173
+ self .rope_init_fn = ROPE_INIT_FUNCTIONS [self .rope_type ]
148
174
149
- def _set_cos_sin_cache (self , seq_len , dtype ):
150
- self .max_seq_len_cached = seq_len
151
- t = ops . arange ( self .max_seq_len_cached , dtype = mindspore . int64 ). type_as ( self .inv_freq )
175
+ inv_freq , self . attention_scaling = self . rope_init_fn (self . config , ** self . rope_kwargs )
176
+ self .register_buffer ( "inv_freq" , inv_freq , persistent = False )
177
+ self .original_inv_freq = self .inv_freq
152
178
153
- freqs = ops .outer (t , self .inv_freq )
154
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
179
+ def _dynamic_frequency_update (self , position_ids ):
180
+ """
181
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
182
+ 1 - growing beyond the cached sequence length (allow scaling)
183
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
184
+ """
185
+ seq_len = ops .max (position_ids ) + 1
186
+ if seq_len > self .max_seq_len_cached : # growth
187
+ inv_freq , self .attention_scaling = self .rope_init_fn (
188
+ self .config , seq_len = seq_len , ** self .rope_kwargs
189
+ )
190
+ self .register_buffer ("inv_freq" , inv_freq , persistent = False ) # TODO joao: may break with compilation
191
+ self .max_seq_len_cached = seq_len
192
+
193
+ if seq_len < self .original_max_seq_len and self .max_seq_len_cached > self .original_max_seq_len : # reset
194
+ self .register_buffer ("inv_freq" , self .original_inv_freq , persistent = False )
195
+ self .max_seq_len_cached = self .original_max_seq_len
196
+
197
+ @no_grad ()
198
+ def forward (self , x , position_ids ):
199
+ if "dynamic" in self .rope_type :
200
+ self ._dynamic_frequency_update (position_ids )
201
+
202
+ # Core RoPE block
203
+ inv_freq_expanded = ops .broadcast_to (self .inv_freq .view (1 , - 1 , 1 ).float (), (position_ids .shape [0 ], - 1 , 1 ))
204
+ position_ids_expanded = ops .unsqueeze (position_ids , 1 ).float ()
205
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
206
+ freqs = ops .transpose (ops .matmul (inv_freq_expanded .float (), position_ids_expanded .float ()), 1 , 2 )
155
207
emb = ops .cat ((freqs , freqs ), dim = - 1 )
156
- self . register_buffer ( "cos_cached" , emb .cos (). to ( dtype ), persistent = False )
157
- self . register_buffer ( "sin_cached" , emb .sin (). to ( dtype ), persistent = False )
208
+ cos = emb .cos ()
209
+ sin = emb .sin ()
158
210
159
- def forward (self , x , seq_len = None ):
160
- # x: [bs, num_attention_heads, seq_len, head_size]
161
- if seq_len > self .max_seq_len_cached :
162
- self ._set_cos_sin_cache (seq_len = seq_len , dtype = x .dtype )
211
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
212
+ cos = cos * self .attention_scaling
213
+ sin = sin * self .attention_scaling
163
214
164
- return (
165
- self .cos_cached [:seq_len ].to (dtype = x .dtype ),
166
- self .sin_cached [:seq_len ].to (dtype = x .dtype ),
167
- )
215
+ return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
168
216
169
217
170
218
# Copied from transformers.models.llama.modeling_llama.rotate_half
@@ -176,7 +224,7 @@ def rotate_half(x):
176
224
return ops .cat ((- x2 , x1 ), dim = - 1 )
177
225
178
226
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
179
- def apply_rotary_pos_emb (q , k , cos , sin , position_ids , unsqueeze_dim = 1 ):
227
+ def apply_rotary_pos_emb (q , k , cos , sin , position_ids = None , unsqueeze_dim = 1 ):
180
228
"""Applies Rotary Position Embedding to the query and key tensors.
181
229
182
230
Args:
@@ -197,9 +245,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
197
245
Returns:
198
246
`tuple(mindspore.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
199
247
"""
200
- position_ids = (position_ids + cos .shape [0 ]) % cos .shape [0 ]
201
- cos = F .embedding (position_ids , cos ).unsqueeze (unsqueeze_dim )
202
- sin = F .embedding (position_ids , sin ).unsqueeze (unsqueeze_dim )
248
+ cos = cos .unsqueeze (unsqueeze_dim )
249
+ sin = sin .unsqueeze (unsqueeze_dim )
203
250
q_embed = (q * cos ) + (rotate_half (q ) * sin )
204
251
k_embed = (k * cos ) + (rotate_half (k ) * sin )
205
252
return q_embed , k_embed
@@ -270,11 +317,12 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
270
317
self .v_proj = nn .Linear (self .hidden_size , self .num_key_value_heads * self .head_dim , bias = True )
271
318
self .o_proj = nn .Linear (self .num_heads * self .head_dim , self .hidden_size , bias = False )
272
319
273
- self .rotary_emb = Qwen2RotaryEmbedding (
274
- self .head_dim ,
275
- max_position_embeddings = self .max_position_embeddings ,
276
- base = self .rope_theta ,
277
- )
320
+ self .rotary_emb = Qwen2RotaryEmbedding (config = self .config )
321
+ # self.rotary_emb = Qwen2RotaryEmbedding(
322
+ # self.head_dim,
323
+ # max_position_embeddings=self.max_position_embeddings,
324
+ # base=self.rope_theta,
325
+ # )
278
326
279
327
def forward (
280
328
self ,
@@ -285,6 +333,7 @@ def forward(
285
333
output_attentions : bool = False ,
286
334
use_cache : bool = False ,
287
335
cache_position : Optional [mindspore .Tensor ] = None ,
336
+ position_embeddings : Optional [Tuple [mindspore .Tensor , mindspore .Tensor ]] = None ,
288
337
) -> Tuple [mindspore .Tensor , Optional [mindspore .Tensor ], Optional [Tuple [mindspore .Tensor ]]]:
289
338
bsz , q_len , _ = hidden_states .shape
290
339
@@ -296,16 +345,25 @@ def forward(
296
345
key_states = ops .transpose (key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ), 1 , 2 )
297
346
value_states = ops .transpose (value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ), 1 , 2 )
298
347
299
- kv_seq_len = key_states .shape [- 2 ]
300
- if past_key_value is not None :
301
- if self .layer_idx is None :
302
- raise ValueError (
303
- f"The cache structure has changed since version v4.36. If you are using { self .__class__ .__name__ } "
304
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
305
- "with a layer index."
306
- )
307
- kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
308
- cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
348
+ # kv_seq_len = key_states.shape[-2]
349
+ # if past_key_value is not None:
350
+ # if self.layer_idx is None:
351
+ # raise ValueError(
352
+ # f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
353
+ # "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
354
+ # "with a layer index."
355
+ # )
356
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
357
+ # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
358
+ if position_embeddings is None :
359
+ logger .warning_once (
360
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
361
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
362
+ "`position_embeddings` (Tuple of tensors, containing cos and sin)."
363
+ )
364
+ cos , sin = self .rotary_emb (value_states , position_ids )
365
+ else :
366
+ cos , sin = position_embeddings
309
367
query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
310
368
311
369
if past_key_value is not None :
@@ -318,11 +376,11 @@ def forward(
318
376
319
377
attn_weights = ops .matmul (query_states , ops .transpose (key_states , 2 , 3 )) / math .sqrt (self .head_dim )
320
378
321
- if attn_weights .shape != (bsz , self .num_heads , q_len , kv_seq_len ):
322
- raise ValueError (
323
- f"Attention weights should be of size { (bsz , self .num_heads , q_len , kv_seq_len )} , but is"
324
- f" { attn_weights .shape } "
325
- )
379
+ # if attn_weights.shape != (bsz, self.num_heads, q_len, kv_seq_len):
380
+ # raise ValueError(
381
+ # f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
382
+ # f" {attn_weights.shape}"
383
+ # )
326
384
327
385
if attention_mask is not None : # no matter the length, we just slice it
328
386
causal_mask = attention_mask [:, :, :, : key_states .shape [- 2 ]]
@@ -380,6 +438,7 @@ def forward(
380
438
output_attentions : Optional [bool ] = False ,
381
439
use_cache : Optional [bool ] = False ,
382
440
cache_position : Optional [mindspore .Tensor ] = None ,
441
+ position_embeddings : Optional [Tuple [mindspore .Tensor , mindspore .Tensor ]] = None ,
383
442
** kwargs ,
384
443
) -> Tuple [mindspore .Tensor , Optional [Tuple [mindspore .Tensor , mindspore .Tensor ]]]:
385
444
"""
@@ -414,6 +473,7 @@ def forward(
414
473
output_attentions = output_attentions ,
415
474
use_cache = use_cache ,
416
475
cache_position = cache_position ,
476
+ position_embeddings = position_embeddings ,
417
477
)
418
478
hidden_states = residual + hidden_states
419
479
@@ -441,6 +501,8 @@ class Qwen2PreTrainedModel(PreTrainedModel):
441
501
_no_split_modules = ["Qwen2DecoderLayer" ]
442
502
_skip_keys_device_placement = "past_key_values"
443
503
_supports_cache_class = True
504
+ # _supports_quantized_cache = True
505
+ _supports_static_cache = True
444
506
445
507
def _init_weights (self , module ):
446
508
std = self .config .initializer_range
@@ -473,8 +535,9 @@ def __init__(self, config: Qwen2Config):
473
535
)
474
536
self ._attn_implementation = config ._attn_implementation
475
537
self .norm = Qwen2RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
476
-
538
+ self . rotary_emb = Qwen2RotaryEmbedding ( config = config )
477
539
self .gradient_checkpointing = False
540
+
478
541
# Initialize weights and apply final processing
479
542
self .post_init ()
480
543
@@ -505,30 +568,32 @@ def forward(
505
568
506
569
return_dict = return_dict if return_dict is not None else self .config .use_return_dict
507
570
508
- if (input_ids is None ) ^ (inputs_embeds is not None ):
509
- raise ValueError (
510
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
511
- )
512
-
513
- if self .gradient_checkpointing and self .training :
514
- if use_cache :
515
- logger .warning_once (
516
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
571
+ if not self .skip_syntax :
572
+ if (input_ids is None ) ^ (inputs_embeds is not None ):
573
+ raise ValueError (
574
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
517
575
)
518
- use_cache = False
519
576
520
- use_legacy_cache = False
521
- if use_cache and not isinstance (past_key_values , Cache ) and not self .training :
522
- use_legacy_cache = True
523
- past_key_values = DynamicCache .from_legacy_cache (past_key_values )
577
+ if self .gradient_checkpointing and self .training and use_cache :
524
578
logger .warning_once (
525
- "We detected that you are passing `past_key_values` as a tuple and this is deprecated.43. "
526
- "Please use an appropriate `Cache` class"
579
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
527
580
)
581
+ use_cache = False
528
582
529
583
if inputs_embeds is None :
530
584
inputs_embeds = self .embed_tokens (input_ids )
531
585
586
+ return_legacy_cache = False
587
+ if (
588
+ use_cache and not isinstance (past_key_values , Cache ) and not self .training
589
+ ): # kept for BC (non `Cache` `past_key_values` inputs)
590
+ return_legacy_cache = True
591
+ past_key_values = DynamicCache .from_legacy_cache (past_key_values )
592
+ logger .warning_once (
593
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated. "
594
+ "Please use an appropriate `Cache` class"
595
+ )
596
+
532
597
if cache_position is None :
533
598
past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
534
599
cache_position = ops .arange (
@@ -540,15 +605,17 @@ def forward(
540
605
causal_mask = self ._update_causal_mask (
541
606
attention_mask , inputs_embeds , cache_position , past_key_values , output_attentions
542
607
)
543
-
544
608
hidden_states = inputs_embeds
545
609
610
+ # create position embeddings to be shared across the decoder layers
611
+ position_embeddings = self .rotary_emb (hidden_states , position_ids )
612
+
546
613
# decoder layers
547
614
all_hidden_states = () if output_hidden_states else None
548
615
all_self_attns = () if output_attentions else None
549
616
next_decoder_cache = None
550
617
551
- for decoder_layer in self .layers :
618
+ for decoder_layer in self .layers . _modules . values () :
552
619
if output_hidden_states :
553
620
all_hidden_states += (hidden_states ,)
554
621
@@ -562,6 +629,7 @@ def forward(
562
629
output_attentions ,
563
630
use_cache ,
564
631
cache_position ,
632
+ position_embeddings ,
565
633
)
566
634
else :
567
635
layer_outputs = decoder_layer (
@@ -572,6 +640,7 @@ def forward(
572
640
output_attentions = output_attentions ,
573
641
use_cache = use_cache ,
574
642
cache_position = cache_position ,
643
+ position_embeddings = position_embeddings ,
575
644
)
576
645
577
646
hidden_states = layer_outputs [0 ]
@@ -588,9 +657,9 @@ def forward(
588
657
if output_hidden_states :
589
658
all_hidden_states += (hidden_states ,)
590
659
591
- next_cache = None
592
- if use_cache :
593
- next_cache = next_decoder_cache .to_legacy_cache () if use_legacy_cache else next_decoder_cache
660
+ next_cache = next_decoder_cache if use_cache else None
661
+ if return_legacy_cache :
662
+ next_cache = next_cache .to_legacy_cache ()
594
663
595
664
if not return_dict :
596
665
return tuple (v for v in [hidden_states , next_cache , all_hidden_states , all_self_attns ] if v is not None )
0 commit comments