Skip to content

Commit 0d78f03

Browse files
Add YaRN
1 parent 7c7b199 commit 0d78f03

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

exllamav2/device.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,59 @@ def prepare_sincos(self):
182182
cfg.l3_rope_original_max_position_embeddings,
183183
)
184184

185+
# YaRN
186+
# Adapted from transformers: https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/modeling_rope_utils.py#L163
187+
188+
elif cfg.alt_rope_method == "yarn":
189+
190+
partial_rotary_factor = 1.0 # Placeholder, assume no partial_rotary_factor in config.
191+
dim = int(head_dim * partial_rotary_factor)
192+
yarn_max_position_embeddings = cfg.yarn_rope_original_max_position_embeddings
193+
factor = cfg.yarn_rope_factor
194+
195+
# Sets the attention factor as suggested in the paper
196+
# See: https://github.com/huggingface/transformers/blob/main/examples/modular-transformers/modeling_super.py#L190-L191
197+
scaling_factor = 0.1 * math.log(factor) + 1.0
198+
199+
# Optional config options
200+
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
201+
beta_fast = 32
202+
beta_slow = 1
203+
204+
# Compute the inverse frequencies
205+
def find_correction_dim(num_rotations, dim, base, yarn_max_position_embeddings):
206+
"""Inverse dimension formula to find the dimension based on the number of rotations"""
207+
return (dim * math.log(yarn_max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
208+
209+
def find_correction_range(low_rot, high_rot, dim, base, yarn_max_position_embeddings):
210+
"""Find dimension range bounds based on rotations"""
211+
low = math.floor(find_correction_dim(low_rot, dim, base, yarn_max_position_embeddings))
212+
high = math.ceil(find_correction_dim(high_rot, dim, base, yarn_max_position_embeddings))
213+
return max(low, 0), min(high, dim - 1)
214+
215+
def linear_ramp_factor(min, max, dim):
216+
if min == max:
217+
max += 0.001 # Prevent singularity
218+
219+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
220+
ramp_func = torch.clamp(linear_func, 0, 1)
221+
return ramp_func
222+
223+
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
224+
# to expand the possible context length. In other words, interpolation = apply scaling factor.
225+
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
226+
inv_freq_extrapolation = 1.0 / pos_freqs
227+
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
228+
229+
low, high = find_correction_range(beta_fast, beta_slow, dim, base, yarn_max_position_embeddings)
230+
231+
# Get n-dimensional rotational scaling corrected for extrapolation
232+
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
233+
inv_freq = (
234+
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
235+
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
236+
)
237+
185238
# Regular
186239

187240
else:

0 commit comments

Comments
 (0)