Skip to content

Commit

Permalink
Fix awq trans_v1 gqa bug
Browse files Browse the repository at this point in the history
  • Loading branch information
gushiqiao committed Jan 2, 2025
1 parent fae1f49 commit e70579a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ def get_scales(self, prev_op, x, w_max, is_gqa, ratio):

x_tmp = self.get_act_scale(x_tmp)

if self.trans_version == 'v1':
if self.trans_version == 'v1' and not is_gqa:
scales = (
(x_tmp.pow(ratio) / w_tmp.pow(1 - ratio))
.clamp(min=1e-4)
.view(-1)
)
elif self.trans_version == 'v2':
elif self.trans_version == 'v2' or is_gqa:
scales = x_tmp.pow(ratio).clamp(min=1e-4).view(-1)

scales = scales / (scales.max() * scales.min()).sqrt()
Expand Down

0 comments on commit e70579a

Please sign in to comment.