diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index f12f3501..3a0054cd 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -139,6 +139,7 @@ class Params: "attn_k": ".self_attn.k_proj", "attn_v": ".self_attn.v_proj", "attn_o": ".self_attn.o_proj", + "linear_attn": ".self_attn.linear_attn", "layers": "layers", "patch_conv": "patch_conv", }) @@ -692,6 +693,21 @@ class Params: self.lm.expect_keys += \ expect_keys_llama + # Deci + + if arch_string == "DeciLMForCausalLM": + arch_recognized = True + self.lm.layer_keys += \ + layer_keys_llama_norms + \ + layer_keys_llama_attn + \ + layer_keys_llama_mlp + self.lm.expect_keys += \ + expect_keys_llama +# self.lm.keys.update({ +# "attn_o": ".self_attn.linear_attn", +# }) + self.lm.supports_tp = True + # Llama (default + fallback) if arch_string != "LlamaForCausalLM" and not arch_recognized: diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 3b6199e6..c2e8ed5c 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -161,8 +161,14 @@ def __init__( self.hidden_size = cfg.vision_hidden_size else: self.num_attention_heads = cfg.num_attention_heads - self.num_key_value_heads = cfg.num_key_value_heads - self.num_key_value_groups = cfg.num_key_value_groups + if type(cfg.num_key_value_heads) is list: + self.num_key_value_heads = cfg.num_key_value_heads[layer_idx] + else: + self.num_key_value_heads = cfg.num_key_value_heads + if type(cfg.num_key_value_groups) is list: + self.num_key_value_groups = cfg.num_key_value_groups[layer_idx] + else: + self.num_key_value_groups = cfg.num_key_value_groups self.head_dim = cfg.head_dim self.hidden_size = cfg.hidden_size @@ -186,10 +192,16 @@ def __init__( f_d = f_c + self.num_key_value_heads * self.head_dim f_key = (key + km["fused_qkv"]) if km["fused_qkv"] else None - self.q_proj = ExLlamaV2Linear(model, key + km["attn_q"], hidden_size, self.num_attention_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_a, f_end = f_b, altpack_qkv = ap.fused_qkv_altpack) - self.k_proj = ExLlamaV2Linear(model, key + km["attn_k"], hidden_size, self.num_key_value_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_b, f_end = f_c, altpack_qkv = ap.fused_qkv_altpack) - self.v_proj = ExLlamaV2Linear(model, key + km["attn_v"], hidden_size, self.num_key_value_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_c, f_end = f_d, altpack_qkv = ap.fused_qkv_altpack) - self.o_proj = ExLlamaV2Linear(model, key + km["attn_o"], self.num_attention_heads * self.head_dim, hidden_size, ap.attention_bias_o, prescale = cfg.scale_depth) + if type(cfg.num_key_value_heads) is list and ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[layer_idx]: + self.q_proj = None + self.k_proj = None + self.v_proj = None + self.o_proj = ExLlamaV2Linear(model, key + km["linear_attn"], self.num_attention_heads * self.head_dim, hidden_size, ap.attention_bias_o, prescale = cfg.scale_depth) + else: + self.q_proj = ExLlamaV2Linear(model, key + km["attn_q"], hidden_size, self.num_attention_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_a, f_end = f_b, altpack_qkv = ap.fused_qkv_altpack) + self.k_proj = ExLlamaV2Linear(model, key + km["attn_k"], hidden_size, self.num_key_value_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_b, f_end = f_c, altpack_qkv = ap.fused_qkv_altpack) + self.v_proj = ExLlamaV2Linear(model, key + km["attn_v"], hidden_size, self.num_key_value_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_c, f_end = f_d, altpack_qkv = ap.fused_qkv_altpack) + self.o_proj = ExLlamaV2Linear(model, key + km["attn_o"], self.num_attention_heads * self.head_dim, hidden_size, ap.attention_bias_o, prescale = cfg.scale_depth) if cfg.use_qk_norm: self.q_norm = ExLlamaV2HeadNorm(model, key + ".self_attn.q_norm", self.num_attention_heads, self.head_dim) @@ -198,12 +210,18 @@ def __init__( self.q_norm = None self.k_norm = None - self.submodules = [ - self.q_proj, - self.k_proj, - self.v_proj, - self.o_proj - ] + if type(cfg.num_key_value_heads) is list and ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[layer_idx]: + self.submodules = [ + self.o_proj + ] + else: + self.submodules = [ + self.q_proj, + self.k_proj, + self.v_proj, + self.o_proj + ] + if self.pre_layernorm: self.submodules += [self.pre_layernorm] if self.post_layernorm: @@ -223,10 +241,10 @@ def __init__( def numel(self) -> int: - numel = self.q_proj.numel() + \ - self.k_proj.numel() + \ - self.v_proj.numel() + \ - self.o_proj.numel() + numel = self.o_proj.numel() + if self.q_proj is not None: numel += self.q_proj.numel() + if self.k_proj is not None: numel += self.k_proj.numel() + if self.v_proj is not None: numel += self.v_proj.numel() if self.pre_layernorm is not None: numel += self.pre_layernorm.numel() if self.post_layernorm is not None: numel += self.post_layernorm.numel() @@ -243,10 +261,12 @@ def load(self, device_context: bool = True): if self.pre_layernorm is not None: self.pre_layernorm.load() if self.post_layernorm is not None: self.post_layernorm.load() + self.o_proj.load(device_context = device_context) + if self.q_proj is None and self.k_proj is None and self.v_proj is None: + return self.q_proj.load(device_context = device_context) self.k_proj.load(device_context = device_context) self.v_proj.load(device_context = device_context) - self.o_proj.load(device_context = device_context) if self.q_norm is not None: self.q_norm.load() if self.k_norm is not None: self.k_norm.load() @@ -347,10 +367,10 @@ def unload(self): def weight_footprint(self): - fp = self.q_proj.weight_footprint() + \ - self.k_proj.weight_footprint() + \ - self.v_proj.weight_footprint() + \ - self.o_proj.weight_footprint() + fp = self.o_proj.weight_footprint() + if self.q_proj is not None: fp += self.q_proj.weight_footprint() + if self.k_proj is not None: fp += self.k_proj.weight_footprint() + if self.v_proj is not None: fp += self.v_proj.weight_footprint() if self.pre_layernorm is not None: fp += self.pre_layernorm.weight_footprint() if self.post_layernorm is not None: @@ -406,10 +426,13 @@ def temp_v_size(self): def temp_dq_size(self): - return max(self.q_proj.temp_dq_size(), - self.k_proj.temp_dq_size(), - self.v_proj.temp_dq_size(), - self.o_proj.temp_dq_size()) + if self.q_proj is None and self.k_proj is None and self.v_proj is None: + return self.o_proj.temp_dq_size() + else: + return max(self.q_proj.temp_dq_size(), + self.k_proj.temp_dq_size(), + self.v_proj.temp_dq_size(), + self.o_proj.temp_dq_size()) def temp_kv_size(self): @@ -440,9 +463,9 @@ def set_device_idx(self, idx: int | None): if self.pre_layernorm is not None: self.pre_layernorm.set_device_idx(idx) if self.post_layernorm is not None: self.post_layernorm.set_device_idx(idx) - self.q_proj.set_device_idx(idx) - self.k_proj.set_device_idx(idx) - self.v_proj.set_device_idx(idx) + if self.q_proj is not None: self.q_proj.set_device_idx(idx) + if self.k_proj is not None: self.k_proj.set_device_idx(idx) + if self.v_proj is not None: self.v_proj.set_device_idx(idx) self.o_proj.set_device_idx(idx) if self.q_norm is not None: self.q_norm.set_device_idx(idx) if self.k_norm is not None: self.k_norm.set_device_idx(idx) @@ -468,6 +491,14 @@ def forward_paged( **kwargs ) -> torch.Tensor: + if self.q_proj is None and self.k_proj is None and self.v_proj is None: + return self.forward_paged_linear( + hidden_states, + cache, + attn_params, + loras, + **kwargs, + ) if self.is_tp: return self.forward_paged_tp( hidden_states, @@ -633,6 +664,50 @@ def forward_paged( return hidden_states + # @profile + def forward_paged_linear( + self, + hidden_states: torch.Tensor, + cache: ExLlamaV2CacheBase | None = None, + attn_params: ExLlamaV2Attention.PagedParams | None = None, + loras: list[ExLlamaV2Lora] | None = None, + **kwargs + ) -> torch.Tensor: + + is_q = self.q_handle is not None + cfg = self.model.config + constants = self.model.get_device_context(self.device_idx, scratch = is_q) + page_size = attn_params.page_size + batch_size, q_len, _ = hidden_states.shape + cache_seqlens = attn_params.get_cache_seqlens(self.device_idx) + block_table = attn_params.get_block_index(self.device_idx) + + sc = attn_params.get_alt_rope_embed(self.device_idx) + if not sc: + sin, cos = constants.sin, constants.cos + else: + sin, cos = sc + + cache_seqlens_rope = cache_seqlens + offsets = attn_params.get_rope_offsets(self.device_idx) + if offsets is not None: + cache_seqlens_rope = cache_seqlens_rope + offsets + + residual = hidden_states + hidden_states = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states + + attn_output = hidden_states.view((batch_size, q_len, self.num_attention_heads * self.head_dim)) + + cache.store_kv_state(self.layer_idx, batch_size, 0, q_len, page_size, cache_seqlens, block_table) + + hidden_states = self.o_proj.forward(attn_output, loras = loras) + if self.post_layernorm: + hidden_states = self.post_layernorm.forward(hidden_states) + if self.has_residual: + hidden_states += residual + + return hidden_states + # @profile def forward_paged_tp( @@ -1368,6 +1443,17 @@ def forward_torch( loras: list[ExLlamaV2Lora] | None = None, **kwargs ) -> torch.Tensor | dict: + + if self.q_proj is None and self.k_proj is None and self.v_proj is None: + return self.forward_torch_linear( + hidden_states, + cache, + attn_params, + past_len, + intermediates, + loras, + **kwargs, + ) global has_flash_attn global has_xformers @@ -1501,6 +1587,54 @@ def forward_torch( else: return hidden_states + def forward_torch_linear( + self, + hidden_states: torch.Tensor, + cache: ExLlamaV2CacheBase | None = None, + attn_params: ExLlamaV2Attention.Params | None = None, + past_len: int | None = None, + intermediates: bool = False, + loras: list[ExLlamaV2Lora] | None = None, + **kwargs + ) -> torch.Tensor | dict: + + cfg = self.model.config + num_attention_heads = self.num_attention_heads + num_key_value_heads = self.num_key_value_heads + head_dim = self.head_dim + + batch_size, q_len, _ = hidden_states.size() + + past_len = 0 if cache is None else cache.current_seq_len + + # Project q, k, v + + residual = hidden_states + post_norm = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states + + # Output projection + + attn_proj = self.o_proj.forward(post_norm, loras = loras) + + # Post layernorm + + if self.post_layernorm: + attn_proj = self.post_layernorm.forward(attn_proj, output_fp32 = self.archparams.residual_stream_fp32) + + # Add residual connection + + hidden_states = (attn_proj + residual) if self.has_residual else attn_proj + + if self.archparams.residual_stream_fp32: + hidden_states = hidden_states.float() + elif self.archparams.clamp_hidden_states: + hidden_states.clamp_(-65504, 65504) + + if intermediates: + return {"post_norm": post_norm, + "hidden_states": hidden_states} + else: + return hidden_states def update_loras(self): @@ -1599,4 +1733,4 @@ def amax(res: list[int]): add(ctx.get_temp_tensors_s(maxrows, 2, BROADCAST_KV, dim = self.head_dim)) add(ctx.get_temp_tensors_s(maxrows, 2, BROADCAST_Q, dim = self.head_dim)) - return scratch \ No newline at end of file + return scratch diff --git a/exllamav2/cache.py b/exllamav2/cache.py index de4dfe2c..058b6966 100644 --- a/exllamav2/cache.py +++ b/exllamav2/cache.py @@ -63,10 +63,19 @@ def __init__( self.head_dim = self.model.config.head_dim self.current_seq_len = 0 - self.shape_basic = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim) - self.shape_wk = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim // self.weights_per_element_k) - self.shape_wv = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim // self.weights_per_element_v) - self.shape_s = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim // 32) + + if type(self.num_key_value_heads) is list: + self.shape_basic = (self.batch_size, self.max_seq_len, max(self.num_key_value_heads), self.head_dim) + else: + self.shape_basic = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim) + self.num_key_value_heads = [self.num_key_value_heads] * self.num_hidden_layers + self.shape_wk = list() + self.shape_wv = list() + self.shape_s = list() + for il in range(self.num_hidden_layers): + self.shape_wk.append((self.batch_size, self.max_seq_len, self.num_key_value_heads[il], self.head_dim // self.weights_per_element_k)) + self.shape_wv.append((self.batch_size, self.max_seq_len, self.num_key_value_heads[il], self.head_dim // self.weights_per_element_v)) + self.shape_s.append((self.batch_size, self.max_seq_len, self.num_key_value_heads[il], self.head_dim // 32)) self.q_block = 0 self.fixed_device = fixed_device @@ -88,11 +97,11 @@ def create_state_tensors( if copy_from is None: device = self.model.cache_map.get(i, self.fixed_device) - p_key_states = torch.zeros(self.shape_wk, dtype = self.dtype, device = device).contiguous() - p_value_states = torch.zeros(self.shape_wv, dtype = self.dtype, device = device).contiguous() + p_key_states = torch.zeros(self.shape_wk[i], dtype = self.dtype, device = device).contiguous() + p_value_states = torch.zeros(self.shape_wv[i], dtype = self.dtype, device = device).contiguous() if self.has_scales: - p_key_scales = torch.zeros(self.shape_s, dtype = torch.float16, device = device).contiguous() - p_value_scales = torch.zeros(self.shape_s, dtype = torch.float16, device = device).contiguous() + p_key_scales = torch.zeros(self.shape_s[i], dtype = torch.float16, device = device).contiguous() + p_value_scales = torch.zeros(self.shape_s[i], dtype = torch.float16, device = device).contiguous() else: p_key_states = copy_from.key_states[i].clone() p_value_states = copy_from.value_states[i].clone() @@ -129,13 +138,13 @@ def update_cache_tensors(self): self.key_states[k] = None self.value_states[k] = None - p_key_states = torch.zeros(self.shape_wk, dtype = self.dtype, device = v).contiguous() - p_value_states = torch.zeros(self.shape_wv, dtype = self.dtype, device = v).contiguous() + p_key_states = torch.zeros(self.shape_wk[k], dtype = self.dtype, device = v).contiguous() + p_value_states = torch.zeros(self.shape_wv[k], dtype = self.dtype, device = v).contiguous() self.key_states[k] = p_key_states self.value_states[k] = p_value_states if self.has_scales: - p_key_scales = torch.zeros(self.shape_s, dtype = torch.float16, device = v).contiguous() - p_value_scales = torch.zeros(self.shape_s, dtype = torch.float16, device = v).contiguous() + p_key_scales = torch.zeros(self.shape_s[k], dtype = torch.float16, device = v).contiguous() + p_value_scales = torch.zeros(self.shape_s[k], dtype = torch.float16, device = v).contiguous() self.key_scales[k] = p_key_scales self.value_scales[k] = p_value_scales diff --git a/exllamav2/config.py b/exllamav2/config.py index 3a5296e4..b911671a 100644 --- a/exllamav2/config.py +++ b/exllamav2/config.py @@ -286,7 +286,8 @@ def prepare(self, no_tensors: bool = False): self.num_attention_heads, opt_subkey = "text_config", ) - self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + self.use_qk_norm = read(read_config, bool, ["use_qk_norm"], False) self.query_pre_attn_scalar = read(read_config, float, "query_pre_attn_scalar", None) @@ -299,13 +300,55 @@ def prepare(self, no_tensors: bool = False): else: default_intermediate_size = no_default - self.intermediate_size = read( - read_config, - int, - ["intermediate_size", "ffn_config->ffn_hidden_size", "n_inner"], - default_intermediate_size, - opt_subkey = "text_config", - ) +# Deci overrides num_key_value_heads, num_key_value_groups and intermediate size + if self.architecture == "DeciLMForCausalLM": + if "block_configs" in read_config: # # Llama-3_1-Nemotron-51B + _block_configs: list[dict[str,Any]] = read_config["block_configs"] + assert self.num_hidden_layers == len(_block_configs) + self.num_key_value_heads = list() + self.num_key_value_groups = list() + self.intermediate_size = list() + self.arch.lm.layer_keys = list() + for il in range(len(_block_configs)): + if _block_configs[il]["attention"]["n_heads_in_group"] is None: + if _block_configs[il]["attention"]["replace_with_linear"] is True: + self.num_key_value_heads.append(0) + self.arch.lm.layer_keys.append([["input_layernorm"],["post_attention_layernorm"],["self_attn.linear_attn"],["mlp.down_proj"],["mlp.gate_proj"],["mlp.up_proj"]]) + else: + self.num_key_value_heads.append(0) + self.arch.lm.layer_keys.append([["mlp.down_proj"],["mlp.gate_proj"],["mlp.up_proj"]]) + else: + self.num_key_value_heads.append(self.num_attention_heads // _block_configs[il]["attention"]["n_heads_in_group"]) + self.arch.lm.layer_keys.append([["input_layernorm"],["post_attention_layernorm"],["self_attn.q_proj"], ["self_attn.k_proj"],["self_attn.v_proj"],["self_attn.o_proj"],["mlp.down_proj"],["mlp.gate_proj"],["mlp.up_proj"]]) + if self.num_key_value_heads[il] == 0: + self.num_key_value_groups.append(0) + else: + self.num_key_value_groups.append(self.num_attention_heads // self.num_key_value_heads[il]) + ffn_mult = _block_configs[il]["ffn"]["ffn_mult"] + intm_size = int(2 * ffn_mult * self.hidden_size / 3) + if intm_size % 256 != 0: + intm_size = intm_size + 256 - (intm_size % 256) + self.intermediate_size.append(intm_size) + else: # Deci-7B, no need to override intermediate_size + self.num_key_value_heads: list[int] = read_config["num_key_value_heads_per_layer"] + self.num_key_value_groups = list() + for il in range(len(self.num_key_value_heads)): + self.num_key_value_groups.append(self.num_attention_heads // self.num_key_value_heads[il]) + self.intermediate_size = read( + read_config, + int, + ["intermediate_size", "ffn_config->ffn_hidden_size", "n_inner"], + default_intermediate_size, + opt_subkey = "text_config", + ) + else: + self.intermediate_size = read( + read_config, + int, + ["intermediate_size", "ffn_config->ffn_hidden_size", "n_inner"], + default_intermediate_size, + opt_subkey = "text_config", + ) self.num_experts = read(read_config, int, ["num_local_experts", "ffn_config->moe_num_experts"], None) self.num_experts_per_token = read(read_config, int,["num_experts_per_tok", "ffn_config->moe_top_k"], None) @@ -450,7 +493,12 @@ def check_keys(archparams, prefix): all_keys = set(self.tensor_file_map.keys()) suffixes = [".q_weight", ".qweight", ".weight", ""] +# for k in all_keys: +# print(k) +# print("****End of all_keys****") + for prefixes in expect_keys: +# print(prefixes) match = False for prefix in prefixes: for suffix in suffixes: @@ -462,7 +510,45 @@ def check_keys(archparams, prefix): if not match: raise ValueError(f" ## Could not find {prefix}.* in model") - check_keys(self.arch.lm, self.arch.lm_prefix) + def check_deci_keys(archparams, prefix): + expect_keys = archparams.expect_keys.copy() + + per_layer_keys = archparams.layer_keys + + for layer_idx in range(self.num_hidden_layers): + for ks in per_layer_keys[layer_idx]: + prefixes = [f"model.layers.{layer_idx}.{k}" for k in ks] + expect_keys.append(prefixes) + + if self.arch.lm_prefix: + expect_keys = [ + [prefix + k for k in k2] + for k2 in expect_keys + ] + + all_keys = set(self.tensor_file_map.keys()) + suffixes = [".q_weight", ".qweight", ".weight", ""] + +# for k in all_keys: +# print(k) +# print("****End of all_keys****") + + for prefixes in expect_keys: + match = False + for prefix in prefixes: + for suffix in suffixes: + if (prefix + suffix) in all_keys: + match = True + break + if match: break + if match: break + if not match: + raise ValueError(f" ## Could not find {prefix}.* in model") + + if self.architecture == "DeciLMForCausalLM" and "block_configs" in read_config: # # Llama-3_1-Nemotron-51B + check_deci_keys(self.arch.lm, self.arch.lm_prefix) + else: + check_keys(self.arch.lm, self.arch.lm_prefix) check_keys(self.arch.mmp, self.arch.mmp_prefix) check_keys(self.arch.vt, self.arch.vt_prefix) diff --git a/exllamav2/conversion/compile.py b/exllamav2/conversion/compile.py index 4fcb8b86..02aa752d 100644 --- a/exllamav2/conversion/compile.py +++ b/exllamav2/conversion/compile.py @@ -93,9 +93,12 @@ def compile_model(job, save_fn, model): if d: out_dict.update(d); current_size += _dsize(d) d = get_f_module(job, module.post_layernorm) if d: out_dict.update(d); current_size += _dsize(d) - d = get_q_module(job, module.q_proj); out_dict.update(d); current_size += _dsize(d) - d = get_q_module(job, module.k_proj); out_dict.update(d); current_size += _dsize(d) - d = get_q_module(job, module.v_proj); out_dict.update(d); current_size += _dsize(d) + if module.q_proj is not None: + d = get_q_module(job, module.q_proj); out_dict.update(d); current_size += _dsize(d) + if module.k_proj is not None: + d = get_q_module(job, module.k_proj); out_dict.update(d); current_size += _dsize(d) + if module.v_proj is not None: + d = get_q_module(job, module.v_proj); out_dict.update(d); current_size += _dsize(d) d = get_q_module(job, module.o_proj); out_dict.update(d); current_size += _dsize(d) if isinstance(module, ExLlamaV2MLP): @@ -277,4 +280,4 @@ def compile_model(job, save_fn, model): with open(config_json, "w") as f: f.write(json.dumps(config_dict, indent = 4)) - print_stage(job, "Compiling", len(model.modules), len(model.modules)) \ No newline at end of file + print_stage(job, "Compiling", len(model.modules), len(model.modules)) diff --git a/exllamav2/conversion/measure.py b/exllamav2/conversion/measure.py index 927fab2a..d5f54ab7 100644 --- a/exllamav2/conversion/measure.py +++ b/exllamav2/conversion/measure.py @@ -142,6 +142,9 @@ def test_error(module, hidden_states, target_states, cache, attn_params): def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_params, keep_q = False): +# return linear_attn call if q,k,v are not in this layer + if "q_proj" not in quantizers and "k_proj" not in quantizers and "v_proj" not in quantizers: + return measure_linear_attn(module, hidden_states, target_states, quantizers, cache, attn_params) qjobs, qmaps = get_qparams_reduced(qparams_attn) results = [] @@ -202,6 +205,48 @@ def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_p return results +def measure_linear_attn(module, hidden_states, target_states, quantizers, cache, attn_params): + + qjobs, qmaps = get_qparams_reduced(qparams_attn) + results = [] + + quantizers["o_proj"].prepare() + + options_o, bits_o = test_quant(module.o_proj, quantizers["o_proj"], qjobs[3]) + + total_numel = module.o_proj.numel() + + max_accuracy = 0.0 + (q_, k_, v_, o_) = (-1, -1, -1, -1) + for (q, k, v, o) in qmaps: + + if o != o_: module.o_proj.linear.weight = nn.Parameter(options_o[o].weight.cuda()) + (q_, k_, v_, o_) = (q, k, v, o) + + total_bits = bits_o[o] + total_bpw = total_bits / total_numel + + accuracy = test_error(module, hidden_states, target_states, cache, attn_params) + print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}") + + max_accuracy = max(accuracy, max_accuracy) + + torch.cuda.empty_cache() + + r = { "accuracy": accuracy, + "total_bits": total_bits, + "o_proj": qjobs[3][o].get_dict() } + results.append(r) + + if max_accuracy < 0.1: + print(" ## Measurement/inference error (1)") + os._exit(1) + + for x in ["o_proj"]: + if x in quantizers: + del quantizers[x] + + return results def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_params, reuse_h_up_proj = None): @@ -484,9 +529,12 @@ def measure_quant(job, save_fn, model, hidden_state_offload_layers): if isinstance(module, ExLlamaV2Attention): mode = "self_attn" - quantizers["q_proj"] = AdaptiveGPTQ(module.q_proj.linear) - quantizers["k_proj"] = AdaptiveGPTQ(module.k_proj.linear) - quantizers["v_proj"] = AdaptiveGPTQ(module.v_proj.linear) + if module.q_proj is not None: + quantizers["q_proj"] = AdaptiveGPTQ(module.q_proj.linear) + if module.k_proj is not None: + quantizers["k_proj"] = AdaptiveGPTQ(module.k_proj.linear) + if module.v_proj is not None: + quantizers["v_proj"] = AdaptiveGPTQ(module.v_proj.linear) quantizers["o_proj"] = AdaptiveGPTQ(module.o_proj.linear) elif isinstance(module, ExLlamaV2MLP): @@ -575,10 +623,14 @@ def measure_quant(job, save_fn, model, hidden_state_offload_layers): # Hessians if mode == "self_attn": - quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K and V - quantizers["o_proj"].add_batch(outputs["attn_output"]) + if module.q_proj is None and module.k_proj is None and module.v_proj is None: + quantizers["o_proj"].add_batch(outputs["post_norm"]) + else: + quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K and V + quantizers["o_proj"].add_batch(outputs["attn_output"]) target_states.append(outputs["hidden_states"].to(target_device)) + if mode == "mlp": quantizers["up_proj"].add_batch(outputs["post_norm"]) # Reuse H for gate_proj quantizers["down_proj"].add_batch(outputs["pre_down"]) diff --git a/exllamav2/conversion/optimize.py b/exllamav2/conversion/optimize.py index 33e6999f..51d69081 100644 --- a/exllamav2/conversion/optimize.py +++ b/exllamav2/conversion/optimize.py @@ -87,10 +87,19 @@ def optimize(job, save_fn, model): if cfg.arch.lm.parallel_decoder_blocks: m1 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + ".parallel_decoder"]["attn"] m2 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + ".parallel_decoder"]["mlp"] + elif type(cfg.arch.lm.layer_keys) is list and type(cfg.intermediate_size) is list: + if ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[i] or ["self_attn.o_proj"] in cfg.arch.lm.layer_keys[i]: + m1 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + ".self_attn"] + m2 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + "." + mlp_mode] + else: + m1 = None + m2 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + "." + mlp_mode] else: m1 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + ".self_attn"] m2 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + "." + mlp_mode] for m in [m1, m2]: + if m is None: + continue slot = [] param = [] for opt in m: @@ -153,14 +162,29 @@ def optimize(job, save_fn, model): logerr = 0 maxerr = 0 job["strategy"] = {} + deci_offset = 0 for layer_ in range(num_layers): - - k1 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + ".self_attn" - k2 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + "." + mlp_mode - p1 = params[layer_ * 2][solution_idx[layer_ * 2]] - p2 = params[layer_ * 2 + 1][solution_idx[layer_ * 2 + 1]] + if type(cfg.arch.lm.layer_keys) is list and type(cfg.intermediate_size) is list: + if ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[layer_] or ["self_attn.o_proj"] in cfg.arch.lm.layer_keys[layer_]: + k1 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + ".self_attn" + k2 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + "." + mlp_mode + p1 = params[layer_ * 2-deci_offset][solution_idx[layer_ * 2-deci_offset]] + p2 = params[layer_ * 2 + 1-deci_offset][solution_idx[layer_ * 2 + 1-deci_offset]] + else: + deci_offset = deci_offset + 1 + k1 = None + k2 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + "." + mlp_mode + p1 = None + p2 = params[layer_ * 2 + 1-deci_offset][solution_idx[layer_ * 2 + 1-deci_offset]] + else: + k1 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + ".self_attn" + k2 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + "." + mlp_mode + p1 = params[layer_ * 2][solution_idx[layer_ * 2]] + p2 = params[layer_ * 2 + 1][solution_idx[layer_ * 2 + 1]] for (k, p, n) in zip((k1, k2), (p1, p2), (numel_attn, numel_mlp)): + if k is None or p is None or n is None: + continue job["strategy"][k] = p bpw = p["total_bits"] / n err = 1 - p["accuracy"] @@ -171,4 +195,4 @@ def optimize(job, save_fn, model): print(f" -- sum(log(err)): {logerr:.6f}") print(f" -- max(err): {maxerr:.6f}") - xx = 0 \ No newline at end of file + xx = 0 diff --git a/exllamav2/conversion/quantize.py b/exllamav2/conversion/quantize.py index 5f38eb3a..578be13f 100644 --- a/exllamav2/conversion/quantize.py +++ b/exllamav2/conversion/quantize.py @@ -134,6 +134,10 @@ def quant_linear(job: dict, def quant_attn(job, module, hidden_states, target_states, quantizers, attn_params, strat): + # return quant_linear_attn in linear attention layer + if "q_proj" not in quantizers and "k_proj" not in quantizers and "v_proj" not in quantizers: + return quant_linear_attn(job, module, hidden_states, target_states, quantizers, attn_params, strat) + quantizers["q_proj"].prepare() quantizers["k_proj"].reuse_h(quantizers["q_proj"]) quantizers["v_proj"].reuse_h(quantizers["q_proj"]) @@ -151,6 +155,17 @@ def quant_attn(job, module, hidden_states, target_states, quantizers, attn_param torch.cuda.empty_cache() +def quant_linear_attn(job, module, hidden_states, target_states, quantizers, attn_params, strat): + + quantizers["o_proj"].prepare() + + quant_linear(job, module.o_proj, quantizers["o_proj"], strat["o_proj"]) + del quantizers[f"o_proj"] + + gc.collect() + torch.cuda.empty_cache() + + def quant_mlp(job, module, hidden_states, target_states, quantizers, attn_params, strat, reuse_h_up_proj = None): has_mlp = module.model.config.arch.lm.mlp_gate @@ -304,9 +319,12 @@ def quant(job, save_fn, model): if isinstance(module, ExLlamaV2Attention): mode = "self_attn" # if index > 1: testc(module, hidden_states, hidden_i_states, module.input_layernorm, [module.q_proj, module.k_proj, module.v_proj]) - quantizers["q_proj"] = AdaptiveGPTQ(module.q_proj.linear) - quantizers["k_proj"] = AdaptiveGPTQ(module.k_proj.linear) - quantizers["v_proj"] = AdaptiveGPTQ(module.v_proj.linear) + if module.q_proj is not None: + quantizers["q_proj"] = AdaptiveGPTQ(module.q_proj.linear) + if module.k_proj is not None: + quantizers["k_proj"] = AdaptiveGPTQ(module.k_proj.linear) + if module.v_proj is not None: + quantizers["v_proj"] = AdaptiveGPTQ(module.v_proj.linear) quantizers["o_proj"] = AdaptiveGPTQ(module.o_proj.linear) elif isinstance(module, ExLlamaV2MLP): @@ -365,8 +383,11 @@ def quant(job, save_fn, model): # Hessians if mode == "self_attn": - quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K and V - quantizers["o_proj"].add_batch(outputs["attn_output"]) + if "q_proj" not in quantizers and "k_proj" not in quantizers and "v_proj" not in quantizers: + quantizers["o_proj"].add_batch(outputs["post_norm"]) + else: + quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K and V + quantizers["o_proj"].add_batch(outputs["attn_output"]) if mode == "mlp": quantizers["up_proj"].add_batch(outputs["post_norm"]) # Reuse H for gate_proj diff --git a/exllamav2/generator/dynamic.py b/exllamav2/generator/dynamic.py index abac46ec..ebb952e2 100644 --- a/exllamav2/generator/dynamic.py +++ b/exllamav2/generator/dynamic.py @@ -471,6 +471,8 @@ def __init__( cache_tensors += self.draft_cache.all_tensors() for c in cache_tensors: + if c is None: + continue key = (c.device.index, c.dtype, c.shape[2], c.shape[3]) if key not in self.defrag_buffer: t = torch.empty((1, self.page_size, c.shape[2], c.shape[3]), dtype = c.dtype, device = c.device) diff --git a/exllamav2/mlp.py b/exllamav2/mlp.py index d0831573..5e4a2d8b 100644 --- a/exllamav2/mlp.py +++ b/exllamav2/mlp.py @@ -62,11 +62,18 @@ def __init__( self.intermediate_size = cfg.vision_intermediate_size else: self.hidden_size = cfg.hidden_size - self.intermediate_size = cfg.intermediate_size + if type(cfg.intermediate_size) is list: + self.intermediate_size = cfg.intermediate_size[layer_idx] + else: + self.intermediate_size = cfg.intermediate_size if in_features is None: in_features = self.hidden_size if out_features is None: out_features = self.hidden_size - if interm_features is None: interm_features = self.intermediate_size + if interm_features is None: + if type(self.intermediate_size) is list: + interm_features = self.intermediate_size[layer_idx] + else: + interm_features = self.intermediate_size self.in_features = in_features self.out_features = out_features self.interm_features = interm_features diff --git a/exllamav2/model.py b/exllamav2/model.py index 80ccf758..1bdf73da 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -117,6 +117,13 @@ def __init__( if cfg.arch.lm.parallel_decoder_blocks: pd = ExLlamaV2ParallelDecoder(self, layer_key, layer_idx, sliding_window = swa) self.modules += [pd] + elif type(cfg.arch.lm.layer_keys) is list and type(cfg.intermediate_size) is list: + mlp = ExLlamaV2MLP(self, layer_key, layer_idx) + if ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[layer_idx] or ["self_attn.o_proj"] in cfg.arch.lm.layer_keys[layer_idx]: + attn = ExLlamaV2Attention(self, layer_key, layer_idx, sliding_window = swa) + self.modules += [attn, mlp] + else: + self.modules += [mlp] else: attn = ExLlamaV2Attention(self, layer_key, layer_idx, sliding_window = swa) if cfg.arch.lm.is_moe: mlp = ExLlamaV2MoEMLP(self, layer_key, layer_idx)