Skip to content

Commit a79e7e8

Browse files
authored
Merge branch 'main' into bf/flex-decoding-integrate
2 parents d983e3d + 8354eba commit a79e7e8

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ mistralai/Mistral-7B-v0.1
7373
mistralai/Mistral-7B-Instruct-v0.1
7474
mistralai/Mistral-7B-Instruct-v0.2
7575
meta-llama/Meta-Llama-3-8B
76+
meta-llama/Meta-Llama-3.1-405B
7677
```
7778

7879
For example, to convert Llama-2-7b-chat-hf
@@ -120,6 +121,7 @@ Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh
120121
| Llama-2-70B | Base | 62.50 | 1135.29 |
121122
| | 8-bit | 80.44 | 752.04 |
122123
| | 4-bit (G=32) | 90.77 | 548.10 |
124+
| Llama-3.1-405B | 8-bit | 15.60 | 815.87 |
123125

124126
### AMD
125127
Benchmarks run on one GCD of a MI-250x.

model.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import math
67
from dataclasses import dataclass
78
from typing import Optional
89

@@ -42,6 +43,7 @@ class ModelArgs:
4243
head_dim: int = 64
4344
rope_base: float = 10000
4445
norm_eps: float = 1e-5
46+
rope_scaling: Optional[dict] = None
4547

4648
def __post_init__(self):
4749
if self.n_local_heads == -1:
@@ -81,6 +83,9 @@ def from_name(cls, name: str):
8183

8284
"llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000),
8385
"llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000),
86+
"llama-3.1-405b": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000,
87+
rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192),
88+
),
8489
}
8590

8691
class KVCache(nn.Module):
@@ -133,7 +138,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
133138
for b in self.layers:
134139
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
135140

136-
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
141+
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype, self.config.rope_scaling)
137142

138143
def forward(self, mask: BlockMask, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
139144
assert self.freqs_cis is not None, "Caches must be initialized first"
@@ -243,11 +248,36 @@ def forward(self, x: Tensor) -> Tensor:
243248
return output * self.weight
244249

245250

251+
def apply_rope_scaling(freqs: torch.Tensor, rope_scaling: Optional[dict] = None):
252+
factor = rope_scaling["factor"]
253+
low_freq_factor = rope_scaling["low_freq_factor"]
254+
high_freq_factor = rope_scaling["high_freq_factor"]
255+
old_context_len = rope_scaling["original_max_position_embeddings"]
256+
257+
low_freq_wavelen = old_context_len / low_freq_factor
258+
high_freq_wavelen = old_context_len / high_freq_factor
259+
new_freqs = []
260+
for freq in freqs:
261+
wavelen = 2 * math.pi / freq
262+
if wavelen < high_freq_wavelen:
263+
new_freqs.append(freq)
264+
elif wavelen > low_freq_wavelen:
265+
new_freqs.append(freq / factor)
266+
else:
267+
assert low_freq_wavelen != high_freq_wavelen
268+
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
269+
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
270+
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
271+
272+
246273
def precompute_freqs_cis(
247274
seq_len: int, n_elem: int, base: int = 10000,
248-
dtype: torch.dtype = torch.bfloat16
275+
dtype: torch.dtype = torch.bfloat16,
276+
rope_scaling: Optional[dict] = None,
249277
) -> Tensor:
250278
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
279+
if rope_scaling is not None:
280+
freqs = apply_rope_scaling(freqs, rope_scaling)
251281
t = torch.arange(seq_len, device=freqs.device)
252282
freqs = torch.outer(t, freqs)
253283
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)

scripts/convert_hf_checkpoint.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,10 @@ def permute(w, n_head):
116116
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
117117
torch.save(final_result, checkpoint_dir / "model.pth")
118118
if 'llama-3' in model_name.lower():
119-
original_dir = checkpoint_dir / "original"
119+
if 'llama-3.1' in model_name.lower():
120+
original_dir = checkpoint_dir / "original" / "mp16"
121+
else:
122+
original_dir = checkpoint_dir / "original"
120123
tokenizer_model = original_dir / "tokenizer.model"
121124
tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
122125
print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}")

0 commit comments

Comments
 (0)