Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update vlm modality #263

Merged
merged 1 commit into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions llmc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch.distributed as dist
import yaml
from easydict import EasyDict
from lmms_eval.utils import make_table
from loguru import logger
from torch.distributed import destroy_process_group, init_process_group

Expand All @@ -34,22 +33,21 @@ def main(config):
eval_model(model, None, eval_list, eval_pos='pretrain')

for modality in config.quant.get('quant_objects', ['language']):
model.get_key_info(modality)
model.set_modality(modality)
if not config.get('calib', False):
blockwise_opt = ALGO_REGISTRY[config.quant.method](
model,
quant_config=config.quant,
input=None,
padding_mask=None,
config=config,
modality=modality,
)
blockwise_opt.run_block_loop()
dist.barrier()
else:
dataset = BaseDataset(model.get_tokenizer(), config.calib, model.batch_process)
calib_data, padding_mask = dataset.get_calib_dataset()
model.collect_first_block_input(calib_data, padding_mask, modality)
model.collect_first_block_input(calib_data, padding_mask)
del calib_data
gc.collect()
torch.cuda.empty_cache()
Expand All @@ -60,7 +58,6 @@ def main(config):
model.get_first_block_input(),
model.get_padding_mask(),
config,
modality
)
else:
blockwise_opt = ALGO_REGISTRY[config.sparse.method](
Expand All @@ -69,7 +66,6 @@ def main(config):
model.get_first_block_input(),
model.get_padding_mask(),
config,
modality
)
blockwise_opt.run_block_loop()
dist.barrier()
Expand Down
4 changes: 1 addition & 3 deletions llmc/compression/blockwise_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@


class BlockwiseOpt(metaclass=ABCMeta):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
def __init__(self, model, quant_config, input, padding_mask, config):
self.model = model
self.modality = modality
self.model.find_blocks(modality)
self.blocks = model.get_blocks()
self.quant_config = quant_config
self.sparsity_config = quant_config
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

@ALGO_REGISTRY
class Awq(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
special_config = self.quant_config.get('special', {})
self.trans = special_config.get('trans', True)
self.trans_version = special_config.get('trans_version', 'v2')
Expand Down
12 changes: 3 additions & 9 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@


class BaseBlockwiseQuantization(BlockwiseOpt):
def __init__(
self, model, quant_config, input, padding_mask, config, modality='language'
):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.set_quant_config()

def w_qdq(self, module, wquantizer):
Expand Down Expand Up @@ -465,11 +463,7 @@ def run(self, block, input_feat, handles):

def block_transform(self, block, input_feat, block_kwargs):
logger.info(f'Start transform the {self.block_idx}-th block')
subsets = (
self.model.get_subsets_in_block(block)
if self.modality == 'language'
else self.model.get_vision_subsets_in_block(block)
)
subsets = self.model.get_subsets_in_block(block)

if self.act_static:
self.register_non_linear_qparams(block, input_feat)
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/dgq.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

@ALGO_REGISTRY
class DGQ(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.model_dtype = next(self.model.model.parameters()).dtype

def w_qdq(self, module, wquantizer):
Expand Down
2 changes: 1 addition & 1 deletion llmc/compression/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class GPTQ(BaseBlockwiseQuantization):
def __init__(
self, model, quant_config, input, padding_mask, config, modality='language'
):
super().__init__(model, quant_config, input, padding_mask, config, modality)
super().__init__(model, quant_config, input, padding_mask, config)
self.dev = torch.device('cuda')
self.model_dtype = next(self.model.model.parameters()).dtype
self.add_quant_config()
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

@ALGO_REGISTRY
class HQQ(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.add_quant_config()

@torch.no_grad()
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/llmint8.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

@ALGO_REGISTRY
class LlmInt8(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.add_quant_config()

@torch.no_grad()
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/ntweak.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

@ALGO_REGISTRY
class NormTweaking(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.add_quant_config()

model_type = self.config['model']['type']
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/omniq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

@ALGO_REGISTRY
class OmniQuant(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.add_quant_config()

model_type = self.config['model']['type']
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/osplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

@ALGO_REGISTRY
class OsPlus(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
def __init__(self, model, quant_config, input, padding_mask, config):
torch.set_grad_enabled(False)
super().__init__(model, quant_config, input, padding_mask, config, modality)
super().__init__(model, quant_config, input, padding_mask, config)

@torch.no_grad()
def filter_subset(self, prev_op):
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/quarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

@ALGO_REGISTRY
class Quarot(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.dev = torch.device('cuda')
self.add_quant_config()
self.preprocess()
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/quik.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

@ALGO_REGISTRY
class QUIK(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.add_quant_config()

def add_quant_config(self):
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

@ALGO_REGISTRY
class RTN(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)

@torch.no_grad()
def block_opt(self, block, *opt_kwargs):
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

@ALGO_REGISTRY
class SmoothQuant(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
special_config = self.quant_config.get('special', {})
self.alpha = special_config.get('alpha', 0.5)

Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/spqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

@ALGO_REGISTRY
class SpQR(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
assert (
self.wquantizer.granularity == 'per_group'
), 'SpQR only supports per_group quantization'
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/tesseraq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

@ALGO_REGISTRY
class TesseraQ(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.add_quant_config()

self.attention_mask = self.input['kwargs'][0].get('attention_mask')
Expand Down
5 changes: 3 additions & 2 deletions llmc/eval/eval_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from lmms_eval.evaluator_utils import run_task_tests
from lmms_eval.loggers.evaluation_tracker import EvaluationTracker
from lmms_eval.tasks import TaskManager, get_task_dict
from lmms_eval.utils import get_datetime_str, simple_parse_args_string
from lmms_eval.utils import (get_datetime_str, make_table,
simple_parse_args_string)
from loguru import logger

from llmc.utils.registry_factory import MODEL_REGISTRY
Expand Down Expand Up @@ -231,6 +232,6 @@ def _adjust_config(task_dict):
results['date'] = datetime_str
# add_env_info(results) # additional environment info to results
# add_tokenizer_info(results, lm) # additional info about tokenizer
return results
return make_table(results)
else:
return None
29 changes: 20 additions & 9 deletions llmc/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,35 @@ def __init__(self, config, device_map=None, use_cache=False):
self.vision_projector = None
self.audio_model = None
self.audio_projector = None
self.modality = None
self.kvcache_buffer = []
self.build_tokenizer()
self.build_model()
self.model.eval()
self.kvcache_buffer = []
self.get_key_info(modality='language')
if self.mm_model:
self.mm_model.eval()

def set_modality(self, modality='language'):
assert modality in ['audio', 'vision', 'language']
self.modality = modality
self.update_key_info()

def get_modality(self):
assert self.modality in ['audio', 'vision', 'language']
return self.modality

def get_key_info(self, modality='language'):
self.find_blocks(modality=modality)
def update_key_info(self):
self.find_blocks()
self.find_embed_layers()
self.find_block_name()
self.add_layernorms_class(modality=modality)
self.add_layernorms_class()

def reset_kv(self):
for kvcache in self.kvcache_buffer:
kvcache._reset_states()

@abstractmethod
def find_blocks(self, modality='language'):
def find_blocks(self):
pass

def find_block_name(self):
Expand Down Expand Up @@ -193,10 +204,10 @@ def build_model(self):
)
logger.info(f'self.model : {self.model}')

def add_layernorms_class(self, modality='language'):
def add_layernorms_class(self):
ln_class_list = []
single_block = self.blocks[0]
ln_dict = self.get_layernorms_in_block(single_block, modality=modality)
ln_dict = self.get_layernorms_in_block(single_block)
for ln_name in ln_dict:
ln_class = ln_dict[ln_name].__class__
if ln_class not in ln_class_list:
Expand All @@ -207,7 +218,7 @@ def add_layernorms_class(self, modality='language'):
logger.info(f'_TRANSFORMERS_LN_TYPES_ : {_TRANSFORMERS_LN_TYPES_}')

@torch.no_grad()
def collect_first_block_input(self, calib_data, padding_mask=None, modality='language'):
def collect_first_block_input(self, calib_data, padding_mask=None):
first_block_input = defaultdict(list)

Catcher = self.get_catcher(first_block_input)
Expand Down
4 changes: 2 additions & 2 deletions llmc/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class Bloom(BaseModel):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)

def find_blocks(self, modality='language'):
def find_blocks(self):
self.blocks = self.model.transformer.h

def find_embed_layers(self):
Expand Down Expand Up @@ -37,7 +37,7 @@ def skip_layer_name(self):
def has_bias(self):
return True

def get_layernorms_in_block(self, block, modality='language'):
def get_layernorms_in_block(self, block):
return {
'input_layernorm': block.input_layernorm,
'post_attention_layernorm': block.post_attention_layernorm,
Expand Down
4 changes: 2 additions & 2 deletions llmc/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class ChatGLM(BaseModel):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)

def find_blocks(self, modality='language'):
def find_blocks(self):
self.blocks = self.model.transformer.encoder.layers

def find_embed_layers(self):
Expand Down Expand Up @@ -43,7 +43,7 @@ def skip_layer_name(self):
def has_bias(self):
return False

def get_layernorms_in_block(self, block, modality='language'):
def get_layernorms_in_block(self, block):
return {
'input_layernorm': block.input_layernorm,
'post_attention_layernorm': block.post_attention_layernorm,
Expand Down
4 changes: 2 additions & 2 deletions llmc/models/deepseekv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class DeepseekV2(BaseModel):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)

def find_blocks(self, modality='language'):
def find_blocks(self):
self.blocks = self.model.model.layers

def find_embed_layers(self):
Expand All @@ -34,7 +34,7 @@ def skip_layer_name(self):
def has_bias(self):
return False

def get_layernorms_in_block(self, block, modality='language'):
def get_layernorms_in_block(self, block):
return {
'input_layernorm': block.input_layernorm,
'post_attention_layernorm': block.post_attention_layernorm,
Expand Down
4 changes: 2 additions & 2 deletions llmc/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class Falcon(BaseModel):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)

def find_blocks(self, modality='language'):
def find_blocks(self):
self.blocks = self.model.transformer.h

def find_embed_layers(self):
Expand All @@ -30,7 +30,7 @@ def get_layers_except_blocks(self):
def has_bias(self):
return False

def get_layernorms_in_block(self, block, modality='language'):
def get_layernorms_in_block(self, block):
if block.config.architectures[0] == 'RWForCausalLM':
new_decoder_architecture = False
elif block.config.architectures[0] == 'FalconForCausalLM':
Expand Down
Loading
Loading