Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lightllm support #265

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 66 additions & 44 deletions llmc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def main(config):
blockwise_opt.run_block_loop()
dist.barrier()
else:
dataset = BaseDataset(model.get_tokenizer(), config.calib, model.batch_process)
dataset = BaseDataset(
model.get_tokenizer(), config.calib, model.batch_process
)
calib_data, padding_mask = dataset.get_calib_dataset()
model.collect_first_block_input(calib_data, padding_mask)
del calib_data
Expand Down Expand Up @@ -91,53 +93,60 @@ def main(config):
blockwise_opt.deploy('fake_quant')
blockwise_opt.save_model(save_fake_path)

if 'save' in config and config.save.get('save_vllm', False):
w, a = config.quant.weight, config.quant.get('act')
if isinstance(w.bit, str):
assert a, 'Only WA float quant is supported.'
assert w.symmetric and a.symmetric, 'Only symmetric quant is supported.'
assert w.bit == a.bit and w.bit in ['e4m3', 'e5m2'] and \
a.bit in ['e4m3', 'e5m2'], 'Only WA FP8 quant is supported'
else:
assert w.symmetric, 'Only symmetric quant is supported.'
assert w.bit in [4, 8], 'Supported quant: w4a16, w8a16, w8a8.'
if a:
assert a.symmetric, 'Only symmetric quant is supported.'
assert a.bit == 8, 'Supported quant: w4a16, w8a16, w8a8.'
blockwise_opt.deploy('vllm_quant')
blockwise_opt.save_model(save_quant_path)
update_vllm_quant_config(blockwise_opt.model, config, save_quant_path)

if 'save' in config and config.save.get('save_sgl', False):
w, a = config.quant.weight, config.quant.get('act')
if isinstance(w.bit, str):
assert a, 'Only WA float quant is supported.'
assert w.symmetric and a.symmetric, 'Only symmetric quant is supported.'
assert w.bit == a.bit and w.bit in ['e4m3', 'e5m2'] and \
a.bit in ['e4m3', 'e5m2'], 'Only WA FP8 quant is supported'
else:
assert w.symmetric, 'Only symmetric quant is supported.'
assert w.bit in [4, 8], 'Supported quant: w4a16, w8a16, w8a8.'
if a:
assert a.symmetric, 'Only symmetric quant is supported.'
assert a.bit == 8, 'Supported quant: w4a16, w8a16, w8a8.'
blockwise_opt.deploy('sgl_quant')
blockwise_opt.save_model(save_quant_path)
update_vllm_quant_config(blockwise_opt.model, config, save_quant_path)
if 'save' in config:
if (
config.save.get('save_vllm', False)
or config.save.get('save_sgl', False)
or config.save.get('save_lightllm', False)
):
w, a = config.quant.weight, config.quant.get('act')

if isinstance(w.bit, str):
assert a, 'Only WA float quant is supported.'
assert (
w.symmetric and a.symmetric
), 'Only symmetric quant is supported.'
assert (
w.bit == a.bit
and w.bit in ['e4m3', 'e5m2']
and a.bit in ['e4m3', 'e5m2']
), 'Only WA FP8 quant is supported'
else:
assert w.symmetric, 'Only symmetric quant is supported.'
assert w.bit in [4, 8], 'Supported quant: w4a16, w8a16, w8a8.'
if a:
assert a.symmetric, 'Only symmetric quant is supported.'
assert a.bit == 8, 'Supported quant: w4a16, w8a16, w8a8.'

if config.save.get('save_vllm', False):
blockwise_opt.deploy('vllm_quant')
if config.save.get('save_lightllm', False):
blockwise_opt.deploy('lightllm_quant')
if config.save.get('save_sgl', False):
blockwise_opt.deploy('sgl_quant')

blockwise_opt.save_model(save_quant_path)
update_vllm_quant_config(blockwise_opt.model, config, save_quant_path)

if 'save' in config and config.save.get('save_autoawq', False):
assert config.quant.weight.bit in [4] and 'act' not in config.quant, \
'AutoAWQ supports only 4-bit weight-only quantization.'
assert not config.quant.weight.symmetric, 'Only asymmetric quant is supported.'
assert (
config.quant.weight.bit in [4] and 'act' not in config.quant
), 'AutoAWQ supports only 4-bit weight-only quantization.'
assert (
not config.quant.weight.symmetric
), 'Only asymmetric quant is supported.'

blockwise_opt.deploy('autoawq_quant')
blockwise_opt.save_model(save_quant_path)
update_autoawq_quant_config(config, save_quant_path)

if 'save' in config and config.save.get('save_mlcllm', False):
assert config.quant.weight.bit in [4] and 'act' not in config.quant, \
'MlcLLM supports only 4-bit weight-only quantization.'
assert not config.quant.weight.symmetric, 'Only asymmetric quant is supported.'
assert (
config.quant.weight.bit in [4] and 'act' not in config.quant
), 'MlcLLM supports only 4-bit weight-only quantization.'
assert (
not config.quant.weight.symmetric
), 'Only asymmetric quant is supported.'

blockwise_opt.deploy('mlcllm_quant')
blockwise_opt.save_model(save_quant_path)
Expand Down Expand Up @@ -192,7 +201,9 @@ def main(config):
if int(os.environ['RANK']) == 0:
if 'save' in config:
if config.save.get('save_trans', False):
save_trans_path = os.path.join(config.save.save_path, 'transformed_model')
save_trans_path = os.path.join(
config.save.save_path, 'transformed_model'
)
mkdirs(save_trans_path)
if config.save.get('save_trtllm', False):
save_trtllm_trans_path = os.path.join(
Expand All @@ -204,16 +215,27 @@ def main(config):
)
mkdirs(save_trtllm_engine_path)
if config.save.get('save_vllm', False):
save_quant_path = os.path.join(config.save.save_path, 'vllm_quant_model')
save_quant_path = os.path.join(
config.save.save_path, 'vllm_quant_model'
)
mkdirs(save_quant_path)
if config.save.get('save_lightllm', False):
save_quant_path = os.path.join(
config.save.save_path, 'lightllm_quant_model'
)
mkdirs(save_quant_path)
if config.save.get('save_sgl', False):
save_quant_path = os.path.join(config.save.save_path, 'sgl_quant_model')
mkdirs(save_quant_path)
if config.save.get('save_autoawq', False):
save_quant_path = os.path.join(config.save.save_path, 'autoawq_quant_model')
save_quant_path = os.path.join(
config.save.save_path, 'autoawq_quant_model'
)
mkdirs(save_quant_path)
if config.save.get('save_mlcllm', False):
save_quant_path = os.path.join(config.save.save_path, 'mlcllm_quant_model')
save_quant_path = os.path.join(
config.save.save_path, 'mlcllm_quant_model'
)
mkdirs(save_quant_path)
if config.save.get('save_fake', False):
save_fake_path = os.path.join(config.save.save_path, 'fake_quant_model')
Expand Down
29 changes: 28 additions & 1 deletion llmc/compression/quantization/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,12 @@ def new(cls, module, w_q, quant_config):
input_scale = module.buf_act_scales_0
else:
input_scale = None
if (
quant_config.act.get('static', False)
and quant_config.get('quant_type', 'int-quant') == 'int-quant'
):
input_scale = input_scale.unsqueeze(0)

if module.bias is not None:
bias = module.bias.data
else:
Expand Down Expand Up @@ -1043,9 +1049,28 @@ def __repr__(self):
)


class LightllmRealQuantLinear(VllmRealQuantLinear):
def __init__(self, weight, bias, scales, input_scale, need_pack):
super().__init__(weight, bias, scales, input_scale, need_pack)

def __repr__(self):
return (
'LightllmRealQuantLinear('
+ f'in_features={self.in_features}, '
+ f'out_features={self.out_features}, '
+ f'bias={self.bias is not None}, '
+ f'weight_shape={self.weight_shape}, '
+ f'weight_dtype={self.weight_dtype}, '
+ f'scales_shape={self.scales_shape}, '
+ f'scales_dtype={self.scales_dtype}, '
+ f'zeros_shape={self.zeros_shape}, '
+ f'zeros_dtype={self.zeros_dtype})'
)


class SglRealQuantLinear(VllmRealQuantLinear):
def __init__(self, weight, bias, scales, input_scale, need_pack):
super().__init__(weight, bias, scales, need_pack)
super().__init__(weight, bias, scales, input_scale, need_pack)

def __repr__(self):
return (
Expand Down Expand Up @@ -1302,12 +1327,14 @@ def __repr__(self):
SglRealQuantLinear,
AutoawqRealQuantLinear,
MlcllmRealQuantLinear,
LightllmRealQuantLinear,
]

_LLMC_ATTN_MAP_ = {'Vit': LlmcViTSelfAttention, 'DeepseekV2': LlmcDeepseekAttention}

_REALQUANT_LINEAR_MAP_ = {
'vllm_quant': VllmRealQuantLinear,
'lightllm_quant': LightllmRealQuantLinear,
'sgl_quant': SglRealQuantLinear,
'autoawq_quant': AutoawqRealQuantLinear,
'mlcllm_quant': MlcllmRealQuantLinear,
Expand Down
11 changes: 9 additions & 2 deletions llmc/utils/export_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,24 @@ def update_vllm_quant_config(
else:
group_size = None

if 'static' in config.quant.act:
dynamic = not config.quant.act.static
else:
dynamic = True

quant_config = {
'config_groups': {
'group_0': {
'targets': ['Linear'], # Now only support "Linear".
'input_activations': {
'dynamic': True,
'dynamic': dynamic,
'group_size': None, # Don't support activations per-group quant.
'num_bits': a_num_bits,
'observer': 'minmax',
'observer_kwargs': {},
'strategy': 'token', # Now only support dynamic per-token
'strategy': 'token'
if config.quant.act.granularity == 'per_token'
else 'tensor',
'symmetric': config.quant.act.symmetric,
'type': quant_type
} if 'act' in config.quant else None,
Expand Down
Loading