diff --git a/src/compressed_tensors/linear/compressed_linear.py b/src/compressed_tensors/linear/compressed_linear.py index 3e2b2f5f..fcbe4606 100644 --- a/src/compressed_tensors/linear/compressed_linear.py +++ b/src/compressed_tensors/linear/compressed_linear.py @@ -21,6 +21,7 @@ QuantizationStatus, initialize_module_for_quantization, ) +from compressed_tensors.utils import register_offload_parameter from torch import Tensor from torch.nn import Parameter from torch.nn.functional import linear @@ -68,7 +69,7 @@ def from_linear( param = Parameter( torch.empty(shape, device=device, dtype=dtype), requires_grad=False ) - module.register_parameter(name, param) + register_offload_parameter(module, name, param) # mark module as compressed module.quantization_status = QuantizationStatus.COMPRESSED diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 8dd8fc51..9ce7a000 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -190,24 +190,19 @@ def _initialize_scale_zero_point( register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) -def _initialize_attn_scales(module: Module) -> None: - """Initlaize k_scale, v_scale for self_attn""" +def _initialize_attn_scales(module: Module): + """Initlaize k_scale, v_scale for self_attn""" expected_shape = 1 # per tensor - param = next(module.parameters()) - scale_dtype = param.dtype - device = param.device + weight_param = getattr(module, "weight", next(module.parameters())) + scale_dtype = weight_param.dtype + device = weight_param.device init_scale = Parameter( torch.empty(expected_shape, dtype=scale_dtype, device=device), requires_grad=False, ) - module.register_parameter(KVCacheScaleType.KEY.value, init_scale) - - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - module.register_parameter(KVCacheScaleType.VALUE.value, init_scale) + register_offload_parameter(module, KVCacheScaleType.KEY.value, init_scale) + register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale.clone())