Skip to content

Commit

Permalink
Fix bugs (#295)
Browse files Browse the repository at this point in the history
Co-authored-by: gushiqiao <[email protected]>
  • Loading branch information
gushiqiao and gushiqiao authored Jan 14, 2025
1 parent 0de0040 commit 74387fe
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 106 deletions.
131 changes: 78 additions & 53 deletions llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def scaling_weight(self, w, scales, is_gqa):
scales_tmp = self.repeat_gqa_scales(scales)
else:
scales_tmp = scales
w_tmp = w.mul_(scales_tmp.view(1, -1))
return w_tmp
w.mul_(scales_tmp.view(1, -1))
return w

@torch.no_grad()
def get_weight_scale(self, layers_dict):
Expand Down Expand Up @@ -60,14 +60,17 @@ def get_weight_scale(self, layers_dict):
return scale

def get_act_scale(self, x):
batch_means = []
b_num = x.shape[0] // self._bs
for num in range(b_num):
batch_x = x[num * self._bs:(num + 1) * self._bs]
batch_mean = batch_x.abs().view(-1, batch_x.shape[-1]).mean(0)
batch_means.append(batch_mean)
final_mean = sum(batch_means) / len(batch_means)
return final_mean
if x.shape[0] == self._bs:
return x.abs().view(-1, x.shape[-1]).mean(0)
else:
batch_means = []
b_num = x.shape[0] // self._bs
for num in range(b_num):
batch_x = x[num * self._bs:(num + 1) * self._bs]
batch_mean = batch_x.abs().view(-1, batch_x.shape[-1]).mean(0)
batch_means.append(batch_mean)
final_mean = sum(batch_means) / len(batch_means)
return final_mean

@torch.no_grad()
def get_scales(self, prev_op, x, w_max, is_gqa, ratio):
Expand All @@ -93,15 +96,22 @@ def get_scales(self, prev_op, x, w_max, is_gqa, ratio):
return scales

def inspect_module_forward(self, x, inspect_module, kwargs):
outs = []
b_num = x.shape[0] // self._bs
for num in range(b_num):
_x = x[num * self._bs:(num + 1) * self._bs]
out = inspect_module(_x, **kwargs)
if isinstance(out, tuple):
out = out[0]
outs.append(out)
return torch.cat(outs, dim=0)
if self._bs == x.shape[0]:
with torch.no_grad():
out = inspect_module(x, **kwargs)
if isinstance(out, tuple):
out = out[0]
return out
else:
outs = []
b_num = x.shape[0] // self._bs
for num in range(b_num):
_x = x[num * self._bs:(num + 1) * self._bs]
out = inspect_module(_x, **kwargs)
if isinstance(out, tuple):
out = out[0]
outs.append(out)
return torch.cat(outs, dim=0)

@torch.no_grad()
def get_original_out(self, x, inspect_module, subset_kwargs):
Expand All @@ -110,14 +120,53 @@ def get_original_out(self, x, inspect_module, subset_kwargs):
return org_out

def calculate_loss(self, org_out, out):
total_loss = 0.0
b_num = org_out.shape[0] // self._bs
for num in range(b_num):
_org_out = org_out[num * self._bs:(num + 1) * self._bs]
_out = out[num * self._bs:(num + 1) * self._bs]
single_loss = (_org_out - _out).float().pow(2).mean().item()
total_loss += single_loss
return total_loss / b_num
if out.shape[0] == self._bs:
return (org_out - out).float().pow(2).mean().item()
else:
total_loss = 0.0
b_num = org_out.shape[0] // self._bs
for num in range(b_num):
_org_out = org_out[num * self._bs:(num + 1) * self._bs]
_out = out[num * self._bs:(num + 1) * self._bs]
single_loss = (_org_out - _out).float().pow(2).mean().item()
total_loss += single_loss
return total_loss / b_num

def fake_quantize_weight(self, weight, scales, is_gqa, layer_name):
weight = self.scaling_weight(weight, scales, is_gqa)
weight.data = get_wquantizer(
self.block_idx,
layer_name,
self.mix_bits_map,
self.quantizer_mix_bits,
self.wquantizer,
).fake_quant_weight_dynamic(weight.data)

return weight

def fake_quantize_input(self, x_tmp, layers_dict):
if self._bs == x_tmp.shape[0]:
x_tmp = get_aquantizer(
self.block_idx,
list(layers_dict.keys())[0],
self.mix_bits_map,
self.quantizer_mix_bits,
self.aquantizer,
).fake_quant_act_dynamic(x_tmp)
else:
outs = []
for i in range(x_tmp.shape[0]):
_x = x_tmp[i]
_x = get_aquantizer(
self.block_idx,
list(layers_dict.keys())[0],
self.mix_bits_map,
self.quantizer_mix_bits,
self.aquantizer,
).fake_quant_act_dynamic(_x)
outs.append(_x)
x_tmp = torch.stack(outs)
return x_tmp

@torch.no_grad()
def search_scale_subset(
Expand Down Expand Up @@ -158,25 +207,12 @@ def search_scale_subset(
else:
org_out = self.get_original_out(x, inspect_module, kwargs)
org_out_dict[i] = org_out
x_max = self.get_act_scale(x)

ratio = n * 1 / n_grid
scales = self.get_scales(prev_op, x, w_max, is_gqa, ratio)
for layer_name in layers_dict:
fc = layers_dict[layer_name]
fc.weight = self.scaling_weight(fc.weight, scales, is_gqa)

fc.weight.data = get_wquantizer(
self.block_idx,
layer_name,
self.mix_bits_map,
self.quantizer_mix_bits,
self.wquantizer,
).fake_quant_weight_dynamic(fc.weight.data)

del x_max
gc.collect()
torch.cuda.empty_cache()
fc.weight = self.fake_quantize_weight(fc.weight, scales, is_gqa, layer_name)

x_tmp = self.scaling_input(x, scales, is_gqa)

Expand All @@ -187,18 +223,7 @@ def search_scale_subset(
self.quantizer_mix_bits,
self.w_only,
):
outs = []
for i in range(x_tmp.shape[0]):
_x = x_tmp[i]
_x = get_aquantizer(
self.block_idx,
list(layers_dict.keys())[0],
self.mix_bits_map,
self.quantizer_mix_bits,
self.aquantizer,
).fake_quant_act_dynamic(_x)
outs.append(_x)
x_tmp = torch.stack(outs)
x_tmp = self.fake_quantize_input(x_tmp, layers_dict)

out = self.inspect_module_forward(x_tmp, inspect_module, kwargs)

Expand Down
23 changes: 13 additions & 10 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,9 @@ def set_quant_config(self):
if self.act_static:
act_static_cfg.update(self.config.calib.n_sample)
act_static_cfg.update(self.config.calib.bs)
kv_quant_type = self.quant_config['kvcache'].get('quant_type', 'int-quant')
self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']](
self.quant_type, self.quant_config['kvcache'],
kv_quant_type, self.quant_config['kvcache'],
self.model.model_config.num_hidden_layers, **kv_special_cfg, **act_static_cfg
)
self.quant_kvcache = True
Expand Down Expand Up @@ -287,8 +288,9 @@ def set_quant_config(self):
# set online-rotation config
self.online_rotate = special_config.get('online_rotate', False)
if self.online_rotate:
assert self.config['model']['type'] in ['Opt', 'Llama']

assert (
self.config['model']['type'] in ['Opt', 'Llama']
), 'Please set online_rotate=False'
self.hidden_size = self.model.model_config.hidden_size
self.set_model_config()
self.modality = self.quant_config.modality
Expand Down Expand Up @@ -832,12 +834,13 @@ def scaling_input(self, x, scales, is_gqa):
scales_tmp = self.repeat_gqa_scales(scales)
else:
scales_tmp = scales

x_tmp = torch.empty_like(x)
for i, batch in enumerate(x):
batch_scale = scales_tmp.view(1, -1)
x_tmp[i] = batch / batch_scale

if hasattr(self, '_bs') and self._bs < x.shape[0]:
x_tmp = torch.empty_like(x)
for i, batch in enumerate(x):
batch_scale = scales_tmp.view(1, -1)
x_tmp[i] = batch / batch_scale
else:
x_tmp = x / scales.view(1, -1)
return x_tmp

@torch.no_grad()
Expand All @@ -846,7 +849,7 @@ def update_input_feat(self, scale, input_feat, layers_dict, is_gqa):
for i in range(len(input_feat[layer_name])):
inp = input_feat[layer_name][i]
scale = scale.to(inp.device)
inp = self.scaling_input(inp, scale, is_gqa)
input_feat[layer_name][i] = self.scaling_input(inp, scale, is_gqa)

@torch.no_grad()
def set_non_linear_mode(self, quant_format, module, mode):
Expand Down
56 changes: 13 additions & 43 deletions llmc/compression/quantization/quant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import torch
from loguru import logger

try:
from qtorch.quant import float_quantize
except Exception:
logger.warning(
'qtorch not found, please install qtorch.'
'Please install qtorch (pip install qtorch).'
)
float_quantize = None

class BaseQuantizer(object):
def __init__(self, bit, symmetric, granularity, **kwargs):
Expand Down Expand Up @@ -36,8 +44,6 @@ def __init__(self, bit, symmetric, granularity, **kwargs):

# hist config
self.bins = self.kwargs.get('bins', 2048)
self.hist_threshold = self.kwargs.get('hist_threshold', 1)
self.dst_nbins = 2**bit if isinstance(bit, int) else None
self.upsample_rate = (
16 # used to reduce quantization errors when upscaling histogram
)
Expand Down Expand Up @@ -87,36 +93,6 @@ def get_tensor_range(self, tensor, args={}):
else:
return self.get_minmax_range(tensor)

def get_hist_range(self, stats_min_max, act_stats_hist):
clip_val = {}
for input_idx, hist in act_stats_hist.items():
hist = hist.float() / hist.sum()
data_max = max(
-torch.min(stats_min_max[input_idx]['min']),
torch.max(stats_min_max[input_idx]['max']),
)
accum = 0
for i in range(len(hist)):
accum += hist[i]
if accum >= self.hist_threshold:
clip_value = (i + 0.5) * (data_max / self.bins)
clip_val[input_idx] = [
max(-clip_value, torch.min(stats_min_max[input_idx]['min'])),
min(clip_value, torch.max(stats_min_max[input_idx]['max'])),
]
break
if input_idx not in clip_val:
clip_val[input_idx] = [
torch.min(stats_min_max[input_idx]['min']),
torch.max(stats_min_max[input_idx]['max']),
]

moving_min_vals, moving_max_vals = [], []
for input_idx, tensor_range in clip_val.items():
moving_min_vals.append(tensor_range[0])
moving_max_vals.append(tensor_range[1])
return moving_min_vals, moving_max_vals

def get_minmax_range(self, tensor):
if self.granularity == 'per_tensor':
max_val = torch.max(tensor)
Expand Down Expand Up @@ -556,7 +532,7 @@ def get_batch_tensors_qparams(self, act_tensors, alpha=0.01, args={}):
if self.calib_algo == 'static_hist':
assert (
self.sym is True and self.granularity == 'per_tensor'
), 'Only support per tensor static symmetric.'
), 'Only support per tensor static symmetric int quantize.'
min_vals, max_vals = self.get_static_hist_range(act_tensors)
elif self.calib_algo == 'static_minmax':
min_vals, max_vals = self.get_static_minmax_range(act_tensors)
Expand Down Expand Up @@ -657,6 +633,7 @@ def __init__(self, bit, symmetric, granularity, **kwargs):

self.qmin = torch.tensor(self.qmin)
self.qmax = torch.tensor(self.qmax)
self.dst_nbins = 2**bit

def get_hqq_qparams(self, tensor, args):
tensor = tensor.float()
Expand Down Expand Up @@ -947,17 +924,10 @@ def __init__(self, bit, symmetric, granularity, **kwargs):
self.sign_bits = 1
self.num_bits = self.e_bits + self.m_bits + self.sign_bits
self.default_bias = 2 ** (self.e_bits - 1)

self.dst_nbins = 2**self.num_bits
self.use_qtorch = self.kwargs.get('use_qtorch')
if self.use_qtorch:
try:
from qtorch.quant import float_quantize
except ImportError:
logger.error('qtorch not found, please install qtorch.')
raise ImportError('Please install qtorch (pip install qtorch).')

self.float_quantize = float_quantize

assert float_quantize is not None, 'Please install qtorch (pip install qtorch). Or set use_qtorch=False'
if 'float_range' in self.kwargs:
self.qmin, self.qmax = self.kwargs['float_range']
else:
Expand Down Expand Up @@ -1045,7 +1015,7 @@ def quant(self, tensor, scales, zeros, qmax, qmin):
scaled_tensor = tensor / scales + zeros
if self.use_qtorch:
org_dtype = scaled_tensor.dtype
q_tensor = self.float_quantize(
q_tensor = float_quantize(
scaled_tensor.float(), self.e_bits, self.m_bits, rounding='nearest'
)
q_tensor.to(org_dtype)
Expand Down

0 comments on commit 74387fe

Please sign in to comment.