diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 514b5434..9cedd214 100644 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -840,7 +840,7 @@ def scaling_input(self, x, scales, is_gqa): batch_scale = scales_tmp.view(1, -1) x_tmp[i] = batch / batch_scale else: - x_tmp = x / scales.view(1, -1) + x_tmp = x / scales_tmp.view(1, -1) return x_tmp @torch.no_grad()