2323from typing import List , Optional , Tuple , Union
2424
2525import mindspore
26- from mindnlp .core import nn , ops , get_default_dtype
26+ from mindnlp .core import nn , ops , get_default_dtype , no_grad
2727from mindnlp .core .nn import BCEWithLogitsLoss , CrossEntropyLoss , MSELoss
2828from mindnlp .core .nn import functional as F
2929
3636 SequenceClassifierOutputWithPast ,
3737 TokenClassifierOutput ,
3838)
39+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
3940from ...modeling_utils import PreTrainedModel
4041from ....utils import logging
4142from ....configs import SUPPORT_VIEW , use_pyboost , ON_ORANGE_PI
@@ -132,39 +133,86 @@ def extra_repr(self):
132133
133134# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
134135class Qwen2RotaryEmbedding (nn .Module ):
135- def __init__ (self , dim , max_position_embeddings = 2048 , base = 10000 ):
136+ def __init__ (
137+ self ,
138+ dim = None ,
139+ max_position_embeddings = 2048 ,
140+ base = 10000 ,
141+ scaling_factor = 1.0 ,
142+ rope_type = "default" ,
143+ config : Optional [Qwen2Config ] = None ,
144+ ):
136145 super ().__init__ ()
146+ # TODO (joao): remove the `if` below, only used for BC
147+ self .rope_kwargs = {}
148+ if config is None :
149+ logger .warning_once (
150+ "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
151+ "`config` argument. All other arguments will be removed"
152+ )
153+ self .rope_kwargs = {
154+ "rope_type" : rope_type ,
155+ "factor" : scaling_factor ,
156+ "dim" : dim ,
157+ "base" : base ,
158+ "max_position_embeddings" : max_position_embeddings ,
159+ }
160+ self .rope_type = rope_type
161+ self .max_seq_len_cached = max_position_embeddings
162+ self .original_max_seq_len = max_position_embeddings
163+ else :
164+ # BC: "rope_type" was originally "type"
165+ if config .rope_scaling is not None :
166+ self .rope_type = config .rope_scaling .get ("rope_type" , config .rope_scaling .get ("type" ))
167+ else :
168+ self .rope_type = "default"
169+ self .max_seq_len_cached = config .max_position_embeddings
170+ self .original_max_seq_len = config .max_position_embeddings
137171
138- self .dim = dim
139- self .max_position_embeddings = max_position_embeddings
140- self .base = base
141- inv_freq = 1.0 / (self .base ** (ops .arange (0 , self .dim , 2 , dtype = mindspore .int64 ).float () / self .dim ))
142- self .register_buffer ("inv_freq" , inv_freq , persistent = False )
143-
144- # Build here to make `torch.jit.trace` work.
145- self ._set_cos_sin_cache (
146- seq_len = max_position_embeddings , dtype = get_default_dtype ()
147- )
172+ self .config = config
173+ self .rope_init_fn = ROPE_INIT_FUNCTIONS [self .rope_type ]
148174
149- def _set_cos_sin_cache (self , seq_len , dtype ):
150- self .max_seq_len_cached = seq_len
151- t = ops . arange ( self .max_seq_len_cached , dtype = mindspore . int64 ). type_as ( self .inv_freq )
175+ inv_freq , self . attention_scaling = self . rope_init_fn (self . config , ** self . rope_kwargs )
176+ self .register_buffer ( "inv_freq" , inv_freq , persistent = False )
177+ self .original_inv_freq = self .inv_freq
152178
153- freqs = ops .outer (t , self .inv_freq )
154- # Different from paper, but it uses a different permutation in order to obtain the same calculation
179+ def _dynamic_frequency_update (self , position_ids ):
180+ """
181+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
182+ 1 - growing beyond the cached sequence length (allow scaling)
183+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
184+ """
185+ seq_len = ops .max (position_ids ) + 1
186+ if seq_len > self .max_seq_len_cached : # growth
187+ inv_freq , self .attention_scaling = self .rope_init_fn (
188+ self .config , seq_len = seq_len , ** self .rope_kwargs
189+ )
190+ self .register_buffer ("inv_freq" , inv_freq , persistent = False ) # TODO joao: may break with compilation
191+ self .max_seq_len_cached = seq_len
192+
193+ if seq_len < self .original_max_seq_len and self .max_seq_len_cached > self .original_max_seq_len : # reset
194+ self .register_buffer ("inv_freq" , self .original_inv_freq , persistent = False )
195+ self .max_seq_len_cached = self .original_max_seq_len
196+
197+ @no_grad ()
198+ def forward (self , x , position_ids ):
199+ if "dynamic" in self .rope_type :
200+ self ._dynamic_frequency_update (position_ids )
201+
202+ # Core RoPE block
203+ inv_freq_expanded = ops .broadcast_to (self .inv_freq .view (1 , - 1 , 1 ).float (), (position_ids .shape [0 ], - 1 , 1 ))
204+ position_ids_expanded = ops .unsqueeze (position_ids , 1 ).float ()
205+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
206+ freqs = ops .transpose (ops .matmul (inv_freq_expanded .float (), position_ids_expanded .float ()), 1 , 2 )
155207 emb = ops .cat ((freqs , freqs ), dim = - 1 )
156- self . register_buffer ( "cos_cached" , emb .cos (). to ( dtype ), persistent = False )
157- self . register_buffer ( "sin_cached" , emb .sin (). to ( dtype ), persistent = False )
208+ cos = emb .cos ()
209+ sin = emb .sin ()
158210
159- def forward (self , x , seq_len = None ):
160- # x: [bs, num_attention_heads, seq_len, head_size]
161- if seq_len > self .max_seq_len_cached :
162- self ._set_cos_sin_cache (seq_len = seq_len , dtype = x .dtype )
211+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
212+ cos = cos * self .attention_scaling
213+ sin = sin * self .attention_scaling
163214
164- return (
165- self .cos_cached [:seq_len ].to (dtype = x .dtype ),
166- self .sin_cached [:seq_len ].to (dtype = x .dtype ),
167- )
215+ return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
168216
169217
170218# Copied from transformers.models.llama.modeling_llama.rotate_half
@@ -176,7 +224,7 @@ def rotate_half(x):
176224 return ops .cat ((- x2 , x1 ), dim = - 1 )
177225
178226# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
179- def apply_rotary_pos_emb (q , k , cos , sin , position_ids , unsqueeze_dim = 1 ):
227+ def apply_rotary_pos_emb (q , k , cos , sin , position_ids = None , unsqueeze_dim = 1 ):
180228 """Applies Rotary Position Embedding to the query and key tensors.
181229
182230 Args:
@@ -197,9 +245,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
197245 Returns:
198246 `tuple(mindspore.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
199247 """
200- position_ids = (position_ids + cos .shape [0 ]) % cos .shape [0 ]
201- cos = F .embedding (position_ids , cos ).unsqueeze (unsqueeze_dim )
202- sin = F .embedding (position_ids , sin ).unsqueeze (unsqueeze_dim )
248+ cos = cos .unsqueeze (unsqueeze_dim )
249+ sin = sin .unsqueeze (unsqueeze_dim )
203250 q_embed = (q * cos ) + (rotate_half (q ) * sin )
204251 k_embed = (k * cos ) + (rotate_half (k ) * sin )
205252 return q_embed , k_embed
@@ -270,11 +317,12 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
270317 self .v_proj = nn .Linear (self .hidden_size , self .num_key_value_heads * self .head_dim , bias = True )
271318 self .o_proj = nn .Linear (self .num_heads * self .head_dim , self .hidden_size , bias = False )
272319
273- self .rotary_emb = Qwen2RotaryEmbedding (
274- self .head_dim ,
275- max_position_embeddings = self .max_position_embeddings ,
276- base = self .rope_theta ,
277- )
320+ self .rotary_emb = Qwen2RotaryEmbedding (config = self .config )
321+ # self.rotary_emb = Qwen2RotaryEmbedding(
322+ # self.head_dim,
323+ # max_position_embeddings=self.max_position_embeddings,
324+ # base=self.rope_theta,
325+ # )
278326
279327 def forward (
280328 self ,
@@ -285,6 +333,7 @@ def forward(
285333 output_attentions : bool = False ,
286334 use_cache : bool = False ,
287335 cache_position : Optional [mindspore .Tensor ] = None ,
336+ position_embeddings : Optional [Tuple [mindspore .Tensor , mindspore .Tensor ]] = None ,
288337 ) -> Tuple [mindspore .Tensor , Optional [mindspore .Tensor ], Optional [Tuple [mindspore .Tensor ]]]:
289338 bsz , q_len , _ = hidden_states .shape
290339
@@ -296,16 +345,25 @@ def forward(
296345 key_states = ops .transpose (key_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ), 1 , 2 )
297346 value_states = ops .transpose (value_states .view (bsz , q_len , self .num_key_value_heads , self .head_dim ), 1 , 2 )
298347
299- kv_seq_len = key_states .shape [- 2 ]
300- if past_key_value is not None :
301- if self .layer_idx is None :
302- raise ValueError (
303- f"The cache structure has changed since version v4.36. If you are using { self .__class__ .__name__ } "
304- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
305- "with a layer index."
306- )
307- kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
308- cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
348+ # kv_seq_len = key_states.shape[-2]
349+ # if past_key_value is not None:
350+ # if self.layer_idx is None:
351+ # raise ValueError(
352+ # f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
353+ # "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
354+ # "with a layer index."
355+ # )
356+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
357+ # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
358+ if position_embeddings is None :
359+ logger .warning_once (
360+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
361+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
362+ "`position_embeddings` (Tuple of tensors, containing cos and sin)."
363+ )
364+ cos , sin = self .rotary_emb (value_states , position_ids )
365+ else :
366+ cos , sin = position_embeddings
309367 query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
310368
311369 if past_key_value is not None :
@@ -318,11 +376,11 @@ def forward(
318376
319377 attn_weights = ops .matmul (query_states , ops .transpose (key_states , 2 , 3 )) / math .sqrt (self .head_dim )
320378
321- if attn_weights .shape != (bsz , self .num_heads , q_len , kv_seq_len ):
322- raise ValueError (
323- f"Attention weights should be of size { (bsz , self .num_heads , q_len , kv_seq_len )} , but is"
324- f" { attn_weights .shape } "
325- )
379+ # if attn_weights.shape != (bsz, self.num_heads, q_len, kv_seq_len):
380+ # raise ValueError(
381+ # f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
382+ # f" {attn_weights.shape}"
383+ # )
326384
327385 if attention_mask is not None : # no matter the length, we just slice it
328386 causal_mask = attention_mask [:, :, :, : key_states .shape [- 2 ]]
@@ -380,6 +438,7 @@ def forward(
380438 output_attentions : Optional [bool ] = False ,
381439 use_cache : Optional [bool ] = False ,
382440 cache_position : Optional [mindspore .Tensor ] = None ,
441+ position_embeddings : Optional [Tuple [mindspore .Tensor , mindspore .Tensor ]] = None ,
383442 ** kwargs ,
384443 ) -> Tuple [mindspore .Tensor , Optional [Tuple [mindspore .Tensor , mindspore .Tensor ]]]:
385444 """
@@ -414,6 +473,7 @@ def forward(
414473 output_attentions = output_attentions ,
415474 use_cache = use_cache ,
416475 cache_position = cache_position ,
476+ position_embeddings = position_embeddings ,
417477 )
418478 hidden_states = residual + hidden_states
419479
@@ -441,6 +501,8 @@ class Qwen2PreTrainedModel(PreTrainedModel):
441501 _no_split_modules = ["Qwen2DecoderLayer" ]
442502 _skip_keys_device_placement = "past_key_values"
443503 _supports_cache_class = True
504+ # _supports_quantized_cache = True
505+ _supports_static_cache = True
444506
445507 def _init_weights (self , module ):
446508 std = self .config .initializer_range
@@ -473,8 +535,9 @@ def __init__(self, config: Qwen2Config):
473535 )
474536 self ._attn_implementation = config ._attn_implementation
475537 self .norm = Qwen2RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
476-
538+ self . rotary_emb = Qwen2RotaryEmbedding ( config = config )
477539 self .gradient_checkpointing = False
540+
478541 # Initialize weights and apply final processing
479542 self .post_init ()
480543
@@ -505,30 +568,32 @@ def forward(
505568
506569 return_dict = return_dict if return_dict is not None else self .config .use_return_dict
507570
508- if (input_ids is None ) ^ (inputs_embeds is not None ):
509- raise ValueError (
510- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
511- )
512-
513- if self .gradient_checkpointing and self .training :
514- if use_cache :
515- logger .warning_once (
516- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
571+ if not self .skip_syntax :
572+ if (input_ids is None ) ^ (inputs_embeds is not None ):
573+ raise ValueError (
574+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
517575 )
518- use_cache = False
519576
520- use_legacy_cache = False
521- if use_cache and not isinstance (past_key_values , Cache ) and not self .training :
522- use_legacy_cache = True
523- past_key_values = DynamicCache .from_legacy_cache (past_key_values )
577+ if self .gradient_checkpointing and self .training and use_cache :
524578 logger .warning_once (
525- "We detected that you are passing `past_key_values` as a tuple and this is deprecated.43. "
526- "Please use an appropriate `Cache` class"
579+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
527580 )
581+ use_cache = False
528582
529583 if inputs_embeds is None :
530584 inputs_embeds = self .embed_tokens (input_ids )
531585
586+ return_legacy_cache = False
587+ if (
588+ use_cache and not isinstance (past_key_values , Cache ) and not self .training
589+ ): # kept for BC (non `Cache` `past_key_values` inputs)
590+ return_legacy_cache = True
591+ past_key_values = DynamicCache .from_legacy_cache (past_key_values )
592+ logger .warning_once (
593+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated. "
594+ "Please use an appropriate `Cache` class"
595+ )
596+
532597 if cache_position is None :
533598 past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
534599 cache_position = ops .arange (
@@ -540,15 +605,17 @@ def forward(
540605 causal_mask = self ._update_causal_mask (
541606 attention_mask , inputs_embeds , cache_position , past_key_values , output_attentions
542607 )
543-
544608 hidden_states = inputs_embeds
545609
610+ # create position embeddings to be shared across the decoder layers
611+ position_embeddings = self .rotary_emb (hidden_states , position_ids )
612+
546613 # decoder layers
547614 all_hidden_states = () if output_hidden_states else None
548615 all_self_attns = () if output_attentions else None
549616 next_decoder_cache = None
550617
551- for decoder_layer in self .layers :
618+ for decoder_layer in self .layers . _modules . values () :
552619 if output_hidden_states :
553620 all_hidden_states += (hidden_states ,)
554621
@@ -562,6 +629,7 @@ def forward(
562629 output_attentions ,
563630 use_cache ,
564631 cache_position ,
632+ position_embeddings ,
565633 )
566634 else :
567635 layer_outputs = decoder_layer (
@@ -572,6 +640,7 @@ def forward(
572640 output_attentions = output_attentions ,
573641 use_cache = use_cache ,
574642 cache_position = cache_position ,
643+ position_embeddings = position_embeddings ,
575644 )
576645
577646 hidden_states = layer_outputs [0 ]
@@ -588,9 +657,9 @@ def forward(
588657 if output_hidden_states :
589658 all_hidden_states += (hidden_states ,)
590659
591- next_cache = None
592- if use_cache :
593- next_cache = next_decoder_cache .to_legacy_cache () if use_legacy_cache else next_decoder_cache
660+ next_cache = next_decoder_cache if use_cache else None
661+ if return_legacy_cache :
662+ next_cache = next_cache .to_legacy_cache ()
594663
595664 if not return_dict :
596665 return tuple (v for v in [hidden_states , next_cache , all_hidden_states , all_self_attns ] if v is not None )
0 commit comments