|
3 | 3 |
|
4 | 4 | # This source code is licensed under the license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
| 6 | +import math |
6 | 7 | from dataclasses import dataclass
|
7 | 8 | from typing import Optional
|
8 | 9 |
|
@@ -42,6 +43,7 @@ class ModelArgs:
|
42 | 43 | head_dim: int = 64
|
43 | 44 | rope_base: float = 10000
|
44 | 45 | norm_eps: float = 1e-5
|
| 46 | + rope_scaling: Optional[dict] = None |
45 | 47 |
|
46 | 48 | def __post_init__(self):
|
47 | 49 | if self.n_local_heads == -1:
|
@@ -81,6 +83,9 @@ def from_name(cls, name: str):
|
81 | 83 |
|
82 | 84 | "llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000),
|
83 | 85 | "llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000),
|
| 86 | + "llama-3.1-405b": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000, |
| 87 | + rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), |
| 88 | + ), |
84 | 89 | }
|
85 | 90 |
|
86 | 91 | class KVCache(nn.Module):
|
@@ -133,7 +138,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
|
133 | 138 | for b in self.layers:
|
134 | 139 | b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
|
135 | 140 |
|
136 |
| - self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype) |
| 141 | + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype, self.config.rope_scaling) |
137 | 142 |
|
138 | 143 | def forward(self, mask: BlockMask, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
|
139 | 144 | assert self.freqs_cis is not None, "Caches must be initialized first"
|
@@ -243,11 +248,36 @@ def forward(self, x: Tensor) -> Tensor:
|
243 | 248 | return output * self.weight
|
244 | 249 |
|
245 | 250 |
|
| 251 | +def apply_rope_scaling(freqs: torch.Tensor, rope_scaling: Optional[dict] = None): |
| 252 | + factor = rope_scaling["factor"] |
| 253 | + low_freq_factor = rope_scaling["low_freq_factor"] |
| 254 | + high_freq_factor = rope_scaling["high_freq_factor"] |
| 255 | + old_context_len = rope_scaling["original_max_position_embeddings"] |
| 256 | + |
| 257 | + low_freq_wavelen = old_context_len / low_freq_factor |
| 258 | + high_freq_wavelen = old_context_len / high_freq_factor |
| 259 | + new_freqs = [] |
| 260 | + for freq in freqs: |
| 261 | + wavelen = 2 * math.pi / freq |
| 262 | + if wavelen < high_freq_wavelen: |
| 263 | + new_freqs.append(freq) |
| 264 | + elif wavelen > low_freq_wavelen: |
| 265 | + new_freqs.append(freq / factor) |
| 266 | + else: |
| 267 | + assert low_freq_wavelen != high_freq_wavelen |
| 268 | + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) |
| 269 | + new_freqs.append((1 - smooth) * freq / factor + smooth * freq) |
| 270 | + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) |
| 271 | + |
| 272 | + |
246 | 273 | def precompute_freqs_cis(
|
247 | 274 | seq_len: int, n_elem: int, base: int = 10000,
|
248 |
| - dtype: torch.dtype = torch.bfloat16 |
| 275 | + dtype: torch.dtype = torch.bfloat16, |
| 276 | + rope_scaling: Optional[dict] = None, |
249 | 277 | ) -> Tensor:
|
250 | 278 | freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
|
| 279 | + if rope_scaling is not None: |
| 280 | + freqs = apply_rope_scaling(freqs, rope_scaling) |
251 | 281 | t = torch.arange(seq_len, device=freqs.device)
|
252 | 282 | freqs = torch.outer(t, freqs)
|
253 | 283 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
|
0 commit comments