@@ -187,53 +187,60 @@ def prepare_sincos(self):
187
187
188
188
elif cfg .alt_rope_method == "yarn" :
189
189
190
- partial_rotary_factor = 1.0 # Placeholder, assume no partial_rotary_factor in config.
191
- dim = int (head_dim * partial_rotary_factor )
192
190
yarn_max_position_embeddings = cfg .yarn_rope_original_max_position_embeddings
193
- factor = cfg .yarn_rope_factor
194
-
195
- # Sets the attention factor as suggested in the paper
196
- # See: https://github.com/huggingface/transformers/blob/main/examples/modular-transformers/modeling_super.py#L190-L191
197
- scaling_factor = 0.1 * math .log (factor ) + 1.0
198
-
199
- # Optional config options
200
- # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
201
- beta_fast = 32
202
- beta_slow = 1
203
-
204
- # Compute the inverse frequencies
205
- def find_correction_dim (num_rotations , dim , base , yarn_max_position_embeddings ):
206
- """Inverse dimension formula to find the dimension based on the number of rotations"""
207
- return (dim * math .log (yarn_max_position_embeddings / (num_rotations * 2 * math .pi ))) / (2 * math .log (base ))
208
-
209
- def find_correction_range (low_rot , high_rot , dim , base , yarn_max_position_embeddings ):
210
- """Find dimension range bounds based on rotations"""
211
- low = math .floor (find_correction_dim (low_rot , dim , base , yarn_max_position_embeddings ))
212
- high = math .ceil (find_correction_dim (high_rot , dim , base , yarn_max_position_embeddings ))
213
- return max (low , 0 ), min (high , dim - 1 )
214
-
215
- def linear_ramp_factor (min , max , dim ):
216
- if min == max :
217
- max += 0.001 # Prevent singularity
218
-
219
- linear_func = (torch .arange (dim , dtype = torch .float32 ) - min ) / (max - min )
220
- ramp_func = torch .clamp (linear_func , 0 , 1 )
221
- return ramp_func
222
-
223
- # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
224
- # to expand the possible context length. In other words, interpolation = apply scaling factor.
225
- pos_freqs = base ** (torch .arange (0 , dim , 2 ).float ().to (device ) / dim )
226
- inv_freq_extrapolation = 1.0 / pos_freqs
227
- inv_freq_interpolation = 1.0 / (factor * pos_freqs )
228
-
229
- low , high = find_correction_range (beta_fast , beta_slow , dim , base , yarn_max_position_embeddings )
230
-
231
- # Get n-dimensional rotational scaling corrected for extrapolation
232
- inv_freq_extrapolation_factor = 1 - linear_ramp_factor (low , high , dim // 2 ).float ().to (device )
233
- inv_freq = (
234
- inv_freq_interpolation * (1 - inv_freq_extrapolation_factor )
235
- + inv_freq_extrapolation * inv_freq_extrapolation_factor
236
- )
191
+
192
+ # Only activate if longer than original ctx
193
+ if cfg .max_seq_len > cfg .yarn_rope_original_max_position_embeddings :
194
+
195
+ partial_rotary_factor = 1.0 # Placeholder, assume no partial_rotary_factor in config.
196
+ dim = int (head_dim * partial_rotary_factor )
197
+
198
+ factor = cfg .yarn_rope_factor
199
+
200
+ # Sets the attention factor as suggested in the paper
201
+ # See: https://github.com/huggingface/transformers/blob/main/examples/modular-transformers/modeling_super.py#L190-L191
202
+ scaling_factor = 0.1 * math .log (factor ) + 1.0
203
+
204
+ # Optional config options
205
+ # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
206
+ beta_fast = 32
207
+ beta_slow = 1
208
+
209
+ # Compute the inverse frequencies
210
+ def find_correction_dim (num_rotations , dim , base , yarn_max_position_embeddings ):
211
+ """Inverse dimension formula to find the dimension based on the number of rotations"""
212
+ return (dim * math .log (yarn_max_position_embeddings / (num_rotations * 2 * math .pi ))) / (2 * math .log (base ))
213
+
214
+ def find_correction_range (low_rot , high_rot , dim , base , yarn_max_position_embeddings ):
215
+ """Find dimension range bounds based on rotations"""
216
+ low = math .floor (find_correction_dim (low_rot , dim , base , yarn_max_position_embeddings ))
217
+ high = math .ceil (find_correction_dim (high_rot , dim , base , yarn_max_position_embeddings ))
218
+ return max (low , 0 ), min (high , dim - 1 )
219
+
220
+ def linear_ramp_factor (min , max , dim ):
221
+ if min == max :
222
+ max += 0.001 # Prevent singularity
223
+
224
+ linear_func = (torch .arange (dim , dtype = torch .float32 ) - min ) / (max - min )
225
+ ramp_func = torch .clamp (linear_func , 0 , 1 )
226
+ return ramp_func
227
+
228
+ # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
229
+ # to expand the possible context length. In other words, interpolation = apply scaling factor.
230
+ pos_freqs = base ** (torch .arange (0 , dim , 2 ).float ().to (device ) / dim )
231
+ inv_freq_extrapolation = 1.0 / pos_freqs
232
+ inv_freq_interpolation = 1.0 / (factor * pos_freqs )
233
+
234
+ low , high = find_correction_range (beta_fast , beta_slow , dim , base , yarn_max_position_embeddings )
235
+
236
+ # Get n-dimensional rotational scaling corrected for extrapolation
237
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor (low , high , dim // 2 ).float ().to (device )
238
+ inv_freq = (
239
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor )
240
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
241
+ )
242
+ else :
243
+ inv_freq = 1.0 / (base ** (torch .arange (0 , head_dim , 2 , device = device ).float () / head_dim ))
237
244
238
245
# Regular
239
246
0 commit comments