From da4c436c2aa9b311dc62335d047a287cebacceff Mon Sep 17 00:00:00 2001 From: chengtao-lv <897674362@qq.com> Date: Mon, 13 Jan 2025 20:14:43 +0800 Subject: [PATCH] support modality quant (#293) * support modality quant * fix bug * Update __main__.py --------- Co-authored-by: Yang Yong --- .../methods/Awq/awq_w_only_vlm.yml | 45 ++++-- llmc/__main__.py | 153 ++++++++---------- .../base_blockwise_quantization.py | 8 +- llmc/eval/utils.py | 7 +- llmc/utils/__init__.py | 5 +- llmc/utils/utils.py | 41 +++-- 6 files changed, 139 insertions(+), 120 deletions(-) diff --git a/configs/quantization/methods/Awq/awq_w_only_vlm.yml b/configs/quantization/methods/Awq/awq_w_only_vlm.yml index 4607a8d2..6621a18e 100644 --- a/configs/quantization/methods/Awq/awq_w_only_vlm.yml +++ b/configs/quantization/methods/Awq/awq_w_only_vlm.yml @@ -25,21 +25,36 @@ eval: bs: 1 inference_per_block: False quant: - method: Awq - quant_objects: [vision, language] # default is [language] - weight: - bit: 4 - symmetric: False - granularity: per_group - group_size: 128 - special: - trans: True - # The options for "trans_version" include "v1" and "v2". - # But their results don't differ significantly. - trans_version: v2 - weight_clip: True - # For 2-bit quantization, setting "clip_sym: False" will yield better results. - clip_sym: True + vision: + method: Awq + weight: + bit: 4 + symmetric: False + granularity: per_group + group_size: 128 + special: + trans: True + # The options for "trans_version" include "v1" and "v2". + # But their results don't differ significantly. + trans_version: v2 + weight_clip: True + # For 2-bit quantization, setting "clip_sym: False" will yield better results. + clip_sym: True + language: + method: Awq + weight: + bit: 4 + symmetric: False + granularity: per_group + group_size: 128 + special: + trans: True + # The options for "trans_version" include "v1" and "v2". + # But their results don't differ significantly. + trans_version: v2 + weight_clip: True + # For 2-bit quantization, setting "clip_sym: False" will yield better results. + clip_sym: True save: save_trans: False save_fake: False diff --git a/llmc/__main__.py b/llmc/__main__.py index 01077eb6..5efa2be7 100644 --- a/llmc/__main__.py +++ b/llmc/__main__.py @@ -17,21 +17,12 @@ from llmc.data import BaseDataset from llmc.eval.utils import eval_model, get_eval_list from llmc.models import * -from llmc.utils import (check_config, mkdirs, print_important_package_version, - seed_all, update_autoawq_quant_config, - update_vllm_quant_config) +from llmc.utils import (check_config, deploy_all_modality, get_modality, + mkdirs, print_important_package_version, seed_all, + update_autoawq_quant_config, update_vllm_quant_config) from llmc.utils.registry_factory import ALGO_REGISTRY, MODEL_REGISTRY -def get_modality(config): - if 'quant' in config: - return config.quant.get('quant_objects', ['language']) - elif 'sparse' in config: - return config.sparse.get('sparse_objects', ['language']) - else: - return ['language'] - - def main(config): model = MODEL_REGISTRY[config.model.type](config) @@ -41,27 +32,20 @@ def main(config): eval_list = get_eval_list(model, config) eval_model(model, None, eval_list, eval_pos='pretrain') - # for modality in config.quant.get('quant_objects', ['language']): - for modality in get_modality(config): + blockwise_opts = [] + modalities, modality_configs = get_modality(config) + for modality, modality_config in zip(modalities, modality_configs): model.set_modality(modality) if not config.get('calib', False): - if not config.get('sparse', False): - blockwise_opt = ALGO_REGISTRY[config.quant.method]( - model, - config.quant, - None, - None, - config, - ) - else: - blockwise_opt = ALGO_REGISTRY[config.sparse.method]( - model, - config.sparse, - None, - None, - config, - ) + blockwise_opt = ALGO_REGISTRY[modality_config.method]( + model, + quant_config=modality_config, + input=None, + padding_mask=None, + config=config, + ) blockwise_opt.run_block_loop() + blockwise_opts.append(blockwise_opt) dist.barrier() else: dataset = BaseDataset( @@ -72,26 +56,18 @@ def main(config): del calib_data gc.collect() torch.cuda.empty_cache() - if not config.get('sparse', False): - blockwise_opt = ALGO_REGISTRY[config.quant.method]( - model, - config.quant, - model.get_first_block_input(), - model.get_padding_mask(), - config, - ) - else: - blockwise_opt = ALGO_REGISTRY[config.sparse.method]( - model, - config.sparse, - model.get_first_block_input(), - model.get_padding_mask(), - config, - ) + blockwise_opt = ALGO_REGISTRY[modality_config.method]( + model, + modality_config, + model.get_first_block_input(), + model.get_padding_mask(), + config, + ) blockwise_opt.run_block_loop() + blockwise_opts.append(blockwise_opt) dist.barrier() - eval_model(model, blockwise_opt, eval_list, eval_pos='transformed') + eval_model(model, blockwise_opts, eval_list, eval_pos='transformed') if int(os.environ['RANK']) == 0: if 'save' in config and config.save.get('save_trans', False): blockwise_opt.save_model(save_trans_path) @@ -106,11 +82,11 @@ def main(config): config.save.get('trtllm_cfg'), ) - eval_model(model, blockwise_opt, eval_list, eval_pos='fake_quant') - eval_model(model, blockwise_opt, eval_list, eval_pos='fake_quant_wo_kv') + eval_model(model, blockwise_opts, eval_list, eval_pos='fake_quant') + eval_model(model, blockwise_opts, eval_list, eval_pos='fake_quant_wo_kv') if 'save' in config and config.save.get('save_fake', False): - blockwise_opt.deploy('fake_quant') + deploy_all_modality(blockwise_opts, 'fake_quant') blockwise_opt.save_model(save_fake_path) if 'save' in config: @@ -119,56 +95,59 @@ def main(config): or config.save.get('save_sgl', False) or config.save.get('save_lightllm', False) ): - w, a = config.quant.weight, config.quant.get('act') - - if isinstance(w.bit, str): - assert a, 'Only WA float quant is supported.' - assert ( - w.symmetric and a.symmetric - ), 'Only symmetric quant is supported.' - assert ( - w.bit == a.bit - and w.bit in ['e4m3', 'e5m2'] - and a.bit in ['e4m3', 'e5m2'] - ), 'Only WA FP8 quant is supported' - else: - assert w.symmetric, 'Only symmetric quant is supported.' - assert w.bit in [4, 8], 'Supported quant: w4a16, w8a16, w8a8.' - if a: - assert a.symmetric, 'Only symmetric quant is supported.' - assert a.bit == 8, 'Supported quant: w4a16, w8a16, w8a8.' + for modality_config in modality_configs: + w, a = modality_config.weight, modality_config.get('act') + + if isinstance(w.bit, str): + assert a, 'Only WA float quant is supported.' + assert ( + w.symmetric and a.symmetric + ), 'Only symmetric quant is supported.' + assert ( + w.bit == a.bit + and w.bit in ['e4m3', 'e5m2'] + and a.bit in ['e4m3', 'e5m2'] + ), 'Only WA FP8 quant is supported' + else: + assert w.symmetric, 'Only symmetric quant is supported.' + assert w.bit in [4, 8], 'Supported quant: w4a16, w8a16, w8a8.' + if a: + assert a.symmetric, 'Only symmetric quant is supported.' + assert a.bit == 8, 'Supported quant: w4a16, w8a16, w8a8.' if config.save.get('save_vllm', False): - blockwise_opt.deploy('vllm_quant') + deploy_all_modality(blockwise_opts, 'vllm_quant') if config.save.get('save_lightllm', False): - blockwise_opt.deploy('lightllm_quant') + deploy_all_modality(blockwise_opts, 'lightllm_quant') if config.save.get('save_sgl', False): - blockwise_opt.deploy('sgl_quant') + deploy_all_modality(blockwise_opts, 'sgl_quant') blockwise_opt.save_model(save_quant_path) update_vllm_quant_config(blockwise_opt.model, config, save_quant_path) if 'save' in config and config.save.get('save_autoawq', False): - assert ( - config.quant.weight.bit in [4] and 'act' not in config.quant - ), 'AutoAWQ supports only 4-bit weight-only quantization.' - assert ( - not config.quant.weight.symmetric - ), 'Only asymmetric quant is supported.' - - blockwise_opt.deploy('autoawq_quant') + for modality_config in modality_configs: + assert ( + modality_config.weight.bit in [4] and 'act' not in modality_config + ), 'AutoAWQ supports only 4-bit weight-only quantization.' + assert ( + not modality_config.weight.symmetric + ), 'Only asymmetric quant is supported.' + + deploy_all_modality(blockwise_opts, 'autoawq_quant') blockwise_opt.save_model(save_quant_path) update_autoawq_quant_config(config, save_quant_path) if 'save' in config and config.save.get('save_mlcllm', False): - assert ( - config.quant.weight.bit in [4] and 'act' not in config.quant - ), 'MlcLLM supports only 4-bit weight-only quantization.' - assert ( - not config.quant.weight.symmetric - ), 'Only asymmetric quant is supported.' - - blockwise_opt.deploy('mlcllm_quant') + for modality_config in modality_configs: + assert ( + modality_config.weight.bit in [4] and 'act' not in modality_config + ), 'MlcLLM supports only 4-bit weight-only quantization.' + assert ( + not modality_config.weight.symmetric + ), 'Only asymmetric quant is supported.' + + deploy_all_modality(blockwise_opts, 'mlcllm_quant') blockwise_opt.save_model(save_quant_path) update_autoawq_quant_config(config, save_quant_path) diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index d4773030..d4e97b54 100644 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -291,8 +291,8 @@ def set_quant_config(self): self.hidden_size = self.model.model_config.hidden_size self.set_model_config() - self.quant_objects = self.quant_config.get('quant_objects', ['language']) - logger.info(f'self.quant_objects : {self.quant_objects}') + self.modality = self.quant_config.modality + logger.info(f'self.quant_objects : {self.quant_config.modality}') def set_model_config(self): self.hidden_size = self.model.model_config.hidden_size @@ -877,13 +877,13 @@ def deploy(self, quant_format, keep_device=False): ) module = module_mapping[quant_format] - if 'vision' in self.quant_objects: + if self.modality == 'vision': self.model.replace_vision_module_all( module, self.get_replacement_params(mode=quant_format, w_only=self.w_only), keep_device=keep_device, ) - if 'language' in self.quant_objects: + if self.modality == 'language': self.model.replace_language_module_all( module, self.get_replacement_params(mode=quant_format, w_only=self.w_only), diff --git a/llmc/eval/utils.py b/llmc/eval/utils.py index c6afae1a..191d54d4 100644 --- a/llmc/eval/utils.py +++ b/llmc/eval/utils.py @@ -6,6 +6,7 @@ from llmc.eval import (AccuracyEval, CustomGenerate, DecodePerplexityEval, HumanEval, PerplexityEval, TokenConsistencyEval, VQAEval) +from llmc.utils import deploy_all_modality def get_eval_list(model, config): @@ -70,7 +71,7 @@ def get_eval_list(model, config): return eval_list -def eval_model(model, blockwise_opt, eval_list, eval_pos): +def eval_model(model, blockwise_opts, eval_list, eval_pos): if int(os.environ['RANK']) == 0: do_eval = False for _, config_for_eval in eval_list: @@ -78,9 +79,9 @@ def eval_model(model, blockwise_opt, eval_list, eval_pos): do_eval = True if do_eval: if eval_pos == 'transformed': - blockwise_opt.deploy('origin_float') + deploy_all_modality(blockwise_opts, 'origin_float') elif eval_pos in ['fake_quant', 'fake_quant_wo_kv']: - blockwise_opt.deploy('fake_quant') + deploy_all_modality(blockwise_opts, 'fake_quant') for eval_class, config_for_eval in eval_list: if eval_pos in config_for_eval.eval.eval_pos: res = eval_class.eval(model) diff --git a/llmc/utils/__init__.py b/llmc/utils/__init__.py index 574b2717..cdd10a9e 100644 --- a/llmc/utils/__init__.py +++ b/llmc/utils/__init__.py @@ -1,4 +1,5 @@ from .export_autoawq import update_autoawq_quant_config from .export_vllm import update_vllm_quant_config -from .utils import (check_config, copy_files, mkdirs, - print_important_package_version, seed_all) +from .utils import (check_config, copy_files, deploy_all_modality, + get_modality, mkdirs, print_important_package_version, + seed_all) diff --git a/llmc/utils/utils.py b/llmc/utils/utils.py index c802a043..fb287fca 100644 --- a/llmc/utils/utils.py +++ b/llmc/utils/utils.py @@ -29,15 +29,18 @@ def check_weight_setting(weight_setting): elif weight_setting.granularity == 'per_head': assert weight_setting.head_num > 0 - if config.quant.weight.get('granularity', False): - weight_setting = config.quant.weight - check_weight_setting(weight_setting) - if config.quant.weight.get('w_1', False): - weight_setting = config.quant.weight.w_1 - check_weight_setting(weight_setting) - if config.quant.weight.get('w_2', False): - weight_setting = config.quant.weight.w_2 - check_weight_setting(weight_setting) + for _, modality_config in config.quant.items(): + if not isinstance(modality_config, dict) or not modality_config.get('weight', False): + continue + if modality_config.weight.get('granularity', False): + weight_setting = modality_config.weight + check_weight_setting(weight_setting) + if modality_config.weight.get('w_1', False): + weight_setting = modality_config.weight.w_1 + check_weight_setting(weight_setting) + if modality_config.weight.get('w_2', False): + weight_setting = modality_config.weight.w_2 + check_weight_setting(weight_setting) if config.model.get('tokenizer_mode', False): assert ( config.model.tokenizer_mode == 'slow' @@ -72,3 +75,23 @@ def print_important_package_version(): logger.info(f"tokenizers : {version('tokenizers')}") logger.info(f"huggingface-hub : {version('huggingface-hub')}") logger.info(f"datasets : {version('datasets')}") + + +def get_modality(config): + modalities = [] + modality_configs = [] + compression_config = config.quant if 'quant' in config else config.sparse + for modality in ['vision', 'language']: + if modality in compression_config: + compression_config[modality].modality = modality + modalities.append(modality) + modality_configs.append(compression_config[modality]) + if not modalities: + compression_config.modality = 'language' + return ['language'], [compression_config] + return modalities, modality_configs + + +def deploy_all_modality(blockwise_opts, quant_format): + for blockwise_opt in blockwise_opts: + blockwise_opt.deploy(quant_format)