Skip to content

Llama-3_1-Nemotron 51B support #726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions exllamav2/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class Params:
"attn_k": ".self_attn.k_proj",
"attn_v": ".self_attn.v_proj",
"attn_o": ".self_attn.o_proj",
"linear_attn": ".self_attn.linear_attn",
"layers": "layers",
"patch_conv": "patch_conv",
})
Expand Down Expand Up @@ -692,6 +693,21 @@ class Params:
self.lm.expect_keys += \
expect_keys_llama

# Deci

if arch_string == "DeciLMForCausalLM":
arch_recognized = True
self.lm.layer_keys += \
layer_keys_llama_norms + \
layer_keys_llama_attn + \
layer_keys_llama_mlp
self.lm.expect_keys += \
expect_keys_llama
# self.lm.keys.update({
# "attn_o": ".self_attn.linear_attn",
# })
self.lm.supports_tp = True

# Llama (default + fallback)

if arch_string != "LlamaForCausalLM" and not arch_recognized:
Expand Down
192 changes: 163 additions & 29 deletions exllamav2/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,14 @@ def __init__(
self.hidden_size = cfg.vision_hidden_size
else:
self.num_attention_heads = cfg.num_attention_heads
self.num_key_value_heads = cfg.num_key_value_heads
self.num_key_value_groups = cfg.num_key_value_groups
if type(cfg.num_key_value_heads) is list:
self.num_key_value_heads = cfg.num_key_value_heads[layer_idx]
else:
self.num_key_value_heads = cfg.num_key_value_heads
if type(cfg.num_key_value_groups) is list:
self.num_key_value_groups = cfg.num_key_value_groups[layer_idx]
else:
self.num_key_value_groups = cfg.num_key_value_groups
self.head_dim = cfg.head_dim
self.hidden_size = cfg.hidden_size

Expand All @@ -186,10 +192,16 @@ def __init__(
f_d = f_c + self.num_key_value_heads * self.head_dim
f_key = (key + km["fused_qkv"]) if km["fused_qkv"] else None

self.q_proj = ExLlamaV2Linear(model, key + km["attn_q"], hidden_size, self.num_attention_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_a, f_end = f_b, altpack_qkv = ap.fused_qkv_altpack)
self.k_proj = ExLlamaV2Linear(model, key + km["attn_k"], hidden_size, self.num_key_value_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_b, f_end = f_c, altpack_qkv = ap.fused_qkv_altpack)
self.v_proj = ExLlamaV2Linear(model, key + km["attn_v"], hidden_size, self.num_key_value_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_c, f_end = f_d, altpack_qkv = ap.fused_qkv_altpack)
self.o_proj = ExLlamaV2Linear(model, key + km["attn_o"], self.num_attention_heads * self.head_dim, hidden_size, ap.attention_bias_o, prescale = cfg.scale_depth)
if type(cfg.num_key_value_heads) is list and ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[layer_idx]:
self.q_proj = None
self.k_proj = None
self.v_proj = None
self.o_proj = ExLlamaV2Linear(model, key + km["linear_attn"], self.num_attention_heads * self.head_dim, hidden_size, ap.attention_bias_o, prescale = cfg.scale_depth)
else:
self.q_proj = ExLlamaV2Linear(model, key + km["attn_q"], hidden_size, self.num_attention_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_a, f_end = f_b, altpack_qkv = ap.fused_qkv_altpack)
self.k_proj = ExLlamaV2Linear(model, key + km["attn_k"], hidden_size, self.num_key_value_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_b, f_end = f_c, altpack_qkv = ap.fused_qkv_altpack)
self.v_proj = ExLlamaV2Linear(model, key + km["attn_v"], hidden_size, self.num_key_value_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_c, f_end = f_d, altpack_qkv = ap.fused_qkv_altpack)
self.o_proj = ExLlamaV2Linear(model, key + km["attn_o"], self.num_attention_heads * self.head_dim, hidden_size, ap.attention_bias_o, prescale = cfg.scale_depth)

if cfg.use_qk_norm:
self.q_norm = ExLlamaV2HeadNorm(model, key + ".self_attn.q_norm", self.num_attention_heads, self.head_dim)
Expand All @@ -198,12 +210,18 @@ def __init__(
self.q_norm = None
self.k_norm = None

self.submodules = [
self.q_proj,
self.k_proj,
self.v_proj,
self.o_proj
]
if type(cfg.num_key_value_heads) is list and ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[layer_idx]:
self.submodules = [
self.o_proj
]
else:
self.submodules = [
self.q_proj,
self.k_proj,
self.v_proj,
self.o_proj
]

if self.pre_layernorm:
self.submodules += [self.pre_layernorm]
if self.post_layernorm:
Expand All @@ -223,10 +241,10 @@ def __init__(

def numel(self) -> int:

numel = self.q_proj.numel() + \
self.k_proj.numel() + \
self.v_proj.numel() + \
self.o_proj.numel()
numel = self.o_proj.numel()
if self.q_proj is not None: numel += self.q_proj.numel()
if self.k_proj is not None: numel += self.k_proj.numel()
if self.v_proj is not None: numel += self.v_proj.numel()

if self.pre_layernorm is not None: numel += self.pre_layernorm.numel()
if self.post_layernorm is not None: numel += self.post_layernorm.numel()
Expand All @@ -243,10 +261,12 @@ def load(self, device_context: bool = True):

if self.pre_layernorm is not None: self.pre_layernorm.load()
if self.post_layernorm is not None: self.post_layernorm.load()
self.o_proj.load(device_context = device_context)
if self.q_proj is None and self.k_proj is None and self.v_proj is None:
return
self.q_proj.load(device_context = device_context)
self.k_proj.load(device_context = device_context)
self.v_proj.load(device_context = device_context)
self.o_proj.load(device_context = device_context)
if self.q_norm is not None: self.q_norm.load()
if self.k_norm is not None: self.k_norm.load()

Expand Down Expand Up @@ -347,10 +367,10 @@ def unload(self):

def weight_footprint(self):

fp = self.q_proj.weight_footprint() + \
self.k_proj.weight_footprint() + \
self.v_proj.weight_footprint() + \
self.o_proj.weight_footprint()
fp = self.o_proj.weight_footprint()
if self.q_proj is not None: fp += self.q_proj.weight_footprint()
if self.k_proj is not None: fp += self.k_proj.weight_footprint()
if self.v_proj is not None: fp += self.v_proj.weight_footprint()
if self.pre_layernorm is not None:
fp += self.pre_layernorm.weight_footprint()
if self.post_layernorm is not None:
Expand Down Expand Up @@ -406,10 +426,13 @@ def temp_v_size(self):

def temp_dq_size(self):

return max(self.q_proj.temp_dq_size(),
self.k_proj.temp_dq_size(),
self.v_proj.temp_dq_size(),
self.o_proj.temp_dq_size())
if self.q_proj is None and self.k_proj is None and self.v_proj is None:
return self.o_proj.temp_dq_size()
else:
return max(self.q_proj.temp_dq_size(),
self.k_proj.temp_dq_size(),
self.v_proj.temp_dq_size(),
self.o_proj.temp_dq_size())


def temp_kv_size(self):
Expand Down Expand Up @@ -440,9 +463,9 @@ def set_device_idx(self, idx: int | None):

if self.pre_layernorm is not None: self.pre_layernorm.set_device_idx(idx)
if self.post_layernorm is not None: self.post_layernorm.set_device_idx(idx)
self.q_proj.set_device_idx(idx)
self.k_proj.set_device_idx(idx)
self.v_proj.set_device_idx(idx)
if self.q_proj is not None: self.q_proj.set_device_idx(idx)
if self.k_proj is not None: self.k_proj.set_device_idx(idx)
if self.v_proj is not None: self.v_proj.set_device_idx(idx)
self.o_proj.set_device_idx(idx)
if self.q_norm is not None: self.q_norm.set_device_idx(idx)
if self.k_norm is not None: self.k_norm.set_device_idx(idx)
Expand All @@ -468,6 +491,14 @@ def forward_paged(
**kwargs
) -> torch.Tensor:

if self.q_proj is None and self.k_proj is None and self.v_proj is None:
return self.forward_paged_linear(
hidden_states,
cache,
attn_params,
loras,
**kwargs,
)
if self.is_tp:
return self.forward_paged_tp(
hidden_states,
Expand Down Expand Up @@ -633,6 +664,50 @@ def forward_paged(

return hidden_states

# @profile
def forward_paged_linear(
self,
hidden_states: torch.Tensor,
cache: ExLlamaV2CacheBase | None = None,
attn_params: ExLlamaV2Attention.PagedParams | None = None,
loras: list[ExLlamaV2Lora] | None = None,
**kwargs
) -> torch.Tensor:

is_q = self.q_handle is not None
cfg = self.model.config
constants = self.model.get_device_context(self.device_idx, scratch = is_q)
page_size = attn_params.page_size
batch_size, q_len, _ = hidden_states.shape
cache_seqlens = attn_params.get_cache_seqlens(self.device_idx)
block_table = attn_params.get_block_index(self.device_idx)

sc = attn_params.get_alt_rope_embed(self.device_idx)
if not sc:
sin, cos = constants.sin, constants.cos
else:
sin, cos = sc

cache_seqlens_rope = cache_seqlens
offsets = attn_params.get_rope_offsets(self.device_idx)
if offsets is not None:
cache_seqlens_rope = cache_seqlens_rope + offsets

residual = hidden_states
hidden_states = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states

attn_output = hidden_states.view((batch_size, q_len, self.num_attention_heads * self.head_dim))

cache.store_kv_state(self.layer_idx, batch_size, 0, q_len, page_size, cache_seqlens, block_table)

hidden_states = self.o_proj.forward(attn_output, loras = loras)
if self.post_layernorm:
hidden_states = self.post_layernorm.forward(hidden_states)
if self.has_residual:
hidden_states += residual

return hidden_states


# @profile
def forward_paged_tp(
Expand Down Expand Up @@ -1368,6 +1443,17 @@ def forward_torch(
loras: list[ExLlamaV2Lora] | None = None,
**kwargs
) -> torch.Tensor | dict:

if self.q_proj is None and self.k_proj is None and self.v_proj is None:
return self.forward_torch_linear(
hidden_states,
cache,
attn_params,
past_len,
intermediates,
loras,
**kwargs,
)

global has_flash_attn
global has_xformers
Expand Down Expand Up @@ -1501,6 +1587,54 @@ def forward_torch(
else:
return hidden_states

def forward_torch_linear(
self,
hidden_states: torch.Tensor,
cache: ExLlamaV2CacheBase | None = None,
attn_params: ExLlamaV2Attention.Params | None = None,
past_len: int | None = None,
intermediates: bool = False,
loras: list[ExLlamaV2Lora] | None = None,
**kwargs
) -> torch.Tensor | dict:

cfg = self.model.config
num_attention_heads = self.num_attention_heads
num_key_value_heads = self.num_key_value_heads
head_dim = self.head_dim

batch_size, q_len, _ = hidden_states.size()

past_len = 0 if cache is None else cache.current_seq_len

# Project q, k, v

residual = hidden_states
post_norm = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states

# Output projection

attn_proj = self.o_proj.forward(post_norm, loras = loras)

# Post layernorm

if self.post_layernorm:
attn_proj = self.post_layernorm.forward(attn_proj, output_fp32 = self.archparams.residual_stream_fp32)

# Add residual connection

hidden_states = (attn_proj + residual) if self.has_residual else attn_proj

if self.archparams.residual_stream_fp32:
hidden_states = hidden_states.float()
elif self.archparams.clamp_hidden_states:
hidden_states.clamp_(-65504, 65504)

if intermediates:
return {"post_norm": post_norm,
"hidden_states": hidden_states}
else:
return hidden_states

def update_loras(self):

Expand Down Expand Up @@ -1599,4 +1733,4 @@ def amax(res: list[int]):
add(ctx.get_temp_tensors_s(maxrows, 2, BROADCAST_KV, dim = self.head_dim))
add(ctx.get_temp_tensors_s(maxrows, 2, BROADCAST_Q, dim = self.head_dim))

return scratch
return scratch
33 changes: 21 additions & 12 deletions exllamav2/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,19 @@ def __init__(
self.head_dim = self.model.config.head_dim

self.current_seq_len = 0
self.shape_basic = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim)
self.shape_wk = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim // self.weights_per_element_k)
self.shape_wv = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim // self.weights_per_element_v)
self.shape_s = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim // 32)

if type(self.num_key_value_heads) is list:
self.shape_basic = (self.batch_size, self.max_seq_len, max(self.num_key_value_heads), self.head_dim)
else:
self.shape_basic = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim)
self.num_key_value_heads = [self.num_key_value_heads] * self.num_hidden_layers
self.shape_wk = list()
self.shape_wv = list()
self.shape_s = list()
for il in range(self.num_hidden_layers):
self.shape_wk.append((self.batch_size, self.max_seq_len, self.num_key_value_heads[il], self.head_dim // self.weights_per_element_k))
self.shape_wv.append((self.batch_size, self.max_seq_len, self.num_key_value_heads[il], self.head_dim // self.weights_per_element_v))
self.shape_s.append((self.batch_size, self.max_seq_len, self.num_key_value_heads[il], self.head_dim // 32))

self.q_block = 0
self.fixed_device = fixed_device
Expand All @@ -88,11 +97,11 @@ def create_state_tensors(

if copy_from is None:
device = self.model.cache_map.get(i, self.fixed_device)
p_key_states = torch.zeros(self.shape_wk, dtype = self.dtype, device = device).contiguous()
p_value_states = torch.zeros(self.shape_wv, dtype = self.dtype, device = device).contiguous()
p_key_states = torch.zeros(self.shape_wk[i], dtype = self.dtype, device = device).contiguous()
p_value_states = torch.zeros(self.shape_wv[i], dtype = self.dtype, device = device).contiguous()
if self.has_scales:
p_key_scales = torch.zeros(self.shape_s, dtype = torch.float16, device = device).contiguous()
p_value_scales = torch.zeros(self.shape_s, dtype = torch.float16, device = device).contiguous()
p_key_scales = torch.zeros(self.shape_s[i], dtype = torch.float16, device = device).contiguous()
p_value_scales = torch.zeros(self.shape_s[i], dtype = torch.float16, device = device).contiguous()
else:
p_key_states = copy_from.key_states[i].clone()
p_value_states = copy_from.value_states[i].clone()
Expand Down Expand Up @@ -129,13 +138,13 @@ def update_cache_tensors(self):
self.key_states[k] = None
self.value_states[k] = None

p_key_states = torch.zeros(self.shape_wk, dtype = self.dtype, device = v).contiguous()
p_value_states = torch.zeros(self.shape_wv, dtype = self.dtype, device = v).contiguous()
p_key_states = torch.zeros(self.shape_wk[k], dtype = self.dtype, device = v).contiguous()
p_value_states = torch.zeros(self.shape_wv[k], dtype = self.dtype, device = v).contiguous()
self.key_states[k] = p_key_states
self.value_states[k] = p_value_states
if self.has_scales:
p_key_scales = torch.zeros(self.shape_s, dtype = torch.float16, device = v).contiguous()
p_value_scales = torch.zeros(self.shape_s, dtype = torch.float16, device = v).contiguous()
p_key_scales = torch.zeros(self.shape_s[k], dtype = torch.float16, device = v).contiguous()
p_value_scales = torch.zeros(self.shape_s[k], dtype = torch.float16, device = v).contiguous()
self.key_scales[k] = p_key_scales
self.value_scales[k] = p_value_scales

Expand Down
Loading