Skip to content

Commit

Permalink
Dev kv (#283)
Browse files Browse the repository at this point in the history
* Fix omniq clip bugs

* Fix omniq clip bugs

* Fix awq trans_v1 gqa bug

* Add streamllm(Sink) support and add eval decode PPL support

---------

Co-authored-by: gushiqiao <[email protected]>
  • Loading branch information
gushiqiao and gushiqiao authored Jan 3, 2025
1 parent 18bdbdd commit 08854d6
Show file tree
Hide file tree
Showing 13 changed files with 412 additions and 91 deletions.
35 changes: 35 additions & 0 deletions configs/quantization/methods/KVQuant/rtn_w_a_kivi_quant_kv.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
base:
seed: &seed 42
model:
type: model_type
path: model path
torch_dtype: auto
eval:
eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos
name: wikitext2
type: decode_ppl
download: False
path: eval_data_path
bs: 1
inference_per_block: False
num_samples: 10
# num_eval_tokens: 3
quant:
method: RTN
weight:
bit: 8
symmetric: True
granularity: per_channel
group_size: -1
act:
bit: 8
symmetric: True
granularity: per_token
kvcache:
method: Kivi
bit: 8
symmetric: True
granularity: per_token
save:
save_fake: False
save_path: /path/to/save/
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@ model:
path: model path
torch_dtype: auto
eval:
eval_pos: [pretrain, fake_quant]
type: code
name: human_eval
res_path: ./human_eval/
# For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False".
# For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True".
eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos
name: wikitext2
type: decode_ppl
download: False
path: eval_data_path
bs: 1
format_tabs: True
inference_per_block: False
# add_chat_temp: True
num_samples: 10
# num_eval_tokens: 3
quant:
method: RTN
weight:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ calib:
preproc: txt_general_preproc
seed: *seed
eval:
eval_pos: [pretrain, fake_quant]
eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos
name: wikitext2
type: decode_ppl
download: False
path: eval data path
seq_len: 2048
# For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False".
# For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True".
path: eval_data_path
bs: 1
inference_per_block: False
num_samples: 10
# num_eval_tokens: 3
quant:
method: RTN
weight:
Expand Down
38 changes: 38 additions & 0 deletions configs/quantization/methods/KVQuant/rtn_w_a_sink_quant_kv.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
base:
seed: &seed 42
model:
type: model_type
path: model path
torch_dtype: auto
eval:
eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos
name: wikitext2
type: decode_ppl
download: False
path: eval_data_path
bs: 1
inference_per_block: False
num_samples: 10
# num_eval_tokens: 3
quant:
method: RTN
weight:
bit: 8
symmetric: True
granularity: per_channel
group_size: -1
act:
bit: 8
symmetric: True
granularity: per_token
kvcache:
method: Sink
bit: 4
symmetric: True
granularity: per_token
special:
window_length: 512
num_sink_tokens: 4
save:
save_fake: False
save_path: /path/to/save/
36 changes: 28 additions & 8 deletions llmc/compression/blockwise_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def run_block_loop(self):

if hasattr(self, 'save_clip') and self.save_clip:
os.makedirs(self.clip_path, exist_ok=True)
torch.save(self.auto_clipper.weight_clips, os.path.join(self.clip_path, 'clips.pth'))
torch.save(
self.auto_clipper.weight_clips,
os.path.join(self.clip_path, 'clips.pth'),
)

def cache_input_hook(self, m, x, y, name, feat_dict):
inputs = [i.detach().cpu() for i in x]
Expand All @@ -63,14 +66,31 @@ def hook_fn(module, args, kwargs):
kwargs['past_key_value'] = kvcache
kwargs['use_cache'] = True
if kwargs['hidden_states'].shape[1] == 1:
if self.config['model']['type'] in ['DeepseekV2']:
kwargs['position_ids'] = kwargs['position_ids'][:, -1].unsqueeze(1)
if kwargs['position_ids'].shape[1] == 1:
# For eval decoding PPL (Perplexity), it will be removed in future versions.
past_seen_tokens = kvcache.get_seq_length()
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + kwargs['hidden_states'].shape[1],
device=kwargs['hidden_states'].device,
)
kwargs['cache_position'] = cache_position
position_ids = cache_position.unsqueeze(0)
kwargs['position_ids'] = position_ids
if 'position_embeddings' in kwargs:
kwargs['position_embeddings'] = self.model.rotary_emb(
kwargs['hidden_states'], position_ids
)
else:
kwargs['position_ids'] = kwargs['position_ids'][:, -1].unsqueeze(0).unsqueeze(0)
if 'position_embeddings' in kwargs:
cos = kwargs['position_embeddings'][0][:, -1, :].unsqueeze(1)
sin = kwargs['position_embeddings'][1][:, -1, :].unsqueeze(1)
kwargs['position_embeddings'] = (cos, sin)
if self.config['model']['type'] in ['DeepseekV2']:
kwargs['position_ids'] = kwargs['position_ids'][:, -1].unsqueeze(1)
else:
kwargs['position_ids'] = \
kwargs['position_ids'][:, -1].unsqueeze(0).unsqueeze(0)
if 'position_embeddings' in kwargs:
cos = kwargs['position_embeddings'][0][:, -1, :].unsqueeze(1)
sin = kwargs['position_embeddings'][1][:, -1, :].unsqueeze(1)
kwargs['position_embeddings'] = (cos, sin)

return args, kwargs

Expand Down
30 changes: 16 additions & 14 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def a_qdq(self, act, module, aquantizer, input_index=0):

def get_replacement_params(self, mode='fake_quant', w_only=False, name=None):
params_dict = {}
if mode == 'fake_quant':
if mode in ['fake_quant', 'fake_quant_wo_kv']:
if not self.mix_bits:
params_dict['a_qdq'] = (
partial(self.a_qdq, aquantizer=self.aquantizer)
Expand Down Expand Up @@ -229,17 +229,16 @@ def set_quant_config(self):
# set kv cache quant config
if 'kvcache' in self.quant_config:
self.quant_config['kvcache']['static'] = self.act_static
kv_special_cfg = self.quant_config['kvcache'].get('special', {})
logger.info(kv_special_cfg)
act_static_cfg = {}
if self.act_static:
self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']](
self.quant_type, self.quant_config['kvcache'],
self.model.model_config.num_hidden_layers, self.config.calib.n_samples,
self.config.calib.bs
)
else:
self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']](
self.quant_type, self.quant_config['kvcache'],
self.model.model_config.num_hidden_layers
)
act_static_cfg.update(self.config.calib.n_sample)
act_static_cfg.update(self.config.calib.bs)
self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']](
self.quant_type, self.quant_config['kvcache'],
self.model.model_config.num_hidden_layers, **kv_special_cfg, **act_static_cfg
)
self.quant_kvcache = True
self.model.kvcache_buffer.append(self.kv_module)
else:
Expand Down Expand Up @@ -860,6 +859,7 @@ def deploy(self, quant_format, keep_device=False):
module_mapping = {
'origin_float': OriginFloatLinear,
'fake_quant': EffcientFakeQuantLinear,
'fake_quant_wo_kv': EffcientFakeQuantLinear,
}
module_mapping.update(_REALQUANT_LINEAR_MAP_)

Expand All @@ -884,10 +884,12 @@ def deploy(self, quant_format, keep_device=False):
self.set_non_linear_mode(quant_format, self.model.model, False)

if self.quant_kvcache:
if quant_format == 'transformed':
self.kv_module.transformed = True
if quant_format == 'origin_float':
self.kv_module.use_org_kv = True
elif quant_format == 'fake_quant_wo_kv':
self.kv_module.use_org_kv = True
elif quant_format == 'fake_quant':
self.kv_module.transformed = False
self.kv_module.use_org_kv = False
if self.act_static:
self.kv_module.calib = False

Expand Down
Loading

0 comments on commit 08854d6

Please sign in to comment.