diff --git a/configs/quantization/methods/KVQuant/rtn_w_a_kivi_quant_kv.yml b/configs/quantization/methods/KVQuant/rtn_w_a_kivi_quant_kv.yml new file mode 100644 index 00000000..0a745dbc --- /dev/null +++ b/configs/quantization/methods/KVQuant/rtn_w_a_kivi_quant_kv.yml @@ -0,0 +1,35 @@ +base: + seed: &seed 42 +model: + type: model_type + path: model path + torch_dtype: auto +eval: + eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos + name: wikitext2 + type: decode_ppl + download: False + path: eval_data_path + bs: 1 + inference_per_block: False + num_samples: 10 + # num_eval_tokens: 3 +quant: + method: RTN + weight: + bit: 8 + symmetric: True + granularity: per_channel + group_size: -1 + act: + bit: 8 + symmetric: True + granularity: per_token + kvcache: + method: Kivi + bit: 8 + symmetric: True + granularity: per_token +save: + save_fake: False + save_path: /path/to/save/ diff --git a/configs/quantization/methods/RTN/rtn_w_a_kv_human_eval.yml b/configs/quantization/methods/KVQuant/rtn_w_a_naive_quant_kv.yml similarity index 59% rename from configs/quantization/methods/RTN/rtn_w_a_kv_human_eval.yml rename to configs/quantization/methods/KVQuant/rtn_w_a_naive_quant_kv.yml index f0e0685c..ade36ebc 100644 --- a/configs/quantization/methods/RTN/rtn_w_a_kv_human_eval.yml +++ b/configs/quantization/methods/KVQuant/rtn_w_a_naive_quant_kv.yml @@ -5,16 +5,15 @@ model: path: model path torch_dtype: auto eval: - eval_pos: [pretrain, fake_quant] - type: code - name: human_eval - res_path: ./human_eval/ - # For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False". - # For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True". + eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos + name: wikitext2 + type: decode_ppl + download: False + path: eval_data_path bs: 1 - format_tabs: True inference_per_block: False - # add_chat_temp: True + num_samples: 10 + # num_eval_tokens: 3 quant: method: RTN weight: diff --git a/configs/quantization/methods/RTN/rtn_w_a_pertensor_static_kv.yml b/configs/quantization/methods/KVQuant/rtn_w_a_pertensor_static_naive_quant_kv.yml similarity index 73% rename from configs/quantization/methods/RTN/rtn_w_a_pertensor_static_kv.yml rename to configs/quantization/methods/KVQuant/rtn_w_a_pertensor_static_naive_quant_kv.yml index fc8417b1..f2bbda67 100644 --- a/configs/quantization/methods/RTN/rtn_w_a_pertensor_static_kv.yml +++ b/configs/quantization/methods/KVQuant/rtn_w_a_pertensor_static_naive_quant_kv.yml @@ -14,15 +14,15 @@ calib: preproc: txt_general_preproc seed: *seed eval: - eval_pos: [pretrain, fake_quant] + eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos name: wikitext2 + type: decode_ppl download: False - path: eval data path - seq_len: 2048 - # For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False". - # For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True". + path: eval_data_path bs: 1 inference_per_block: False + num_samples: 10 + # num_eval_tokens: 3 quant: method: RTN weight: diff --git a/configs/quantization/methods/KVQuant/rtn_w_a_sink_quant_kv.yml b/configs/quantization/methods/KVQuant/rtn_w_a_sink_quant_kv.yml new file mode 100644 index 00000000..040f260c --- /dev/null +++ b/configs/quantization/methods/KVQuant/rtn_w_a_sink_quant_kv.yml @@ -0,0 +1,38 @@ +base: + seed: &seed 42 +model: + type: model_type + path: model path + torch_dtype: auto +eval: + eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos + name: wikitext2 + type: decode_ppl + download: False + path: eval_data_path + bs: 1 + inference_per_block: False + num_samples: 10 + # num_eval_tokens: 3 +quant: + method: RTN + weight: + bit: 8 + symmetric: True + granularity: per_channel + group_size: -1 + act: + bit: 8 + symmetric: True + granularity: per_token + kvcache: + method: Sink + bit: 4 + symmetric: True + granularity: per_token + special: + window_length: 512 + num_sink_tokens: 4 +save: + save_fake: False + save_path: /path/to/save/ diff --git a/llmc/compression/blockwise_optimization.py b/llmc/compression/blockwise_optimization.py index b6162cf2..4bc0a7a1 100644 --- a/llmc/compression/blockwise_optimization.py +++ b/llmc/compression/blockwise_optimization.py @@ -45,7 +45,10 @@ def run_block_loop(self): if hasattr(self, 'save_clip') and self.save_clip: os.makedirs(self.clip_path, exist_ok=True) - torch.save(self.auto_clipper.weight_clips, os.path.join(self.clip_path, 'clips.pth')) + torch.save( + self.auto_clipper.weight_clips, + os.path.join(self.clip_path, 'clips.pth'), + ) def cache_input_hook(self, m, x, y, name, feat_dict): inputs = [i.detach().cpu() for i in x] @@ -63,14 +66,31 @@ def hook_fn(module, args, kwargs): kwargs['past_key_value'] = kvcache kwargs['use_cache'] = True if kwargs['hidden_states'].shape[1] == 1: - if self.config['model']['type'] in ['DeepseekV2']: - kwargs['position_ids'] = kwargs['position_ids'][:, -1].unsqueeze(1) + if kwargs['position_ids'].shape[1] == 1: + # For eval decoding PPL (Perplexity), it will be removed in future versions. + past_seen_tokens = kvcache.get_seq_length() + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + kwargs['hidden_states'].shape[1], + device=kwargs['hidden_states'].device, + ) + kwargs['cache_position'] = cache_position + position_ids = cache_position.unsqueeze(0) + kwargs['position_ids'] = position_ids + if 'position_embeddings' in kwargs: + kwargs['position_embeddings'] = self.model.rotary_emb( + kwargs['hidden_states'], position_ids + ) else: - kwargs['position_ids'] = kwargs['position_ids'][:, -1].unsqueeze(0).unsqueeze(0) - if 'position_embeddings' in kwargs: - cos = kwargs['position_embeddings'][0][:, -1, :].unsqueeze(1) - sin = kwargs['position_embeddings'][1][:, -1, :].unsqueeze(1) - kwargs['position_embeddings'] = (cos, sin) + if self.config['model']['type'] in ['DeepseekV2']: + kwargs['position_ids'] = kwargs['position_ids'][:, -1].unsqueeze(1) + else: + kwargs['position_ids'] = \ + kwargs['position_ids'][:, -1].unsqueeze(0).unsqueeze(0) + if 'position_embeddings' in kwargs: + cos = kwargs['position_embeddings'][0][:, -1, :].unsqueeze(1) + sin = kwargs['position_embeddings'][1][:, -1, :].unsqueeze(1) + kwargs['position_embeddings'] = (cos, sin) return args, kwargs diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 3ca66400..5925c111 100644 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -58,7 +58,7 @@ def a_qdq(self, act, module, aquantizer, input_index=0): def get_replacement_params(self, mode='fake_quant', w_only=False, name=None): params_dict = {} - if mode == 'fake_quant': + if mode in ['fake_quant', 'fake_quant_wo_kv']: if not self.mix_bits: params_dict['a_qdq'] = ( partial(self.a_qdq, aquantizer=self.aquantizer) @@ -229,17 +229,16 @@ def set_quant_config(self): # set kv cache quant config if 'kvcache' in self.quant_config: self.quant_config['kvcache']['static'] = self.act_static + kv_special_cfg = self.quant_config['kvcache'].get('special', {}) + logger.info(kv_special_cfg) + act_static_cfg = {} if self.act_static: - self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']]( - self.quant_type, self.quant_config['kvcache'], - self.model.model_config.num_hidden_layers, self.config.calib.n_samples, - self.config.calib.bs - ) - else: - self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']]( - self.quant_type, self.quant_config['kvcache'], - self.model.model_config.num_hidden_layers - ) + act_static_cfg.update(self.config.calib.n_sample) + act_static_cfg.update(self.config.calib.bs) + self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']]( + self.quant_type, self.quant_config['kvcache'], + self.model.model_config.num_hidden_layers, **kv_special_cfg, **act_static_cfg + ) self.quant_kvcache = True self.model.kvcache_buffer.append(self.kv_module) else: @@ -860,6 +859,7 @@ def deploy(self, quant_format, keep_device=False): module_mapping = { 'origin_float': OriginFloatLinear, 'fake_quant': EffcientFakeQuantLinear, + 'fake_quant_wo_kv': EffcientFakeQuantLinear, } module_mapping.update(_REALQUANT_LINEAR_MAP_) @@ -884,10 +884,12 @@ def deploy(self, quant_format, keep_device=False): self.set_non_linear_mode(quant_format, self.model.model, False) if self.quant_kvcache: - if quant_format == 'transformed': - self.kv_module.transformed = True + if quant_format == 'origin_float': + self.kv_module.use_org_kv = True + elif quant_format == 'fake_quant_wo_kv': + self.kv_module.use_org_kv = True elif quant_format == 'fake_quant': - self.kv_module.transformed = False + self.kv_module.use_org_kv = False if self.act_static: self.kv_module.calib = False diff --git a/llmc/compression/quantization/kvquant.py b/llmc/compression/quantization/kvquant.py index 9a8ada7a..ed67a04c 100644 --- a/llmc/compression/quantization/kvquant.py +++ b/llmc/compression/quantization/kvquant.py @@ -27,7 +27,7 @@ def __init__(self, quant_type, kvquant_cfg, num_hidden_layers, num_samples=128, self.static = kvquant_cfg.get('static', False) self._quantized_key_cache = [] self._quantized_value_cache = [] - self.transformed = False + self.use_org_kv = False if self.static: self._reset_buffers() @@ -48,9 +48,8 @@ def update( layer_idx, cache_kwargs, ): - if self.transformed: - super().update(key_states, value_states, layer_idx, cache_kwargs) - + if self.use_org_kv: + return super().update(key_states, value_states, layer_idx, cache_kwargs) elif self.static and self.calib: self._calibration(layer_idx, key_states, value_states) keys_to_return, values_to_return = key_states, value_states @@ -217,6 +216,8 @@ def get_qparams(self, tensor): return scales, zeros, qmin, qmax def get_seq_length(self, layer_idx=0): + if self.use_org_kv: + return super().get_seq_length() if len(self._quantized_key_cache) <= layer_idx: return 0 return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 @@ -224,12 +225,10 @@ def get_seq_length(self, layer_idx=0): @KV_REGISTRY.register('Kivi') class KiviQuantKVCache(NaiveQuantKVCache): - def __init__(self, quant_type, kvquant_cfg, num_hidden_layers, num_samples, bsz): + def __init__(self, quant_type, kvquant_cfg, num_hidden_layers, num_samples=128, bsz=1): super().__init__(quant_type, kvquant_cfg, num_hidden_layers, num_samples, bsz) assert not self.static, 'Only support dynamic quantization for KIVI' self.residual_length = kvquant_cfg.get('residual_length', 128) - self.key_cache = [] - self.value_cache = [] def update( self, @@ -238,51 +237,218 @@ def update( layer_idx, cache_kwargs, ): - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] + if self.use_org_kv: + return super().update(key_states, value_states, layer_idx, cache_kwargs) + else: + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + if len(self.key_cache) <= layer_idx: + self._quantized_key_cache.append(self._quantize(key_states.contiguous(), + layer_idx, + is_key=True)) + self._quantized_value_cache.append(self._quantize(value_states.contiguous(), + layer_idx, + is_key=False)) + self.key_cache.append(torch.zeros(0, + dtype=key_states.dtype, + device=key_states.device)) + self.value_cache.append(torch.zeros(0, + dtype=key_states.dtype, + device=key_states.device)) + keys_to_return, values_to_return = key_states, value_states + else: + dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) + dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) + keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states] + values_to_return = [dequant_value, self.value_cache[layer_idx], value_states] + + keys_to_return = torch.cat(keys_to_return, dim=-2) + values_to_return = torch.cat(values_to_return, dim=-2) + if ( + self.key_cache[layer_idx].dim() == 4 + and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length + ): + self._quantized_key_cache[layer_idx] = \ + self._quantize(keys_to_return.contiguous(), layer_idx, is_key=True) + self._quantized_value_cache[layer_idx] = self._quantize( + values_to_return.contiguous(), layer_idx, is_key=False + ) + self.key_cache[layer_idx] = torch.zeros(0, + dtype=key_states.dtype, + device=key_states.device) + self.value_cache[layer_idx] = torch.zeros(0, + dtype=key_states.dtype, + device=key_states.device) + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], + dim=-2) + self.value_cache[layer_idx] = \ + torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return keys_to_return, values_to_return + + +@KV_REGISTRY.register('Sink') +class SinkQuantKVCache(NaiveQuantKVCache): + def __init__( + self, + quant_type, + kvquant_cfg, + num_hidden_layers, + window_length, + num_sink_tokens, + num_samples=128, + bsz=1 + ): + super().__init__(quant_type, kvquant_cfg, num_hidden_layers, num_samples, bsz) + assert not self.static, 'Only support dynamic quantization for Sink' + self.window_length = window_length + self.num_sink_tokens = num_sink_tokens + self.cos_sin_rerotation_cache = {} + self._cos_cache = None + self._sin_cache = None + + @staticmethod + def _rotate_half(x): + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + def _apply_key_rotary_pos_emb( + self, key_states, cos, sin + ): + rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) + return rotated_key_states + + def _get_rerotation_cos_sin( + self, key_states, cos, sin + ): + if key_states.shape[-2] not in self.cos_sin_rerotation_cache: + # Upcast to float32 temporarily for better accuracy + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + original_cos = cos[self.num_sink_tokens + key_states.shape[-2]:] + shifted_cos = cos[self.num_sink_tokens:-key_states.shape[-2]] + original_sin = sin[self.num_sink_tokens + key_states.shape[-2]:] + shifted_sin = sin[self.num_sink_tokens:-key_states.shape[-2]] + rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin + rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin + + self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( + rerotation_cos.to(key_states.dtype).unsqueeze(0), + rerotation_sin.to(key_states.dtype).unsqueeze(0), + ) + return self.cos_sin_rerotation_cache[key_states.shape[-2]] + + def get_seq_length(self, layer_idx=0): + """Returns the sequence length of the cached states. + + A layer index can be optionally passed. + """ if len(self.key_cache) <= layer_idx: - self._quantized_key_cache.append(self._quantize(key_states.contiguous(), - layer_idx, - is_key=True)) - self._quantized_value_cache.append(self._quantize(value_states.contiguous(), - layer_idx, - is_key=False)) - self.key_cache.append(torch.zeros(0, - dtype=key_states.dtype, - device=key_states.device)) - self.value_cache.append(torch.zeros(0, - dtype=key_states.dtype, - device=key_states.device)) - keys_to_return, values_to_return = key_states, value_states + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_cache_shape(self): + """Returns the maximum sequence length of the cache object, in case of + SinkCache it is the window length.""" + return self.window_length + + def update( + self, + key_states, + value_states, + layer_idx, + cache_kwargs, + ): + + if self.use_org_kv: + return super().update(key_states, value_states, layer_idx, cache_kwargs) else: - dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) - dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) - keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states] - values_to_return = [dequant_value, self.value_cache[layer_idx], value_states] - - keys_to_return = torch.cat(keys_to_return, dim=-2) - values_to_return = torch.cat(values_to_return, dim=-2) - if ( - self.key_cache[layer_idx].dim() == 4 - and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length - ): - self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous(), - layer_idx, - is_key=True) - self._quantized_value_cache[layer_idx] = self._quantize( - values_to_return.contiguous(), layer_idx, is_key=False - ) - self.key_cache[layer_idx] = torch.zeros(0, - dtype=key_states.dtype, - device=key_states.device) - self.value_cache[layer_idx] = torch.zeros(0, - dtype=key_states.dtype, - device=key_states.device) + sin = cache_kwargs.get('sin') + cos = cache_kwargs.get('cos') + partial_rotation_size = cache_kwargs.get('partial_rotation_size') + using_rope = cos is not None and sin is not None + + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + if using_rope and layer_idx == 0: + + if cos.dim() == 2: + self._cos_cache = cos + self._sin_cache = sin + else: + if self._cos_cache is None: + self._cos_cache = cos[0, ...] + self._sin_cache = sin[0, ...] + elif self._cos_cache.shape[0] < self.window_length: + self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) + self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) + + # [bsz, num_heads, seq_len, head_dim] + if len(self.key_cache) <= layer_idx: + # Empty cache + self.key_cache.append(key_states) + self.value_cache.append(value_states) + + elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: + # Growing cache + self.key_cache[layer_idx] = \ + torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = \ + torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], - dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], - dim=-2) + # Shifting cache + keys_to_keep = self.key_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2]: + ] + + if using_rope: + rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( + key_states, + self._cos_cache[: self.window_length], + self._sin_cache[: self.window_length] + ) + if partial_rotation_size is not None: + keys_to_keep, keys_pass = ( + keys_to_keep[..., :partial_rotation_size], + keys_to_keep[..., partial_rotation_size:], + ) + keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, + rerotation_cos, + rerotation_sin) + if partial_rotation_size is not None: + keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) + + # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens + sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] + + dq_keys_to_keep = self._dequantize(self._quantize(keys_to_keep.contiguous(), + layer_idx, + is_key=True)) + dq_keys = self._dequantize(self._quantize(key_states.contiguous(), + layer_idx, + is_key=True)) + + self.key_cache[layer_idx] = torch.cat([sink_keys, dq_keys_to_keep, dq_keys], dim=-2) + + sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] + values_to_keep = self.value_cache[layer_idx][ + :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2]: + ] + dq_values_to_keep = self._dequantize(self._quantize(values_to_keep.contiguous(), + layer_idx, + is_key=True)) + dq_values = self._dequantize(self._quantize(value_states.contiguous(), + layer_idx, + is_key=True)) - return keys_to_return, values_to_return + self.value_cache[layer_idx] = torch.cat([sink_values, + dq_values_to_keep, + dq_values], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] diff --git a/llmc/eval/__init__.py b/llmc/eval/__init__.py index 836b7922..ee5ee971 100644 --- a/llmc/eval/__init__.py +++ b/llmc/eval/__init__.py @@ -1,6 +1,6 @@ from .eval_acc import AccuracyEval from .eval_code import HumanEval from .eval_custom_generate import CustomGenerate -from .eval_ppl import PerplexityEval +from .eval_ppl import DecodePerplexityEval, PerplexityEval from .eval_token_consist import TokenConsistencyEval from .eval_vqa import VQAEval diff --git a/llmc/eval/eval_base.py b/llmc/eval/eval_base.py index a76500af..4c86e077 100644 --- a/llmc/eval/eval_base.py +++ b/llmc/eval/eval_base.py @@ -20,6 +20,7 @@ def __init__(self, model, config): self.model_type = config.model.type logger.info(f'eval_cfg : {self.eval_cfg}') self.dataset = self.eval_cfg['name'] + self.dataset_type = self.eval_cfg.get('type', 'ppl') assert self.dataset in [ 'wikitext2', 'c4', @@ -31,6 +32,8 @@ def __init__(self, model, config): 'custom_gen', ], 'Eval only support wikitext2, c4, ptb, custom, human_eval dataset now.' self.seq_len = self.eval_cfg.get('seq_len', None) + self.num_samples = self.eval_cfg.get('num_samples', None) + self.num_eval_tokens = self.eval_cfg.get('num_eval_tokens', None) self.eval_dataset_bs = self.eval_cfg['bs'] self.eval_dataset_path = self.eval_cfg.get('path', None) self.apply_chat_template = self.eval_cfg.get('apply_chat_template', False) @@ -70,7 +73,10 @@ def build_data(self): testdata = load_from_disk(self.eval_dataset_path) self.testdata = testdata # encode data - if self.dataset == 'wikitext2': + if self.dataset_type == 'decode_ppl': + assert self.dataset == 'wikitext2' + testenc = testdata['text'] + elif self.dataset == 'wikitext2': testenc = self.tokenizer( '\n\n'.join(testdata['text']), return_tensors='pt' ) diff --git a/llmc/eval/eval_ppl.py b/llmc/eval/eval_ppl.py index 0e3d55cb..fe1613bd 100644 --- a/llmc/eval/eval_ppl.py +++ b/llmc/eval/eval_ppl.py @@ -5,12 +5,12 @@ import torch.nn as nn from datasets import load_dataset, load_from_disk from loguru import logger +from tqdm import tqdm from .eval_base import BaseEval class PerplexityEval(BaseEval): - @torch.no_grad() def eval_func(self, model, testenc, seq_len, bs, eval_pos): testenc = testenc.input_ids @@ -57,3 +57,38 @@ def eval_func(self, model, testenc, seq_len, bs, eval_pos): torch.cuda.empty_cache() return ppl.item() + + +class DecodePerplexityEval(BaseEval): + @torch.no_grad() + def eval_func(self, model, testenc, seq_len, bs, eval_pos): + num_eval_tokens = 0 + num_samples = 1 if self.num_samples is None else self.num_samples + loss_fn = torch.nn.CrossEntropyLoss(reduction='none') + nlls = [] + + for text in testenc[: num_samples]: + logger.info(text) + encodings = self.tokenizer(text, return_tensors='pt') + seq_len = encodings.input_ids.size(1) + logger.info(f'seq_len: {seq_len}') + pbar = tqdm(range(0, seq_len - 1)) + + for idx in pbar: + input_ids = encodings.input_ids[:, idx:idx + 1].cuda() + with torch.no_grad(): + outputs = model.model( + input_ids, + ) + logits = outputs.logits.view(-1, model.model.config.vocab_size) + label = encodings.input_ids[:, idx + 1:idx + 2].to(logits.device).view(-1) + neg_log_likelihood = loss_fn(logits, label) + nlls.append(neg_log_likelihood) + num_eval_tokens += 1 + if self.num_eval_tokens is not None and num_eval_tokens >= self.num_eval_tokens: + break + if self.num_eval_tokens is not None and num_eval_tokens >= self.num_eval_tokens: + break + model.reset_kv() + ppl = torch.exp(torch.stack(nlls).mean()) + return ppl.item() diff --git a/llmc/eval/utils.py b/llmc/eval/utils.py index e173fc9f..26af99fb 100644 --- a/llmc/eval/utils.py +++ b/llmc/eval/utils.py @@ -3,14 +3,24 @@ from loguru import logger -from llmc.eval import (AccuracyEval, CustomGenerate, HumanEval, PerplexityEval, - TokenConsistencyEval, VQAEval) +from llmc.eval import (AccuracyEval, CustomGenerate, DecodePerplexityEval, + HumanEval, PerplexityEval, TokenConsistencyEval, + VQAEval) def get_eval_list(model, config): eval_list = [] if int(os.environ['RANK']) == 0: if 'eval' in config: + if 'type' in config.eval and config.eval.type == 'decode_ppl': + if 'pretrain' in config.eval.eval_pos: + raise ValueError( + 'Unsupported: Evaluating decode_ppl with a pretrained model. ' + ) + # Pretrained models do not use key-value caching. + # Please use a transformed model to evaluate decode_ppl + # for the original model. + if not isinstance(config.eval, list): eval_config_list = [config.eval] else: @@ -50,8 +60,12 @@ def get_eval_list(model, config): eval_class = TokenConsistencyEval(model, config_for_eval) elif config_tmp.eval.type == 'ppl': eval_class = PerplexityEval(model, config_for_eval) + elif config_tmp.eval.type == 'decode_ppl': + eval_class = DecodePerplexityEval(model, config_for_eval) else: - raise ValueError(f'Unsupported eval type: {config_tmp.eval.type}') + raise ValueError( + f'Unsupported eval type: {config_tmp.eval.type}' + ) eval_list.append((eval_class, config_for_eval)) return eval_list diff --git a/llmc/models/qwen.py b/llmc/models/qwen.py index e5ae52dd..88402235 100644 --- a/llmc/models/qwen.py +++ b/llmc/models/qwen.py @@ -27,6 +27,9 @@ def get_head_layers(self): def get_pre_head_layernorm_layers(self): return [self.model.transformer.ln_f] + def get_attn_in_block(self, block): + return {'self_attn': block.self_attn} + def get_layers_except_blocks(self): return [self.wte, self.rotary_emb, diff --git a/llmc/models/qwen2.py b/llmc/models/qwen2.py index d78ba11d..4d9d9551 100644 --- a/llmc/models/qwen2.py +++ b/llmc/models/qwen2.py @@ -27,6 +27,9 @@ def find_block_name(self): def get_embed_layers(self): return [self.embed_tokens] + def get_attn_in_block(self, block): + return {'self_attn': block.self_attn} + def get_attention_rotary_layers(self): if packaging.version.parse(version('transformers')) >= packaging.version.parse('4.45.0'): return [self.rotary_emb]