From 6b896014dc160232b1eb57c2a183042b583037d8 Mon Sep 17 00:00:00 2001 From: gushiqiao Date: Mon, 30 Dec 2024 14:51:18 +0800 Subject: [PATCH] fix bug --- llmc/compression/quantization/awq.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/llmc/compression/quantization/awq.py b/llmc/compression/quantization/awq.py index e7cd10eb..f0863a70 100644 --- a/llmc/compression/quantization/awq.py +++ b/llmc/compression/quantization/awq.py @@ -185,6 +185,7 @@ def search_scale_subset( torch.cuda.empty_cache() x_tmp = self.scaling_input(x, scales, is_gqa) + logger.info(f"x_tmp:{x_tmp.shape}") if not check_w_only( self.block_idx, @@ -205,6 +206,7 @@ def search_scale_subset( ).fake_quant_act_dynamic(_x) outs.append(_x) x_tmp = torch.stack(outs) + logger.info(f"x_tmp:{x_tmp.shape}") out = self.inspect_module_forward(x_tmp, inspect_module, kwargs) @@ -212,6 +214,9 @@ def search_scale_subset( org_out = org_out * self.padding_mask[i].unsqueeze(dim=-1).to(org_out.device) # noqa out = out * self.padding_mask[i].unsqueeze(dim=-1).to(out.device) + logger.info(f"org_out:{org_out.shape}") + logger.info(f"out:{out.shape}") + loss = self.calculate_loss(org_out, out) if len(input) == 1: