@@ -287,31 +287,30 @@ def prepare(self, no_tensors: bool = False):
287
287
rs = read (read_config , dict , "rope_scaling" , None )
288
288
if rs :
289
289
scaling_type = rs .get ("type" , None )
290
+ rope_type = rs .get ("rope_type" , None )
291
+ assert not (scaling_type and rope_type ), "rope_scaling key has both `type` and `rope_type` subkeys"
290
292
if scaling_type == "linear" :
291
293
assert "factor" in rs , "'factor' missing from 'rope_scaling' config"
292
294
self .scale_pos_emb = rs .get ("factor" , 1.0 )
293
295
if scaling_type == "su" or scaling_type == "longrope" :
294
- assert "long_factor" in rs , "'long_factor' missing from 'rope_scaling' config"
295
- assert "short_factor" in rs , "'short_factor' missing from 'rope_scaling' config"
296
+ assert "long_factor" in rs , "'long_factor' missing from 'rope_scaling' config ('su' mode) "
297
+ assert "short_factor" in rs , "'short_factor' missing from 'rope_scaling' config ('su' mode) "
296
298
assert "original_max_position_embeddings" in read_config , \
297
299
"'original_max_position_embeddings' required for 'su' scaling"
298
300
self .scale_long_factor = rs ["long_factor" ]
299
301
self .scale_short_factor = rs ["short_factor" ]
300
302
self .original_max_seq_len = read_config ["original_max_position_embeddings" ]
301
303
self .alt_rope_method = "su"
302
- # if scaling_type == "yarn":
303
- # self.scale_alpha_value = factor
304
- rope_type = rs .get ("rope_type" , None )
304
+ if scaling_type == "yarn" :
305
+ self .alt_rope_method = "yarn"
306
+ self .yarn_rope_factor = rs ["factor" ]
307
+ self .yarn_rope_original_max_position_embeddings = rs ["original_max_position_embeddings" ]
305
308
if rope_type == "llama3" :
306
309
self .alt_rope_method = "llama3"
307
310
self .l3_rope_factor = rs ["factor" ]
308
311
self .l3_rope_low_freq_factor = rs ["low_freq_factor" ]
309
312
self .l3_rope_high_freq_factor = rs ["high_freq_factor" ]
310
313
self .l3_rope_original_max_position_embeddings = rs ["original_max_position_embeddings" ]
311
- if scaling_type == "yarn" :
312
- self .alt_rope_method = "yarn"
313
- self .yarn_rope_factor = rs ["factor" ]
314
- self .yarn_rope_original_max_position_embeddings = rs ["original_max_position_embeddings" ]
315
314
316
315
# Checkpoint format (for GPTQ models)
317
316
0 commit comments