Skip to content

Commit 5d43593

Browse files
committed
Add YaRN factor override to model_init
1 parent c84f597 commit 5d43593

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

exllamav2/config.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -287,31 +287,30 @@ def prepare(self, no_tensors: bool = False):
287287
rs = read(read_config, dict, "rope_scaling", None)
288288
if rs:
289289
scaling_type = rs.get("type", None)
290+
rope_type = rs.get("rope_type", None)
291+
assert not (scaling_type and rope_type), "rope_scaling key has both `type` and `rope_type` subkeys"
290292
if scaling_type == "linear":
291293
assert "factor" in rs, "'factor' missing from 'rope_scaling' config"
292294
self.scale_pos_emb = rs.get("factor", 1.0)
293295
if scaling_type == "su" or scaling_type == "longrope":
294-
assert "long_factor" in rs, "'long_factor' missing from 'rope_scaling' config"
295-
assert "short_factor" in rs, "'short_factor' missing from 'rope_scaling' config"
296+
assert "long_factor" in rs, "'long_factor' missing from 'rope_scaling' config ('su' mode)"
297+
assert "short_factor" in rs, "'short_factor' missing from 'rope_scaling' config ('su' mode)"
296298
assert "original_max_position_embeddings" in read_config, \
297299
"'original_max_position_embeddings' required for 'su' scaling"
298300
self.scale_long_factor = rs["long_factor"]
299301
self.scale_short_factor = rs["short_factor"]
300302
self.original_max_seq_len = read_config["original_max_position_embeddings"]
301303
self.alt_rope_method = "su"
302-
# if scaling_type == "yarn":
303-
# self.scale_alpha_value = factor
304-
rope_type = rs.get("rope_type", None)
304+
if scaling_type == "yarn":
305+
self.alt_rope_method = "yarn"
306+
self.yarn_rope_factor = rs["factor"]
307+
self.yarn_rope_original_max_position_embeddings = rs["original_max_position_embeddings"]
305308
if rope_type == "llama3":
306309
self.alt_rope_method = "llama3"
307310
self.l3_rope_factor = rs["factor"]
308311
self.l3_rope_low_freq_factor = rs["low_freq_factor"]
309312
self.l3_rope_high_freq_factor = rs["high_freq_factor"]
310313
self.l3_rope_original_max_position_embeddings = rs["original_max_position_embeddings"]
311-
if scaling_type == "yarn":
312-
self.alt_rope_method = "yarn"
313-
self.yarn_rope_factor = rs["factor"]
314-
self.yarn_rope_original_max_position_embeddings = rs["original_max_position_embeddings"]
315314

316315
# Checkpoint format (for GPTQ models)
317316

exllamav2/model_init.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def add_args(parser):
1515
parser.add_argument("-l", "--length", type = int, help = "Maximum sequence length")
1616
parser.add_argument("-rs", "--rope_scale", type = float, help = "RoPE scaling factor")
1717
parser.add_argument("-ra", "--rope_alpha", type = float, help = "RoPE alpha value (NTK)")
18+
parser.add_argument("-ry", "--rope_yarn", type = float, help = "Set RoPE YaRN factor (use default max_seq_len as original_max_position_embeddings if not configured)")
1819
parser.add_argument("-nfa", "--no_flash_attn", action = "store_true", help = "Disable Flash Attention")
1920
parser.add_argument("-nxf", "--no_xformers", action = "store_true", help = "Disable xformers, an alternative plan of flash attn for older devices")
2021
parser.add_argument("-nsdpa", "--no_sdpa", action = "store_true", help = "Disable Torch SDPA")
@@ -27,7 +28,6 @@ def add_args(parser):
2728
parser.add_argument("-chunk", "--chunk_size", type = int, help = "Chunk size ('input length')")
2829

2930

30-
3131
def print_options(args):
3232

3333
print(f" -- Model: {args.model_dir}")
@@ -38,6 +38,7 @@ def print_options(args):
3838
if args.length is not None: print_opts += [f"length: {args.length}"]
3939
if args.rope_scale is not None: print_opts += [f"rope_scale: {args.rope_scale}"]
4040
if args.rope_alpha is not None: print_opts += [f"rope_alpha: {args.rope_alpha}"]
41+
if args.rope_yarn is not None: print_opts += [f"rope_yarn: {args.rope_yarn}"]
4142
if args.no_flash_attn: print_opts += ["no_flash_attn"]
4243
if args.no_xformers: print_opts += ["no_xformers"]
4344
if args.no_sdpa: print_opts += ["no_sdpa"]
@@ -97,6 +98,12 @@ def init(
9798

9899
# Set config options
99100

101+
if args.rope_yarn:
102+
if config.alt_rope_method != "yarn":
103+
config.yarn_rope_original_max_position_embeddings = config.max_seq_len
104+
config.alt_rope_method = "yarn"
105+
config.yarn_rope_factor = args.rope_yarn
106+
100107
if args.length: config.max_seq_len = args.length
101108
if args.rope_scale: config.scale_pos_emb = args.rope_scale
102109
if args.rope_alpha: config.scale_alpha_value = args.rope_alpha

0 commit comments

Comments
 (0)