@@ -182,6 +182,66 @@ def prepare_sincos(self):
182
182
cfg .l3_rope_original_max_position_embeddings ,
183
183
)
184
184
185
+ # YaRN
186
+ # Adapted from transformers: https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/modeling_rope_utils.py#L163
187
+
188
+ elif cfg .alt_rope_method == "yarn" :
189
+
190
+ yarn_max_position_embeddings = cfg .max_seq_len
191
+
192
+ # Only activate if longer than original ctx
193
+ if cfg .max_seq_len > cfg .yarn_rope_original_max_position_embeddings :
194
+
195
+ partial_rotary_factor = 1.0 # Placeholder, assume no partial_rotary_factor in config.
196
+ dim = int (head_dim * partial_rotary_factor )
197
+
198
+ factor = cfg .yarn_rope_factor
199
+
200
+ # Sets the attention factor as suggested in the paper
201
+ # See: https://github.com/huggingface/transformers/blob/main/examples/modular-transformers/modeling_super.py#L190-L191
202
+ scaling_factor = 0.1 * math .log (factor ) + 1.0
203
+
204
+ # Optional config options
205
+ # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
206
+ beta_fast = 32
207
+ beta_slow = 1
208
+
209
+ # Compute the inverse frequencies
210
+ def find_correction_dim (num_rotations , dim , base , yarn_max_position_embeddings ):
211
+ """Inverse dimension formula to find the dimension based on the number of rotations"""
212
+ return (dim * math .log (yarn_max_position_embeddings / (num_rotations * 2 * math .pi ))) / (2 * math .log (base ))
213
+
214
+ def find_correction_range (low_rot , high_rot , dim , base , yarn_max_position_embeddings ):
215
+ """Find dimension range bounds based on rotations"""
216
+ low = math .floor (find_correction_dim (low_rot , dim , base , yarn_max_position_embeddings ))
217
+ high = math .ceil (find_correction_dim (high_rot , dim , base , yarn_max_position_embeddings ))
218
+ return max (low , 0 ), min (high , dim - 1 )
219
+
220
+ def linear_ramp_factor (min , max , dim ):
221
+ if min == max :
222
+ max += 0.001 # Prevent singularity
223
+
224
+ linear_func = (torch .arange (dim , dtype = torch .float32 ) - min ) / (max - min )
225
+ ramp_func = torch .clamp (linear_func , 0 , 1 )
226
+ return ramp_func
227
+
228
+ # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
229
+ # to expand the possible context length. In other words, interpolation = apply scaling factor.
230
+ pos_freqs = base ** (torch .arange (0 , dim , 2 ).float ().to (device ) / dim )
231
+ inv_freq_extrapolation = 1.0 / pos_freqs
232
+ inv_freq_interpolation = 1.0 / (factor * pos_freqs )
233
+
234
+ low , high = find_correction_range (beta_fast , beta_slow , dim , base , yarn_max_position_embeddings )
235
+
236
+ # Get n-dimensional rotational scaling corrected for extrapolation
237
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor (low , high , dim // 2 ).float ().to (device )
238
+ inv_freq = (
239
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor )
240
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
241
+ )
242
+ else :
243
+ inv_freq = 1.0 / (base ** (torch .arange (0 , head_dim , 2 , device = device ).float () / head_dim ))
244
+
185
245
# Regular
186
246
187
247
else :
0 commit comments