Skip to content

Commit c84f597

Browse files
committed
Merge branch 'refs/heads/dev-yarn' into dev
2 parents f1adff9 + 6b73184 commit c84f597

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

exllamav2/config.py

+6
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ class ExLlamaV2Config:
111111
l3_rope_low_freq_factor: float | None
112112
l3_rope_high_freq_factor: float | None
113113
l3_rope_original_max_position_embeddings: int | None
114+
yarn_rope_factor: float | None
115+
yarn_rope_original_max_position_embeddings: int | None
114116
checkpoint_fused_mlp: bool
115117
checkpoint_offset_qzeros: bool
116118

@@ -306,6 +308,10 @@ def prepare(self, no_tensors: bool = False):
306308
self.l3_rope_low_freq_factor = rs["low_freq_factor"]
307309
self.l3_rope_high_freq_factor = rs["high_freq_factor"]
308310
self.l3_rope_original_max_position_embeddings = rs["original_max_position_embeddings"]
311+
if scaling_type == "yarn":
312+
self.alt_rope_method = "yarn"
313+
self.yarn_rope_factor = rs["factor"]
314+
self.yarn_rope_original_max_position_embeddings = rs["original_max_position_embeddings"]
309315

310316
# Checkpoint format (for GPTQ models)
311317

exllamav2/device.py

+60
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,66 @@ def prepare_sincos(self):
182182
cfg.l3_rope_original_max_position_embeddings,
183183
)
184184

185+
# YaRN
186+
# Adapted from transformers: https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/modeling_rope_utils.py#L163
187+
188+
elif cfg.alt_rope_method == "yarn":
189+
190+
yarn_max_position_embeddings = cfg.max_seq_len
191+
192+
# Only activate if longer than original ctx
193+
if cfg.max_seq_len > cfg.yarn_rope_original_max_position_embeddings:
194+
195+
partial_rotary_factor = 1.0 # Placeholder, assume no partial_rotary_factor in config.
196+
dim = int(head_dim * partial_rotary_factor)
197+
198+
factor = cfg.yarn_rope_factor
199+
200+
# Sets the attention factor as suggested in the paper
201+
# See: https://github.com/huggingface/transformers/blob/main/examples/modular-transformers/modeling_super.py#L190-L191
202+
scaling_factor = 0.1 * math.log(factor) + 1.0
203+
204+
# Optional config options
205+
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
206+
beta_fast = 32
207+
beta_slow = 1
208+
209+
# Compute the inverse frequencies
210+
def find_correction_dim(num_rotations, dim, base, yarn_max_position_embeddings):
211+
"""Inverse dimension formula to find the dimension based on the number of rotations"""
212+
return (dim * math.log(yarn_max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
213+
214+
def find_correction_range(low_rot, high_rot, dim, base, yarn_max_position_embeddings):
215+
"""Find dimension range bounds based on rotations"""
216+
low = math.floor(find_correction_dim(low_rot, dim, base, yarn_max_position_embeddings))
217+
high = math.ceil(find_correction_dim(high_rot, dim, base, yarn_max_position_embeddings))
218+
return max(low, 0), min(high, dim - 1)
219+
220+
def linear_ramp_factor(min, max, dim):
221+
if min == max:
222+
max += 0.001 # Prevent singularity
223+
224+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
225+
ramp_func = torch.clamp(linear_func, 0, 1)
226+
return ramp_func
227+
228+
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
229+
# to expand the possible context length. In other words, interpolation = apply scaling factor.
230+
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
231+
inv_freq_extrapolation = 1.0 / pos_freqs
232+
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
233+
234+
low, high = find_correction_range(beta_fast, beta_slow, dim, base, yarn_max_position_embeddings)
235+
236+
# Get n-dimensional rotational scaling corrected for extrapolation
237+
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
238+
inv_freq = (
239+
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
240+
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
241+
)
242+
else:
243+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device = device).float() / head_dim))
244+
185245
# Regular
186246

187247
else:

0 commit comments

Comments
 (0)