Skip to content

Commit 74387fe

Browse files
gushiqiaogushiqiao
andauthored
Fix bugs (#295)
Co-authored-by: gushiqiao <[email protected]>
1 parent 0de0040 commit 74387fe

File tree

3 files changed

+104
-106
lines changed

3 files changed

+104
-106
lines changed

llmc/compression/quantization/awq.py

Lines changed: 78 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def scaling_weight(self, w, scales, is_gqa):
3131
scales_tmp = self.repeat_gqa_scales(scales)
3232
else:
3333
scales_tmp = scales
34-
w_tmp = w.mul_(scales_tmp.view(1, -1))
35-
return w_tmp
34+
w.mul_(scales_tmp.view(1, -1))
35+
return w
3636

3737
@torch.no_grad()
3838
def get_weight_scale(self, layers_dict):
@@ -60,14 +60,17 @@ def get_weight_scale(self, layers_dict):
6060
return scale
6161

6262
def get_act_scale(self, x):
63-
batch_means = []
64-
b_num = x.shape[0] // self._bs
65-
for num in range(b_num):
66-
batch_x = x[num * self._bs:(num + 1) * self._bs]
67-
batch_mean = batch_x.abs().view(-1, batch_x.shape[-1]).mean(0)
68-
batch_means.append(batch_mean)
69-
final_mean = sum(batch_means) / len(batch_means)
70-
return final_mean
63+
if x.shape[0] == self._bs:
64+
return x.abs().view(-1, x.shape[-1]).mean(0)
65+
else:
66+
batch_means = []
67+
b_num = x.shape[0] // self._bs
68+
for num in range(b_num):
69+
batch_x = x[num * self._bs:(num + 1) * self._bs]
70+
batch_mean = batch_x.abs().view(-1, batch_x.shape[-1]).mean(0)
71+
batch_means.append(batch_mean)
72+
final_mean = sum(batch_means) / len(batch_means)
73+
return final_mean
7174

7275
@torch.no_grad()
7376
def get_scales(self, prev_op, x, w_max, is_gqa, ratio):
@@ -93,15 +96,22 @@ def get_scales(self, prev_op, x, w_max, is_gqa, ratio):
9396
return scales
9497

9598
def inspect_module_forward(self, x, inspect_module, kwargs):
96-
outs = []
97-
b_num = x.shape[0] // self._bs
98-
for num in range(b_num):
99-
_x = x[num * self._bs:(num + 1) * self._bs]
100-
out = inspect_module(_x, **kwargs)
101-
if isinstance(out, tuple):
102-
out = out[0]
103-
outs.append(out)
104-
return torch.cat(outs, dim=0)
99+
if self._bs == x.shape[0]:
100+
with torch.no_grad():
101+
out = inspect_module(x, **kwargs)
102+
if isinstance(out, tuple):
103+
out = out[0]
104+
return out
105+
else:
106+
outs = []
107+
b_num = x.shape[0] // self._bs
108+
for num in range(b_num):
109+
_x = x[num * self._bs:(num + 1) * self._bs]
110+
out = inspect_module(_x, **kwargs)
111+
if isinstance(out, tuple):
112+
out = out[0]
113+
outs.append(out)
114+
return torch.cat(outs, dim=0)
105115

106116
@torch.no_grad()
107117
def get_original_out(self, x, inspect_module, subset_kwargs):
@@ -110,14 +120,53 @@ def get_original_out(self, x, inspect_module, subset_kwargs):
110120
return org_out
111121

112122
def calculate_loss(self, org_out, out):
113-
total_loss = 0.0
114-
b_num = org_out.shape[0] // self._bs
115-
for num in range(b_num):
116-
_org_out = org_out[num * self._bs:(num + 1) * self._bs]
117-
_out = out[num * self._bs:(num + 1) * self._bs]
118-
single_loss = (_org_out - _out).float().pow(2).mean().item()
119-
total_loss += single_loss
120-
return total_loss / b_num
123+
if out.shape[0] == self._bs:
124+
return (org_out - out).float().pow(2).mean().item()
125+
else:
126+
total_loss = 0.0
127+
b_num = org_out.shape[0] // self._bs
128+
for num in range(b_num):
129+
_org_out = org_out[num * self._bs:(num + 1) * self._bs]
130+
_out = out[num * self._bs:(num + 1) * self._bs]
131+
single_loss = (_org_out - _out).float().pow(2).mean().item()
132+
total_loss += single_loss
133+
return total_loss / b_num
134+
135+
def fake_quantize_weight(self, weight, scales, is_gqa, layer_name):
136+
weight = self.scaling_weight(weight, scales, is_gqa)
137+
weight.data = get_wquantizer(
138+
self.block_idx,
139+
layer_name,
140+
self.mix_bits_map,
141+
self.quantizer_mix_bits,
142+
self.wquantizer,
143+
).fake_quant_weight_dynamic(weight.data)
144+
145+
return weight
146+
147+
def fake_quantize_input(self, x_tmp, layers_dict):
148+
if self._bs == x_tmp.shape[0]:
149+
x_tmp = get_aquantizer(
150+
self.block_idx,
151+
list(layers_dict.keys())[0],
152+
self.mix_bits_map,
153+
self.quantizer_mix_bits,
154+
self.aquantizer,
155+
).fake_quant_act_dynamic(x_tmp)
156+
else:
157+
outs = []
158+
for i in range(x_tmp.shape[0]):
159+
_x = x_tmp[i]
160+
_x = get_aquantizer(
161+
self.block_idx,
162+
list(layers_dict.keys())[0],
163+
self.mix_bits_map,
164+
self.quantizer_mix_bits,
165+
self.aquantizer,
166+
).fake_quant_act_dynamic(_x)
167+
outs.append(_x)
168+
x_tmp = torch.stack(outs)
169+
return x_tmp
121170

122171
@torch.no_grad()
123172
def search_scale_subset(
@@ -158,25 +207,12 @@ def search_scale_subset(
158207
else:
159208
org_out = self.get_original_out(x, inspect_module, kwargs)
160209
org_out_dict[i] = org_out
161-
x_max = self.get_act_scale(x)
162210

163211
ratio = n * 1 / n_grid
164212
scales = self.get_scales(prev_op, x, w_max, is_gqa, ratio)
165213
for layer_name in layers_dict:
166214
fc = layers_dict[layer_name]
167-
fc.weight = self.scaling_weight(fc.weight, scales, is_gqa)
168-
169-
fc.weight.data = get_wquantizer(
170-
self.block_idx,
171-
layer_name,
172-
self.mix_bits_map,
173-
self.quantizer_mix_bits,
174-
self.wquantizer,
175-
).fake_quant_weight_dynamic(fc.weight.data)
176-
177-
del x_max
178-
gc.collect()
179-
torch.cuda.empty_cache()
215+
fc.weight = self.fake_quantize_weight(fc.weight, scales, is_gqa, layer_name)
180216

181217
x_tmp = self.scaling_input(x, scales, is_gqa)
182218

@@ -187,18 +223,7 @@ def search_scale_subset(
187223
self.quantizer_mix_bits,
188224
self.w_only,
189225
):
190-
outs = []
191-
for i in range(x_tmp.shape[0]):
192-
_x = x_tmp[i]
193-
_x = get_aquantizer(
194-
self.block_idx,
195-
list(layers_dict.keys())[0],
196-
self.mix_bits_map,
197-
self.quantizer_mix_bits,
198-
self.aquantizer,
199-
).fake_quant_act_dynamic(_x)
200-
outs.append(_x)
201-
x_tmp = torch.stack(outs)
226+
x_tmp = self.fake_quantize_input(x_tmp, layers_dict)
202227

203228
out = self.inspect_module_forward(x_tmp, inspect_module, kwargs)
204229

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,9 @@ def set_quant_config(self):
243243
if self.act_static:
244244
act_static_cfg.update(self.config.calib.n_sample)
245245
act_static_cfg.update(self.config.calib.bs)
246+
kv_quant_type = self.quant_config['kvcache'].get('quant_type', 'int-quant')
246247
self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']](
247-
self.quant_type, self.quant_config['kvcache'],
248+
kv_quant_type, self.quant_config['kvcache'],
248249
self.model.model_config.num_hidden_layers, **kv_special_cfg, **act_static_cfg
249250
)
250251
self.quant_kvcache = True
@@ -287,8 +288,9 @@ def set_quant_config(self):
287288
# set online-rotation config
288289
self.online_rotate = special_config.get('online_rotate', False)
289290
if self.online_rotate:
290-
assert self.config['model']['type'] in ['Opt', 'Llama']
291-
291+
assert (
292+
self.config['model']['type'] in ['Opt', 'Llama']
293+
), 'Please set online_rotate=False'
292294
self.hidden_size = self.model.model_config.hidden_size
293295
self.set_model_config()
294296
self.modality = self.quant_config.modality
@@ -832,12 +834,13 @@ def scaling_input(self, x, scales, is_gqa):
832834
scales_tmp = self.repeat_gqa_scales(scales)
833835
else:
834836
scales_tmp = scales
835-
836-
x_tmp = torch.empty_like(x)
837-
for i, batch in enumerate(x):
838-
batch_scale = scales_tmp.view(1, -1)
839-
x_tmp[i] = batch / batch_scale
840-
837+
if hasattr(self, '_bs') and self._bs < x.shape[0]:
838+
x_tmp = torch.empty_like(x)
839+
for i, batch in enumerate(x):
840+
batch_scale = scales_tmp.view(1, -1)
841+
x_tmp[i] = batch / batch_scale
842+
else:
843+
x_tmp = x / scales.view(1, -1)
841844
return x_tmp
842845

843846
@torch.no_grad()
@@ -846,7 +849,7 @@ def update_input_feat(self, scale, input_feat, layers_dict, is_gqa):
846849
for i in range(len(input_feat[layer_name])):
847850
inp = input_feat[layer_name][i]
848851
scale = scale.to(inp.device)
849-
inp = self.scaling_input(inp, scale, is_gqa)
852+
input_feat[layer_name][i] = self.scaling_input(inp, scale, is_gqa)
850853

851854
@torch.no_grad()
852855
def set_non_linear_mode(self, quant_format, module, mode):

llmc/compression/quantization/quant.py

Lines changed: 13 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
import torch
22
from loguru import logger
33

4+
try:
5+
from qtorch.quant import float_quantize
6+
except Exception:
7+
logger.warning(
8+
'qtorch not found, please install qtorch.'
9+
'Please install qtorch (pip install qtorch).'
10+
)
11+
float_quantize = None
412

513
class BaseQuantizer(object):
614
def __init__(self, bit, symmetric, granularity, **kwargs):
@@ -36,8 +44,6 @@ def __init__(self, bit, symmetric, granularity, **kwargs):
3644

3745
# hist config
3846
self.bins = self.kwargs.get('bins', 2048)
39-
self.hist_threshold = self.kwargs.get('hist_threshold', 1)
40-
self.dst_nbins = 2**bit if isinstance(bit, int) else None
4147
self.upsample_rate = (
4248
16 # used to reduce quantization errors when upscaling histogram
4349
)
@@ -87,36 +93,6 @@ def get_tensor_range(self, tensor, args={}):
8793
else:
8894
return self.get_minmax_range(tensor)
8995

90-
def get_hist_range(self, stats_min_max, act_stats_hist):
91-
clip_val = {}
92-
for input_idx, hist in act_stats_hist.items():
93-
hist = hist.float() / hist.sum()
94-
data_max = max(
95-
-torch.min(stats_min_max[input_idx]['min']),
96-
torch.max(stats_min_max[input_idx]['max']),
97-
)
98-
accum = 0
99-
for i in range(len(hist)):
100-
accum += hist[i]
101-
if accum >= self.hist_threshold:
102-
clip_value = (i + 0.5) * (data_max / self.bins)
103-
clip_val[input_idx] = [
104-
max(-clip_value, torch.min(stats_min_max[input_idx]['min'])),
105-
min(clip_value, torch.max(stats_min_max[input_idx]['max'])),
106-
]
107-
break
108-
if input_idx not in clip_val:
109-
clip_val[input_idx] = [
110-
torch.min(stats_min_max[input_idx]['min']),
111-
torch.max(stats_min_max[input_idx]['max']),
112-
]
113-
114-
moving_min_vals, moving_max_vals = [], []
115-
for input_idx, tensor_range in clip_val.items():
116-
moving_min_vals.append(tensor_range[0])
117-
moving_max_vals.append(tensor_range[1])
118-
return moving_min_vals, moving_max_vals
119-
12096
def get_minmax_range(self, tensor):
12197
if self.granularity == 'per_tensor':
12298
max_val = torch.max(tensor)
@@ -556,7 +532,7 @@ def get_batch_tensors_qparams(self, act_tensors, alpha=0.01, args={}):
556532
if self.calib_algo == 'static_hist':
557533
assert (
558534
self.sym is True and self.granularity == 'per_tensor'
559-
), 'Only support per tensor static symmetric.'
535+
), 'Only support per tensor static symmetric int quantize.'
560536
min_vals, max_vals = self.get_static_hist_range(act_tensors)
561537
elif self.calib_algo == 'static_minmax':
562538
min_vals, max_vals = self.get_static_minmax_range(act_tensors)
@@ -657,6 +633,7 @@ def __init__(self, bit, symmetric, granularity, **kwargs):
657633

658634
self.qmin = torch.tensor(self.qmin)
659635
self.qmax = torch.tensor(self.qmax)
636+
self.dst_nbins = 2**bit
660637

661638
def get_hqq_qparams(self, tensor, args):
662639
tensor = tensor.float()
@@ -947,17 +924,10 @@ def __init__(self, bit, symmetric, granularity, **kwargs):
947924
self.sign_bits = 1
948925
self.num_bits = self.e_bits + self.m_bits + self.sign_bits
949926
self.default_bias = 2 ** (self.e_bits - 1)
950-
927+
self.dst_nbins = 2**self.num_bits
951928
self.use_qtorch = self.kwargs.get('use_qtorch')
952929
if self.use_qtorch:
953-
try:
954-
from qtorch.quant import float_quantize
955-
except ImportError:
956-
logger.error('qtorch not found, please install qtorch.')
957-
raise ImportError('Please install qtorch (pip install qtorch).')
958-
959-
self.float_quantize = float_quantize
960-
930+
assert float_quantize is not None, 'Please install qtorch (pip install qtorch). Or set use_qtorch=False'
961931
if 'float_range' in self.kwargs:
962932
self.qmin, self.qmax = self.kwargs['float_range']
963933
else:
@@ -1045,7 +1015,7 @@ def quant(self, tensor, scales, zeros, qmax, qmin):
10451015
scaled_tensor = tensor / scales + zeros
10461016
if self.use_qtorch:
10471017
org_dtype = scaled_tensor.dtype
1048-
q_tensor = self.float_quantize(
1018+
q_tensor = float_quantize(
10491019
scaled_tensor.float(), self.e_bits, self.m_bits, rounding='nearest'
10501020
)
10511021
q_tensor.to(org_dtype)

0 commit comments

Comments
 (0)