Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

align auto_quantizer with main branch in Transformers #437

Merged
merged 7 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion auto_round/auto_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from transformers.quantizers import AutoQuantizationConfig, HfQuantizer
from transformers.quantizers.auto import AUTO_QUANTIZER_MAPPING
from transformers.utils.quantization_config import AwqConfig, GPTQConfig, QuantizationConfigMixin, QuantizationMethod

from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING
WeiweiZhang1 marked this conversation as resolved.
Show resolved Hide resolved
from auto_round.utils import (get_module, set_module, is_hpu_supported, get_block_names,
get_multimodal_block_names, find_matching_blocks)

Expand Down Expand Up @@ -191,6 +191,29 @@ def merge_quantization_configs(
warnings.warn(warning_msg)

return quantization_config

@staticmethod
def supports_quant_method(quantization_config_dict):
AUTO_QUANTIZATION_CONFIG_MAPPING['intel/auto-round'] = AutoRoundConfig
AUTO_QUANTIZATION_CONFIG_MAPPING['intel/auto_round'] = AutoRoundConfig
quant_method = quantization_config_dict.get("quant_method", None)
if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
elif quant_method is None:
raise ValueError(
"The model's quantization config from the arguments has no `quant_method` attribute."\
"Make sure that the model has been correctly quantized"
)

if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
logger.warning(
f"Unknown quantization type, got {quant_method} - supported types are:"
f" {list(AUTO_QUANTIZER_MAPPING.keys())}. Hence, we will skip the quantization. "
"To remove the warning, you can delete the quantization_config attribute in config.json"
)
return False
return True


class AutoRoundQuantizationMethod(str, Enum):
Expand Down Expand Up @@ -758,3 +781,4 @@ def is_serializable(self):

transformers.quantizers.auto.AutoHfQuantizer = AutoHfQuantizer
transformers.modeling_utils.AutoHfQuantizer = AutoHfQuantizer

2 changes: 1 addition & 1 deletion auto_round/script/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def tune(args):
if args.quant_lm_head:
layer_config[lm_head_layer_name] = {"bits": args.bits}
for format in formats:
if "auto_round" not in format:
if "auto_round" not in format and "fake" not in format:
WeiweiZhang1 marked this conversation as resolved.
Show resolved Hide resolved
auto_round_formats = [s for s in supported_formats if s.startswith("auto_round")]
raise ValueError(
f"{format} is not supported for lm-head quantization, please change to {auto_round_formats}")
Expand Down
3 changes: 2 additions & 1 deletion auto_round/script/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def tune(args):
if args.quant_lm_head:
layer_config[lm_head_layer_name] = {"bits": args.bits}
for format in formats:
if "auto_round" not in format:
if "auto_round" not in format and "fake" not in format:
auto_round_formats = [s for s in supported_formats if s.startswith("auto_round")]
raise ValueError(
f"{format} is not supported for lm-head quantization, please change to {auto_round_formats}")
Expand Down Expand Up @@ -587,3 +587,4 @@ def lmms_eval(args):
apply_chat_template=False,
)
return results