Skip to content

Commit

Permalink
update vlm modality and simplify vlm model code
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang committed Dec 15, 2024
1 parent 5c0af04 commit 6ddbf82
Show file tree
Hide file tree
Showing 42 changed files with 260 additions and 230 deletions.
8 changes: 2 additions & 6 deletions llmc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch.distributed as dist
import yaml
from easydict import EasyDict
from lmms_eval.utils import make_table
from loguru import logger
from torch.distributed import destroy_process_group, init_process_group

Expand All @@ -34,22 +33,21 @@ def main(config):
eval_model(model, None, eval_list, eval_pos='pretrain')

for modality in config.quant.get('quant_objects', ['language']):
model.get_key_info(modality)
model.set_modality(modality)
if not config.get('calib', False):
blockwise_opt = ALGO_REGISTRY[config.quant.method](
model,
quant_config=config.quant,
input=None,
padding_mask=None,
config=config,
modality=modality,
)
blockwise_opt.run_block_loop()
dist.barrier()
else:
dataset = BaseDataset(model.get_tokenizer(), config.calib, model.batch_process)
calib_data, padding_mask = dataset.get_calib_dataset()
model.collect_first_block_input(calib_data, padding_mask, modality)
model.collect_first_block_input(calib_data, padding_mask)
del calib_data
gc.collect()
torch.cuda.empty_cache()
Expand All @@ -60,7 +58,6 @@ def main(config):
model.get_first_block_input(),
model.get_padding_mask(),
config,
modality
)
else:
blockwise_opt = ALGO_REGISTRY[config.sparse.method](
Expand All @@ -69,7 +66,6 @@ def main(config):
model.get_first_block_input(),
model.get_padding_mask(),
config,
modality
)
blockwise_opt.run_block_loop()
dist.barrier()
Expand Down
4 changes: 1 addition & 3 deletions llmc/compression/blockwise_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@


class BlockwiseOpt(metaclass=ABCMeta):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
def __init__(self, model, quant_config, input, padding_mask, config):
self.model = model
self.modality = modality
self.model.find_blocks(modality)
self.blocks = model.get_blocks()
self.quant_config = quant_config
self.sparsity_config = quant_config
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

@ALGO_REGISTRY
class Awq(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
special_config = self.quant_config.get('special', {})
self.trans = special_config.get('trans', True)
self.trans_version = special_config.get('trans_version', 'v2')
Expand Down
12 changes: 3 additions & 9 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@


class BaseBlockwiseQuantization(BlockwiseOpt):
def __init__(
self, model, quant_config, input, padding_mask, config, modality='language'
):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.set_quant_config()

def w_qdq(self, module, wquantizer):
Expand Down Expand Up @@ -465,11 +463,7 @@ def run(self, block, input_feat, handles):

def block_transform(self, block, input_feat, block_kwargs):
logger.info(f'Start transform the {self.block_idx}-th block')
subsets = (
self.model.get_subsets_in_block(block)
if self.modality == 'language'
else self.model.get_vision_subsets_in_block(block)
)
subsets = self.model.get_subsets_in_block(block)

if self.act_static:
self.register_non_linear_qparams(block, input_feat)
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/dgq.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

@ALGO_REGISTRY
class DGQ(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.model_dtype = next(self.model.model.parameters()).dtype

def w_qdq(self, module, wquantizer):
Expand Down
2 changes: 1 addition & 1 deletion llmc/compression/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class GPTQ(BaseBlockwiseQuantization):
def __init__(
self, model, quant_config, input, padding_mask, config, modality='language'
):
super().__init__(model, quant_config, input, padding_mask, config, modality)
super().__init__(model, quant_config, input, padding_mask, config)
self.dev = torch.device('cuda')
self.model_dtype = next(self.model.model.parameters()).dtype
self.add_quant_config()
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

@ALGO_REGISTRY
class HQQ(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.add_quant_config()

@torch.no_grad()
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/llmint8.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

@ALGO_REGISTRY
class LlmInt8(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.add_quant_config()

@torch.no_grad()
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/ntweak.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

@ALGO_REGISTRY
class NormTweaking(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.add_quant_config()

model_type = self.config['model']['type']
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/omniq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

@ALGO_REGISTRY
class OmniQuant(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.add_quant_config()

model_type = self.config['model']['type']
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/osplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

@ALGO_REGISTRY
class OsPlus(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
def __init__(self, model, quant_config, input, padding_mask, config):
torch.set_grad_enabled(False)
super().__init__(model, quant_config, input, padding_mask, config, modality)
super().__init__(model, quant_config, input, padding_mask, config)

@torch.no_grad()
def filter_subset(self, prev_op):
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/quarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

@ALGO_REGISTRY
class Quarot(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.dev = torch.device('cuda')
self.add_quant_config()
self.preprocess()
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/quik.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

@ALGO_REGISTRY
class QUIK(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.add_quant_config()

def add_quant_config(self):
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

@ALGO_REGISTRY
class RTN(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)

@torch.no_grad()
def block_opt(self, block, *opt_kwargs):
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

@ALGO_REGISTRY
class SmoothQuant(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
special_config = self.quant_config.get('special', {})
self.alpha = special_config.get('alpha', 0.5)

Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/spqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

@ALGO_REGISTRY
class SpQR(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
assert (
self.wquantizer.granularity == 'per_group'
), 'SpQR only supports per_group quantization'
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/tesseraq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

@ALGO_REGISTRY
class TesseraQ(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.add_quant_config()

self.attention_mask = self.input['kwargs'][0].get('attention_mask')
Expand Down
5 changes: 3 additions & 2 deletions llmc/eval/eval_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from lmms_eval.evaluator_utils import run_task_tests
from lmms_eval.loggers.evaluation_tracker import EvaluationTracker
from lmms_eval.tasks import TaskManager, get_task_dict
from lmms_eval.utils import get_datetime_str, simple_parse_args_string
from lmms_eval.utils import (get_datetime_str, make_table,
simple_parse_args_string)
from loguru import logger

from llmc.utils.registry_factory import MODEL_REGISTRY
Expand Down Expand Up @@ -231,6 +232,6 @@ def _adjust_config(task_dict):
results['date'] = datetime_str
# add_env_info(results) # additional environment info to results
# add_tokenizer_info(results, lm) # additional info about tokenizer
return results
return make_table(results)
else:
return None
29 changes: 20 additions & 9 deletions llmc/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,35 @@ def __init__(self, config, device_map=None, use_cache=False):
self.vision_projector = None
self.audio_model = None
self.audio_projector = None
self.modality = None
self.kvcache_buffer = []
self.build_tokenizer()
self.build_model()
self.model.eval()
self.kvcache_buffer = []
self.get_key_info(modality='language')
if self.mm_model:
self.mm_model.eval()

def set_modality(self, modality='language'):
assert modality in ['audio', 'vision', 'language']
self.modality = modality
self.update_key_info()

def get_modality(self):
assert self.modality in ['audio', 'vision', 'language']
return self.modality

def get_key_info(self, modality='language'):
self.find_blocks(modality=modality)
def update_key_info(self):
self.find_blocks()
self.find_embed_layers()
self.find_block_name()
self.add_layernorms_class(modality=modality)
self.add_layernorms_class()

def reset_kv(self):
for kvcache in self.kvcache_buffer:
kvcache._reset_states()

@abstractmethod
def find_blocks(self, modality='language'):
def find_blocks(self):
pass

def find_block_name(self):
Expand Down Expand Up @@ -193,10 +204,10 @@ def build_model(self):
)
logger.info(f'self.model : {self.model}')

def add_layernorms_class(self, modality='language'):
def add_layernorms_class(self):
ln_class_list = []
single_block = self.blocks[0]
ln_dict = self.get_layernorms_in_block(single_block, modality=modality)
ln_dict = self.get_layernorms_in_block(single_block)
for ln_name in ln_dict:
ln_class = ln_dict[ln_name].__class__
if ln_class not in ln_class_list:
Expand All @@ -207,7 +218,7 @@ def add_layernorms_class(self, modality='language'):
logger.info(f'_TRANSFORMERS_LN_TYPES_ : {_TRANSFORMERS_LN_TYPES_}')

@torch.no_grad()
def collect_first_block_input(self, calib_data, padding_mask=None, modality='language'):
def collect_first_block_input(self, calib_data, padding_mask=None):
first_block_input = defaultdict(list)

Catcher = self.get_catcher(first_block_input)
Expand Down
4 changes: 2 additions & 2 deletions llmc/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class Bloom(BaseModel):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)

def find_blocks(self, modality='language'):
def find_blocks(self):
self.blocks = self.model.transformer.h

def find_embed_layers(self):
Expand Down Expand Up @@ -37,7 +37,7 @@ def skip_layer_name(self):
def has_bias(self):
return True

def get_layernorms_in_block(self, block, modality='language'):
def get_layernorms_in_block(self, block):
return {
'input_layernorm': block.input_layernorm,
'post_attention_layernorm': block.post_attention_layernorm,
Expand Down
4 changes: 2 additions & 2 deletions llmc/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class ChatGLM(BaseModel):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)

def find_blocks(self, modality='language'):
def find_blocks(self):
self.blocks = self.model.transformer.encoder.layers

def find_embed_layers(self):
Expand Down Expand Up @@ -43,7 +43,7 @@ def skip_layer_name(self):
def has_bias(self):
return False

def get_layernorms_in_block(self, block, modality='language'):
def get_layernorms_in_block(self, block):
return {
'input_layernorm': block.input_layernorm,
'post_attention_layernorm': block.post_attention_layernorm,
Expand Down
4 changes: 2 additions & 2 deletions llmc/models/deepseekv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class DeepseekV2(BaseModel):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)

def find_blocks(self, modality='language'):
def find_blocks(self):
self.blocks = self.model.model.layers

def find_embed_layers(self):
Expand All @@ -34,7 +34,7 @@ def skip_layer_name(self):
def has_bias(self):
return False

def get_layernorms_in_block(self, block, modality='language'):
def get_layernorms_in_block(self, block):
return {
'input_layernorm': block.input_layernorm,
'post_attention_layernorm': block.post_attention_layernorm,
Expand Down
4 changes: 2 additions & 2 deletions llmc/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class Falcon(BaseModel):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)

def find_blocks(self, modality='language'):
def find_blocks(self):
self.blocks = self.model.transformer.h

def find_embed_layers(self):
Expand All @@ -30,7 +30,7 @@ def get_layers_except_blocks(self):
def has_bias(self):
return False

def get_layernorms_in_block(self, block, modality='language'):
def get_layernorms_in_block(self, block):
if block.config.architectures[0] == 'RWForCausalLM':
new_decoder_architecture = False
elif block.config.architectures[0] == 'FalconForCausalLM':
Expand Down
Loading

0 comments on commit 6ddbf82

Please sign in to comment.