diff --git a/auto_round/auto_quantizer.py b/auto_round/auto_quantizer.py index 4856c253..305f64ab 100644 --- a/auto_round/auto_quantizer.py +++ b/auto_round/auto_quantizer.py @@ -41,7 +41,6 @@ from transformers.quantizers import AutoQuantizationConfig, HfQuantizer from transformers.quantizers.auto import AUTO_QUANTIZER_MAPPING from transformers.utils.quantization_config import AwqConfig, GPTQConfig, QuantizationConfigMixin, QuantizationMethod - from auto_round.utils import (get_module, set_module, is_hpu_supported, get_block_names, get_multimodal_block_names, find_matching_blocks) @@ -191,6 +190,30 @@ def merge_quantization_configs( warnings.warn(warning_msg) return quantization_config + + @staticmethod + def supports_quant_method(quantization_config_dict): + from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING + AUTO_QUANTIZATION_CONFIG_MAPPING['intel/auto-round'] = AutoRoundConfig + AUTO_QUANTIZATION_CONFIG_MAPPING['intel/auto_round'] = AutoRoundConfig + quant_method = quantization_config_dict.get("quant_method", None) + if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False): + suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit" + quant_method = QuantizationMethod.BITS_AND_BYTES + suffix + elif quant_method is None: + raise ValueError( + "The model's quantization config from the arguments has no `quant_method` attribute."\ + "Make sure that the model has been correctly quantized" + ) + + if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys(): + logger.warning( + f"Unknown quantization type, got {quant_method} - supported types are:" + f" {list(AUTO_QUANTIZER_MAPPING.keys())}. Hence, we will skip the quantization. " + "To remove the warning, you can delete the quantization_config attribute in config.json" + ) + return False + return True class AutoRoundQuantizationMethod(str, Enum): @@ -758,3 +781,4 @@ def is_serializable(self): transformers.quantizers.auto.AutoHfQuantizer = AutoHfQuantizer transformers.modeling_utils.AutoHfQuantizer = AutoHfQuantizer + diff --git a/auto_round/script/llm.py b/auto_round/script/llm.py index e44cd9f0..44559916 100644 --- a/auto_round/script/llm.py +++ b/auto_round/script/llm.py @@ -463,7 +463,7 @@ def tune(args): if args.quant_lm_head: layer_config[lm_head_layer_name] = {"bits": args.bits} for format in formats: - if "auto_round" not in format: + if "auto_round" not in format and "fake" not in format: auto_round_formats = [s for s in supported_formats if s.startswith("auto_round")] raise ValueError( f"{format} is not supported for lm-head quantization, please change to {auto_round_formats}") diff --git a/auto_round/script/mllm.py b/auto_round/script/mllm.py index 884ac920..f45eab28 100644 --- a/auto_round/script/mllm.py +++ b/auto_round/script/mllm.py @@ -418,7 +418,7 @@ def tune(args): if args.quant_lm_head: layer_config[lm_head_layer_name] = {"bits": args.bits} for format in formats: - if "auto_round" not in format: + if "auto_round" not in format and "fake" not in format: auto_round_formats = [s for s in supported_formats if s.startswith("auto_round")] raise ValueError( f"{format} is not supported for lm-head quantization, please change to {auto_round_formats}") @@ -587,3 +587,4 @@ def lmms_eval(args): apply_chat_template=False, ) return results +