Skip to content

Commit 8dca1ab

Browse files
Only trigger if long context config is set
1 parent b195503 commit 8dca1ab

File tree

1 file changed

+53
-46
lines changed

1 file changed

+53
-46
lines changed

exllamav2/device.py

Lines changed: 53 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -187,53 +187,60 @@ def prepare_sincos(self):
187187

188188
elif cfg.alt_rope_method == "yarn":
189189

190-
partial_rotary_factor = 1.0 # Placeholder, assume no partial_rotary_factor in config.
191-
dim = int(head_dim * partial_rotary_factor)
192190
yarn_max_position_embeddings = cfg.yarn_rope_original_max_position_embeddings
193-
factor = cfg.yarn_rope_factor
194-
195-
# Sets the attention factor as suggested in the paper
196-
# See: https://github.com/huggingface/transformers/blob/main/examples/modular-transformers/modeling_super.py#L190-L191
197-
scaling_factor = 0.1 * math.log(factor) + 1.0
198-
199-
# Optional config options
200-
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
201-
beta_fast = 32
202-
beta_slow = 1
203-
204-
# Compute the inverse frequencies
205-
def find_correction_dim(num_rotations, dim, base, yarn_max_position_embeddings):
206-
"""Inverse dimension formula to find the dimension based on the number of rotations"""
207-
return (dim * math.log(yarn_max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
208-
209-
def find_correction_range(low_rot, high_rot, dim, base, yarn_max_position_embeddings):
210-
"""Find dimension range bounds based on rotations"""
211-
low = math.floor(find_correction_dim(low_rot, dim, base, yarn_max_position_embeddings))
212-
high = math.ceil(find_correction_dim(high_rot, dim, base, yarn_max_position_embeddings))
213-
return max(low, 0), min(high, dim - 1)
214-
215-
def linear_ramp_factor(min, max, dim):
216-
if min == max:
217-
max += 0.001 # Prevent singularity
218-
219-
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
220-
ramp_func = torch.clamp(linear_func, 0, 1)
221-
return ramp_func
222-
223-
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
224-
# to expand the possible context length. In other words, interpolation = apply scaling factor.
225-
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
226-
inv_freq_extrapolation = 1.0 / pos_freqs
227-
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
228-
229-
low, high = find_correction_range(beta_fast, beta_slow, dim, base, yarn_max_position_embeddings)
230-
231-
# Get n-dimensional rotational scaling corrected for extrapolation
232-
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
233-
inv_freq = (
234-
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
235-
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
236-
)
191+
192+
# Only activate if longer than original ctx
193+
if cfg.max_seq_len > cfg.yarn_rope_original_max_position_embeddings:
194+
195+
partial_rotary_factor = 1.0 # Placeholder, assume no partial_rotary_factor in config.
196+
dim = int(head_dim * partial_rotary_factor)
197+
198+
factor = cfg.yarn_rope_factor
199+
200+
# Sets the attention factor as suggested in the paper
201+
# See: https://github.com/huggingface/transformers/blob/main/examples/modular-transformers/modeling_super.py#L190-L191
202+
scaling_factor = 0.1 * math.log(factor) + 1.0
203+
204+
# Optional config options
205+
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
206+
beta_fast = 32
207+
beta_slow = 1
208+
209+
# Compute the inverse frequencies
210+
def find_correction_dim(num_rotations, dim, base, yarn_max_position_embeddings):
211+
"""Inverse dimension formula to find the dimension based on the number of rotations"""
212+
return (dim * math.log(yarn_max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
213+
214+
def find_correction_range(low_rot, high_rot, dim, base, yarn_max_position_embeddings):
215+
"""Find dimension range bounds based on rotations"""
216+
low = math.floor(find_correction_dim(low_rot, dim, base, yarn_max_position_embeddings))
217+
high = math.ceil(find_correction_dim(high_rot, dim, base, yarn_max_position_embeddings))
218+
return max(low, 0), min(high, dim - 1)
219+
220+
def linear_ramp_factor(min, max, dim):
221+
if min == max:
222+
max += 0.001 # Prevent singularity
223+
224+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
225+
ramp_func = torch.clamp(linear_func, 0, 1)
226+
return ramp_func
227+
228+
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
229+
# to expand the possible context length. In other words, interpolation = apply scaling factor.
230+
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
231+
inv_freq_extrapolation = 1.0 / pos_freqs
232+
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
233+
234+
low, high = find_correction_range(beta_fast, beta_slow, dim, base, yarn_max_position_embeddings)
235+
236+
# Get n-dimensional rotational scaling corrected for extrapolation
237+
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
238+
inv_freq = (
239+
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
240+
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
241+
)
242+
else:
243+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device = device).float() / head_dim))
237244

238245
# Regular
239246

0 commit comments

Comments
 (0)