Skip to content

Commit b771d0a

Browse files
authored
modify qwen2 for jit (#2028)
1 parent 36a31ed commit b771d0a

File tree

2 files changed

+150
-73
lines changed

2 files changed

+150
-73
lines changed

mindnlp/transformers/models/qwen2/configuration_qwen2.py

+8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
""" Qwen2 model configuration"""
1616

1717
from mindnlp.utils import logging
18+
from ...modeling_rope_utils import rope_config_validation
1819
from ...configuration_utils import PretrainedConfig
1920

2021

@@ -112,6 +113,7 @@ def __init__(
112113
use_cache=True,
113114
tie_word_embeddings=False,
114115
rope_theta=10000.0,
116+
rope_scaling=None,
115117
use_sliding_window=False,
116118
sliding_window=4096,
117119
max_window_layers=28,
@@ -169,7 +171,13 @@ def __init__(
169171
self.rms_norm_eps = rms_norm_eps
170172
self.use_cache = use_cache
171173
self.rope_theta = rope_theta
174+
self.rope_scaling = rope_scaling
172175
self.attention_dropout = attention_dropout
176+
# Validate the correctness of rotary position embeddings parameters
177+
# BC: if there is a 'type' field, move it to 'rope_type'.
178+
if self.rope_scaling is not None and "type" in self.rope_scaling:
179+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
180+
rope_config_validation(self)
173181

174182
super().__init__(
175183
tie_word_embeddings=tie_word_embeddings,

mindnlp/transformers/models/qwen2/modeling_qwen2.py

+142-73
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from typing import List, Optional, Tuple, Union
2424

2525
import mindspore
26-
from mindnlp.core import nn, ops, get_default_dtype
26+
from mindnlp.core import nn, ops, get_default_dtype, no_grad
2727
from mindnlp.core.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2828
from mindnlp.core.nn import functional as F
2929

@@ -36,6 +36,7 @@
3636
SequenceClassifierOutputWithPast,
3737
TokenClassifierOutput,
3838
)
39+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
3940
from ...modeling_utils import PreTrainedModel
4041
from ....utils import logging
4142
from ....configs import SUPPORT_VIEW, use_pyboost, ON_ORANGE_PI
@@ -132,39 +133,86 @@ def extra_repr(self):
132133

133134
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
134135
class Qwen2RotaryEmbedding(nn.Module):
135-
def __init__(self, dim, max_position_embeddings=2048, base=10000):
136+
def __init__(
137+
self,
138+
dim=None,
139+
max_position_embeddings=2048,
140+
base=10000,
141+
scaling_factor=1.0,
142+
rope_type="default",
143+
config: Optional[Qwen2Config] = None,
144+
):
136145
super().__init__()
146+
# TODO (joao): remove the `if` below, only used for BC
147+
self.rope_kwargs = {}
148+
if config is None:
149+
logger.warning_once(
150+
"`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
151+
"`config` argument. All other arguments will be removed"
152+
)
153+
self.rope_kwargs = {
154+
"rope_type": rope_type,
155+
"factor": scaling_factor,
156+
"dim": dim,
157+
"base": base,
158+
"max_position_embeddings": max_position_embeddings,
159+
}
160+
self.rope_type = rope_type
161+
self.max_seq_len_cached = max_position_embeddings
162+
self.original_max_seq_len = max_position_embeddings
163+
else:
164+
# BC: "rope_type" was originally "type"
165+
if config.rope_scaling is not None:
166+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
167+
else:
168+
self.rope_type = "default"
169+
self.max_seq_len_cached = config.max_position_embeddings
170+
self.original_max_seq_len = config.max_position_embeddings
137171

138-
self.dim = dim
139-
self.max_position_embeddings = max_position_embeddings
140-
self.base = base
141-
inv_freq = 1.0 / (self.base ** (ops.arange(0, self.dim, 2, dtype=mindspore.int64).float() / self.dim))
142-
self.register_buffer("inv_freq", inv_freq, persistent=False)
143-
144-
# Build here to make `torch.jit.trace` work.
145-
self._set_cos_sin_cache(
146-
seq_len=max_position_embeddings, dtype=get_default_dtype()
147-
)
172+
self.config = config
173+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
148174

149-
def _set_cos_sin_cache(self, seq_len, dtype):
150-
self.max_seq_len_cached = seq_len
151-
t = ops.arange(self.max_seq_len_cached, dtype=mindspore.int64).type_as(self.inv_freq)
175+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, **self.rope_kwargs)
176+
self.register_buffer("inv_freq", inv_freq, persistent=False)
177+
self.original_inv_freq = self.inv_freq
152178

153-
freqs = ops.outer(t, self.inv_freq)
154-
# Different from paper, but it uses a different permutation in order to obtain the same calculation
179+
def _dynamic_frequency_update(self, position_ids):
180+
"""
181+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
182+
1 - growing beyond the cached sequence length (allow scaling)
183+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
184+
"""
185+
seq_len = ops.max(position_ids) + 1
186+
if seq_len > self.max_seq_len_cached: # growth
187+
inv_freq, self.attention_scaling = self.rope_init_fn(
188+
self.config, seq_len=seq_len, **self.rope_kwargs
189+
)
190+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
191+
self.max_seq_len_cached = seq_len
192+
193+
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
194+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
195+
self.max_seq_len_cached = self.original_max_seq_len
196+
197+
@no_grad()
198+
def forward(self, x, position_ids):
199+
if "dynamic" in self.rope_type:
200+
self._dynamic_frequency_update(position_ids)
201+
202+
# Core RoPE block
203+
inv_freq_expanded = ops.broadcast_to(self.inv_freq.view(1, -1, 1).float(), (position_ids.shape[0], -1, 1))
204+
position_ids_expanded = ops.unsqueeze(position_ids, 1).float()
205+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
206+
freqs = ops.transpose(ops.matmul(inv_freq_expanded.float(), position_ids_expanded.float()), 1, 2)
155207
emb = ops.cat((freqs, freqs), dim=-1)
156-
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
157-
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
208+
cos = emb.cos()
209+
sin = emb.sin()
158210

159-
def forward(self, x, seq_len=None):
160-
# x: [bs, num_attention_heads, seq_len, head_size]
161-
if seq_len > self.max_seq_len_cached:
162-
self._set_cos_sin_cache(seq_len=seq_len, dtype=x.dtype)
211+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
212+
cos = cos * self.attention_scaling
213+
sin = sin * self.attention_scaling
163214

164-
return (
165-
self.cos_cached[:seq_len].to(dtype=x.dtype),
166-
self.sin_cached[:seq_len].to(dtype=x.dtype),
167-
)
215+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
168216

169217

170218
# Copied from transformers.models.llama.modeling_llama.rotate_half
@@ -176,7 +224,7 @@ def rotate_half(x):
176224
return ops.cat((-x2, x1), dim=-1)
177225

178226
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
179-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
227+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
180228
"""Applies Rotary Position Embedding to the query and key tensors.
181229
182230
Args:
@@ -197,9 +245,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
197245
Returns:
198246
`tuple(mindspore.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
199247
"""
200-
position_ids = (position_ids + cos.shape[0]) % cos.shape[0]
201-
cos = F.embedding(position_ids, cos).unsqueeze(unsqueeze_dim)
202-
sin = F.embedding(position_ids, sin).unsqueeze(unsqueeze_dim)
248+
cos = cos.unsqueeze(unsqueeze_dim)
249+
sin = sin.unsqueeze(unsqueeze_dim)
203250
q_embed = (q * cos) + (rotate_half(q) * sin)
204251
k_embed = (k * cos) + (rotate_half(k) * sin)
205252
return q_embed, k_embed
@@ -270,11 +317,12 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
270317
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
271318
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
272319

273-
self.rotary_emb = Qwen2RotaryEmbedding(
274-
self.head_dim,
275-
max_position_embeddings=self.max_position_embeddings,
276-
base=self.rope_theta,
277-
)
320+
self.rotary_emb = Qwen2RotaryEmbedding(config=self.config)
321+
# self.rotary_emb = Qwen2RotaryEmbedding(
322+
# self.head_dim,
323+
# max_position_embeddings=self.max_position_embeddings,
324+
# base=self.rope_theta,
325+
# )
278326

279327
def forward(
280328
self,
@@ -285,6 +333,7 @@ def forward(
285333
output_attentions: bool = False,
286334
use_cache: bool = False,
287335
cache_position: Optional[mindspore.Tensor] = None,
336+
position_embeddings: Optional[Tuple[mindspore.Tensor, mindspore.Tensor]] = None,
288337
) -> Tuple[mindspore.Tensor, Optional[mindspore.Tensor], Optional[Tuple[mindspore.Tensor]]]:
289338
bsz, q_len, _ = hidden_states.shape
290339

@@ -296,16 +345,25 @@ def forward(
296345
key_states = ops.transpose(key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim), 1, 2)
297346
value_states = ops.transpose(value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim), 1, 2)
298347

299-
kv_seq_len = key_states.shape[-2]
300-
if past_key_value is not None:
301-
if self.layer_idx is None:
302-
raise ValueError(
303-
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
304-
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
305-
"with a layer index."
306-
)
307-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
308-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
348+
# kv_seq_len = key_states.shape[-2]
349+
# if past_key_value is not None:
350+
# if self.layer_idx is None:
351+
# raise ValueError(
352+
# f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
353+
# "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
354+
# "with a layer index."
355+
# )
356+
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
357+
# cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
358+
if position_embeddings is None:
359+
logger.warning_once(
360+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
361+
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
362+
"`position_embeddings` (Tuple of tensors, containing cos and sin)."
363+
)
364+
cos, sin = self.rotary_emb(value_states, position_ids)
365+
else:
366+
cos, sin = position_embeddings
309367
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
310368

311369
if past_key_value is not None:
@@ -318,11 +376,11 @@ def forward(
318376

319377
attn_weights = ops.matmul(query_states, ops.transpose(key_states, 2, 3)) / math.sqrt(self.head_dim)
320378

321-
if attn_weights.shape != (bsz, self.num_heads, q_len, kv_seq_len):
322-
raise ValueError(
323-
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
324-
f" {attn_weights.shape}"
325-
)
379+
# if attn_weights.shape != (bsz, self.num_heads, q_len, kv_seq_len):
380+
# raise ValueError(
381+
# f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
382+
# f" {attn_weights.shape}"
383+
# )
326384

327385
if attention_mask is not None: # no matter the length, we just slice it
328386
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
@@ -380,6 +438,7 @@ def forward(
380438
output_attentions: Optional[bool] = False,
381439
use_cache: Optional[bool] = False,
382440
cache_position: Optional[mindspore.Tensor] = None,
441+
position_embeddings: Optional[Tuple[mindspore.Tensor, mindspore.Tensor]] = None,
383442
**kwargs,
384443
) -> Tuple[mindspore.Tensor, Optional[Tuple[mindspore.Tensor, mindspore.Tensor]]]:
385444
"""
@@ -414,6 +473,7 @@ def forward(
414473
output_attentions=output_attentions,
415474
use_cache=use_cache,
416475
cache_position=cache_position,
476+
position_embeddings=position_embeddings,
417477
)
418478
hidden_states = residual + hidden_states
419479

@@ -441,6 +501,8 @@ class Qwen2PreTrainedModel(PreTrainedModel):
441501
_no_split_modules = ["Qwen2DecoderLayer"]
442502
_skip_keys_device_placement = "past_key_values"
443503
_supports_cache_class = True
504+
# _supports_quantized_cache = True
505+
_supports_static_cache = True
444506

445507
def _init_weights(self, module):
446508
std = self.config.initializer_range
@@ -473,8 +535,9 @@ def __init__(self, config: Qwen2Config):
473535
)
474536
self._attn_implementation = config._attn_implementation
475537
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
476-
538+
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
477539
self.gradient_checkpointing = False
540+
478541
# Initialize weights and apply final processing
479542
self.post_init()
480543

@@ -505,30 +568,32 @@ def forward(
505568

506569
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
507570

508-
if (input_ids is None) ^ (inputs_embeds is not None):
509-
raise ValueError(
510-
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
511-
)
512-
513-
if self.gradient_checkpointing and self.training:
514-
if use_cache:
515-
logger.warning_once(
516-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
571+
if not self.skip_syntax:
572+
if (input_ids is None) ^ (inputs_embeds is not None):
573+
raise ValueError(
574+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
517575
)
518-
use_cache = False
519576

520-
use_legacy_cache = False
521-
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
522-
use_legacy_cache = True
523-
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
577+
if self.gradient_checkpointing and self.training and use_cache:
524578
logger.warning_once(
525-
"We detected that you are passing `past_key_values` as a tuple and this is deprecated.43. "
526-
"Please use an appropriate `Cache` class"
579+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
527580
)
581+
use_cache = False
528582

529583
if inputs_embeds is None:
530584
inputs_embeds = self.embed_tokens(input_ids)
531585

586+
return_legacy_cache = False
587+
if (
588+
use_cache and not isinstance(past_key_values, Cache) and not self.training
589+
): # kept for BC (non `Cache` `past_key_values` inputs)
590+
return_legacy_cache = True
591+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
592+
logger.warning_once(
593+
"We detected that you are passing `past_key_values` as a tuple and this is deprecated. "
594+
"Please use an appropriate `Cache` class"
595+
)
596+
532597
if cache_position is None:
533598
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
534599
cache_position = ops.arange(
@@ -540,15 +605,17 @@ def forward(
540605
causal_mask = self._update_causal_mask(
541606
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
542607
)
543-
544608
hidden_states = inputs_embeds
545609

610+
# create position embeddings to be shared across the decoder layers
611+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
612+
546613
# decoder layers
547614
all_hidden_states = () if output_hidden_states else None
548615
all_self_attns = () if output_attentions else None
549616
next_decoder_cache = None
550617

551-
for decoder_layer in self.layers:
618+
for decoder_layer in self.layers._modules.values():
552619
if output_hidden_states:
553620
all_hidden_states += (hidden_states,)
554621

@@ -562,6 +629,7 @@ def forward(
562629
output_attentions,
563630
use_cache,
564631
cache_position,
632+
position_embeddings,
565633
)
566634
else:
567635
layer_outputs = decoder_layer(
@@ -572,6 +640,7 @@ def forward(
572640
output_attentions=output_attentions,
573641
use_cache=use_cache,
574642
cache_position=cache_position,
643+
position_embeddings=position_embeddings,
575644
)
576645

577646
hidden_states = layer_outputs[0]
@@ -588,9 +657,9 @@ def forward(
588657
if output_hidden_states:
589658
all_hidden_states += (hidden_states,)
590659

591-
next_cache = None
592-
if use_cache:
593-
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
660+
next_cache = next_decoder_cache if use_cache else None
661+
if return_legacy_cache:
662+
next_cache = next_cache.to_legacy_cache()
594663

595664
if not return_dict:
596665
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)

0 commit comments

Comments
 (0)