Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
gushiqiao committed Dec 30, 2024
1 parent 13463e7 commit 6b89601
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def search_scale_subset(
torch.cuda.empty_cache()

x_tmp = self.scaling_input(x, scales, is_gqa)
logger.info(f"x_tmp:{x_tmp.shape}")

if not check_w_only(
self.block_idx,
Expand All @@ -205,13 +206,17 @@ def search_scale_subset(
).fake_quant_act_dynamic(_x)
outs.append(_x)
x_tmp = torch.stack(outs)
logger.info(f"x_tmp:{x_tmp.shape}")

out = self.inspect_module_forward(x_tmp, inspect_module, kwargs)

if self.padding_mask and org_out.shape[1] == self.padding_mask[i].shape[-1]:
org_out = org_out * self.padding_mask[i].unsqueeze(dim=-1).to(org_out.device) # noqa
out = out * self.padding_mask[i].unsqueeze(dim=-1).to(out.device)

logger.info(f"org_out:{org_out.shape}")
logger.info(f"out:{out.shape}")

loss = self.calculate_loss(org_out, out)

if len(input) == 1:
Expand Down

0 comments on commit 6b89601

Please sign in to comment.