diff --git a/llmc/__main__.py b/llmc/__main__.py index 65ea06a0..c8a804e8 100644 --- a/llmc/__main__.py +++ b/llmc/__main__.py @@ -45,7 +45,9 @@ def main(config): blockwise_opt.run_block_loop() dist.barrier() else: - dataset = BaseDataset(model.get_tokenizer(), config.calib, model.batch_process) + dataset = BaseDataset( + model.get_tokenizer(), config.calib, model.batch_process + ) calib_data, padding_mask = dataset.get_calib_dataset() model.collect_first_block_input(calib_data, padding_mask) del calib_data @@ -91,53 +93,60 @@ def main(config): blockwise_opt.deploy('fake_quant') blockwise_opt.save_model(save_fake_path) - if 'save' in config and config.save.get('save_vllm', False): - w, a = config.quant.weight, config.quant.get('act') - if isinstance(w.bit, str): - assert a, 'Only WA float quant is supported.' - assert w.symmetric and a.symmetric, 'Only symmetric quant is supported.' - assert w.bit == a.bit and w.bit in ['e4m3', 'e5m2'] and \ - a.bit in ['e4m3', 'e5m2'], 'Only WA FP8 quant is supported' - else: - assert w.symmetric, 'Only symmetric quant is supported.' - assert w.bit in [4, 8], 'Supported quant: w4a16, w8a16, w8a8.' - if a: - assert a.symmetric, 'Only symmetric quant is supported.' - assert a.bit == 8, 'Supported quant: w4a16, w8a16, w8a8.' - blockwise_opt.deploy('vllm_quant') - blockwise_opt.save_model(save_quant_path) - update_vllm_quant_config(blockwise_opt.model, config, save_quant_path) - - if 'save' in config and config.save.get('save_sgl', False): - w, a = config.quant.weight, config.quant.get('act') - if isinstance(w.bit, str): - assert a, 'Only WA float quant is supported.' - assert w.symmetric and a.symmetric, 'Only symmetric quant is supported.' - assert w.bit == a.bit and w.bit in ['e4m3', 'e5m2'] and \ - a.bit in ['e4m3', 'e5m2'], 'Only WA FP8 quant is supported' - else: - assert w.symmetric, 'Only symmetric quant is supported.' - assert w.bit in [4, 8], 'Supported quant: w4a16, w8a16, w8a8.' - if a: - assert a.symmetric, 'Only symmetric quant is supported.' - assert a.bit == 8, 'Supported quant: w4a16, w8a16, w8a8.' - blockwise_opt.deploy('sgl_quant') - blockwise_opt.save_model(save_quant_path) - update_vllm_quant_config(blockwise_opt.model, config, save_quant_path) + if 'save' in config: + if ( + config.save.get('save_vllm', False) + or config.save.get('save_sgl', False) + or config.save.get('save_lightllm', False) + ): + w, a = config.quant.weight, config.quant.get('act') + + if isinstance(w.bit, str): + assert a, 'Only WA float quant is supported.' + assert ( + w.symmetric and a.symmetric + ), 'Only symmetric quant is supported.' + assert ( + w.bit == a.bit + and w.bit in ['e4m3', 'e5m2'] + and a.bit in ['e4m3', 'e5m2'] + ), 'Only WA FP8 quant is supported' + else: + assert w.symmetric, 'Only symmetric quant is supported.' + assert w.bit in [4, 8], 'Supported quant: w4a16, w8a16, w8a8.' + if a: + assert a.symmetric, 'Only symmetric quant is supported.' + assert a.bit == 8, 'Supported quant: w4a16, w8a16, w8a8.' + + if config.save.get('save_vllm', False): + blockwise_opt.deploy('vllm_quant') + if config.save.get('save_lightllm', False): + blockwise_opt.deploy('lightllm_quant') + if config.save.get('save_sgl', False): + blockwise_opt.deploy('sgl_quant') + + blockwise_opt.save_model(save_quant_path) + update_vllm_quant_config(blockwise_opt.model, config, save_quant_path) if 'save' in config and config.save.get('save_autoawq', False): - assert config.quant.weight.bit in [4] and 'act' not in config.quant, \ - 'AutoAWQ supports only 4-bit weight-only quantization.' - assert not config.quant.weight.symmetric, 'Only asymmetric quant is supported.' + assert ( + config.quant.weight.bit in [4] and 'act' not in config.quant + ), 'AutoAWQ supports only 4-bit weight-only quantization.' + assert ( + not config.quant.weight.symmetric + ), 'Only asymmetric quant is supported.' blockwise_opt.deploy('autoawq_quant') blockwise_opt.save_model(save_quant_path) update_autoawq_quant_config(config, save_quant_path) if 'save' in config and config.save.get('save_mlcllm', False): - assert config.quant.weight.bit in [4] and 'act' not in config.quant, \ - 'MlcLLM supports only 4-bit weight-only quantization.' - assert not config.quant.weight.symmetric, 'Only asymmetric quant is supported.' + assert ( + config.quant.weight.bit in [4] and 'act' not in config.quant + ), 'MlcLLM supports only 4-bit weight-only quantization.' + assert ( + not config.quant.weight.symmetric + ), 'Only asymmetric quant is supported.' blockwise_opt.deploy('mlcllm_quant') blockwise_opt.save_model(save_quant_path) @@ -192,7 +201,9 @@ def main(config): if int(os.environ['RANK']) == 0: if 'save' in config: if config.save.get('save_trans', False): - save_trans_path = os.path.join(config.save.save_path, 'transformed_model') + save_trans_path = os.path.join( + config.save.save_path, 'transformed_model' + ) mkdirs(save_trans_path) if config.save.get('save_trtllm', False): save_trtllm_trans_path = os.path.join( @@ -204,16 +215,27 @@ def main(config): ) mkdirs(save_trtllm_engine_path) if config.save.get('save_vllm', False): - save_quant_path = os.path.join(config.save.save_path, 'vllm_quant_model') + save_quant_path = os.path.join( + config.save.save_path, 'vllm_quant_model' + ) + mkdirs(save_quant_path) + if config.save.get('save_lightllm', False): + save_quant_path = os.path.join( + config.save.save_path, 'lightllm_quant_model' + ) mkdirs(save_quant_path) if config.save.get('save_sgl', False): save_quant_path = os.path.join(config.save.save_path, 'sgl_quant_model') mkdirs(save_quant_path) if config.save.get('save_autoawq', False): - save_quant_path = os.path.join(config.save.save_path, 'autoawq_quant_model') + save_quant_path = os.path.join( + config.save.save_path, 'autoawq_quant_model' + ) mkdirs(save_quant_path) if config.save.get('save_mlcllm', False): - save_quant_path = os.path.join(config.save.save_path, 'mlcllm_quant_model') + save_quant_path = os.path.join( + config.save.save_path, 'mlcllm_quant_model' + ) mkdirs(save_quant_path) if config.save.get('save_fake', False): save_fake_path = os.path.join(config.save.save_path, 'fake_quant_model') diff --git a/llmc/compression/quantization/module_utils.py b/llmc/compression/quantization/module_utils.py index 5a51dc5c..06cc2006 100644 --- a/llmc/compression/quantization/module_utils.py +++ b/llmc/compression/quantization/module_utils.py @@ -972,6 +972,12 @@ def new(cls, module, w_q, quant_config): input_scale = module.buf_act_scales_0 else: input_scale = None + if ( + quant_config.act.get('static', False) + and quant_config.get('quant_type', 'int-quant') == 'int-quant' + ): + input_scale = input_scale.unsqueeze(0) + if module.bias is not None: bias = module.bias.data else: @@ -1043,9 +1049,28 @@ def __repr__(self): ) +class LightllmRealQuantLinear(VllmRealQuantLinear): + def __init__(self, weight, bias, scales, input_scale, need_pack): + super().__init__(weight, bias, scales, input_scale, need_pack) + + def __repr__(self): + return ( + 'LightllmRealQuantLinear(' + + f'in_features={self.in_features}, ' + + f'out_features={self.out_features}, ' + + f'bias={self.bias is not None}, ' + + f'weight_shape={self.weight_shape}, ' + + f'weight_dtype={self.weight_dtype}, ' + + f'scales_shape={self.scales_shape}, ' + + f'scales_dtype={self.scales_dtype}, ' + + f'zeros_shape={self.zeros_shape}, ' + + f'zeros_dtype={self.zeros_dtype})' + ) + + class SglRealQuantLinear(VllmRealQuantLinear): def __init__(self, weight, bias, scales, input_scale, need_pack): - super().__init__(weight, bias, scales, need_pack) + super().__init__(weight, bias, scales, input_scale, need_pack) def __repr__(self): return ( @@ -1302,12 +1327,14 @@ def __repr__(self): SglRealQuantLinear, AutoawqRealQuantLinear, MlcllmRealQuantLinear, + LightllmRealQuantLinear, ] _LLMC_ATTN_MAP_ = {'Vit': LlmcViTSelfAttention, 'DeepseekV2': LlmcDeepseekAttention} _REALQUANT_LINEAR_MAP_ = { 'vllm_quant': VllmRealQuantLinear, + 'lightllm_quant': LightllmRealQuantLinear, 'sgl_quant': SglRealQuantLinear, 'autoawq_quant': AutoawqRealQuantLinear, 'mlcllm_quant': MlcllmRealQuantLinear, diff --git a/llmc/utils/export_vllm.py b/llmc/utils/export_vllm.py index 7c5e1212..e3273101 100644 --- a/llmc/utils/export_vllm.py +++ b/llmc/utils/export_vllm.py @@ -48,17 +48,24 @@ def update_vllm_quant_config( else: group_size = None + if 'static' in config.quant.act: + dynamic = not config.quant.act.static + else: + dynamic = True + quant_config = { 'config_groups': { 'group_0': { 'targets': ['Linear'], # Now only support "Linear". 'input_activations': { - 'dynamic': True, + 'dynamic': dynamic, 'group_size': None, # Don't support activations per-group quant. 'num_bits': a_num_bits, 'observer': 'minmax', 'observer_kwargs': {}, - 'strategy': 'token', # Now only support dynamic per-token + 'strategy': 'token' + if config.quant.act.granularity == 'per_token' + else 'tensor', 'symmetric': config.quant.act.symmetric, 'type': quant_type } if 'act' in config.quant else None,