diff --git a/llmc/compression/quantization/awq.py b/llmc/compression/quantization/awq.py index 1ea48b33..256d8d66 100644 --- a/llmc/compression/quantization/awq.py +++ b/llmc/compression/quantization/awq.py @@ -80,13 +80,13 @@ def get_scales(self, prev_op, x, w_max, is_gqa, ratio): x_tmp = self.get_act_scale(x_tmp) - if self.trans_version == 'v1': + if self.trans_version == 'v1' and not is_gqa: scales = ( (x_tmp.pow(ratio) / w_tmp.pow(1 - ratio)) .clamp(min=1e-4) .view(-1) ) - elif self.trans_version == 'v2': + elif self.trans_version == 'v2' or is_gqa: scales = x_tmp.pow(ratio).clamp(min=1e-4).view(-1) scales = scales / (scales.max() * scales.min()).sqrt()