Skip to content

Commit 6b89601

Browse files
author
gushiqiao
committed
fix bug
1 parent 13463e7 commit 6b89601

File tree

1 file changed

+5
-0
lines changed
  • llmc/compression/quantization

1 file changed

+5
-0
lines changed

llmc/compression/quantization/awq.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def search_scale_subset(
185185
torch.cuda.empty_cache()
186186

187187
x_tmp = self.scaling_input(x, scales, is_gqa)
188+
logger.info(f"x_tmp:{x_tmp.shape}")
188189

189190
if not check_w_only(
190191
self.block_idx,
@@ -205,13 +206,17 @@ def search_scale_subset(
205206
).fake_quant_act_dynamic(_x)
206207
outs.append(_x)
207208
x_tmp = torch.stack(outs)
209+
logger.info(f"x_tmp:{x_tmp.shape}")
208210

209211
out = self.inspect_module_forward(x_tmp, inspect_module, kwargs)
210212

211213
if self.padding_mask and org_out.shape[1] == self.padding_mask[i].shape[-1]:
212214
org_out = org_out * self.padding_mask[i].unsqueeze(dim=-1).to(org_out.device) # noqa
213215
out = out * self.padding_mask[i].unsqueeze(dim=-1).to(out.device)
214216

217+
logger.info(f"org_out:{org_out.shape}")
218+
logger.info(f"out:{out.shape}")
219+
215220
loss = self.calculate_loss(org_out, out)
216221

217222
if len(input) == 1:

0 commit comments

Comments
 (0)