Skip to content

Commit

Permalink
Support ShadowKV and fix bugs (#291)
Browse files Browse the repository at this point in the history
* Support ShadowKV and fix bugs

* Support ShadowKV and fix bugs

* Support ShadowKV and fix bugs

---------

Co-authored-by: gushiqiao <[email protected]>
  • Loading branch information
gushiqiao and gushiqiao authored Jan 13, 2025
1 parent 7ad5a05 commit 271de80
Show file tree
Hide file tree
Showing 26 changed files with 1,479 additions and 701 deletions.
46 changes: 46 additions & 0 deletions config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
base:
seed: &seed 42
model:
type: Qwen2
path: /home/gushiqiao/nvme/gushiqiao/bussinesss/code_72b/SenseChat-Code-Tmp
tokenizer_mode: fast
torch_dtype: auto
calib:
name: pileval
download: False
path: /home/gushiqiao/nvme/gushiqiao/llm_datasets/calib/pileval
n_samples: 256
bs: -1
seq_len: 512
preproc: txt_general_preproc
seed: *seed
# eval:
# - eval_pos: [ fake_quant]
# name: wikitext2
# download: False
# path: /home/gushiqiao/nvme/gushiqiao/llm_datasets/eval/wikitext2
# seq_len: 2048
# # For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False".
# # For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True".
# bs: 10
# inference_per_block: True
quant:
method: Awq
weight:
bit: 8
symmetric: True
granularity: per_channel
group_size: -1
act:
bit: 8
symmetric: True
granularity: per_token
special:
trans: True
trans_version: v2
weight_clip: False
awq_bs: 128
quant_out: True
save:
save_trans: True
save_path: ./awq_test_new_pileval_down_ov/
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ model:
path: model path
torch_dtype: auto
eval:
eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos
eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #decode_ppl eval not support pretrain eval pos
name: wikitext2
type: decode_ppl
download: False
path: eval_data_path
bs: 1
inference_per_block: False
num_samples: 10
num_samples: 50
# num_eval_tokens: 3
quant:
method: RTN
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ model:
path: model path
torch_dtype: auto
eval:
eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #long_ppl eval not support pretrain eval pos
eval_pos: [transformed, fake_quant, fake_quant_wo_kv] #decode_ppl eval not support pretrain eval pos
name: wikitext2
type: decode_ppl
download: False
path: eval_data_path
bs: 1
inference_per_block: False
num_samples: 10
num_samples: 50
# num_eval_tokens: 3
quant:
method: RTN
Expand Down
38 changes: 0 additions & 38 deletions configs/quantization/methods/KVQuant/rtn_w_a_sink_quant_kv.yml

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ quant:
symmetric: True
granularity: per_tensor
static: True
calib_algo: static_hist
save:
save_fake: False
save_path: /path/to/save/
22 changes: 22 additions & 0 deletions configs/sparsification/methods/Kvsparse/shadowkv.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
base:
seed: &seed 42
model:
type: model_type
path: model path
torch_dtype: torch.bfloat16
eval:
eval_pos: [transformed]
name: wikitext2
download: False
path: eval_data_path
bs: 1
seq_len: 2048
sparse:
method: Dense
kvcache:
method: ShadowKV
replace_attn: True
sparsity_out: False
save:
save_trans: False
save_path: ./save
25 changes: 25 additions & 0 deletions configs/sparsification/methods/Kvsparse/sinkkv.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
base:
seed: &seed 42
model:
type: model_type
path: model path
torch_dtype: torch.bfloat16
eval:
eval_pos: [transformed]
name: wikitext2
type: decode_ppl
download: False
path: eval_data_path
bs: 1
inference_per_block: False
num_samples: 50
# num_eval_tokens: 3
sparse:
method: Dense
kvcache:
method: SinkKV
window_length: 256
num_sink_tokens: 4
save:
save_fake: False
save_path: /path/to/save/
3 changes: 1 addition & 2 deletions configs/sparsification/methods/Magnitude/magnitude.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,5 @@ sparse:
weight:
sparsity: 0.5
save:
save_fp: False
save_lightllm: False
save_trans: False
save_path: ./save
3 changes: 1 addition & 2 deletions configs/sparsification/methods/ShortGPT/shortgpt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ sparse:
weight:
n_prune_layers: 9
save:
save_trans: True
save_fp: False
save_trans: False
save_lightllm: False
save_path: ./save
3 changes: 1 addition & 2 deletions configs/sparsification/methods/Wanda/wanda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,5 @@ sparse:
sparsity: 0.5
sparsity_out: False
save:
save_fp: False
save_lightllm: False
save_trans: False
save_path: ./save
24 changes: 17 additions & 7 deletions llmc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,22 @@ def main(config):
for modality in get_modality(config):
model.set_modality(modality)
if not config.get('calib', False):
blockwise_opt = ALGO_REGISTRY[config.quant.method](
model,
quant_config=config.quant,
input=None,
padding_mask=None,
config=config,
)
if not config.get('sparse', False):
blockwise_opt = ALGO_REGISTRY[config.quant.method](
model,
config.quant,
None,
None,
config,
)
else:
blockwise_opt = ALGO_REGISTRY[config.sparse.method](
model,
config.sparse,
None,
None,
config,
)
blockwise_opt.run_block_loop()
dist.barrier()
else:
Expand Down Expand Up @@ -98,6 +107,7 @@ def main(config):
)

eval_model(model, blockwise_opt, eval_list, eval_pos='fake_quant')
eval_model(model, blockwise_opt, eval_list, eval_pos='fake_quant_wo_kv')

if 'save' in config and config.save.get('save_fake', False):
blockwise_opt.deploy('fake_quant')
Expand Down
64 changes: 34 additions & 30 deletions llmc/compression/blockwise_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@


class BlockwiseOpt(metaclass=ABCMeta):
def __init__(self, model, quant_config, input, padding_mask, config):
def __init__(self, model, compress_config, input, padding_mask, config):
self.model = model
self.blocks = model.get_blocks()
self.quant_config = quant_config
self.sparsity_config = quant_config
self.quant_config = compress_config
self.sparsity_config = compress_config
self.input = input
self.padding_mask = padding_mask
self.data_free = False if self.input else True
Expand Down Expand Up @@ -60,37 +60,41 @@ def cache_input_hook(self, m, x, y, name, feat_dict):
else:
feat_dict[name].append(tuple(inputs))

def kv_cache_input_hook(self):
def kv_cache_input_hook(self, attn_layer):
def hook_fn(module, args, kwargs):
kvcache = getattr(module, 'kvcache')
kwargs['past_key_value'] = kvcache
kwargs['use_cache'] = True
if kwargs['hidden_states'].shape[1] == 1:
if kwargs['position_ids'].shape[1] == 1:
# For eval decoding PPL (Perplexity), it will be removed in future versions.
past_seen_tokens = kvcache.get_seq_length()
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + kwargs['hidden_states'].shape[1],
device=kwargs['hidden_states'].device,
if self.config.eval.get('type', None) == 'decode_ppl':
# For eval decoding PPL (Perplexity).
past_seen_tokens = kvcache.get_seq_length()
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + kwargs['hidden_states'].shape[1],
device=kwargs['hidden_states'].device,
)
kwargs['cache_position'] = cache_position
position_ids = cache_position.unsqueeze(0)
kwargs['position_ids'] = position_ids
if 'position_embeddings' in kwargs:
kwargs['position_embeddings'] = self.model.rotary_emb(
kwargs['hidden_states'], position_ids
)
kwargs['cache_position'] = cache_position
position_ids = cache_position.unsqueeze(0)
kwargs['position_ids'] = position_ids
if 'position_embeddings' in kwargs:
kwargs['position_embeddings'] = self.model.rotary_emb(
kwargs['hidden_states'], position_ids
)
else:
if self.config['model']['type'] in ['DeepseekV2']:
kwargs['position_ids'] = kwargs['position_ids'][:, -1].unsqueeze(1)
else:
kwargs['position_ids'] = \
kwargs['position_ids'][:, -1].unsqueeze(0).unsqueeze(0)
if 'position_embeddings' in kwargs:
cos = kwargs['position_embeddings'][0][:, -1, :].unsqueeze(1)
sin = kwargs['position_embeddings'][1][:, -1, :].unsqueeze(1)
kwargs['position_embeddings'] = (cos, sin)
if kwargs['hidden_states'].shape[1] == 1:
from .sparsification.kvsparse import ShadowKVCache
if isinstance(kvcache, ShadowKVCache):
hidden_states = kwargs['hidden_states'][:, -1, :].unsqueeze(0)
kwargs['hidden_states'] = hidden_states
bsz, q_len, _ = hidden_states.size()
tmp_query_states = \
attn_layer.q_proj(hidden_states).view(bsz,
q_len,
-1,
attn_layer.head_dim).transpose(1, 2)
retrieval_position_ids = \
kvcache.get_retrieval_position_ids(layer_idx=attn_layer.layer_idx,
query_states=tmp_query_states)
kwargs['retrieval_position_ids'] = retrieval_position_ids
kwargs['cos_sin_cache'] = self.cos_sin_cache

return args, kwargs

Expand Down
2 changes: 1 addition & 1 deletion llmc/compression/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .dgq import DGQ
from .gptq import GPTQ
from .hqq import HQQ
from .kvquant import NaiveQuantKVCache
from .kvquant import KiviQuantKVCache, NaiveQuantKVCache
from .llmint8 import LlmInt8
from .module_utils import FakeQuantLinear
from .ntweak import NormTweaking
Expand Down
Loading

0 comments on commit 271de80

Please sign in to comment.