@@ -57,6 +57,47 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
57
57
return q_embed , k_embed
58
58
59
59
60
+ def su_scaled_rope_forward (self , x : torch .Tensor , position_ids : torch .Tensor , seq_len = None ):
61
+ if self .inv_freq is None :
62
+ short_ext_factors = torch .tensor (self .short_factor , dtype = torch .float32 , device = x .device )
63
+ inv_freq_shape = torch .arange (0 , self .dim , 2 ,
64
+ dtype = torch .int64 , device = x .device ).float () / self .dim
65
+ self .inv_freq = 1.0 / (short_ext_factors * self .base ** inv_freq_shape )
66
+
67
+ long_ext_factors = torch .tensor (self .long_factor , dtype = torch .float32 , device = x .device )
68
+ self .register_buffer ("long_inv_freq" , None , persistent = False )
69
+ self .long_inv_freq = 1.0 / (long_ext_factors * self .base ** inv_freq_shape )
70
+
71
+ seq_len = seq_len if seq_len is not None else torch .max (position_ids ) + 1
72
+ if seq_len > self .original_max_position_embeddings :
73
+ inv_freq = self .long_inv_freq
74
+ else :
75
+ inv_freq = self .inv_freq
76
+
77
+ inv_freq_expanded = inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 )
78
+ position_ids_expanded = position_ids [:, None , :].float ()
79
+
80
+ # Force float32 since bfloat16 loses precision on long contexts
81
+ # See https://github.com/huggingface/transformers/pull/29285
82
+ device_type = x .device .type
83
+ device_type = device_type if isinstance (device_type , str ) and device_type != "mps" else "cpu"
84
+ with torch .autocast (device_type = device_type , enabled = False ):
85
+ freqs = (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
86
+ emb = torch .cat ((freqs , freqs ), dim = - 1 )
87
+
88
+ scale = self .max_position_embeddings / self .original_max_position_embeddings
89
+ if scale <= 1.0 :
90
+ scaling_factor = 1.0
91
+ else :
92
+ scaling_factor = math .sqrt (
93
+ 1 + math .log (scale ) / math .log (self .original_max_position_embeddings )
94
+ )
95
+
96
+ cos = emb .cos () * scaling_factor
97
+ sin = emb .sin () * scaling_factor
98
+ return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
99
+
100
+
60
101
def attention_forward (
61
102
self ,
62
103
hidden_states : torch .Tensor ,
0 commit comments