Skip to content

Commit 08854d6

Browse files
gushiqiaogushiqiao
andauthored
Dev kv (#283)
* Fix omniq clip bugs * Fix omniq clip bugs * Fix awq trans_v1 gqa bug * Add streamllm(Sink) support and add eval decode PPL support --------- Co-authored-by: gushiqiao <[email protected]>
1 parent 18bdbdd commit 08854d6

File tree

13 files changed

+412
-91
lines changed

13 files changed

+412
-91
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
base:
2+
seed: &seed 42
3+
model:
4+
type: model_type
5+
path: model path
6+
torch_dtype: auto
7+
eval:
8+
eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos
9+
name: wikitext2
10+
type: decode_ppl
11+
download: False
12+
path: eval_data_path
13+
bs: 1
14+
inference_per_block: False
15+
num_samples: 10
16+
# num_eval_tokens: 3
17+
quant:
18+
method: RTN
19+
weight:
20+
bit: 8
21+
symmetric: True
22+
granularity: per_channel
23+
group_size: -1
24+
act:
25+
bit: 8
26+
symmetric: True
27+
granularity: per_token
28+
kvcache:
29+
method: Kivi
30+
bit: 8
31+
symmetric: True
32+
granularity: per_token
33+
save:
34+
save_fake: False
35+
save_path: /path/to/save/

configs/quantization/methods/RTN/rtn_w_a_kv_human_eval.yml renamed to configs/quantization/methods/KVQuant/rtn_w_a_naive_quant_kv.yml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,15 @@ model:
55
path: model path
66
torch_dtype: auto
77
eval:
8-
eval_pos: [pretrain, fake_quant]
9-
type: code
10-
name: human_eval
11-
res_path: ./human_eval/
12-
# For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False".
13-
# For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True".
8+
eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos
9+
name: wikitext2
10+
type: decode_ppl
11+
download: False
12+
path: eval_data_path
1413
bs: 1
15-
format_tabs: True
1614
inference_per_block: False
17-
# add_chat_temp: True
15+
num_samples: 10
16+
# num_eval_tokens: 3
1817
quant:
1918
method: RTN
2019
weight:

configs/quantization/methods/RTN/rtn_w_a_pertensor_static_kv.yml renamed to configs/quantization/methods/KVQuant/rtn_w_a_pertensor_static_naive_quant_kv.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ calib:
1414
preproc: txt_general_preproc
1515
seed: *seed
1616
eval:
17-
eval_pos: [pretrain, fake_quant]
17+
eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos
1818
name: wikitext2
19+
type: decode_ppl
1920
download: False
20-
path: eval data path
21-
seq_len: 2048
22-
# For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False".
23-
# For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True".
21+
path: eval_data_path
2422
bs: 1
2523
inference_per_block: False
24+
num_samples: 10
25+
# num_eval_tokens: 3
2626
quant:
2727
method: RTN
2828
weight:
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
base:
2+
seed: &seed 42
3+
model:
4+
type: model_type
5+
path: model path
6+
torch_dtype: auto
7+
eval:
8+
eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos
9+
name: wikitext2
10+
type: decode_ppl
11+
download: False
12+
path: eval_data_path
13+
bs: 1
14+
inference_per_block: False
15+
num_samples: 10
16+
# num_eval_tokens: 3
17+
quant:
18+
method: RTN
19+
weight:
20+
bit: 8
21+
symmetric: True
22+
granularity: per_channel
23+
group_size: -1
24+
act:
25+
bit: 8
26+
symmetric: True
27+
granularity: per_token
28+
kvcache:
29+
method: Sink
30+
bit: 4
31+
symmetric: True
32+
granularity: per_token
33+
special:
34+
window_length: 512
35+
num_sink_tokens: 4
36+
save:
37+
save_fake: False
38+
save_path: /path/to/save/

llmc/compression/blockwise_optimization.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ def run_block_loop(self):
4545

4646
if hasattr(self, 'save_clip') and self.save_clip:
4747
os.makedirs(self.clip_path, exist_ok=True)
48-
torch.save(self.auto_clipper.weight_clips, os.path.join(self.clip_path, 'clips.pth'))
48+
torch.save(
49+
self.auto_clipper.weight_clips,
50+
os.path.join(self.clip_path, 'clips.pth'),
51+
)
4952

5053
def cache_input_hook(self, m, x, y, name, feat_dict):
5154
inputs = [i.detach().cpu() for i in x]
@@ -63,14 +66,31 @@ def hook_fn(module, args, kwargs):
6366
kwargs['past_key_value'] = kvcache
6467
kwargs['use_cache'] = True
6568
if kwargs['hidden_states'].shape[1] == 1:
66-
if self.config['model']['type'] in ['DeepseekV2']:
67-
kwargs['position_ids'] = kwargs['position_ids'][:, -1].unsqueeze(1)
69+
if kwargs['position_ids'].shape[1] == 1:
70+
# For eval decoding PPL (Perplexity), it will be removed in future versions.
71+
past_seen_tokens = kvcache.get_seq_length()
72+
cache_position = torch.arange(
73+
past_seen_tokens,
74+
past_seen_tokens + kwargs['hidden_states'].shape[1],
75+
device=kwargs['hidden_states'].device,
76+
)
77+
kwargs['cache_position'] = cache_position
78+
position_ids = cache_position.unsqueeze(0)
79+
kwargs['position_ids'] = position_ids
80+
if 'position_embeddings' in kwargs:
81+
kwargs['position_embeddings'] = self.model.rotary_emb(
82+
kwargs['hidden_states'], position_ids
83+
)
6884
else:
69-
kwargs['position_ids'] = kwargs['position_ids'][:, -1].unsqueeze(0).unsqueeze(0)
70-
if 'position_embeddings' in kwargs:
71-
cos = kwargs['position_embeddings'][0][:, -1, :].unsqueeze(1)
72-
sin = kwargs['position_embeddings'][1][:, -1, :].unsqueeze(1)
73-
kwargs['position_embeddings'] = (cos, sin)
85+
if self.config['model']['type'] in ['DeepseekV2']:
86+
kwargs['position_ids'] = kwargs['position_ids'][:, -1].unsqueeze(1)
87+
else:
88+
kwargs['position_ids'] = \
89+
kwargs['position_ids'][:, -1].unsqueeze(0).unsqueeze(0)
90+
if 'position_embeddings' in kwargs:
91+
cos = kwargs['position_embeddings'][0][:, -1, :].unsqueeze(1)
92+
sin = kwargs['position_embeddings'][1][:, -1, :].unsqueeze(1)
93+
kwargs['position_embeddings'] = (cos, sin)
7494

7595
return args, kwargs
7696

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def a_qdq(self, act, module, aquantizer, input_index=0):
5858

5959
def get_replacement_params(self, mode='fake_quant', w_only=False, name=None):
6060
params_dict = {}
61-
if mode == 'fake_quant':
61+
if mode in ['fake_quant', 'fake_quant_wo_kv']:
6262
if not self.mix_bits:
6363
params_dict['a_qdq'] = (
6464
partial(self.a_qdq, aquantizer=self.aquantizer)
@@ -229,17 +229,16 @@ def set_quant_config(self):
229229
# set kv cache quant config
230230
if 'kvcache' in self.quant_config:
231231
self.quant_config['kvcache']['static'] = self.act_static
232+
kv_special_cfg = self.quant_config['kvcache'].get('special', {})
233+
logger.info(kv_special_cfg)
234+
act_static_cfg = {}
232235
if self.act_static:
233-
self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']](
234-
self.quant_type, self.quant_config['kvcache'],
235-
self.model.model_config.num_hidden_layers, self.config.calib.n_samples,
236-
self.config.calib.bs
237-
)
238-
else:
239-
self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']](
240-
self.quant_type, self.quant_config['kvcache'],
241-
self.model.model_config.num_hidden_layers
242-
)
236+
act_static_cfg.update(self.config.calib.n_sample)
237+
act_static_cfg.update(self.config.calib.bs)
238+
self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']](
239+
self.quant_type, self.quant_config['kvcache'],
240+
self.model.model_config.num_hidden_layers, **kv_special_cfg, **act_static_cfg
241+
)
243242
self.quant_kvcache = True
244243
self.model.kvcache_buffer.append(self.kv_module)
245244
else:
@@ -860,6 +859,7 @@ def deploy(self, quant_format, keep_device=False):
860859
module_mapping = {
861860
'origin_float': OriginFloatLinear,
862861
'fake_quant': EffcientFakeQuantLinear,
862+
'fake_quant_wo_kv': EffcientFakeQuantLinear,
863863
}
864864
module_mapping.update(_REALQUANT_LINEAR_MAP_)
865865

@@ -884,10 +884,12 @@ def deploy(self, quant_format, keep_device=False):
884884
self.set_non_linear_mode(quant_format, self.model.model, False)
885885

886886
if self.quant_kvcache:
887-
if quant_format == 'transformed':
888-
self.kv_module.transformed = True
887+
if quant_format == 'origin_float':
888+
self.kv_module.use_org_kv = True
889+
elif quant_format == 'fake_quant_wo_kv':
890+
self.kv_module.use_org_kv = True
889891
elif quant_format == 'fake_quant':
890-
self.kv_module.transformed = False
892+
self.kv_module.use_org_kv = False
891893
if self.act_static:
892894
self.kv_module.calib = False
893895

0 commit comments

Comments
 (0)