Skip to content

Commit bedbcc4

Browse files
authored
v3.5 hotfix (#2566)
1 parent b9a60d6 commit bedbcc4

File tree

5 files changed

+78
-17
lines changed

5 files changed

+78
-17
lines changed

eval_llm/WIKITEXT2/readme.md

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
These are perplexity computed on wikitext2.
2+
3+
Numbers are not comparable to lm-evaluation-harness since they compute word / byte / bit perplexity like this:
4+
5+
hf-auto (pretrained=mistralai/Mistral-7B-Instruct-v0.2), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 8
6+
| Tasks |Version|Filter|n-shot| Metric |Value | |Stderr|
7+
|--------|------:|------|------|---------------|-----:|---|------|
8+
|wikitext| 2|none |None |word_perplexity|9.8183|± |N/A |
9+
| | |none |None |byte_perplexity|1.5329|± |N/A |
10+
| | |none |None |bits_per_byte |0.6163|± |N/A |
11+
12+
13+
hf-auto (pretrained=meta-llama/Llama-2-7b-hf), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1
14+
| Tasks |Version|Filter|n-shot| Metric |Value | |Stderr|
15+
|--------|------:|------|------|---------------|-----:|---|------|
16+
|wikitext| 2|none |None |word_perplexity|8.7921|± |N/A |
17+
| | |none |None |byte_perplexity|1.5016|± |N/A |
18+
| | |none |None |bits_per_byte |0.5865|± |N/A |
19+
20+
21+
Numbers are not comparable to perplexity reported by llama.cpp because we use a smaller context window but also we detokenize the raw corpus (thing that they shoudl do but they don't)
22+
23+
| 7B Family | | PPL | Time (sec) |
24+
| ---------------- | --------------------- | ----- | ---------- |
25+
| Base | llama2 | 5.78 | 152 |
26+
| | mistral v0.1 | 5.70 | 162 |
27+
| | awq | 5.81 | 165 |
28+
| | Yi-6B-200K | 7.76 | 133 |
29+
| | xgen-7B | 8.64 | 129 |
30+
| | mpt-7B | 8.43 | 147 |
31+
| | | | |
32+
| Instruct / Tuned | llama2-chat | 7.37 | 148 |
33+
| | mistral-instr-v0.2 | 6.98 | 160 |
34+
| | gemm-awq | 7.07 | 164 |
35+
| | gemv-awq | 7.07 | 237 |
36+
| | | | |
37+
| | Alma-7B-R | 6.82 | 156 |
38+
| | TowerInstruct-7B | 6.45 | 157 |
39+
| | codellama-7B | 8.56 | 154 |
40+
| | | | |
41+
| 3B Family | Phi-2 | 9.74 | 52 |
42+
| | Phi-2-psy | 10.44 | 53 |
43+
| | | | |
44+
| 13B Family | llama2 (4-bit) | 5.31 | 296 |
45+
| | llama2-chat (4-bit) | 6.59 | 292 |
46+
| | | | |
47+
| 34B Family | codellama-34B (4-bit) | 6.00 | 706 |
48+
49+
50+
We note that llama2 and Mistral are in fact very close for their base model. However there is a shift between their chat model.
51+
52+
All others are quite below which is surprising for Yi given their results on the Open llm leaderboard.
53+
54+
I need to check why Mistral seems a little slower than llama2, it should be the opposite.

onmt/models/model.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ def load_state_dict(
157157
)
158158
param.data = checkpoint["generator"][keyname]
159159
del checkpoint["generator"][keyname]
160-
elif strict and "lora" not in param_name:
160+
elif strict and (
161+
"lora" not in param_name and "slopes" not in param_name
162+
):
161163
raise ValueError(
162164
"Missing key in checkpoint: %s" % name + "." + param_name
163165
)
@@ -234,7 +236,9 @@ def load_safe_state_dict(
234236
name, module, param_name, param, buf_list, ckpt_t, offset
235237
)
236238
keyfound[name + "." + param_name] = True
237-
elif strict and "lora" not in param_name:
239+
elif strict and (
240+
"lora" not in param_name and "slopes" not in param_name
241+
):
238242
raise ValueError(
239243
"Missing key in safetensors checkpoint: %s" % name
240244
+ "."

onmt/modules/multi_headed_attn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ def forward(
599599
base=self.rotary_theta,
600600
device=query.device,
601601
)
602-
rope = self.rope[start_pos : start_pos + seqlen]
602+
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
603603
query, key = apply_rotary_emb(
604604
query, key, rope, interleave=self.rotary_interleave
605605
)

onmt/train_single.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,6 @@
2727

2828
def prepare_transforms_vocabs(opt, transforms_cls):
2929
"""Prepare or dump transforms before training."""
30-
# if transform + options set in 'valid' we need to copy in main
31-
# transform / options for scoring considered as inference
32-
validset_transforms = opt.data.get("valid", {}).get("transforms", None)
33-
if validset_transforms:
34-
opt.transforms = validset_transforms
35-
if opt.data.get("valid", {}).get("tgt_prefix", None):
36-
opt.tgt_prefix = opt.data.get("valid", {}).get("tgt_prefix", None)
37-
opt.tgt_file_prefix = True
38-
if opt.data.get("valid", {}).get("src_prefix", None):
39-
opt.src_prefix = opt.data.get("valid", {}).get("src_prefix", None)
40-
if opt.data.get("valid", {}).get("tgt_suffix", None):
41-
opt.tgt_suffix = opt.data.get("valid", {}).get("tgt_suffix", None)
42-
if opt.data.get("valid", {}).get("src_suffix", None):
43-
opt.src_suffix = opt.data.get("valid", {}).get("src_suffix", None)
4430
specials = get_specials(opt, transforms_cls)
4531

4632
vocabs = build_vocab(opt, specials)
@@ -77,6 +63,20 @@ def _init_train(opt):
7763
"""
7864
ArgumentParser.validate_prepare_opts(opt)
7965
transforms_cls = get_transforms_cls(opt._all_transform)
66+
# if transform + options set in 'valid' we need to copy in main
67+
# transform / options for scoring considered as inference
68+
validset_transforms = opt.data.get("valid", {}).get("transforms", None)
69+
if validset_transforms:
70+
opt.transforms = validset_transforms
71+
if opt.data.get("valid", {}).get("tgt_prefix", None):
72+
opt.tgt_prefix = opt.data.get("valid", {}).get("tgt_prefix", None)
73+
opt.tgt_file_prefix = True
74+
if opt.data.get("valid", {}).get("src_prefix", None):
75+
opt.src_prefix = opt.data.get("valid", {}).get("src_prefix", None)
76+
if opt.data.get("valid", {}).get("tgt_suffix", None):
77+
opt.tgt_suffix = opt.data.get("valid", {}).get("tgt_suffix", None)
78+
if opt.data.get("valid", {}).get("src_suffix", None):
79+
opt.src_suffix = opt.data.get("valid", {}).get("src_suffix", None)
8080
if opt.train_from:
8181
# Load checkpoint if we resume from a previous training.
8282
checkpoint = load_checkpoint(ckpt_path=opt.train_from)

tools/convert_llama.py

+3
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,9 @@ def __init__(self, model_path: str):
428428
global_attention_function="softmax",
429429
self_attn_type="scaled-dot",
430430
max_relative_positions=-1,
431+
rotary_interleave=True,
432+
rotary_theta=10000,
433+
rotary_dim=0,
431434
heads=heads,
432435
sliding_window=sliding_window,
433436
transformer_ff=transformer_ff,

0 commit comments

Comments
 (0)