Skip to content

Commit

Permalink
support modality quant (#293)
Browse files Browse the repository at this point in the history
* support modality quant

* fix bug

* Update __main__.py

---------

Co-authored-by: Yang Yong <[email protected]>
  • Loading branch information
chengtao-lv and helloyongyang authored Jan 13, 2025
1 parent 271de80 commit da4c436
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 120 deletions.
45 changes: 30 additions & 15 deletions configs/quantization/methods/Awq/awq_w_only_vlm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,36 @@ eval:
bs: 1
inference_per_block: False
quant:
method: Awq
quant_objects: [vision, language] # default is [language]
weight:
bit: 4
symmetric: False
granularity: per_group
group_size: 128
special:
trans: True
# The options for "trans_version" include "v1" and "v2".
# But their results don't differ significantly.
trans_version: v2
weight_clip: True
# For 2-bit quantization, setting "clip_sym: False" will yield better results.
clip_sym: True
vision:
method: Awq
weight:
bit: 4
symmetric: False
granularity: per_group
group_size: 128
special:
trans: True
# The options for "trans_version" include "v1" and "v2".
# But their results don't differ significantly.
trans_version: v2
weight_clip: True
# For 2-bit quantization, setting "clip_sym: False" will yield better results.
clip_sym: True
language:
method: Awq
weight:
bit: 4
symmetric: False
granularity: per_group
group_size: 128
special:
trans: True
# The options for "trans_version" include "v1" and "v2".
# But their results don't differ significantly.
trans_version: v2
weight_clip: True
# For 2-bit quantization, setting "clip_sym: False" will yield better results.
clip_sym: True
save:
save_trans: False
save_fake: False
Expand Down
153 changes: 66 additions & 87 deletions llmc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,12 @@
from llmc.data import BaseDataset
from llmc.eval.utils import eval_model, get_eval_list
from llmc.models import *
from llmc.utils import (check_config, mkdirs, print_important_package_version,
seed_all, update_autoawq_quant_config,
update_vllm_quant_config)
from llmc.utils import (check_config, deploy_all_modality, get_modality,
mkdirs, print_important_package_version, seed_all,
update_autoawq_quant_config, update_vllm_quant_config)
from llmc.utils.registry_factory import ALGO_REGISTRY, MODEL_REGISTRY


def get_modality(config):
if 'quant' in config:
return config.quant.get('quant_objects', ['language'])
elif 'sparse' in config:
return config.sparse.get('sparse_objects', ['language'])
else:
return ['language']


def main(config):
model = MODEL_REGISTRY[config.model.type](config)

Expand All @@ -41,27 +32,20 @@ def main(config):
eval_list = get_eval_list(model, config)
eval_model(model, None, eval_list, eval_pos='pretrain')

# for modality in config.quant.get('quant_objects', ['language']):
for modality in get_modality(config):
blockwise_opts = []
modalities, modality_configs = get_modality(config)
for modality, modality_config in zip(modalities, modality_configs):
model.set_modality(modality)
if not config.get('calib', False):
if not config.get('sparse', False):
blockwise_opt = ALGO_REGISTRY[config.quant.method](
model,
config.quant,
None,
None,
config,
)
else:
blockwise_opt = ALGO_REGISTRY[config.sparse.method](
model,
config.sparse,
None,
None,
config,
)
blockwise_opt = ALGO_REGISTRY[modality_config.method](
model,
quant_config=modality_config,
input=None,
padding_mask=None,
config=config,
)
blockwise_opt.run_block_loop()
blockwise_opts.append(blockwise_opt)
dist.barrier()
else:
dataset = BaseDataset(
Expand All @@ -72,26 +56,18 @@ def main(config):
del calib_data
gc.collect()
torch.cuda.empty_cache()
if not config.get('sparse', False):
blockwise_opt = ALGO_REGISTRY[config.quant.method](
model,
config.quant,
model.get_first_block_input(),
model.get_padding_mask(),
config,
)
else:
blockwise_opt = ALGO_REGISTRY[config.sparse.method](
model,
config.sparse,
model.get_first_block_input(),
model.get_padding_mask(),
config,
)
blockwise_opt = ALGO_REGISTRY[modality_config.method](
model,
modality_config,
model.get_first_block_input(),
model.get_padding_mask(),
config,
)
blockwise_opt.run_block_loop()
blockwise_opts.append(blockwise_opt)
dist.barrier()

eval_model(model, blockwise_opt, eval_list, eval_pos='transformed')
eval_model(model, blockwise_opts, eval_list, eval_pos='transformed')
if int(os.environ['RANK']) == 0:
if 'save' in config and config.save.get('save_trans', False):
blockwise_opt.save_model(save_trans_path)
Expand All @@ -106,11 +82,11 @@ def main(config):
config.save.get('trtllm_cfg'),
)

eval_model(model, blockwise_opt, eval_list, eval_pos='fake_quant')
eval_model(model, blockwise_opt, eval_list, eval_pos='fake_quant_wo_kv')
eval_model(model, blockwise_opts, eval_list, eval_pos='fake_quant')
eval_model(model, blockwise_opts, eval_list, eval_pos='fake_quant_wo_kv')

if 'save' in config and config.save.get('save_fake', False):
blockwise_opt.deploy('fake_quant')
deploy_all_modality(blockwise_opts, 'fake_quant')
blockwise_opt.save_model(save_fake_path)

if 'save' in config:
Expand All @@ -119,56 +95,59 @@ def main(config):
or config.save.get('save_sgl', False)
or config.save.get('save_lightllm', False)
):
w, a = config.quant.weight, config.quant.get('act')

if isinstance(w.bit, str):
assert a, 'Only WA float quant is supported.'
assert (
w.symmetric and a.symmetric
), 'Only symmetric quant is supported.'
assert (
w.bit == a.bit
and w.bit in ['e4m3', 'e5m2']
and a.bit in ['e4m3', 'e5m2']
), 'Only WA FP8 quant is supported'
else:
assert w.symmetric, 'Only symmetric quant is supported.'
assert w.bit in [4, 8], 'Supported quant: w4a16, w8a16, w8a8.'
if a:
assert a.symmetric, 'Only symmetric quant is supported.'
assert a.bit == 8, 'Supported quant: w4a16, w8a16, w8a8.'
for modality_config in modality_configs:
w, a = modality_config.weight, modality_config.get('act')

if isinstance(w.bit, str):
assert a, 'Only WA float quant is supported.'
assert (
w.symmetric and a.symmetric
), 'Only symmetric quant is supported.'
assert (
w.bit == a.bit
and w.bit in ['e4m3', 'e5m2']
and a.bit in ['e4m3', 'e5m2']
), 'Only WA FP8 quant is supported'
else:
assert w.symmetric, 'Only symmetric quant is supported.'
assert w.bit in [4, 8], 'Supported quant: w4a16, w8a16, w8a8.'
if a:
assert a.symmetric, 'Only symmetric quant is supported.'
assert a.bit == 8, 'Supported quant: w4a16, w8a16, w8a8.'

if config.save.get('save_vllm', False):
blockwise_opt.deploy('vllm_quant')
deploy_all_modality(blockwise_opts, 'vllm_quant')
if config.save.get('save_lightllm', False):
blockwise_opt.deploy('lightllm_quant')
deploy_all_modality(blockwise_opts, 'lightllm_quant')
if config.save.get('save_sgl', False):
blockwise_opt.deploy('sgl_quant')
deploy_all_modality(blockwise_opts, 'sgl_quant')

blockwise_opt.save_model(save_quant_path)
update_vllm_quant_config(blockwise_opt.model, config, save_quant_path)

if 'save' in config and config.save.get('save_autoawq', False):
assert (
config.quant.weight.bit in [4] and 'act' not in config.quant
), 'AutoAWQ supports only 4-bit weight-only quantization.'
assert (
not config.quant.weight.symmetric
), 'Only asymmetric quant is supported.'

blockwise_opt.deploy('autoawq_quant')
for modality_config in modality_configs:
assert (
modality_config.weight.bit in [4] and 'act' not in modality_config
), 'AutoAWQ supports only 4-bit weight-only quantization.'
assert (
not modality_config.weight.symmetric
), 'Only asymmetric quant is supported.'

deploy_all_modality(blockwise_opts, 'autoawq_quant')
blockwise_opt.save_model(save_quant_path)
update_autoawq_quant_config(config, save_quant_path)

if 'save' in config and config.save.get('save_mlcllm', False):
assert (
config.quant.weight.bit in [4] and 'act' not in config.quant
), 'MlcLLM supports only 4-bit weight-only quantization.'
assert (
not config.quant.weight.symmetric
), 'Only asymmetric quant is supported.'

blockwise_opt.deploy('mlcllm_quant')
for modality_config in modality_configs:
assert (
modality_config.weight.bit in [4] and 'act' not in modality_config
), 'MlcLLM supports only 4-bit weight-only quantization.'
assert (
not modality_config.weight.symmetric
), 'Only asymmetric quant is supported.'

deploy_all_modality(blockwise_opts, 'mlcllm_quant')
blockwise_opt.save_model(save_quant_path)
update_autoawq_quant_config(config, save_quant_path)

Expand Down
8 changes: 4 additions & 4 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ def set_quant_config(self):

self.hidden_size = self.model.model_config.hidden_size
self.set_model_config()
self.quant_objects = self.quant_config.get('quant_objects', ['language'])
logger.info(f'self.quant_objects : {self.quant_objects}')
self.modality = self.quant_config.modality
logger.info(f'self.quant_objects : {self.quant_config.modality}')

def set_model_config(self):
self.hidden_size = self.model.model_config.hidden_size
Expand Down Expand Up @@ -877,13 +877,13 @@ def deploy(self, quant_format, keep_device=False):
)

module = module_mapping[quant_format]
if 'vision' in self.quant_objects:
if self.modality == 'vision':
self.model.replace_vision_module_all(
module,
self.get_replacement_params(mode=quant_format, w_only=self.w_only),
keep_device=keep_device,
)
if 'language' in self.quant_objects:
if self.modality == 'language':
self.model.replace_language_module_all(
module,
self.get_replacement_params(mode=quant_format, w_only=self.w_only),
Expand Down
7 changes: 4 additions & 3 deletions llmc/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from llmc.eval import (AccuracyEval, CustomGenerate, DecodePerplexityEval,
HumanEval, PerplexityEval, TokenConsistencyEval,
VQAEval)
from llmc.utils import deploy_all_modality


def get_eval_list(model, config):
Expand Down Expand Up @@ -70,17 +71,17 @@ def get_eval_list(model, config):
return eval_list


def eval_model(model, blockwise_opt, eval_list, eval_pos):
def eval_model(model, blockwise_opts, eval_list, eval_pos):
if int(os.environ['RANK']) == 0:
do_eval = False
for _, config_for_eval in eval_list:
if eval_pos in config_for_eval.eval.eval_pos:
do_eval = True
if do_eval:
if eval_pos == 'transformed':
blockwise_opt.deploy('origin_float')
deploy_all_modality(blockwise_opts, 'origin_float')
elif eval_pos in ['fake_quant', 'fake_quant_wo_kv']:
blockwise_opt.deploy('fake_quant')
deploy_all_modality(blockwise_opts, 'fake_quant')
for eval_class, config_for_eval in eval_list:
if eval_pos in config_for_eval.eval.eval_pos:
res = eval_class.eval(model)
Expand Down
5 changes: 3 additions & 2 deletions llmc/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .export_autoawq import update_autoawq_quant_config
from .export_vllm import update_vllm_quant_config
from .utils import (check_config, copy_files, mkdirs,
print_important_package_version, seed_all)
from .utils import (check_config, copy_files, deploy_all_modality,
get_modality, mkdirs, print_important_package_version,
seed_all)
41 changes: 32 additions & 9 deletions llmc/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,18 @@ def check_weight_setting(weight_setting):
elif weight_setting.granularity == 'per_head':
assert weight_setting.head_num > 0

if config.quant.weight.get('granularity', False):
weight_setting = config.quant.weight
check_weight_setting(weight_setting)
if config.quant.weight.get('w_1', False):
weight_setting = config.quant.weight.w_1
check_weight_setting(weight_setting)
if config.quant.weight.get('w_2', False):
weight_setting = config.quant.weight.w_2
check_weight_setting(weight_setting)
for _, modality_config in config.quant.items():
if not isinstance(modality_config, dict) or not modality_config.get('weight', False):
continue
if modality_config.weight.get('granularity', False):
weight_setting = modality_config.weight
check_weight_setting(weight_setting)
if modality_config.weight.get('w_1', False):
weight_setting = modality_config.weight.w_1
check_weight_setting(weight_setting)
if modality_config.weight.get('w_2', False):
weight_setting = modality_config.weight.w_2
check_weight_setting(weight_setting)
if config.model.get('tokenizer_mode', False):
assert (
config.model.tokenizer_mode == 'slow'
Expand Down Expand Up @@ -72,3 +75,23 @@ def print_important_package_version():
logger.info(f"tokenizers : {version('tokenizers')}")
logger.info(f"huggingface-hub : {version('huggingface-hub')}")
logger.info(f"datasets : {version('datasets')}")


def get_modality(config):
modalities = []
modality_configs = []
compression_config = config.quant if 'quant' in config else config.sparse
for modality in ['vision', 'language']:
if modality in compression_config:
compression_config[modality].modality = modality
modalities.append(modality)
modality_configs.append(compression_config[modality])
if not modalities:
compression_config.modality = 'language'
return ['language'], [compression_config]
return modalities, modality_configs


def deploy_all_modality(blockwise_opts, quant_format):
for blockwise_opt in blockwise_opts:
blockwise_opt.deploy(quant_format)

0 comments on commit da4c436

Please sign in to comment.