From c64d65a1df30064bd6ad9ce4fc13e24962d7e7cf Mon Sep 17 00:00:00 2001 From: Yee Man Chan Date: Tue, 28 Jan 2025 12:30:22 +0800 Subject: [PATCH 1/3] Llama-3_1-Nemotron-51B support --- exllamav2/architecture.py | 16 +++++ exllamav2/attn.py | 13 +++- exllamav2/cache.py | 33 ++++++---- exllamav2/config.py | 104 ++++++++++++++++++++++++++++--- exllamav2/conversion/compile.py | 11 +++- exllamav2/conversion/measure.py | 64 +++++++++++++++++-- exllamav2/conversion/optimize.py | 44 +++++++++++-- exllamav2/conversion/quantize.py | 24 +++++++ exllamav2/generator/dynamic.py | 2 + exllamav2/mlp.py | 11 +++- exllamav2/model.py | 13 ++++ 11 files changed, 298 insertions(+), 37 deletions(-) diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index f12f3501..3a0054cd 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -139,6 +139,7 @@ class Params: "attn_k": ".self_attn.k_proj", "attn_v": ".self_attn.v_proj", "attn_o": ".self_attn.o_proj", + "linear_attn": ".self_attn.linear_attn", "layers": "layers", "patch_conv": "patch_conv", }) @@ -692,6 +693,21 @@ class Params: self.lm.expect_keys += \ expect_keys_llama + # Deci + + if arch_string == "DeciLMForCausalLM": + arch_recognized = True + self.lm.layer_keys += \ + layer_keys_llama_norms + \ + layer_keys_llama_attn + \ + layer_keys_llama_mlp + self.lm.expect_keys += \ + expect_keys_llama +# self.lm.keys.update({ +# "attn_o": ".self_attn.linear_attn", +# }) + self.lm.supports_tp = True + # Llama (default + fallback) if arch_string != "LlamaForCausalLM" and not arch_recognized: diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 3b6199e6..09fde7ae 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -161,8 +161,14 @@ def __init__( self.hidden_size = cfg.vision_hidden_size else: self.num_attention_heads = cfg.num_attention_heads - self.num_key_value_heads = cfg.num_key_value_heads - self.num_key_value_groups = cfg.num_key_value_groups + if type(cfg.num_key_value_heads) is list: + self.num_key_value_heads = cfg.num_key_value_heads[layer_idx] + else: + self.num_key_value_heads = cfg.num_key_value_heads + if type(cfg.num_key_value_groups) is list: + self.num_key_value_groups = cfg.num_key_value_groups[layer_idx] + else: + self.num_key_value_groups = cfg.num_key_value_groups self.head_dim = cfg.head_dim self.hidden_size = cfg.hidden_size @@ -204,6 +210,7 @@ def __init__( self.v_proj, self.o_proj ] + if self.pre_layernorm: self.submodules += [self.pre_layernorm] if self.post_layernorm: @@ -1599,4 +1606,4 @@ def amax(res: list[int]): add(ctx.get_temp_tensors_s(maxrows, 2, BROADCAST_KV, dim = self.head_dim)) add(ctx.get_temp_tensors_s(maxrows, 2, BROADCAST_Q, dim = self.head_dim)) - return scratch \ No newline at end of file + return scratch diff --git a/exllamav2/cache.py b/exllamav2/cache.py index de4dfe2c..058b6966 100644 --- a/exllamav2/cache.py +++ b/exllamav2/cache.py @@ -63,10 +63,19 @@ def __init__( self.head_dim = self.model.config.head_dim self.current_seq_len = 0 - self.shape_basic = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim) - self.shape_wk = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim // self.weights_per_element_k) - self.shape_wv = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim // self.weights_per_element_v) - self.shape_s = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim // 32) + + if type(self.num_key_value_heads) is list: + self.shape_basic = (self.batch_size, self.max_seq_len, max(self.num_key_value_heads), self.head_dim) + else: + self.shape_basic = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim) + self.num_key_value_heads = [self.num_key_value_heads] * self.num_hidden_layers + self.shape_wk = list() + self.shape_wv = list() + self.shape_s = list() + for il in range(self.num_hidden_layers): + self.shape_wk.append((self.batch_size, self.max_seq_len, self.num_key_value_heads[il], self.head_dim // self.weights_per_element_k)) + self.shape_wv.append((self.batch_size, self.max_seq_len, self.num_key_value_heads[il], self.head_dim // self.weights_per_element_v)) + self.shape_s.append((self.batch_size, self.max_seq_len, self.num_key_value_heads[il], self.head_dim // 32)) self.q_block = 0 self.fixed_device = fixed_device @@ -88,11 +97,11 @@ def create_state_tensors( if copy_from is None: device = self.model.cache_map.get(i, self.fixed_device) - p_key_states = torch.zeros(self.shape_wk, dtype = self.dtype, device = device).contiguous() - p_value_states = torch.zeros(self.shape_wv, dtype = self.dtype, device = device).contiguous() + p_key_states = torch.zeros(self.shape_wk[i], dtype = self.dtype, device = device).contiguous() + p_value_states = torch.zeros(self.shape_wv[i], dtype = self.dtype, device = device).contiguous() if self.has_scales: - p_key_scales = torch.zeros(self.shape_s, dtype = torch.float16, device = device).contiguous() - p_value_scales = torch.zeros(self.shape_s, dtype = torch.float16, device = device).contiguous() + p_key_scales = torch.zeros(self.shape_s[i], dtype = torch.float16, device = device).contiguous() + p_value_scales = torch.zeros(self.shape_s[i], dtype = torch.float16, device = device).contiguous() else: p_key_states = copy_from.key_states[i].clone() p_value_states = copy_from.value_states[i].clone() @@ -129,13 +138,13 @@ def update_cache_tensors(self): self.key_states[k] = None self.value_states[k] = None - p_key_states = torch.zeros(self.shape_wk, dtype = self.dtype, device = v).contiguous() - p_value_states = torch.zeros(self.shape_wv, dtype = self.dtype, device = v).contiguous() + p_key_states = torch.zeros(self.shape_wk[k], dtype = self.dtype, device = v).contiguous() + p_value_states = torch.zeros(self.shape_wv[k], dtype = self.dtype, device = v).contiguous() self.key_states[k] = p_key_states self.value_states[k] = p_value_states if self.has_scales: - p_key_scales = torch.zeros(self.shape_s, dtype = torch.float16, device = v).contiguous() - p_value_scales = torch.zeros(self.shape_s, dtype = torch.float16, device = v).contiguous() + p_key_scales = torch.zeros(self.shape_s[k], dtype = torch.float16, device = v).contiguous() + p_value_scales = torch.zeros(self.shape_s[k], dtype = torch.float16, device = v).contiguous() self.key_scales[k] = p_key_scales self.value_scales[k] = p_value_scales diff --git a/exllamav2/config.py b/exllamav2/config.py index 3a5296e4..b911671a 100644 --- a/exllamav2/config.py +++ b/exllamav2/config.py @@ -286,7 +286,8 @@ def prepare(self, no_tensors: bool = False): self.num_attention_heads, opt_subkey = "text_config", ) - self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + self.use_qk_norm = read(read_config, bool, ["use_qk_norm"], False) self.query_pre_attn_scalar = read(read_config, float, "query_pre_attn_scalar", None) @@ -299,13 +300,55 @@ def prepare(self, no_tensors: bool = False): else: default_intermediate_size = no_default - self.intermediate_size = read( - read_config, - int, - ["intermediate_size", "ffn_config->ffn_hidden_size", "n_inner"], - default_intermediate_size, - opt_subkey = "text_config", - ) +# Deci overrides num_key_value_heads, num_key_value_groups and intermediate size + if self.architecture == "DeciLMForCausalLM": + if "block_configs" in read_config: # # Llama-3_1-Nemotron-51B + _block_configs: list[dict[str,Any]] = read_config["block_configs"] + assert self.num_hidden_layers == len(_block_configs) + self.num_key_value_heads = list() + self.num_key_value_groups = list() + self.intermediate_size = list() + self.arch.lm.layer_keys = list() + for il in range(len(_block_configs)): + if _block_configs[il]["attention"]["n_heads_in_group"] is None: + if _block_configs[il]["attention"]["replace_with_linear"] is True: + self.num_key_value_heads.append(0) + self.arch.lm.layer_keys.append([["input_layernorm"],["post_attention_layernorm"],["self_attn.linear_attn"],["mlp.down_proj"],["mlp.gate_proj"],["mlp.up_proj"]]) + else: + self.num_key_value_heads.append(0) + self.arch.lm.layer_keys.append([["mlp.down_proj"],["mlp.gate_proj"],["mlp.up_proj"]]) + else: + self.num_key_value_heads.append(self.num_attention_heads // _block_configs[il]["attention"]["n_heads_in_group"]) + self.arch.lm.layer_keys.append([["input_layernorm"],["post_attention_layernorm"],["self_attn.q_proj"], ["self_attn.k_proj"],["self_attn.v_proj"],["self_attn.o_proj"],["mlp.down_proj"],["mlp.gate_proj"],["mlp.up_proj"]]) + if self.num_key_value_heads[il] == 0: + self.num_key_value_groups.append(0) + else: + self.num_key_value_groups.append(self.num_attention_heads // self.num_key_value_heads[il]) + ffn_mult = _block_configs[il]["ffn"]["ffn_mult"] + intm_size = int(2 * ffn_mult * self.hidden_size / 3) + if intm_size % 256 != 0: + intm_size = intm_size + 256 - (intm_size % 256) + self.intermediate_size.append(intm_size) + else: # Deci-7B, no need to override intermediate_size + self.num_key_value_heads: list[int] = read_config["num_key_value_heads_per_layer"] + self.num_key_value_groups = list() + for il in range(len(self.num_key_value_heads)): + self.num_key_value_groups.append(self.num_attention_heads // self.num_key_value_heads[il]) + self.intermediate_size = read( + read_config, + int, + ["intermediate_size", "ffn_config->ffn_hidden_size", "n_inner"], + default_intermediate_size, + opt_subkey = "text_config", + ) + else: + self.intermediate_size = read( + read_config, + int, + ["intermediate_size", "ffn_config->ffn_hidden_size", "n_inner"], + default_intermediate_size, + opt_subkey = "text_config", + ) self.num_experts = read(read_config, int, ["num_local_experts", "ffn_config->moe_num_experts"], None) self.num_experts_per_token = read(read_config, int,["num_experts_per_tok", "ffn_config->moe_top_k"], None) @@ -450,7 +493,12 @@ def check_keys(archparams, prefix): all_keys = set(self.tensor_file_map.keys()) suffixes = [".q_weight", ".qweight", ".weight", ""] +# for k in all_keys: +# print(k) +# print("****End of all_keys****") + for prefixes in expect_keys: +# print(prefixes) match = False for prefix in prefixes: for suffix in suffixes: @@ -462,7 +510,45 @@ def check_keys(archparams, prefix): if not match: raise ValueError(f" ## Could not find {prefix}.* in model") - check_keys(self.arch.lm, self.arch.lm_prefix) + def check_deci_keys(archparams, prefix): + expect_keys = archparams.expect_keys.copy() + + per_layer_keys = archparams.layer_keys + + for layer_idx in range(self.num_hidden_layers): + for ks in per_layer_keys[layer_idx]: + prefixes = [f"model.layers.{layer_idx}.{k}" for k in ks] + expect_keys.append(prefixes) + + if self.arch.lm_prefix: + expect_keys = [ + [prefix + k for k in k2] + for k2 in expect_keys + ] + + all_keys = set(self.tensor_file_map.keys()) + suffixes = [".q_weight", ".qweight", ".weight", ""] + +# for k in all_keys: +# print(k) +# print("****End of all_keys****") + + for prefixes in expect_keys: + match = False + for prefix in prefixes: + for suffix in suffixes: + if (prefix + suffix) in all_keys: + match = True + break + if match: break + if match: break + if not match: + raise ValueError(f" ## Could not find {prefix}.* in model") + + if self.architecture == "DeciLMForCausalLM" and "block_configs" in read_config: # # Llama-3_1-Nemotron-51B + check_deci_keys(self.arch.lm, self.arch.lm_prefix) + else: + check_keys(self.arch.lm, self.arch.lm_prefix) check_keys(self.arch.mmp, self.arch.mmp_prefix) check_keys(self.arch.vt, self.arch.vt_prefix) diff --git a/exllamav2/conversion/compile.py b/exllamav2/conversion/compile.py index 4fcb8b86..5c8a4f17 100644 --- a/exllamav2/conversion/compile.py +++ b/exllamav2/conversion/compile.py @@ -3,6 +3,7 @@ ExLlamaV2Embedding, ExLlamaV2PosEmbedding, ExLlamaV2Attention, + ExLlamaV2LinearAttention, ExLlamaV2MLP, ExLlamaV2MoEMLP, ExLlamaV2ParallelDecoder, @@ -98,6 +99,14 @@ def compile_model(job, save_fn, model): d = get_q_module(job, module.v_proj); out_dict.update(d); current_size += _dsize(d) d = get_q_module(job, module.o_proj); out_dict.update(d); current_size += _dsize(d) + if isinstance(module, ExLlamaV2LinearAttention): + + d = get_f_module(job, module.pre_layernorm) + if d: out_dict.update(d); current_size += _dsize(d) + d = get_f_module(job, module.post_layernorm) + if d: out_dict.update(d); current_size += _dsize(d) + d = get_q_module(job, module.o_proj); out_dict.update(d); current_size += _dsize(d) + if isinstance(module, ExLlamaV2MLP): has_gate = model.config.arch.lm.mlp_gate @@ -277,4 +286,4 @@ def compile_model(job, save_fn, model): with open(config_json, "w") as f: f.write(json.dumps(config_dict, indent = 4)) - print_stage(job, "Compiling", len(model.modules), len(model.modules)) \ No newline at end of file + print_stage(job, "Compiling", len(model.modules), len(model.modules)) diff --git a/exllamav2/conversion/measure.py b/exllamav2/conversion/measure.py index 927fab2a..bcaca269 100644 --- a/exllamav2/conversion/measure.py +++ b/exllamav2/conversion/measure.py @@ -3,6 +3,7 @@ ExLlamaV2Embedding, ExLlamaV2PosEmbedding, ExLlamaV2Attention, + ExLlamaV2LinearAttention, ExLlamaV2MLP, ExLlamaV2MoEMLP, ExLlamaV2Linear, @@ -202,6 +203,48 @@ def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_p return results +def measure_linear_attn(module, hidden_states, target_states, quantizers, cache, attn_params): + + qjobs, qmaps = get_qparams_reduced(qparams_attn) + results = [] + + quantizers["o_proj"].prepare() + + options_o, bits_o = test_quant(module.o_proj, quantizers["o_proj"], qjobs[3]) + + total_numel = module.o_proj.numel() + + max_accuracy = 0.0 + (q_, k_, v_, o_) = (-1, -1, -1, -1) + for (q, k, v, o) in qmaps: + + if o != o_: module.o_proj.linear.weight = nn.Parameter(options_o[o].weight.cuda()) + (q_, k_, v_, o_) = (q, k, v, o) + + total_bits = bits_o[o] + total_bpw = total_bits / total_numel + + accuracy = test_error(module, hidden_states, target_states, cache, attn_params) + print(f" -- {total_bpw:1.4f} bpw accuracy: {accuracy:1.8f}") + + max_accuracy = max(accuracy, max_accuracy) + + torch.cuda.empty_cache() + + r = { "accuracy": accuracy, + "total_bits": total_bits, + "o_proj": qjobs[3][o].get_dict() } + results.append(r) + + if max_accuracy < 0.1: + print(" ## Measurement/inference error (1)") + os._exit(1) + + for x in ["o_proj"]: + if x in quantizers: + del quantizers[x] + + return results def measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_params, reuse_h_up_proj = None): @@ -484,9 +527,16 @@ def measure_quant(job, save_fn, model, hidden_state_offload_layers): if isinstance(module, ExLlamaV2Attention): mode = "self_attn" - quantizers["q_proj"] = AdaptiveGPTQ(module.q_proj.linear) - quantizers["k_proj"] = AdaptiveGPTQ(module.k_proj.linear) - quantizers["v_proj"] = AdaptiveGPTQ(module.v_proj.linear) + if module.q_proj.linear is not None: + quantizers["q_proj"] = AdaptiveGPTQ(module.q_proj.linear) + if module.k_proj.linear is not None: + quantizers["k_proj"] = AdaptiveGPTQ(module.k_proj.linear) + if module.v_proj.linear is not None: + quantizers["v_proj"] = AdaptiveGPTQ(module.v_proj.linear) + quantizers["o_proj"] = AdaptiveGPTQ(module.o_proj.linear) + + elif isinstance(module, ExLlamaV2LinearAttention): + mode = "linear_attn" quantizers["o_proj"] = AdaptiveGPTQ(module.o_proj.linear) elif isinstance(module, ExLlamaV2MLP): @@ -528,7 +578,7 @@ def measure_quant(job, save_fn, model, hidden_state_offload_layers): cache = None attn_params = ExLlamaV2Attention.Params(1, hidden_states[0].shape[1], 0, None, None) \ - if mode in ["self_attn", "parallel_decoder"] else None + if mode in ["self_attn", "linear_attn", "parallel_decoder"] else None target_states = [] target_states_attn = [] @@ -579,6 +629,10 @@ def measure_quant(job, save_fn, model, hidden_state_offload_layers): quantizers["o_proj"].add_batch(outputs["attn_output"]) target_states.append(outputs["hidden_states"].to(target_device)) + if mode == "linear_attn": + quantizers["o_proj"].add_batch(outputs["post_norm"]) + target_states.append(outputs["hidden_states"].to(target_device)) + if mode == "mlp": quantizers["up_proj"].add_batch(outputs["post_norm"]) # Reuse H for gate_proj quantizers["down_proj"].add_batch(outputs["pre_down"]) @@ -622,6 +676,8 @@ def measure_quant(job, save_fn, model, hidden_state_offload_layers): if mode == "self_attn": m = measure_attn(module, hidden_states, target_states, quantizers, cache, attn_params) + if mode == "linear_attn": + m = measure_linear_attn(module, hidden_states, target_states, quantizers, cache, attn_params) if mode == "mlp": m = measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_params) diff --git a/exllamav2/conversion/optimize.py b/exllamav2/conversion/optimize.py index 33e6999f..dfb1e641 100644 --- a/exllamav2/conversion/optimize.py +++ b/exllamav2/conversion/optimize.py @@ -87,10 +87,22 @@ def optimize(job, save_fn, model): if cfg.arch.lm.parallel_decoder_blocks: m1 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + ".parallel_decoder"]["attn"] m2 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + ".parallel_decoder"]["mlp"] + elif type(cfg.arch.lm.layer_keys) is list and type(cfg.intermediate_size) is list: + if ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[i]: + m1 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + ".linear_attn"] + m2 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + "." + mlp_mode] + elif ["self_attn.o_proj"] in cfg.arch.lm.layer_keys[i]: + m1 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + ".self_attn"] + m2 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + "." + mlp_mode] + else: + m1 = None + m2 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + "." + mlp_mode] else: m1 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + ".self_attn"] m2 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + "." + mlp_mode] for m in [m1, m2]: + if m is None: + continue slot = [] param = [] for opt in m: @@ -153,14 +165,34 @@ def optimize(job, save_fn, model): logerr = 0 maxerr = 0 job["strategy"] = {} + deci_offset = 0 for layer_ in range(num_layers): - - k1 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + ".self_attn" - k2 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + "." + mlp_mode - p1 = params[layer_ * 2][solution_idx[layer_ * 2]] - p2 = params[layer_ * 2 + 1][solution_idx[layer_ * 2 + 1]] + if type(cfg.arch.lm.layer_keys) is list and type(cfg.intermediate_size) is list: + if ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[layer_]: + k1 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + ".linear_attn" + k2 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + "." + mlp_mode + p1 = params[layer_ * 2-deci_offset][solution_idx[layer_ * 2-deci_offset]] + p2 = params[layer_ * 2 + 1-deci_offset][solution_idx[layer_ * 2 + 1-deci_offset]] + elif ["self_attn.o_proj"] in cfg.arch.lm.layer_keys[layer_]: + k1 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + ".self_attn" + k2 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + "." + mlp_mode + p1 = params[layer_ * 2-deci_offset][solution_idx[layer_ * 2-deci_offset]] + p2 = params[layer_ * 2 + 1-deci_offset][solution_idx[layer_ * 2 + 1-deci_offset]] + else: + deci_offset = deci_offset + 1 + k1 = None + k2 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + "." + mlp_mode + p1 = None + p2 = params[layer_ * 2 + 1-deci_offset][solution_idx[layer_ * 2 + 1-deci_offset]] + else: + k1 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + ".self_attn" + k2 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + "." + mlp_mode + p1 = params[layer_ * 2][solution_idx[layer_ * 2]] + p2 = params[layer_ * 2 + 1][solution_idx[layer_ * 2 + 1]] for (k, p, n) in zip((k1, k2), (p1, p2), (numel_attn, numel_mlp)): + if k is None or p is None or n is None: + continue job["strategy"][k] = p bpw = p["total_bits"] / n err = 1 - p["accuracy"] @@ -171,4 +203,4 @@ def optimize(job, save_fn, model): print(f" -- sum(log(err)): {logerr:.6f}") print(f" -- max(err): {maxerr:.6f}") - xx = 0 \ No newline at end of file + xx = 0 diff --git a/exllamav2/conversion/quantize.py b/exllamav2/conversion/quantize.py index 5f38eb3a..19ac1ce6 100644 --- a/exllamav2/conversion/quantize.py +++ b/exllamav2/conversion/quantize.py @@ -3,6 +3,7 @@ ExLlamaV2Embedding, ExLlamaV2PosEmbedding, ExLlamaV2Attention, + ExLlamaV2LinearAttention, ExLlamaV2MLP, ExLlamaV2MoEMLP, ExLlamaV2ParallelDecoder, @@ -151,6 +152,17 @@ def quant_attn(job, module, hidden_states, target_states, quantizers, attn_param torch.cuda.empty_cache() +def quant_linear_attn(job, module, hidden_states, target_states, quantizers, attn_params, strat): + + quantizers["o_proj"].prepare() + + quant_linear(job, module.o_proj, quantizers["o_proj"], strat["o_proj"]) + del quantizers[f"o_proj"] + + gc.collect() + torch.cuda.empty_cache() + + def quant_mlp(job, module, hidden_states, target_states, quantizers, attn_params, strat, reuse_h_up_proj = None): has_mlp = module.model.config.arch.lm.mlp_gate @@ -309,6 +321,11 @@ def quant(job, save_fn, model): quantizers["v_proj"] = AdaptiveGPTQ(module.v_proj.linear) quantizers["o_proj"] = AdaptiveGPTQ(module.o_proj.linear) + elif isinstance(module, ExLlamaV2LinearAttention): + mode = "linear_attn" + # if index > 1: testc(module, hidden_states, hidden_i_states, module.input_layernorm, [module.q_proj, module.k_proj, module.v_proj]) + quantizers["o_proj"] = AdaptiveGPTQ(module.o_proj.linear) + elif isinstance(module, ExLlamaV2MLP): mode = "mlp" has_mlp = model.config.arch.lm.mlp_gate @@ -368,6 +385,9 @@ def quant(job, save_fn, model): quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K and V quantizers["o_proj"].add_batch(outputs["attn_output"]) + if mode == "linear_attn": + quantizers["o_proj"].add_batch(outputs["post_norm"]) + if mode == "mlp": quantizers["up_proj"].add_batch(outputs["post_norm"]) # Reuse H for gate_proj quantizers["down_proj"].add_batch(outputs["pre_down"]) @@ -409,6 +429,10 @@ def quant(job, save_fn, model): strat = strategy[module.key + "." + mode] quant_attn(job, module, hidden_states, target_states, quantizers, attn_params, strat) + if mode == "linear_attn": + strat = strategy[module.key + "." + mode] + quant_linear_attn(job, module, hidden_states, target_states, quantizers, attn_params, strat) + if mode == "mlp": strat = strategy[module.key + "." + mode] quant_mlp(job, module, hidden_states, target_states, quantizers, attn_params, strat) diff --git a/exllamav2/generator/dynamic.py b/exllamav2/generator/dynamic.py index abac46ec..ebb952e2 100644 --- a/exllamav2/generator/dynamic.py +++ b/exllamav2/generator/dynamic.py @@ -471,6 +471,8 @@ def __init__( cache_tensors += self.draft_cache.all_tensors() for c in cache_tensors: + if c is None: + continue key = (c.device.index, c.dtype, c.shape[2], c.shape[3]) if key not in self.defrag_buffer: t = torch.empty((1, self.page_size, c.shape[2], c.shape[3]), dtype = c.dtype, device = c.device) diff --git a/exllamav2/mlp.py b/exllamav2/mlp.py index d0831573..5e4a2d8b 100644 --- a/exllamav2/mlp.py +++ b/exllamav2/mlp.py @@ -62,11 +62,18 @@ def __init__( self.intermediate_size = cfg.vision_intermediate_size else: self.hidden_size = cfg.hidden_size - self.intermediate_size = cfg.intermediate_size + if type(cfg.intermediate_size) is list: + self.intermediate_size = cfg.intermediate_size[layer_idx] + else: + self.intermediate_size = cfg.intermediate_size if in_features is None: in_features = self.hidden_size if out_features is None: out_features = self.hidden_size - if interm_features is None: interm_features = self.intermediate_size + if interm_features is None: + if type(self.intermediate_size) is list: + interm_features = self.intermediate_size[layer_idx] + else: + interm_features = self.intermediate_size self.in_features = in_features self.out_features = out_features self.interm_features = interm_features diff --git a/exllamav2/model.py b/exllamav2/model.py index 80ccf758..d7facfd1 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -37,6 +37,7 @@ from exllamav2.rmsnorm import ExLlamaV2RMSNorm from exllamav2.layernorm import ExLlamaV2LayerNorm from exllamav2.attn import ExLlamaV2Attention, has_flash_attn, has_xformers +from exllamav2.linear_attn import ExLlamaV2LinearAttention, has_flash_attn, has_xformers from exllamav2.lora import ExLlamaV2Lora from exllamav2.mlp import ExLlamaV2MLP from exllamav2.moe_mlp import ExLlamaV2MoEMLP @@ -117,6 +118,16 @@ def __init__( if cfg.arch.lm.parallel_decoder_blocks: pd = ExLlamaV2ParallelDecoder(self, layer_key, layer_idx, sliding_window = swa) self.modules += [pd] + elif type(cfg.arch.lm.layer_keys) is list and type(cfg.intermediate_size) is list: + mlp = ExLlamaV2MLP(self, layer_key, layer_idx) + if ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[layer_idx]: + attn = ExLlamaV2LinearAttention(self, layer_key, layer_idx) + self.modules += [attn, mlp] + elif ["self_attn.o_proj"] in cfg.arch.lm.layer_keys[layer_idx]: + attn = ExLlamaV2Attention(self, layer_key, layer_idx, sliding_window = swa) + self.modules += [attn, mlp] + else: + self.modules += [mlp] else: attn = ExLlamaV2Attention(self, layer_key, layer_idx, sliding_window = swa) if cfg.arch.lm.is_moe: mlp = ExLlamaV2MoEMLP(self, layer_key, layer_idx) @@ -161,6 +172,7 @@ def __init__( while True: layer_idx -= 1 if isinstance(self.modules[layer_idx], ExLlamaV2Attention) or \ + isinstance(self.modules[layer_idx], ExLlamaV2LinearAttention) or \ isinstance(self.modules[layer_idx], ExLlamaV2ParallelDecoder): break @@ -609,6 +621,7 @@ def load_autosplit_gen( try: if isinstance(module, ExLlamaV2Attention) or \ + isinstance(module, ExLlamaV2LinearAttention) or \ isinstance(module, ExLlamaV2ParallelDecoder): self.cache_map[module.layer_idx] = module.device() cache.update_cache_tensors() From bc370eb65e9d66f08735d93d7d99618a320cdf55 Mon Sep 17 00:00:00 2001 From: Yee Man Chan Date: Tue, 28 Jan 2025 12:32:21 +0800 Subject: [PATCH 2/3] Llama-3_1-Nemotron-51B support --- exllamav2/linear_attn.py | 1329 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 1329 insertions(+) create mode 100644 exllamav2/linear_attn.py diff --git a/exllamav2/linear_attn.py b/exllamav2/linear_attn.py new file mode 100644 index 00000000..dc36d7ff --- /dev/null +++ b/exllamav2/linear_attn.py @@ -0,0 +1,1329 @@ +from __future__ import annotations + +import torch +from torch import nn +from exllamav2.module import ExLlamaV2Module +from exllamav2.rmsnorm import ExLlamaV2RMSNorm +from exllamav2.layernorm import ExLlamaV2LayerNorm +from exllamav2.headnorm import ExLlamaV2HeadNorm +from exllamav2.linear import ExLlamaV2Linear +from exllamav2.cache import ExLlamaV2CacheBase +from exllamav2.ext import exllamav2_ext as ext_c, none_tensor +from exllamav2.lora import ExLlamaV2Lora +from exllamav2.architecture import RopeStyle +from exllamav2.tensor_p import BROADCAST_KV, BROADCAST_Q +import math +import torch.nn.functional as F +import inspect +import os + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from exllamav2.model import ExLlamaV2 + +# Detect available options for attention + +has_flash_attn = False +has_flash_attn_with_paged = False +has_flash_attn_with_window = False +has_flash_attn_with_softcap = False +has_xformers = False +has_lower_right_sdpa = False + +if 'EXLLAMA_NO_FLASH_ATTN' not in os.environ: + + try: + import flash_attn + flash_attn_ver = [int(t) for t in flash_attn.__version__.split(".") if t.isdigit()] + is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count())) + + if not is_ampere_or_newer_gpu: + print(" ## Warning: Flash Attention is installed but unsupported GPUs were detected.") + + if [2, 2, 1] <= flash_attn_ver < [2, 5, 7]: + from flash_attn import flash_attn_func + has_flash_attn = True + + if [2, 5, 7] <= flash_attn_ver: + from flash_attn import flash_attn_func, flash_attn_with_kvcache + # import flash_attn_2_cuda as flash_attn_cuda + + signature = list(inspect.signature(flash_attn_func).parameters) + has_flash_attn_with_window = "window_size" in signature + has_flash_attn_with_softcap = "softcap" in signature + + import flash_attn_2_cuda as flash_attn_cuda + # ext_c.set_flash_attn_func() + + has_flash_attn = True + has_flash_attn_with_paged = True + + except ModuleNotFoundError: + pass + except NameError: + pass + +if 'EXLLAMA_NO_XFORMERS' not in os.environ: + + try: + import xformers.ops as xops + # LowerTriangularFromBottomRightMask was added in xformers version 2.4 + from xformers.ops.fmha.attn_bias import LowerTriangularFromBottomRightMask + has_xformers = True + except ModuleNotFoundError: + pass + +if 'EXLLAMA_NO_SDPA' not in os.environ: + try: + from torch.nn.attention.bias import causal_lower_right + has_lower_right_sdpa = True + except ImportError: + pass + + +def assert_paged_attn(): + """ + Raise an exception if paged attention is not available. + """ + global has_flash_attn_with_paged + assert has_flash_attn_with_paged, \ + "Paged attention required Flash Attention 2.5.7 or later" + + +class ExLlamaV2LinearAttention(ExLlamaV2Module): + + name: str = "LinearAttention" + + layer_idx: int + pre_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None + post_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None + q_proj: ExLlamaV2Linear | None + k_proj: ExLlamaV2Linear | None + v_proj: ExLlamaV2Linear | None + o_proj: ExLlamaV2Linear | None + q_norm: ExLlamaV2HeadNorm | None + k_norm: ExLlamaV2HeadNorm | None + + q_handle: int | None + + temp_state: torch.tensor + temp_q: torch.tensor + temp_k: torch.tensor + temp_v: torch.tensor + temp_o: torch.tensor + temp_dq: torch.tensor + # temp_kv: torch.tensor + + temp_lora_size: int + + has_norm: bool + has_residual: bool + scaling: float + sliding_window: int + + is_tp: bool + tp_dq_size: list[int] | None + + from exllamav2.attn_params import Params + from exllamav2.attn_params import PagedParams + + def __init__( + self, + model: ExLlamaV2, + key: str, + layer_idx: int, + has_norm: bool = True, + has_residual: bool = True, + sliding_window: int = 0, + archparams = None + ): + super().__init__(model, key, archparams) + + cfg = self.model.config + ap = self.archparams + km = self.archparams.keys + + self.is_tp = False + self.tp_dq_size = None + + self.layer_idx = layer_idx + self.has_norm = has_norm + self.has_residual = has_residual + + self.q_handle = None + self.temp_lora_size = 0 + + if ap.is_vision: + self.num_attention_heads = cfg.vision_num_attention_heads + self.num_key_value_heads = cfg.vision_num_key_value_heads + self.num_key_value_groups = cfg.vision_num_key_value_groups + self.head_dim = cfg.vision_head_dim + self.hidden_size = cfg.vision_hidden_size + else: + self.num_attention_heads = cfg.num_attention_heads + if type(cfg.num_key_value_heads) is list: + self.num_key_value_heads = cfg.num_key_value_heads[layer_idx] + else: + self.num_key_value_heads = cfg.num_key_value_heads + if type(cfg.num_key_value_groups) is list: + self.num_key_value_groups = cfg.num_key_value_groups[layer_idx] + else: + self.num_key_value_groups = cfg.num_key_value_groups + self.head_dim = cfg.head_dim + self.hidden_size = cfg.hidden_size + + hidden_size = self.hidden_size + + if self.has_norm and (km["norm_1"] or km["norm_1_post"]): + if ap.norm == "layernorm": + self.pre_layernorm = ExLlamaV2LayerNorm(model, key + km["norm_1"], archparams) + self.post_layernorm = ExLlamaV2LayerNorm(model, key + km["norm_1_post"], archparams) if km["norm_1_post"] else None + elif ap.norm == "rmsnorm": + self.pre_layernorm = ExLlamaV2RMSNorm(model, key + km["norm_1"], archparams) + self.post_layernorm = ExLlamaV2RMSNorm(model, key + km["norm_1_post"], archparams) if km["norm_1_post"] else None + else: + self.pre_layernorm = None + self.post_layernorm = None + self.has_norm = False + + f_a = 0 + f_b = self.num_attention_heads * self.head_dim + f_c = f_b + self.num_key_value_heads * self.head_dim + f_d = f_c + self.num_key_value_heads * self.head_dim + f_key = (key + km["fused_qkv"]) if km["fused_qkv"] else None + + self.o_proj = ExLlamaV2Linear(model, key + km["linear_attn"], self.num_attention_heads * self.head_dim, hidden_size, ap.attention_bias_o, prescale = cfg.scale_depth) + + if cfg.use_qk_norm: + self.q_norm = ExLlamaV2HeadNorm(model, key + ".self_attn.q_norm", self.num_attention_heads, self.head_dim) + self.k_norm = ExLlamaV2HeadNorm(model, key + ".self_attn.k_norm", self.num_key_value_heads, self.head_dim) + else: + self.q_norm = None + self.k_norm = None + + self.submodules = [ + self.o_proj + ] + + if self.pre_layernorm: + self.submodules += [self.pre_layernorm] + if self.post_layernorm: + self.submodules += [self.post_layernorm] + if cfg.use_qk_norm: + self.submodules += [self.q_norm, self.k_norm] + + if cfg.attention_multiplier: + self.scaling = cfg.attention_multiplier + elif cfg.query_pre_attn_scalar: + self.scaling = cfg.query_pre_attn_scalar ** (-0.5) + else: + self.scaling = 1 / math.sqrt(self.head_dim) + + self.sliding_window = sliding_window + + + def numel(self) -> int: + + numel = self.o_proj.numel() + + if self.pre_layernorm is not None: numel += self.pre_layernorm.numel() + if self.post_layernorm is not None: numel += self.post_layernorm.numel() + if self.q_norm is not None: numel += self.q_norm.numel() + if self.k_norm is not None: numel += self.k_norm.numel() + + return numel + + + @torch.inference_mode + def load(self, device_context: bool = True): + + cfg = self.model.config + + if self.pre_layernorm is not None: self.pre_layernorm.load() + if self.post_layernorm is not None: self.post_layernorm.load() + self.o_proj.load(device_context = device_context) + + def unload(self): + if self.q_handle is not None: + ext_c.free_q_attn(self.q_handle) + self.q_handle = None + + if self.pre_layernorm is not None: self.pre_layernorm.unload() + if self.post_layernorm is not None: self.post_layernorm.unload() + self.o_proj.unload() + + self.temp_state = None + self.temp_dq = None + + def weight_footprint(self): + + fp = self.o_proj.weight_footprint() + if self.pre_layernorm is not None: + fp += self.pre_layernorm.weight_footprint() + if self.post_layernorm is not None: + fp += self.post_layernorm.weight_footprint() + + return fp + + + def scratch_space_fixed(self): + + return self.temp_state_size() + \ + self.temp_dq_size() + + + def scratch_space(self): + + return self.temp_state_size() + \ + self.temp_dq_size() + \ + self.temp_kv_size() + # self.temp_attn_size() + # Accounted for separately in model.set_device_map() + + + def temp_state_size(self): + + cfg = self.model.config + return cfg.max_input_len * cfg.max_batch_size * max(self.num_attention_heads * self.head_dim, self.hidden_size) * 2 + 128 + + + def temp_q_size(self): + + cfg = self.model.config + return cfg.max_input_len * cfg.max_batch_size * self.num_attention_heads * self.head_dim * 2 + 128 + + + def temp_k_size(self): + + cfg = self.model.config + return cfg.max_input_len * cfg.max_batch_size * self.num_key_value_heads * self.head_dim * 2 + 128 + + + def temp_v_size(self): + + cfg = self.model.config + return cfg.max_input_len * cfg.max_batch_size * self.num_key_value_heads * self.head_dim * 2 + 128 + + + def temp_dq_size(self): + + return self.o_proj.temp_dq_size() + + def temp_kv_size(self): + + cfg = self.model.config + if self.num_key_value_heads == self.num_attention_heads: return 0 + return 2 * cfg.max_seq_len * cfg.max_batch_size * self.num_attention_heads * self.head_dim * 2 + 128 + + + def temp_attn_size(self): + global has_flash_attn + global has_xformers + + cfg = self.model.config + att_max = min(cfg.max_attention_size, cfg.max_seq_len ** 2) + + if (has_flash_attn and not cfg.no_flash_attn) or (has_xformers and not cfg.no_xformers) : + #in sm>=80 devices, xformers uses the same memory as flash_attn + #todo: due to the different implementions. in sm<80 devices, xformers uses less memory than it in sm>=80. There may still be room for optimization. + eff = cfg.max_attention_size ** 0.5 / 190 # based on supposed memory savings listed in flash-attn repo + some fudging + att_max //= eff + + return 2 * att_max * self.num_attention_heads * 2 + 128 + + + def set_device_idx(self, idx: int | None): + super().set_device_idx(idx) + + if self.pre_layernorm is not None: self.pre_layernorm.set_device_idx(idx) + if self.post_layernorm is not None: self.post_layernorm.set_device_idx(idx) + self.o_proj.set_device_idx(idx) + + def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + + if n_rep == 1: return hidden_states + + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + return hidden_states + + + # @profile + def forward_paged( + self, + hidden_states: torch.Tensor, + cache: ExLlamaV2CacheBase | None = None, + attn_params: ExLlamaV2Attention.PagedParams | None = None, + loras: list[ExLlamaV2Lora] | None = None, + **kwargs + ) -> torch.Tensor: + + if self.is_tp: + return self.forward_paged_tp( + hidden_states, + cache, + attn_params, + loras, + **kwargs, + ) + + is_q = self.q_handle is not None + cfg = self.model.config + constants = self.model.get_device_context(self.device_idx, scratch = is_q) + page_size = attn_params.page_size + batch_size, q_len, _ = hidden_states.shape + cache_seqlens = attn_params.get_cache_seqlens(self.device_idx) + block_table = attn_params.get_block_index(self.device_idx) + + # TODO: We only need keys/values when preprocess_only == True, so we could skip q projection and attention + # on the last layer. Would need custom kernel to update paged cache if not calling flash_attn_with_kvcache + # skip_attn = kwargs.get("kv_only") + + # TODO: Potentially we could emulate paged cache when in Q4 mode, since that requires copying the active part + # of the current cache layer anyway. Test if block diagonal masking works with lower-right aligned mask. + + if cache.q_block > 1: + k_cache_f, v_cache_f = cache.get_kv_state(self.layer_idx, batch_size, 0, attn_params.max_cache_seqlen, page_size, cache_seqlens, block_table) + else: + k_cache_f, v_cache_f = cache.get_kv_state(self.layer_idx, batch_size, 0, 0, page_size, cache_seqlens, block_table) + + k_cache = k_cache_f.view(k_cache_f.shape[1] // page_size, page_size, k_cache_f.shape[2], k_cache_f.shape[3]) + v_cache = v_cache_f.view(v_cache_f.shape[1] // page_size, page_size, v_cache_f.shape[2], v_cache_f.shape[3]) + + sc = attn_params.get_alt_rope_embed(self.device_idx) + if not sc: + sin, cos = constants.sin, constants.cos + else: + sin, cos = sc + + cache_seqlens_rope = cache_seqlens + offsets = attn_params.get_rope_offsets(self.device_idx) + if offsets is not None: + cache_seqlens_rope = cache_seqlens_rope + offsets + + if is_q: + q = torch.empty((batch_size, q_len, self.num_attention_heads, self.head_dim), device = hidden_states.device, dtype = torch.half) + if attn_params.is_sequential: + assert batch_size == 1 + k = k_cache_f[:, attn_params.first_index : attn_params.first_index + q_len, :, :] + v = v_cache_f[:, attn_params.first_index : attn_params.first_index + q_len, :, :] + else: + k = torch.empty((batch_size, q_len, self.num_key_value_heads, self.head_dim), device = hidden_states.device, dtype = torch.half) + v = torch.empty((batch_size, q_len, self.num_key_value_heads, self.head_dim), device = hidden_states.device, dtype = torch.half) + + if loras is None or self.temp_lora_size == 0: + pass_loras = [] + pass_lora_temp = none_tensor + else: + pass_loras = [id(x) for x in loras] + pass_lora_temp = torch.empty((self.temp_lora_size,), dtype = torch.half, device = hidden_states.device) + + ext_c.q_attn_forward_1( + self.q_handle, + hidden_states, + batch_size, + q_len, + 0, + cache_seqlens_rope, + q, + k, + v, + sin, + cos, + pass_loras, + pass_lora_temp + ) + else: + residual = hidden_states + hidden_states = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states + + attn_output = hidden_states.view((batch_size, q_len, self.num_attention_heads * self.head_dim)) + + cache.store_kv_state(self.layer_idx, batch_size, 0, q_len, page_size, cache_seqlens, block_table) + + # Output projection + + if is_q: + ext_c.q_attn_forward_2( + self.q_handle, + hidden_states, + attn_output, + batch_size, + q_len, + pass_loras, + pass_lora_temp + ) + else: + hidden_states = self.o_proj.forward(attn_output, loras = loras) + if self.post_layernorm: + hidden_states = self.post_layernorm.forward(hidden_states) + if self.has_residual: + hidden_states += residual + + return hidden_states + + + # @profile + def forward_paged_tp( + self, + hidden_states: torch.Tensor, + cache: ExLlamaV2CacheBase | None = None, + attn_params: ExLlamaV2Attention.PagedParams | None = None, + loras: list[ExLlamaV2Lora] | None = None, + **kwargs + ) -> torch.Tensor: + + cfg = self.model.config + ctx = self.model.tp_context + + assert not self.sliding_window, \ + "Sliding window not supported in TP mode" + + attn_params.prep_tp(self.model) + page_size = attn_params.page_size + + batch_size, q_len, _ = hidden_states.shape + rows = batch_size * q_len + hidden_states = hidden_states.view(-1, self.hidden_size) + dtype = hidden_states.dtype + + k_cache_f, v_cache_f = cache.get_kv_state( + self.layer_idx, + batch_size, + 0, + attn_params.max_cache_seqlen, + page_size, + attn_params.cache_seqlens_tp, + attn_params.block_index_tp + ) + + k_cache = [x.view(x.shape[1] // page_size, page_size, x.shape[2], x.shape[3]) for x in k_cache_f] + v_cache = [x.view(x.shape[1] // page_size, page_size, x.shape[2], x.shape[3]) for x in v_cache_f] + + sin, cos = ctx.get_sin_cos() + + ext_c.tp_attn_forward_paged_( + self.model.tp_context.ext_tp_context, + hidden_states, + self.temp_bc0, + self.temp_bc1, + self.temp_bc2, + self.temp_q, + self.temp_k, + self.temp_v, + self.temp_o, + k_cache, + v_cache, + self.pre_layernorm.weight if self.pre_layernorm is not None else [], + self.pre_layernorm.variance_epsilon if self.pre_layernorm is not None else 0.0, + self.q_proj.q_handle, + self.k_proj.q_handle, + self.v_proj.q_handle, + self.o_proj.q_handle, + self.head_dim, + int(self.archparams.rope_style), + batch_size, + q_len, + sin, + cos, + attn_params.cache_seqlens_tp, + attn_params.block_index_tp, + self.scaling + ) + + cache.store_kv_state( + self.layer_idx, + batch_size, + 0, + q_len, + page_size, + attn_params.cache_seqlens_tp, + attn_params.block_index_tp + ) + + return ctx.get_pinned(0, batch_size, q_len, self.hidden_size) + + + # @profile + def forward_paged_tp_old( + self, + hidden_states: torch.Tensor, + cache: ExLlamaV2CacheBase | None = None, + attn_params: ExLlamaV2Attention.PagedParams | None = None, + loras: list[ExLlamaV2Lora] | None = None, + **kwargs + ) -> torch.Tensor: + + assert self.q_handle is not None + cfg = self.model.config + split = self.model.tp_context.get_split(BROADCAST_KV) + batch_size, q_len, _ = hidden_states.shape + attn_params.prep_tp(self.model) + page_size = attn_params.page_size + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + k_cache_f, v_cache_f = cache.get_kv_state( + self.layer_idx, + batch_size, + 0, + attn_params.max_cache_seqlen, + page_size, + attn_params.cache_seqlens_tp, + attn_params.block_index_tp + ) + + k_cache = [x.view(x.shape[1] // page_size, page_size, x.shape[2], x.shape[3]) for x in k_cache_f] + v_cache = [x.view(x.shape[1] // page_size, page_size, x.shape[2], x.shape[3]) for x in v_cache_f] + + hidden_states = self.model.tp_context.broadcast(0, hidden_states, BROADCAST_KV, dim = self.head_dim) + + residual = hidden_states + + post_norm = self.pre_layernorm.forward_tp(hidden_states, output_split = True) if self.has_norm else hidden_states + q = self.q_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = self.head_dim) + k = self.k_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = self.head_dim) + v = self.v_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = self.head_dim) + q = [q_.view(batch_size, q_len, q_.shape[1] // self.head_dim, self.head_dim) for q_ in q] + k = [k_.view(batch_size, q_len, k_.shape[1] // self.head_dim, self.head_dim) for k_ in k] + v = [v_.view(batch_size, q_len, v_.shape[1] // self.head_dim, self.head_dim) for v_ in v] + if cfg.use_qk_norm: + assert False, "TP not implemented for QK norm" # TODO: ... + # q = self.q_norm.forward(q) + # k = self.k_norm.forward(k) + if self.archparams.rope_style != RopeStyle.NONE: + for idx, (dev, a, b) in enumerate(split): + context = self.model.get_device_context(dev) + torch.cuda.set_stream(context.stream) + for t, heads in [(q[idx], self.num_key_value_groups), (k[idx], 1)]: + ext_c.rope_( + t, + context.sin, + context.cos, + 0, + (b - a) * heads, + self.head_dim, + attn_params.cache_seqlens_tp[idx], + self.archparams.rope_style == RopeStyle.NEOX + ) + if attn_params.is_sequential: + k_ = [x[:, attn_params.first_index: attn_params.first_index + q_len, :, :] for x in k_cache_f] + v_ = [x[:, attn_params.first_index: attn_params.first_index + q_len, :, :] for x in v_cache_f] + for (dev, a, b), x_, x, y_, y in zip(split, k_, k, v_, v): + context = self.model.get_device_context(dev) + torch.cuda.set_stream(context.stream) + x_.copy_(x) + y_.copy_(y) + k = None + v = None + cache_seqlens_a = attn_params.cache_seqlens_after_tp + else: + cache_seqlens_a = attn_params.cache_seqlens_tp + + # if cache.q_block == 1: + # cache.get_kv_state( + # self.layer_idx, + # batch_size, + # 0, + # attn_params.max_cache_seqlen, + # page_size, + # attn_params.cache_seqlens_tp, + # attn_params.block_index_tp + # ) + + flash_kwargs = {} + if self.sliding_window: + # assert has_flash_attn_with_window, \ + # "Installed version of flash-attn does not support sliding window" + if has_flash_attn_with_window: + flash_kwargs["window_size"] = (self.sliding_window, self.sliding_window) + if cfg.attn_logit_softcapping: + # assert has_flash_attn_with_softcap, \ + # "Installed version of flash-attn does not support softcapping" + if has_flash_attn_with_softcap: + flash_kwargs["softcap"] = cfg.attn_logit_softcapping + + attn_outputs = [] + for idx in range(len(split)): + dev, a, b = split[idx] + context = self.model.get_device_context(dev) + torch.cuda.set_stream(context.stream) + + attn_output = flash_attn_with_kvcache( + q = q[idx], + k = k[idx] if k is not None else None, + v = v[idx] if v is not None else None, + k_cache = k_cache[idx], + v_cache = v_cache[idx], + cache_seqlens = cache_seqlens_a[idx], + block_table = attn_params.block_index_tp[idx], + causal = True, + softmax_scale = self.scaling, + **flash_kwargs + ) + attn_output = attn_output.view(batch_size * q_len, (b - a) * self.head_dim * self.num_key_value_groups) + attn_outputs.append(attn_output) + + cache.store_kv_state( + self.layer_idx, + batch_size, + 0, + q_len, + page_size, + attn_params.cache_seqlens_tp, + attn_params.block_index_tp + ) + + # Output projection + + attn_outputs = self.model.tp_context.allgather(1, attn_outputs, BROADCAST_Q, BROADCAST_Q, dim = self.head_dim) + + hidden_states = self.o_proj.forward_tp(attn_outputs, loras = loras, dim = self.head_dim, output_split = True) + + if self.has_residual: + self.model.tp_context.add_residual(hidden_states, residual, BROADCAST_Q, dim = self.head_dim) + + hidden_states = self.model.tp_context.gather(0, hidden_states, BROADCAST_Q, dim = self.head_dim) + + # if self.post_layernorm: # TODO: ... + # hidden_states = self.post_layernorm.forward(hidden_states) + + hidden_states = hidden_states.view(batch_size, q_len, hidden_states.shape[-1]) + return hidden_states + + def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg, causal = True): + + num_attn_heads = q_states.shape[2] + head_dim = q_states.shape[3] + + q_states = q_states.transpose(1, 2) + k_states = k_states.transpose(1, 2) + v_states = v_states.transpose(1, 2) + + # SDPA + + if has_lower_right_sdpa and not cfg.no_sdpa and not cfg.attn_logit_softcapping: + + k_states = self.repeat_kv(k_states, self.num_key_value_groups) + v_states = self.repeat_kv(v_states, self.num_key_value_groups) + + if self.sliding_window and k_states.shape[2] >= self.sliding_window: + k_states = k_states[:, :, -self.sliding_window:, :] + v_states = v_states[:, :, -self.sliding_window:, :] + + if attn_params.is_causal(): + attn_mask_lr = causal_lower_right(q_len, k_states.shape[2]) + else: + attn_mask_lr = attn_params.get_attn_mask(q_states.device) + attn_output = F.scaled_dot_product_attention( + q_states, + k_states, + v_states, + attn_mask_lr if causal else None, + scale = self.scaling + ) + + # Matmul attn + + else: + + k_states = self.repeat_kv(k_states, self.num_key_value_groups) + k_states = k_states.transpose(-1, -2) + + attn_weights = torch.matmul(q_states, k_states) + + attn_weights *= self.scaling + if causal: + attn_mask = attn_params.get_attn_mask(attn_weights.device) + + if cfg.attn_logit_softcapping: + ext_c.softcap_(attn_weights, cfg.attn_logit_softcapping) + if causal and attn_mask is not None: + attn_weights = attn_weights + attn_mask + if self.sliding_window and k_states.shape[-1] >= self.sliding_window: + attn_weights = attn_weights[:, :, :, -self.sliding_window:] + v_states = v_states[:, :, -self.sliding_window:, :] + + attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16) + + v_states = self.repeat_kv(v_states, self.num_key_value_groups) + attn_output = torch.matmul(attn_weights, v_states) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape((batch_size, q_len, num_attn_heads * head_dim)) + return attn_output + + + def _attn_flash(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg, causal = True): + + flash_kwargs = {} + if self.sliding_window: + # assert has_flash_attn_with_window, \ + # "Installed version of flash-attn does not support sliding window" + if has_flash_attn_with_window: + flash_kwargs["window_size"] = (self.sliding_window, self.sliding_window) + if cfg.attn_logit_softcapping: + # assert has_flash_attn_with_softcap, \ + # "Installed version of flash-attn does not support softcapping" + if has_flash_attn_with_softcap: + flash_kwargs["softcap"] = cfg.attn_logit_softcapping + + attn_output = flash_attn_func( + q_states, + k_states, + v_states, + causal = causal, + softmax_scale = self.scaling, + **flash_kwargs + ) + attn_output = attn_output.reshape((batch_size, q_len, self.num_attention_heads * self.head_dim)) + return attn_output + + + def _attn_xformers(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg, causal = True): + + # assert not self.sliding_window, \ + # "Sliding window not currently supported for xformers" + + # assert not cfg.attn_logit_softcapping, \ + # "Softcap not yet supported for xformers" + + # xformers memory_efficient_attention, could be beneficial if your device's architecture is less than sm_80 are almost the same. But the martix operation + # make this implemention much slower. + + k_states = k_states.transpose(1, 2) + v_states = v_states.transpose(1, 2) + + k_states = self.repeat_kv(k_states, self.num_key_value_groups) + v_states = self.repeat_kv(v_states, self.num_key_value_groups) + + k_states = k_states.transpose(1, 2) + v_states = v_states.transpose(1, 2) + + attn_output = xops.memory_efficient_attention( + q_states, + k_states, + v_states, + attn_bias = LowerTriangularFromBottomRightMask() if causal else None, + scale = self.scaling + ) + attn_output = attn_output.reshape((batch_size, q_len, self.num_attention_heads * self.head_dim)) + + return attn_output + + + # @profile + def forward( + self, + hidden_states: torch.Tensor, + cache: ExLlamaV2CacheBase | None = None, + attn_params: ExLlamaV2Attention.Params | None = None, + past_len: int | None = None, + intermediates: bool = False, + loras: list[ExLlamaV2Lora] | None = None, + **kwargs + ) -> torch.Tensor | dict[str: torch.Tensor]: + + cfg = self.model.config + global has_flash_attn + global has_xformers + use_flash_attn = has_flash_attn and not cfg.no_flash_attn + + if isinstance(attn_params, ExLlamaV2LinearAttention.PagedParams): + return self.forward_paged( + hidden_states, + cache, + attn_params, + loras = loras, + **kwargs + ) + + if self.is_tp: + if cache is not None and use_flash_attn: + return self.forward_tp( + hidden_states, + cache, + attn_params, + past_len, + intermediates, + loras, + **kwargs, + ) + else: + # TODO: Can't use the optimized forward function because it writes directly to a fixed output + # tensor, and flash-attn currently has a bug that prevents that from working when q_len == 1 + return self.forward_tp_old( + hidden_states, + cache, + attn_params, + past_len, + intermediates, + loras, + **kwargs, + ) + + if self.q_handle is None or intermediates: + return self.forward_torch( + hidden_states, + cache, + attn_params, + past_len, + intermediates, + loras = loras, + **kwargs + ) + + constants = self.model.get_device_context(self.device_idx) + + batch_size, q_len, _ = hidden_states.shape + direct = (batch_size == 1 and cache is not None and isinstance(cache, ExLlamaV2CacheBase)) + + # If conditions are right we can write the K/V projections directly into the cache + + if direct: + batch_keys, batch_values = cache.get_kv_state(self.layer_idx, batch_size, 0, past_len) + + # RMS norm, Q/K/V projections, position embeddings + + if loras is None or self.temp_lora_size == 0: + pass_loras = [] + pass_lora_temp = none_tensor + else: + pass_loras = [id(x) for x in loras] + pass_lora_temp = torch.empty((self.temp_lora_size,), dtype = torch.half, device = hidden_states.device) + + if attn_params.position_offsets is not None: + pass_past_len_1 = past_len + pass_past_len_2 = attn_params.get_position_offsets(hidden_states.device) + if pass_past_len_1 == 0: + offsets = attn_params.get_rope_offsets(self.device_idx) + if offsets is not None: + pass_past_len_2 = pass_past_len_2 + offsets + else: + pass_past_len_1 = past_len + pass_past_len_2 = none_tensor + if attn_params.rope_offsets is not None: + offset = attn_params.rope_offsets.cpu().item() + pass_past_len_1 += offset + + ext_c.q_attn_forward_1( + self.q_handle, + hidden_states, + batch_size, + q_len, + pass_past_len_1, + pass_past_len_2, + q_states, + k_states, + v_states, + constants.sin, + constants.cos, + pass_loras, + pass_lora_temp + ) + + # Select attention function + + if (has_flash_attn and not cfg.no_flash_attn) and attn_params.is_causal(): + attn_func = self._attn_flash + elif (has_xformers and not cfg.no_xformers) and attn_params.is_causal(): + attn_func = self._attn_xformers + else: + attn_func = self._attn_torch + + # Straight attention without cache + + if cache is None: + + q_states = q_states.view(batch_size, q_len, self.num_attention_heads, self.head_dim) + k_states = k_states.view(batch_size, q_len, self.num_key_value_heads, self.head_dim) + v_states = v_states.view(batch_size, q_len, self.num_key_value_heads, self.head_dim) + + attn_output = attn_func(batch_size, q_len, q_states, k_states, v_states, attn_params, cfg) + + # Regular cache (FP16, FP8, Q4) + + elif isinstance(cache, ExLlamaV2CacheBase): + + q_states = q_states.view(batch_size, q_len, self.num_attention_heads, self.head_dim) + k_states = k_states.view(batch_size, q_len, self.num_key_value_heads, self.head_dim) + v_states = v_states.view(batch_size, q_len, self.num_key_value_heads, self.head_dim) + + if not direct: + batch_keys, batch_values = cache.get_kv_state(self.layer_idx, batch_size, 0, past_len) + batch_keys[:batch_size, past_len:past_len + q_len, :].copy_(k_states) + batch_values[:batch_size, past_len:past_len + q_len, :].copy_(v_states) + + k_states = batch_keys[:batch_size, :past_len + q_len, :] + v_states = batch_values[:batch_size, :past_len + q_len, :] + + cache.store_kv_state(self.layer_idx, batch_size, past_len, q_len) + + attn_output = attn_func(batch_size, q_len, q_states, k_states, v_states, attn_params, cfg) + + # Output projection + + ext_c.q_attn_forward_2( + self.q_handle, + hidden_states, + attn_output, + batch_size, + q_len, + pass_loras, + pass_lora_temp + ) + + if self.archparams.clamp_hidden_states: + hidden_states.clamp_(-65504, 65504) + + return hidden_states + + def forward_tp( + self, + hidden_states: torch.Tensor, + cache: ExLlamaV2CacheBase | None = None, + attn_params: ExLlamaV2Attention.Params | None = None, + past_len: int | None = None, + intermediates: bool = False, + loras: list[ExLlamaV2Lora] | None = None, + **kwargs + ) -> torch.Tensor: + + cfg = self.model.config + ctx = self.model.tp_context + + assert not cache or cache.q_block != 1, \ + "Models with odd key/value dims not supported in TP mode with quantized cache" + assert not self.sliding_window, \ + "Sliding window not supported in TP mode" + + attn_params.prep_tp(self.model) + + batch_size, q_len, _ = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + past_len = 0 if cache is None else cache.current_seq_len + + k_cache, v_cache = cache.get_kv_state(self.layer_idx, batch_size, 0, past_len) if cache else ([], []) + + sin, cos = ctx.get_sin_cos() + + ext_c.tp_attn_forward_( + self.model.tp_context.ext_tp_context, + hidden_states, + self.temp_bc0, + self.temp_bc1, + self.temp_bc2, + self.temp_q, + self.temp_k, + self.temp_v, + self.temp_o, + k_cache, + v_cache, + self.pre_layernorm.weight if self.pre_layernorm is not None else [], + self.pre_layernorm.variance_epsilon if self.pre_layernorm is not None else 0.0, + self.q_proj.q_handle, + self.k_proj.q_handle, + self.v_proj.q_handle, + self.o_proj.q_handle, + self.head_dim, + int(self.archparams.rope_style), + batch_size, + q_len, + sin, + cos, + attn_params.past_len_tp, + self.scaling + ) + + if cache is not None: + cache.store_kv_state(self.layer_idx, batch_size, past_len, q_len) + + return ctx.get_pinned(0, batch_size, q_len, self.hidden_size) + + + def forward_tp_old( + self, + hidden_states: torch.Tensor, + cache: ExLlamaV2CacheBase | None = None, + attn_params: ExLlamaV2Attention.Params | None = None, + past_len: int | None = None, + intermediates: bool = False, + loras: list[ExLlamaV2Lora] | None = None, + **kwargs + ): + cfg = self.model.config + split = self.model.tp_context.get_split(BROADCAST_KV) + batch_size, q_len, _ = hidden_states.shape + attn_params.prep_tp(self.model) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + past_len = 0 if cache is None else cache.current_seq_len + + assert self.q_handle is not None + use_flash_attn = has_flash_attn and not cfg.no_flash_attn + if not use_flash_attn: + assert has_lower_right_sdpa and attn_params.is_causal() and not cfg.no_sdpa and not cfg.attn_logit_softcapping, \ + "TP attention without flash-attn must use Torch SDPA with lower-right attention mask " \ + "(use PyTorch 2.4.0+) and does not support logit softcapping." + + hidden_states = self.model.tp_context.broadcast(0, hidden_states, BROADCAST_KV, dim = self.head_dim) + + residual = hidden_states + + post_norm = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states + q = self.q_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = self.head_dim) + k = self.k_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = self.head_dim) + v = self.v_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = self.head_dim) + + q = [q_.view(batch_size, q_len, q_.shape[1] // self.head_dim, self.head_dim) for q_ in q] + k = [k_.view(batch_size, q_len, k_.shape[1] // self.head_dim, self.head_dim) for k_ in k] + v = [v_.view(batch_size, q_len, v_.shape[1] // self.head_dim, self.head_dim) for v_ in v] + + if cache: + k_cache, v_cache = cache.get_kv_state(self.layer_idx, batch_size, 0, past_len) + else: + k_cache, v_cache = None, None + + if self.archparams.rope_style != RopeStyle.NONE: + for idx, (dev, a, b) in enumerate(split): + context = self.model.get_device_context(dev) + torch.cuda.set_stream(context.stream) + for t, heads in [(q[idx], self.num_key_value_groups), (k[idx], 1)]: + ext_c.rope_( + t, + context.sin, + context.cos, + past_len, + (b - a) * heads, + self.head_dim, + attn_params.position_offsets_tp[idx] if attn_params.position_offsets is not None else none_tensor, + self.archparams.rope_style == RopeStyle.NEOX + ) + + attn_outputs = [] + for idx in range(len(split)): + dev, a, b = split[idx] + context = self.model.get_device_context(dev) + torch.cuda.set_stream(context.stream) + + if k_cache is not None: + if use_flash_attn: + attn_output = flash_attn_with_kvcache( + q = q[idx], + k = k[idx], + v = v[idx], + k_cache = k_cache[idx], + v_cache = v_cache[idx], + causal = True, + softmax_scale = self.scaling, + cache_seqlens = attn_params.past_len_tp[idx] + ) + else: + cache_a = attn_params.past_len + cache_b = attn_params.past_len + q_len + k_cache[idx][:batch_size, cache_a:cache_b, :, :].copy_(k[idx]) + v_cache[idx][:batch_size, cache_a:cache_b, :, :].copy_(v[idx]) + attn_output = self._attn_torch( + batch_size, + q_len, + q[idx], + k_cache[idx][:batch_size, :cache_b, :, :], + v_cache[idx][:batch_size, :cache_b, :, :], + attn_params, + cfg + ) + else: + if use_flash_attn: + attn_output = flash_attn_func( + q[idx], + k[idx], + v[idx], + causal = True, + softmax_scale = self.scaling, + ) + else: + attn_output = self._attn_torch( + batch_size, + q_len, + q[idx], + k[idx], + v[idx], + attn_params, + cfg + ) + + attn_output = attn_output.view(batch_size * q_len, (b - a) * self.head_dim * self.num_key_value_groups) + attn_outputs.append(attn_output) + + if cache is not None: + cache.store_kv_state(self.layer_idx, batch_size, past_len, q_len) + + # Output projection + + attn_outputs = self.model.tp_context.allgather(1, attn_outputs, BROADCAST_Q, BROADCAST_Q, dim = self.head_dim) + + hidden_states = self.o_proj.forward_tp(attn_outputs, loras = loras, dim = self.head_dim, output_split = True) + + if self.has_residual: + self.model.tp_context.add_residual(hidden_states, residual, BROADCAST_Q, dim = self.head_dim) + + hidden_states = self.model.tp_context.gather(0, hidden_states, BROADCAST_Q, dim = self.head_dim) + + # if self.post_layernorm: # TODO: ... + # hidden_states = self.post_layernorm.forward(hidden_states) + + hidden_states = hidden_states.view(batch_size, q_len, hidden_states.shape[-1]) + return hidden_states + + + def forward_torch( + self, + hidden_states: torch.Tensor, + cache: ExLlamaV2CacheBase | None = None, + attn_params: ExLlamaV2Attention.Params | None = None, + past_len: int | None = None, + intermediates: bool = False, + loras: list[ExLlamaV2Lora] | None = None, + **kwargs + ) -> torch.Tensor | dict: + + global has_flash_attn + global has_xformers + + cfg = self.model.config + num_attention_heads = self.num_attention_heads + num_key_value_heads = self.num_key_value_heads + head_dim = self.head_dim + + batch_size, q_len, _ = hidden_states.size() + + past_len = 0 if cache is None else cache.current_seq_len + + # Project q, k, v + + residual = hidden_states + post_norm = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states + + # Output projection + + attn_proj = self.o_proj.forward(post_norm, loras = loras) + + # Post layernorm + + if self.post_layernorm: + attn_proj = self.post_layernorm.forward(attn_proj, output_fp32 = self.archparams.residual_stream_fp32) + + # Add residual connection + + hidden_states = (attn_proj + residual) if self.has_residual else attn_proj + + if self.archparams.residual_stream_fp32: + hidden_states = hidden_states.float() + elif self.archparams.clamp_hidden_states: + hidden_states.clamp_(-65504, 65504) + + if intermediates: + return {"post_norm": post_norm, + "hidden_states": hidden_states} + else: + return hidden_states + + + def update_loras(self): + + if self.q_handle is None: return + + cfg = self.model.config + + q_proj_lora_a = { id(k): v for k, v in self.q_proj.lora_a_tensors.items() } + q_proj_lora_b = { id(k): v for k, v in self.q_proj.lora_b_tensors.items() } + k_proj_lora_a = { id(k): v for k, v in self.k_proj.lora_a_tensors.items() } + k_proj_lora_b = { id(k): v for k, v in self.k_proj.lora_b_tensors.items() } + v_proj_lora_a = { id(k): v for k, v in self.v_proj.lora_a_tensors.items() } + v_proj_lora_b = { id(k): v for k, v in self.v_proj.lora_b_tensors.items() } + o_proj_lora_a = { id(k): v for k, v in self.o_proj.lora_a_tensors.items() } + o_proj_lora_b = { id(k): v for k, v in self.o_proj.lora_b_tensors.items() } + + temp_lora_size = ext_c.q_attn_set_loras( + self.q_handle, + q_proj_lora_a, + q_proj_lora_b, + k_proj_lora_a, + k_proj_lora_b, + v_proj_lora_a, + v_proj_lora_b, + o_proj_lora_a, + o_proj_lora_b + ) + + self.temp_lora_size = temp_lora_size * cfg.max_batch_size * cfg.max_input_len + + + def is_quant(self): + return self.q_handle is not None + + + def tp_split(self): + + cfg = self.model.config + ctx = self.model.tp_context + + if self.pre_layernorm is not None: + self.pre_layernorm.tp_split(BROADCAST_KV) + if self.post_layernorm is not None: + self.post_layernorm.tp_split(BROADCAST_KV) + + self.q_proj.tp_split(BROADCAST_Q, dim = self.head_dim) + self.k_proj.tp_split(BROADCAST_KV, dim = self.head_dim) + self.v_proj.tp_split(BROADCAST_KV, dim = self.head_dim) + self.o_proj.tp_split(BROADCAST_Q, dim = self.head_dim) + + maxrows = cfg.max_batch_size * cfg.max_input_len + dtype = torch.half + + ctx.begin_scratch_alloc_tp() + ctx.reserve_scratch(self.tp_dq_size) + self.temp_bc0 = ctx.get_scratch_slice_tp_bc(maxrows, dtype, BROADCAST_Q, dim = self.head_dim) + self.temp_bc1 = ctx.get_scratch_slice_tp_bc(maxrows, dtype, BROADCAST_Q, dim = self.head_dim) + self.temp_bc2 = ctx.get_scratch_slice_tp_bc(maxrows, dtype, BROADCAST_Q, dim = self.head_dim) + self.temp_q = ctx.get_scratch_slice_tp(maxrows, dtype, BROADCAST_Q, dim = self.head_dim) + self.temp_k = ctx.get_scratch_slice_tp(maxrows, dtype, BROADCAST_KV, dim = self.head_dim) + self.temp_v = ctx.get_scratch_slice_tp(maxrows, dtype, BROADCAST_KV, dim = self.head_dim) + self.temp_o = ctx.get_scratch_slice_tp(maxrows, dtype, BROADCAST_Q, dim = self.head_dim) + + self.is_tp = True + self.set_device_idx(None) + + + def scratch_space_tp(self): + + cfg = self.model.config + ctx = self.model.tp_context + devs = ctx.num_devices + scratch = [0] * devs + + def add(res: list[int]): + for i, s in enumerate(res): + scratch[i] += s + + def amax(res: list[int]): + for i, s in enumerate(res): + scratch[i] = max(scratch[i], s) + + amax(self.q_proj.scratch_space_tp(BROADCAST_Q, self.head_dim)) + amax(self.k_proj.scratch_space_tp(BROADCAST_KV, self.head_dim)) + amax(self.v_proj.scratch_space_tp(BROADCAST_KV, self.head_dim)) + amax(self.o_proj.scratch_space_tp(BROADCAST_Q, self.head_dim)) + self.tp_dq_size = [s for s in scratch] + + maxrows = cfg.max_batch_size * cfg.max_input_len + + add(ctx.get_temp_tensors_bc_s(maxrows, 2, BROADCAST_Q, dim = self.head_dim)) + add(ctx.get_temp_tensors_bc_s(maxrows, 2, BROADCAST_Q, dim = self.head_dim)) + add(ctx.get_temp_tensors_bc_s(maxrows, 2, BROADCAST_Q, dim = self.head_dim)) + add(ctx.get_temp_tensors_s(maxrows, 2, BROADCAST_Q, dim = self.head_dim)) + add(ctx.get_temp_tensors_s(maxrows, 2, BROADCAST_KV, dim = self.head_dim)) + add(ctx.get_temp_tensors_s(maxrows, 2, BROADCAST_KV, dim = self.head_dim)) + add(ctx.get_temp_tensors_s(maxrows, 2, BROADCAST_Q, dim = self.head_dim)) + + return scratch From a62b0fd82e385e1f2cba955585d7c34b8e94d450 Mon Sep 17 00:00:00 2001 From: Yee Man Chan Date: Wed, 29 Jan 2025 14:26:44 +0800 Subject: [PATCH 3/3] removed linear_attn.py and ExLlamaV2LinearAttetnion class by merging it to ExLlammaV2Attention --- exllamav2/attn.py | 179 +++- exllamav2/conversion/compile.py | 18 +- exllamav2/conversion/measure.py | 28 +- exllamav2/conversion/optimize.py | 12 +- exllamav2/conversion/quantize.py | 33 +- exllamav2/linear_attn.py | 1329 ------------------------------ exllamav2/model.py | 8 +- 7 files changed, 189 insertions(+), 1418 deletions(-) delete mode 100644 exllamav2/linear_attn.py diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 09fde7ae..c2e8ed5c 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -192,10 +192,16 @@ def __init__( f_d = f_c + self.num_key_value_heads * self.head_dim f_key = (key + km["fused_qkv"]) if km["fused_qkv"] else None - self.q_proj = ExLlamaV2Linear(model, key + km["attn_q"], hidden_size, self.num_attention_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_a, f_end = f_b, altpack_qkv = ap.fused_qkv_altpack) - self.k_proj = ExLlamaV2Linear(model, key + km["attn_k"], hidden_size, self.num_key_value_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_b, f_end = f_c, altpack_qkv = ap.fused_qkv_altpack) - self.v_proj = ExLlamaV2Linear(model, key + km["attn_v"], hidden_size, self.num_key_value_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_c, f_end = f_d, altpack_qkv = ap.fused_qkv_altpack) - self.o_proj = ExLlamaV2Linear(model, key + km["attn_o"], self.num_attention_heads * self.head_dim, hidden_size, ap.attention_bias_o, prescale = cfg.scale_depth) + if type(cfg.num_key_value_heads) is list and ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[layer_idx]: + self.q_proj = None + self.k_proj = None + self.v_proj = None + self.o_proj = ExLlamaV2Linear(model, key + km["linear_attn"], self.num_attention_heads * self.head_dim, hidden_size, ap.attention_bias_o, prescale = cfg.scale_depth) + else: + self.q_proj = ExLlamaV2Linear(model, key + km["attn_q"], hidden_size, self.num_attention_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_a, f_end = f_b, altpack_qkv = ap.fused_qkv_altpack) + self.k_proj = ExLlamaV2Linear(model, key + km["attn_k"], hidden_size, self.num_key_value_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_b, f_end = f_c, altpack_qkv = ap.fused_qkv_altpack) + self.v_proj = ExLlamaV2Linear(model, key + km["attn_v"], hidden_size, self.num_key_value_heads * self.head_dim, ap.attention_bias_qkv, f_key = f_key, f_beg = f_c, f_end = f_d, altpack_qkv = ap.fused_qkv_altpack) + self.o_proj = ExLlamaV2Linear(model, key + km["attn_o"], self.num_attention_heads * self.head_dim, hidden_size, ap.attention_bias_o, prescale = cfg.scale_depth) if cfg.use_qk_norm: self.q_norm = ExLlamaV2HeadNorm(model, key + ".self_attn.q_norm", self.num_attention_heads, self.head_dim) @@ -204,12 +210,17 @@ def __init__( self.q_norm = None self.k_norm = None - self.submodules = [ - self.q_proj, - self.k_proj, - self.v_proj, - self.o_proj - ] + if type(cfg.num_key_value_heads) is list and ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[layer_idx]: + self.submodules = [ + self.o_proj + ] + else: + self.submodules = [ + self.q_proj, + self.k_proj, + self.v_proj, + self.o_proj + ] if self.pre_layernorm: self.submodules += [self.pre_layernorm] @@ -230,10 +241,10 @@ def __init__( def numel(self) -> int: - numel = self.q_proj.numel() + \ - self.k_proj.numel() + \ - self.v_proj.numel() + \ - self.o_proj.numel() + numel = self.o_proj.numel() + if self.q_proj is not None: numel += self.q_proj.numel() + if self.k_proj is not None: numel += self.k_proj.numel() + if self.v_proj is not None: numel += self.v_proj.numel() if self.pre_layernorm is not None: numel += self.pre_layernorm.numel() if self.post_layernorm is not None: numel += self.post_layernorm.numel() @@ -250,10 +261,12 @@ def load(self, device_context: bool = True): if self.pre_layernorm is not None: self.pre_layernorm.load() if self.post_layernorm is not None: self.post_layernorm.load() + self.o_proj.load(device_context = device_context) + if self.q_proj is None and self.k_proj is None and self.v_proj is None: + return self.q_proj.load(device_context = device_context) self.k_proj.load(device_context = device_context) self.v_proj.load(device_context = device_context) - self.o_proj.load(device_context = device_context) if self.q_norm is not None: self.q_norm.load() if self.k_norm is not None: self.k_norm.load() @@ -354,10 +367,10 @@ def unload(self): def weight_footprint(self): - fp = self.q_proj.weight_footprint() + \ - self.k_proj.weight_footprint() + \ - self.v_proj.weight_footprint() + \ - self.o_proj.weight_footprint() + fp = self.o_proj.weight_footprint() + if self.q_proj is not None: fp += self.q_proj.weight_footprint() + if self.k_proj is not None: fp += self.k_proj.weight_footprint() + if self.v_proj is not None: fp += self.v_proj.weight_footprint() if self.pre_layernorm is not None: fp += self.pre_layernorm.weight_footprint() if self.post_layernorm is not None: @@ -413,10 +426,13 @@ def temp_v_size(self): def temp_dq_size(self): - return max(self.q_proj.temp_dq_size(), - self.k_proj.temp_dq_size(), - self.v_proj.temp_dq_size(), - self.o_proj.temp_dq_size()) + if self.q_proj is None and self.k_proj is None and self.v_proj is None: + return self.o_proj.temp_dq_size() + else: + return max(self.q_proj.temp_dq_size(), + self.k_proj.temp_dq_size(), + self.v_proj.temp_dq_size(), + self.o_proj.temp_dq_size()) def temp_kv_size(self): @@ -447,9 +463,9 @@ def set_device_idx(self, idx: int | None): if self.pre_layernorm is not None: self.pre_layernorm.set_device_idx(idx) if self.post_layernorm is not None: self.post_layernorm.set_device_idx(idx) - self.q_proj.set_device_idx(idx) - self.k_proj.set_device_idx(idx) - self.v_proj.set_device_idx(idx) + if self.q_proj is not None: self.q_proj.set_device_idx(idx) + if self.k_proj is not None: self.k_proj.set_device_idx(idx) + if self.v_proj is not None: self.v_proj.set_device_idx(idx) self.o_proj.set_device_idx(idx) if self.q_norm is not None: self.q_norm.set_device_idx(idx) if self.k_norm is not None: self.k_norm.set_device_idx(idx) @@ -475,6 +491,14 @@ def forward_paged( **kwargs ) -> torch.Tensor: + if self.q_proj is None and self.k_proj is None and self.v_proj is None: + return self.forward_paged_linear( + hidden_states, + cache, + attn_params, + loras, + **kwargs, + ) if self.is_tp: return self.forward_paged_tp( hidden_states, @@ -640,6 +664,50 @@ def forward_paged( return hidden_states + # @profile + def forward_paged_linear( + self, + hidden_states: torch.Tensor, + cache: ExLlamaV2CacheBase | None = None, + attn_params: ExLlamaV2Attention.PagedParams | None = None, + loras: list[ExLlamaV2Lora] | None = None, + **kwargs + ) -> torch.Tensor: + + is_q = self.q_handle is not None + cfg = self.model.config + constants = self.model.get_device_context(self.device_idx, scratch = is_q) + page_size = attn_params.page_size + batch_size, q_len, _ = hidden_states.shape + cache_seqlens = attn_params.get_cache_seqlens(self.device_idx) + block_table = attn_params.get_block_index(self.device_idx) + + sc = attn_params.get_alt_rope_embed(self.device_idx) + if not sc: + sin, cos = constants.sin, constants.cos + else: + sin, cos = sc + + cache_seqlens_rope = cache_seqlens + offsets = attn_params.get_rope_offsets(self.device_idx) + if offsets is not None: + cache_seqlens_rope = cache_seqlens_rope + offsets + + residual = hidden_states + hidden_states = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states + + attn_output = hidden_states.view((batch_size, q_len, self.num_attention_heads * self.head_dim)) + + cache.store_kv_state(self.layer_idx, batch_size, 0, q_len, page_size, cache_seqlens, block_table) + + hidden_states = self.o_proj.forward(attn_output, loras = loras) + if self.post_layernorm: + hidden_states = self.post_layernorm.forward(hidden_states) + if self.has_residual: + hidden_states += residual + + return hidden_states + # @profile def forward_paged_tp( @@ -1375,6 +1443,17 @@ def forward_torch( loras: list[ExLlamaV2Lora] | None = None, **kwargs ) -> torch.Tensor | dict: + + if self.q_proj is None and self.k_proj is None and self.v_proj is None: + return self.forward_torch_linear( + hidden_states, + cache, + attn_params, + past_len, + intermediates, + loras, + **kwargs, + ) global has_flash_attn global has_xformers @@ -1508,6 +1587,54 @@ def forward_torch( else: return hidden_states + def forward_torch_linear( + self, + hidden_states: torch.Tensor, + cache: ExLlamaV2CacheBase | None = None, + attn_params: ExLlamaV2Attention.Params | None = None, + past_len: int | None = None, + intermediates: bool = False, + loras: list[ExLlamaV2Lora] | None = None, + **kwargs + ) -> torch.Tensor | dict: + + cfg = self.model.config + num_attention_heads = self.num_attention_heads + num_key_value_heads = self.num_key_value_heads + head_dim = self.head_dim + + batch_size, q_len, _ = hidden_states.size() + + past_len = 0 if cache is None else cache.current_seq_len + + # Project q, k, v + + residual = hidden_states + post_norm = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states + + # Output projection + + attn_proj = self.o_proj.forward(post_norm, loras = loras) + + # Post layernorm + + if self.post_layernorm: + attn_proj = self.post_layernorm.forward(attn_proj, output_fp32 = self.archparams.residual_stream_fp32) + + # Add residual connection + + hidden_states = (attn_proj + residual) if self.has_residual else attn_proj + + if self.archparams.residual_stream_fp32: + hidden_states = hidden_states.float() + elif self.archparams.clamp_hidden_states: + hidden_states.clamp_(-65504, 65504) + + if intermediates: + return {"post_norm": post_norm, + "hidden_states": hidden_states} + else: + return hidden_states def update_loras(self): diff --git a/exllamav2/conversion/compile.py b/exllamav2/conversion/compile.py index 5c8a4f17..02aa752d 100644 --- a/exllamav2/conversion/compile.py +++ b/exllamav2/conversion/compile.py @@ -3,7 +3,6 @@ ExLlamaV2Embedding, ExLlamaV2PosEmbedding, ExLlamaV2Attention, - ExLlamaV2LinearAttention, ExLlamaV2MLP, ExLlamaV2MoEMLP, ExLlamaV2ParallelDecoder, @@ -94,17 +93,12 @@ def compile_model(job, save_fn, model): if d: out_dict.update(d); current_size += _dsize(d) d = get_f_module(job, module.post_layernorm) if d: out_dict.update(d); current_size += _dsize(d) - d = get_q_module(job, module.q_proj); out_dict.update(d); current_size += _dsize(d) - d = get_q_module(job, module.k_proj); out_dict.update(d); current_size += _dsize(d) - d = get_q_module(job, module.v_proj); out_dict.update(d); current_size += _dsize(d) - d = get_q_module(job, module.o_proj); out_dict.update(d); current_size += _dsize(d) - - if isinstance(module, ExLlamaV2LinearAttention): - - d = get_f_module(job, module.pre_layernorm) - if d: out_dict.update(d); current_size += _dsize(d) - d = get_f_module(job, module.post_layernorm) - if d: out_dict.update(d); current_size += _dsize(d) + if module.q_proj is not None: + d = get_q_module(job, module.q_proj); out_dict.update(d); current_size += _dsize(d) + if module.k_proj is not None: + d = get_q_module(job, module.k_proj); out_dict.update(d); current_size += _dsize(d) + if module.v_proj is not None: + d = get_q_module(job, module.v_proj); out_dict.update(d); current_size += _dsize(d) d = get_q_module(job, module.o_proj); out_dict.update(d); current_size += _dsize(d) if isinstance(module, ExLlamaV2MLP): diff --git a/exllamav2/conversion/measure.py b/exllamav2/conversion/measure.py index bcaca269..d5f54ab7 100644 --- a/exllamav2/conversion/measure.py +++ b/exllamav2/conversion/measure.py @@ -3,7 +3,6 @@ ExLlamaV2Embedding, ExLlamaV2PosEmbedding, ExLlamaV2Attention, - ExLlamaV2LinearAttention, ExLlamaV2MLP, ExLlamaV2MoEMLP, ExLlamaV2Linear, @@ -143,6 +142,9 @@ def test_error(module, hidden_states, target_states, cache, attn_params): def measure_attn(module, hidden_states, target_states, quantizers, cache, attn_params, keep_q = False): +# return linear_attn call if q,k,v are not in this layer + if "q_proj" not in quantizers and "k_proj" not in quantizers and "v_proj" not in quantizers: + return measure_linear_attn(module, hidden_states, target_states, quantizers, cache, attn_params) qjobs, qmaps = get_qparams_reduced(qparams_attn) results = [] @@ -527,18 +529,14 @@ def measure_quant(job, save_fn, model, hidden_state_offload_layers): if isinstance(module, ExLlamaV2Attention): mode = "self_attn" - if module.q_proj.linear is not None: + if module.q_proj is not None: quantizers["q_proj"] = AdaptiveGPTQ(module.q_proj.linear) - if module.k_proj.linear is not None: + if module.k_proj is not None: quantizers["k_proj"] = AdaptiveGPTQ(module.k_proj.linear) - if module.v_proj.linear is not None: + if module.v_proj is not None: quantizers["v_proj"] = AdaptiveGPTQ(module.v_proj.linear) quantizers["o_proj"] = AdaptiveGPTQ(module.o_proj.linear) - elif isinstance(module, ExLlamaV2LinearAttention): - mode = "linear_attn" - quantizers["o_proj"] = AdaptiveGPTQ(module.o_proj.linear) - elif isinstance(module, ExLlamaV2MLP): mode = "mlp" has_gate = module.model.config.arch.lm.mlp_gate @@ -578,7 +576,7 @@ def measure_quant(job, save_fn, model, hidden_state_offload_layers): cache = None attn_params = ExLlamaV2Attention.Params(1, hidden_states[0].shape[1], 0, None, None) \ - if mode in ["self_attn", "linear_attn", "parallel_decoder"] else None + if mode in ["self_attn", "parallel_decoder"] else None target_states = [] target_states_attn = [] @@ -625,13 +623,13 @@ def measure_quant(job, save_fn, model, hidden_state_offload_layers): # Hessians if mode == "self_attn": - quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K and V - quantizers["o_proj"].add_batch(outputs["attn_output"]) + if module.q_proj is None and module.k_proj is None and module.v_proj is None: + quantizers["o_proj"].add_batch(outputs["post_norm"]) + else: + quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K and V + quantizers["o_proj"].add_batch(outputs["attn_output"]) target_states.append(outputs["hidden_states"].to(target_device)) - if mode == "linear_attn": - quantizers["o_proj"].add_batch(outputs["post_norm"]) - target_states.append(outputs["hidden_states"].to(target_device)) if mode == "mlp": quantizers["up_proj"].add_batch(outputs["post_norm"]) # Reuse H for gate_proj @@ -676,8 +674,6 @@ def measure_quant(job, save_fn, model, hidden_state_offload_layers): if mode == "self_attn": m = measure_attn(module, hidden_states, target_states, quantizers, cache, attn_params) - if mode == "linear_attn": - m = measure_linear_attn(module, hidden_states, target_states, quantizers, cache, attn_params) if mode == "mlp": m = measure_mlp(module, hidden_states, target_states, quantizers, cache, attn_params) diff --git a/exllamav2/conversion/optimize.py b/exllamav2/conversion/optimize.py index dfb1e641..51d69081 100644 --- a/exllamav2/conversion/optimize.py +++ b/exllamav2/conversion/optimize.py @@ -88,10 +88,7 @@ def optimize(job, save_fn, model): m1 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + ".parallel_decoder"]["attn"] m2 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + ".parallel_decoder"]["mlp"] elif type(cfg.arch.lm.layer_keys) is list and type(cfg.intermediate_size) is list: - if ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[i]: - m1 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + ".linear_attn"] - m2 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + "." + mlp_mode] - elif ["self_attn.o_proj"] in cfg.arch.lm.layer_keys[i]: + if ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[i] or ["self_attn.o_proj"] in cfg.arch.lm.layer_keys[i]: m1 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + ".self_attn"] m2 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + "." + mlp_mode] else: @@ -168,12 +165,7 @@ def optimize(job, save_fn, model): deci_offset = 0 for layer_ in range(num_layers): if type(cfg.arch.lm.layer_keys) is list and type(cfg.intermediate_size) is list: - if ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[layer_]: - k1 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + ".linear_attn" - k2 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + "." + mlp_mode - p1 = params[layer_ * 2-deci_offset][solution_idx[layer_ * 2-deci_offset]] - p2 = params[layer_ * 2 + 1-deci_offset][solution_idx[layer_ * 2 + 1-deci_offset]] - elif ["self_attn.o_proj"] in cfg.arch.lm.layer_keys[layer_]: + if ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[layer_] or ["self_attn.o_proj"] in cfg.arch.lm.layer_keys[layer_]: k1 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + ".self_attn" k2 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + "." + mlp_mode p1 = params[layer_ * 2-deci_offset][solution_idx[layer_ * 2-deci_offset]] diff --git a/exllamav2/conversion/quantize.py b/exllamav2/conversion/quantize.py index 19ac1ce6..578be13f 100644 --- a/exllamav2/conversion/quantize.py +++ b/exllamav2/conversion/quantize.py @@ -3,7 +3,6 @@ ExLlamaV2Embedding, ExLlamaV2PosEmbedding, ExLlamaV2Attention, - ExLlamaV2LinearAttention, ExLlamaV2MLP, ExLlamaV2MoEMLP, ExLlamaV2ParallelDecoder, @@ -135,6 +134,10 @@ def quant_linear(job: dict, def quant_attn(job, module, hidden_states, target_states, quantizers, attn_params, strat): + # return quant_linear_attn in linear attention layer + if "q_proj" not in quantizers and "k_proj" not in quantizers and "v_proj" not in quantizers: + return quant_linear_attn(job, module, hidden_states, target_states, quantizers, attn_params, strat) + quantizers["q_proj"].prepare() quantizers["k_proj"].reuse_h(quantizers["q_proj"]) quantizers["v_proj"].reuse_h(quantizers["q_proj"]) @@ -316,14 +319,12 @@ def quant(job, save_fn, model): if isinstance(module, ExLlamaV2Attention): mode = "self_attn" # if index > 1: testc(module, hidden_states, hidden_i_states, module.input_layernorm, [module.q_proj, module.k_proj, module.v_proj]) - quantizers["q_proj"] = AdaptiveGPTQ(module.q_proj.linear) - quantizers["k_proj"] = AdaptiveGPTQ(module.k_proj.linear) - quantizers["v_proj"] = AdaptiveGPTQ(module.v_proj.linear) - quantizers["o_proj"] = AdaptiveGPTQ(module.o_proj.linear) - - elif isinstance(module, ExLlamaV2LinearAttention): - mode = "linear_attn" - # if index > 1: testc(module, hidden_states, hidden_i_states, module.input_layernorm, [module.q_proj, module.k_proj, module.v_proj]) + if module.q_proj is not None: + quantizers["q_proj"] = AdaptiveGPTQ(module.q_proj.linear) + if module.k_proj is not None: + quantizers["k_proj"] = AdaptiveGPTQ(module.k_proj.linear) + if module.v_proj is not None: + quantizers["v_proj"] = AdaptiveGPTQ(module.v_proj.linear) quantizers["o_proj"] = AdaptiveGPTQ(module.o_proj.linear) elif isinstance(module, ExLlamaV2MLP): @@ -382,11 +383,11 @@ def quant(job, save_fn, model): # Hessians if mode == "self_attn": - quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K and V - quantizers["o_proj"].add_batch(outputs["attn_output"]) - - if mode == "linear_attn": - quantizers["o_proj"].add_batch(outputs["post_norm"]) + if "q_proj" not in quantizers and "k_proj" not in quantizers and "v_proj" not in quantizers: + quantizers["o_proj"].add_batch(outputs["post_norm"]) + else: + quantizers["q_proj"].add_batch(outputs["post_norm"]) # Reuse H for K and V + quantizers["o_proj"].add_batch(outputs["attn_output"]) if mode == "mlp": quantizers["up_proj"].add_batch(outputs["post_norm"]) # Reuse H for gate_proj @@ -429,10 +430,6 @@ def quant(job, save_fn, model): strat = strategy[module.key + "." + mode] quant_attn(job, module, hidden_states, target_states, quantizers, attn_params, strat) - if mode == "linear_attn": - strat = strategy[module.key + "." + mode] - quant_linear_attn(job, module, hidden_states, target_states, quantizers, attn_params, strat) - if mode == "mlp": strat = strategy[module.key + "." + mode] quant_mlp(job, module, hidden_states, target_states, quantizers, attn_params, strat) diff --git a/exllamav2/linear_attn.py b/exllamav2/linear_attn.py deleted file mode 100644 index dc36d7ff..00000000 --- a/exllamav2/linear_attn.py +++ /dev/null @@ -1,1329 +0,0 @@ -from __future__ import annotations - -import torch -from torch import nn -from exllamav2.module import ExLlamaV2Module -from exllamav2.rmsnorm import ExLlamaV2RMSNorm -from exllamav2.layernorm import ExLlamaV2LayerNorm -from exllamav2.headnorm import ExLlamaV2HeadNorm -from exllamav2.linear import ExLlamaV2Linear -from exllamav2.cache import ExLlamaV2CacheBase -from exllamav2.ext import exllamav2_ext as ext_c, none_tensor -from exllamav2.lora import ExLlamaV2Lora -from exllamav2.architecture import RopeStyle -from exllamav2.tensor_p import BROADCAST_KV, BROADCAST_Q -import math -import torch.nn.functional as F -import inspect -import os - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from exllamav2.model import ExLlamaV2 - -# Detect available options for attention - -has_flash_attn = False -has_flash_attn_with_paged = False -has_flash_attn_with_window = False -has_flash_attn_with_softcap = False -has_xformers = False -has_lower_right_sdpa = False - -if 'EXLLAMA_NO_FLASH_ATTN' not in os.environ: - - try: - import flash_attn - flash_attn_ver = [int(t) for t in flash_attn.__version__.split(".") if t.isdigit()] - is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count())) - - if not is_ampere_or_newer_gpu: - print(" ## Warning: Flash Attention is installed but unsupported GPUs were detected.") - - if [2, 2, 1] <= flash_attn_ver < [2, 5, 7]: - from flash_attn import flash_attn_func - has_flash_attn = True - - if [2, 5, 7] <= flash_attn_ver: - from flash_attn import flash_attn_func, flash_attn_with_kvcache - # import flash_attn_2_cuda as flash_attn_cuda - - signature = list(inspect.signature(flash_attn_func).parameters) - has_flash_attn_with_window = "window_size" in signature - has_flash_attn_with_softcap = "softcap" in signature - - import flash_attn_2_cuda as flash_attn_cuda - # ext_c.set_flash_attn_func() - - has_flash_attn = True - has_flash_attn_with_paged = True - - except ModuleNotFoundError: - pass - except NameError: - pass - -if 'EXLLAMA_NO_XFORMERS' not in os.environ: - - try: - import xformers.ops as xops - # LowerTriangularFromBottomRightMask was added in xformers version 2.4 - from xformers.ops.fmha.attn_bias import LowerTriangularFromBottomRightMask - has_xformers = True - except ModuleNotFoundError: - pass - -if 'EXLLAMA_NO_SDPA' not in os.environ: - try: - from torch.nn.attention.bias import causal_lower_right - has_lower_right_sdpa = True - except ImportError: - pass - - -def assert_paged_attn(): - """ - Raise an exception if paged attention is not available. - """ - global has_flash_attn_with_paged - assert has_flash_attn_with_paged, \ - "Paged attention required Flash Attention 2.5.7 or later" - - -class ExLlamaV2LinearAttention(ExLlamaV2Module): - - name: str = "LinearAttention" - - layer_idx: int - pre_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None - post_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None - q_proj: ExLlamaV2Linear | None - k_proj: ExLlamaV2Linear | None - v_proj: ExLlamaV2Linear | None - o_proj: ExLlamaV2Linear | None - q_norm: ExLlamaV2HeadNorm | None - k_norm: ExLlamaV2HeadNorm | None - - q_handle: int | None - - temp_state: torch.tensor - temp_q: torch.tensor - temp_k: torch.tensor - temp_v: torch.tensor - temp_o: torch.tensor - temp_dq: torch.tensor - # temp_kv: torch.tensor - - temp_lora_size: int - - has_norm: bool - has_residual: bool - scaling: float - sliding_window: int - - is_tp: bool - tp_dq_size: list[int] | None - - from exllamav2.attn_params import Params - from exllamav2.attn_params import PagedParams - - def __init__( - self, - model: ExLlamaV2, - key: str, - layer_idx: int, - has_norm: bool = True, - has_residual: bool = True, - sliding_window: int = 0, - archparams = None - ): - super().__init__(model, key, archparams) - - cfg = self.model.config - ap = self.archparams - km = self.archparams.keys - - self.is_tp = False - self.tp_dq_size = None - - self.layer_idx = layer_idx - self.has_norm = has_norm - self.has_residual = has_residual - - self.q_handle = None - self.temp_lora_size = 0 - - if ap.is_vision: - self.num_attention_heads = cfg.vision_num_attention_heads - self.num_key_value_heads = cfg.vision_num_key_value_heads - self.num_key_value_groups = cfg.vision_num_key_value_groups - self.head_dim = cfg.vision_head_dim - self.hidden_size = cfg.vision_hidden_size - else: - self.num_attention_heads = cfg.num_attention_heads - if type(cfg.num_key_value_heads) is list: - self.num_key_value_heads = cfg.num_key_value_heads[layer_idx] - else: - self.num_key_value_heads = cfg.num_key_value_heads - if type(cfg.num_key_value_groups) is list: - self.num_key_value_groups = cfg.num_key_value_groups[layer_idx] - else: - self.num_key_value_groups = cfg.num_key_value_groups - self.head_dim = cfg.head_dim - self.hidden_size = cfg.hidden_size - - hidden_size = self.hidden_size - - if self.has_norm and (km["norm_1"] or km["norm_1_post"]): - if ap.norm == "layernorm": - self.pre_layernorm = ExLlamaV2LayerNorm(model, key + km["norm_1"], archparams) - self.post_layernorm = ExLlamaV2LayerNorm(model, key + km["norm_1_post"], archparams) if km["norm_1_post"] else None - elif ap.norm == "rmsnorm": - self.pre_layernorm = ExLlamaV2RMSNorm(model, key + km["norm_1"], archparams) - self.post_layernorm = ExLlamaV2RMSNorm(model, key + km["norm_1_post"], archparams) if km["norm_1_post"] else None - else: - self.pre_layernorm = None - self.post_layernorm = None - self.has_norm = False - - f_a = 0 - f_b = self.num_attention_heads * self.head_dim - f_c = f_b + self.num_key_value_heads * self.head_dim - f_d = f_c + self.num_key_value_heads * self.head_dim - f_key = (key + km["fused_qkv"]) if km["fused_qkv"] else None - - self.o_proj = ExLlamaV2Linear(model, key + km["linear_attn"], self.num_attention_heads * self.head_dim, hidden_size, ap.attention_bias_o, prescale = cfg.scale_depth) - - if cfg.use_qk_norm: - self.q_norm = ExLlamaV2HeadNorm(model, key + ".self_attn.q_norm", self.num_attention_heads, self.head_dim) - self.k_norm = ExLlamaV2HeadNorm(model, key + ".self_attn.k_norm", self.num_key_value_heads, self.head_dim) - else: - self.q_norm = None - self.k_norm = None - - self.submodules = [ - self.o_proj - ] - - if self.pre_layernorm: - self.submodules += [self.pre_layernorm] - if self.post_layernorm: - self.submodules += [self.post_layernorm] - if cfg.use_qk_norm: - self.submodules += [self.q_norm, self.k_norm] - - if cfg.attention_multiplier: - self.scaling = cfg.attention_multiplier - elif cfg.query_pre_attn_scalar: - self.scaling = cfg.query_pre_attn_scalar ** (-0.5) - else: - self.scaling = 1 / math.sqrt(self.head_dim) - - self.sliding_window = sliding_window - - - def numel(self) -> int: - - numel = self.o_proj.numel() - - if self.pre_layernorm is not None: numel += self.pre_layernorm.numel() - if self.post_layernorm is not None: numel += self.post_layernorm.numel() - if self.q_norm is not None: numel += self.q_norm.numel() - if self.k_norm is not None: numel += self.k_norm.numel() - - return numel - - - @torch.inference_mode - def load(self, device_context: bool = True): - - cfg = self.model.config - - if self.pre_layernorm is not None: self.pre_layernorm.load() - if self.post_layernorm is not None: self.post_layernorm.load() - self.o_proj.load(device_context = device_context) - - def unload(self): - if self.q_handle is not None: - ext_c.free_q_attn(self.q_handle) - self.q_handle = None - - if self.pre_layernorm is not None: self.pre_layernorm.unload() - if self.post_layernorm is not None: self.post_layernorm.unload() - self.o_proj.unload() - - self.temp_state = None - self.temp_dq = None - - def weight_footprint(self): - - fp = self.o_proj.weight_footprint() - if self.pre_layernorm is not None: - fp += self.pre_layernorm.weight_footprint() - if self.post_layernorm is not None: - fp += self.post_layernorm.weight_footprint() - - return fp - - - def scratch_space_fixed(self): - - return self.temp_state_size() + \ - self.temp_dq_size() - - - def scratch_space(self): - - return self.temp_state_size() + \ - self.temp_dq_size() + \ - self.temp_kv_size() - # self.temp_attn_size() + # Accounted for separately in model.set_device_map() - - - def temp_state_size(self): - - cfg = self.model.config - return cfg.max_input_len * cfg.max_batch_size * max(self.num_attention_heads * self.head_dim, self.hidden_size) * 2 + 128 - - - def temp_q_size(self): - - cfg = self.model.config - return cfg.max_input_len * cfg.max_batch_size * self.num_attention_heads * self.head_dim * 2 + 128 - - - def temp_k_size(self): - - cfg = self.model.config - return cfg.max_input_len * cfg.max_batch_size * self.num_key_value_heads * self.head_dim * 2 + 128 - - - def temp_v_size(self): - - cfg = self.model.config - return cfg.max_input_len * cfg.max_batch_size * self.num_key_value_heads * self.head_dim * 2 + 128 - - - def temp_dq_size(self): - - return self.o_proj.temp_dq_size() - - def temp_kv_size(self): - - cfg = self.model.config - if self.num_key_value_heads == self.num_attention_heads: return 0 - return 2 * cfg.max_seq_len * cfg.max_batch_size * self.num_attention_heads * self.head_dim * 2 + 128 - - - def temp_attn_size(self): - global has_flash_attn - global has_xformers - - cfg = self.model.config - att_max = min(cfg.max_attention_size, cfg.max_seq_len ** 2) - - if (has_flash_attn and not cfg.no_flash_attn) or (has_xformers and not cfg.no_xformers) : - #in sm>=80 devices, xformers uses the same memory as flash_attn - #todo: due to the different implementions. in sm<80 devices, xformers uses less memory than it in sm>=80. There may still be room for optimization. - eff = cfg.max_attention_size ** 0.5 / 190 # based on supposed memory savings listed in flash-attn repo + some fudging - att_max //= eff - - return 2 * att_max * self.num_attention_heads * 2 + 128 - - - def set_device_idx(self, idx: int | None): - super().set_device_idx(idx) - - if self.pre_layernorm is not None: self.pre_layernorm.set_device_idx(idx) - if self.post_layernorm is not None: self.post_layernorm.set_device_idx(idx) - self.o_proj.set_device_idx(idx) - - def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - - if n_rep == 1: return hidden_states - - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - hidden_states = hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - return hidden_states - - - # @profile - def forward_paged( - self, - hidden_states: torch.Tensor, - cache: ExLlamaV2CacheBase | None = None, - attn_params: ExLlamaV2Attention.PagedParams | None = None, - loras: list[ExLlamaV2Lora] | None = None, - **kwargs - ) -> torch.Tensor: - - if self.is_tp: - return self.forward_paged_tp( - hidden_states, - cache, - attn_params, - loras, - **kwargs, - ) - - is_q = self.q_handle is not None - cfg = self.model.config - constants = self.model.get_device_context(self.device_idx, scratch = is_q) - page_size = attn_params.page_size - batch_size, q_len, _ = hidden_states.shape - cache_seqlens = attn_params.get_cache_seqlens(self.device_idx) - block_table = attn_params.get_block_index(self.device_idx) - - # TODO: We only need keys/values when preprocess_only == True, so we could skip q projection and attention - # on the last layer. Would need custom kernel to update paged cache if not calling flash_attn_with_kvcache - # skip_attn = kwargs.get("kv_only") - - # TODO: Potentially we could emulate paged cache when in Q4 mode, since that requires copying the active part - # of the current cache layer anyway. Test if block diagonal masking works with lower-right aligned mask. - - if cache.q_block > 1: - k_cache_f, v_cache_f = cache.get_kv_state(self.layer_idx, batch_size, 0, attn_params.max_cache_seqlen, page_size, cache_seqlens, block_table) - else: - k_cache_f, v_cache_f = cache.get_kv_state(self.layer_idx, batch_size, 0, 0, page_size, cache_seqlens, block_table) - - k_cache = k_cache_f.view(k_cache_f.shape[1] // page_size, page_size, k_cache_f.shape[2], k_cache_f.shape[3]) - v_cache = v_cache_f.view(v_cache_f.shape[1] // page_size, page_size, v_cache_f.shape[2], v_cache_f.shape[3]) - - sc = attn_params.get_alt_rope_embed(self.device_idx) - if not sc: - sin, cos = constants.sin, constants.cos - else: - sin, cos = sc - - cache_seqlens_rope = cache_seqlens - offsets = attn_params.get_rope_offsets(self.device_idx) - if offsets is not None: - cache_seqlens_rope = cache_seqlens_rope + offsets - - if is_q: - q = torch.empty((batch_size, q_len, self.num_attention_heads, self.head_dim), device = hidden_states.device, dtype = torch.half) - if attn_params.is_sequential: - assert batch_size == 1 - k = k_cache_f[:, attn_params.first_index : attn_params.first_index + q_len, :, :] - v = v_cache_f[:, attn_params.first_index : attn_params.first_index + q_len, :, :] - else: - k = torch.empty((batch_size, q_len, self.num_key_value_heads, self.head_dim), device = hidden_states.device, dtype = torch.half) - v = torch.empty((batch_size, q_len, self.num_key_value_heads, self.head_dim), device = hidden_states.device, dtype = torch.half) - - if loras is None or self.temp_lora_size == 0: - pass_loras = [] - pass_lora_temp = none_tensor - else: - pass_loras = [id(x) for x in loras] - pass_lora_temp = torch.empty((self.temp_lora_size,), dtype = torch.half, device = hidden_states.device) - - ext_c.q_attn_forward_1( - self.q_handle, - hidden_states, - batch_size, - q_len, - 0, - cache_seqlens_rope, - q, - k, - v, - sin, - cos, - pass_loras, - pass_lora_temp - ) - else: - residual = hidden_states - hidden_states = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states - - attn_output = hidden_states.view((batch_size, q_len, self.num_attention_heads * self.head_dim)) - - cache.store_kv_state(self.layer_idx, batch_size, 0, q_len, page_size, cache_seqlens, block_table) - - # Output projection - - if is_q: - ext_c.q_attn_forward_2( - self.q_handle, - hidden_states, - attn_output, - batch_size, - q_len, - pass_loras, - pass_lora_temp - ) - else: - hidden_states = self.o_proj.forward(attn_output, loras = loras) - if self.post_layernorm: - hidden_states = self.post_layernorm.forward(hidden_states) - if self.has_residual: - hidden_states += residual - - return hidden_states - - - # @profile - def forward_paged_tp( - self, - hidden_states: torch.Tensor, - cache: ExLlamaV2CacheBase | None = None, - attn_params: ExLlamaV2Attention.PagedParams | None = None, - loras: list[ExLlamaV2Lora] | None = None, - **kwargs - ) -> torch.Tensor: - - cfg = self.model.config - ctx = self.model.tp_context - - assert not self.sliding_window, \ - "Sliding window not supported in TP mode" - - attn_params.prep_tp(self.model) - page_size = attn_params.page_size - - batch_size, q_len, _ = hidden_states.shape - rows = batch_size * q_len - hidden_states = hidden_states.view(-1, self.hidden_size) - dtype = hidden_states.dtype - - k_cache_f, v_cache_f = cache.get_kv_state( - self.layer_idx, - batch_size, - 0, - attn_params.max_cache_seqlen, - page_size, - attn_params.cache_seqlens_tp, - attn_params.block_index_tp - ) - - k_cache = [x.view(x.shape[1] // page_size, page_size, x.shape[2], x.shape[3]) for x in k_cache_f] - v_cache = [x.view(x.shape[1] // page_size, page_size, x.shape[2], x.shape[3]) for x in v_cache_f] - - sin, cos = ctx.get_sin_cos() - - ext_c.tp_attn_forward_paged_( - self.model.tp_context.ext_tp_context, - hidden_states, - self.temp_bc0, - self.temp_bc1, - self.temp_bc2, - self.temp_q, - self.temp_k, - self.temp_v, - self.temp_o, - k_cache, - v_cache, - self.pre_layernorm.weight if self.pre_layernorm is not None else [], - self.pre_layernorm.variance_epsilon if self.pre_layernorm is not None else 0.0, - self.q_proj.q_handle, - self.k_proj.q_handle, - self.v_proj.q_handle, - self.o_proj.q_handle, - self.head_dim, - int(self.archparams.rope_style), - batch_size, - q_len, - sin, - cos, - attn_params.cache_seqlens_tp, - attn_params.block_index_tp, - self.scaling - ) - - cache.store_kv_state( - self.layer_idx, - batch_size, - 0, - q_len, - page_size, - attn_params.cache_seqlens_tp, - attn_params.block_index_tp - ) - - return ctx.get_pinned(0, batch_size, q_len, self.hidden_size) - - - # @profile - def forward_paged_tp_old( - self, - hidden_states: torch.Tensor, - cache: ExLlamaV2CacheBase | None = None, - attn_params: ExLlamaV2Attention.PagedParams | None = None, - loras: list[ExLlamaV2Lora] | None = None, - **kwargs - ) -> torch.Tensor: - - assert self.q_handle is not None - cfg = self.model.config - split = self.model.tp_context.get_split(BROADCAST_KV) - batch_size, q_len, _ = hidden_states.shape - attn_params.prep_tp(self.model) - page_size = attn_params.page_size - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - - k_cache_f, v_cache_f = cache.get_kv_state( - self.layer_idx, - batch_size, - 0, - attn_params.max_cache_seqlen, - page_size, - attn_params.cache_seqlens_tp, - attn_params.block_index_tp - ) - - k_cache = [x.view(x.shape[1] // page_size, page_size, x.shape[2], x.shape[3]) for x in k_cache_f] - v_cache = [x.view(x.shape[1] // page_size, page_size, x.shape[2], x.shape[3]) for x in v_cache_f] - - hidden_states = self.model.tp_context.broadcast(0, hidden_states, BROADCAST_KV, dim = self.head_dim) - - residual = hidden_states - - post_norm = self.pre_layernorm.forward_tp(hidden_states, output_split = True) if self.has_norm else hidden_states - q = self.q_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = self.head_dim) - k = self.k_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = self.head_dim) - v = self.v_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = self.head_dim) - q = [q_.view(batch_size, q_len, q_.shape[1] // self.head_dim, self.head_dim) for q_ in q] - k = [k_.view(batch_size, q_len, k_.shape[1] // self.head_dim, self.head_dim) for k_ in k] - v = [v_.view(batch_size, q_len, v_.shape[1] // self.head_dim, self.head_dim) for v_ in v] - if cfg.use_qk_norm: - assert False, "TP not implemented for QK norm" # TODO: ... - # q = self.q_norm.forward(q) - # k = self.k_norm.forward(k) - if self.archparams.rope_style != RopeStyle.NONE: - for idx, (dev, a, b) in enumerate(split): - context = self.model.get_device_context(dev) - torch.cuda.set_stream(context.stream) - for t, heads in [(q[idx], self.num_key_value_groups), (k[idx], 1)]: - ext_c.rope_( - t, - context.sin, - context.cos, - 0, - (b - a) * heads, - self.head_dim, - attn_params.cache_seqlens_tp[idx], - self.archparams.rope_style == RopeStyle.NEOX - ) - if attn_params.is_sequential: - k_ = [x[:, attn_params.first_index: attn_params.first_index + q_len, :, :] for x in k_cache_f] - v_ = [x[:, attn_params.first_index: attn_params.first_index + q_len, :, :] for x in v_cache_f] - for (dev, a, b), x_, x, y_, y in zip(split, k_, k, v_, v): - context = self.model.get_device_context(dev) - torch.cuda.set_stream(context.stream) - x_.copy_(x) - y_.copy_(y) - k = None - v = None - cache_seqlens_a = attn_params.cache_seqlens_after_tp - else: - cache_seqlens_a = attn_params.cache_seqlens_tp - - # if cache.q_block == 1: - # cache.get_kv_state( - # self.layer_idx, - # batch_size, - # 0, - # attn_params.max_cache_seqlen, - # page_size, - # attn_params.cache_seqlens_tp, - # attn_params.block_index_tp - # ) - - flash_kwargs = {} - if self.sliding_window: - # assert has_flash_attn_with_window, \ - # "Installed version of flash-attn does not support sliding window" - if has_flash_attn_with_window: - flash_kwargs["window_size"] = (self.sliding_window, self.sliding_window) - if cfg.attn_logit_softcapping: - # assert has_flash_attn_with_softcap, \ - # "Installed version of flash-attn does not support softcapping" - if has_flash_attn_with_softcap: - flash_kwargs["softcap"] = cfg.attn_logit_softcapping - - attn_outputs = [] - for idx in range(len(split)): - dev, a, b = split[idx] - context = self.model.get_device_context(dev) - torch.cuda.set_stream(context.stream) - - attn_output = flash_attn_with_kvcache( - q = q[idx], - k = k[idx] if k is not None else None, - v = v[idx] if v is not None else None, - k_cache = k_cache[idx], - v_cache = v_cache[idx], - cache_seqlens = cache_seqlens_a[idx], - block_table = attn_params.block_index_tp[idx], - causal = True, - softmax_scale = self.scaling, - **flash_kwargs - ) - attn_output = attn_output.view(batch_size * q_len, (b - a) * self.head_dim * self.num_key_value_groups) - attn_outputs.append(attn_output) - - cache.store_kv_state( - self.layer_idx, - batch_size, - 0, - q_len, - page_size, - attn_params.cache_seqlens_tp, - attn_params.block_index_tp - ) - - # Output projection - - attn_outputs = self.model.tp_context.allgather(1, attn_outputs, BROADCAST_Q, BROADCAST_Q, dim = self.head_dim) - - hidden_states = self.o_proj.forward_tp(attn_outputs, loras = loras, dim = self.head_dim, output_split = True) - - if self.has_residual: - self.model.tp_context.add_residual(hidden_states, residual, BROADCAST_Q, dim = self.head_dim) - - hidden_states = self.model.tp_context.gather(0, hidden_states, BROADCAST_Q, dim = self.head_dim) - - # if self.post_layernorm: # TODO: ... - # hidden_states = self.post_layernorm.forward(hidden_states) - - hidden_states = hidden_states.view(batch_size, q_len, hidden_states.shape[-1]) - return hidden_states - - def _attn_torch(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg, causal = True): - - num_attn_heads = q_states.shape[2] - head_dim = q_states.shape[3] - - q_states = q_states.transpose(1, 2) - k_states = k_states.transpose(1, 2) - v_states = v_states.transpose(1, 2) - - # SDPA - - if has_lower_right_sdpa and not cfg.no_sdpa and not cfg.attn_logit_softcapping: - - k_states = self.repeat_kv(k_states, self.num_key_value_groups) - v_states = self.repeat_kv(v_states, self.num_key_value_groups) - - if self.sliding_window and k_states.shape[2] >= self.sliding_window: - k_states = k_states[:, :, -self.sliding_window:, :] - v_states = v_states[:, :, -self.sliding_window:, :] - - if attn_params.is_causal(): - attn_mask_lr = causal_lower_right(q_len, k_states.shape[2]) - else: - attn_mask_lr = attn_params.get_attn_mask(q_states.device) - attn_output = F.scaled_dot_product_attention( - q_states, - k_states, - v_states, - attn_mask_lr if causal else None, - scale = self.scaling - ) - - # Matmul attn - - else: - - k_states = self.repeat_kv(k_states, self.num_key_value_groups) - k_states = k_states.transpose(-1, -2) - - attn_weights = torch.matmul(q_states, k_states) - - attn_weights *= self.scaling - if causal: - attn_mask = attn_params.get_attn_mask(attn_weights.device) - - if cfg.attn_logit_softcapping: - ext_c.softcap_(attn_weights, cfg.attn_logit_softcapping) - if causal and attn_mask is not None: - attn_weights = attn_weights + attn_mask - if self.sliding_window and k_states.shape[-1] >= self.sliding_window: - attn_weights = attn_weights[:, :, :, -self.sliding_window:] - v_states = v_states[:, :, -self.sliding_window:, :] - - attn_weights = nn.functional.softmax(attn_weights, dim = -1, dtype = torch.float16) - - v_states = self.repeat_kv(v_states, self.num_key_value_groups) - attn_output = torch.matmul(attn_weights, v_states) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape((batch_size, q_len, num_attn_heads * head_dim)) - return attn_output - - - def _attn_flash(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg, causal = True): - - flash_kwargs = {} - if self.sliding_window: - # assert has_flash_attn_with_window, \ - # "Installed version of flash-attn does not support sliding window" - if has_flash_attn_with_window: - flash_kwargs["window_size"] = (self.sliding_window, self.sliding_window) - if cfg.attn_logit_softcapping: - # assert has_flash_attn_with_softcap, \ - # "Installed version of flash-attn does not support softcapping" - if has_flash_attn_with_softcap: - flash_kwargs["softcap"] = cfg.attn_logit_softcapping - - attn_output = flash_attn_func( - q_states, - k_states, - v_states, - causal = causal, - softmax_scale = self.scaling, - **flash_kwargs - ) - attn_output = attn_output.reshape((batch_size, q_len, self.num_attention_heads * self.head_dim)) - return attn_output - - - def _attn_xformers(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg, causal = True): - - # assert not self.sliding_window, \ - # "Sliding window not currently supported for xformers" - - # assert not cfg.attn_logit_softcapping, \ - # "Softcap not yet supported for xformers" - - # xformers memory_efficient_attention, could be beneficial if your device's architecture is less than sm_80 are almost the same. But the martix operation - # make this implemention much slower. - - k_states = k_states.transpose(1, 2) - v_states = v_states.transpose(1, 2) - - k_states = self.repeat_kv(k_states, self.num_key_value_groups) - v_states = self.repeat_kv(v_states, self.num_key_value_groups) - - k_states = k_states.transpose(1, 2) - v_states = v_states.transpose(1, 2) - - attn_output = xops.memory_efficient_attention( - q_states, - k_states, - v_states, - attn_bias = LowerTriangularFromBottomRightMask() if causal else None, - scale = self.scaling - ) - attn_output = attn_output.reshape((batch_size, q_len, self.num_attention_heads * self.head_dim)) - - return attn_output - - - # @profile - def forward( - self, - hidden_states: torch.Tensor, - cache: ExLlamaV2CacheBase | None = None, - attn_params: ExLlamaV2Attention.Params | None = None, - past_len: int | None = None, - intermediates: bool = False, - loras: list[ExLlamaV2Lora] | None = None, - **kwargs - ) -> torch.Tensor | dict[str: torch.Tensor]: - - cfg = self.model.config - global has_flash_attn - global has_xformers - use_flash_attn = has_flash_attn and not cfg.no_flash_attn - - if isinstance(attn_params, ExLlamaV2LinearAttention.PagedParams): - return self.forward_paged( - hidden_states, - cache, - attn_params, - loras = loras, - **kwargs - ) - - if self.is_tp: - if cache is not None and use_flash_attn: - return self.forward_tp( - hidden_states, - cache, - attn_params, - past_len, - intermediates, - loras, - **kwargs, - ) - else: - # TODO: Can't use the optimized forward function because it writes directly to a fixed output - # tensor, and flash-attn currently has a bug that prevents that from working when q_len == 1 - return self.forward_tp_old( - hidden_states, - cache, - attn_params, - past_len, - intermediates, - loras, - **kwargs, - ) - - if self.q_handle is None or intermediates: - return self.forward_torch( - hidden_states, - cache, - attn_params, - past_len, - intermediates, - loras = loras, - **kwargs - ) - - constants = self.model.get_device_context(self.device_idx) - - batch_size, q_len, _ = hidden_states.shape - direct = (batch_size == 1 and cache is not None and isinstance(cache, ExLlamaV2CacheBase)) - - # If conditions are right we can write the K/V projections directly into the cache - - if direct: - batch_keys, batch_values = cache.get_kv_state(self.layer_idx, batch_size, 0, past_len) - - # RMS norm, Q/K/V projections, position embeddings - - if loras is None or self.temp_lora_size == 0: - pass_loras = [] - pass_lora_temp = none_tensor - else: - pass_loras = [id(x) for x in loras] - pass_lora_temp = torch.empty((self.temp_lora_size,), dtype = torch.half, device = hidden_states.device) - - if attn_params.position_offsets is not None: - pass_past_len_1 = past_len - pass_past_len_2 = attn_params.get_position_offsets(hidden_states.device) - if pass_past_len_1 == 0: - offsets = attn_params.get_rope_offsets(self.device_idx) - if offsets is not None: - pass_past_len_2 = pass_past_len_2 + offsets - else: - pass_past_len_1 = past_len - pass_past_len_2 = none_tensor - if attn_params.rope_offsets is not None: - offset = attn_params.rope_offsets.cpu().item() - pass_past_len_1 += offset - - ext_c.q_attn_forward_1( - self.q_handle, - hidden_states, - batch_size, - q_len, - pass_past_len_1, - pass_past_len_2, - q_states, - k_states, - v_states, - constants.sin, - constants.cos, - pass_loras, - pass_lora_temp - ) - - # Select attention function - - if (has_flash_attn and not cfg.no_flash_attn) and attn_params.is_causal(): - attn_func = self._attn_flash - elif (has_xformers and not cfg.no_xformers) and attn_params.is_causal(): - attn_func = self._attn_xformers - else: - attn_func = self._attn_torch - - # Straight attention without cache - - if cache is None: - - q_states = q_states.view(batch_size, q_len, self.num_attention_heads, self.head_dim) - k_states = k_states.view(batch_size, q_len, self.num_key_value_heads, self.head_dim) - v_states = v_states.view(batch_size, q_len, self.num_key_value_heads, self.head_dim) - - attn_output = attn_func(batch_size, q_len, q_states, k_states, v_states, attn_params, cfg) - - # Regular cache (FP16, FP8, Q4) - - elif isinstance(cache, ExLlamaV2CacheBase): - - q_states = q_states.view(batch_size, q_len, self.num_attention_heads, self.head_dim) - k_states = k_states.view(batch_size, q_len, self.num_key_value_heads, self.head_dim) - v_states = v_states.view(batch_size, q_len, self.num_key_value_heads, self.head_dim) - - if not direct: - batch_keys, batch_values = cache.get_kv_state(self.layer_idx, batch_size, 0, past_len) - batch_keys[:batch_size, past_len:past_len + q_len, :].copy_(k_states) - batch_values[:batch_size, past_len:past_len + q_len, :].copy_(v_states) - - k_states = batch_keys[:batch_size, :past_len + q_len, :] - v_states = batch_values[:batch_size, :past_len + q_len, :] - - cache.store_kv_state(self.layer_idx, batch_size, past_len, q_len) - - attn_output = attn_func(batch_size, q_len, q_states, k_states, v_states, attn_params, cfg) - - # Output projection - - ext_c.q_attn_forward_2( - self.q_handle, - hidden_states, - attn_output, - batch_size, - q_len, - pass_loras, - pass_lora_temp - ) - - if self.archparams.clamp_hidden_states: - hidden_states.clamp_(-65504, 65504) - - return hidden_states - - def forward_tp( - self, - hidden_states: torch.Tensor, - cache: ExLlamaV2CacheBase | None = None, - attn_params: ExLlamaV2Attention.Params | None = None, - past_len: int | None = None, - intermediates: bool = False, - loras: list[ExLlamaV2Lora] | None = None, - **kwargs - ) -> torch.Tensor: - - cfg = self.model.config - ctx = self.model.tp_context - - assert not cache or cache.q_block != 1, \ - "Models with odd key/value dims not supported in TP mode with quantized cache" - assert not self.sliding_window, \ - "Sliding window not supported in TP mode" - - attn_params.prep_tp(self.model) - - batch_size, q_len, _ = hidden_states.shape - hidden_states = hidden_states.view(-1, self.hidden_size) - past_len = 0 if cache is None else cache.current_seq_len - - k_cache, v_cache = cache.get_kv_state(self.layer_idx, batch_size, 0, past_len) if cache else ([], []) - - sin, cos = ctx.get_sin_cos() - - ext_c.tp_attn_forward_( - self.model.tp_context.ext_tp_context, - hidden_states, - self.temp_bc0, - self.temp_bc1, - self.temp_bc2, - self.temp_q, - self.temp_k, - self.temp_v, - self.temp_o, - k_cache, - v_cache, - self.pre_layernorm.weight if self.pre_layernorm is not None else [], - self.pre_layernorm.variance_epsilon if self.pre_layernorm is not None else 0.0, - self.q_proj.q_handle, - self.k_proj.q_handle, - self.v_proj.q_handle, - self.o_proj.q_handle, - self.head_dim, - int(self.archparams.rope_style), - batch_size, - q_len, - sin, - cos, - attn_params.past_len_tp, - self.scaling - ) - - if cache is not None: - cache.store_kv_state(self.layer_idx, batch_size, past_len, q_len) - - return ctx.get_pinned(0, batch_size, q_len, self.hidden_size) - - - def forward_tp_old( - self, - hidden_states: torch.Tensor, - cache: ExLlamaV2CacheBase | None = None, - attn_params: ExLlamaV2Attention.Params | None = None, - past_len: int | None = None, - intermediates: bool = False, - loras: list[ExLlamaV2Lora] | None = None, - **kwargs - ): - cfg = self.model.config - split = self.model.tp_context.get_split(BROADCAST_KV) - batch_size, q_len, _ = hidden_states.shape - attn_params.prep_tp(self.model) - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - past_len = 0 if cache is None else cache.current_seq_len - - assert self.q_handle is not None - use_flash_attn = has_flash_attn and not cfg.no_flash_attn - if not use_flash_attn: - assert has_lower_right_sdpa and attn_params.is_causal() and not cfg.no_sdpa and not cfg.attn_logit_softcapping, \ - "TP attention without flash-attn must use Torch SDPA with lower-right attention mask " \ - "(use PyTorch 2.4.0+) and does not support logit softcapping." - - hidden_states = self.model.tp_context.broadcast(0, hidden_states, BROADCAST_KV, dim = self.head_dim) - - residual = hidden_states - - post_norm = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states - q = self.q_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = self.head_dim) - k = self.k_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = self.head_dim) - v = self.v_proj.forward_tp(post_norm, loras = loras, output_split = True, dim = self.head_dim) - - q = [q_.view(batch_size, q_len, q_.shape[1] // self.head_dim, self.head_dim) for q_ in q] - k = [k_.view(batch_size, q_len, k_.shape[1] // self.head_dim, self.head_dim) for k_ in k] - v = [v_.view(batch_size, q_len, v_.shape[1] // self.head_dim, self.head_dim) for v_ in v] - - if cache: - k_cache, v_cache = cache.get_kv_state(self.layer_idx, batch_size, 0, past_len) - else: - k_cache, v_cache = None, None - - if self.archparams.rope_style != RopeStyle.NONE: - for idx, (dev, a, b) in enumerate(split): - context = self.model.get_device_context(dev) - torch.cuda.set_stream(context.stream) - for t, heads in [(q[idx], self.num_key_value_groups), (k[idx], 1)]: - ext_c.rope_( - t, - context.sin, - context.cos, - past_len, - (b - a) * heads, - self.head_dim, - attn_params.position_offsets_tp[idx] if attn_params.position_offsets is not None else none_tensor, - self.archparams.rope_style == RopeStyle.NEOX - ) - - attn_outputs = [] - for idx in range(len(split)): - dev, a, b = split[idx] - context = self.model.get_device_context(dev) - torch.cuda.set_stream(context.stream) - - if k_cache is not None: - if use_flash_attn: - attn_output = flash_attn_with_kvcache( - q = q[idx], - k = k[idx], - v = v[idx], - k_cache = k_cache[idx], - v_cache = v_cache[idx], - causal = True, - softmax_scale = self.scaling, - cache_seqlens = attn_params.past_len_tp[idx] - ) - else: - cache_a = attn_params.past_len - cache_b = attn_params.past_len + q_len - k_cache[idx][:batch_size, cache_a:cache_b, :, :].copy_(k[idx]) - v_cache[idx][:batch_size, cache_a:cache_b, :, :].copy_(v[idx]) - attn_output = self._attn_torch( - batch_size, - q_len, - q[idx], - k_cache[idx][:batch_size, :cache_b, :, :], - v_cache[idx][:batch_size, :cache_b, :, :], - attn_params, - cfg - ) - else: - if use_flash_attn: - attn_output = flash_attn_func( - q[idx], - k[idx], - v[idx], - causal = True, - softmax_scale = self.scaling, - ) - else: - attn_output = self._attn_torch( - batch_size, - q_len, - q[idx], - k[idx], - v[idx], - attn_params, - cfg - ) - - attn_output = attn_output.view(batch_size * q_len, (b - a) * self.head_dim * self.num_key_value_groups) - attn_outputs.append(attn_output) - - if cache is not None: - cache.store_kv_state(self.layer_idx, batch_size, past_len, q_len) - - # Output projection - - attn_outputs = self.model.tp_context.allgather(1, attn_outputs, BROADCAST_Q, BROADCAST_Q, dim = self.head_dim) - - hidden_states = self.o_proj.forward_tp(attn_outputs, loras = loras, dim = self.head_dim, output_split = True) - - if self.has_residual: - self.model.tp_context.add_residual(hidden_states, residual, BROADCAST_Q, dim = self.head_dim) - - hidden_states = self.model.tp_context.gather(0, hidden_states, BROADCAST_Q, dim = self.head_dim) - - # if self.post_layernorm: # TODO: ... - # hidden_states = self.post_layernorm.forward(hidden_states) - - hidden_states = hidden_states.view(batch_size, q_len, hidden_states.shape[-1]) - return hidden_states - - - def forward_torch( - self, - hidden_states: torch.Tensor, - cache: ExLlamaV2CacheBase | None = None, - attn_params: ExLlamaV2Attention.Params | None = None, - past_len: int | None = None, - intermediates: bool = False, - loras: list[ExLlamaV2Lora] | None = None, - **kwargs - ) -> torch.Tensor | dict: - - global has_flash_attn - global has_xformers - - cfg = self.model.config - num_attention_heads = self.num_attention_heads - num_key_value_heads = self.num_key_value_heads - head_dim = self.head_dim - - batch_size, q_len, _ = hidden_states.size() - - past_len = 0 if cache is None else cache.current_seq_len - - # Project q, k, v - - residual = hidden_states - post_norm = self.pre_layernorm.forward(hidden_states) if self.has_norm else hidden_states - - # Output projection - - attn_proj = self.o_proj.forward(post_norm, loras = loras) - - # Post layernorm - - if self.post_layernorm: - attn_proj = self.post_layernorm.forward(attn_proj, output_fp32 = self.archparams.residual_stream_fp32) - - # Add residual connection - - hidden_states = (attn_proj + residual) if self.has_residual else attn_proj - - if self.archparams.residual_stream_fp32: - hidden_states = hidden_states.float() - elif self.archparams.clamp_hidden_states: - hidden_states.clamp_(-65504, 65504) - - if intermediates: - return {"post_norm": post_norm, - "hidden_states": hidden_states} - else: - return hidden_states - - - def update_loras(self): - - if self.q_handle is None: return - - cfg = self.model.config - - q_proj_lora_a = { id(k): v for k, v in self.q_proj.lora_a_tensors.items() } - q_proj_lora_b = { id(k): v for k, v in self.q_proj.lora_b_tensors.items() } - k_proj_lora_a = { id(k): v for k, v in self.k_proj.lora_a_tensors.items() } - k_proj_lora_b = { id(k): v for k, v in self.k_proj.lora_b_tensors.items() } - v_proj_lora_a = { id(k): v for k, v in self.v_proj.lora_a_tensors.items() } - v_proj_lora_b = { id(k): v for k, v in self.v_proj.lora_b_tensors.items() } - o_proj_lora_a = { id(k): v for k, v in self.o_proj.lora_a_tensors.items() } - o_proj_lora_b = { id(k): v for k, v in self.o_proj.lora_b_tensors.items() } - - temp_lora_size = ext_c.q_attn_set_loras( - self.q_handle, - q_proj_lora_a, - q_proj_lora_b, - k_proj_lora_a, - k_proj_lora_b, - v_proj_lora_a, - v_proj_lora_b, - o_proj_lora_a, - o_proj_lora_b - ) - - self.temp_lora_size = temp_lora_size * cfg.max_batch_size * cfg.max_input_len - - - def is_quant(self): - return self.q_handle is not None - - - def tp_split(self): - - cfg = self.model.config - ctx = self.model.tp_context - - if self.pre_layernorm is not None: - self.pre_layernorm.tp_split(BROADCAST_KV) - if self.post_layernorm is not None: - self.post_layernorm.tp_split(BROADCAST_KV) - - self.q_proj.tp_split(BROADCAST_Q, dim = self.head_dim) - self.k_proj.tp_split(BROADCAST_KV, dim = self.head_dim) - self.v_proj.tp_split(BROADCAST_KV, dim = self.head_dim) - self.o_proj.tp_split(BROADCAST_Q, dim = self.head_dim) - - maxrows = cfg.max_batch_size * cfg.max_input_len - dtype = torch.half - - ctx.begin_scratch_alloc_tp() - ctx.reserve_scratch(self.tp_dq_size) - self.temp_bc0 = ctx.get_scratch_slice_tp_bc(maxrows, dtype, BROADCAST_Q, dim = self.head_dim) - self.temp_bc1 = ctx.get_scratch_slice_tp_bc(maxrows, dtype, BROADCAST_Q, dim = self.head_dim) - self.temp_bc2 = ctx.get_scratch_slice_tp_bc(maxrows, dtype, BROADCAST_Q, dim = self.head_dim) - self.temp_q = ctx.get_scratch_slice_tp(maxrows, dtype, BROADCAST_Q, dim = self.head_dim) - self.temp_k = ctx.get_scratch_slice_tp(maxrows, dtype, BROADCAST_KV, dim = self.head_dim) - self.temp_v = ctx.get_scratch_slice_tp(maxrows, dtype, BROADCAST_KV, dim = self.head_dim) - self.temp_o = ctx.get_scratch_slice_tp(maxrows, dtype, BROADCAST_Q, dim = self.head_dim) - - self.is_tp = True - self.set_device_idx(None) - - - def scratch_space_tp(self): - - cfg = self.model.config - ctx = self.model.tp_context - devs = ctx.num_devices - scratch = [0] * devs - - def add(res: list[int]): - for i, s in enumerate(res): - scratch[i] += s - - def amax(res: list[int]): - for i, s in enumerate(res): - scratch[i] = max(scratch[i], s) - - amax(self.q_proj.scratch_space_tp(BROADCAST_Q, self.head_dim)) - amax(self.k_proj.scratch_space_tp(BROADCAST_KV, self.head_dim)) - amax(self.v_proj.scratch_space_tp(BROADCAST_KV, self.head_dim)) - amax(self.o_proj.scratch_space_tp(BROADCAST_Q, self.head_dim)) - self.tp_dq_size = [s for s in scratch] - - maxrows = cfg.max_batch_size * cfg.max_input_len - - add(ctx.get_temp_tensors_bc_s(maxrows, 2, BROADCAST_Q, dim = self.head_dim)) - add(ctx.get_temp_tensors_bc_s(maxrows, 2, BROADCAST_Q, dim = self.head_dim)) - add(ctx.get_temp_tensors_bc_s(maxrows, 2, BROADCAST_Q, dim = self.head_dim)) - add(ctx.get_temp_tensors_s(maxrows, 2, BROADCAST_Q, dim = self.head_dim)) - add(ctx.get_temp_tensors_s(maxrows, 2, BROADCAST_KV, dim = self.head_dim)) - add(ctx.get_temp_tensors_s(maxrows, 2, BROADCAST_KV, dim = self.head_dim)) - add(ctx.get_temp_tensors_s(maxrows, 2, BROADCAST_Q, dim = self.head_dim)) - - return scratch diff --git a/exllamav2/model.py b/exllamav2/model.py index d7facfd1..1bdf73da 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -37,7 +37,6 @@ from exllamav2.rmsnorm import ExLlamaV2RMSNorm from exllamav2.layernorm import ExLlamaV2LayerNorm from exllamav2.attn import ExLlamaV2Attention, has_flash_attn, has_xformers -from exllamav2.linear_attn import ExLlamaV2LinearAttention, has_flash_attn, has_xformers from exllamav2.lora import ExLlamaV2Lora from exllamav2.mlp import ExLlamaV2MLP from exllamav2.moe_mlp import ExLlamaV2MoEMLP @@ -120,10 +119,7 @@ def __init__( self.modules += [pd] elif type(cfg.arch.lm.layer_keys) is list and type(cfg.intermediate_size) is list: mlp = ExLlamaV2MLP(self, layer_key, layer_idx) - if ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[layer_idx]: - attn = ExLlamaV2LinearAttention(self, layer_key, layer_idx) - self.modules += [attn, mlp] - elif ["self_attn.o_proj"] in cfg.arch.lm.layer_keys[layer_idx]: + if ["self_attn.linear_attn"] in cfg.arch.lm.layer_keys[layer_idx] or ["self_attn.o_proj"] in cfg.arch.lm.layer_keys[layer_idx]: attn = ExLlamaV2Attention(self, layer_key, layer_idx, sliding_window = swa) self.modules += [attn, mlp] else: @@ -172,7 +168,6 @@ def __init__( while True: layer_idx -= 1 if isinstance(self.modules[layer_idx], ExLlamaV2Attention) or \ - isinstance(self.modules[layer_idx], ExLlamaV2LinearAttention) or \ isinstance(self.modules[layer_idx], ExLlamaV2ParallelDecoder): break @@ -621,7 +616,6 @@ def load_autosplit_gen( try: if isinstance(module, ExLlamaV2Attention) or \ - isinstance(module, ExLlamaV2LinearAttention) or \ isinstance(module, ExLlamaV2ParallelDecoder): self.cache_map[module.layer_idx] = module.device() cache.update_cache_tensors()