Skip to content

Commit e82b688

Browse files
committed
revert submodule to module
1 parent d8aad41 commit e82b688

File tree

4 files changed

+10
-10
lines changed

4 files changed

+10
-10
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
4343
from compressed_tensors.quantization.quant_args import QuantizationArgs
4444
from compressed_tensors.quantization.utils import (
45-
is_submodule_quantized,
45+
is_model_quantized,
4646
iter_named_leaf_modules,
4747
)
4848
from compressed_tensors.utils import get_safetensors_folder, update_parameter_data
@@ -426,7 +426,7 @@ def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
426426
"""
427427
quantized_modules_to_args = {}
428428
for name, submodule in iter_named_leaf_modules(model):
429-
if is_submodule_quantized(submodule):
429+
if is_model_quantized(submodule):
430430
if submodule.quantization_scheme.weights is not None:
431431
name = fix_fsdp_module_name(name)
432432
quantized_modules_to_args[name] = submodule.quantization_scheme.weights

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
"is_sparse_target",
5757
]
5858

59-
from compressed_tensors.quantization.utils.helpers import is_submodule_quantized
59+
from compressed_tensors.quantization.utils.helpers import is_model_quantized
6060
from compressed_tensors.utils.safetensors_load import get_quantization_state_dict
6161

6262

@@ -76,7 +76,7 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
7676
state_dict = get_quantization_state_dict(model_path)
7777

7878
for name, submodule in iter_named_leaf_modules(model):
79-
if not is_submodule_quantized(submodule):
79+
if not is_model_quantized(submodule):
8080
continue
8181
if submodule.quantization_scheme.weights is not None:
8282
base_name = "weight"

src/compressed_tensors/quantization/quant_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
from compressed_tensors.quantization.utils import (
2525
calculate_compression_ratio,
26-
is_submodule_quantized,
26+
is_model_quantized,
2727
iter_named_quantizable_modules,
2828
module_type,
2929
parse_out_kv_cache_args,
@@ -181,7 +181,7 @@ def from_pretrained(
181181
model, include_children=True, include_attn=True
182182
):
183183
layer_type = module_type(submodule)
184-
if not is_submodule_quantized(submodule):
184+
if not is_model_quantized(submodule):
185185
if layer_type not in ignore:
186186
ignore[layer_type] = []
187187
ignore[layer_type].append(name)

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
"is_kv_cache_quant_scheme",
4242
"is_model_quantized",
4343
"is_model_quantized_from_path",
44-
"is_submodule_quantized",
44+
"is_model_quantized",
4545
"iter_named_leaf_modules",
4646
"iter_named_quantizable_modules",
4747
"module_type",
@@ -169,7 +169,7 @@ def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]:
169169
return None
170170

171171

172-
def is_submodule_quantized(module: Module) -> bool:
172+
def is_model_quantized(module: Module) -> bool:
173173
"""
174174
Check if a module is quantized, based on the existence of a non-empty quantization
175175
scheme
@@ -202,7 +202,7 @@ def is_model_quantized(model: Module) -> bool:
202202
"""
203203

204204
for _, submodule in iter_named_leaf_modules(model):
205-
if is_submodule_quantized(submodule):
205+
if is_model_quantized(submodule):
206206
return True
207207

208208
return False
@@ -353,7 +353,7 @@ def calculate_compression_ratio(model: Module) -> float:
353353
uncompressed_bits = get_torch_bit_depth(parameter)
354354
compressed_bits = uncompressed_bits
355355
if (
356-
is_submodule_quantized(submodule)
356+
is_model_quantized(submodule)
357357
and submodule.quantization_scheme.weights
358358
):
359359
compressed_bits = submodule.quantization_scheme.weights.num_bits

0 commit comments

Comments
 (0)