Skip to content

move llmcompressor util is_model_path_quantized to ct #246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 36 additions & 20 deletions src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,26 @@
from torch import FloatTensor, IntTensor, Tensor
from torch.nn import Module
from tqdm import tqdm
from transformers import AutoConfig


__all__ = [
"infer_quantization_status",
"is_module_quantized",
"is_model_quantized",
"module_type",
"KV_CACHE_TARGETS",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alphabetical

"calculate_compression_ratio",
"get_torch_bit_depth",
"calculate_qparams",
"calculate_range",
"can_quantize",
"parse_out_kv_cache_args",
"KV_CACHE_TARGETS",
"compute_dynamic_scales_and_zp",
"get_torch_bit_depth",
"infer_quantization_status",
"is_kv_cache_quant_scheme",
"is_model_quantized",
"is_model_quantized_from_path",
"is_module_quantized",
"iter_named_leaf_modules",
"iter_named_quantizable_modules",
"compute_dynamic_scales_and_zp",
"calculate_range",
"calculate_qparams",
"module_type",
"parse_out_kv_cache_args",
]

# target the self_attn layer
Expand Down Expand Up @@ -170,22 +172,17 @@ def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]:
def is_module_quantized(module: Module) -> bool:
"""
Check if a module is quantized, based on the existence of a non-empty quantization
scheme
scheme.

:param module: pytorch module to check
:param module: PyTorch module to check
:return: True if module is quantized, False otherwise
"""
if not hasattr(module, "quantization_scheme"):
return False

if module.quantization_scheme.weights is not None:
return True

if module.quantization_scheme.input_activations is not None:
return True

if module.quantization_scheme.output_activations is not None:
return True
for attr in ("weights", "input_activations", "output_activations"):
if getattr(module.quantization_scheme, attr, None) is not None:
return True

return False

Expand All @@ -206,6 +203,25 @@ def is_model_quantized(model: Module) -> bool:
return False


def is_model_quantized_from_path(path: str) -> bool:
"""
Determine if model stub or path is quantized based
on the config

:param path: path to the model or HF stub
:return: True if config contains quantization_config from the given path

"""
config = AutoConfig.from_pretrained(path)
if config is not None:
if (
hasattr(config, "quantization_config")
and config.quantization_config["quant_method"] == "compressed-tensors"
):
return True
return False


def module_type(module: Module) -> str:
"""
Gets a string representation of a module type
Expand Down